Skip to content

Commit

Permalink
[AIR] Make Checkpoint.get_preprocessor faster (ray-project#32350)
Browse files Browse the repository at this point in the history
Previously, `get_preprocessor` would always serialize the Checkpoint into a dictionary first. This is incredibly wasteful and causes huge memory usage and runtime with large directory-based Checkpoints. This PR changes the logic to first see if a directory Checkpoint should be loaded into a dictionary or not in order to obtain the preprocessor.

Context: I had ran into it when trying to do predictions with 25 GB Hugging Face model. `HuggingFacePredictor` calls `get_preprocessor` internally, and that takes ages to complete and almost caused an OOM for me - and all of that is unnecessary as the preprocessor has to be loaded from a file anyway.

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
Yard1 authored and edoakes committed Mar 22, 2023
1 parent 28d0670 commit 3795f75
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
40 changes: 33 additions & 7 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,17 +764,26 @@ def __fspath__(self):
def get_preprocessor(self) -> Optional["Preprocessor"]:
"""Return the saved preprocessor, if one exists."""

if self._override_preprocessor:
return self._override_preprocessor

# The preprocessor will either be stored in an in-memory dict or
# written to storage. In either case, it will use the PREPROCESSOR_KEY key.

# First try converting to dictionary.
checkpoint_dict = self.to_dict()
preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None)

if preprocessor is None:
# Fallback to reading from directory.
# If this is a pure directory checkpoint (not a dict checkpoint saved to dir),
# then do not convert to dictionary as that takes a lot of time and memory.
if self.uri:
with self.as_directory() as checkpoint_path:
preprocessor = load_preprocessor_from_dir(checkpoint_path)
if _is_persisted_directory_checkpoint(checkpoint_path):
# If this is a persisted directory checkpoint, then we load the
# files from the temp directory created by the context.
# That way we avoid having to download the files twice.
loaded_checkpoint = self.from_directory(checkpoint_path)
preprocessor = _get_preprocessor(loaded_checkpoint)
else:
preprocessor = load_preprocessor_from_dir(checkpoint_path)
else:
preprocessor = _get_preprocessor(self)

return preprocessor

Expand Down Expand Up @@ -859,3 +868,20 @@ def _make_dir(path: str, acquire_del_lock: bool = True) -> None:
open(del_lock_path, "a").close()

os.makedirs(path, exist_ok=True)


def _is_persisted_directory_checkpoint(path: str) -> bool:
return Path(path, _DICT_CHECKPOINT_FILE_NAME).exists()


def _get_preprocessor(checkpoint: "Checkpoint") -> Optional["Preprocessor"]:
# First try converting to dictionary.
checkpoint_dict = checkpoint.to_dict()
preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None)

if preprocessor is None:
# Fallback to reading from directory.
with checkpoint.as_directory() as checkpoint_path:
preprocessor = load_preprocessor_from_dir(checkpoint_path)

return preprocessor
7 changes: 7 additions & 0 deletions python/ray/air/tests/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,13 @@ def testDirCheckpointSetPreprocessor(self):
preprocessor = checkpoint.get_preprocessor()
assert preprocessor.multiplier == 1

# Also check that loading from dir works
new_checkpoint_dir = os.path.join(tmpdir, "new_checkpoint")
checkpoint.to_directory(new_checkpoint_dir)
checkpoint = Checkpoint.from_directory(new_checkpoint_dir)
preprocessor = checkpoint.get_preprocessor()
assert preprocessor.multiplier == 1

def testDirCheckpointSetPreprocessorAsDict(self):
with tempfile.TemporaryDirectory() as tmpdir:
preprocessor = DummyPreprocessor(1)
Expand Down

0 comments on commit 3795f75

Please sign in to comment.