-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Activation offloading for fullfinetuning + fix tied embedding #1847
Activation offloading for fullfinetuning + fix tied embedding #1847
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1847
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7853938 with merge base d3039da (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
# training attributes | ||
self._enable_activation_checkpointing = cfg.enable_activation_checkpointing | ||
self._enable_activation_offloading = cfg.get( | ||
"enable_activation_offloading", False | ||
) | ||
if self._enable_activation_offloading and self._device.type != "cuda": | ||
raise RuntimeError( | ||
"enable_activation_offloading should only be enabled for training on CUDA" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed from init. This is handled in setup_model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait I'm confused.. we are still doing this in __init__
, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recipes/full_finetune_distributed.py
Outdated
opt_state_dict=( | ||
checkpoint_dict[training.OPT_KEY] | ||
if self._resume_from_checkpoint | ||
else None | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pre commit hook
recipes/full_finetune_distributed.py
Outdated
collate_fn=( | ||
partial( | ||
collate_fn, | ||
padding_idx=self._tokenizer.pad_id, | ||
ignore_idx=self._loss_fn.ignore_index, | ||
) | ||
if not packed | ||
else padded_collate_packed | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pre commit hook
recipes/full_finetune_distributed.py
Outdated
if enable_activation_offloading: | ||
if self._device.type != "cuda": | ||
raise RuntimeError( | ||
"enable_activation_offloading should only be True for training on CUDA" | ||
) | ||
if not enable_activation_checkpointing: | ||
raise RuntimeError( | ||
"enable_activation_offloading should only be True when enable_activation_checkpointing is True" | ||
) | ||
elif enable_activation_checkpointing: | ||
log.info( | ||
"Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " | ||
"Enabling activation offloading should reduce memory further." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the only think that i dont like about this is that it could fail much faster if we added it to the init. But i like that the checks are near the code where it matters. So i am not sure which one to pick.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think failing much faster is better
self.activations_handling_ctx = training.get_act_offloading_ctx_manager( | ||
model, enable_activation_offloading | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added function to handle NoOp / getting context
recipes/full_finetune_distributed.py
Outdated
if not enable_activation_checkpointing: | ||
raise RuntimeError( | ||
"enable_activation_offloading should only be True when enable_activation_checkpointing is True" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no point in running offloading when AC is off. Its extremely slow.
@@ -9,13 +9,33 @@ | |||
import torch.nn.functional as F | |||
|
|||
|
|||
class Linear(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring explains why i had to add this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I'm a bit confused by this change on two fronts:
(1) Does this not change the key names in the state dict?
(2) Now that we have a module again, how do we not wind up right back where we started with to_empty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(1) No. The Linear nn.Module doesnt have any weights. The weight is passed in the forward only.
Class Linear(nn.Module):
def forward(x, weight):
return F.Linear(x, weight)
TiedLinear is still a regular python class, and the key name is still model.TiedLinear.tok_embedding.weight
(2) i tested it with FSDP (ran the script and added assertion in the training loop), and confirmed that the memory pointers are the same in model.tok_embedding.weight and model.output.weight. So things are fine. Is that what you were referencing too?
Will add the option to all the configs for full finetuning |
@@ -83,7 +83,7 @@ dtype: bf16 | |||
|
|||
# Activations Memory | |||
enable_activation_checkpointing: True | |||
enable_activation_offloading: True | |||
enable_activation_offloading: True # True reduces memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but do we want to say "True reduces memory" even if we already set it to True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I'm curious why we choose to enable it here (other than the fact that it was already enabled). Seems like the general rule of thumb is to enable for low-memory configs? But this one I'm not clear on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great catch! I intended to only set it to True in low memory configs. Not sure how this one happened.
With that being said, i think that we should keep the comment, even when its True already. Do you disagree?
@@ -68,6 +68,7 @@ device: cuda | |||
|
|||
# Memory management | |||
enable_activation_checkpointing: True | |||
enable_activation_offloading: False # True reduces memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remind me.. did we test activation offloading on the vision models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i didnt include it in my tests. Let me do it tomorrow.
recipes/full_finetune_distributed.py
Outdated
back during the backward pass. As always, there is a tradeoff--these savings in memory can | ||
come at the cost of training performance and CPU resources. To recover some runtime cost, | ||
we've added an option to enable offloading on a different stream to permit overlapping with | ||
the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor point but in two days 2.5 will be stable so we may not need this comment about nightlies by the time this lands anyways (fine to keep it in, just pointing it out)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed
# training attributes | ||
self._enable_activation_checkpointing = cfg.enable_activation_checkpointing | ||
self._enable_activation_offloading = cfg.get( | ||
"enable_activation_offloading", False | ||
) | ||
if self._enable_activation_offloading and self._device.type != "cuda": | ||
raise RuntimeError( | ||
"enable_activation_offloading should only be enabled for training on CUDA" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait I'm confused.. we are still doing this in __init__
, no?
@@ -34,6 +34,7 @@ def _get_test_config_overrides(self): | |||
"batch_size=4", | |||
"dtype=fp32", | |||
"enable_activation_checkpointing=False", | |||
"enable_activation_offloading=False", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we set it to True in at least one of our test cases somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should. I will see if i can quickly add it tomorrow
@@ -9,13 +9,33 @@ | |||
import torch.nn.functional as F | |||
|
|||
|
|||
class Linear(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I'm a bit confused by this change on two fronts:
(1) Does this not change the key names in the state dict?
(2) Now that we have a module again, how do we not wind up right back where we started with to_empty?
if enable_activation_offloading: | ||
activations_handling_ctx = OffloadActivations() | ||
|
||
# Below is our hack to disable offloading the last output Linear in every | ||
# step, as the cost for offloading the activation and then soon after bringing | ||
# it back is expensive. Moreover, due to heuristics in our streaming API, | ||
# we actually use more memory if we offload it as it interferes with chunkedCE. | ||
output_head_detected = False | ||
if hasattr(model, "output"): | ||
noop_ctx = NoOpManager() | ||
if isinstance(model.output, nn.Module): | ||
model.output.register_forward_pre_hook( | ||
lambda *args: noop_ctx.__enter__() | ||
) | ||
model.output.register_forward_hook( | ||
lambda *args: noop_ctx.__exit__(), always_call=True | ||
) | ||
output_head_detected = True | ||
elif isinstance(model.output, TiedLinear): | ||
model.output.linear.register_forward_pre_hook( | ||
lambda *args: noop_ctx.__enter__() | ||
) | ||
model.output.linear.register_forward_hook( | ||
lambda *args: noop_ctx.__exit__(), always_call=True | ||
) | ||
output_head_detected = True | ||
|
||
elif hasattr(model, "decoder"): | ||
noop_ctx = NoOpManager() | ||
if isinstance(model.decoder, nn.Module): | ||
model.decoder.output.register_forward_pre_hook( | ||
lambda *args: noop_ctx.__enter__() | ||
) | ||
model.decoder.output.register_forward_hook( | ||
lambda *args: noop_ctx.__exit__(), always_call=True | ||
) | ||
output_head_detected = True | ||
|
||
if not output_head_detected: | ||
log.warning( | ||
"During activation offloading, no output head was detected. " | ||
"If your model has an output head, it will be offloaded. " | ||
"This usually greatly slows training, given the large vocabulary size. " | ||
"To change this behavior, set your output head as model.output and make it " | ||
"an nn.Module." | ||
) | ||
|
||
else: | ||
activations_handling_ctx = contextlib.nullcontext() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this covers all of our cases, but I dont like too much how it looks. In the future change it to identify tensor size, and if larger than threshold, make it a non op
@@ -173,6 +175,7 @@ def test_training_state_on_resume( | |||
resume_from_checkpoint=True \ | |||
metric_logger.filename={log_file} \ | |||
enable_activation_checkpointing=True \ | |||
enable_activation_offloading=True \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I enabled it for some tests in this file only. It tests lora and distributed. Ideally, we should have it vs many other parameters, like compile, vision and tiedembeddings. I wont address those in this PR. This should be part of the testing improvement, IMO
@@ -569,7 +613,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: | |||
# Shape [b, s], needed for the loss not the model | |||
labels = batch.pop("labels") | |||
|
|||
logits = self._model(**batch) | |||
with self.activations_handling_ctx: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curiosity brings me here. If I gather correctly, every model forward pass for which which we want to offload activations for needs to sit inside this context manager?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct!
elif hasattr(model, "decoder"): | ||
# TODO: it errors out. Needs debugging. | ||
# assert_size_stride(rsqrt_2, (4, 32, 1601, 1), (52224, 1632, 1, 1)) | ||
# AssertionError: expected size 4==4, stride 51232==52224 at dim=0; | ||
# # expected size 32==32, stride 1601==1632 at dim=1 | ||
raise NotImplementedError( | ||
"Multimodal model does not support activation offloading yet. Please set enable_activation_offloading=False" | ||
) | ||
# if isinstance(model.decoder, nn.Module): | ||
# model.decoder.output.register_forward_pre_hook( | ||
# lambda *args: noop_ctx.__enter__() | ||
# ) | ||
# model.decoder.output.register_forward_hook( | ||
# lambda *args: noop_ctx.__exit__(), always_call=True | ||
# ) | ||
# output_head_detected = True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs debugging in a follow up pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha can you remove the commented out code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're not savages, remove the commented out code.
Context
What is the purpose of this PR? Is it to
Changelog
Test plan