Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed model breaks after model.resize_token_embeddings #2211

Closed
2 of 4 tasks
erap129 opened this issue Dec 4, 2023 · 3 comments
Closed
2 of 4 tasks

Deepspeed model breaks after model.resize_token_embeddings #2211

erap129 opened this issue Dec 4, 2023 · 3 comments
Assignees

Comments

@erap129
Copy link

erap129 commented Dec 4, 2023

System Info

- `Accelerate` version: 0.22.0
- Platform: Linux-5.4.0-153-generic-x86_64-with-glibc2.27
- Python version: 3.8.16
- Numpy version: 1.24.4
- PyTorch version (GPU?): 2.0.1+cu117 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- System RAM: 125.51 GB
- GPU type: NVIDIA GeForce RTX 3090
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: DEEPSPEED
        - use_cpu: False
        - debug: True
        - num_processes: 2
        - machine_rank: 0
        - num_machines: 1
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - deepspeed_config: {'deepspeed_config_file': '/app/code/fsdp_experimentation/zero_stage2_config.json', 'zero3_init_flag': True}
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

In my usecase I need to call model.resize_token_embeddings(len(tokenizer)) after calling accelerate.prepare.
Why? Because I want to be able to load an accelerator state with accelerator.load_state and for that I need to call accelerator.prepare beforehand (based on this issue - #285)
I am pretty sure that by using resize_token_embeddings I am ruining the deepspeed initialiazation that happened inside accelerator.prepare for the embedding layer.
If I remove model.resize_token_embeddings or put it before calling accelerator.prepare the script works.

Here is a minimal reproducible example:

train.py

import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from accelerate import Accelerator

# Initialize a small model and tokenizer
model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(model_name)

# Initialize Accelerator with DeepSpeed and ZeRO-3 optimization
accelerator = Accelerator()
model, optimizer, train_dataloader = accelerator.prepare(model, torch.optim.Adam(model.parameters()), [])

# Add tokens to the tokenizer
new_tokens = ['[NEW_TOKEN1]', '[NEW_TOKEN2]']
num_added_toks = tokenizer.add_tokens(new_tokens)

# Resize token embeddings - Expected to cause the issue
model.resize_token_embeddings(len(tokenizer))

# Dummy forward pass to test if everything works
inputs = tokenizer("Hello, this is a test", return_tensors="pt").to(accelerator.device)
outputs = model(**inputs)
print(outputs)

deepspeed_config.yaml

compute_environment: LOCAL_MACHINE
deepspeed_config:
 deepspeed_config_file: zero3_minimal_example.json
 zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
num_machines: 1
num_processes: 1
use_cpu: false

zero3_minimal_example.json

{
    "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "sub_group_size": 1e9,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": "auto"
    },
    "gradient_accumulation_steps": 1,
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": 1,
    "train_micro_batch_size_per_gpu": 1,
    "wall_clock_breakdown": false
}

Run this:
accelerate launch --config_file deepspeed_config.yaml train.py

This is the error I get (although in my full example the error is a bit different, something with the parameters of the model not having a ds_params field):

Traceback (most recent call last):
  File "minimal_deepspeed_tokenizer_problem.py", line 23, in <module>
    outputs = model(**inputs)
  File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1768, in forward
    loss = self.module(*inputs, **kwargs)
  File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py", line 789, in forward
    distilbert_output = self.distilbert(
  File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py", line 607, in forward
    embeddings = self.embeddings(input_ids, inputs_embeds)  # (bs, seq_length, dim)
  File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py", line 120, in forward
    input_embeds = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)
  File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/lib/python3.8/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D

Expected behavior

I am expecting print(outputs) to work and for the script not to crash.

@muellerzr
Copy link
Collaborator

cc @pacman100

@pacman100
Copy link
Contributor

Hello, either when training or resuming, please resize the embedding layer before preparing it. That way it is independent of load_state. Please let us know if that resolves the issue

Copy link

github-actions bot commented Jan 3, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants