Skip to content

Commit

Permalink
Vicuna Models checkpoints transfer script (#1657)
Browse files Browse the repository at this point in the history
* Add Vicuna tokenizer and preset

* Add vicuna tokenizer and preset

* Sort the imports as per isort lib

* fix lint errors

* Add vicuna preset to llam2

* remove separate vicuna checkpoint script

* indentation fix
  • Loading branch information
sineeli committed Jun 7, 2024
1 parent 30b34d3 commit 50e0414
Showing 1 changed file with 132 additions and 110 deletions.
242 changes: 132 additions & 110 deletions tools/checkpoint_conversion/convert_llama_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,156 +11,156 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import os
import shutil
import tempfile
import traceback

import numpy as np
import torch
from absl import app
from absl import flags
from keras import ops
from transformers import AutoTokenizer
from transformers import LlamaForCausalLM
os.environ["KERAS_BACKEND"] = "torch"

import numpy as np # noqa: E402
import torch # noqa: E402
from absl import app # noqa: E402
from absl import flags # noqa: E402
from keras import ops # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from transformers import LlamaForCausalLM # noqa: E402

from keras_nlp import upload_preset
from keras_nlp.models import LlamaBackbone
from keras_nlp.models import LlamaCausalLMPreprocessor
from keras_nlp.models import LlamaTokenizer
from keras_nlp import upload_preset # noqa: E402
from keras_nlp.models import LlamaBackbone # noqa: E402
from keras_nlp.models import LlamaCausalLMPreprocessor # noqa: E402
from keras_nlp.models import LlamaTokenizer # noqa: E402

PRESET_MAP = {
"llama2_7b_en": "meta-llama/Llama-2-7b-hf",
"llama2_instruct_7b_en": "meta-llama/Llama-2-7b-chat-hf",
"vicuna_1.5_7b_en": "lmsys/vicuna-7b-v1.5",
}

torch_dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
)

flags.DEFINE_string(
name="validate_dtype",
default="bfloat16",
help=(
"The dtype of the two models while validating numerics."
"can be 'float32', 'float16', or 'bfloat16'"
),
)

flags.DEFINE_string(
name="save_dtype",
default="bfloat16",
help=(
"The dtype of the two models while validating numerics."
"can be 'float32', 'float16', or 'bfloat16'"
),
)


flags.DEFINE_string(
name="upload_link",
default=None,
help=(
"The link to upload the model. can be in these formats: "
"`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`, "
"`hf://[<HF_USERNAME>/]<MODEL>`"
),
)


def convert_checkpoints(keras_nlp_model, hf_model):
config = hf_model.config

keras_nlp_model.token_embedding.embeddings.assign(
hf_model.model.embed_tokens.weight.detach().cpu().numpy()
hf_model.model.embed_tokens.weight
)

for i in range(keras_nlp_model.num_layers):
keras_nlp_model.transformer_layers[
i
]._self_attention_layer._key_dense.set_weights(
[
hf_model.model.layers[i]
.self_attn.k_proj.weight.T.reshape(
hf_model.model.layers[i].self_attn.k_proj.weight.T.reshape(
config.hidden_size,
config.num_key_value_heads,
config.hidden_size // config.num_attention_heads,
)
.detach()
.cpu()
.numpy()
]
)
keras_nlp_model.transformer_layers[
i
]._self_attention_layer._query_dense.set_weights(
[
hf_model.model.layers[i]
.self_attn.q_proj.weight.T.reshape(
hf_model.model.layers[i].self_attn.q_proj.weight.T.reshape(
config.hidden_size,
config.num_attention_heads,
config.hidden_size // config.num_attention_heads,
)
.detach()
.cpu()
.numpy()
]
)
keras_nlp_model.transformer_layers[
i
]._self_attention_layer._value_dense.set_weights(
[
hf_model.model.layers[i]
.self_attn.v_proj.weight.T.reshape(
hf_model.model.layers[i].self_attn.v_proj.weight.T.reshape(
config.hidden_size,
config.num_key_value_heads,
config.hidden_size // config.num_attention_heads,
)
.detach()
.cpu()
.numpy()
]
)
keras_nlp_model.transformer_layers[
i
]._self_attention_layer._output_dense.set_weights(
[
hf_model.model.layers[i]
.self_attn.o_proj.weight.T.reshape(
hf_model.model.layers[i].self_attn.o_proj.weight.T.reshape(
config.num_attention_heads,
config.hidden_size // config.num_attention_heads,
config.hidden_size,
)
.detach()
.cpu()
.numpy()
]
)
keras_nlp_model.transformer_layers[
i
]._self_attention_layernorm.set_weights(
[
hf_model.model.layers[i]
.input_layernorm.weight.detach()
.cpu()
.numpy()
]
[hf_model.model.layers[i].input_layernorm.weight]
)
keras_nlp_model.transformer_layers[
i
]._feedforward_intermediate_dense.set_weights(
[
hf_model.model.layers[i]
.mlp.up_proj.weight.T.detach()
.cpu()
.numpy()
]
[hf_model.model.layers[i].mlp.up_proj.weight.T]
)
keras_nlp_model.transformer_layers[
i
]._feedforward_output_dense.set_weights(
[
hf_model.model.layers[i]
.mlp.down_proj.weight.T.detach()
.cpu()
.numpy()
]
[hf_model.model.layers[i].mlp.down_proj.weight.T]
)
keras_nlp_model.transformer_layers[
i
]._feedforward_gate_dense.set_weights(
[
hf_model.model.layers[i]
.mlp.gate_proj.weight.T.detach()
.cpu()
.numpy()
]
[hf_model.model.layers[i].mlp.gate_proj.weight.T]
)
keras_nlp_model.transformer_layers[
i
]._feedforward_layernorm.set_weights(
[
hf_model.model.layers[i]
.post_attention_layernorm.weight.detach()
.cpu()
.numpy()
]
[hf_model.model.layers[i].post_attention_layernorm.weight.detach()]
)

keras_nlp_model.layer_norm.set_weights(
[hf_model.model.norm.weight.detach().cpu().numpy()]
)
keras_nlp_model.layer_norm.set_weights([hf_model.model.norm.weight])
keras_nlp_model.token_embedding.reverse_embeddings.assign(
hf_model.lm_head.weight.T.detach().cpu().numpy()
hf_model.lm_head.weight.T
)


Expand All @@ -176,7 +176,7 @@ def test_model(
hf_outputs = hf_model(
**hf_model_tokenizer(["What is Keras?"], return_tensors="pt")
)
hf_output_logits = hf_outputs.logits.detach().cpu().numpy()
hf_output_logits = ops.convert_to_numpy(hf_outputs.logits)

keras_nlp_preprocessor = LlamaCausalLMPreprocessor(keras_nlp_tokenizer)
keras_nlp_output = keras_nlp_model(
Expand All @@ -187,7 +187,7 @@ def test_model(
)
keras_nlp_logits = ops.convert_to_numpy(keras_nlp_logits)

# High tolerence since bfloat16 is used as the default dtype for Llama
# High tolerence when bfloat16 is used as the default dtype for Llama
try:
np.testing.assert_allclose(
keras_nlp_logits, hf_output_logits, atol=1e-4
Expand All @@ -201,7 +201,7 @@ def test_model(

def test_tokenizer(keras_nlp_tokenizer, hf_tokenizer):
hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt")
hf_output = hf_output["input_ids"].detach().cpu().numpy()
hf_output = ops.convert_to_numpy(hf_output["input_ids"])
keras_nlp_preprocessor = LlamaCausalLMPreprocessor(keras_nlp_tokenizer)
keras_nlp_output = keras_nlp_preprocessor(
["What is Keras?"], sequence_length=6
Expand All @@ -219,55 +219,77 @@ def main(_):
f"of {','.join(PRESET_MAP.keys())}"
)
preset = FLAGS.preset
upload_link = FLAGS.upload_link
hf_preset = PRESET_MAP[preset]
torch_dtype = torch_dtype_map.get(FLAGS.validate_dtype)

# === Load the Huggingface model ===
hf_model = LlamaForCausalLM.from_pretrained(
hf_preset, torch_dtype=torch.bfloat16
)
hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
hf_model.eval()
print("\n-> Huggingface model and tokenizer loaded")

# === Load the KerasNLP model ===
backbone_kwargs = dict(
vocabulary_size=hf_model.config.vocab_size,
hidden_dim=hf_model.config.hidden_size,
num_layers=hf_model.config.num_hidden_layers,
num_query_heads=hf_model.config.num_attention_heads,
num_key_value_heads=hf_model.config.num_key_value_heads,
intermediate_dim=hf_model.config.intermediate_size,
layer_norm_epsilon=hf_model.config.rms_norm_eps,
rope_max_wavelength=hf_model.config.rope_theta,
dtype="bfloat16",
)
keras_nlp_model = LlamaBackbone(**backbone_kwargs)
# === Create the temporary save directories ===
temp_dir = tempfile.mkdtemp()

try:
# === Load the Huggingface model ===
hf_model = LlamaForCausalLM.from_pretrained(
hf_preset, torch_dtype=torch_dtype
)
hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
hf_model.eval()
print(
f"\n-> Huggingface model and tokenizer loaded with dtype: {FLAGS.validate_dtype}"
)

# === Load the KerasNLP model ===
backbone_kwargs = dict(
vocabulary_size=hf_model.config.vocab_size,
hidden_dim=hf_model.config.hidden_size,
num_layers=hf_model.config.num_hidden_layers,
num_query_heads=hf_model.config.num_attention_heads,
num_key_value_heads=hf_model.config.num_key_value_heads,
intermediate_dim=hf_model.config.intermediate_size,
layer_norm_epsilon=hf_model.config.rms_norm_eps,
rope_max_wavelength=hf_model.config.rope_theta,
dtype=FLAGS.validate_dtype,
)
keras_nlp_model = LlamaBackbone(**backbone_kwargs)

# === Get the tokenizer from the Huggingface model ===
tokenizer_path = hf_tokenizer.vocab_file
keras_nlp_tokenizer = LlamaTokenizer(tokenizer_path)
print("\n-> Keras 3 model and tokenizer loaded.")

# === Port the weights ===
convert_checkpoints(keras_nlp_model, hf_model)
print("\n-> Weight transfer done.")

# === Check that the models and tokenizers outputs match ===
test_tokenizer(keras_nlp_tokenizer, hf_tokenizer)
test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer)
print("\n-> Tests passed!")

# === Get the tokenizer from the Huggingface model ===
tokenizer_path = hf_tokenizer.vocab_file
keras_nlp_tokenizer = LlamaTokenizer(tokenizer_path)
print("\n-> Keras 3 model and tokenizer loaded.")
keras_nlp_model.save_weights(os.path.join(temp_dir, "model.weights.h5"))
print(f"\n-> Saved the model weights in {FLAGS.validate_dtype}")

# === Port the weights ===
convert_checkpoints(keras_nlp_model, hf_model)
print("\n-> Weight transfer done.")
del keras_nlp_model, hf_model
gc.collect()

# === Check that the models and tokenizers outputs match ===
test_tokenizer(keras_nlp_tokenizer, hf_tokenizer)
test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer)
print("\n-> Tests passed!")
# === Save the weights again in user defined dtype ===
backbone_kwargs["dtype"] = FLAGS.save_dtype
keras_nlp_model = LlamaBackbone(**backbone_kwargs)
keras_nlp_model.load_weights(os.path.join(temp_dir, "model.weights.h5"))

keras_nlp_model.save_to_preset(preset)
print("\n-> Saved the model preset in float16")
# === Save the model ===
keras_nlp_model.save_to_preset(preset)
print(f"\n-> Saved the model preset in {FLAGS.save_dtype}")

# === Save the tokenizer ===
keras_nlp_tokenizer.save_to_preset(preset)
print("\n-> Saved the tokenizer")
# === Save the tokenizer ===
keras_nlp_tokenizer.save_to_preset(preset)
print("\n-> Saved the tokenizer")

# === Upload the preset ===
uri = f"kaggle://keras/llama2/keras/{preset}"
upload_preset(uri, preset)
print("-> Uploaded the preset!")
# == Upload preset ==
if upload_link is not None:
upload_preset(upload_link, preset)
print("-> Uploaded the preset!")
finally:
shutil.rmtree(temp_dir)


if __name__ == "__main__":
Expand Down

0 comments on commit 50e0414

Please sign in to comment.