Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed from_preset file downloading to use GFile when able #1665

Merged
merged 7 commits into from
Jun 15, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,18 @@
from packaging.version import parse

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.backend import config
from keras_nlp.src.backend import config as backend_config
from keras_nlp.src.backend import keras
from keras_nlp.src.utils.keras_utils import print_msg

try:
import tensorflow as tf
except ImportError:
raise ImportError(
"To use `keras_nlp`, please install Tensorflow: `pip install tensorflow`. "
"The TensorFlow package is required for data preprocessing with any backend."
)

try:
import kagglehub
Expand All @@ -43,6 +53,10 @@
GS_PREFIX = "gs://"
HF_PREFIX = "hf://"

KAGGLE_SCHEME = "kaggle"
GS_SCHEME = "gs"
HF_SCHEME = "hf"

TOKENIZER_ASSET_DIR = "assets/tokenizer"

# Config file names.
Expand Down Expand Up @@ -99,13 +113,15 @@ def get_file(preset, path):
)
if preset in BUILTIN_PRESETS:
preset = BUILTIN_PRESETS[preset]["kaggle_handle"]
if preset.startswith(KAGGLE_PREFIX):

scheme = preset.split("://")[0].lower()
VarunS1997 marked this conversation as resolved.
Show resolved Hide resolved
if scheme == KAGGLE_SCHEME:
if kagglehub is None:
raise ImportError(
"`from_preset()` requires the `kagglehub` package. "
"Please install with `pip install kagglehub`."
)
kaggle_handle = preset.removeprefix(KAGGLE_PREFIX)
kaggle_handle = preset.removeprefix(KAGGLE_SCHEME + "://")
num_segments = len(kaggle_handle.split("/"))
if num_segments not in (4, 5):
raise ValueError(
Expand Down Expand Up @@ -134,25 +150,23 @@ def get_file(preset, path):
else:
raise ValueError(message)

elif preset.startswith(GS_PREFIX):
elif scheme in tf.io.gfile.get_registered_schemes():
url = os.path.join(preset, path)
url = url.replace(GS_PREFIX, "https://storage.googleapis.com/")
subdir = preset.replace(GS_PREFIX, "gs_")
subdir = subdir.replace("/", "_").replace("-", "_")
subdir = preset.replace("://", "_").replace("-", "_")
filename = os.path.basename(path)
subdir = os.path.join(subdir, os.path.dirname(path))
return keras.utils.get_file(
return copy_gfile_to_cache(
filename,
url,
cache_subdir=os.path.join("models", subdir),
)
elif preset.startswith(HF_PREFIX):
elif scheme == HF_SCHEME:
if huggingface_hub is None:
raise ImportError(
f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. "
"Please install with `pip install huggingface_hub`."
)
hf_handle = preset.removeprefix(HF_PREFIX)
hf_handle = preset.removeprefix(HF_SCHEME + "://")
try:
return huggingface_hub.hf_hub_download(
repo_id=hf_handle, filename=path
Expand Down Expand Up @@ -192,6 +206,25 @@ def get_file(preset, path):
)


def copy_gfile_to_cache(filename, url, cache_subdir):
"""Much of this is adapted from get_file of keras core."""
if cache_subdir is None:
cache_dir = config.keras_home()

datadir_base = os.path.expanduser(cache_dir)
if not os.access(datadir_base, os.W_OK):
datadir_base = os.path.join("/tmp", ".keras")
datadir = os.path.join(datadir_base, cache_subdir)
os.makedirs(datadir, exist_ok=True)

fpath = os.path.join(datadir, filename)
if not os.path.exists(fpath):
print_msg(f"Downloading data from {url}")
tf.io.gfile.copy(url, fpath)

return fpath


def check_file_exists(preset, path):
try:
get_file(preset, path)
Expand Down
Loading