Skip to content

Commit

Permalink
fix(datasets): Use put() and get() instead of copy in `TensorFl…
Browse files Browse the repository at this point in the history
…owModelDataset`'s `_save` and `_load` methods.

Signed-off-by: gitgud5000 <17186026+gitgud5000@users.noreply.github.com>
  • Loading branch information
gitgud5000 committed Sep 22, 2024
1 parent 552b973 commit 5e35f95
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _load(self) -> tf.keras.Model:
# We assume .keras
path = str(PurePath(tempdir) / TEMPORARY_KERAS_FILE) # noqa: PLW2901

self._fs.copy(load_path, path)
self._fs.get(load_path, path)

# Pass the local temporary directory/file path to keras.load_model
device_name = self._load_args.pop("tf_device", None)
Expand All @@ -169,7 +169,7 @@ def _save(self, data: tf.keras.Model) -> None:

# Use fsspec to take from local tempfile directory/file and
# put in ArbitraryFileSystem
self._fs.copy(path, save_path)
self._fs.put(path, save_path)

def _exists(self) -> bool:
try:
Expand Down

0 comments on commit 5e35f95

Please sign in to comment.