Skip to content

Commit

Permalink
[modeling_utils] respect original dtype in _get_resized_lm_head (#14181)
Browse files Browse the repository at this point in the history
* respect dtype in _get_resized_lm_head

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* consistency

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
stas00 and sgugger authored Oct 28, 2021
1 parent 88cd82e commit 123cce6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,9 +789,8 @@ def _get_resized_embeddings(
)

# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(
self.device, dtype=old_embeddings.weight.dtype
)
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(self.device, dtype=old_embeddings.weight.dtype)

# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
Expand Down Expand Up @@ -862,7 +861,8 @@ def _get_resized_lm_head(
# Build new lm head
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias).to(self.device)
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
new_lm_head = new_lm_head.to(self.device, dtype=old_lm_head.weight.dtype)

# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
Expand Down

0 comments on commit 123cce6

Please sign in to comment.