Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading mistral reward model checkpoints #911

Merged

Conversation

SalmanMohammadi
Copy link
Collaborator

See #812 (comment) for context.

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

I've added some lightweight logic in torchtune/utils/_checkpointing/_checkpointer.py to support loading a mistral reward model from huggingface, and added a test case to ensure model weights are converted correctly.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Copy link

pytorch-bot bot commented May 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/911

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bded16a with merge base f819b4b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 1, 2024
@@ -24,6 +24,7 @@ class ModelType(Enum):
LLAMA3 = "llama3"
MISTRAL = "mistral"
PHI3_MINI = "phi3_mini"
MISTRAL_REWARD = "mistral_reward"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this could be mistral_classifier?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think mistral_reward makes sense unless you think there will be classifier checkpoints we'll need to load?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave as-is for now, we can always generalise later if there's use cases.

@@ -384,6 +384,14 @@ def load_checkpoint(self) -> Dict[str, Any]:

if self._model_type == ModelType.PHI3_MINI:
converted_state_dict[utils.MODEL_KEY] = phi3_hf_to_tune(merged_state_dict)
elif self._model_type == ModelType.MISTRAL_REWARD:
merged_state_dict["lm_head.weight"] = merged_state_dict.pop("score.weight")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does seem very tempting to add this one line change here. But this will start to add up and make this checkpointer really hard to read.

I do need to redesign this component a bit. But in the mean time sn alternative here (granted it's more lines of code with some copy paste) would be to define a custom convert_weights.py within the mistral folder and define a new mistral_reward_hf_to_tune and mistral_reward_tune_to_hf functions which do this mapping (and any future changes you might need for reward modeling). An example is what we did for phi3. As I mentioned, its more copy-paste code, but I can definitely see this function updating as you add more heads etc. And I'm afraid the checkpointer will bloat up because of it.

@ebsmothers let me know what you think about this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense to me.

@@ -426,6 +434,16 @@ def save_checkpoint(
# convert the state_dict back to hf format; do this inplace
if self._model_type == ModelType.PHI3_MINI:
state_dict[utils.MODEL_KEY] = phi3_tune_to_hf(state_dict[utils.MODEL_KEY])
elif self._model_type == ModelType.MISTRAL_REWARD:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

@@ -126,7 +126,7 @@ def hf_to_tune(
repo in HF (https://huggingface.co/meta-llama/Llama-2-7b-hf).

Args:
state_dict (Dict[str, torch.Tensor]): State dict in Meta's format.
state_dict (Dict[str, torch.Tensor]): State dict in HF's format.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good, thank you!

@kartikayk kartikayk merged commit 057709e into pytorch:main May 4, 2024
29 checks passed
@SalmanMohammadi SalmanMohammadi deleted the convert-hf-mistral-reward-models branch July 20, 2024 22:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants