Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Covert a safetensor checkpoint from Hugging Face hub #1662

Merged
merged 32 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
72ce619
chore: adding gemma and llama3
ariG23498 Jun 5, 2024
8f2fe93
chore: adding init
ariG23498 Jun 5, 2024
a3fdc06
chore: removing hard coded values
ariG23498 Jun 6, 2024
3235592
chore: using backbone properties
ariG23498 Jun 6, 2024
0bb47ad
chore: reformat
ariG23498 Jun 7, 2024
4994ca8
chore: review changes
ariG23498 Jun 10, 2024
606fcd7
chore: removing einops with custom np operations
ariG23498 Jun 11, 2024
3eec438
fix: variable name
ariG23498 Jun 11, 2024
225219b
check: none type for reshape and transpose patterns
ariG23498 Jun 11, 2024
7f42b2d
chore: fixing the nesting of reshape and transpose patterns
ariG23498 Jun 11, 2024
59aeb70
fixing nesting of patterns
ariG23498 Jun 11, 2024
47c1ea3
chore: gemma weight rearrange fix
ariG23498 Jun 11, 2024
8f90f9b
chore: adding a hook function to reshape and transpose the hf tensors…
ariG23498 Jun 15, 2024
f016fbb
fix: variable to assign
ariG23498 Jun 15, 2024
09d8689
fix: gemma port
ariG23498 Jun 15, 2024
99588a1
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 15, 2024
767ee2a
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 18, 2024
07183d5
chore: adding tests
ariG23498 Jun 18, 2024
cc969bc
review comments
ariG23498 Jun 19, 2024
56e0dfc
adding safetensors as a dep
ariG23498 Jun 19, 2024
57a3d33
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 20, 2024
22109a2
chore: adding jax memory cleanup
ariG23498 Jun 20, 2024
e021465
utf 8 encoding
ariG23498 Jun 20, 2024
7d0cfad
chore: changing tests
ariG23498 Jun 20, 2024
e99f98e
chore: fixing tests
ariG23498 Jun 21, 2024
f61a9fa
fix tests
ariG23498 Jun 21, 2024
5a29dc0
chore: adding guard rails for None types
ariG23498 Jun 21, 2024
85c2586
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 24, 2024
c34241b
Trigger Build
ariG23498 Jun 24, 2024
851fd69
review suggestions
ariG23498 Jun 24, 2024
cf1ff29
fix raising ValueError
ariG23498 Jun 24, 2024
b06c6e4
fix error message
ariG23498 Jun 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_metadata
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.transformers.convert import load_transformers_backbone


@keras_nlp_export("keras_nlp.models.Backbone")
Expand Down Expand Up @@ -173,7 +174,11 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
)
```
"""
validate_metadata(preset)
format = check_format(preset)

if format == "transformers":
return load_transformers_backbone(cls, preset, load_weights)

preset_cls = check_config_class(preset)
if not issubclass(preset_cls, cls):
raise ValueError(
Expand Down
4 changes: 1 addition & 3 deletions keras_nlp/src/models/backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def test_from_preset(self):
def test_from_preset_errors(self):
with self.assertRaises(ValueError):
GPT2Backbone.from_preset("bert_tiny_en_uncased", load_weights=False)
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
with self.assertRaises(ValueError):
# No loading on a non-keras model.
Backbone.from_preset("hf://google-bert/bert-base-uncased")

Expand Down
11 changes: 9 additions & 2 deletions keras_nlp/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_file_exists
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -128,7 +128,14 @@ def from_preset(
)
```
"""
validate_metadata(preset)
format = check_format(preset)

if format == "transformers":
if cls.tokenizer_cls is None:
raise ValueError("Tokenizer class is None")
tokenizer = cls.tokenizer_cls.from_preset(preset)
return cls(tokenizer=tokenizer, **kwargs)

if cls == Preprocessor:
raise ValueError(
"Do not call `Preprocessor.from_preset()` directly. Instead call a "
Expand Down
5 changes: 1 addition & 4 deletions keras_nlp/src/models/preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
RobertaPreprocessor,
)
from keras_nlp.src.tests.test_case import TestCase
from keras_nlp.src.utils.preset_utils import METADATA_FILE
from keras_nlp.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.src.utils.preset_utils import check_config_class
Expand Down Expand Up @@ -67,9 +66,7 @@ def test_from_preset_errors(self):
with self.assertRaises(ValueError):
# No loading on an incorrect class.
BertPreprocessor.from_preset("gpt2_base_en")
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
with self.assertRaises(ValueError):
# No loading on a non-keras model.
Preprocessor.from_preset("hf://google-bert/bert-base-uncased")

Expand Down
14 changes: 12 additions & 2 deletions keras_nlp/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@
from keras_nlp.src.utils.preset_utils import TASK_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_file_exists
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -187,7 +187,17 @@ def from_preset(
)
```
"""
validate_metadata(preset)
format = check_format(preset)

