diff --git a/tests/torchtune/utils/test_checkpointer.py b/tests/torchtune/utils/test_checkpointer.py index d8d09452ea..6063eaac4b 100644 --- a/tests/torchtune/utils/test_checkpointer.py +++ b/tests/torchtune/utils/test_checkpointer.py @@ -13,7 +13,7 @@ import torch from torch import randn -from torchtune.models import llama2 +from torchtune.models import llama2, mistral from torchtune.utils._checkpointing import FullModelHFCheckpointer from torchtune.utils._checkpointing._checkpointer_utils import safe_torch_load from torchtune.utils.seed import set_seed @@ -294,6 +294,153 @@ def test_save_load_checkpoint_multiple_file( assert len(output_state_dict_2.keys()) + 1 == len(orig_state_dict_2.keys()) +class TestHFMistralRewardModelFullModelCheckpointer: + @pytest.fixture + def weight_dtype(self): + return torch.float16 + + @pytest.fixture + def state_dict(self, weight_dtype): + """ + State dict for a HF format mistral reward model checkpoint. This state dict is + "complete" and can be loaded into a TorchTune model once correctly converted. + """ + state_dict = { + "model.embed_tokens.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype), + "model.layers.0.input_layernorm.weight": randn(_DIM, dtype=weight_dtype), + "model.layers.0.self_attn.q_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.k_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.v_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.self_attn.o_proj.weight": randn( + _DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.post_attention_layernorm.weight": randn( + _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.gate_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.down_proj.weight": randn( + _DIM, _HIDDEN_DIM, dtype=weight_dtype + ), + "model.layers.0.mlp.up_proj.weight": randn( + _HIDDEN_DIM, _DIM, dtype=weight_dtype + ), + "model.norm.weight": randn(_DIM, dtype=weight_dtype), + "score.weight": randn(1, _DIM, dtype=weight_dtype), + } + return state_dict + + @pytest.fixture + def mistral_reward_model_hf_checkpoint(self, tmp_path, state_dict): + """ + Fixture which creates a checkpoint file for the Mistral reward model. The + state dict follows the HF_FORMAT for the checkpoint format. + + The state dicts supports testing for a single-file checkpoint. + Multiple file checkpoints are already tested for Llama2. + * The checkpoint contains layer0 + embed + norm + score keys + and can be tested in isolation + + The model corresponds to the following config: + * num_layers: 1 + * num_heads: 4 + * num_kv_heads: 4 + * embed_dim: 64 + * max_seq_len: 128 + * num_classes: 1 + * intermediate_dim: 256 + + """ + checkpoint_file = tmp_path / "mistral_reward_model_hf_checkpoint.pt" + + torch.save(state_dict, checkpoint_file) + + config = { + "hidden_size": 64, + "num_attention_heads": 4, + "num_key_value_heads": 4, + "num_classes": 1, + } + config_file = Path.joinpath(tmp_path, "config.json") + with config_file.open("w") as f: + json.dump(config, f) + + return checkpoint_file + + @pytest.fixture + def single_file_checkpointer( + self, mistral_reward_model_hf_checkpoint, tmp_path + ) -> FullModelHFCheckpointer: + checkpoint_file = mistral_reward_model_hf_checkpoint + return FullModelHFCheckpointer( + checkpoint_dir=tmp_path, + checkpoint_files=[checkpoint_file], + model_type="MISTRAL_REWARD", + output_dir=tmp_path, + ) + + def test_load_save_checkpoint_single_file( + self, + single_file_checkpointer: FullModelHFCheckpointer, + mistral_reward_model_hf_checkpoint: Path, + ): + """ + Test ``load_checkpoint`` and ``save_checkpoint`` method within the + FullModelHFCheckpointer for a single checkpoint file for a mistral reward model. + + We test: + * ``load_checkpoint`` loads the right sets of keys + * Internal state of the checkpointer is correctly updated + * Converted checkpoint can be loaded into the `mistral_classifier` TorchTune implementation + * Saved checkpoint keys match the original checkpoint + """ + # Read the state dict directly from file using torch.load. This will be the state + # dict we test against + checkpoint_file = mistral_reward_model_hf_checkpoint + orig_state_dict = safe_torch_load(checkpoint_file) + + # Converted state dict from the checkpointer + state_dict = single_file_checkpointer.load_checkpoint() + # Check that we've loaded all the keys + assert len(state_dict["model"].keys()) == len(orig_state_dict.keys()) + + # the keys in original state dict should match up with the keys in the weight_map + for key in orig_state_dict.keys(): + if "inv_freq" in key: + continue + assert key in single_file_checkpointer._weight_map + + # loading the state dict into the model implementation should work correctly + model = mistral.mistral_classifier( + num_classes=1, + vocab_size=_VOCAB_SIZE, + num_layers=1, + num_heads=_NUM_HEADS, + num_kv_heads=_NUM_KV_HEADS, + embed_dim=_DIM, + intermediate_dim=_HIDDEN_DIM, + max_seq_len=128, + ) + model.load_state_dict(state_dict["model"]) + + single_file_checkpointer.save_checkpoint(state_dict, epoch=1) + + # Reload the output checkpoint file and compare to the original checkpoint. This + # assumes we know what the name of the file is. This is fine, breaking this logic + # should be something we capture through this test + output_file = Path.joinpath(checkpoint_file.parent, "hf_model_0001_1.pt") + output_state_dict = safe_torch_load(output_file) + + assert len(output_state_dict.keys()) == len(orig_state_dict.keys()) + + class TestCheckpointerUtils: @pytest.fixture def model_checkpoint(self, tmp_path): diff --git a/torchtune/models/convert_weights.py b/torchtune/models/convert_weights.py index 6555c5d43e..8d9a542b87 100644 --- a/torchtune/models/convert_weights.py +++ b/torchtune/models/convert_weights.py @@ -126,7 +126,7 @@ def hf_to_tune( repo in HF (https://huggingface.co/meta-llama/Llama-2-7b-hf). Args: - state_dict (Dict[str, torch.Tensor]): State dict in Meta's format. + state_dict (Dict[str, torch.Tensor]): State dict in HF's format. num_heads (int): Number of heads in the model. num_kv_heads (int): Number of heads in the key/value projection layers. dim (int): Dimension of the model. @@ -176,7 +176,7 @@ def tune_to_hf( dim (int): Dimension of the model. Returns: - Dict[str, torch.Tensor]: State dict in Meta's format. + Dict[str, torch.Tensor]: State dict in HF's format. """ converted_state_dict = {} inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()} diff --git a/torchtune/models/mistral/__init__.py b/torchtune/models/mistral/__init__.py index ad04776412..a97c92ee93 100644 --- a/torchtune/models/mistral/__init__.py +++ b/torchtune/models/mistral/__init__.py @@ -5,6 +5,10 @@ # LICENSE file in the root directory of this source tree. from ._component_builders import lora_mistral, mistral, mistral_classifier +from ._convert_weights import ( # noqa + mistral_reward_hf_to_tune, + mistral_reward_tune_to_hf, +) from ._model_builders import ( lora_mistral_7b, lora_mistral_classifier, diff --git a/torchtune/models/mistral/_convert_weights.py b/torchtune/models/mistral/_convert_weights.py new file mode 100644 index 0000000000..f2cd1ceb71 --- /dev/null +++ b/torchtune/models/mistral/_convert_weights.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch + +from torchtune.models.convert_weights import get_mapped_key + +_MISTRAL_REWARD = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale", + "model.norm.weight": "norm.scale", + "score.weight": "output.weight", +} + + +def mistral_reward_hf_to_tune( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, +) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from HF's format to TorchTune's format, which contains the weights + of a Mistral reward model. + State dicts from multiple checkpoint files should be consolidated into a single state dict + before calling this function. + The logic is identical to :func:`~torchtune.models.convert_weights.hf_to_tune`, but with a different mapping. + + Eg of HF-format state dict can be found in the ``Ray2333/reward-model-Mistral-7B-instruct-Unified-Feedback`` + repo in HF. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in HF's format. + num_heads (int): Number of heads in the model. + num_kv_heads (int): Number of heads in the key/value projection layers. + dim (int): Dimension of the model. + head_dim (int): Dimension of the head. If not provided, it will be calculated + as dim // num_heads. + + Returns: + Dict[str, torch.Tensor]: State dict in TorchTune's format. + """ + converted_state_dict = {} + if head_dim is None: + head_dim = dim // num_heads + + def _permute(t, n_heads): + return ( + t.view(n_heads, 2, head_dim // 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + + for key, value in state_dict.items(): + if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings + new_key = get_mapped_key(key, _MISTRAL_REWARD) + if "q_proj" in key: + value = _permute(value, num_heads) + elif "k_proj" in key: + value = _permute(value, num_kv_heads) + converted_state_dict[new_key] = value + return converted_state_dict + + +def mistral_reward_tune_to_hf( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, +) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from TorchTune's format to Hugging Face's format for a Mistral reward model. + + This function takes a state dictionary in TorchTune's format, which contains the weights of a Mistral reward model, + and converts it into a format that can be loaded into a Hugging Face model. + The logic is identical to :func:`~torchtune.models.convert_weights.tune_to_hf`, but with a different mapping. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in TorchTune's format. + num_heads (int, optional): Number of heads in the model. Defaults to 32. + num_kv_heads (int, optional): Number of heads in the key/value projection layers. Defaults to 32. + dim (int, optional): Dimension of the model. Defaults to 4096. + + Returns: + Dict[str, torch.Tensor]: State dict in Hugging Face's format. + + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _MISTRAL_REWARD.items()} + head_dim = dim // num_heads + + def _permute(t, n_heads): + return ( + t.view(n_heads, head_dim // 2, 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + if "q_proj" in key: + value = _permute(value, num_heads) + elif "k_proj" in key: + value = _permute(value, num_kv_heads) + converted_state_dict[new_key] = value + + return converted_state_dict diff --git a/torchtune/utils/_checkpointing/_checkpointer.py b/torchtune/utils/_checkpointing/_checkpointer.py index 760aa3fedd..653c7da321 100644 --- a/torchtune/utils/_checkpointing/_checkpointer.py +++ b/torchtune/utils/_checkpointing/_checkpointer.py @@ -15,6 +15,10 @@ from torchtune import utils from torchtune.models import convert_weights +from torchtune.models.mistral import ( + mistral_reward_hf_to_tune, + mistral_reward_tune_to_hf, +) from torchtune.models.phi3 import phi3_hf_to_tune, phi3_tune_to_hf from torchtune.utils._checkpointing._checkpointer_utils import ( get_path, @@ -384,6 +388,13 @@ def load_checkpoint(self) -> Dict[str, Any]: if self._model_type == ModelType.PHI3_MINI: converted_state_dict[utils.MODEL_KEY] = phi3_hf_to_tune(merged_state_dict) + elif self._model_type == ModelType.MISTRAL_REWARD: + converted_state_dict[utils.MODEL_KEY] = mistral_reward_hf_to_tune( + merged_state_dict, + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + ) else: converted_state_dict[utils.MODEL_KEY] = convert_weights.hf_to_tune( merged_state_dict, @@ -426,6 +437,13 @@ def save_checkpoint( # convert the state_dict back to hf format; do this inplace if self._model_type == ModelType.PHI3_MINI: state_dict[utils.MODEL_KEY] = phi3_tune_to_hf(state_dict[utils.MODEL_KEY]) + elif self._model_type == ModelType.MISTRAL_REWARD: + state_dict[utils.MODEL_KEY] = mistral_reward_tune_to_hf( + state_dict[utils.MODEL_KEY], + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + ) else: state_dict[utils.MODEL_KEY] = convert_weights.tune_to_hf( state_dict[utils.MODEL_KEY], diff --git a/torchtune/utils/_checkpointing/_checkpointer_utils.py b/torchtune/utils/_checkpointing/_checkpointer_utils.py index d28567756b..d24f6ac8fb 100644 --- a/torchtune/utils/_checkpointing/_checkpointer_utils.py +++ b/torchtune/utils/_checkpointing/_checkpointer_utils.py @@ -24,6 +24,7 @@ class ModelType(Enum): LLAMA3 = "llama3" MISTRAL = "mistral" PHI3_MINI = "phi3_mini" + MISTRAL_REWARD = "mistral_reward" def get_path(input_dir: Path, filename: str, missing_ok: bool = False) -> Path: