-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
grad accum in LoRA distributed recipe #644
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/644
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9d6b4c6 with merge base 32d66df (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
recipes/lora_finetune_distributed.py
Outdated
@@ -468,6 +472,17 @@ def save_checkpoint( | |||
intermediate_checkpoint=intermediate_checkpoint, | |||
) | |||
|
|||
def _should_update_weights(self, current_iteration: int) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a utility at this point since it's in every recipe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I just don't know where to put it tbh. None of the existing files feel relevant and I don't wanna create a new file just for this. Lmk if you have suggestions here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe torchtune/utils/training.py
or torchtune/utils/optim_utils.py
? I don't see any harm in making a new file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tbh I'm inclined to go the other way and say that we should just do this inline. I definitely do not want to add even more indirection than we already have. And this is literally just doing a mod check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also @joecummings @kartikayk or @rohan-varma if you have thoughts here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do it inline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm yeah, that's cleaner than I thought it would be. excellent choice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean - thanks for adding this!
How motivated are you to fix this? :) I think it should be a two line change tbh:
I think this should work and then we'll no longer have the discrepancy between tqdm and our wndb logs. Can happen in a follow up as well. |
I will probably just punt to a follow-up if it's all the same |
Context
Implemented following the approach in
full_finetune_distributed.py
Run without grad accumulation
Can see from wandb that there are 20 iterations logged
Run with grad accumulation
Since tqdm logs iteration number, not step number, we can see that both cases run for the expected number of iterations.