Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Covert a safetensor checkpoint from Hugging Face hub #1662

Merged
merged 32 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
72ce619
chore: adding gemma and llama3
ariG23498 Jun 5, 2024
8f2fe93
chore: adding init
ariG23498 Jun 5, 2024
a3fdc06
chore: removing hard coded values
ariG23498 Jun 6, 2024
3235592
chore: using backbone properties
ariG23498 Jun 6, 2024
0bb47ad
chore: reformat
ariG23498 Jun 7, 2024
4994ca8
chore: review changes
ariG23498 Jun 10, 2024
606fcd7
chore: removing einops with custom np operations
ariG23498 Jun 11, 2024
3eec438
fix: variable name
ariG23498 Jun 11, 2024
225219b
check: none type for reshape and transpose patterns
ariG23498 Jun 11, 2024
7f42b2d
chore: fixing the nesting of reshape and transpose patterns
ariG23498 Jun 11, 2024
59aeb70
fixing nesting of patterns
ariG23498 Jun 11, 2024
47c1ea3
chore: gemma weight rearrange fix
ariG23498 Jun 11, 2024
8f90f9b
chore: adding a hook function to reshape and transpose the hf tensors…
ariG23498 Jun 15, 2024
f016fbb
fix: variable to assign
ariG23498 Jun 15, 2024
09d8689
fix: gemma port
ariG23498 Jun 15, 2024
99588a1
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 15, 2024
767ee2a
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 18, 2024
07183d5
chore: adding tests
ariG23498 Jun 18, 2024
cc969bc
review comments
ariG23498 Jun 19, 2024
56e0dfc
adding safetensors as a dep
ariG23498 Jun 19, 2024
57a3d33
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 20, 2024
22109a2
chore: adding jax memory cleanup
ariG23498 Jun 20, 2024
e021465
utf 8 encoding
ariG23498 Jun 20, 2024
7d0cfad
chore: changing tests
ariG23498 Jun 20, 2024
e99f98e
chore: fixing tests
ariG23498 Jun 21, 2024
f61a9fa
fix tests
ariG23498 Jun 21, 2024
5a29dc0
chore: adding guard rails for None types
ariG23498 Jun 21, 2024
85c2586
Merge branch 'master' into aritra/hf-port
ariG23498 Jun 24, 2024
c34241b
Trigger Build
ariG23498 Jun 24, 2024
851fd69
review suggestions
ariG23498 Jun 24, 2024
cf1ff29
fix raising ValueError
ariG23498 Jun 24, 2024
b06c6e4
fix error message
ariG23498 Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@

# Config file names.
CONFIG_FILE = "config.json"
HF_CONFIG_FILE = "config.json"
TOKENIZER_CONFIG_FILE = "tokenizer.json"
TASK_CONFIG_FILE = "task.json"
PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
Expand Down
12 changes: 10 additions & 2 deletions keras_nlp/src/utils/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def load_transformers_backbone(cls, preset, load_weights):
return load_gemma_backbone(cls, preset, load_weights)
if cls.__name__ == "Llama3Backbone":
return load_llama3_backbone(cls, preset, load_weights)
raise ValueError(f"No conversion huggingface/transformers to {cls}")
raise ValueError(
f"{cls} has not been ported from the Hugging Face format yet. "
"Please check Hugging Face Hub for the Keras model. "
"Models in Keras format should end with `-keras`. (e.g google/gemma-2b-keras)"
)


def load_transformers_tokenizer(cls, preset):
Expand All @@ -39,4 +43,8 @@ def load_transformers_tokenizer(cls, preset):
return load_gemma_tokenizer(cls, preset)
if cls.__name__ == "Llama3Tokenizer":
return load_llama3_tokenizer(cls, preset)
raise ValueError(f"No conversion huggingface/transformers to {cls}")
ValueError(
f"{cls} has not been ported from the Hugging Face format yet. "
"Please check Hugging Face Hub for the Keras model. "
"Models in Keras format should end with `-keras`. (e.g google/gemma-2b-keras)"
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
)
6 changes: 4 additions & 2 deletions keras_nlp/src/utils/transformers/convert_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import numpy as np

from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import SAFETENSOR_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import load_config
Expand All @@ -33,7 +35,7 @@ def load_gemma_backbone(cls, preset, load_weights):
Returns:
backbone: Initialized Keras model backbone.
"""
transformers_config = load_config(preset, "config.json")
transformers_config = load_config(preset, HF_CONFIG_FILE)

backbone = cls(
vocabulary_size=transformers_config["vocab_size"],
Expand All @@ -50,7 +52,7 @@ def load_gemma_backbone(cls, preset, load_weights):

jax_memory_cleanup(backbone)
# Code to port the weights from safetensors into the keras nlp model
safetensor_config = load_config(preset, "model.safetensors.index.json")
safetensor_config = load_config(preset, SAFETENSOR_CONFIG_FILE)
safetensor_files = {
fname: get_file(preset, fname)
for fname in set(safetensor_config["weight_map"].values())
Expand Down
6 changes: 4 additions & 2 deletions keras_nlp/src/utils/transformers/convert_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import numpy as np

from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import SAFETENSOR_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import load_config
Expand All @@ -33,7 +35,7 @@ def load_llama3_backbone(cls, preset, load_weights):
Returns:
backbone: Initialized Keras model backbone.
"""
transformers_config = load_config(preset, "config.json")
transformers_config = load_config(preset, HF_CONFIG_FILE)

backbone = cls(
vocabulary_size=transformers_config["vocab_size"],
Expand All @@ -49,7 +51,7 @@ def load_llama3_backbone(cls, preset, load_weights):

jax_memory_cleanup(backbone)
# Code to port the weights from safetensors into the keras nlp model
safetensor_config = load_config(preset, "model.safetensors.index.json")
safetensor_config = load_config(preset, SAFETENSOR_CONFIG_FILE)
safetensor_files = {
fname: get_file(preset, fname)
for fname in set(safetensor_config["weight_map"].values())
Expand Down
Loading