Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Covert a safetensor checkpoint from Hugging Face hub #1662

Merged
merged 32 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
72ce619
chore: adding gemma and llama3
ariG23498 Jun 5, 2024
8f2fe93
chore: adding init
ariG23498 Jun 5, 2024
a3fdc06
chore: removing hard coded values
ariG23498 Jun 6, 2024
3235592
chore: using backbone properties
ariG23498 Jun 6, 2024
0bb47ad
chore: reformat
ariG23498 Jun 7, 2024
4994ca8
chore: review changes
ariG23498 Jun 10, 2024
606fcd7
chore: removing einops with custom np operations
ariG23498 Jun 11, 2024
3eec438
fix: variable name
ariG23498 Jun 11, 2024
225219b
check: none type for reshape and transpose patterns
ariG23498 Jun 11, 2024
7f42b2d
chore: fixing the nesting of reshape and transpose patterns
ariG23498 Jun 11, 2024
59aeb70
fixing nesting of patterns
ariG23498 Jun 11, 2024
47c1ea3
chore: gemma weight rearrange fix
ariG23498 Jun 11, 2024
8f90f9b
chore: adding a hook function to reshape and transpose the hf tensors…
ariG23498 Jun 15, 2024
f016fbb
fix: variable to assign
ariG23498 Jun 15, 2024
09d8689
fix: gemma port
ariG23498 Jun 15, 2024
99588a1
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 15, 2024
767ee2a
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 18, 2024
07183d5
chore: adding tests
ariG23498 Jun 18, 2024
cc969bc
review comments
ariG23498 Jun 19, 2024
56e0dfc
adding safetensors as a dep
ariG23498 Jun 19, 2024
57a3d33
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 20, 2024
22109a2
chore: adding jax memory cleanup
ariG23498 Jun 20, 2024
e021465
utf 8 encoding
ariG23498 Jun 20, 2024
7d0cfad
chore: changing tests
ariG23498 Jun 20, 2024
e99f98e
chore: fixing tests
ariG23498 Jun 21, 2024
f61a9fa
fix tests
ariG23498 Jun 21, 2024
5a29dc0
chore: adding guard rails for None types
ariG23498 Jun 21, 2024
85c2586
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 24, 2024
c34241b
Trigger Build
ariG23498 Jun 24, 2024
851fd69
review suggestions
ariG23498 Jun 24, 2024
cf1ff29
fix raising ValueError
ariG23498 Jun 24, 2024
b06c6e4
fix error message
ariG23498 Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_metadata
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.transformers.convert import load_transformers_backbone


@keras_nlp_export("keras_nlp.models.Backbone")
Expand Down Expand Up @@ -198,7 +199,11 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
)
```
"""
validate_metadata(preset)
format = check_format(preset)

if format == "transformers":
return load_transformers_backbone(cls, preset, load_weights)

preset_cls = check_config_class(preset)
if not issubclass(preset_cls, cls):
raise ValueError(
Expand Down
9 changes: 7 additions & 2 deletions keras_nlp/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_file_exists
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -127,7 +127,12 @@ def from_preset(
)
```
"""
validate_metadata(preset)
format = check_format(preset)

if format == "transformers":
tokenizer = cls.tokenizer_cls.from_preset(preset)
return cls(tokenizer=tokenizer, **kwargs)

if cls == Preprocessor:
raise ValueError(
"Do not call `Preprocessor.from_preset()` directly. Instead call a "
Expand Down
9 changes: 7 additions & 2 deletions keras_nlp/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
from keras_nlp.src.utils.preset_utils import TASK_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_file_exists
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -213,7 +213,12 @@ def from_preset(
)
```
"""
validate_metadata(preset)
format = check_format(preset)

if format == "transformers":
backbone = cls.backbone_cls.from_preset(preset)
preprocessor = cls.preprocessor_cls.from_preset(preset)
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)

