Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading mistral reward model checkpoints #911

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 148 additions & 1 deletion tests/torchtune/utils/test_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
from torch import randn

from torchtune.models import llama2
from torchtune.models import llama2, mistral
from torchtune.utils._checkpointing import FullModelHFCheckpointer
from torchtune.utils._checkpointing._checkpointer_utils import safe_torch_load
from torchtune.utils.seed import set_seed
Expand Down Expand Up @@ -294,6 +294,153 @@ def test_save_load_checkpoint_multiple_file(
assert len(output_state_dict_2.keys()) + 1 == len(orig_state_dict_2.keys())


class TestHFMistralRewardModelFullModelCheckpointer:
@pytest.fixture
def weight_dtype(self):
return torch.float16

@pytest.fixture
def state_dict(self, weight_dtype):
"""
State dict for a HF format mistral reward model checkpoint. This state dict is
"complete" and can be loaded into a TorchTune model once correctly converted.
"""
state_dict = {
"model.embed_tokens.weight": randn(_VOCAB_SIZE, _DIM, dtype=weight_dtype),
"model.layers.0.input_layernorm.weight": randn(_DIM, dtype=weight_dtype),
"model.layers.0.self_attn.q_proj.weight": randn(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.k_proj.weight": randn(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.v_proj.weight": randn(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.self_attn.o_proj.weight": randn(
_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.post_attention_layernorm.weight": randn(
_DIM, dtype=weight_dtype
),
"model.layers.0.mlp.gate_proj.weight": randn(
_HIDDEN_DIM, _DIM, dtype=weight_dtype
),
"model.layers.0.mlp.down_proj.weight": randn(
_DIM, _HIDDEN_DIM, dtype=weight_dtype
),
"model.layers.0.mlp.up_proj.weight": randn(
_HIDDEN_DIM, _DIM, dtype=weight_dtype
),
"model.norm.weight": randn(_DIM, dtype=weight_dtype),
"score.weight": randn(1, _DIM, dtype=weight_dtype),
}
return state_dict

@pytest.fixture
def mistral_reward_model_hf_checkpoint(self, tmp_path, state_dict):
"""
Fixture which creates a checkpoint file for the Mistral reward model. The
state dict follows the HF_FORMAT for the checkpoint format.

The state dicts supports testing for a single-file checkpoint.
Multiple file checkpoints are already tested for Llama2.
* The checkpoint contains layer0 + embed + norm + score keys
and can be tested in isolation

The model corresponds to the following config:
* num_layers: 1
* num_heads: 4
* num_kv_heads: 4
* embed_dim: 64
* max_seq_len: 128
* num_classes: 1
* intermediate_dim: 256

"""
checkpoint_file = tmp_path / "mistral_reward_model_hf_checkpoint.pt"

torch.save(state_dict, checkpoint_file)

config = {
"hidden_size": 64,
"num_attention_heads": 4,
"num_key_value_heads": 4,
"num_classes": 1,
}
config_file = Path.joinpath(tmp_path, "config.json")
with config_file.open("w") as f:
json.dump(config, f)

return checkpoint_file

@pytest.fixture
def single_file_checkpointer(
self, mistral_reward_model_hf_checkpoint, tmp_path
) -> FullModelHFCheckpointer:
checkpoint_file = mistral_reward_model_hf_checkpoint
return FullModelHFCheckpointer(
checkpoint_dir=tmp_path,
checkpoint_files=[checkpoint_file],
model_type="MISTRAL_REWARD",
output_dir=tmp_path,
)

def test_load_save_checkpoint_single_file(
self,
single_file_checkpointer: FullModelHFCheckpointer,
mistral_reward_model_hf_checkpoint: Path,
):
"""
Test ``load_checkpoint`` and ``save_checkpoint`` method within the
FullModelHFCheckpointer for a single checkpoint file for a mistral reward model.

We test:
* ``load_checkpoint`` loads the right sets of keys
* Internal state of the checkpointer is correctly updated
* Converted checkpoint can be loaded into the `mistral_classifier` TorchTune implementation
* Saved checkpoint keys match the original checkpoint
"""
# Read the state dict directly from file using torch.load. This will be the state
# dict we test against
checkpoint_file = mistral_reward_model_hf_checkpoint
orig_state_dict = safe_torch_load(checkpoint_file)

# Converted state dict from the checkpointer
state_dict = single_file_checkpointer.load_checkpoint()
# Check that we've loaded all the keys
assert len(state_dict["model"].keys()) == len(orig_state_dict.keys())

# the keys in original state dict should match up with the keys in the weight_map
for key in orig_state_dict.keys():
if "inv_freq" in key:
continue
assert key in single_file_checkpointer._weight_map

# loading the state dict into the model implementation should work correctly
model = mistral.mistral_classifier(
num_classes=1,
vocab_size=_VOCAB_SIZE,
num_layers=1,
num_heads=_NUM_HEADS,
num_kv_heads=_NUM_KV_HEADS,
embed_dim=_DIM,
intermediate_dim=_HIDDEN_DIM,
max_seq_len=128,
)
model.load_state_dict(state_dict["model"])

single_file_checkpointer.save_checkpoint(state_dict, epoch=1)

# Reload the output checkpoint file and compare to the original checkpoint. This
# assumes we know what the name of the file is. This is fine, breaking this logic
# should be something we capture through this test
output_file = Path.joinpath(checkpoint_file.parent, "hf_model_0001_1.pt")
output_state_dict = safe_torch_load(output_file)

assert len(output_state_dict.keys()) == len(orig_state_dict.keys())


class TestCheckpointerUtils:
@pytest.fixture
def model_checkpoint(self, tmp_path):
Expand Down
4 changes: 2 additions & 2 deletions torchtune/models/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def hf_to_tune(
repo in HF (https://huggingface.co/meta-llama/Llama-2-7b-hf).

Args:
state_dict (Dict[str, torch.Tensor]): State dict in Meta's format.
state_dict (Dict[str, torch.Tensor]): State dict in HF's format.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

num_heads (int): Number of heads in the model.
num_kv_heads (int): Number of heads in the key/value projection layers.
dim (int): Dimension of the model.
Expand Down Expand Up @@ -176,7 +176,7 @@ def tune_to_hf(
dim (int): Dimension of the model.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
Dict[str, torch.Tensor]: State dict in HF's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()}
Expand Down
4 changes: 4 additions & 0 deletions torchtune/models/mistral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
# LICENSE file in the root directory of this source tree.

from ._component_builders import lora_mistral, mistral, mistral_classifier
from ._convert_weights import ( # noqa
mistral_reward_hf_to_tune,
mistral_reward_tune_to_hf,
)
from ._model_builders import (
lora_mistral_7b,
lora_mistral_classifier,
Expand Down
121 changes: 121 additions & 0 deletions torchtune/models/mistral/_convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch

from torchtune.models.convert_weights import get_mapped_key

_MISTRAL_REWARD = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight",
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.mlp_norm.scale",
"model.norm.weight": "norm.scale",
"score.weight": "output.weight",
}


def mistral_reward_hf_to_tune(
state_dict: Dict[str, torch.Tensor],
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
head_dim: int = None,
) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from HF's format to TorchTune's format, which contains the weights
of a Mistral reward model.
State dicts from multiple checkpoint files should be consolidated into a single state dict
before calling this function.
The logic is identical to :func:`~torchtune.models.convert_weights.hf_to_tune`, but with a different mapping.

Eg of HF-format state dict can be found in the ``Ray2333/reward-model-Mistral-7B-instruct-Unified-Feedback``
repo in HF.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in HF's format.
num_heads (int): Number of heads in the model.
num_kv_heads (int): Number of heads in the key/value projection layers.
dim (int): Dimension of the model.
head_dim (int): Dimension of the head. If not provided, it will be calculated
as dim // num_heads.

Returns:
Dict[str, torch.Tensor]: State dict in TorchTune's format.
"""
converted_state_dict = {}
if head_dim is None:
head_dim = dim // num_heads

def _permute(t, n_heads):
return (
t.view(n_heads, 2, head_dim // 2, dim)
.transpose(1, 2)
.reshape((head_dim * n_heads), dim)
)

for key, value in state_dict.items():
if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings
new_key = get_mapped_key(key, _MISTRAL_REWARD)
if "q_proj" in key:
value = _permute(value, num_heads)
elif "k_proj" in key:
value = _permute(value, num_kv_heads)
converted_state_dict[new_key] = value
return converted_state_dict


def mistral_reward_tune_to_hf(
state_dict: Dict[str, torch.Tensor],
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096,
) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from TorchTune's format to Hugging Face's format for a Mistral reward model.

This function takes a state dictionary in TorchTune's format, which contains the weights of a Mistral reward model,
and converts it into a format that can be loaded into a Hugging Face model.
The logic is identical to :func:`~torchtune.models.convert_weights.tune_to_hf`, but with a different mapping.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in TorchTune's format.
num_heads (int, optional): Number of heads in the model. Defaults to 32.
num_kv_heads (int, optional): Number of heads in the key/value projection layers. Defaults to 32.
dim (int, optional): Dimension of the model. Defaults to 4096.

Returns:
Dict[str, torch.Tensor]: State dict in Hugging Face's format.

"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _MISTRAL_REWARD.items()}
head_dim = dim // num_heads

def _permute(t, n_heads):
return (
t.view(n_heads, head_dim // 2, 2, dim)
.transpose(1, 2)
.reshape((head_dim * n_heads), dim)
)

for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
if "q_proj" in key:
value = _permute(value, num_heads)
elif "k_proj" in key:
value = _permute(value, num_kv_heads)
converted_state_dict[new_key] = value

return converted_state_dict
18 changes: 18 additions & 0 deletions torchtune/utils/_checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from torchtune import utils

from torchtune.models import convert_weights
from torchtune.models.mistral import (
mistral_reward_hf_to_tune,
mistral_reward_tune_to_hf,
)
from torchtune.models.phi3 import phi3_hf_to_tune, phi3_tune_to_hf
from torchtune.utils._checkpointing._checkpointer_utils import (
get_path,
Expand Down Expand Up @@ -384,6 +388,13 @@ def load_checkpoint(self) -> Dict[str, Any]:

if self._model_type == ModelType.PHI3_MINI:
converted_state_dict[utils.MODEL_KEY] = phi3_hf_to_tune(merged_state_dict)
elif self._model_type == ModelType.MISTRAL_REWARD:
converted_state_dict[utils.MODEL_KEY] = mistral_reward_hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
else:
converted_state_dict[utils.MODEL_KEY] = convert_weights.hf_to_tune(
merged_state_dict,
Expand Down Expand Up @@ -426,6 +437,13 @@ def save_checkpoint(
# convert the state_dict back to hf format; do this inplace
if self._model_type == ModelType.PHI3_MINI:
state_dict[utils.MODEL_KEY] = phi3_tune_to_hf(state_dict[utils.MODEL_KEY])
elif self._model_type == ModelType.MISTRAL_REWARD:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

state_dict[utils.MODEL_KEY] = mistral_reward_tune_to_hf(
state_dict[utils.MODEL_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
else:
state_dict[utils.MODEL_KEY] = convert_weights.tune_to_hf(
state_dict[utils.MODEL_KEY],
Expand Down
1 change: 1 addition & 0 deletions torchtune/utils/_checkpointing/_checkpointer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ModelType(Enum):
LLAMA3 = "llama3"
MISTRAL = "mistral"
PHI3_MINI = "phi3_mini"
MISTRAL_REWARD = "mistral_reward"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this could be mistral_classifier?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think mistral_reward makes sense unless you think there will be classifier checkpoints we'll need to load?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave as-is for now, we can always generalise later if there's use cases.



def get_path(input_dir: Path, filename: str, missing_ok: bool = False) -> Path:
Expand Down
Loading