Understanding how tf.nn.embedding_lookup
works can be unduly complex. Perhaps a simple example will help. All it does is lookup the embedding values given a list of indices.
Let’s say we have these embeddings in 3 dimension space for a vocabulary of 4 items.
#Embedding with 3 dimensions with a vocabulary of 4
embedding = [
[0.36808, 0.20834, -0.22319],
[0.7503, 0.71623, -0.27033],
[0.042523, -0.21172, 0.044739],
[0.17698, 0.065221, 0.28548]
]
We can then lookup the embedding for the first and third item like this.
tf_embedding = tf.constant(embedding, dtype=tf.float32)
with tf.Session() as sess:
index_to_lookup = [0, 2]
lookup = tf.nn.embedding_lookup(tf_embedding, index_to_lookup)
print(sess.run(lookup))
This will print.
[[ 0.36808 0.20834 -0.22319 ]
[ 0.042523 -0.21172 0.044739]]