diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1f6bb2c7a07768..04f61d9d4c6efe 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -789,9 +789,8 @@ def _get_resized_embeddings( ) # Build new embeddings - new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to( - self.device, dtype=old_embeddings.weight.dtype - ) + new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) + new_embeddings.to(self.device, dtype=old_embeddings.weight.dtype) # initialize all new embeddings (in particular added tokens) self._init_weights(new_embeddings) @@ -862,7 +861,8 @@ def _get_resized_lm_head( # Build new lm head new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) has_new_lm_head_bias = old_lm_head.bias is not None - new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias).to(self.device) + new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias) + new_lm_head = new_lm_head.to(self.device, dtype=old_lm_head.weight.dtype) # initialize new lm head (in particular added tokens) self._init_weights(new_lm_head)