Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full finetune < 16GB #527

Merged
merged 37 commits into from
Mar 29, 2024
Merged

Full finetune < 16GB #527

merged 37 commits into from
Mar 29, 2024

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented Mar 19, 2024

Context

  • We'd like to enable a variant of full finetune that trains in < 16GB of RAM for users with consumer grade GPUs that have limited GPU RAM. This PR enables the full finetune to fit into that memory requirement.

Changelog

  • Add support for running optimizer in backward. This is done by following this tutorial: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html in _setup_optimizer.
  • Add OptimInBackwardWrapper to checkpoint the optimizer states when running in backward + unittests.
  • Update relevant portions of training code such as not calling step, zero_grad when using optimizer in backward.
  • Disable optim state dict checkpointing for optimizer in backward, as this is nontrivial to implement and not supported yet.
  • Warn when user is loading in an optimizer state checkpoint but optimizer in backward is enabled. In this case, the checkpoint is not loaded in.
  • Change configuration to use PagedAdamW optimizer, when using this, memory is < 16GB:

GPU peak memory allocation: 13.993540096GB, GPU peak memory reserved: 15.974006784GB, GPU peak memory active: 13.993540096GB

Test plan

  • Run with bnb PagedAdamW + optim in backward + full bf16: tune full_finetune_single_device --config recipes/configs/full_finetune_single_device.yaml &> out &

Memory does not spike above 16GB reserved -

Memory Stats::
 3450     GPU peak memory allocation: 13.92 GB
 3451     GPU peak memory reserved: 15.97 GB
 3452     GPU peak memory active: 13.92 GB

Loss curves:
image

Eval result:

truthfulqa_mc2: {'acc,none': 0.4372944697596623, 'acc_stderr,none': 0.014862338279863522, 'alias': 'truthfulqa_mc2'}

Checkpoint save/load of optimizer states

  • Run and save recipe state: tune full_finetune_single_device --config recipes/configs/full_finetune_single_device.yaml max_steps_per_epoch=1 device=cuda:1
  • Load recipe state, including optimizer state, back in for training: set resume_from_checkpoint: True and recipe_checkpoint: /tmp/llama2/recipe_state.pt, then run: tune full_finetune_single_device --config recipes/configs/full_finetune_single_device.yaml device=cuda:1 --> training resume works

Discussion points

  • Configuring bnb optimizers. Currently we are just able to do this in the config itself, user can just swap torch.optim.AdamW for bitsandbytes.optim.PagedAdamW for example. Can definitely have some documentation around this, but wondering if we need to do anything further for configuration.
  • Currently, I'm mostly testing w/PagedAdamW but the 8bit adam from BNB works as well and can be similarly configured. PagedAdamW results in more memory savings (~14GB peak memory vs 27GB for 8 bit adam).
  • Checkpointing of optimizer in backward. For a fully functional recipe, checkpointing optimizer states for optimizer in backward is critical. We can do this with distributed's _apply_optimizer_in_backward API + distributed_state_dict, though the downsides are that the latter API is only avail. in more recent versions of torch, and the former API is private. OTOH, to implement checkpointing using the technique in https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html is not well understood from a technical perspective (hasn't been done before, and there are tricky things to figure out like how to key the optimizers by a UUID instead of a numerical index, which won't work when there are multiple optimizers). Using the former implementation is also useful since it is better tested and designed than building something ourselves.
  • UX / exposing this memory efficient config. Currently I've just modified the config in place, but we can easily land a separate memory efficient config. We could even author an entirely separate recipe for users who don't want to enable optimizer in backward and deal with the UX complexity of using this.

Follow up work

  • Documentation for using bitsandbytes optimizers, update documentation for memory efficiency in general

Copy link

pytorch-bot bot commented Mar 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/527

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6c9731d with merge base ec89eb0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 19, 2024
Copy link

netlify bot commented Mar 19, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit 57030f2
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/6600dd6518f3700008df8a7d
😎 Deploy Preview https://deploy-preview-527--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@@ -32,8 +32,9 @@ resume_from_checkpoint: False
batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.SGD
_component_: bitsandbytes.optim.PagedAdamW
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add this to requirements, docs, etc

lr: 2e-5
optimizer_in_bwd: True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will likely ship this as False and add docs / tutorials since we don't support loading in optimizer states yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm, we do support it now.

@@ -22,6 +22,59 @@
)


class OptimizerInBackwardWrapper:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd move this to _checkpointer/_checkpoint_utils.py since I plan to eventually nuke this file

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are opportunities to make the optimizer wrapper have identical APIs to a normal optimizer, we can hide a lot of logic within the wrapper class and we won't need to check if self._optimizer_in_bwd in so many locations. I'm worried this will confuse new users esp since this is one of the main flagship recipes

_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
_component_: torchtune.utils.metric_logging.WandBLogger
project: foo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is user facing, let's provide a real project name?

# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps

