Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added adapter_only option to LoRA #1220

Merged
merged 13 commits into from
Jul 29, 2024
Merged

Added adapter_only option to LoRA #1220

merged 13 commits into from
Jul 29, 2024

Conversation

spider-man-tm
Copy link
Contributor

@spider-man-tm spider-man-tm commented Jul 25, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Changelog

In #1210, the "adapter_only" (boolean) option was added to the save_checkpoint method of each Checkpointer class. When this option is set to True, only the adapter weights are saved instead of the entire model weights.
This PR applies that change to LoRA fine-tuning.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Copy link

pytorch-bot bot commented Jul 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1220

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit eb7d41d with merge base f0a15c5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @spider-man-tm!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 25, 2024
@spider-man-tm spider-man-tm marked this pull request as ready for review July 25, 2024 09:52
@spider-man-tm
Copy link
Contributor Author

To the reviewer,

Please let me know if there are any missing parts in my implementation or if there are any parts that should be removed. It's perfectly fine if the maintainer edits it directly.

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 0% with 20 lines in your changes missing coverage. Please review.

Project coverage is 70.19%. Comparing base (7eb89e2) to head (aeedd92).
Report is 2 commits behind head on main.

Files Patch % Lines
recipes/lora_dpo_distributed.py 0.00% 5 Missing ⚠️
recipes/lora_dpo_single_device.py 0.00% 5 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 5 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1220      +/-   ##
==========================================
+ Coverage   67.81%   70.19%   +2.37%     
==========================================
  Files         219      220       +1     
  Lines        9908     9957      +49     
==========================================
+ Hits         6719     6989     +270     
+ Misses       3189     2968     -221     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

intermediate_checkpoint=intermediate_checkpoint,
)
# If the option was True, save only the adapter except for the last epoch
is_intermediate_epoch = epoch + 1 < self.total_epochs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for your contribution : )
Small q: do we want to support the case where merged weights are never saved, even in the final epoch, and let the user merge them? Might be helpful when working with particularly large models. @joecummings

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I think @SalmanMohammadi's point is inline with what I was thinking here. This does mean that we need to do a better job explaining how to merge weights or what exactly to do with them afterwards.

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concretely, should we fix adapter_only=self._save_adapter_only here, and add something to the effect of:

if not is_intermediate_epoch:
    log.info(
        "Saving final model checkpoint. Please note that you have set save_adapter_only=True, so only adapter weights will be saved." 
        "You will need to merge the adapter weights into your base model for further use. See {where do we point them for now??}")

EDIT: could we also move this into the checkpointer's save_checkpoint?

# adapter_only option
# Set to True to save only the adapter weights for intermediate epochs.
# For the final epoch, the entire model weights will be saved regardless of this option.
adapter_only: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Considering save_adapter_weights_only instead of adapter_only for clarity.

I know it's different from the param name on the checkpointer, but the extra context isn't needed there b/c you already know you're dealing with saving a checkpointing.

cc @SalmanMohammadi and @pbontrager for thoughts.

@@ -143,6 +143,7 @@ def __init__(self, cfg: DictConfig) -> None:
self.global_step = 0

self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._adapter_only = cfg.get("adapter_only", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a logging INFO statement saying that only the adapter weights will be saved so the user knows right away?

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your excellent point, and raise you another (dubious quality) point.

In line with moving towards outsourcing config validation/handholding into utilities from recipes, why not have save_adapter_weights_only asd a property of the checkpointer? We could then throw an INFO in the checkpointer constructor.

Your config will then be:

checkpointer:
  ...
  save_adapter_weights_only: True

This would fit if we're not doing any additional logic checks on saving the adapter weights only, i.e. if save_adapter_weights_only: True we don't merge weights in the last epoch.

edit: sorry, as usual, I'm adding complexity to everything I touch

Copy link
Contributor

@joecummings joecummings Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm, this does make sense to me. I'd much rather have the error propagated from the checkpointer, rather than the recipe file.

I think the complexity arises in how we actually instantiate the checkpointer. I don't think we want save_adapter_weights_only as a param on the initialization of the checkpointer b/c the user should be able to take the component and save adapters or the whole file if they want without re-initializing it. Therefore, it has to be a param on the save method.

So if we don't want it being passed to the constructor, we would have to parse out the save_adapter_weights_only from the config file before creating the checkpointer, which feels a little messy. Maybe we push the warning to the save method of the checkpointer? The only issue with this is that then the user could accidentally have the save_adapter_weights_only=True, train their whole model, and then save the checkpoints without actually wanting that feature. But that might be too hand-holdy to worry about?

Thanks for coming on this journey of vomiting all my thoughts on this PR. I think my TL;DR is that this config variable should be separate from the instantiation of the checkpointer, but maybe we push the logging info to the save method instead of in the main recipe.

Thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we push the logging info to the save method instead of in the main recipe.

The save_checkpoint_method of the checkpointer right? I agree here - we also validate adapter_only there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally agree with @joecummings here.. the checkpointer class itself should mostly care about stuff that's global across a given fine-tuning run (input and output checkpoint formats). The flag to save only adapter weights is a more local thing since it can in theory vary depending on where we are in training. So imo it makes sense to expose only in save_checkpoint and not in init.

A separate (but still relevant) point: our checkpointer naming still seems to imply it's only used for full fine-tuning. This is confusing and another reason why it'd look weird to put save_adapter_weights_only in its init. (I don't think that's a good reason to keep save_adapter_weights_only outside the checkpointer init, I actually think we should just rename the checkpointer instead)

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great - thanks so much :)

Just a couple comments.

@spider-man-tm
Copy link
Contributor Author

Thank you to all reviewers. I have made the following revisions:

  • The name of the option adapter_only was changed to save_adapter_weights_only. While adapter_only is understandable within the save_checkpoint method, it becomes ambiguous when used independently.
  • To avoid complications, the option is retained as a method option for save_checkpoint rather than being included in the Checkpointer constructor.
  • Adjusted the code to maintain consistency by removing the special case where all weights were saved only for the last epoch.
  • Added log output to inform the user that only adapter weights are being saved.
    • Due to a too long error occurring when running pre-commit run --all-files, the logs are spread over 4 lines.

Feel free to review the changes and provide further feedback!

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really appreciate you adding this! Just two minor things then I think this is good to go

Comment on lines 528 to 534
if not is_intermediate_epoch:
log.info(
"Saving final model checkpoint."
"Please note that you have set save_adapter_weights_only=True, so only adapter weights will be saved."
"You need to merge the adapter weights into your base model for further use. "
f"See {type(self._checkpointer).__name__}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm being dense but is this log correct? I don't see where we actually check that save_adapter_weights_only=True prior to logging this (similar comment in the other recipes too)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Sorry, my original suggestion missed this.

Do we still think it's a good idea to move this into the checkpointer, rather than recipes?

Copy link
Contributor Author

@spider-man-tm spider-man-tm Jul 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SalmanMohammadi

Indeed, rather than writing similar code each time we create a new recipe, it might be better to create it with a checkpointer. What do you think?

commit: e8e8757

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, rather than writing similar code each time we create a new recipe, it might be better to create it with a checkpointer. What do you think?

Good point, I'm inclined to agree with this (kinda similar to my other comment, it's nice to keep the recipe files themselves as clean as possible)

@@ -143,6 +143,9 @@ def __init__(self, cfg: DictConfig) -> None:
self.global_step = 0

self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
log.info(f"save_adapter_weights_only: {self._save_adapter_weights_only}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more of a nit since I know @joecummings gave contrary advice here already, but I don't like logging config fields in the recipe like this. We already log the full config, no need to just directly re-log an individual config field unless there is something more non-trivial happening (e.g. we use some feature that depends on a particular version of PyTorch or something).

@spider-man-tm
Copy link
Contributor Author

spider-man-tm commented Jul 27, 2024

@ebsmothers

Thank you for your review! I have made the corrections based on your feedback. 7dec6ef

Maybe I'm being dense but is this log correct? I don't see where we actually check that save_adapter_weights_only=True prior to logging this

You’re absolutely right, this could indeed confuse the user. I’ve now modified the code to log the message based on the value of the option.

            if not is_intermediate_epoch:
                log.ingo("Saving final epoch checkpoint.")
                if self._save_adapter_weights_only:
                    log.info(
                        "Please note that you have set save_adapter_weights_only=True, so only adapter weights will be saved."
                        "You need to merge the adapter weights into your base model for further use. "
                        f"See {type(self._checkpointer).__name__}"
                    )
                else:
                    log.info(
                        "The full model checkpoint, including all weights and configurations, has been saved successfully."
                        "You can now use this checkpoint for further training or inference."
                    )

We already log the full config, no need to just directly re-log an individual config field unless there is something more non-trivial happening

I’ve removed the log output for that section. I agree that it’s best not to deviate too much from the overall logging pattern for the other options.

@spider-man-tm
Copy link
Contributor Author

I moved the log output to checkpointer.
e8e8757

@SalmanMohammadi
Copy link
Collaborator

LGTM. Thanks so much for your contribution, and your patience in addressing our comments : )

@spider-man-tm
Copy link
Contributor Author

@SalmanMohammadi
PyTorch is my favorite ML package, and I'm happy to have contributed to it. Thanks so much!

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work!

@joecummings joecummings merged commit f34b5b0 into pytorch:main Jul 29, 2024
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants