-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prevent OOM during checkpoint save on colab for llama3-8b qlora recipe #1315
Changes from all commits
9194409
c3beb8a
7a825d8
b0b9031
327cb10
923b3b3
14a147b
30f3b1e
731d062
6dffac5
1106e47
2a4f19b
cc4bd24
d17098c
8f1f633
f87ca63
8390868
a119a4d
fc15271
278e89e
f309bca
244896b
7b094df
147b8b4
05250f5
c2829db
8a16c76
91f7d43
b3364f6
645891b
e44507a
31f8a20
c808a60
922c73a
eba1ffa
dc8c6b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,3 +104,6 @@ profiler: | |
warmup_steps: 5 | ||
active_steps: 2 | ||
num_cycles: 1 | ||
|
||
# For colab use True | ||
low_cpu_ram: False |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
TransformerSelfAttentionLayer, | ||
) | ||
|
||
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook | ||
from torchtune.modules.common_utils import _register_reparametrize_state_dict_hooks | ||
|
||
from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear | ||
|
||
|
@@ -256,9 +256,7 @@ def lora_llama3( | |
if quantize_base: | ||
# For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly | ||
# so as to not increase peak memory | ||
model._register_state_dict_hook( | ||
partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True) | ||
) | ||
_register_reparametrize_state_dict_hooks(model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't have to be done in this PR, but we can think about adding this for other models that would have a similar memory situation when running QLoRA (Llama 3.1 8B is an obvious choice, but there are a handful of other similarly-sized models supported in our repo that could benefit from this) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will do in followup |
||
|
||
return model | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we want to only run state_dict post hooks once, so we should reuse the
state_dict
this does change the semantic slightly though -- before
adapter_*.pt
contained weights tagged with CUDA, but now it contains weights tagged with CPU.Not sure whether the old behavior was intended/whether this change is ok (but no CI seems to fail :D) Also when loading, we
map_location="cpu"
regardlessThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed this comment before. But yeah I think this change makes sense, I don't think there's any reason we need to require CUDA weights. And as you point out since we load on CPU when resuming it was probably never really an issue. Plus not re-running the state dict post hooks is a nice bonus.