这种问题是,对于每一个变量 variable 由于是基于protobuf存在这大小限制(2G),这个时候,我们需要将embedding拆开,拆分成N等分,来使得每一个
variable都在2G以下;
1 # !/usr/bin/env/python
2 # coding=utf-8
3 import tensorflow as tf
4 import numpy as np
5
6 input_ids = tf.placeholder(dtype=tf.int32, shape=[None,None])
7
8 num_shards = 3
9 weights = []
10 weights_shape = np.arange(27).reshape(9, 3)
11 # assert weights_shape[0] % num_shards == 0
12 num_shards_len = (weights_shape.shape[0]) / num_shards
13 assert (weights_shape.shape[0]) % num_shards ==0
14 begin_ = 0
15 ends_ = num_shards_len
16 for i in range(0, num_shards):
17 if (i + 1) * num_shards_len < weights_shape.shape[0]:
18 begin_ = i * num_shards_len
19 if i + 1 == num_shards:
20 ends_ = weights_shape.shape[0]
21 else:
22 ends_ = (i + 1) * num_shards_len
23 else:
24 begin_ = i * num_shards_len
25 ends_ = weights_shape.shape[0]
26 weights_i = tf.get_variable("words-%02d" % i,
27 initializer=tf.constant(weights_shape[begin_: ends_, ]))
28 weights.append(weights_i)
29
30 input_embedding = tf.nn.embedding_lookup(weights, input_ids,partition_strategy="div")
31
32 sess = tf.InteractiveSession()
33 sess.run(tf.global_variables_initializer())
34 print(sess.run(weights))
35
36 print(sess.run(input_embedding, feed_dict={input_ids: [[1, 2], [3, 0], [8, 2], [5, 1]]}))
结果为:
[array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]), array([[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]), array([[18, 19, 20],
[21, 22, 23],
[24, 25, 26]])]
[[[ 3 4 5]
[ 6 7 8]]
[[ 9 10 11]
[ 0 1 2]]
[[24 25 26]
[ 6 7 8]]
[[15 16 17]
[ 3 4 5]]]