-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added adapter_only option to LoRA #1220
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1220
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit eb7d41d with merge base f0a15c5 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @spider-man-tm! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
To the reviewer, Please let me know if there are any missing parts in my implementation or if there are any parts that should be removed. It's perfectly fine if the maintainer edits it directly. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1220 +/- ##
==========================================
+ Coverage 67.81% 70.19% +2.37%
==========================================
Files 219 220 +1
Lines 9908 9957 +49
==========================================
+ Hits 6719 6989 +270
+ Misses 3189 2968 -221 ☔ View full report in Codecov by Sentry. |
recipes/lora_dpo_distributed.py
Outdated
intermediate_checkpoint=intermediate_checkpoint, | ||
) | ||
# If the option was True, save only the adapter except for the last epoch | ||
is_intermediate_epoch = epoch + 1 < self.total_epochs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for your contribution : )
Small q: do we want to support the case where merged weights are never saved, even in the final epoch, and let the user merge them? Might be helpful when working with particularly large models. @joecummings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I think @SalmanMohammadi's point is inline with what I was thinking here. This does mean that we need to do a better job explaining how to merge weights or what exactly to do with them afterwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Concretely, should we fix adapter_only=self._save_adapter_only
here, and add something to the effect of:
if not is_intermediate_epoch:
log.info(
"Saving final model checkpoint. Please note that you have set save_adapter_only=True, so only adapter weights will be saved."
"You will need to merge the adapter weights into your base model for further use. See {where do we point them for now??}")
EDIT: could we also move this into the checkpointer's save_checkpoint
?
# adapter_only option | ||
# Set to True to save only the adapter weights for intermediate epochs. | ||
# For the final epoch, the entire model weights will be saved regardless of this option. | ||
adapter_only: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Considering save_adapter_weights_only
instead of adapter_only
for clarity.
I know it's different from the param name on the checkpointer, but the extra context isn't needed there b/c you already know you're dealing with saving a checkpointing.
cc @SalmanMohammadi and @pbontrager for thoughts.
recipes/lora_dpo_distributed.py
Outdated
@@ -143,6 +143,7 @@ def __init__(self, cfg: DictConfig) -> None: | |||
self.global_step = 0 | |||
|
|||
self._resume_from_checkpoint = cfg.resume_from_checkpoint | |||
self._adapter_only = cfg.get("adapter_only", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a logging INFO statement saying that only the adapter weights will be saved so the user knows right away?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your excellent point, and raise you another (dubious quality) point.
In line with moving towards outsourcing config validation/handholding into utilities from recipes, why not have save_adapter_weights_only
asd a property of the checkpointer? We could then throw an INFO in the checkpointer constructor.
Your config will then be:
checkpointer:
...
save_adapter_weights_only: True
This would fit if we're not doing any additional logic checks on saving the adapter weights only, i.e. if save_adapter_weights_only: True
we don't merge weights in the last epoch.
edit: sorry, as usual, I'm adding complexity to everything I touch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmmm, this does make sense to me. I'd much rather have the error propagated from the checkpointer, rather than the recipe file.
I think the complexity arises in how we actually instantiate the checkpointer. I don't think we want save_adapter_weights_only
as a param on the initialization of the checkpointer b/c the user should be able to take the component and save adapters or the whole file if they want without re-initializing it. Therefore, it has to be a param on the save
method.
So if we don't want it being passed to the constructor, we would have to parse out the save_adapter_weights_only
from the config file before creating the checkpointer, which feels a little messy. Maybe we push the warning to the save
method of the checkpointer? The only issue with this is that then the user could accidentally have the save_adapter_weights_only=True
, train their whole model, and then save the checkpoints without actually wanting that feature. But that might be too hand-holdy to worry about?
Thanks for coming on this journey of vomiting all my thoughts on this PR. I think my TL;DR is that this config variable should be separate from the instantiation of the checkpointer, but maybe we push the logging info to the save
method instead of in the main recipe.
Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we push the logging info to the save method instead of in the main recipe.
The save_checkpoint_method
of the checkpointer right? I agree here - we also validate adapter_only
there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally agree with @joecummings here.. the checkpointer class itself should mostly care about stuff that's global across a given fine-tuning run (input and output checkpoint formats). The flag to save only adapter weights is a more local thing since it can in theory vary depending on where we are in training. So imo it makes sense to expose only in save_checkpoint
and not in init.
A separate (but still relevant) point: our checkpointer naming still seems to imply it's only used for full fine-tuning. This is confusing and another reason why it'd look weird to put save_adapter_weights_only
in its init. (I don't think that's a good reason to keep save_adapter_weights_only
outside the checkpointer init, I actually think we should just rename the checkpointer instead)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great - thanks so much :)
Just a couple comments.
Thank you to all reviewers. I have made the following revisions:
Feel free to review the changes and provide further feedback! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really appreciate you adding this! Just two minor things then I think this is good to go
if not is_intermediate_epoch: | ||
log.info( | ||
"Saving final model checkpoint." | ||
"Please note that you have set save_adapter_weights_only=True, so only adapter weights will be saved." | ||
"You need to merge the adapter weights into your base model for further use. " | ||
f"See {type(self._checkpointer).__name__}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm being dense but is this log correct? I don't see where we actually check that save_adapter_weights_only=True
prior to logging this (similar comment in the other recipes too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Sorry, my original suggestion missed this.
Do we still think it's a good idea to move this into the checkpointer, rather than recipes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, rather than writing similar code each time we create a new recipe, it might be better to create it with a checkpointer. What do you think?
commit: e8e8757
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, rather than writing similar code each time we create a new recipe, it might be better to create it with a checkpointer. What do you think?
Good point, I'm inclined to agree with this (kinda similar to my other comment, it's nice to keep the recipe files themselves as clean as possible)
recipes/lora_dpo_distributed.py
Outdated
@@ -143,6 +143,9 @@ def __init__(self, cfg: DictConfig) -> None: | |||
self.global_step = 0 | |||
|
|||
self._resume_from_checkpoint = cfg.resume_from_checkpoint | |||
self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) | |||
log.info(f"save_adapter_weights_only: {self._save_adapter_weights_only}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more of a nit since I know @joecummings gave contrary advice here already, but I don't like logging config fields in the recipe like this. We already log the full config, no need to just directly re-log an individual config field unless there is something more non-trivial happening (e.g. we use some feature that depends on a particular version of PyTorch or something).
Thank you for your review! I have made the corrections based on your feedback. 7dec6ef
You’re absolutely right, this could indeed confuse the user. I’ve now modified the code to log the message based on the value of the option. if not is_intermediate_epoch:
log.ingo("Saving final epoch checkpoint.")
if self._save_adapter_weights_only:
log.info(
"Please note that you have set save_adapter_weights_only=True, so only adapter weights will be saved."
"You need to merge the adapter weights into your base model for further use. "
f"See {type(self._checkpointer).__name__}"
)
else:
log.info(
"The full model checkpoint, including all weights and configurations, has been saved successfully."
"You can now use this checkpoint for further training or inference."
)
I’ve removed the log output for that section. I agree that it’s best not to deviate too much from the overall logging pattern for the other options. |
I moved the log output to checkpointer. |
LGTM. Thanks so much for your contribution, and your patience in addressing our comments : ) |
@SalmanMohammadi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work!
Context
What is the purpose of this PR? Is it to
Changelog
In #1210, the "adapter_only" (boolean) option was added to the save_checkpoint method of each Checkpointer class. When this option is set to True, only the adapter weights are saved instead of the entire model weights.
This PR applies that change to LoRA fine-tuning.
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
pre-commit install
)pytest tests
pytest tests -m integration_test