Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update hf weight conversion script to llama 3 #551

Merged
merged 3 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ This is the reverse conversion for `convert_llama_weights_to_hf.py` script from
- Copy file params.json from the official llama download into that directory.
- Run the conversion script. `model-path` can be a Hugging Face hub model or a local hf model directory.
```
python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Llama-2-70b-chat-hf --output-dir test70B --model-size 70B
python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Meta-Llama-3-70B-Instruct --output-dir test70B --model-size 70B
```

## Step 1: Run inference
Checkout the official llama inference [repo](https://github.com/facebookresearch/llama). Test using chat or text completion.
Checkout the official llama 3 inference [repo](https://github.com/meta-llama/llama3). Test using chat or text completion.
```
torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_2_dir}/tokenizer.model
torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_3_dir}/tokenizer.model
```

For validation, please compare the converted weights with official llama 2 weights
```
python compare_llama_weights.py test70B ${llama_2_70b_chat_dir}
python compare_llama_weights.py test70B ${Llama-3-70B-Instruct_dir}
```
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,25 @@ def main() -> None:
assert len(one) == len(
two
), "shard should have the same length: {} != {}".format(len(one), len(two))
one = sorted(one.items(), key=lambda x: x[0])
two = sorted(two.items(), key=lambda x: x[0])

for _, (v, w) in enumerate(zip(one.items(), two.items())):
for _, (v, w) in enumerate(zip(one, two)):
assert v[0] == w[0], "{} != {}".format(v[0], w[0])
assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format(
v[0], v[1].shape, w[1].shape
)

delta = (v[1] - w[1]).abs().max().item()
deltas.append((i, v[0], delta))
deltas.append((i, v[0], delta, w[1].abs().mean().item()))
del one
del two
gc.collect()

deltas = sorted(deltas, key=lambda x: x[-1], reverse=True)
deltas = sorted(deltas, key=lambda x: x[-2], reverse=True)
print("Top 10 largest deltas:")
for i, k, v in deltas[:10]:
print(f" shard {i} {k}: {v}")
for i, k, delta, value in deltas[:10]:
print(f" shard {i} {k}: {delta} vs {value}")


if __name__ == "__main__":
Expand Down
22 changes: 14 additions & 8 deletions src/llama_recipes/tools/convert_hf_weights_to_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

NUM_SHARDS = {
"7B": 1,
"8B": 1,
"13B": 2,
"34B": 4,
"30B": 4,
Expand All @@ -30,15 +31,12 @@ def write_model(model_path, model_size, output_base_path):
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = (
1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
).to(dtype)
llama_version = 3 if params.get("vocab_size") == 128256 else 2

if "n_kv_heads" in params:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
key_value_dim = dim // num_key_value_heads
num_local_key_value_heads = num_key_value_heads // num_shards
key_value_dim = dims_per_head * num_key_value_heads
else: # compatibility with other checkpoints
num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard
Expand Down Expand Up @@ -72,7 +70,10 @@ def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
for i, tensor in enumerate(tensors):
state_dict[i][name] = tensor.clone()

insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
concat_dim = 0 if llama_version == 3 else 1
insert_chunk(
"tok_embeddings.weight", loaded["model.embed_tokens.weight"], concat_dim
)
insert("norm.weight", loaded["model.norm.weight"])
insert_chunk("output.weight", loaded["lm_head.weight"], 0)

Expand Down Expand Up @@ -136,7 +137,12 @@ def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
f"layers.{layer_i}.ffn_norm.weight",
loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
)
insert("rope.freqs", inv_freq)
if llama_version != 3:
base = 10000.0
inv_freq = (
1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
).to(dtype)
insert("rope.freqs", inv_freq)

for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
torch.save(
Expand Down
Loading