-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix 3.1 rope init for compile #1544
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1544
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3d25ee0 with merge base df29d8a (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neaet!
with torch.device(x.device): | ||
self.rope_init() | ||
raise RuntimeError( | ||
"RoPE cache is not built. Please call rope_init() first." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this feels a bit inconvenient. Do we have to raise the error?
I also think that we should update the other positional embedding to follow the same pattern. They dont have is_cache_built. In the recipes, we should add a check:
if hasattr(m, init_rope) and not m.is_cache_bult
currently it is just:
if hasattr(m, init_rope)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this feels a bit inconvenient. Do we have to raise the error?
I think it's safest. Was talking with @joecummings about this a bit; previously we had an issue where RoPE was not getting initialized and so there were garbage values left over from the to_empty
call. This was extremely hard to debug (Joe can tell you all about that one).
I also think that we should update the other positional embedding to follow the same pattern.
Updating the other one is fine but I might save that for now since it has a larger blast radius and isn't blocking anything like this is.
if hasattr(m, init_rope) and m.is_cache_bult
I think it'd be not m.is_cache_built
, right? Either way I don't think there's any real cost to calling it twice. But I plan to move this init logic into a common util soon, so we can make the change as part of that move.
Currently our Llama 3.1 doesn't work with compile because we init the RoPE cache in the first forward.
For our single-device recipe we can init the cache as part of
__init__
no problem, but for our distributed recipe we load on meta device. In that case we cannot init the cache in__init__
because the RoPE scaling factors require data to exist in intermediate tensors (which they won't on meta device). So for our distributed recipes currently initialize RoPE cache directly in the recipe here after sharding the model.So we just need a way to skip calling
rope_init
from a meta device context. Actually we've solved this problem hackily before, see e.g. here. Well this change is basically identical to that, but this timerope_init
always gets called. Then once the frequency is constructed, we break out early if it's on meta device and rely on the recipe to manually call after FSDP sharding.Test plan
Apart from green CI, checking that loss curves are the same on a few different configs..
Single-device LoRA 8B
Distributed LoRA 8B
Single-device FFT 8B