if format == "transformers":
if cls.backbone_cls is None:
raise ValueError("Backbone class is None")
if cls.preprocessor_cls is None:
raise ValueError("Preprocessor class is None")

backbone = cls.backbone_cls.from_preset(preset)
preprocessor = cls.preprocessor_cls.from_preset(preset)
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)

if cls == Task:
raise ValueError(
Expand Down
4 changes: 1 addition & 3 deletions keras_nlp/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def test_from_preset_errors(self):
with self.assertRaises(ValueError):
# No loading on an incorrect class.
BertClassifier.from_preset("gpt2_base_en", load_weights=False)
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
with self.assertRaises(ValueError):
# No loading on a non-keras model.
CausalLM.from_preset("hf://google-bert/bert-base-uncased")

Expand Down
8 changes: 6 additions & 2 deletions keras_nlp/src/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.preset_utils import save_tokenizer_assets
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.transformers.convert import load_transformers_tokenizer


@keras_nlp_export(
Expand Down Expand Up @@ -215,7 +216,10 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from
tokenizer.detokenize([5, 6, 7, 8, 9])
```
"""
validate_metadata(preset)
format = check_format(preset)
if format == "transformers":
return load_transformers_tokenizer(cls, preset)

preset_cls = check_config_class(
preset, config_file=TOKENIZER_CONFIG_FILE
)
Expand Down
5 changes: 1 addition & 4 deletions keras_nlp/src/tokenizers/tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from keras_nlp.src.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_nlp.src.tests.test_case import TestCase
from keras_nlp.src.tokenizers.tokenizer import Tokenizer
from keras_nlp.src.utils.preset_utils import METADATA_FILE
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
Expand Down Expand Up @@ -70,9 +69,7 @@ def test_from_preset(self):
def test_from_preset_errors(self):
with self.assertRaises(ValueError):
GPT2Tokenizer.from_preset("bert_tiny_en_uncased")
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
with self.assertRaises(ValueError):
# No loading on a non-keras model.
Tokenizer.from_preset("hf://google-bert/bert-base-uncased")

Expand Down
19 changes: 14 additions & 5 deletions keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,19 @@

# Config file names.
CONFIG_FILE = "config.json"
HF_CONFIG_FILE = "config.json"
TOKENIZER_CONFIG_FILE = "tokenizer.json"
TASK_CONFIG_FILE = "task.json"
PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
METADATA_FILE = "metadata.json"
SAFETENSOR_CONFIG_FILE = "model.safetensors.index.json"

README_FILE = "README.md"

# Weight file names.
MODEL_WEIGHTS_FILE = "model.weights.h5"
TASK_WEIGHTS_FILE = "task.weights.h5"
SAFETENSOR_FILE = "model.safetensors"

# Global state for preset registry.
BUILTIN_PRESETS = {}
Expand Down Expand Up @@ -324,7 +327,7 @@ def _validate_tokenizer(preset, allow_incomplete=False):
)
config_path = get_file(preset, TOKENIZER_CONFIG_FILE)
try:
with open(config_path) as config_file:
with open(config_path, encoding="utf-8") as config_file:
config = json.load(config_file)
except Exception as e:
raise ValueError(
Expand Down Expand Up @@ -357,7 +360,7 @@ def _validate_backbone(preset):
f"`{CONFIG_FILE}` is missing from the preset directory `{preset}`."
)
try:
with open(config_path) as config_file:
with open(config_path, encoding="utf-8") as config_file:
json.load(config_file)
except Exception as e:
raise ValueError(
Expand Down Expand Up @@ -530,12 +533,17 @@ def upload_preset(

def load_config(preset, config_file=CONFIG_FILE):
config_path = get_file(preset, config_file)
with open(config_path) as config_file:
with open(config_path, encoding="utf-8") as config_file:
config = json.load(config_file)
return config


def validate_metadata(preset):
def check_format(preset):
if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists(
preset, SAFETENSOR_CONFIG_FILE
):
return "transformers"
Comment on lines +542 to +545
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw this does not account for the fact if a hugging face repository has safetensors and a h5 file, as you can see in google-bert/bert-base-uncased.

The current tests fail becasue there is yet to be a port written for the bert. Do you think having a check for METADATA_FILE first, and then checking for safetensors make more sense here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it seems like everything is works as intended here, we just need to tweak the tests.

We have four tests that were added for showing a friendly error if someone is trying to load a transformers checkpoint we do not support....

https://github.com/search?q=repo%3Akeras-team%2Fkeras-nlp+%22hf%3A%2F%2Fgoogle-bert%2Fbert-base-uncased%22&type=code

I think we just need to change to error expectation, instead of a FileNotFound error we will get a ValueError. Eventually, we might want to delete those tests and consolidate all of this "transformers testing" into the new directory you are adding, but that can be a follow up I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So no changes here right?

Copy link
Member

@mattdangerw mattdangerw Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need to update the error here (and three equivalent checks in the search above), to a ValueError, maybe just ditch the message regex.

with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):


if not check_file_exists(preset, METADATA_FILE):
raise FileNotFoundError(
f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`, "
Expand All @@ -548,6 +556,7 @@ def validate_metadata(preset):
f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. "
"Please verify that the model you are trying to load is a Keras model."
)
return "keras"


def load_serialized_object(
Expand All @@ -566,7 +575,7 @@ def check_config_class(
):
"""Validate a preset is being loaded on the correct class."""
config_path = get_file(preset, config_file)
with open(config_path) as config_file:
with open(config_path, encoding="utf-8") as config_file:
config = json.load(config_file)
return keras.saving.get_registered_object(config["registered_name"])

Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/src/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import METADATA_FILE
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import validate_metadata
from keras_nlp.src.utils.preset_utils import check_format


class PresetUtilsTest(TestCase):
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_missing_metadata(self):
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
validate_metadata(preset_dir)
check_format(preset_dir)

def test_incorrect_metadata(self):
temp_dir = self.get_temp_dir()
Expand All @@ -112,4 +112,4 @@ def test_incorrect_metadata(self):
json.dump(data, f)

with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"):
validate_metadata(preset_dir)
check_format(preset_dir)
13 changes: 13 additions & 0 deletions keras_nlp/src/utils/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
48 changes: 48 additions & 0 deletions keras_nlp/src/utils/transformers/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert huggingface models to KerasNLP."""


from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_backbone
from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_tokenizer
from keras_nlp.src.utils.transformers.convert_llama3 import load_llama3_backbone
from keras_nlp.src.utils.transformers.convert_llama3 import (
load_llama3_tokenizer,
)


def load_transformers_backbone(cls, preset, load_weights):
if cls is None:
raise ValueError("Backbone class is None")
if cls.__name__ == "GemmaBackbone":
return load_gemma_backbone(cls, preset, load_weights)
if cls.__name__ == "Llama3Backbone":
return load_llama3_backbone(cls, preset, load_weights)
raise ValueError(
f"{cls} has not been ported from the Hugging Face format yet. "
"Please check Hugging Face Hub for the Keras model. "
)


def load_transformers_tokenizer(cls, preset):
if cls is None:
raise ValueError("Tokenizer class is None")
if cls.__name__ == "GemmaTokenizer":
return load_gemma_tokenizer(cls, preset)
if cls.__name__ == "Llama3Tokenizer":
return load_llama3_tokenizer(cls, preset)
raise ValueError(
f"{cls} has not been ported from the Hugging Face format yet. "
"Please check Hugging Face Hub for the Keras model. "
)
Loading
Loading