Embedding Lookup in Tensorflow

Understanding how tf.nn.embedding_lookup works can be unduly complex. Perhaps a simple example will help. All it does is lookup the embedding values given a list of indices.

Let’s say we have these embeddings in 3 dimension space for a vocabulary of 4 items.

#Embedding with 3 dimensions with a vocabulary of 4
embedding = [
    [0.36808, 0.20834, -0.22319],
    [0.7503, 0.71623, -0.27033],
    [0.042523, -0.21172, 0.044739],
    [0.17698, 0.065221, 0.28548]
]

We can then lookup the embedding for the first and third item like this.

tf_embedding = tf.constant(embedding, dtype=tf.float32)

with tf.Session() as sess:
  index_to_lookup = [0, 2]
  lookup = tf.nn.embedding_lookup(tf_embedding, index_to_lookup)

  print(sess.run(lookup))

This will print.

[[ 0.36808   0.20834  -0.22319 ]
 [ 0.042523 -0.21172   0.044739]]