if cls == Task:
raise ValueError(
Expand Down
8 changes: 6 additions & 2 deletions keras_nlp/src/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import save_tokenizer_assets
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.transformers.convert import load_transformers_tokenizer


@keras_nlp_export(
Expand Down Expand Up @@ -215,7 +216,10 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from
tokenizer.detokenize([5, 6, 7, 8, 9])
```
"""
validate_metadata(preset)
format = check_format(preset)
if format == "transformers":
return load_transformers_tokenizer(cls, preset)

preset_cls = check_config_class(
preset, config_file=TOKENIZER_CONFIG_FILE
)
Expand Down
8 changes: 7 additions & 1 deletion keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,12 @@ def load_config(preset, config_file=CONFIG_FILE):
return config


def validate_metadata(preset):
def check_format(preset):
if check_file_exists(preset, "model.safetensors") or check_file_exists(
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
preset, "model.safetensors.index.json"
):
return "transformers"

if not check_file_exists(preset, METADATA_FILE):
raise FileNotFoundError(
f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`. "
Expand All @@ -559,6 +564,7 @@ def validate_metadata(preset):
f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. "
"Please verify that the model you are trying to load is a Keras model."
)
return "keras"


def load_serialized_object(
Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/src/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import METADATA_FILE
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.preset_utils import check_format


class PresetUtilsTest(TestCase):
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_missing_metadata(self):
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
validate_metadata(preset_dir)
check_format(preset_dir)

def test_incorrect_metadata(self):
temp_dir = self.get_temp_dir()
Expand All @@ -114,4 +114,4 @@ def test_incorrect_metadata(self):
json.dump(data, f)

with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"):
validate_metadata(preset_dir)
check_format(preset_dir)
13 changes: 13 additions & 0 deletions keras_nlp/src/utils/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
38 changes: 38 additions & 0 deletions keras_nlp/src/utils/transformers/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert huggingface models to KerasNLP."""


from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_backbone
from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_tokenizer
from keras_nlp.src.utils.transformers.convert_llama3 import load_llama3_backbone
from keras_nlp.src.utils.transformers.convert_llama3 import (
load_llama3_tokenizer,
)


def load_transformers_backbone(cls, preset, load_weights):
if cls.__name__ == "GemmaBackbone":
return load_gemma_backbone(cls, preset, load_weights)
if cls.__name__ == "Llama3Backbone":
return load_llama3_backbone(cls, preset, load_weights)
raise ValueError(f"No conversion huggingface/transformers to {cls}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user doesn't know that conversion is required to load a transformers checkpoint in Keras and try to load a transformers checkpoint that doesn't have conversion, they'll end up here, right? Similar to #1574
In that case, I think it'd be nice to have an error message helping the user to know that if conversion is not supported, they can switch to loading a Keras checkpoint if available.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have modified the Value Error message. Let me know if that was what you wanted.



def load_transformers_tokenizer(cls, preset):
if cls.__name__ == "GemmaTokenizer":
return load_gemma_tokenizer(cls, preset)
if cls.__name__ == "Llama3Tokenizer":
return load_llama3_tokenizer(cls, preset)
raise ValueError(f"No conversion huggingface/transformers to {cls}")
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
182 changes: 182 additions & 0 deletions keras_nlp/src/utils/transformers/convert_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

import numpy as np

from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import load_config
from keras_nlp.src.utils.transformers.safetensor_utils import set_keras_weight


def load_gemma_backbone(cls, preset, load_weights):
"""
Load and initialize the Gemma backbone model.

Args:
cls (class): Keras model class.
preset (str): Preset configuration name.
load_weights (bool): Whether to load the weights.

Returns:
backbone: Initialized Keras model backbone.
"""
transformers_config = load_config(preset, "config.json")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a constant for config.json here. We have a plan to change the name of this file in the future so using the constant would make future changes easier.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the config.json comes from the Hugging Face Repository. I have added another constant to support this file name, and now am using the constant. Does the current implementation look good?


backbone = cls(
vocabulary_size=transformers_config["vocab_size"],
num_layers=transformers_config["num_hidden_layers"],
num_query_heads=transformers_config["num_attention_heads"],
num_key_value_heads=transformers_config["num_key_value_heads"],
hidden_dim=transformers_config["hidden_size"],
intermediate_dim=transformers_config["intermediate_size"] * 2,
head_dim=transformers_config["head_dim"],
)

if load_weights:
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
safetensor_config = load_config(preset, "model.safetensors.index.json")
safetensor_files = {
fname: get_file(preset, fname)
for fname in set(safetensor_config["weight_map"].values())
}
port_weight = partial(
set_keras_weight,
safetensor_files=safetensor_files,
safetensor_config=safetensor_config,
)

# Embedding layer
port_weight(
keras_variable=backbone.get_layer("token_embedding").variables[0],
hf_weight_key="model.embed_tokens.weight",
)

# Attention blocks
for i in range(backbone.num_layers):
decoder_layer = backbone.get_layer(f"decoder_block_{i}")
# Norm layers
port_weight(
keras_variable=decoder_layer.pre_attention_norm.variables[0],
hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
)
port_weight(
keras_variable=decoder_layer.pre_ffw_norm.variables[0],
hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
)

# Attention layers
port_weight(
keras_variable=decoder_layer.attention.query_dense.variables[0],
hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
# rearrange_patterns="(a c) b -> a b c",
# rearrange_dims={"a": backbone.num_query_heads},
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
np.reshape(
hf_tensor,
(keras_shape[0], keras_shape[2], keras_shape[1]),
),
axes=(0, 2, 1),
),
)
port_weight(
keras_variable=decoder_layer.attention.key_dense.variables[0],
hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
# rearrange_patterns="(a c) b -> a b c",
# rearrange_dims={"a": backbone.num_key_value_heads},
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
np.reshape(
hf_tensor,
(keras_shape[0], keras_shape[2], keras_shape[1]),
),
axes=(0, 2, 1),
),
)
port_weight(
keras_variable=decoder_layer.attention.value_dense.variables[0],
hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
# rearrange_patterns="(a c) b -> a b c",
# rearrange_dims={"a": backbone.num_key_value_heads},
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
np.reshape(
hf_tensor,
(keras_shape[0], keras_shape[2], keras_shape[1]),
),
axes=(0, 2, 1),
),
)
port_weight(
keras_variable=decoder_layer.attention.output_dense.variables[
0
],
hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
# rearrange_patterns="c (a b) -> a b c",
# rearrange_dims={"a": backbone.num_query_heads},
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
np.reshape(
hf_tensor,
(keras_shape[2], keras_shape[0], keras_shape[1]),
),
axes=(1, 2, 0),
),
)

# MLP layers
port_weight(
keras_variable=decoder_layer.gating_ffw.variables[0],
hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight",
# rearrange_patterns="b a -> a b",
hook_fn=lambda hf_tensor, _: np.transpose(
hf_tensor, axes=(1, 0)
),
)
port_weight(
keras_variable=decoder_layer.gating_ffw_2.variables[0],
hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight",
# rearrange_patterns="b a -> a b",
hook_fn=lambda hf_tensor, _: np.transpose(
hf_tensor, axes=(1, 0)
),
)
port_weight(
keras_variable=decoder_layer.ffw_linear.variables[0],
hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight",
# rearrange_patterns="b a -> a b",
hook_fn=lambda hf_tensor, _: np.transpose(
hf_tensor, axes=(1, 0)
),
)

# Final normalization layer
port_weight(
keras_variable=backbone.get_layer("final_normalization").variables[
0
],
hf_weight_key="model.norm.weight",
)

return backbone


def load_gemma_tokenizer(cls, preset):
"""
Load the Gemma tokenizer.

Args:
cls (class): Tokenizer class.
preset (str): Preset configuration name.

Returns:
tokenizer: Initialized tokenizer.
"""
return cls(get_file(preset, "tokenizer.model"))
Loading
Loading