-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loading mistral reward model checkpoints #911
Loading mistral reward model checkpoints #911
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/911
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit bded16a with merge base f819b4b (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -24,6 +24,7 @@ class ModelType(Enum): | |||
LLAMA3 = "llama3" | |||
MISTRAL = "mistral" | |||
PHI3_MINI = "phi3_mini" | |||
MISTRAL_REWARD = "mistral_reward" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this could be mistral_classifier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think mistral_reward
makes sense unless you think there will be classifier checkpoints we'll need to load?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll leave as-is for now, we can always generalise later if there's use cases.
@@ -384,6 +384,14 @@ def load_checkpoint(self) -> Dict[str, Any]: | |||
|
|||
if self._model_type == ModelType.PHI3_MINI: | |||
converted_state_dict[utils.MODEL_KEY] = phi3_hf_to_tune(merged_state_dict) | |||
elif self._model_type == ModelType.MISTRAL_REWARD: | |||
merged_state_dict["lm_head.weight"] = merged_state_dict.pop("score.weight") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does seem very tempting to add this one line change here. But this will start to add up and make this checkpointer really hard to read.
I do need to redesign this component a bit. But in the mean time sn alternative here (granted it's more lines of code with some copy paste) would be to define a custom convert_weights.py
within the mistral folder and define a new mistral_reward_hf_to_tune
and mistral_reward_tune_to_hf
functions which do this mapping (and any future changes you might need for reward modeling). An example is what we did for phi3
. As I mentioned, its more copy-paste code, but I can definitely see this function updating as you add more heads etc. And I'm afraid the checkpointer will bloat up because of it.
@ebsmothers let me know what you think about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this makes sense to me.
@@ -426,6 +434,16 @@ def save_checkpoint( | |||
# convert the state_dict back to hf format; do this inplace | |||
if self._model_type == ModelType.PHI3_MINI: | |||
state_dict[utils.MODEL_KEY] = phi3_tune_to_hf(state_dict[utils.MODEL_KEY]) | |||
elif self._model_type == ModelType.MISTRAL_REWARD: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above
@@ -126,7 +126,7 @@ def hf_to_tune( | |||
repo in HF (https://huggingface.co/meta-llama/Llama-2-7b-hf). | |||
|
|||
Args: | |||
state_dict (Dict[str, torch.Tensor]): State dict in Meta's format. | |||
state_dict (Dict[str, torch.Tensor]): State dict in HF's format. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks good, thank you!
See #812 (comment) for context.
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
I've added some lightweight logic in
torchtune/utils/_checkpointing/_checkpointer.py
to support loading a mistral reward model from huggingface, and added a test case to ensure model weights are converted correctly.Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
pre-commit install
)pytest tests
pytest tests -m integration_test