-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loading mistral reward model checkpoints #911
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Dict | ||
|
||
import torch | ||
|
||
from torchtune.models.convert_weights import get_mapped_key | ||
|
||
_MISTRAL_REWARD = { | ||
"model.embed_tokens.weight": "tok_embeddings.weight", | ||
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", | ||
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", | ||
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight", | ||
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", | ||
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", | ||
"model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", | ||
"model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", | ||
"model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", | ||
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", | ||
"model.norm.weight": "norm.scale", | ||
"score.weight": "output.weight", | ||
} | ||
|
||
|
||
def mistral_reward_hf_to_tune( | ||
state_dict: Dict[str, torch.Tensor], | ||
num_heads: int = 32, | ||
num_kv_heads: int = 32, | ||
dim: int = 4096, | ||
head_dim: int = None, | ||
) -> Dict[str, torch.Tensor]: | ||
""" | ||
Convert a state dict from HF's format to TorchTune's format, which contains the weights | ||
of a Mistral reward model. | ||
State dicts from multiple checkpoint files should be consolidated into a single state dict | ||
before calling this function. | ||
The logic is identical to :func:`~torchtune.models.convert_weights.hf_to_tune`, but with a different mapping. | ||
|
||
Eg of HF-format state dict can be found in the ``Ray2333/reward-model-Mistral-7B-instruct-Unified-Feedback`` | ||
repo in HF. | ||
|
||
Args: | ||
state_dict (Dict[str, torch.Tensor]): State dict in HF's format. | ||
num_heads (int): Number of heads in the model. | ||
num_kv_heads (int): Number of heads in the key/value projection layers. | ||
dim (int): Dimension of the model. | ||
head_dim (int): Dimension of the head. If not provided, it will be calculated | ||
as dim // num_heads. | ||
|
||
Returns: | ||
Dict[str, torch.Tensor]: State dict in TorchTune's format. | ||
""" | ||
converted_state_dict = {} | ||
if head_dim is None: | ||
head_dim = dim // num_heads | ||
|
||
def _permute(t, n_heads): | ||
return ( | ||
t.view(n_heads, 2, head_dim // 2, dim) | ||
.transpose(1, 2) | ||
.reshape((head_dim * n_heads), dim) | ||
) | ||
|
||
for key, value in state_dict.items(): | ||
if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings | ||
new_key = get_mapped_key(key, _MISTRAL_REWARD) | ||
if "q_proj" in key: | ||
value = _permute(value, num_heads) | ||
elif "k_proj" in key: | ||
value = _permute(value, num_kv_heads) | ||
converted_state_dict[new_key] = value | ||
return converted_state_dict | ||
|
||
|
||
def mistral_reward_tune_to_hf( | ||
state_dict: Dict[str, torch.Tensor], | ||
num_heads: int = 32, | ||
num_kv_heads: int = 32, | ||
dim: int = 4096, | ||
) -> Dict[str, torch.Tensor]: | ||
""" | ||
Convert a state dict from TorchTune's format to Hugging Face's format for a Mistral reward model. | ||
|
||
This function takes a state dictionary in TorchTune's format, which contains the weights of a Mistral reward model, | ||
and converts it into a format that can be loaded into a Hugging Face model. | ||
The logic is identical to :func:`~torchtune.models.convert_weights.tune_to_hf`, but with a different mapping. | ||
|
||
Args: | ||
state_dict (Dict[str, torch.Tensor]): State dict in TorchTune's format. | ||
num_heads (int, optional): Number of heads in the model. Defaults to 32. | ||
num_kv_heads (int, optional): Number of heads in the key/value projection layers. Defaults to 32. | ||
dim (int, optional): Dimension of the model. Defaults to 4096. | ||
|
||
Returns: | ||
Dict[str, torch.Tensor]: State dict in Hugging Face's format. | ||
|
||
""" | ||
converted_state_dict = {} | ||
inverted_mapping_dict = {v: k for k, v in _MISTRAL_REWARD.items()} | ||
head_dim = dim // num_heads | ||
|
||
def _permute(t, n_heads): | ||
return ( | ||
t.view(n_heads, head_dim // 2, 2, dim) | ||
.transpose(1, 2) | ||
.reshape((head_dim * n_heads), dim) | ||
) | ||
|
||
for key, value in state_dict.items(): | ||
new_key = get_mapped_key(key, inverted_mapping_dict) | ||
if "q_proj" in key: | ||
value = _permute(value, num_heads) | ||
elif "k_proj" in key: | ||
value = _permute(value, num_kv_heads) | ||
converted_state_dict[new_key] = value | ||
|
||
return converted_state_dict |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,10 @@ | |
from torchtune import utils | ||
|
||
from torchtune.models import convert_weights | ||
from torchtune.models.mistral import ( | ||
mistral_reward_hf_to_tune, | ||
mistral_reward_tune_to_hf, | ||
) | ||
from torchtune.models.phi3 import phi3_hf_to_tune, phi3_tune_to_hf | ||
from torchtune.utils._checkpointing._checkpointer_utils import ( | ||
get_path, | ||
|
@@ -384,6 +388,13 @@ def load_checkpoint(self) -> Dict[str, Any]: | |
|
||
if self._model_type == ModelType.PHI3_MINI: | ||
converted_state_dict[utils.MODEL_KEY] = phi3_hf_to_tune(merged_state_dict) | ||
elif self._model_type == ModelType.MISTRAL_REWARD: | ||
converted_state_dict[utils.MODEL_KEY] = mistral_reward_hf_to_tune( | ||
merged_state_dict, | ||
num_heads=self._config["num_attention_heads"], | ||
num_kv_heads=self._config["num_key_value_heads"], | ||
dim=self._config["hidden_size"], | ||
) | ||
else: | ||
converted_state_dict[utils.MODEL_KEY] = convert_weights.hf_to_tune( | ||
merged_state_dict, | ||
|
@@ -426,6 +437,13 @@ def save_checkpoint( | |
# convert the state_dict back to hf format; do this inplace | ||
if self._model_type == ModelType.PHI3_MINI: | ||
state_dict[utils.MODEL_KEY] = phi3_tune_to_hf(state_dict[utils.MODEL_KEY]) | ||
elif self._model_type == ModelType.MISTRAL_REWARD: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as above |
||
state_dict[utils.MODEL_KEY] = mistral_reward_tune_to_hf( | ||
state_dict[utils.MODEL_KEY], | ||
num_heads=self._config["num_attention_heads"], | ||
num_kv_heads=self._config["num_key_value_heads"], | ||
dim=self._config["hidden_size"], | ||
) | ||
else: | ||
state_dict[utils.MODEL_KEY] = convert_weights.tune_to_hf( | ||
state_dict[utils.MODEL_KEY], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ class ModelType(Enum): | |
LLAMA3 = "llama3" | ||
MISTRAL = "mistral" | ||
PHI3_MINI = "phi3_mini" | ||
MISTRAL_REWARD = "mistral_reward" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe this could be mistral_classifier? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll leave as-is for now, we can always generalise later if there's use cases. |
||
|
||
|
||
def get_path(input_dir: Path, filename: str, missing_ok: bool = False) -> Path: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch!