Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pretrained model #164

Closed
minmummax opened this issue Jan 4, 2019 · 4 comments
Closed

pretrained model #164

minmummax opened this issue Jan 4, 2019 · 4 comments

Comments

@minmummax
Copy link

is the pretrained model downloaded include word embedding?
I do not see any embedding in your code
please

@rodgzilla
Copy link
Contributor

rodgzilla commented Jan 7, 2019

All the code related to word embeddings is located there https://github.com/huggingface/pytorch-pretrained-BERT/blob/8da280ebbeca5ebd7561fd05af78c65df9161f92/pytorch_pretrained_bert/modeling.py#L172-L200

If you want to access pretrained embeddings, the easier thing to do would be to load a pretrained model and extract its embedding matrices.

@minmummax
Copy link
Author

All the code related to word embeddings is located there

pytorch-pretrained-BERT/pytorch_pretrained_bert/modeling.py

Lines 172 to 200 in 8da280e

class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def init(self, config):
super(BertEmbeddings, self).init()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

     # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 
     # any TensorFlow checkpoint file 
     self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 
     self.dropout = nn.Dropout(config.hidden_dropout_prob) 

 def forward(self, input_ids, token_type_ids=None): 
     seq_length = input_ids.size(1) 
     position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 
     position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 
     if token_type_ids is None: 
         token_type_ids = torch.zeros_like(input_ids) 

     words_embeddings = self.word_embeddings(input_ids) 
     position_embeddings = self.position_embeddings(position_ids) 
     token_type_embeddings = self.token_type_embeddings(token_type_ids) 

     embeddings = words_embeddings + position_embeddings + token_type_embeddings 
     embeddings = self.LayerNorm(embeddings) 
     embeddings = self.dropout(embeddings) 
     return embeddings 

If you want to access pretrained embeddings, the easier thing to do would be to load a pretrained model and extract its embedding matrices.

oh I have seen this code these days . and from this code I think it dose not use the pretrained embedding paras , and what do you mean by load and extract a pretrained model ???? Is it from the original supplies

@rodgzilla
Copy link
Contributor

In [1]: from pytorch_pretrained_bert import BertModel                                                                                                                                    

In [2]: model = BertModel.from_pretrained('bert-base-uncased')                                                                                                                           

In [3]: model.embeddings.word_embeddings                                                                                                                                                 
Out[3]: Embedding(30522, 768)

This field of the BertEmbeddings class contains the pretrained embeddings. It gets set by calling BertModel.from_pretrained.

@thomwolf
Copy link
Member

thomwolf commented Jan 7, 2019

Thanks Gregory that the way to go indeed!

@thomwolf thomwolf closed this as completed Jan 7, 2019
ZYC-ModelCloud pushed a commit to ZYC-ModelCloud/transformers that referenced this issue Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants