Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discrepancy in Training Loss Behavior with Gradient Accumulation using DeepSpeed #34694

Closed
2 of 4 tasks
kmchiti opened this issue Nov 12, 2024 · 6 comments
Closed
2 of 4 tasks

Comments

@kmchiti
Copy link

kmchiti commented Nov 12, 2024

System Info

Accelerate version: 1.1.0
transformers version: 4.46.2
DeepSpeed version: 0.14.4
Platform: Linux 5.15.0-101-generic #111-Ubuntu SMP x86_64 GNU/Linux
Python version: 3.10.14
PyTorch version (GPU?): 2.1.2+cu118 True
GPU type: NVIDIA A100

Who can help?

@muellerzr

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The code provided below is a simplified example of training a small model using the Hugging Face Trainer. The setup includes creating a dataset, initializing a model and tokenizer, and configuring the Trainer with different settings for gradient accumulation and DeepSpeed.

import argparse
import torch
from datasets import load_dataset
from transformers import (set_seed,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          LlamaForCausalLM,
                          LlamaConfig,
                          AutoTokenizer
                          )


DEEPSPEED_CONFIG = {
  "zero_optimization": {
    "stage": 0
  },
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": "auto",
      "betas": "auto",
      "eps": "auto",
      "weight_decay": "auto"
    }
  },
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto"
}

TRAIN_ARGS = {'output_dir': './test_GA',
              'bf16': True,
              'learning_rate': 6e-4,
              'lr_scheduler_type': 'cosine',
              'max_steps': 200,
              'optim': 'adamw_torch',
              'weight_decay': 0.1,
              'per_device_train_batch_size': 128,
              'gradient_accumulation_steps': 1,
              'logging_steps': 1,
              'report_to': 'none'}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--bs', default=128, type=int, help='batch size')
    parser.add_argument('--ga', default=1, type=int, help='number of gradient accumulation step')
    parser.add_argument('--deepspeed', action='store_true', help='use deepspeed')
    args = parser.parse_args()

    set_seed(42)
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False

    # Initialize dataset
    CONTEXT_LENGTH = 512  # Small context length as specified
    def preprocess_data(examples, tokenizer, max_length=CONTEXT_LENGTH):
        """Tokenizes the input data and truncates/pads to the max context length."""
        return tokenizer(examples["sentence"], truncation=True, padding="max_length", max_length=max_length, add_special_tokens=True)

    # Load the dataset from Hugging Face
    dataset = load_dataset("ptb_text_only", trust_remote_code=True, split='train')

    # Load and configure the tokenizer
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", add_prefix_space=True, use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token

    # Preprocess the dataset
    column_names = list(dataset.features)
    train_dataset = dataset.map(lambda x: preprocess_data(x, tokenizer), batched=True, remove_columns=column_names)

    # Initialize model
    model_cfg = LlamaConfig(n_positions=CONTEXT_LENGTH, hidden_size=512, num_attention_heads=8, num_hidden_layers=4,
                            vocab_size=tokenizer.vocab_size, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id)
    model = LlamaForCausalLM(model_cfg)
    
    # Initialize trainer
    if args.deepspeed:
        TRAIN_ARGS.update({"deepspeed": DEEPSPEED_CONFIG})
    TRAIN_ARGS.update({"per_device_train_batch_size": args.bs, "gradient_accumulation_steps": args.ga})
    trainer = Trainer(model=model, args=TrainingArguments(**TRAIN_ARGS), train_dataset=train_dataset,
                      data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False))
    trainer.train()

Expected behavior

The training loss should remain consistent for different gradient accumulation steps, both with and without DeepSpeed enabled. However, the figure shows a divergence when DeepSpeed is enabled:

gradient_accumulation_issue

@kmchiti kmchiti added the bug label Nov 12, 2024
@LysandreJik
Copy link
Member

Thanks for the bug report @kmchiti, we're taking a look

@effortprogrammer
Copy link

Do you have any plannings for the fix?

@Yangr116
Copy link

Yangr116 commented Dec 9, 2024

Any update?

@muellerzr
Copy link
Contributor

#35157 should help fix this

@michaelroyzen
Copy link

Try updating to 4.46.3, there is a gradient accumulation bug touching DeepSpeed that was patched there: #34645

Copy link

github-actions bot commented Jan 8, 2025

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants