Skip to content

Commit

Permalink
[Train] Improve lazy checkpointing (ray-project#32233)
Browse files Browse the repository at this point in the history
This PR improves Train lazy checkpointing with NFS setups. Previously, the logic to determine whether lazy checkpointing should be used was dependent on whether the Train worker-actor was on the same node as the Trainable actor. The new logic instead has the Trainable actor drop a marker file in the Trial's directory. If a worker-actor can detect that file, it means it can access the same directory as the Trainable actor.

This PR also fixes lazy checkpointing env var propagation.

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
Yard1 authored and edoakes committed Mar 22, 2023
1 parent 3795f75 commit e777228
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 8 deletions.
10 changes: 10 additions & 0 deletions python/ray/air/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,13 @@
COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV = (
"TRAIN_COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING"
)

# Integer value which if set will disable lazy checkpointing
# (avoiding unnecessary serialization if worker is on the same node
# as Trainable)
DISABLE_LAZY_CHECKPOINTING_ENV = "TRAIN_DISABLE_LAZY_CHECKPOINTING"

# Name of the marker dropped by the Trainable. If a worker detects
# the presence of the marker in the trial dir, it will use lazy
# checkpointing.
LAZY_CHECKPOINT_MARKER_FILE = ".lazy_checkpoint_marker"
14 changes: 14 additions & 0 deletions python/ray/air/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray.air.constants import SESSION_MISUSE_LOG_ONCE_KEY
from ray.train.session import _TrainSessionImpl
from ray.util import log_once
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.data import DatasetIterator
Expand Down Expand Up @@ -37,6 +38,7 @@ def wrapper(*args, **kwargs):
return inner


@PublicAPI(stability="beta")
@_warn_session_misuse()
def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
"""Report metrics and optionally save a checkpoint.
Expand Down Expand Up @@ -90,6 +92,7 @@ def train_func():
_get_session().report(metrics, checkpoint=checkpoint)


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_checkpoint() -> Optional[Checkpoint]:
"""Access the session's last checkpoint to resume from if applicable.
Expand Down Expand Up @@ -140,30 +143,35 @@ def train_func():
return _get_session().loaded_checkpoint


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_experiment_name() -> str:
"""Experiment name for the corresponding trial."""
return _get_session().experiment_name


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_name() -> str:
"""Trial name for the corresponding trial."""
return _get_session().trial_name


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_id() -> str:
"""Trial id for the corresponding trial."""
return _get_session().trial_id


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_resources() -> "PlacementGroupFactory":
"""Trial resources for the corresponding trial."""
return _get_session().trial_resources


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_dir() -> str:
"""Log directory corresponding to the trial directory for a Tune session.
Expand All @@ -186,6 +194,7 @@ def train_func():
return _get_session().trial_dir


@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=1)
def get_world_size() -> int:
"""Get the current world size (i.e. total number of workers) for this run.
Expand Down Expand Up @@ -216,6 +225,7 @@ def train_loop_per_worker(config):
return session.world_size


@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_world_rank() -> int:
"""Get the world rank of this worker.
Expand Down Expand Up @@ -249,6 +259,7 @@ def train_loop_per_worker():
return session.world_rank


@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_local_rank() -> int:
"""Get the local rank of this worker (rank of the worker on its node).
Expand Down Expand Up @@ -281,6 +292,7 @@ def train_loop_per_worker():
return session.local_rank


@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_local_world_size() -> int:
"""Get the local rank of this worker (rank of the worker on its node).
Expand Down Expand Up @@ -311,6 +323,7 @@ def get_local_world_size() -> int:
return session.local_world_size


@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_node_rank() -> int:
"""Get the local rank of this worker (rank of the worker on its node).
Expand Down Expand Up @@ -341,6 +354,7 @@ def get_node_rank() -> int:
return session.node_rank


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_dataset_shard(
dataset_name: Optional[str] = None,
Expand Down
30 changes: 29 additions & 1 deletion python/ray/train/_internal/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import logging
from pathlib import Path
from typing import Callable, Dict, List, Optional, Type, Union

from ray.air import Checkpoint, CheckpointConfig
from ray.air import Checkpoint, CheckpointConfig, session
from ray.air._internal.checkpoint_manager import CheckpointStorage
from ray.air._internal.checkpoint_manager import (
_CheckpointManager as CommonCheckpointManager,
Expand All @@ -16,6 +17,7 @@
TUNE_CHECKPOINT_ID,
TUNE_INSTALLED,
CHECKPOINT_METADATA_KEY,
LAZY_CHECKPOINT_MARKER_FILE,
)

if TUNE_INSTALLED:
Expand Down Expand Up @@ -209,6 +211,24 @@ def latest_checkpoint_id(self) -> Optional[int]:


class TuneCheckpointManager(CheckpointManager):
def __init__(
self,
run_dir: Optional[Path] = None,
checkpoint_strategy: Optional[CheckpointConfig] = None,
):
super().__init__(run_dir, checkpoint_strategy)

# Name of the marker dropped by the Trainable. If a worker detects
# the presence of the marker in the trial dir, it will use lazy
# checkpointing.
self._lazy_marker_path = None
if tune.is_session_enabled():
self._lazy_marker_path = (
Path(session.get_trial_dir()) / LAZY_CHECKPOINT_MARKER_FILE
)
with open(self._lazy_marker_path, "w"):
pass

def _load_checkpoint(
self, checkpoint_to_load: Optional[Union[Dict, str, Path, Checkpoint]]
) -> Optional[Union[Dict, Checkpoint]]:
Expand Down Expand Up @@ -247,6 +267,14 @@ def next_checkpoint_path(self) -> Optional[Path]:
def _get_next_checkpoint_path(self) -> Optional[Path]:
return None

def __del__(self):
try:
assert self._lazy_marker_path
os.remove(str(self._lazy_marker_path))
except Exception:
pass
return super().__del__()


def _construct_checkpoint_path_name(checkpoint_id: int) -> str:
return f"checkpoint_{checkpoint_id:06d}"
4 changes: 3 additions & 1 deletion python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum, auto
from pathlib import Path
from typing import Callable, Dict, Optional, Type, Union

import ray
Expand All @@ -25,6 +26,7 @@
TIME_TOTAL_S,
TIMESTAMP,
CHECKPOINT_METADATA_KEY,
LAZY_CHECKPOINT_MARKER_FILE,
)
from ray.train.error import SessionMisuseError
from ray.train.session import _TrainSessionImpl
Expand Down Expand Up @@ -300,7 +302,7 @@ def checkpoint(self, checkpoint: Checkpoint):
checkpoint
and self.enable_lazy_checkpointing
and checkpoint._local_path
and self.get_current_ip() == self.trial_info.driver_ip
and (Path(self.trial_info.logdir) / LAZY_CHECKPOINT_MARKER_FILE).exists()
):
metadata.update({CHECKPOINT_METADATA_KEY: checkpoint._metadata})
checkpoint = str(checkpoint._local_path)
Expand Down
6 changes: 2 additions & 4 deletions python/ray/train/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
TRAIN_DATASET_KEY,
WILDCARD_KEY,
COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
DISABLE_LAZY_CHECKPOINTING_ENV,
LAZY_CHECKPOINT_MARKER_FILE,
)

# Autofilled session.report() metrics. Keys should be consistent with Tune.
Expand Down Expand Up @@ -64,10 +66,6 @@
# PACK to SPREAD. 1 for True, 0 for False.
TRAIN_ENABLE_WORKER_SPREAD_ENV = "TRAIN_ENABLE_WORKER_SPREAD"

# Integer value which if set will disable lazy checkpointing
# (avoiding unnecessary serialization if worker is on the same node
# as Trainable)
DISABLE_LAZY_CHECKPOINTING_ENV = "TRAIN_DISABLE_LAZY_CHECKPOINTING"

# Blacklist virtualized networking.
DEFAULT_NCCL_SOCKET_IFNAME = "^lo,docker,veth"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def checkpoint_train_func():
("dict", True),
("dir", True),
("lazy_dir", True),
("dir", False),
("lazy_dir", False),
)

Expand Down
6 changes: 5 additions & 1 deletion python/ray/tune/execution/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from ray.actor import ActorHandle
from ray.air import Checkpoint, AcquiredResources, ResourceRequest
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
from ray.air.constants import COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV
from ray.air.constants import (
COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
DISABLE_LAZY_CHECKPOINTING_ENV,
)
from ray.air.execution import ResourceManager
from ray.air.execution.resources.placement_group import (
PlacementGroupResourceManager,
Expand Down Expand Up @@ -46,6 +49,7 @@
"PL_DISABLE_FORK": "1"
}
ENV_VARS_TO_PROPAGATE = {
DISABLE_LAZY_CHECKPOINTING_ENV,
COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
"TUNE_CHECKPOINT_CLOUD_RETRY_NUM",
"TUNE_CHECKPOINT_CLOUD_RETRY_WAIT_TIME_S",
Expand Down
2 changes: 2 additions & 0 deletions python/ray/tune/syncer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
delete_at_uri,
is_non_local_path_uri,
)
from ray.air.constants import LAZY_CHECKPOINT_MARKER_FILE
from ray.exceptions import RayActorError
from ray.tune import TuneError
from ray.tune.callback import Callback
Expand Down Expand Up @@ -57,6 +58,7 @@
"./checkpoint_tmp*",
"./save_to_object*",
"./rank_*",
f"./{LAZY_CHECKPOINT_MARKER_FILE}",
]


Expand Down

0 comments on commit e777228

Please sign in to comment.