-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prevent OOM during checkpoint save on colab for llama3-8b qlora recipe #1315
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1315
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit dc8c6b0 with merge base 66590b4 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
0976c41
to
8e9add3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! Left a few basic questions, but overall really excited to see that we'll be able to provide proper end-to-end Colab support with these changes. One high-level comment is that we should think about writing a utility for the portion in ~L520-L550. Then we can gate behind a config like low_memory_save
or something like that. But that's more of a UX thing, happy to help out there if needed.
|
||
# Construct the full state dict with LoRA weights merged into base LLM weights | ||
merged_state_dict = get_merged_lora_ckpt( | ||
state_dict, | ||
rank=self._lora_rank, | ||
alpha=self._lora_alpha, | ||
dest_state_dict=dest_state_dict, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(another noob q:) so after this, we can save the checkpoint normally (e.g. in the call to save_checkpoint
on L580), even though we are now saving something like {"model": dest_state_dict, ... (other stuff)}
to a separate file from the mmapped one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should think of dest_state_dict
as being backed by disk (and yes we are writing it once to disk in a torch.save format in this whole bit)
The final save_checkpoint
on L580 is re-saving a new checkpoint with {"model": merged_state_dict, ... (other stuff)}
to a new file checkpoint file. So we are saving dest_state_dict
twice
We could potentially do something smarter than what we're doing right now to avoid this "re-save" but the code changes would be more invasive than they are now, given that the checkpointers seem to do some remapping though I refrained from doing this as a v0, wdyt?
054b6fd
to
b8dbcfd
Compare
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice | ||
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys | ||
adapter_key_filter = lambda x: x in self.adapter_params | ||
adapter_state_dict = { | ||
k: v for k, v in state_dict.items() if adapter_key_filter(k) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we want to only run state_dict post hooks once, so we should reuse the state_dict
this does change the semantic slightly though -- before adapter_*.pt
contained weights tagged with CUDA, but now it contains weights tagged with CPU.
Not sure whether the old behavior was intended/whether this change is ok (but no CI seems to fail :D) Also when loading, we map_location="cpu"
regardless
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed this comment before. But yeah I think this change makes sense, I don't think there's any reason we need to require CUDA weights. And as you point out since we load on CPU when resuming it was probably never really an issue. Plus not re-running the state dict post hooks is a nice bonus.
@@ -23,6 +23,7 @@ model: | |||
apply_lora_to_output: False | |||
lora_rank: 8 | |||
lora_alpha: 16 | |||
low_cpu_ram: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that we have to modify the state dict hook, I see why it makes sense to put this in the model config. But it does feel a bit weird to me since it's not really a property of the model (more just the model is a convenient place for us to know that we're gonna have to upcast NF4 tensors).
I wonder if we can instead define a standalone config, parse it in the recipe with e.g. low_cpu_ram = cfg.get("low_cpu_ram", False)
, then use that to overwrite the reparametrize_as_dtype_state_dict_post_hook
. Maybe a bit hacky, but we can at least assert that the expected state dict hook is there before replacing it to ensure that we aren't adding this onto any old non-QLoRA model. Is that obviously worse? (Mainly I want to avoid our model classes having to know or care about low-level details of how they're gonna be checkpointed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I think to manually remove and reregister the hook we would need the handle returned by _register_state_dict_hook
Is the way I updated this to patch the hook ok with you
if sys.platform == "win32": | ||
raise RuntimeError( | ||
"low_cpu_ram=True not supported on Windows." | ||
) | ||
else: | ||
raise RuntimeError("low_cpu_ram=True requires torch.__version__ >= 2.5.0.dev20240830.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here: ideally we can do these checks in the recipe (or in a utility) rather than in the builder of the model
torchtune/modules/common_utils.py
Outdated
|
||
|
||
# mmap.MAP_SHARED is not supported on Windows but this change targets colab. | ||
if hasattr(torch.serialization, "skip_data") and not sys.platform == "win32": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the first half of this check just a proxy for a particular torch version? If so maybe better to just directly gate on that (with torch_version_ge
or something)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch_version_ge
seems to cause circular import when imported in this file :/ so just using __torch_version__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah we are working to fix that, __torch_version__
is good too
b6f74d1
to
2c45831
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more small comments and questions, but otherwise this looks good to go!
# Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice | ||
# Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys | ||
adapter_key_filter = lambda x: x in self.adapter_params | ||
adapter_state_dict = { | ||
k: v for k, v in state_dict.items() if adapter_key_filter(k) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed this comment before. But yeah I think this change makes sense, I don't think there's any reason we need to require CUDA weights. And as you point out since we load on CPU when resuming it was probably never really an issue. Plus not re-running the state dict post hooks is a nice bonus.
if cfg.get("low_cpu_ram", False): | ||
common_utils._use_low_cpu_ram = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this have to be in recipe_main
? Can we instead do it somewhere inside the recipe class (before the model gets instantiated)? Also would add a one-line comment explaining this
torchtune/modules/common_utils.py
Outdated
|
||
|
||
# mmap.MAP_SHARED is not supported on Windows but this change targets colab. | ||
if torch.__version__ >= "2.5.0.dev20240906" and not sys.platform == "win32": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a sanity check here: is this if/else just to ensure that no one tries to directly import the _low_ram_reparametrize_as_dtype_state_dict_post_hook
API on an unsupported environment? Mainly asking because we have the equivalent checks in _register_reparametrize_state_dict_hooks
now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, removing this if else
model._register_state_dict_hook( | ||
partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True) | ||
) | ||
_register_reparametrize_state_dict_hooks(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't have to be done in this PR, but we can think about adding this for other models that would have a similar memory situation when running QLoRA (Llama 3.1 8B is an obvious choice, but there are a handful of other similarly-sized models supported in our repo that could benefit from this)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do in followup
torchtune/modules/common_utils.py
Outdated
# Create a state_dict on disk with space reserved for storage bytes | ||
# Then load with mmap and MAP_SHARED (can writeback to disk file) | ||
dest_state_dict_path = "/tmp/fake_state_dict.pt" | ||
with torch.serialization.skip_data(materialize_fake_tensors=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob q: what does materialize_fake_tensors
mean in this context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It means that FakeTensors in the object passed to torch.save
will be treated as if they were real tensors. The implication is that torch.load
will load a tensor (not FakeTensor) on the FakeTensor's device with storage allocated but uninitialized (0s)
torchtune/modules/common_utils.py
Outdated
# In place update original state_dict object. Although the private state dict | ||
# post hook supports out of place behavior, the semantic actually buggy. We eventually want | ||
# to use the public state_dict post hook which does not support out of place behavior. | ||
for k in state_dict.keys(): | ||
state_dict[k] = dest_state_dict[k] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I have some misunderstanding here. If we inplace update the state dict to the upcasted version of the weights, why won't it cause an OOM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we do sd = torch.load(..., mmap=True)
, the storages of the tensors in sd
are mmap-backed
state_dict[k] = dest_state_dict[k]
does not access any pages of the storage of the tensor given by dest_state_dict[k]
, so the storage is not materialized, and no OOM will happen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah nvm I get it now, I was not thinking it through carefully enough. Thanks for the explanation!
…t save happens earlier and result is deterministic
This reverts commit abdbd7f.
a02ccd6
to
91f7d43
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1315 +/- ##
==========================================
- Coverage 27.22% 27.18% -0.04%
==========================================
Files 286 286
Lines 13828 13869 +41
==========================================
+ Hits 3764 3770 +6
- Misses 10064 10099 +35 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for enabling this, can't wait to put together some torchtune Colab notebooks for our users!
Context
This PR prevents OOM during checkpoint save on colab for the following recipe,
tune run recipes/lora_finetune_single_device.py --config recipes/configs/llama3/8B_qlora_single_device.yaml
I do believe it is possible to do better (perf improvements for this colab case) if we refactor the checkpointing logic, but that would be a more invasive refactor imo for now this is the most minimally invasive change that unblocks this use case.
Changelog
_low_ram_reparametrize_as_dtype_state_dict_post_hook
and_register_reparametrize_state_dict_hooks
that toggles between the regular reparametrize state_dict hook and this onelow_cpu_ram
-- whenlow_cpu_ram
isTrue
_register_reparametrize_state_dict_hooks
toggles to thelow_ram
version of the hookOld changelog (keeping around for posterity)
Set a seed in8B_qlora_single_device.yaml
to make dataloader samples (and hence weights) deterministic[for testing velocity purposes]Changed number of epochs to 8, reducedmax_steps_per_epoch
--> 20,gradient_accumulation_steps
--> 2 to makesave_checkpoint
be called sooner[for loss curves] some changes to lora_finetune_single_device.py to mimic a user doingresume_from_checkpoint
after each epoch (while preserving logger)[Not to be landed in torchtune, for PoC purpose] patchedFakeTensor.__reduce_ex__
which is needed to ensure thewrite_record_metadata
utility to create empty checkpoints is called (requires changes in Prototype changes to create fake checkpoints with empty storages pytorch#133272) I need to figure out how to land this piece :)Skip registration ofreparametrize_as_dtype_state_dict_post_hook
for llama3- Showed an example of how to do the corresponding (using mmap to prevent OOM) inlora_finetune_single_device.py:save_checkpoint
Test plan
Sanity check
Ran
tune run
command on devgpu (with only the changes in8B_qlora_single_device.yaml
to set seed and decrease steps per epoch) and verified thatmeta_model_0.pt
generated is the same before and after the changes in this PR with small snippetVerified that colab does not OOM
https://colab.research.google.com/drive/1y7Az78ATauK7gkewZkcMO3cNgVWm1233?usp=sharing
Loss Curves
The validation was run on commit
abdbd7
which has special logic to mimicresume_from_checkpoint
for each epochConfig is per the changes in
recipes/configs/llama3/8B_qlora_single_device.yaml
(8 epochs with 20 steps per epoch and gradient accumulation every 2 steps) on 6cf31b6. For the reloading checkpoint case I modifiedlora_fine_tune_single_device.py:recipe_main
to mimicresume_from_checkpoint
after each epochDevgpu:
Colab: