Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba2 conversion script for original models #32580

Merged
merged 17 commits into from
Aug 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,179 @@
"""This script can be used to convert checkpoints provided in the `mamba2_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""

import argparse
import json
from functools import partial
from os import path
from typing import Dict, Optional

import torch
from safetensors import safe_open
from safetensors.torch import save_model

from transformers import LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM
from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM


def convert_mamba2_checkpoint_file_to_huggingface_model_file(
mamba2_checkpoint_path: str, tokenizer_model_path: str, output_dir: str
) -> None:
hf_config = Mamba2Config()
hf_model = Mamba2ForCausalLM(hf_config)
def load_state_dict_from_safetensors(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
# Load weights and config from paths
original_state_dict = {}
with safe_open(mamba2_checkpoint_path, framework="pt") as f:
with safe_open(path.join(mamba2_checkpoint_path, ckpt_name), framework="pt") as f:
for k in f.keys():
newk = k.removeprefix("model.")
original_state_dict[newk] = f.get_tensor(k).clone()
return original_state_dict


def load_state_dict_from_torch(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
return torch.load(path.join(mamba2_checkpoint_path, ckpt_name), map_location="cpu")


def convert_ssm_config_to_hf_config(config_ssm: Dict, mamba2_model_dict: Dict) -> Mamba2Config:
"""Convert a Mamba2Config from mamba_ssm to a Mamba2Config from here."""
hf_config = Mamba2Config()

# Switch to a different dict depending on model type
config_dict = mamba2_model_dict

# Set important values from config and recalculate other resulting entries
hf_config.hidden_size = config_ssm[config_dict["hidden_size"]]
hf_config.num_heads = (hf_config.hidden_size * hf_config.expand) // hf_config.head_dim
hf_config.num_hidden_layers = config_ssm[config_dict["num_hidden_layers"]]
hf_config.n_groups = config_ssm.get(config_dict["n_groups"], 1)
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
hf_config.bos_token_id = config_dict["bos_token_id"]
hf_config.pad_token_id = config_dict["pad_token_id"]
hf_config.eos_token_id = config_dict["eos_token_id"]

# Padded vocab size, mostly of 16 but 32 is also very common in different models
vocab_size = config_ssm["vocab_size"]
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
if (vocab_size % pad_vocab_size_multiple) != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
hf_config.vocab_size = vocab_size

return hf_config


def load_and_save_tokenizer(
mamba2_model_type: str,
output_dir: str,
tokenizer_model_path: Optional[str] = None,
) -> None:
tokenizer = None

# Load tokenizer
if tokenizer_model_path is not None and mamba2_model_type == "codestral":
tokenizer_class = LlamaTokenizerFast
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
elif mamba2_model_type == "mamba_ssm":
tokenizer = GPTNeoXTokenizerFast.from_pretrained("state-spaces/mamba-130m-hf", padding_side="left")

# Save tokenizer
if tokenizer is not None:
tokenizer.save_pretrained(output_dir)


_MAMBA2_MODELS_DICT = {
"codestral": {
"hidden_size": "dim",
"num_hidden_layers": "n_layers",
"n_groups": "n_groups",
"bos_token_id": 0,
"pad_token_id": 1,
"eos_token_id": 2,
"config_name": "params.json",
"load_state_dict": partial(load_state_dict_from_safetensors, ckpt_name="consolidated.safetensors"),
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "codestral"),
},
"mamba_ssm": {
"hidden_size": "d_model",
"num_hidden_layers": "n_layer",
"n_groups": "ngroups",
"bos_token_id": 0,
"pad_token_id": 0,
"eos_token_id": 0,
"config_name": "config.json",
"load_state_dict": partial(load_state_dict_from_torch, ckpt_name="pytorch_model.bin"),
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "mamba_ssm"),
},
}


def convert_mamba2_checkpoint_file_to_huggingface_model_file(
mamba2_checkpoint_path: str,
mamba2_model_type: str,
precision: str,
output_dir: str,
tokenizer_model_path: Optional[str] = None,
) -> None:
mamba2_model_dict = _MAMBA2_MODELS_DICT[mamba2_model_type]

# Load and save config based on name
config_path = path.join(mamba2_checkpoint_path, mamba2_model_dict["config_name"])
with open(config_path, "r", encoding="utf-8") as json_file:
config = json.load(json_file)
hf_config = convert_ssm_config_to_hf_config(config_ssm=config, mamba2_model_dict=mamba2_model_dict)
hf_config.save_pretrained(output_dir)

# Load state dict of the original model and transfer to hf model
original_state_dict = mamba2_model_dict["load_state_dict"](mamba2_checkpoint_path=mamba2_checkpoint_path)
hf_model = Mamba2ForCausalLM(hf_config)
hf_model.load_state_dict(original_state_dict)

# Save new model to pytorch_dump_path
hf_model.to(torch.bfloat16).save_pretrained(output_dir)
tokenizer_class = LlamaTokenizerFast
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
tokenizer.save_pretrained(output_dir)
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
save_model(hf_model.to(dtype), path.join(output_dir, "model.safetensors"), metadata={"format": "pt"})

# Load and save tokenizer
mamba2_model_dict["load_and_save_tokenizer"](output_dir=output_dir, tokenizer_model_path=tokenizer_model_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--mamba2_checkpoint_file",
"--mamba2_checkpoint_directory",
type=str,
required=True,
help="Path to a `pytorch_model.bin` mamba2_ssm checkpoint file to be converted.",
help="Path to a directory containing the `pytorch_model.bin` or `.safetensors` mamba2_ssm checkpoint file to be converted.",
)
parser.add_argument(
"-c",
"--tokenizer_model_path",
"-m",
"--mamba2_model_type",
type=str,
default="mamba_ssm",
const="mamba_ssm",
required=True,
choices=("codestral", "mamba_ssm"),
help="The model type the conversion will be performed on. Can choose from either `codestral` or `mamba_ssm`.",
)
Comment on lines 154 to +163
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User-dependent choice on conversion. For now, defaults to mamba_ssm as codestral has a repo which can be used so I'd expect more people to use it locally for the paper models.

parser.add_argument(
"-p",
"--precision",
type=str,
default="fp16",
const="fp16",
required=True,
help="Path to a `config.json` file corresponding to a Mamba2Config of the original mamba2_ssm model.",
choices=("fp32", "fp16", "bf16"),
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
)
parser.add_argument(
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
)
parser.add_argument(
"-t",
"--tokenizer_model_path",
type=str,
default=None,
required=False,
help="Path to a `codestral` tokenizer file.",
)
args = parser.parse_args()

convert_mamba2_checkpoint_file_to_huggingface_model_file(
args.mamba2_checkpoint_file, args.tokenizer_model_path, args.output_dir
args.mamba2_checkpoint_directory,
args.mamba2_model_type,
args.precision,
args.output_dir,
args.tokenizer_model_path,
)
Loading