# TODO: find a better place / way to perform validation of args that don't yet
# compose with each other.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after the dataclasses were removed we lost a convenient place to perform validation - my thinking is we provide a convenient utility function(s) that does this for users that they just import in their recipe, but curious about your thoughts

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could be one approach, a possible downside is that the helper just becomes this monolithic dumping ground where we're checking various configs and it just becomes a large swath of if statements. On the other hand, if we don't do something like this then we'll just have it spelled out in each recipe which will increase code bloat and maintainence overhead (would have to copy things into each recipe)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some piece of this is just not trying to do too much in a single recipe, right? This is part of the reason we split single-device and distributed recipes to begin with.

We can still do config validation, just have to check that fields are defined instead of naively just checking values of fields. Personally I would be in favor of some config validation utilities defined on a per-recipe basis under configs/ somewhere, but coupled with clear documentation in the recipe class of which intersections of features are not supported.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense. I've added some documentation in the recipe class for now. @RdoubleA , let's chat about Evan's suggestion here?

recipes/full_finetune_single_device.py Outdated Show resolved Hide resolved
log.info("Optimizer is initialized.")
return optimizer
if optimizer_in_bwd:
self._optimizer_in_bwd = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set this in init?

optim_dict[param].zero_grad()

for p in self._model.parameters():
p.register_post_accumulate_grad_hook(optim_step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say put the optimizer in backward logic in a separate utility, it uses some non-intuitive logic that may confuse users

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree and see the reasoning in general, but IMO a downside of something like a utils.setup_optim_in_backward(optim_config, model) is that there is a side-effect of hooks getting registered on the model. I want things that modify state to be as explicit as possible, so I could do something like register_optimizer_in_backward_hooks and make_optim_checkpoint_wrapper - more utilities / components than a monolithic thing that configures the entire optimizer in backward. @kartikayk , @ebsmothers what do you think?

raise RuntimeError(
"Failed loading in-backward optimizer checkpoints."
"Please make sure run being restored from was using in-backward optimizer."
f"Original error {str(e)}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the from e should take care of surfacing the original error

Copy link
Member Author

@rohan-varma rohan-varma Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried this, and it didn't - i.e. I didn't see the exception from e in this error. Not sure if I just messed something up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me continue checking this though

if not self._optimizer_in_bwd:
ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict()
else:
ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the APIs will be the same for the optim ckpt wrapper, you could just call it self._optimizer and remove the if else

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm personally not an advocate for this due to reason explained in the other comment. If folks feel like this is better UX though, I'm happy to just add it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to agree with @RdoubleA. Definitely don't want to hide the actual behavior too much, but in this case we already have state_dict and other APIs defined, we might as well just cut down on branching (aside from where it's really needed, like in train). But honestly not taking a super strong stance here so fine either way

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - let's discuss in follow up PRs.

"lr": self._optimizer.param_groups[0]["lr"],
# NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently
# true since we don't expose the ability to configure this yet.
"lr": list(self._optim_ckpt_wrapper.optim_map.values())[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make a method or property in the wrapper that retrieves this so this can be cleaner?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea!

@rohan-varma
Copy link
Member Author

If there are opportunities to make the optimizer wrapper have identical APIs to a normal optimizer, we can hide a lot of logic within the wrapper class and we won't need to check if self._optimizer_in_bwd in so many locations. I'm worried this will confuse new users esp since this is one of the main flagship recipes

So I'm actually reluctant to have the optimizer wrapper have identical APIs as a normal optimizer. This would mean it would need to implement a .step(), which would likely be a noop since the step actually happens in the backward pass. In general I'm a bit reluctant to add these sort of pass through abstractions that are just no-ops, as the user needs to click through the implementation to find it doesn't do anything (as expected).

I do agree that all the if/else'ing here is not ideal.

Comment on lines 12 to 19
# --config llama2/7B_full_single_device \
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device \
# --config llama2/7B_full_single_device \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing _low_memory in these

batch_size: 2
epochs: 1
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bitsandbytes is not officially in our core requirements, right? What's our plan for handling things gracefully here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call out, I think we'll simply add bitsandbytes to our core deps.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added bitsandbytes as a core dep

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed bnb as a core dep after discussion

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall the changes look good and the memory reduction is great. Agree with your point in the summary, we should definitely put together a quick tutorial on some of the memory-saving tricks being applied here.

Two asks before stamping: (1) figure out a plan for bitsandbytes dep (right now we are not doing any checks anywhere as far as I can tell), and (2) tests 😃

@@ -15,6 +15,101 @@ class ModelType(Enum):
LLAMA2 = "llama2"


class OptimizerInBackwardWrapper:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! One comment on the recipe test, otherwise just a couple nits

@rohan-varma rohan-varma merged commit 5cbc196 into main Mar 29, 2024
20 checks passed
@chauhang
Copy link
Contributor

chauhang commented Mar 29, 2024

@rohan-varma Is there a way to achieve the same using torchao? BnB support on AMD requires complex setup, compile from source did not work for me.

tcapelle pushed a commit to tcapelle/torchtune that referenced this pull request Apr 5, 2024
@joecummings joecummings deleted the full_ft_mem branch April 11, 2024 15:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants