Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing some issues with our support for 70/405B models #941

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 76 additions & 85 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import json
import re
import shutil
import sys
from pathlib import Path
from typing import Optional

from safetensors.torch import load_file as load_safetensors_file
import torch

from torchao._models.llama.model import ModelArgs
Expand All @@ -24,63 +25,49 @@ def convert_hf_checkpoint(
) -> None:
if model_name is None:
model_name = checkpoint_dir.name

# Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
# need to be copied into model.pth.
# Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
# weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
# currently supported.
# Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
is_llama3 = "Llama-3" in model_name
if is_llama3:
# Check if we have multiple original/consolidated.NN.pth files and report error
# if we do for Llama 3.
original_dir = checkpoint_dir / "original"
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)]
if len(bin_files) > 1:
raise ValueError(
f"Multiple consolidated.NN.pth files found in {original_dir}. "
"Merging them into one model.pth file is not supported for Llama 3.")


config = ModelArgs.from_name(model_name)
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
if not is_llama3:
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"

assert model_map_json.is_file()

with open(model_map_json) as json_map:
bin_index = json.load(json_map)

weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
else:
# There is no separate pytorch_model.bin.index.json file for llama3.
# Instead, we will just use all original/consolidated.NN.pth files.
# so, we use model.safetensors.index.json
weight_map = None
original_dir = checkpoint_dir / "original"
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)}

model_map_json_safetensors = checkpoint_dir / 'model.safetensors.index.json'
model_map_json_pytorch = checkpoint_dir / "pytorch_model.bin.index.json"
model_map_json = None

try:
assert model_map_json_safetensors.is_file()
model_map_json = model_map_json_safetensors
print(f"Found safetensors index at {model_map_json_safetensors}")
except AssertionError:
print(f"{model_map_json_safetensors} not found")
if model_map_json is None:
try:
assert model_map_json_pytorch.is_file()
model_map_json = model_map_json_pytorch
print(f"Found pytorch index at {model_map_json_pytorch}")
except AssertionError:
print(f"{model_map_json_pytorch} not found")

if model_map_json is None: raise Exception("No model map found!")

with open(model_map_json) as json_map:
bin_index = json.load(json_map)

weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}

def permute(w, n_head):
dim = config.dim
Expand All @@ -92,40 +79,44 @@ def permute(w, n_head):

merged_result = {}
for file in sorted(bin_files):
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)
if "safetensors" in str(file):
state_dict = load_safetensors_file(str(file), device="cpu")
merged_result.update(state_dict)
else:
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)
final_result = {}
if weight_map is not None:
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]

final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_head)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
else:
final_result = merged_result
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]

final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_head)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")
if is_llama3:
original_dir = checkpoint_dir / "original"
if 'llama-3-' in model_name.lower() or 'llama-3.1-' in model_name.lower():
if 'llama-3.1-405b' in model_name.lower():
original_dir = checkpoint_dir / "original" / "mp16"
else:
original_dir = checkpoint_dir / "original"
tokenizer_model = original_dir / "tokenizer.model"
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")
Expand Down
2 changes: 1 addition & 1 deletion scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
from huggingface_hub import snapshot_download
os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
try:
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors")
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token)
except HTTPError as e:
if e.response.status_code == 401:
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
Expand Down
8 changes: 7 additions & 1 deletion torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ def from_name(cls, name: str):
"stories15M": dict(n_layer=6, n_head=6, dim=288),
"stories110M": dict(n_layer=12, n_head=12, dim=768),
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
"Llama-3.1-8B": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, use_scaled_rope=True)
"Llama-3.1-8B": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000, use_scaled_rope=True),
"Llama-3.1-70B": dict(block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000,
use_scaled_rope=True
),
"Llama-3.1-405B": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000,
use_scaled_rope=True
),
}

# this is a model specific variable that controls whether index_put is used for the kv_cache update,
Expand Down
Loading