diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index d0231d5bd18f92..612100146954ab 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -146,7 +146,7 @@ class BertEmbeddings(nn.Module): def __init__(self, config): super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)