diff --git a/python/ray/air/constants.py b/python/ray/air/constants.py index 12dca587419f..1accc998eebd 100644 --- a/python/ray/air/constants.py +++ b/python/ray/air/constants.py @@ -45,3 +45,13 @@ COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV = ( "TRAIN_COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING" ) + +# Integer value which if set will disable lazy checkpointing +# (avoiding unnecessary serialization if worker is on the same node +# as Trainable) +DISABLE_LAZY_CHECKPOINTING_ENV = "TRAIN_DISABLE_LAZY_CHECKPOINTING" + +# Name of the marker dropped by the Trainable. If a worker detects +# the presence of the marker in the trial dir, it will use lazy +# checkpointing. +LAZY_CHECKPOINT_MARKER_FILE = ".lazy_checkpoint_marker" diff --git a/python/ray/air/session.py b/python/ray/air/session.py index 6fe5f3bb7ccd..b8747f56952b 100644 --- a/python/ray/air/session.py +++ b/python/ray/air/session.py @@ -7,6 +7,7 @@ from ray.air.constants import SESSION_MISUSE_LOG_ONCE_KEY from ray.train.session import _TrainSessionImpl from ray.util import log_once +from ray.util.annotations import PublicAPI if TYPE_CHECKING: from ray.data import DatasetIterator @@ -37,6 +38,7 @@ def wrapper(*args, **kwargs): return inner +@PublicAPI(stability="beta") @_warn_session_misuse() def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None: """Report metrics and optionally save a checkpoint. @@ -90,6 +92,7 @@ def train_func(): _get_session().report(metrics, checkpoint=checkpoint) +@PublicAPI(stability="beta") @_warn_session_misuse() def get_checkpoint() -> Optional[Checkpoint]: """Access the session's last checkpoint to resume from if applicable. @@ -140,30 +143,35 @@ def train_func(): return _get_session().loaded_checkpoint +@PublicAPI(stability="beta") @_warn_session_misuse() def get_experiment_name() -> str: """Experiment name for the corresponding trial.""" return _get_session().experiment_name +@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_name() -> str: """Trial name for the corresponding trial.""" return _get_session().trial_name +@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_id() -> str: """Trial id for the corresponding trial.""" return _get_session().trial_id +@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_resources() -> "PlacementGroupFactory": """Trial resources for the corresponding trial.""" return _get_session().trial_resources +@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_dir() -> str: """Log directory corresponding to the trial directory for a Tune session. @@ -186,6 +194,7 @@ def train_func(): return _get_session().trial_dir +@PublicAPI(stability="beta") @_warn_session_misuse(default_value=1) def get_world_size() -> int: """Get the current world size (i.e. total number of workers) for this run. @@ -216,6 +225,7 @@ def train_loop_per_worker(config): return session.world_size +@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_world_rank() -> int: """Get the world rank of this worker. @@ -249,6 +259,7 @@ def train_loop_per_worker(): return session.world_rank +@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_local_rank() -> int: """Get the local rank of this worker (rank of the worker on its node). @@ -281,6 +292,7 @@ def train_loop_per_worker(): return session.local_rank +@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_local_world_size() -> int: """Get the local rank of this worker (rank of the worker on its node). @@ -311,6 +323,7 @@ def get_local_world_size() -> int: return session.local_world_size +@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_node_rank() -> int: """Get the local rank of this worker (rank of the worker on its node). @@ -341,6 +354,7 @@ def get_node_rank() -> int: return session.node_rank +@PublicAPI(stability="beta") @_warn_session_misuse() def get_dataset_shard( dataset_name: Optional[str] = None, diff --git a/python/ray/train/_internal/checkpoint.py b/python/ray/train/_internal/checkpoint.py index 39bcddc8fbcb..a85f05d7c915 100644 --- a/python/ray/train/_internal/checkpoint.py +++ b/python/ray/train/_internal/checkpoint.py @@ -1,8 +1,9 @@ +import os import logging from pathlib import Path from typing import Callable, Dict, List, Optional, Type, Union -from ray.air import Checkpoint, CheckpointConfig +from ray.air import Checkpoint, CheckpointConfig, session from ray.air._internal.checkpoint_manager import CheckpointStorage from ray.air._internal.checkpoint_manager import ( _CheckpointManager as CommonCheckpointManager, @@ -16,6 +17,7 @@ TUNE_CHECKPOINT_ID, TUNE_INSTALLED, CHECKPOINT_METADATA_KEY, + LAZY_CHECKPOINT_MARKER_FILE, ) if TUNE_INSTALLED: @@ -209,6 +211,24 @@ def latest_checkpoint_id(self) -> Optional[int]: class TuneCheckpointManager(CheckpointManager): + def __init__( + self, + run_dir: Optional[Path] = None, + checkpoint_strategy: Optional[CheckpointConfig] = None, + ): + super().__init__(run_dir, checkpoint_strategy) + + # Name of the marker dropped by the Trainable. If a worker detects + # the presence of the marker in the trial dir, it will use lazy + # checkpointing. + self._lazy_marker_path = None + if tune.is_session_enabled(): + self._lazy_marker_path = ( + Path(session.get_trial_dir()) / LAZY_CHECKPOINT_MARKER_FILE + ) + with open(self._lazy_marker_path, "w"): + pass + def _load_checkpoint( self, checkpoint_to_load: Optional[Union[Dict, str, Path, Checkpoint]] ) -> Optional[Union[Dict, Checkpoint]]: @@ -247,6 +267,14 @@ def next_checkpoint_path(self) -> Optional[Path]: def _get_next_checkpoint_path(self) -> Optional[Path]: return None + def __del__(self): + try: + assert self._lazy_marker_path + os.remove(str(self._lazy_marker_path)) + except Exception: + pass + return super().__del__() + def _construct_checkpoint_path_name(checkpoint_id: int) -> str: return f"checkpoint_{checkpoint_id:06d}" diff --git a/python/ray/train/_internal/session.py b/python/ray/train/_internal/session.py index a623dc449a45..369261901f46 100644 --- a/python/ray/train/_internal/session.py +++ b/python/ray/train/_internal/session.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum, auto +from pathlib import Path from typing import Callable, Dict, Optional, Type, Union import ray @@ -25,6 +26,7 @@ TIME_TOTAL_S, TIMESTAMP, CHECKPOINT_METADATA_KEY, + LAZY_CHECKPOINT_MARKER_FILE, ) from ray.train.error import SessionMisuseError from ray.train.session import _TrainSessionImpl @@ -300,7 +302,7 @@ def checkpoint(self, checkpoint: Checkpoint): checkpoint and self.enable_lazy_checkpointing and checkpoint._local_path - and self.get_current_ip() == self.trial_info.driver_ip + and (Path(self.trial_info.logdir) / LAZY_CHECKPOINT_MARKER_FILE).exists() ): metadata.update({CHECKPOINT_METADATA_KEY: checkpoint._metadata}) checkpoint = str(checkpoint._local_path) diff --git a/python/ray/train/constants.py b/python/ray/train/constants.py index 71b499befeb0..4b6b1ac61bce 100644 --- a/python/ray/train/constants.py +++ b/python/ray/train/constants.py @@ -13,6 +13,8 @@ TRAIN_DATASET_KEY, WILDCARD_KEY, COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV, + DISABLE_LAZY_CHECKPOINTING_ENV, + LAZY_CHECKPOINT_MARKER_FILE, ) # Autofilled session.report() metrics. Keys should be consistent with Tune. @@ -64,10 +66,6 @@ # PACK to SPREAD. 1 for True, 0 for False. TRAIN_ENABLE_WORKER_SPREAD_ENV = "TRAIN_ENABLE_WORKER_SPREAD" -# Integer value which if set will disable lazy checkpointing -# (avoiding unnecessary serialization if worker is on the same node -# as Trainable) -DISABLE_LAZY_CHECKPOINTING_ENV = "TRAIN_DISABLE_LAZY_CHECKPOINTING" # Blacklist virtualized networking. DEFAULT_NCCL_SOCKET_IFNAME = "^lo,docker,veth" diff --git a/python/ray/train/tests/test_data_parallel_trainer_checkpointing.py b/python/ray/train/tests/test_data_parallel_trainer_checkpointing.py index c2caeedb1839..d4fdca28b848 100644 --- a/python/ray/train/tests/test_data_parallel_trainer_checkpointing.py +++ b/python/ray/train/tests/test_data_parallel_trainer_checkpointing.py @@ -43,7 +43,6 @@ def checkpoint_train_func(): ("dict", True), ("dir", True), ("lazy_dir", True), - ("dir", False), ("lazy_dir", False), ) diff --git a/python/ray/tune/execution/ray_trial_executor.py b/python/ray/tune/execution/ray_trial_executor.py index 878360e0b4be..2a95537f45f2 100644 --- a/python/ray/tune/execution/ray_trial_executor.py +++ b/python/ray/tune/execution/ray_trial_executor.py @@ -16,7 +16,10 @@ from ray.actor import ActorHandle from ray.air import Checkpoint, AcquiredResources, ResourceRequest from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint -from ray.air.constants import COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV +from ray.air.constants import ( + COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV, + DISABLE_LAZY_CHECKPOINTING_ENV, +) from ray.air.execution import ResourceManager from ray.air.execution.resources.placement_group import ( PlacementGroupResourceManager, @@ -46,6 +49,7 @@ "PL_DISABLE_FORK": "1" } ENV_VARS_TO_PROPAGATE = { + DISABLE_LAZY_CHECKPOINTING_ENV, COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV, "TUNE_CHECKPOINT_CLOUD_RETRY_NUM", "TUNE_CHECKPOINT_CLOUD_RETRY_WAIT_TIME_S", diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index dc4a739c66a9..11fcc64b0d71 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -27,6 +27,7 @@ delete_at_uri, is_non_local_path_uri, ) +from ray.air.constants import LAZY_CHECKPOINT_MARKER_FILE from ray.exceptions import RayActorError from ray.tune import TuneError from ray.tune.callback import Callback @@ -57,6 +58,7 @@ "./checkpoint_tmp*", "./save_to_object*", "./rank_*", + f"./{LAZY_CHECKPOINT_MARKER_FILE}", ]