Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Covert a safetensor checkpoint from Hugging Face hub #1662

Merged
merged 32 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
72ce619
chore: adding gemma and llama3
ariG23498 Jun 5, 2024
8f2fe93
chore: adding init
ariG23498 Jun 5, 2024
a3fdc06
chore: removing hard coded values
ariG23498 Jun 6, 2024
3235592
chore: using backbone properties
ariG23498 Jun 6, 2024
0bb47ad
chore: reformat
ariG23498 Jun 7, 2024
4994ca8
chore: review changes
ariG23498 Jun 10, 2024
606fcd7
chore: removing einops with custom np operations
ariG23498 Jun 11, 2024
3eec438
fix: variable name
ariG23498 Jun 11, 2024
225219b
check: none type for reshape and transpose patterns
ariG23498 Jun 11, 2024
7f42b2d
chore: fixing the nesting of reshape and transpose patterns
ariG23498 Jun 11, 2024
59aeb70
fixing nesting of patterns
ariG23498 Jun 11, 2024
47c1ea3
chore: gemma weight rearrange fix
ariG23498 Jun 11, 2024
8f90f9b
chore: adding a hook function to reshape and transpose the hf tensors…
ariG23498 Jun 15, 2024
f016fbb
fix: variable to assign
ariG23498 Jun 15, 2024
09d8689
fix: gemma port
ariG23498 Jun 15, 2024
99588a1
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 15, 2024
767ee2a
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 18, 2024
07183d5
chore: adding tests
ariG23498 Jun 18, 2024
cc969bc
review comments
ariG23498 Jun 19, 2024
56e0dfc
adding safetensors as a dep
ariG23498 Jun 19, 2024
57a3d33
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 20, 2024
22109a2
chore: adding jax memory cleanup
ariG23498 Jun 20, 2024
e021465
utf 8 encoding
ariG23498 Jun 20, 2024
7d0cfad
chore: changing tests
ariG23498 Jun 20, 2024
e99f98e
chore: fixing tests
ariG23498 Jun 21, 2024
f61a9fa
fix tests
ariG23498 Jun 21, 2024
5a29dc0
chore: adding guard rails for None types
ariG23498 Jun 21, 2024
85c2586
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 24, 2024
c34241b
Trigger Build
ariG23498 Jun 24, 2024
851fd69
review suggestions
ariG23498 Jun 24, 2024
cf1ff29
fix raising ValueError
ariG23498 Jun 24, 2024
b06c6e4
fix error message
ariG23498 Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_metadata
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.transformers_utils import load_transformers_backbone


@keras_nlp_export("keras_nlp.models.Backbone")
Expand Down Expand Up @@ -198,7 +199,11 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
)
```
"""
validate_metadata(preset)
format = check_format(preset)

if format == "huggingface":
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
return load_transformers_backbone(cls, preset, load_weights)

preset_cls = check_config_class(preset)
if not issubclass(preset_cls, cls):
raise ValueError(
Expand Down
9 changes: 7 additions & 2 deletions keras_nlp/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_file_exists
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -127,7 +127,12 @@ def from_preset(
)
```
"""
validate_metadata(preset)
format = check_format(preset)

if format == "huggingface":
tokenizer = cls.tokenizer_cls.from_preset(preset)
return cls(tokenizer=tokenizer, **kwargs)

if cls == Preprocessor:
raise ValueError(
"Do not call `Preprocessor.from_preset()` directly. Instead call a "
Expand Down
9 changes: 7 additions & 2 deletions keras_nlp/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
from keras_nlp.src.utils.preset_utils import TASK_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_file_exists
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -213,7 +213,12 @@ def from_preset(
)
```
"""
validate_metadata(preset)
format = check_format(preset)

if format == "huggingface":
backbone = cls.backbone_cls.from_preset(preset)
preprocessor = cls.preprocessor_cls.from_preset(preset)
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)

if cls == Task:
raise ValueError(
Expand Down
8 changes: 6 additions & 2 deletions keras_nlp/src/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import save_tokenizer_assets
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.transformers_utils import load_transformers_tokenizer


@keras_nlp_export(
Expand Down Expand Up @@ -215,7 +216,10 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from
tokenizer.detokenize([5, 6, 7, 8, 9])
```
"""
validate_metadata(preset)
format = check_format(preset)
if format == "huggingface":
return load_transformers_tokenizer(cls, preset)

preset_cls = check_config_class(
preset, config_file=TOKENIZER_CONFIG_FILE
)
Expand Down
8 changes: 7 additions & 1 deletion keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,12 @@ def load_config(preset, config_file=CONFIG_FILE):
return config


def validate_metadata(preset):
def check_format(preset):
if check_file_exists(preset, "model.safetensors") or check_file_exists(
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
preset, "model.safetensors.index.json"
):
return "huggingface"

if not check_file_exists(preset, METADATA_FILE):
raise FileNotFoundError(
f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`. "
Expand All @@ -513,6 +518,7 @@ def validate_metadata(preset):
f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. "
"Please verify that the model you are trying to load is a Keras model."
)
return "keras"


def load_serialized_object(
Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/src/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import METADATA_FILE
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.preset_utils import check_format


class PresetUtilsTest(TestCase):
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_missing_metadata(self):
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
validate_metadata(preset_dir)
check_format(preset_dir)

def test_incorrect_metadata(self):
temp_dir = self.get_temp_dir()
Expand All @@ -114,4 +114,4 @@ def test_incorrect_metadata(self):
json.dump(data, f)

with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"):
validate_metadata(preset_dir)
check_format(preset_dir)
13 changes: 13 additions & 0 deletions keras_nlp/src/utils/transformers_model_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasNLP Authors
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
60 changes: 60 additions & 0 deletions keras_nlp/src/utils/transformers_model_utils/hf_common_port.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import einops
from safetensors import safe_open
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved


def set_keras_weights(
safetensor_files,
safetensor_config,
keras_weight,
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
hf_weight_keys,
rearrange_patterns=None,
rearrange_dims=None,
):
"""
Set Keras model weights from SafeTensors file.

Args:
safetensor_files (dict): Dictionary of SafeTensor file paths.
safetensor_config (dict): Configuration for SafeTensors.
keras_weight (keras.layers.Layer): Keras layer to set the weights for.
hf_weight_keys (str or list): Key(s) for the Hugging Face weight(s).
rearrange_patterns (str or list, optional): Pattern(s) for rearranging dimensions using einops.
rearrange_dims (dict, optional): Dimensions for rearranging using einops.
"""
if isinstance(hf_weight_keys, str):
hf_weight_keys = [hf_weight_keys]
if rearrange_patterns and isinstance(rearrange_patterns, str):
rearrange_patterns = [rearrange_patterns] * len(hf_weight_keys)
elif not rearrange_patterns:
rearrange_patterns = [None] * len(hf_weight_keys)

tensors = []
for hf_weight_key, rearrange_pattern in zip(
hf_weight_keys, rearrange_patterns
):
safetensor_file = safetensor_files[
safetensor_config["weight_map"][hf_weight_key]
]
with safe_open(safetensor_file, framework="np") as f:
tensor = f.get_tensor(hf_weight_key)
if rearrange_pattern:
tensor = einops.rearrange(
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
tensor,
rearrange_pattern,
**rearrange_dims if rearrange_dims else {}
)
tensors.append(tensor)
keras_weight.set_weights(tensors)
158 changes: 158 additions & 0 deletions keras_nlp/src/utils/transformers_model_utils/hf_gemma_port.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import load_config
from keras_nlp.src.utils.transformers_model_utils.hf_common_port import (
set_keras_weights,
)


def load_gemma_backbone(cls, preset, load_weights):
"""
Load and initialize the Gemma backbone model.

Args:
cls (class): Keras model class.
preset (str): Preset configuration name.
load_weights (bool): Whether to load the weights.

Returns:
backbone: Initialized Keras model backbone.
"""
backbone_config = load_config(preset, "config.json")
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved

backbone = cls(
vocabulary_size=backbone_config["vocab_size"],
num_layers=backbone_config["num_hidden_layers"],
num_query_heads=backbone_config["num_attention_heads"],
num_key_value_heads=backbone_config["num_key_value_heads"],
hidden_dim=backbone_config["hidden_size"],
intermediate_dim=backbone_config["intermediate_size"] * 2,
head_dim=backbone_config["head_dim"],
)

if load_weights:
safetensor_config = load_config(preset, "model.safetensors.index.json")
safetensor_files = {
fname: get_file(preset, fname)
for fname in set(safetensor_config["weight_map"].values())
}
port_weight = partial(
set_keras_weights,
safetensor_files=safetensor_files,
safetensor_config=safetensor_config,
)

# Embedding layer
port_weight(
keras_weight=backbone.get_layer("token_embedding"),
hf_weight_keys="model.embed_tokens.weight",
)

# Attention blocks
for i in range(backbone.num_layers):
# Norm layers
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
port_weight(
keras_weight=backbone.get_layer(
f"decoder_block_{i}"
).pre_attention_norm,
hf_weight_keys=f"model.layers.{i}.input_layernorm.weight",
)
port_weight(
keras_weight=backbone.get_layer(
f"decoder_block_{i}"
).pre_ffw_norm,
hf_weight_keys=f"model.layers.{i}.post_attention_layernorm.weight",
)

# Attention layers
port_weight(
keras_weight=backbone.get_layer(
f"decoder_block_{i}"
).attention.query_dense,
hf_weight_keys=f"model.layers.{i}.self_attn.q_proj.weight",
rearrange_patterns="(a c) b -> a b c",
rearrange_dims={"a": backbone.num_query_heads},
)
port_weight(
keras_weight=backbone.get_layer(
f"decoder_block_{i}"
).attention.key_dense,
hf_weight_keys=f"model.layers.{i}.self_attn.k_proj.weight",
rearrange_patterns="(a c) b -> a b c",
rearrange_dims={"a": backbone.num_key_value_heads},
)
port_weight(
keras_weight=backbone.get_layer(
f"decoder_block_{i}"
).attention.value_dense,
hf_weight_keys=f"model.layers.{i}.self_attn.v_proj.weight",
rearrange_patterns="(a c) b -> a b c",
rearrange_dims={"a": backbone.num_key_value_heads},
)
port_weight(
keras_weight=backbone.get_layer(
f"decoder_block_{i}"
).attention.output_dense,
hf_weight_keys=f"model.layers.{i}.self_attn.o_proj.weight",
rearrange_patterns="c (a b) -> a b c",
rearrange_dims={"a": backbone.num_query_heads},
)

# MLP layers
port_weight(
keras_weight=backbone.get_layer(
f"decoder_block_{i}"
).gating_ffw,
hf_weight_keys=f"model.layers.{i}.mlp.gate_proj.weight",
rearrange_patterns="b a -> a b",
)
port_weight(
keras_weight=backbone.get_layer(
f"decoder_block_{i}"
).gating_ffw_2,
hf_weight_keys=f"model.layers.{i}.mlp.up_proj.weight",
rearrange_patterns="b a -> a b",
)
port_weight(
keras_weight=backbone.get_layer(
f"decoder_block_{i}"
).ffw_linear,
hf_weight_keys=f"model.layers.{i}.mlp.down_proj.weight",
rearrange_patterns="b a -> a b",
)

# Final normalization layer
port_weight(
keras_weight=backbone.get_layer("final_normalization"),
hf_weight_keys="model.norm.weight",
)

return backbone


def load_gemma_tokenizer(cls, preset):
"""
Load the Gemma tokenizer.

Args:
cls (class): Tokenizer class.
preset (str): Preset configuration name.

Returns:
tokenizer: Initialized tokenizer.
"""
return cls(get_file(preset, "tokenizer.model"))
Loading
Loading