-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full finetune < 16GB #527
Full finetune < 16GB #527
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/527
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6c9731d with merge base ec89eb0 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
@@ -32,8 +32,9 @@ resume_from_checkpoint: False | |||
batch_size: 2 | |||
epochs: 3 | |||
optimizer: | |||
_component_: torch.optim.SGD | |||
_component_: bitsandbytes.optim.PagedAdamW |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to add this to requirements, docs, etc
lr: 2e-5 | ||
optimizer_in_bwd: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will likely ship this as False
and add docs / tutorials since we don't support loading in optimizer states yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nvm, we do support it now.
torchtune/utils/checkpoint.py
Outdated
@@ -22,6 +22,59 @@ | |||
) | |||
|
|||
|
|||
class OptimizerInBackwardWrapper: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd move this to _checkpointer/_checkpoint_utils.py since I plan to eventually nuke this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are opportunities to make the optimizer wrapper have identical APIs to a normal optimizer, we can hide a lot of logic within the wrapper class and we won't need to check if self._optimizer_in_bwd
in so many locations. I'm worried this will confuse new users esp since this is one of the main flagship recipes
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
_component_: torchtune.utils.metric_logging.WandBLogger | ||
project: foo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this is user facing, let's provide a real project name?
# Training cfg | ||
self._resume_from_checkpoint = cfg.resume_from_checkpoint | ||
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps | ||
|
||
# TODO: find a better place / way to perform validation of args that don't yet | ||
# compose with each other. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after the dataclasses were removed we lost a convenient place to perform validation - my thinking is we provide a convenient utility function(s) that does this for users that they just import in their recipe, but curious about your thoughts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That could be one approach, a possible downside is that the helper just becomes this monolithic dumping ground where we're checking various configs and it just becomes a large swath of if
statements. On the other hand, if we don't do something like this then we'll just have it spelled out in each recipe which will increase code bloat and maintainence overhead (would have to copy things into each recipe)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some piece of this is just not trying to do too much in a single recipe, right? This is part of the reason we split single-device and distributed recipes to begin with.
We can still do config validation, just have to check that fields are defined instead of naively just checking values of fields. Personally I would be in favor of some config validation utilities defined on a per-recipe basis under configs/ somewhere, but coupled with clear documentation in the recipe class of which intersections of features are not supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense. I've added some documentation in the recipe class for now. @RdoubleA , let's chat about Evan's suggestion here?
log.info("Optimizer is initialized.") | ||
return optimizer | ||
if optimizer_in_bwd: | ||
self._optimizer_in_bwd = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set this in init?
optim_dict[param].zero_grad() | ||
|
||
for p in self._model.parameters(): | ||
p.register_post_accumulate_grad_hook(optim_step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say put the optimizer in backward logic in a separate utility, it uses some non-intuitive logic that may confuse users
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree and see the reasoning in general, but IMO a downside of something like a utils.setup_optim_in_backward(optim_config, model)
is that there is a side-effect of hooks getting registered on the model. I want things that modify state to be as explicit as possible, so I could do something like register_optimizer_in_backward_hooks
and make_optim_checkpoint_wrapper
- more utilities / components than a monolithic thing that configures the entire optimizer in backward. @kartikayk , @ebsmothers what do you think?
raise RuntimeError( | ||
"Failed loading in-backward optimizer checkpoints." | ||
"Please make sure run being restored from was using in-backward optimizer." | ||
f"Original error {str(e)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the from e
should take care of surfacing the original error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually tried this, and it didn't - i.e. I didn't see the exception from e
in this error. Not sure if I just messed something up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me continue checking this though
if not self._optimizer_in_bwd: | ||
ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict() | ||
else: | ||
ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the APIs will be the same for the optim ckpt wrapper, you could just call it self._optimizer
and remove the if else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm personally not an advocate for this due to reason explained in the other comment. If folks feel like this is better UX though, I'm happy to just add it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm inclined to agree with @RdoubleA. Definitely don't want to hide the actual behavior too much, but in this case we already have state_dict and other APIs defined, we might as well just cut down on branching (aside from where it's really needed, like in train). But honestly not taking a super strong stance here so fine either way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense - let's discuss in follow up PRs.
"lr": self._optimizer.param_groups[0]["lr"], | ||
# NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently | ||
# true since we don't expose the ability to configure this yet. | ||
"lr": list(self._optim_ckpt_wrapper.optim_map.values())[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you make a method or property in the wrapper that retrieves this so this can be cleaner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea!
So I'm actually reluctant to have the optimizer wrapper have identical APIs as a normal optimizer. This would mean it would need to implement a I do agree that all the if/else'ing here is not ideal. |
# --config llama2/7B_full_single_device \ | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device \ | ||
# --config llama2/7B_full_single_device \ | ||
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing _low_memory
in these
batch_size: 2 | ||
epochs: 1 | ||
optimizer: | ||
_component_: bitsandbytes.optim.PagedAdamW |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bitsandbytes is not officially in our core requirements, right? What's our plan for handling things gracefully here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call out, I think we'll simply add bitsandbytes to our core deps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added bitsandbytes as a core dep
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed bnb as a core dep after discussion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall the changes look good and the memory reduction is great. Agree with your point in the summary, we should definitely put together a quick tutorial on some of the memory-saving tricks being applied here.
Two asks before stamping: (1) figure out a plan for bitsandbytes dep (right now we are not doing any checks anywhere as far as I can tell), and (2) tests 😃
@@ -15,6 +15,101 @@ class ModelType(Enum): | |||
LLAMA2 = "llama2" | |||
|
|||
|
|||
class OptimizerInBackwardWrapper: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @janeyx99
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! One comment on the recipe test, otherwise just a couple nits
@rohan-varma Is there a way to achieve the same using torchao? BnB support on AMD requires complex setup, compile from source did not work for me. |
Context
Changelog
_setup_optimizer
.OptimInBackwardWrapper
to checkpoint the optimizer states when running in backward + unittests.step
,zero_grad
when using optimizer in backward.Test plan
tune full_finetune_single_device --config recipes/configs/full_finetune_single_device.yaml &> out &
Memory does not spike above 16GB reserved -
Loss curves:
Eval result:
Checkpoint save/load of optimizer states
tune full_finetune_single_device --config recipes/configs/full_finetune_single_device.yaml max_steps_per_epoch=1 device=cuda:1
resume_from_checkpoint: True
andrecipe_checkpoint: /tmp/llama2/recipe_state.pt
, then run:tune full_finetune_single_device --config recipes/configs/full_finetune_single_device.yaml device=cuda:1
--> training resume worksDiscussion points
torch.optim.AdamW
forbitsandbytes.optim.PagedAdamW
for example. Can definitely have some documentation around this, but wondering if we need to do anything further for configuration._apply_optimizer_in_backward
API +distributed_state_dict
, though the downsides are that the latter API is only avail. in more recent versions of torch, and the former API is private. OTOH, to implement checkpointing using the technique in https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html is not well understood from a technical perspective (hasn't been done before, and there are tricky things to figure out like how to key the optimizers by a UUID instead of a numerical index, which won't work when there are multiple optimizers). Using the former implementation is also useful since it is better tested and designed than building something ourselves.Follow up work