From 637ed1390e3a6cda1bcd78d0f94971754e86241a Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Thu, 26 Sep 2024 03:33:30 -0400 Subject: [PATCH] fixing some issues with our support for 70/405B models (#941) Summary: download and convert scripts needed to be updated alongside model.py config files Test Plan: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-70B/model.pth Reviewers: Subscribers: Tasks: Tags: --- scripts/convert_hf_checkpoint.py | 161 +++++++++++++++---------------- scripts/download.py | 2 +- torchao/_models/llama/model.py | 8 +- 3 files changed, 84 insertions(+), 87 deletions(-) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 7b0f76903..3098c818b 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -8,9 +8,10 @@ import json import re import shutil +import sys from pathlib import Path from typing import Optional - +from safetensors.torch import load_file as load_safetensors_file import torch from torchao._models.llama.model import ModelArgs @@ -24,63 +25,49 @@ def convert_hf_checkpoint( ) -> None: if model_name is None: model_name = checkpoint_dir.name - - # Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files - # need to be copied into model.pth. - # Llama 3 70B can't be easily merged into one model.pth file, though, since names of the - # weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not - # currently supported. - # Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken - is_llama3 = "Llama-3" in model_name - if is_llama3: - # Check if we have multiple original/consolidated.NN.pth files and report error - # if we do for Llama 3. - original_dir = checkpoint_dir / "original" - pattern = re.compile(r"^consolidated\.\d{2}\.pth$") - bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)] - if len(bin_files) > 1: - raise ValueError( - f"Multiple consolidated.NN.pth files found in {original_dir}. " - "Merging them into one model.pth file is not supported for Llama 3.") - - config = ModelArgs.from_name(model_name) print(f"Model config {config.__dict__}") # Load the json file containing weight mapping - if not is_llama3: - model_map_json = checkpoint_dir / "pytorch_model.bin.index.json" - - assert model_map_json.is_file() - - with open(model_map_json) as json_map: - bin_index = json.load(json_map) - - weight_map = { - "model.embed_tokens.weight": "tok_embeddings.weight", - "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", - "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", - "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", - "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, - 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', - "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", - "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", - "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", - "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", - "model.norm.weight": "norm.weight", - "lm_head.weight": "output.weight", - } - bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} - else: - # There is no separate pytorch_model.bin.index.json file for llama3. - # Instead, we will just use all original/consolidated.NN.pth files. - # so, we use model.safetensors.index.json - weight_map = None - original_dir = checkpoint_dir / "original" - pattern = re.compile(r"^consolidated\.\d{2}\.pth$") - bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)} - + model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json' + model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json" + model_map_json = None + + try: + assert model_map_json_safetensors.is_file() + model_map_json = model_map_json_safetensors + print(f"Found safetensors index at {model_map_json_safetensors}") + except AssertionError: + print(f"{model_map_json_safetensors} not found") + if model_map_json is None: + try: + assert model_map_json_pytorch.is_file() + model_map_json = model_map_json_pytorch + print(f"Found pytorch index at {model_map_json_pytorch}") + except AssertionError: + print(f"{model_map_json_pytorch} not found") + + if model_map_json is None: raise Exception("No model map found!") + + with open(model_map_json) as json_map: + bin_index = json.load(json_map) + + weight_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, + 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} def permute(w, n_head): dim = config.dim @@ -92,40 +79,44 @@ def permute(w, n_head): merged_result = {} for file in sorted(bin_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) - merged_result.update(state_dict) + if "safetensors" in str(file): + state_dict = load_safetensors_file(str(file), device="cpu") + merged_result.update(state_dict) + else: + state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + merged_result.update(state_dict) final_result = {} - if weight_map is not None: - for key, value in merged_result.items(): - if "layers" in key: - abstract_key = re.sub(r'(\d+)', '{}', key) - layer_num = re.search(r'\d+', key).group(0) - new_key = weight_map[abstract_key] - if new_key is None: - continue - new_key = new_key.format(layer_num) - else: - new_key = weight_map[key] - - final_result[new_key] = value - - for key in tuple(final_result.keys()): - if "wq" in key: - q = final_result[key] - k = final_result[key.replace("wq", "wk")] - v = final_result[key.replace("wq", "wv")] - q = permute(q, config.n_head) - k = permute(k, config.n_local_heads) - final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) - del final_result[key] - del final_result[key.replace("wq", "wk")] - del final_result[key.replace("wq", "wv")] - else: - final_result = merged_result + for key, value in merged_result.items(): + if "layers" in key: + abstract_key = re.sub(r'(\d+)', '{}', key) + layer_num = re.search(r'\d+', key).group(0) + new_key = weight_map[abstract_key] + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = weight_map[key] + + final_result[new_key] = value + + for key in tuple(final_result.keys()): + if "wq" in key: + q = final_result[key] + k = final_result[key.replace("wq", "wk")] + v = final_result[key.replace("wq", "wv")] + q = permute(q, config.n_head) + k = permute(k, config.n_local_heads) + final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) + del final_result[key] + del final_result[key.replace("wq", "wk")] + del final_result[key.replace("wq", "wv")] print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") torch.save(final_result, checkpoint_dir / "model.pth") - if is_llama3: - original_dir = checkpoint_dir / "original" + if 'llama-3-' in model_name.lower() or 'llama-3.1-' in model_name.lower(): + if 'llama-3.1-405b' in model_name.lower(): + original_dir = checkpoint_dir / "original" / "mp16" + else: + original_dir = checkpoint_dir / "original" tokenizer_model = original_dir / "tokenizer.model" tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}") diff --git a/scripts/download.py b/scripts/download.py index 3fc89e712..571e03adb 100644 --- a/scripts/download.py +++ b/scripts/download.py @@ -15,7 +15,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) - from huggingface_hub import snapshot_download os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) try: - snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors") + snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token) except HTTPError as e: if e.response.status_code == 401: print("You need to pass a valid `--hf_token=...` to download private checkpoints.") diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 92448b599..de1f31197 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -72,7 +72,13 @@ def from_name(cls, name: str): "stories15M": dict(n_layer=6, n_head=6, dim=288), "stories110M": dict(n_layer=12, n_head=12, dim=768), "Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000), - "Llama-3.1-8B": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, use_scaled_rope=True) + "Llama-3.1-8B": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, use_scaled_rope=True), + "Llama-3.1-70B": dict(block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000, + use_scaled_rope=True + ), + "Llama-3.1-405B": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000, + use_scaled_rope=True + ), } # this is a model specific variable that controls whether index_put is used for the kv_cache update,