Skip to content

Commit

Permalink
Loading mistral reward model checkpoints (#911)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored May 4, 2024
1 parent d36e818 commit 057709e
Show file tree
Hide file tree
Showing 6 changed files with 294 additions and 3 deletions.
149 changes: 148 additions & 1 deletion tests/torchtune/utils/test_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
from torch import randn

from torchtune.models import llama2
from torchtune.models import llama2, mistral
from torchtune.utils._checkpointing import FullModelHFCheckpointer
from torchtune.utils._checkpointing._checkpointer_utils import safe_torch_load
from torchtune.utils.seed import set_seed
Expand Down Expand Up @@ -294,6 +294,153 @@ def test_save_load_checkpoint_multiple_file(
assert len(output_state_dict_2.keys()) + 1 == len(orig_state_dict_2.keys())


class TestHFMistralRewardModelFullModelCheckpointer:
@pytest.fixture
def weight_dtype(self):
return torch.float16

@pytest.fixture
def state_dict(self, weight_dtype):
"""
State dict for a HF format mistral reward model checkpoint. This state dict is
"complete" and can be loaded into a TorchTune model once correctly converted.
"""
state_dict = {
"model.embed_tokens.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype),
"model.layers.0.input_layernorm.weight": randn(_DIM, dtype=weight_dtype),
"model.layers.0.self_attn.q_proj.weight": randn(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.k_proj.weight": randn(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.v_proj.weight": randn(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.o_proj.weight": randn(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.post_attention_layernorm.weight": randn(
_DIM, dtype=weight_dtype
),
"model.layers.0.mlp.gate_proj.weight": randn(
_HIDDEN_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.mlp.down_proj.weight": randn(
_DIM, _HIDDEN_DIM, dtype=weight_dtype
),
"model.layers.0.mlp.up_proj.weight": randn(
_HIDDEN_DIM, _DIM, dtype=weight_dtype
),
"model.norm.weight": randn(_DIM, dtype=weight_dtype),
"score.weight": randn(1, _DIM, dtype=weight_dtype),
}
return state_dict

@pytest.fixture
def mistral_reward_model_hf_checkpoint(self, tmp_path, state_dict):
"""
Fixture which creates a checkpoint file for the Mistral reward model. The
state dict follows the HF_FORMAT for the checkpoint format.
The state dicts supports testing for a single-file checkpoint.
Multiple file checkpoints are already tested for Llama2.
* The checkpoint contains layer0 + embed + norm + score keys
and can be tested in isolation
The model corresponds to the following config:
* num_layers: 1
* num_heads: 4
* num_kv_heads: 4
* embed_dim: 64
* max_seq_len: 128
* num_classes: 1
* intermediate_dim: 256
"""
checkpoint_file = tmp_path / "mistral_reward_model_hf_checkpoint.pt"

torch.save(state_dict, checkpoint_file)

config = {
"hidden_size": 64,
"num_attention_heads": 4,
"num_key_value_heads": 4,
"num_classes": 1,
}
config_file = Path.joinpath(tmp_path, "config.json")
with config_file.open("w") as f:
json.dump(config, f)

return checkpoint_file

@pytest.fixture
def single_file_checkpointer(
self, mistral_reward_model_hf_checkpoint, tmp_path
) -> FullModelHFCheckpointer:
checkpoint_file = mistral_reward_model_hf_checkpoint
return FullModelHFCheckpointer(
checkpoint_dir=tmp_path,
checkpoint_files=[checkpoint_file],
model_type="MISTRAL_REWARD",
output_dir=tmp_path,
)

def test_load_save_checkpoint_single_file(
self,
single_file_checkpointer: FullModelHFCheckpointer,
mistral_reward_model_hf_checkpoint: Path,
):
"""
Test ``load_checkpoint`` and ``save_checkpoint`` method within the
FullModelHFCheckpointer for a single checkpoint file for a mistral reward model.
We test:
* ``load_checkpoint`` loads the right sets of keys
* Internal state of the checkpointer is correctly updated
* Converted checkpoint can be loaded into the `mistral_classifier` TorchTune implementation
* Saved checkpoint keys match the original checkpoint
"""
# Read the state dict directly from file using torch.load. This will be the state
# dict we test against
checkpoint_file = mistral_reward_model_hf_checkpoint
orig_state_dict = safe_torch_load(checkpoint_file)

# Converted state dict from the checkpointer
state_dict = single_file_checkpointer.load_checkpoint()
# Check that we've loaded all the keys
assert len(state_dict["model"].keys()) == len(orig_state_dict.keys())

# the keys in original state dict should match up with the keys in the weight_map
for key in orig_state_dict.keys():
if "inv_freq" in key:
continue
assert key in single_file_checkpointer._weight_map

# loading the state dict into the model implementation should work correctly
model = mistral.mistral_classifier(
num_classes=1,
vocab_size=_VOCAB_SIZE,
num_layers=1,
num_heads=_NUM_HEADS,
num_kv_heads=_NUM_KV_HEADS,
embed_dim=_DIM,
intermediate_dim=_HIDDEN_DIM,
max_seq_len=128,
)
model.load_state_dict(state_dict["model"])

single_file_checkpointer.save_checkpoint(state_dict, epoch=1)

# Reload the output checkpoint file and compare to the original checkpoint. This
# assumes we know what the name of the file is. This is fine, breaking this logic
# should be something we capture through this test
output_file = Path.joinpath(checkpoint_file.parent, "hf_model_0001_1.pt")
output_state_dict = safe_torch_load(output_file)

assert len(output_state_dict.keys()) == len(orig_state_dict.keys())


class TestCheckpointerUtils:
@pytest.fixture
def model_checkpoint(self, tmp_path):
Expand Down
4 changes: 2 additions & 2 deletions torchtune/models/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def hf_to_tune(
repo in HF (https://huggingface.co/meta-llama/Llama-2-7b-hf).
Args:
state_dict (Dict[str, torch.Tensor]): State dict in Meta's format.
state_dict (Dict[str, torch.Tensor]): State dict in HF's format.
num_heads (int): Number of heads in the model.
num_kv_heads (int): Number of heads in the key/value projection layers.
dim (int): Dimension of the model.
Expand Down Expand Up @@ -176,7 +176,7 @@ def tune_to_hf(
dim (int): Dimension of the model.
Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
Dict[str, torch.Tensor]: State dict in HF's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()}
Expand Down
4 changes: 4 additions & 0 deletions torchtune/models/mistral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
# LICENSE file in the root directory of this source tree.

from ._component_builders import lora_mistral, mistral, mistral_classifier
from ._convert_weights import ( # noqa
mistral_reward_hf_to_tune,
mistral_reward_tune_to_hf,
)
from ._model_builders import (
lora_mistral_7b,
lora_mistral_classifier,
Expand Down
121 changes: 121 additions & 0 deletions torchtune/models/mistral/_convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch

from torchtune.models.convert_weights import get_mapped_key

_MISTRAL_REWARD = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
"model.norm.weight": "norm.scale",
"score.weight": "output.weight",
}


def mistral_reward_hf_to_tune(
state_dict: Dict[str, torch.Tensor],
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
head_dim: int = None,
) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from HF's format to TorchTune's format, which contains the weights
of a Mistral reward model.
State dicts from multiple checkpoint files should be consolidated into a single state dict
before calling this function.
The logic is identical to :func:`~torchtune.models.convert_weights.hf_to_tune`, but with a different mapping.
Eg of HF-format state dict can be found in the ``Ray2333/reward-model-Mistral-7B-instruct-Unified-Feedback``
repo in HF.
Args:
state_dict (Dict[str, torch.Tensor]): State dict in HF's format.
num_heads (int): Number of heads in the model.
num_kv_heads (int): Number of heads in the key/value projection layers.
dim (int): Dimension of the model.
head_dim (int): Dimension of the head. If not provided, it will be calculated
as dim // num_heads.
Returns:
Dict[str, torch.Tensor]: State dict in TorchTune's format.
"""
converted_state_dict = {}
if head_dim is None:
head_dim = dim // num_heads

def _permute(t, n_heads):
return (
t.view(n_heads, 2, head_dim // 2, dim)
.transpose(1, 2)
.reshape((head_dim * n_heads), dim)
)

for key, value in state_dict.items():
if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings
new_key = get_mapped_key(key, _MISTRAL_REWARD)
if "q_proj" in key:
value = _permute(value, num_heads)
elif "k_proj" in key:
value = _permute(value, num_kv_heads)
converted_state_dict[new_key] = value
return converted_state_dict


def mistral_reward_tune_to_hf(
state_dict: Dict[str, torch.Tensor],
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from TorchTune's format to Hugging Face's format for a Mistral reward model.
This function takes a state dictionary in TorchTune's format, which contains the weights of a Mistral reward model,
and converts it into a format that can be loaded into a Hugging Face model.
The logic is identical to :func:`~torchtune.models.convert_weights.tune_to_hf`, but with a different mapping.
Args:
state_dict (Dict[str, torch.Tensor]): State dict in TorchTune's format.
num_heads (int, optional): Number of heads in the model. Defaults to 32.
num_kv_heads (int, optional): Number of heads in the key/value projection layers. Defaults to 32.
dim (int, optional): Dimension of the model. Defaults to 4096.
Returns:
Dict[str, torch.Tensor]: State dict in Hugging Face's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _MISTRAL_REWARD.items()}
head_dim = dim // num_heads

def _permute(t, n_heads):
return (
t.view(n_heads, head_dim // 2, 2, dim)
.transpose(1, 2)
.reshape((head_dim * n_heads), dim)
)

for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
if "q_proj" in key:
value = _permute(value, num_heads)
elif "k_proj" in key:
value = _permute(value, num_kv_heads)
converted_state_dict[new_key] = value

return converted_state_dict
18 changes: 18 additions & 0 deletions torchtune/utils/_checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from torchtune import utils

from torchtune.models import convert_weights
from torchtune.models.mistral import (
mistral_reward_hf_to_tune,
mistral_reward_tune_to_hf,
)
from torchtune.models.phi3 import phi3_hf_to_tune, phi3_tune_to_hf
from torchtune.utils._checkpointing._checkpointer_utils import (
get_path,
Expand Down Expand Up @@ -384,6 +388,13 @@ def load_checkpoint(self) -> Dict[str, Any]:

if self._model_type == ModelType.PHI3_MINI:
converted_state_dict[utils.MODEL_KEY] = phi3_hf_to_tune(merged_state_dict)
elif self._model_type == ModelType.MISTRAL_REWARD:
converted_state_dict[utils.MODEL_KEY] = mistral_reward_hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
else:
converted_state_dict[utils.MODEL_KEY] = convert_weights.hf_to_tune(
merged_state_dict,
Expand Down Expand Up @@ -426,6 +437,13 @@ def save_checkpoint(
# convert the state_dict back to hf format; do this inplace
if self._model_type == ModelType.PHI3_MINI:
state_dict[utils.MODEL_KEY] = phi3_tune_to_hf(state_dict[utils.MODEL_KEY])
elif self._model_type == ModelType.MISTRAL_REWARD:
state_dict[utils.MODEL_KEY] = mistral_reward_tune_to_hf(
state_dict[utils.MODEL_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
else:
state_dict[utils.MODEL_KEY] = convert_weights.tune_to_hf(
state_dict[utils.MODEL_KEY],
Expand Down
1 change: 1 addition & 0 deletions torchtune/utils/_checkpointing/_checkpointer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ModelType(Enum):
LLAMA3 = "llama3"
MISTRAL = "mistral"
PHI3_MINI = "phi3_mini"
MISTRAL_REWARD = "mistral_reward"


def get_path(input_dir: Path, filename: str, missing_ok: bool = False) -> Path:
Expand Down

0 comments on commit 057709e

Please sign in to comment.