Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation offloading for fullfinetuning + fix tied embedding #1847

Merged
merged 21 commits into from
Oct 30, 2024

Conversation

felipemello1
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Changelog

  • Added activation offloading for single/distributed full finetuning
  • Found that we didnt address the tied embedding issue. Fixed it by adding a dummy nn.Module so that tensor hooks would work with it
  • Moved the context manager logic to a function
  • Made the recipes more aligned

Test plan

image
image
image
image

Copy link

pytorch-bot bot commented Oct 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1847

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7853938 with merge base d3039da (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 15, 2024
Comment on lines -154 to -168
# training attributes
self._enable_activation_checkpointing = cfg.enable_activation_checkpointing
self._enable_activation_offloading = cfg.get(
"enable_activation_offloading", False
)
if self._enable_activation_offloading and self._device.type != "cuda":
raise RuntimeError(
"enable_activation_offloading should only be enabled for training on CUDA"
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed from init. This is handled in setup_model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I'm confused.. we are still doing this in __init__, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it back after this:

image

Comment on lines 242 to 246
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
else None
),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre commit hook

Comment on lines 552 to 560
collate_fn=(
partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed
),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre commit hook

Comment on lines 464 to 477
if enable_activation_offloading:
if self._device.type != "cuda":
raise RuntimeError(
"enable_activation_offloading should only be True for training on CUDA"
)
if not enable_activation_checkpointing:
raise RuntimeError(
"enable_activation_offloading should only be True when enable_activation_checkpointing is True"
)
elif enable_activation_checkpointing:
log.info(
"Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
"Enabling activation offloading should reduce memory further."
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only think that i dont like about this is that it could fail much faster if we added it to the init. But i like that the checks are near the code where it matters. So i am not sure which one to pick.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think failing much faster is better

Comment on lines +479 to +481
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added function to handle NoOp / getting context

Comment on lines 469 to 471
if not enable_activation_checkpointing:
raise RuntimeError(
"enable_activation_offloading should only be True when enable_activation_checkpointing is True"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no point in running offloading when AC is off. Its extremely slow.

@@ -9,13 +9,33 @@
import torch.nn.functional as F


class Linear(nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring explains why i had to add this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm a bit confused by this change on two fronts:

(1) Does this not change the key names in the state dict?
(2) Now that we have a module again, how do we not wind up right back where we started with to_empty?

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) No. The Linear nn.Module doesnt have any weights. The weight is passed in the forward only.

Class Linear(nn.Module):
      def forward(x, weight):
             return F.Linear(x, weight)

TiedLinear is still a regular python class, and the key name is still model.TiedLinear.tok_embedding.weight

(2) i tested it with FSDP (ran the script and added assertion in the training loop), and confirmed that the memory pointers are the same in model.tok_embedding.weight and model.output.weight. So things are fine. Is that what you were referencing too?

@felipemello1
Copy link
Contributor Author

Will add the option to all the configs for full finetuning

@joecummings joecummings mentioned this pull request Oct 15, 2024
34 tasks
@@ -83,7 +83,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: True
enable_activation_offloading: True # True reduces memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but do we want to say "True reduces memory" even if we already set it to True?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I'm curious why we choose to enable it here (other than the fact that it was already enabled). Seems like the general rule of thumb is to enable for low-memory configs? But this one I'm not clear on

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch! I intended to only set it to True in low memory configs. Not sure how this one happened.

With that being said, i think that we should keep the comment, even when its True already. Do you disagree?

@@ -68,6 +68,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remind me.. did we test activation offloading on the vision models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i didnt include it in my tests. Let me do it tomorrow.

back during the backward pass. As always, there is a tradeoff--these savings in memory can
come at the cost of training performance and CPU resources. To recover some runtime cost,
we've added an option to enable offloading on a different stream to permit overlapping with
the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor point but in two days 2.5 will be stable so we may not need this comment about nightlies by the time this lands anyways (fine to keep it in, just pointing it out)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

Comment on lines -154 to -168
# training attributes
self._enable_activation_checkpointing = cfg.enable_activation_checkpointing
self._enable_activation_offloading = cfg.get(
"enable_activation_offloading", False
)
if self._enable_activation_offloading and self._device.type != "cuda":
raise RuntimeError(
"enable_activation_offloading should only be enabled for training on CUDA"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I'm confused.. we are still doing this in __init__, no?

@@ -34,6 +34,7 @@ def _get_test_config_overrides(self):
"batch_size=4",
"dtype=fp32",
"enable_activation_checkpointing=False",
"enable_activation_offloading=False",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set it to True in at least one of our test cases somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should. I will see if i can quickly add it tomorrow

@@ -9,13 +9,33 @@
import torch.nn.functional as F


class Linear(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm a bit confused by this change on two fronts:

(1) Does this not change the key names in the state dict?
(2) Now that we have a module again, how do we not wind up right back where we started with to_empty?

Comment on lines 353 to 401
if enable_activation_offloading:
activations_handling_ctx = OffloadActivations()

# Below is our hack to disable offloading the last output Linear in every
# step, as the cost for offloading the activation and then soon after bringing
# it back is expensive. Moreover, due to heuristics in our streaming API,
# we actually use more memory if we offload it as it interferes with chunkedCE.
output_head_detected = False
if hasattr(model, "output"):
noop_ctx = NoOpManager()
if isinstance(model.output, nn.Module):
model.output.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
model.output.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
output_head_detected = True
elif isinstance(model.output, TiedLinear):
model.output.linear.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
model.output.linear.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
output_head_detected = True

elif hasattr(model, "decoder"):
noop_ctx = NoOpManager()
if isinstance(model.decoder, nn.Module):
model.decoder.output.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
model.decoder.output.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)
output_head_detected = True

if not output_head_detected:
log.warning(
"During activation offloading, no output head was detected. "
"If your model has an output head, it will be offloaded. "
"This usually greatly slows training, given the large vocabulary size. "
"To change this behavior, set your output head as model.output and make it "
"an nn.Module."
)

else:
activations_handling_ctx = contextlib.nullcontext()
Copy link
Contributor Author

@felipemello1 felipemello1 Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this covers all of our cases, but I dont like too much how it looks. In the future change it to identify tensor size, and if larger than threshold, make it a non op

@@ -173,6 +175,7 @@ def test_training_state_on_resume(
resume_from_checkpoint=True \
metric_logger.filename={log_file} \
enable_activation_checkpointing=True \
enable_activation_offloading=True \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I enabled it for some tests in this file only. It tests lora and distributed. Ideally, we should have it vs many other parameters, like compile, vision and tiedembeddings. I wont address those in this PR. This should be part of the testing improvement, IMO

@@ -569,7 +613,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")

logits = self._model(**batch)
with self.activations_handling_ctx:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curiosity brings me here. If I gather correctly, every model forward pass for which which we want to offload activations for needs to sit inside this context manager?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct!

Comment on lines +404 to +420
elif hasattr(model, "decoder"):
# TODO: it errors out. Needs debugging.
# assert_size_stride(rsqrt_2, (4, 32, 1601, 1), (52224, 1632, 1, 1))
# AssertionError: expected size 4==4, stride 51232==52224 at dim=0;
# # expected size 32==32, stride 1601==1632 at dim=1
raise NotImplementedError(
"Multimodal model does not support activation offloading yet. Please set enable_activation_offloading=False"
)
# if isinstance(model.decoder, nn.Module):
# model.decoder.output.register_forward_pre_hook(
# lambda *args: noop_ctx.__enter__()
# )
# model.decoder.output.register_forward_hook(
# lambda *args: noop_ctx.__exit__(), always_call=True
# )
# output_head_detected = True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs debugging in a follow up pr

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha can you remove the commented out code?

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not savages, remove the commented out code.

@felipemello1 felipemello1 merged commit e99b890 into pytorch:main Oct 30, 2024
17 checks passed
@felipemello1 felipemello1 deleted the offloading_single_device branch October 30, 2024 23:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants