-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Trainer
] Correct behavior of _load_best_model
for PEFT models
#24103
[Trainer
] Correct behavior of _load_best_model
for PEFT models
#24103
Conversation
Trainer
] Correct behavior of _load_best_model
Trainer
] Correct behavior of _load_best_model
for PEFT models
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @younesbelkada for simplifying trainer usage with PEFT in terms of saving/loading as this has been a reason for numerous issues 🚀. Left few comments.
src/transformers/trainer.py
Outdated
@@ -2177,11 +2177,18 @@ def _load_best_model(self): | |||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") | |||
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) | |||
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) | |||
adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.bin") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it can also be safetensor ckpts, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe adding best_safe_adapter_model_path
should serve the purpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect, will refactor that a bit
src/transformers/trainer.py
Outdated
else: | ||
# We can't do pure 8bit training using transformers. | ||
logger.warning("Could not loading a quantized checkpoint.") | ||
has_been_loaded = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be removed now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is needed so that it can be used in the block below for the check, otherwise it will throw an error similar as #24096
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AH sorry I see what you meant, yes will remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
proposed something in bf31c5e
- add ST format as well
src/transformers/trainer.py
Outdated
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.bin") | ||
best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, "adapter_model.safetensors") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those two should be in constants (like WEIGHTS_NAME
) as they are now used several time across the file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, just added it!
…24103) * v1 * some refactor - add ST format as well * fix * add `ADAPTER_WEIGHTS_NAME` & `ADAPTER_SAFE_WEIGHTS_NAME`
…uggingface#24103) * v1 * some refactor - add ST format as well * fix * add `ADAPTER_WEIGHTS_NAME` & `ADAPTER_SAFE_WEIGHTS_NAME`
What does this PR do?
Fixes #24096
This PR fixes the bugs related with PEFT models and
load_best_model_at_end
. It also refactors a bit the current logic to extend it generally to all LoRA models, not only 8-bit base models + LoRA.Repro script
cc @sgugger @pacman100