Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed from_preset file downloading to use GFile when able #1665

Merged
merged 7 commits into from
Jun 15, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 62 additions & 17 deletions keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
import json
import os
import re
import shutil
import urllib

import tensorflow as tf
VarunS1997 marked this conversation as resolved.
Show resolved Hide resolved

from absl import logging
from keras.src.utils import io_utils
VarunS1997 marked this conversation as resolved.
Show resolved Hide resolved
from keras.src.utils.file_utils import path_to_string
from packaging.version import parse

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.backend import config as backend_config
from keras_nlp.src.backend import config as backend_config, config
from keras_nlp.src.backend import keras

try:
Expand Down Expand Up @@ -134,14 +140,16 @@ def get_file(preset, path):
else:
raise ValueError(message)

elif preset.startswith(GS_PREFIX):
elif any(preset.lower().startswith(scheme + "://") for scheme in tf.io.gfile.get_registered_schemes()):
VarunS1997 marked this conversation as resolved.
Show resolved Hide resolved
url = os.path.join(preset, path)
url = url.replace(GS_PREFIX, "https://storage.googleapis.com/")
subdir = preset.replace(GS_PREFIX, "gs_")
subdir = preset
for scheme in tf.io.gfile.get_registered_schemes():
VarunS1997 marked this conversation as resolved.
Show resolved Hide resolved
if (subdir.lower().startswith(scheme + "://")):
subdir = subdir.replace(scheme + "://", scheme + "_")
subdir = subdir.replace("/", "_").replace("-", "_")
filename = os.path.basename(path)
subdir = os.path.join(subdir, os.path.dirname(path))
return keras.utils.get_file(
return load_preset_from_gcs(
filename,
url,
cache_subdir=os.path.join("models", subdir),
Expand Down Expand Up @@ -192,6 +200,43 @@ def get_file(preset, path):
)


def load_preset_from_gcs(fname, url, cache_subdir):
VarunS1997 marked this conversation as resolved.
Show resolved Hide resolved
"""Much of this is adapted from get_file of keras core."""
if url is None:
VarunS1997 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
'Please specify the "url" argument (URL of the file '
"to download)."
)

if cache_subdir is None:
cache_dir = config.keras_home()

datadir_base = os.path.expanduser(cache_dir)
if not os.access(datadir_base, os.W_OK):
datadir_base = os.path.join("/tmp", ".keras")
datadir = os.path.join(datadir_base, cache_subdir)
os.makedirs(datadir, exist_ok=True)

fname = path_to_string(fname)
VarunS1997 marked this conversation as resolved.
Show resolved Hide resolved
if not fname:
VarunS1997 marked this conversation as resolved.
Show resolved Hide resolved
fname = os.path.basename(urllib.parse.urlsplit(url).path)
if not fname:
raise ValueError(
"Can't parse the file name from the origin provided: "
f"'{url}'."
"Please specify the `fname` as the input param."
)
fpath = os.path.join(datadir, fname)

if not os.path.exists(fpath):
io_utils.print_msg(f"Downloading data from {url}")
VarunS1997 marked this conversation as resolved.
Show resolved Hide resolved

with tf.io.gfile.GFile(url, "rb") as preset_file:
shutil.copyfileobj(preset_file, fpath)

return fpath


def check_file_exists(preset, path):
try:
get_file(preset, path)
Expand Down Expand Up @@ -243,10 +288,10 @@ def save_tokenizer_assets(tokenizer, preset):


def save_serialized_object(
layer,
preset,
config_file=CONFIG_FILE,
config_to_skip=[],
layer,
VarunS1997 marked this conversation as resolved.
Show resolved Hide resolved
preset,
config_file=CONFIG_FILE,
config_to_skip=[],
):
check_keras_3()
make_preset_dir(preset)
Expand Down Expand Up @@ -412,9 +457,9 @@ def delete_model_card(preset):

@keras_nlp_export("keras_nlp.upload_preset")
def upload_preset(
uri,
preset,
allow_incomplete=False,
uri,
preset,
allow_incomplete=False,
):
"""Upload a preset directory to a model hub.

Expand Down Expand Up @@ -516,18 +561,18 @@ def validate_metadata(preset):


def load_serialized_object(
preset,
config_file=CONFIG_FILE,
config_overrides={},
preset,
config_file=CONFIG_FILE,
config_overrides={},
):
config = load_config(preset, config_file)
config["config"] = {**config["config"], **config_overrides}
return keras.saving.deserialize_keras_object(config)


def check_config_class(
preset,
config_file=CONFIG_FILE,
preset,
config_file=CONFIG_FILE,
):
"""Validate a preset is being loaded on the correct class."""
config_path = get_file(preset, config_file)
Expand Down
Loading