Skip to content

Commit

Permalink
fixing some issues with our support for 70/405B models (#941)
Browse files Browse the repository at this point in the history
Summary: download and convert scripts needed to be updated alongside
model.py config files

Test Plan: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-70B/model.pth

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles authored Sep 26, 2024
1 parent 4b5b5ee commit 637ed13
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 87 deletions.
161 changes: 76 additions & 85 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import json
import re
import shutil
import sys
from pathlib import Path
from typing import Optional

from safetensors.torch import load_file as load_safetensors_file
import torch

from torchao._models.llama.model import ModelArgs
Expand All @@ -24,63 +25,49 @@ def convert_hf_checkpoint(
) -> None:
if model_name is None:
model_name = checkpoint_dir.name

# Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
# need to be copied into model.pth.
# Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
# weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
# currently supported.
# Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
is_llama3 = "Llama-3" in model_name
if is_llama3:
# Check if we have multiple original/consolidated.NN.pth files and report error
# if we do for Llama 3.
original_dir = checkpoint_dir / "original"
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)]
if len(bin_files) > 1:
raise ValueError(
f"Multiple consolidated.NN.pth files found in {original_dir}. "
"Merging them into one model.pth file is not supported for Llama 3.")


config = ModelArgs.from_name(model_name)
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
if not is_llama3:
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"

assert model_map_json.is_file()

with open(model_map_json) as json_map:
bin_index = json.load(json_map)

weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
else:
# There is no separate pytorch_model.bin.index.json file for llama3.
# Instead, we will just use all original/consolidated.NN.pth files.
# so, we use model.safetensors.index.json
weight_map = None
original_dir = checkpoint_dir / "original"
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)}

model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
model_map_json = None

try:
assert model_map_json_safetensors.is_file()
model_map_json = model_map_json_safetensors
print(f"Found safetensors index at {model_map_json_safetensors}")
except AssertionError:
print(f"{model_map_json_safetensors} not found")
if model_map_json is None:
try:
assert model_map_json_pytorch.is_file()
model_map_json = model_map_json_pytorch
print(f"Found pytorch index at {model_map_json_pytorch}")
except AssertionError:
print(f"{model_map_json_pytorch} not found")

if model_map_json is None: raise Exception("No model map found!")

with open(model_map_json) as json_map:
bin_index = json.load(json_map)

weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}

def permute(w, n_head):
dim = config.dim
Expand All @@ -92,40 +79,44 @@ def permute(w, n_head):

merged_result = {}
for file in sorted(bin_files):
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)
if "safetensors" in str(file):
state_dict = load_safetensors_file(str(file), device="cpu")
merged_result.update(state_dict)
else:
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)
final_result = {}
if weight_map is not None:
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]

final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_head)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
else:
final_result = merged_result
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]

final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_head)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")
if is_llama3:
original_dir = checkpoint_dir / "original"
if 'llama-3-' in model_name.lower() or 'llama-3.1-' in model_name.lower():
if 'llama-3.1-405b' in model_name.lower():
original_dir = checkpoint_dir / "original" / "mp16"
else:
original_dir = checkpoint_dir / "original"
tokenizer_model = original_dir / "tokenizer.model"
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")
Expand Down
2 changes: 1 addition & 1 deletion scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
from huggingface_hub import snapshot_download
os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
try:
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors")
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token)
except HTTPError as e:
if e.response.status_code == 401:
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
Expand Down
8 changes: 7 additions & 1 deletion torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ def from_name(cls, name: str):
"stories15M": dict(n_layer=6, n_head=6, dim=288),
"stories110M": dict(n_layer=12, n_head=12, dim=768),
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
"Llama-3.1-8B": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, use_scaled_rope=True)
"Llama-3.1-8B": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, use_scaled_rope=True),
"Llama-3.1-70B": dict(block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000,
use_scaled_rope=True
),
"Llama-3.1-405B": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000,
use_scaled_rope=True
),
}

# this is a model specific variable that controls whether index_put is used for the kv_cache update,
Expand Down

0 comments on commit 637ed13

Please sign in to comment.