-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Using Gensim Embeddings with Keras and Tensorflow
Radim Řehůřek edited this page Sep 10, 2020
·
3 revisions
So you trained a Word2Vec, Doc2Vec or FastText embedding model using Gensim, and now you want to use the result in a Keras / Tensorflow pipeline. How do you connect the two?
Use this function:
from tensorflow.keras.layers import Embedding
def gensim_to_keras_embedding(model, train_embeddings=False):
"""Get a Keras 'Embedding' layer with weights set from Word2Vec model's learned word embeddings.
Parameters
----------
train_embeddings : bool
If False, the returned weights are frozen and stopped from being updated.
If True, the weights can / will be further updated in Keras.
Returns
-------
`keras.layers.Embedding`
Embedding layer, to be used as input to deeper network layers.
"""
keyed_vectors = model.wv # structure holding the result of training
weights = keyed_vectors.vectors # vectors themselves, a 2D numpy array
index_to_key = keyed_vectors.index_to_key # which row in `weights` corresponds to which word?
layer = Embedding(
input_dim=weights.shape[0],
output_dim=weights.shape[1],
weights=[weights],
trainable=train_embeddings,
)
return layer
So, in other words:
- The trained weights are in
model.wv.vectors
, which is a 2D matrix of shape(number of words, dimensionality of word vectors)
. - The mapping between the word indices in this matrix (integers) and the words themselves (strings) is in
model.wv.index_to_key
.
Note: The code talks about "keys" instead of "words", because various embedding models can in principle be used with non-word inputs. For example, in doc2vec the keys are "document tags". The algorithms don't really care what the interpretation of the key
string is – it's an opaque identifier, relevant only in co-occurrence patterns with other keys.