Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Activation offloading for distributed lora recipe #1645

Merged

Conversation

Jackmin801
Copy link
Contributor

@Jackmin801 Jackmin801 commented Sep 21, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • Change docstring to reflect arg change activation_checkpoint -> enable_activation_checkpoint
  • Add arg enable_activation_offloading and logic to do activation offloading

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Sep 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1645

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7be67c4 with merge base 4e69db8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @Jackmin801!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 21, 2024
@Jackmin801 Jackmin801 changed the title port changes from single device case [Feat] Activation Sep 21, 2024
@Jackmin801 Jackmin801 changed the title [Feat] Activation [Feat] Activation offloading for distributed lora recipe Sep 21, 2024
@Jackmin801
Copy link
Contributor Author

Jackmin801 commented Sep 22, 2024

While we were working on this PR, we observed a weird behaviour in the trace.

image
image
Why does this aten::_scaled_dot_product_cudnn_attention_backward and forward take so much time in cpu???

Steps to reproduce

1. Clone repo at this branch

git clone --branch feat-activation-offloading-streams-dist git@github.com:Jackmin801/torchtune
cd torchtune

2. Dev Install with torch nightly

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
pip install torchao
pip install matplotlib pandas numpy
pip install -e ".[dev]"

3. Login to huggingface and download model

huggingface-cli login --token <HF_TOKEN>
tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/*.pth"

4. Run single device

#tune run --nnodes 1 --nproc-per-node 2 lora_finetune_distributed --config recipes/configs/llama3_1/8B_lora.yaml \
tune run lora_finetune_single_device --config recipes/configs/llama3_1/8B_lora.yaml \
  optimizer_in_bwd=False \
  enable_activation_offloading=False \
  enable_activation_checkpointing=True \
  compile=False \
  model.lora_attn_modules="[q_proj,v_proj]" \
  model.apply_lora_to_mlp=False \
  model.apply_lora_to_output=False \
  model.lora_rank=8 \
  model.lora_alpha=16 \
  dataset.source=Yukang/LongAlpaca-12k \
  dataset.packed=False \
  dataset.split=train[:10%] \
  dataset.train_on_input=True \
  tokenizer.max_seq_len=8192 \
  metric_logger=torchtune.training.metric_logging.StdoutLogger \
  metric_logger.project=recipe_profiling \
  log_every_n_steps=1 \
  log_peak_memory_stats=True \
  gradient_accumulation_steps=1 \
  max_steps_per_epoch=4 \
  epochs=1 \
  batch_size=2 \
  metric_logger.name=llama3__qlora__seqlen_8192__act_ckpt_True__act_off_True__bs2 \
  profiler.enabled=True \
  profiler.profile_memory=True \
  profiler.with_stack=True \
  profiler.wait_steps=0 \
  profiler.warmup_steps=2 \
  profiler.active_steps=2 \
  profiler.num_cycles=1

5. Observe the trace

Weird right?

Observations

@yf225 figured out that theres a warning from cuDNN sdpa about memory layout that is likely the culprit

/data/users/weif/pytorch/torch/autograd/graph.py:826: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at /data/users/weif/pytorch/aten/src/ATen/native/cudnn/MHA.cpp:674.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

@janeyx99 @joecummings @yf225

@Jackmin801 Jackmin801 marked this pull request as ready for review September 22, 2024 01:01
@Jackmin801 Jackmin801 force-pushed the feat-activation-offloading-streams-dist branch from 662271b to 7886acc Compare September 22, 2024 01:03
@Jackmin801
Copy link
Contributor Author

ah. will lint haha. wanted to start hacking first

@Jackmin801 Jackmin801 force-pushed the feat-activation-offloading-streams-dist branch from e0c60b4 to 7be67c4 Compare September 26, 2024 00:56
@felipemello1
Copy link
Contributor

Thanks the PR!! Would you be able to run it with activation on and off for ~25-50 steps and compare loss, memory, tps?

If you use weights and biases, its probably the easiest way to take a screenshot (you can login with your gmail, get the token, pip install wandb, wandb login, insert your token)

then you can run your config like this:
tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora max_steps_per_epoch=50 epochs=1 metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True ...other params...

Please use dataset.packed=True, compile=True, enable_activation_checkpoint=True, enable_activation_offload=true/false

if loss and TPS is about the same, but memory is much lower, this should be easy to approve.

@Jackmin801
Copy link
Contributor Author

Jackmin801 commented Sep 29, 2024

image

Loss seems to be roughly the same (difference probably caused by flash attn scheduling). The memory usage is reduced by ~20% with a tps reduction of ~2% on 2 PCIe A100s.

Command:

tune run --nproc_per_node 2 lora_finetune_distributed \
  --config recipes/configs/llama3_1/8B_lora.yaml \
  enable_activation_offloading=True \
  enable_activation_checkpointing=True \
  compile=True \
  dataset.source=Yukang/LongAlpaca-12k \
  dataset.packed=True \
  tokenizer.max_seq_len=8192 \
  metric_logger=torchtune.training.metric_logging.WandBLogger \
  metric_logger.project=torchtune-activation-offloading \
  log_every_n_steps=1 \
  log_peak_memory_stats=True \
  max_steps_per_epoch=50 \
  epochs=1 \
  gradient_accumulation_steps=1 \
  batch_size=2

@felipemello1
Copy link
Contributor

will get to this tomorrow. Thanks for posting the results!

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! Thanks for the PR!

@felipemello1 felipemello1 merged commit 7af77c7 into pytorch:main Oct 3, 2024
17 checks passed
@Jackmin801 Jackmin801 deleted the feat-activation-offloading-streams-dist branch October 4, 2024 09:16
@joecummings joecummings mentioned this pull request Oct 9, 2024
34 tasks
mori360 pushed a commit to mori360/torchtune that referenced this pull request Oct 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants