Skip to content

Commit

Permalink
Fix resize embedding with Deepspeed (#32192)
Browse files Browse the repository at this point in the history
fix resize when deepspeed
  • Loading branch information
zucchini-nlp authored and LysandreJik committed Jul 24, 2024
1 parent a2b6a00 commit 4672b4d
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,13 +2131,23 @@ def _get_resized_embeddings(
# Replace weights in old_embeddings and return to maintain the same embedding type.
# This ensures correct functionality when a Custom Embedding class is passed as input.
# The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
old_embeddings.weight.data = new_embeddings.weight.data
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
old_embeddings.weight.data = new_embeddings.weight.data
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]

# If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
# will be set to `None` in the resized embeddings.
if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
old_embeddings.padding_idx = None
# If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
# will be set to `None` in the resized embeddings.
if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
old_embeddings.padding_idx = None
else:
old_embeddings.weight.data = new_embeddings.weight.data
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
old_embeddings.padding_idx = None

return old_embeddings

Expand Down

0 comments on commit 4672b4d

Please sign in to comment.