From c96f8b4f9f944ecff676f784d5044f6e4251f034 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 7 Jul 2023 18:40:28 -0700 Subject: [PATCH] [1/n] Lightweight Ray AIR API refactor (#36706) This PR removes some circularities in the Ray AIR import system so we can put the training related functions into `ray.train`. It introduces a training context and makes report, get_dataset_shard, Checkpoint, Result, and the following configs: - CheckpointConfig - DataConfig - FailureConfig - RunConfig - ScalingConfig available in `ray.train`. No user facing changes yet, the old APIs still work. Going forward, it will be most consistent / symmetrical if these things are included in the following way: ```python from ray import train, tune, serve # Pick the subset that is needed # Include what you need from the following: from ray.train import CheckpointConfig, DataConfig, FailureConfig, RunConfig, ScalingConfig # ... def train_func(): dataset_shard = train.get_dataset_shard("train") world_size = train.get_context().get_world_size() # ... train.report(...) trainer = train.torch.TorchTrainer( train_func, scaling_config=ScalingConfig(num_workers=2), ) result = trainer.fit() ``` We have many examples in https://github.com/ray-project/ray/pull/37123 on how this looks like in actual code. Signed-off-by: Bhavpreet Singh --- doc/source/train/api/api.rst | 19 + python/ray/air/checkpoint.py | 10 +- python/ray/air/session.py | 407 +-------------------- python/ray/train/__init__.py | 14 + python/ray/train/_internal/checkpoint.py | 3 +- python/ray/train/_internal/session.py | 399 +++++++++++++++++++- python/ray/train/base_trainer.py | 2 +- python/ray/train/context.py | 86 +++++ python/ray/train/data_parallel_trainer.py | 2 +- python/ray/train/torch/config.py | 2 +- python/ray/train/torch/train_loop_utils.py | 2 +- 11 files changed, 529 insertions(+), 417 deletions(-) create mode 100644 python/ray/train/context.py diff --git a/doc/source/train/api/api.rst b/doc/source/train/api/api.rst index 4405fd87e63c..7bd0100e1e7e 100644 --- a/doc/source/train/api/api.rst +++ b/doc/source/train/api/api.rst @@ -52,6 +52,25 @@ Train Backend Base Classes ~train.backend.Backend ~train.backend.BackendConfig +Ray Train Config +---------------- + +.. autosummary:: + + ~ray.train.DataConfig + + +Ray Train Loop +-------------- + +.. autosummary:: + :toctree: doc/ + + ~train.context.TrainContext + ~train.get_context + ~train.get_dataset_shard + ~train.report + .. _train-integration-api: .. _train-framework-specific-ckpts: diff --git a/python/ray/air/checkpoint.py b/python/ray/air/checkpoint.py index 76268f91950c..7f022fcc7f61 100644 --- a/python/ray/air/checkpoint.py +++ b/python/ray/air/checkpoint.py @@ -331,7 +331,7 @@ def from_bytes(cls, data: bytes) -> "Checkpoint": data: Data object containing pickled checkpoint data. Returns: - Checkpoint: checkpoint object. + ray.air.checkpoint.Checkpoint: checkpoint object. """ bytes_data = pickle.loads(data) if isinstance(bytes_data, dict): @@ -360,7 +360,7 @@ def from_dict(cls, data: dict) -> "Checkpoint": data: Dictionary containing checkpoint data. Returns: - Checkpoint: checkpoint object. + ray.air.checkpoint.Checkpoint: checkpoint object. """ state = {} if _METADATA_KEY in data: @@ -455,7 +455,7 @@ def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint": Checkpoint). Returns: - Checkpoint: checkpoint object. + ray.air.checkpoint.Checkpoint: checkpoint object. """ state = {} @@ -474,7 +474,7 @@ def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint": @classmethod @DeveloperAPI def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint": - """Create a checkpoint from a generic :class:`Checkpoint`. + """Create a checkpoint from a generic :class:`ray.air.checkpoint.Checkpoint`. This method can be used to create a framework-specific checkpoint from a generic :class:`Checkpoint` object. @@ -715,7 +715,7 @@ def from_uri(cls, uri: str) -> "Checkpoint": uri: Source location URI to read data from. Returns: - Checkpoint: checkpoint object. + ray.air.checkpoint.Checkpoint: checkpoint object. """ state = {} try: diff --git a/python/ray/air/session.py b/python/ray/air/session.py index fcd8de0e21b8..d7d4e87a40cc 100644 --- a/python/ray/air/session.py +++ b/python/ray/air/session.py @@ -1,405 +1,2 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional -import warnings -import functools - -from ray.air._internal.session import _get_session -from ray.air.checkpoint import Checkpoint -from ray.air.constants import SESSION_MISUSE_LOG_ONCE_KEY -from ray.util import log_once -from ray.util.annotations import PublicAPI - -if TYPE_CHECKING: - from ray.data import DataIterator - from ray.tune.execution.placement_groups import PlacementGroupFactory - - -def _warn_session_misuse(default_value: Any = None): - """Warns if fn is being used outside of session and returns ``default_value``.""" - - def inner(fn: Callable): - fn_name = fn.__name__ - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - session = _get_session() - if not session: - if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"): - warnings.warn( - f"`{fn_name}` is meant to only be " - "called inside a function that is executed by a Tuner" - f" or Trainer. Returning `{default_value}`." - ) - return default_value - return fn(*args, **kwargs) - - return wrapper - - return inner - - -@PublicAPI(stability="beta") -@_warn_session_misuse() -def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None: - """Report metrics and optionally save a checkpoint. - - Each invocation of this method will automatically increment the underlying - iteration number. The physical meaning of this "iteration" is defined by - user (or more specifically the way they call ``report``). - It does not necessarily map to one epoch. - - This API is the canonical way to report metrics from Tune and Train, and - replaces the legacy ``tune.report``, ``with tune.checkpoint_dir``, - ``train.report`` and ``train.save_checkpoint`` calls. - - Note on directory checkpoints: AIR will take ownership of checkpoints passed - to ``report()`` by moving them to a new path. The original directory will no - longer be accessible to the caller after the report call. - - Example: - .. code-block: python - - from ray.air import session - from ray.air.checkpoint import Checkpoint - from ray.air.config import ScalingConfig - - ######## Using it in the *per worker* train loop (TrainSession) ####### - def train_func(): - model = build_model() - model.save("my_model", overwrite=True) - session.report( - metrics={"foo": "bar"}, - checkpoint=Checkpoint.from_directory(temp_dir.name) - ) - # Air guarantees by this point, you can safely write new stuff to - # "my_model" directory. - - scaling_config = ScalingConfig(num_workers=2) - trainer = TensorflowTrainer( - train_loop_per_worker=train_func, scaling_config=scaling_config - ) - result = trainer.fit() - # If you navigate to result.checkpoint's path, you will find the - content of ``model.save()`` under it. - # If you have `SyncConfig` configured, the content should also - # show up in the corresponding cloud storage path. - - Args: - metrics: The metrics you want to report. - checkpoint: The optional checkpoint you want to report. - """ - - _get_session().report(metrics, checkpoint=checkpoint) - - -@PublicAPI(stability="beta") -@_warn_session_misuse() -def get_checkpoint() -> Optional[Checkpoint]: - """Access the session's last checkpoint to resume from if applicable. - - Returns: - Checkpoint object if the session is currently being resumed. - Otherwise, return None. - - .. code-block:: python - - ######## Using it in the *per worker* train loop (TrainSession) ###### - from ray.air import session - from ray.air.checkpoint import Checkpoint - from ray.air.config import ScalingConfig - def train_func(): - ckpt = session.get_checkpoint() - if ckpt: - with ckpt.as_directory() as loaded_checkpoint_dir: - import tensorflow as tf - - model = tf.keras.models.load_model(loaded_checkpoint_dir) - else: - model = build_model() - - model.save("my_model", overwrite=True) - session.report( - metrics={"iter": 1}, - checkpoint=Checkpoint.from_directory("my_model") - ) - - scaling_config = ScalingConfig(num_workers=2) - trainer = TensorflowTrainer( - train_loop_per_worker=train_func, scaling_config=scaling_config - ) - result = trainer.fit() - - # trainer2 will pick up from the checkpoint saved by trainer1. - trainer2 = TensorflowTrainer( - train_loop_per_worker=train_func, - scaling_config=scaling_config, - # this is ultimately what is accessed through - # ``Session.get_checkpoint()`` - resume_from_checkpoint=result.checkpoint, - ) - result2 = trainer2.fit() - """ - - return _get_session().loaded_checkpoint - - -@PublicAPI(stability="beta") -@_warn_session_misuse() -def get_experiment_name() -> str: - """Experiment name for the corresponding trial.""" - return _get_session().experiment_name - - -@PublicAPI(stability="beta") -@_warn_session_misuse() -def get_trial_name() -> str: - """Trial name for the corresponding trial.""" - return _get_session().trial_name - - -@PublicAPI(stability="beta") -@_warn_session_misuse() -def get_trial_id() -> str: - """Trial id for the corresponding trial.""" - return _get_session().trial_id - - -@PublicAPI(stability="beta") -@_warn_session_misuse() -def get_trial_resources() -> "PlacementGroupFactory": - """Trial resources for the corresponding trial.""" - return _get_session().trial_resources - - -@PublicAPI(stability="beta") -@_warn_session_misuse() -def get_trial_dir() -> str: - """Log directory corresponding to the trial directory for a Tune session. - If calling from a Train session, this will give the trial directory of its parent - Tune session. - - .. code-block:: python - - from ray import tune - from ray.air import session - - def train_func(): - # Example: - # >>> session.get_trial_dir() - # ~/ray_results// - - tuner = tune.Tuner(train_func) - tuner.fit() - """ - return _get_session().trial_dir - - -@PublicAPI(stability="beta") -@_warn_session_misuse(default_value=1) -def get_world_size() -> int: - """Get the current world size (i.e. total number of workers) for this run. - - .. code-block:: python - - import time - from ray.air import session - from ray.air.config import ScalingConfig - - def train_loop_per_worker(config): - assert session.get_world_size() == 4 - - train_dataset = ray.data.from_items( - [{"x": x, "y": x + 1} for x in range(32)]) - trainer = TensorflowTrainer(train_loop_per_worker, - scaling_config=ScalingConfig(num_workers=1), - datasets={"train": train_dataset}) - trainer.fit() - """ - session = _get_session() - if not hasattr(session, "world_size"): - raise RuntimeError( - "`get_world_size` can only be called for TrainSession! " - "Make sure you only use that in `train_loop_per_worker` function" - "that is passed into `DataParallelTrainer`." - ) - return session.world_size - - -@PublicAPI(stability="beta") -@_warn_session_misuse(default_value=0) -def get_world_rank() -> int: - """Get the world rank of this worker. - - .. code-block:: python - - import time - from ray.air import session - from ray.air.config import ScalingConfig - - def train_loop_per_worker(): - for iter in range(100): - time.sleep(1) - if session.get_world_rank() == 0: - print("Worker 0") - - train_dataset = ray.data.from_items( - [{"x": x, "y": x + 1} for x in range(32)]) - trainer = TensorflowTrainer(train_loop_per_worker, - scaling_config=ScalingConfig(num_workers=1), - datasets={"train": train_dataset}) - trainer.fit() - """ - session = _get_session() - if not hasattr(session, "world_rank"): - raise RuntimeError( - "`get_world_rank` can only be called for TrainSession! " - "Make sure you only use that in `train_loop_per_worker` function" - "that is passed into `DataParallelTrainer`." - ) - return session.world_rank - - -@PublicAPI(stability="beta") -@_warn_session_misuse(default_value=0) -def get_local_rank() -> int: - """Get the local rank of this worker (rank of the worker on its node). - - .. code-block:: python - - import time - from ray.air import session - from ray.air.config import ScalingConfig - - def train_loop_per_worker(): - if torch.cuda.is_available(): - torch.cuda.set_device(session.get_local_rank()) - ... - - train_dataset = ray.data.from_items( - [{"x": x, "y": x + 1} for x in range(32)]) - trainer = TensorflowTrainer(train_loop_per_worker, - scaling_config=ScalingConfig(num_workers=1), - datasets={"train": train_dataset}) - trainer.fit() - """ - session = _get_session() - if not hasattr(session, "local_rank"): - raise RuntimeError( - "`get_local_rank` can only be called for TrainSession! " - "Make sure you only use that in `train_loop_per_worker` function" - "that is passed into `DataParallelTrainer`." - ) - return session.local_rank - - -@PublicAPI(stability="beta") -@_warn_session_misuse(default_value=0) -def get_local_world_size() -> int: - """Get the local rank of this worker (rank of the worker on its node). - - Example: - >>> import ray - >>> from ray.air import session - >>> from ray.air.config import ScalingConfig - >>> from ray.train.torch import TorchTrainer - >>> - >>> def train_loop_per_worker(): - ... return session.get_local_world_size() - >>> - >>> train_dataset = ray.data.from_items( - ... [{"x": x, "y": x + 1} for x in range(32)]) - >>> trainer = TorchTrainer(train_loop_per_worker, - ... scaling_config=ScalingConfig(num_workers=1), - ... datasets={"train": train_dataset}) - >>> trainer.fit() # doctest: +SKIP - """ - session = _get_session() - if not hasattr(session, "local_world_size"): - raise RuntimeError( - "`get_local_world_size` can only be called for TrainSession! " - "Make sure you only use that in `train_loop_per_worker` function" - "that is passed into `DataParallelTrainer`." - ) - return session.local_world_size - - -@PublicAPI(stability="beta") -@_warn_session_misuse(default_value=0) -def get_node_rank() -> int: - """Get the local rank of this worker (rank of the worker on its node). - - Example: - >>> import ray - >>> from ray.air import session - >>> from ray.air.config import ScalingConfig - >>> from ray.train.torch import TorchTrainer - >>> - >>> def train_loop_per_worker(): - ... return session.get_node_rank() - >>> - >>> train_dataset = ray.data.from_items( - ... [{"x": x, "y": x + 1} for x in range(32)]) - >>> trainer = TorchTrainer(train_loop_per_worker, - ... scaling_config=ScalingConfig(num_workers=1), - ... datasets={"train": train_dataset}) - >>> trainer.fit() # doctest: +SKIP - """ - session = _get_session() - if not hasattr(session, "node_rank"): - raise RuntimeError( - "`get_node_rank` can only be called for TrainSession! " - "Make sure you only use that in `train_loop_per_worker` function" - "that is passed into `DataParallelTrainer`." - ) - return session.node_rank - - -@PublicAPI(stability="beta") -@_warn_session_misuse() -def get_dataset_shard( - dataset_name: Optional[str] = None, -) -> Optional["DataIterator"]: - """Returns the :class:`ray.data.DataIterator` shard for this worker. - - Call :meth:`~ray.data.DataIterator.iter_torch_batches` or - :meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the - appropriate framework-specific data type. - - .. code-block:: python - - import ray - from ray import train - from ray.air import session - from ray.air.config import ScalingConfig - - def train_loop_per_worker(): - model = Net() - for iter in range(100): - # Trainer will automatically handle sharding. - data_shard = session.get_dataset_shard("train") - for batch in data_shard.iter_torch_batches(): - # ... - return model - - train_dataset = ray.data.from_items( - [{"x": x, "y": x + 1} for x in range(32)]) - trainer = TorchTrainer(train_loop_per_worker, - scaling_config=ScalingConfig(num_workers=2), - datasets={"train": train_dataset}) - trainer.fit() - - Args: - dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then - specifies which dataset shard to return. - - Returns: - The ``DataIterator`` shard to use for this worker. - If no dataset is passed into Trainer, then return None. - """ - session = _get_session() - if not hasattr(session, "get_dataset_shard"): - raise RuntimeError( - "`get_dataset_shard` can only be called for TrainSession! " - "Make sure you only use that in `train_loop_per_worker` function" - "that is passed into `DataParallelTrainer`." - ) - return session.get_dataset_shard(dataset_name) +from ray.air._internal.session import _get_session # noqa: F401 +from ray.train._internal.session import * # noqa: F401,F403 diff --git a/python/ray/train/__init__.py b/python/ray/train/__init__.py index 388c6480fe1c..6933175d45e9 100644 --- a/python/ray/train/__init__.py +++ b/python/ray/train/__init__.py @@ -1,15 +1,29 @@ from ray._private.usage import usage_lib from ray.train.backend import BackendConfig from ray.train.data_config import DataConfig +from ray.train.context import get_context from ray.train.constants import TRAIN_DATASET_KEY +from ray.train._internal.session import get_dataset_shard, report from ray.train.trainer import TrainingIterator +from ray.air import Checkpoint +from ray.air.config import CheckpointConfig, FailureConfig, RunConfig, ScalingConfig +from ray.air.result import Result usage_lib.record_library_usage("train") __all__ = [ + "get_context", + "get_dataset_shard", + "report", "BackendConfig", + "Checkpoint", + "CheckpointConfig", "DataConfig", + "FailureConfig", + "Result", + "RunConfig", + "ScalingConfig", "TrainingIterator", "TRAIN_DATASET_KEY", ] diff --git a/python/ray/train/_internal/checkpoint.py b/python/ray/train/_internal/checkpoint.py index 82c25febb8b0..ee8e08cbb2f3 100644 --- a/python/ray/train/_internal/checkpoint.py +++ b/python/ray/train/_internal/checkpoint.py @@ -3,12 +3,13 @@ from pathlib import Path from typing import Callable, Dict, List, Optional, Type, Union -from ray.air import Checkpoint, CheckpointConfig, session +from ray.air import Checkpoint, CheckpointConfig from ray.air._internal.checkpoint_manager import CheckpointStorage from ray.air._internal.checkpoint_manager import ( _CheckpointManager as CommonCheckpointManager, ) from ray.air._internal.checkpoint_manager import _TrackedCheckpoint +from ray.train._internal import session from ray.train._internal.session import TrainingResult from ray.train._internal.utils import construct_path from ray.train.constants import ( diff --git a/python/ray/train/_internal/session.py b/python/ray/train/_internal/session.py index 023dcd0a4cfa..e46e26e8fc83 100644 --- a/python/ray/train/_internal/session.py +++ b/python/ray/train/_internal/session.py @@ -8,17 +8,20 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum, auto +import functools from pathlib import Path import shutil -from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union import warnings import ray +from ray.air._internal.session import _get_session from ray.air._internal.util import StartTraceback, RunnerThread from ray.air.checkpoint import Checkpoint from ray.air.constants import ( _RESULT_FETCH_TIMEOUT, _ERROR_FETCH_TIMEOUT, + SESSION_MISUSE_LOG_ONCE_KEY, TIMESTAMP, TIME_THIS_ITER_S, ) @@ -36,7 +39,7 @@ ) from ray.train.error import SessionMisuseError -from ray.util.annotations import DeveloperAPI +from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.debug import log_once @@ -551,3 +554,395 @@ def set_accelerator(accelerator: Accelerator) -> None: if session.accelerator is not None: raise RuntimeError("Cannot change accelerator once set.") session.accelerator = accelerator + + +def _warn_session_misuse(default_value: Any = None): + """Warns if fn is being used outside of session and returns ``default_value``.""" + + def inner(fn: Callable): + fn_name = fn.__name__ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + session = _get_session() + if not session: + if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"): + warnings.warn( + f"`{fn_name}` is meant to only be " + "called inside a function that is executed by a Tuner" + f" or Trainer. Returning `{default_value}`." + ) + return default_value + return fn(*args, **kwargs) + + return wrapper + + return inner + + +@PublicAPI(stability="beta") +@_warn_session_misuse() +def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None: + """Report metrics and optionally save a checkpoint. + + Each invocation of this method will automatically increment the underlying + iteration number. The physical meaning of this "iteration" is defined by + user (or more specifically the way they call ``report``). + It does not necessarily map to one epoch. + + This API is the canonical way to report metrics from Tune and Train, and + replaces the legacy ``tune.report``, ``with tune.checkpoint_dir``, + ``train.report`` and ``train.save_checkpoint`` calls. + + Note on directory checkpoints: AIR will take ownership of checkpoints passed + to ``report()`` by moving them to a new path. The original directory will no + longer be accessible to the caller after the report call. + + Example: + .. code-block: python + + from ray.air import session + from ray.air.checkpoint import Checkpoint + from ray.air.config import ScalingConfig + + ######## Using it in the *per worker* train loop (TrainSession) ####### + def train_func(): + model = build_model() + model.save("my_model", overwrite=True) + session.report( + metrics={"foo": "bar"}, + checkpoint=Checkpoint.from_directory(temp_dir.name) + ) + # Air guarantees by this point, you can safely write new stuff to + # "my_model" directory. + + scaling_config = ScalingConfig(num_workers=2) + trainer = TensorflowTrainer( + train_loop_per_worker=train_func, scaling_config=scaling_config + ) + result = trainer.fit() + # If you navigate to result.checkpoint's path, you will find the + content of ``model.save()`` under it. + # If you have `SyncConfig` configured, the content should also + # show up in the corresponding cloud storage path. + + Args: + metrics: The metrics you want to report. + checkpoint: The optional checkpoint you want to report. + """ + + _get_session().report(metrics, checkpoint=checkpoint) + + +@PublicAPI(stability="beta") +@_warn_session_misuse() +def get_checkpoint() -> Optional[Checkpoint]: + """Access the session's last checkpoint to resume from if applicable. + + Returns: + Checkpoint object if the session is currently being resumed. + Otherwise, return None. + + .. code-block:: python + + ######## Using it in the *per worker* train loop (TrainSession) ###### + from ray.air import session + from ray.air.checkpoint import Checkpoint + from ray.air.config import ScalingConfig + def train_func(): + ckpt = session.get_checkpoint() + if ckpt: + with ckpt.as_directory() as loaded_checkpoint_dir: + import tensorflow as tf + + model = tf.keras.models.load_model(loaded_checkpoint_dir) + else: + model = build_model() + + model.save("my_model", overwrite=True) + session.report( + metrics={"iter": 1}, + checkpoint=Checkpoint.from_directory("my_model") + ) + + scaling_config = ScalingConfig(num_workers=2) + trainer = TensorflowTrainer( + train_loop_per_worker=train_func, scaling_config=scaling_config + ) + result = trainer.fit() + + # trainer2 will pick up from the checkpoint saved by trainer1. + trainer2 = TensorflowTrainer( + train_loop_per_worker=train_func, + scaling_config=scaling_config, + # this is ultimately what is accessed through + # ``Session.get_checkpoint()`` + resume_from_checkpoint=result.checkpoint, + ) + result2 = trainer2.fit() + """ + + return _get_session().loaded_checkpoint + + +@PublicAPI(stability="beta") +@_warn_session_misuse() +def get_experiment_name() -> str: + """Experiment name for the corresponding trial.""" + return _get_session().experiment_name + + +@PublicAPI(stability="beta") +@_warn_session_misuse() +def get_trial_name() -> str: + """Trial name for the corresponding trial.""" + return _get_session().trial_name + + +@PublicAPI(stability="beta") +@_warn_session_misuse() +def get_trial_id() -> str: + """Trial id for the corresponding trial.""" + return _get_session().trial_id + + +@PublicAPI(stability="beta") +@_warn_session_misuse() +def get_trial_resources() -> "PlacementGroupFactory": + """Trial resources for the corresponding trial.""" + return _get_session().trial_resources + + +@PublicAPI(stability="beta") +@_warn_session_misuse() +def get_trial_dir() -> str: + """Log directory corresponding to the trial directory for a Tune session. + If calling from a Train session, this will give the trial directory of its parent + Tune session. + + .. code-block:: python + + from ray import tune + from ray.air import session + + def train_func(): + # Example: + # >>> session.get_trial_dir() + # ~/ray_results// + + tuner = tune.Tuner(train_func) + tuner.fit() + """ + return _get_session().trial_dir + + +@PublicAPI(stability="beta") +@_warn_session_misuse(default_value=1) +def get_world_size() -> int: + """Get the current world size (i.e. total number of workers) for this run. + + .. code-block:: python + + import time + from ray.air import session + from ray.air.config import ScalingConfig + + def train_loop_per_worker(config): + assert session.get_world_size() == 4 + + train_dataset = ray.data.from_items( + [{"x": x, "y": x + 1} for x in range(32)]) + trainer = TensorflowTrainer(train_loop_per_worker, + scaling_config=ScalingConfig(num_workers=1), + datasets={"train": train_dataset}) + trainer.fit() + """ + session = _get_session() + if not hasattr(session, "world_size"): + raise RuntimeError( + "`get_world_size` can only be called for TrainSession! " + "Make sure you only use that in `train_loop_per_worker` function" + "that is passed into `DataParallelTrainer`." + ) + return session.world_size + + +@PublicAPI(stability="beta") +@_warn_session_misuse(default_value=0) +def get_world_rank() -> int: + """Get the world rank of this worker. + + .. code-block:: python + + import time + from ray.air import session + from ray.air.config import ScalingConfig + + def train_loop_per_worker(): + for iter in range(100): + time.sleep(1) + if session.get_world_rank() == 0: + print("Worker 0") + + train_dataset = ray.data.from_items( + [{"x": x, "y": x + 1} for x in range(32)]) + trainer = TensorflowTrainer(train_loop_per_worker, + scaling_config=ScalingConfig(num_workers=1), + datasets={"train": train_dataset}) + trainer.fit() + """ + session = _get_session() + if not hasattr(session, "world_rank"): + raise RuntimeError( + "`get_world_rank` can only be called for TrainSession! " + "Make sure you only use that in `train_loop_per_worker` function" + "that is passed into `DataParallelTrainer`." + ) + return session.world_rank + + +@PublicAPI(stability="beta") +@_warn_session_misuse(default_value=0) +def get_local_rank() -> int: + """Get the local rank of this worker (rank of the worker on its node). + + .. code-block:: python + + import time + from ray.air import session + from ray.air.config import ScalingConfig + + def train_loop_per_worker(): + if torch.cuda.is_available(): + torch.cuda.set_device(session.get_local_rank()) + ... + + train_dataset = ray.data.from_items( + [{"x": x, "y": x + 1} for x in range(32)]) + trainer = TensorflowTrainer(train_loop_per_worker, + scaling_config=ScalingConfig(num_workers=1), + datasets={"train": train_dataset}) + trainer.fit() + """ + session = _get_session() + if not hasattr(session, "local_rank"): + raise RuntimeError( + "`get_local_rank` can only be called for TrainSession! " + "Make sure you only use that in `train_loop_per_worker` function" + "that is passed into `DataParallelTrainer`." + ) + return session.local_rank + + +@PublicAPI(stability="beta") +@_warn_session_misuse(default_value=0) +def get_local_world_size() -> int: + """Get the local world size of this node (i.e. number of workers on this node). + + Example: + >>> import ray + >>> from ray.air import session + >>> from ray.air.config import ScalingConfig + >>> from ray.train.torch import TorchTrainer + >>> + >>> def train_loop_per_worker(): + ... return session.get_local_world_size() + >>> + >>> train_dataset = ray.data.from_items( + ... [{"x": x, "y": x + 1} for x in range(32)]) + >>> trainer = TorchTrainer(train_loop_per_worker, + ... scaling_config=ScalingConfig(num_workers=1), + ... datasets={"train": train_dataset}) + >>> trainer.fit() # doctest: +SKIP + """ + session = _get_session() + if not hasattr(session, "local_world_size"): + raise RuntimeError( + "`get_local_world_size` can only be called for TrainSession! " + "Make sure you only use that in `train_loop_per_worker` function" + "that is passed into `DataParallelTrainer`." + ) + return session.local_world_size + + +@PublicAPI(stability="beta") +@_warn_session_misuse(default_value=0) +def get_node_rank() -> int: + """Get the rank of this node. + + Example: + >>> import ray + >>> from ray.air import session + >>> from ray.air.config import ScalingConfig + >>> from ray.train.torch import TorchTrainer + >>> + >>> def train_loop_per_worker(): + ... return session.get_node_rank() + >>> + >>> train_dataset = ray.data.from_items( + ... [{"x": x, "y": x + 1} for x in range(32)]) + >>> trainer = TorchTrainer(train_loop_per_worker, + ... scaling_config=ScalingConfig(num_workers=1), + ... datasets={"train": train_dataset}) + >>> trainer.fit() # doctest: +SKIP + """ + session = _get_session() + if not hasattr(session, "node_rank"): + raise RuntimeError( + "`get_node_rank` can only be called for TrainSession! " + "Make sure you only use that in `train_loop_per_worker` function" + "that is passed into `DataParallelTrainer`." + ) + return session.node_rank + + +@PublicAPI(stability="beta") +@_warn_session_misuse() +def get_dataset_shard( + dataset_name: Optional[str] = None, +) -> Optional["DataIterator"]: + """Returns the :class:`ray.data.DataIterator` shard for this worker. + + Call :meth:`~ray.data.DataIterator.iter_torch_batches` or + :meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the + appropriate framework-specific data type. + + .. code-block:: python + + import ray + from ray import train + from ray.air import session + from ray.air.config import ScalingConfig + + def train_loop_per_worker(): + model = Net() + for iter in range(100): + # Trainer will automatically handle sharding. + data_shard = session.get_dataset_shard("train") + for batch in data_shard.iter_torch_batches(): + # ... + return model + + train_dataset = ray.data.from_items( + [{"x": x, "y": x + 1} for x in range(32)]) + trainer = TorchTrainer(train_loop_per_worker, + scaling_config=ScalingConfig(num_workers=2), + datasets={"train": train_dataset}) + trainer.fit() + + Args: + dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then + specifies which dataset shard to return. + + Returns: + The ``DataIterator`` shard to use for this worker. + If no dataset is passed into Trainer, then return None. + """ + session = _get_session() + if not hasattr(session, "get_dataset_shard"): + raise RuntimeError( + "`get_dataset_shard` can only be called for TrainSession! " + "Make sure you only use that in `train_loop_per_worker` function" + "that is passed into `DataParallelTrainer`." + ) + return session.get_dataset_shard(dataset_name) diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index 034cfedf1917..af1399bd27d1 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -19,9 +19,9 @@ from ray.air._internal import usage as air_usage from ray.air._internal.usage import AirEntrypoint from ray.air.checkpoint import Checkpoint -from ray.air import session from ray.air.config import RunConfig, ScalingConfig from ray.air.result import Result +from ray.train._internal import session from ray.train.constants import TRAIN_DATASET_KEY from ray.util import PublicAPI from ray.util.annotations import DeveloperAPI diff --git a/python/ray/train/context.py b/python/ray/train/context.py new file mode 100644 index 000000000000..e055dd359f95 --- /dev/null +++ b/python/ray/train/context.py @@ -0,0 +1,86 @@ +import threading +from typing import TYPE_CHECKING, Optional + +from ray.air import Checkpoint +from ray.train._internal import session +from ray.util.annotations import PublicAPI + + +if TYPE_CHECKING: + from ray.tune.execution.placement_groups import PlacementGroupFactory + + +# The context singleton on this process. +_default_context: "Optional[TrainContext]" = None +_context_lock = threading.Lock() + + +def _copy_doc(copy_func): + def wrapped(func): + func.__doc__ = copy_func.__doc__ + return func + + return wrapped + + +@PublicAPI(stability="beta") +class TrainContext: + """Context for Ray training executions.""" + + @_copy_doc(session.get_checkpoint) + def get_checkpoint(self) -> Optional[Checkpoint]: + return session.get_checkpoint() + + @_copy_doc(session.get_experiment_name) + def get_experiment_name(self) -> str: + return session.get_experiment_name() + + @_copy_doc(session.get_trial_name) + def get_trial_name(self) -> str: + return session.get_trial_name() + + @_copy_doc(session.get_trial_id) + def get_trial_id(self) -> str: + return session.get_trial_id() + + @_copy_doc(session.get_trial_resources) + def get_trial_resources(self) -> "PlacementGroupFactory": + return session.get_trial_resources() + + @_copy_doc(session.get_trial_dir) + def get_trial_dir(self) -> str: + return session.get_trial_dir() + + @_copy_doc(session.get_world_size) + def get_world_size(self) -> int: + return session.get_world_size() + + @_copy_doc(session.get_world_rank) + def get_world_rank(self) -> int: + return session.get_world_rank() + + @_copy_doc(session.get_local_rank) + def get_local_rank(self) -> int: + return session.get_local_rank() + + @_copy_doc(session.get_local_world_size) + def get_local_world_size(self) -> int: + return session.get_local_world_size() + + @_copy_doc(session.get_node_rank) + def get_node_rank(self) -> int: + return session.get_node_rank() + + +@PublicAPI(stability="beta") +def get_context() -> TrainContext: + """Get or create a singleton training context. + + The context is only available in a training or tuning loop. + """ + global _default_context + + with _context_lock: + if _default_context is None: + _default_context = TrainContext() + return _default_context diff --git a/python/ray/train/data_parallel_trainer.py b/python/ray/train/data_parallel_trainer.py index 485c2269f7c8..6fc87e2eabe2 100644 --- a/python/ray/train/data_parallel_trainer.py +++ b/python/ray/train/data_parallel_trainer.py @@ -7,13 +7,13 @@ import ray from ray import tune -from ray.air import session from ray.air.checkpoint import Checkpoint from ray.air._internal.checkpointing import add_preprocessor_to_checkpoint from ray.air.config import DatasetConfig, RunConfig, ScalingConfig, CheckpointConfig from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY, LAZY_CHECKPOINT_MARKER_FILE from ray.air._internal.checkpoint_manager import _TrackedCheckpoint from ray.train import BackendConfig, TrainingIterator +from ray.train._internal import session from ray.train._internal.backend_executor import BackendExecutor, TrialInfo from ray.train._internal.checkpoint import TuneCheckpointManager from ray.train.data_config import DataConfig, _LegacyDataConfigWrapper diff --git a/python/ray/train/torch/config.py b/python/ray/train/torch/config.py index 7c5e0ab8926d..9b1c87c38c70 100644 --- a/python/ray/train/torch/config.py +++ b/python/ray/train/torch/config.py @@ -136,7 +136,7 @@ def _shutdown_torch(destroy_process_group=False): def _set_torch_distributed_env_vars(): # Same env vars as in # https://pytorch.org/docs/stable/elastic/run.html#environment-variables - from ray.air import session + from ray.train._internal import session from ray.train.torch.train_loop_utils import get_device os.environ["LOCAL_RANK"] = str(session.get_local_rank()) diff --git a/python/ray/train/torch/train_loop_utils.py b/python/ray/train/torch/train_loop_utils.py index c41177f73a80..c1724ccfeb7c 100644 --- a/python/ray/train/torch/train_loop_utils.py +++ b/python/ray/train/torch/train_loop_utils.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Callable, Union -from ray.air import session +from ray.train._internal import session from ray.train._internal.accelerator import Accelerator from torch.optim import Optimizer from ray.train._internal.session import get_accelerator, set_accelerator