Skip to content

Commit

Permalink
Fix GFile downloads (#1666)
Browse files Browse the repository at this point in the history
Previous version seemed to be broken

- Used a local variable that was never set.
- Used a utility from Keras that did not exist.
- Did not handle not found errors properly.
  • Loading branch information
mattdangerw committed Jun 17, 2024
1 parent a149634 commit a0d5cd4
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from packaging.version import parse

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.backend import config
from keras_nlp.src.backend import config as backend_config
from keras_nlp.src.backend import keras
from keras_nlp.src.utils.keras_utils import print_msg
Expand Down Expand Up @@ -155,14 +154,19 @@ def get_file(preset, path):

elif scheme in tf.io.gfile.get_registered_schemes():
url = os.path.join(preset, path)
subdir = preset.replace("://", "_").replace("-", "_")
subdir = preset.replace("://", "_").replace("-", "_").replace("/", "_")
filename = os.path.basename(path)
subdir = os.path.join(subdir, os.path.dirname(path))
return copy_gfile_to_cache(
filename,
url,
cache_subdir=os.path.join("models", subdir),
)
try:
return copy_gfile_to_cache(
filename,
url,
cache_subdir=os.path.join("models", subdir),
)
except tf.errors.NotFoundError as e:
raise FileNotFoundError(
f"`{path}` doesn't exist in preset directory `{preset}`.",
) from e
elif scheme == HF_SCHEME:
if huggingface_hub is None:
raise ImportError(
Expand Down Expand Up @@ -211,16 +215,16 @@ def get_file(preset, path):

def copy_gfile_to_cache(filename, url, cache_subdir):
"""Much of this is adapted from get_file of keras core."""
if cache_subdir is None:
cache_dir = config.keras_home()

datadir_base = os.path.expanduser(cache_dir)
if not os.access(datadir_base, os.W_OK):
datadir_base = os.path.join("/tmp", ".keras")
datadir = os.path.join(datadir_base, cache_subdir)
os.makedirs(datadir, exist_ok=True)
if "KERAS_HOME" in os.environ:
cachdir_base = os.environ.get("KERAS_HOME")
else:
cachdir_base = os.path.expanduser(os.path.join("~", ".keras"))
if not os.access(cachdir_base, os.W_OK):
cachdir_base = os.path.join("/tmp", ".keras")
cachedir = os.path.join(cachdir_base, cache_subdir)
os.makedirs(cachedir, exist_ok=True)

fpath = os.path.join(datadir, filename)
fpath = os.path.join(cachedir, filename)
if not os.path.exists(fpath):
print_msg(f"Downloading data from {url}")
tf.io.gfile.copy(url, fpath)
Expand Down

0 comments on commit a0d5cd4

Please sign in to comment.