Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix 3.1 rope init for compile #1544

Merged
merged 2 commits into from
Sep 11, 2024
Merged

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Sep 11, 2024

Currently our Llama 3.1 doesn't work with compile because we init the RoPE cache in the first forward.

For our single-device recipe we can init the cache as part of __init__ no problem, but for our distributed recipe we load on meta device. In that case we cannot init the cache in __init__ because the RoPE scaling factors require data to exist in intermediate tensors (which they won't on meta device). So for our distributed recipes currently initialize RoPE cache directly in the recipe here after sharding the model.

So we just need a way to skip calling rope_init from a meta device context. Actually we've solved this problem hackily before, see e.g. here. Well this change is basically identical to that, but this time rope_init always gets called. Then once the frequency is constructed, we break out early if it's on meta device and rely on the recipe to manually call after FSDP sharding.

Test plan

Apart from green CI, checking that loss curves are the same on a few different configs..

Single-device LoRA 8B

Screenshot 2024-09-11 at 1 56 51 PM

Distributed LoRA 8B

Screenshot 2024-09-11 at 1 56 34 PM

Single-device FFT 8B

Screenshot 2024-09-11 at 1 56 11 PM

Copy link

pytorch-bot bot commented Sep 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1544

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3d25ee0 with merge base df29d8a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 11, 2024
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neaet!

@ebsmothers ebsmothers marked this pull request as ready for review September 11, 2024 20:31
with torch.device(x.device):
self.rope_init()
raise RuntimeError(
"RoPE cache is not built. Please call rope_init() first."
Copy link
Contributor

@felipemello1 felipemello1 Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels a bit inconvenient. Do we have to raise the error?

I also think that we should update the other positional embedding to follow the same pattern. They dont have is_cache_built. In the recipes, we should add a check:

if hasattr(m, init_rope) and not m.is_cache_bult

currently it is just:
if hasattr(m, init_rope)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels a bit inconvenient. Do we have to raise the error?

I think it's safest. Was talking with @joecummings about this a bit; previously we had an issue where RoPE was not getting initialized and so there were garbage values left over from the to_empty call. This was extremely hard to debug (Joe can tell you all about that one).

I also think that we should update the other positional embedding to follow the same pattern.

Updating the other one is fine but I might save that for now since it has a larger blast radius and isn't blocking anything like this is.

if hasattr(m, init_rope) and m.is_cache_bult

I think it'd be not m.is_cache_built, right? Either way I don't think there's any real cost to calling it twice. But I plan to move this init logic into a common util soon, so we can make the change as part of that move.

@ebsmothers ebsmothers merged commit 221031a into pytorch:main Sep 11, 2024
17 checks passed
@ebsmothers ebsmothers deleted the 31-rope-init branch September 11, 2024 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants