-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama3-8b memory efficient full finetune #990
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/990
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c99b4d7 with merge base 3883081 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Few high level questions:
|
I'd like to avoid this especially if we stick with the current recipe. Since this recipe is used for other workloads, users who change optimizer via config for those workloads would be surprised that their changes don't have effect since we hardcode here. |
@RdoubleA Yeah the UX concerns definitely make sense. I think this config can just be renamed appropriately to match what we have for llama2. Open to authoring a separate recipe or moving this to a helper function - feel free to let me know what you and @ebsmothers think or if any additional input is needed from me, thanks! |
Refactored to simply use PagedAdamW8bit after @ebsmothers suggestion! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Really happy to see we were able to get good memory with PagedAdamW8bit. We should see if it helps on the Llama2 memory-efficient config at all too (ofc not as urgent since we already have reasonable peak memory there)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dang how did all that become one line
Context
What is the purpose of this PR? Is it to
TL;DR: This PR saves ~46% peak memory for llama3-8b single device full finetune while keeping performance at parity to current offering, just by switching Adamw8bit -> PagedAdamw8bit. @ebsmothers reminded me that this exists after I took a much more complicated approach.
Changelog
Test plan
Current 8B full single device:
8B_full_single_device using PagedAdamW:
8B full single device with this PR (PagedAdamW8bit):
For comparison, current llama2-7b 7B_full_low_memory:
TL;DR: This PR reduces peak memory by ~46% while maintaining approximately the same perf, getting us to a < 24 GB full finetune.
Loss curves are the same (comparing today's baseline versus with these changes) -
Follow-ups