Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Update run status and actor status for train runs. #46395

Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions dashboard/modules/train/train_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ray.dashboard.modules.job.utils import (
find_jobs_by_job_ids,
)
from ray.experimental.state.api import list_actors

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -61,6 +62,8 @@ async def get_train_runs(self, req: Request) -> Response:
else:
try:
train_runs = await stats_actor.get_all_train_runs.remote()
self._update_actor_status(train_runs)

# Sort train runs in reverse chronological order
train_runs = sorted(
train_runs.values(),
Expand Down Expand Up @@ -96,6 +99,19 @@ async def get_train_runs(self, req: Request) -> Response:
content_type="application/json",
)

def _update_actor_status(self, train_runs):
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
actor_status_table = {actor.actor_id: actor.state for actor in list_actors()}
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

for train_run in train_runs:
train_run.controller_actor_status = actor_status_table.get(
train_run.controller_actor_id, None
)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

for worker_info in train_run.workers:
worker_info.actor_status = actor_status_table.get(
worker_info.actor_id, None
)

@staticmethod
def is_minimal_module():
return False
Expand Down
17 changes: 17 additions & 0 deletions python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,8 @@ def initialize_session(

# Register Train Run before training starts
if self.state_tracking_enabled:
from ray.train._internal.state.schema import RunStatusEnum

core_context = ray.runtime_context.get_runtime_context()

self.state_manager.register_train_run(
Expand All @@ -553,6 +555,7 @@ def initialize_session(
datasets=datasets,
worker_group=self.worker_group,
start_time_ms=self._start_time_ms,
run_status=RunStatusEnum.STARTED,
)

# Run the training function asynchronously in its own thread.
Expand Down Expand Up @@ -650,6 +653,20 @@ def end_training():
results = self.get_with_failure_handling(futures)
return results

def report_final_run_status(self, errored=False):
"""Report the final train run status and end time to TrainStateActor."""
if self.state_tracking_enabled:
from ray.train._internal.state.schema import RunStatusEnum

self.state_manager.update_train_run_info(
updates=dict(
run_status=RunStatusEnum.ERRORED
if errored
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
else RunStatusEnum.FINISHED,
end_time_ms=int(time.time() * 1000),
)
)

def get_with_failure_handling(self, remote_values):
"""Gets the remote values while handling for worker failures.

Expand Down
27 changes: 27 additions & 0 deletions python/ray/train/_internal/state/schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from enum import Enum
from typing import List, Optional

from ray._private.pydantic_compat import BaseModel, Field
from ray.dashboard.modules.job.pydantic_models import JobDetails
from ray.util.annotations import DeveloperAPI


@DeveloperAPI
class RunStatusEnum(str, Enum):
STARTED = "STARTED"
FINISHED = "FINISHED"
ERRORED = "ERRORED"


@DeveloperAPI
class ActorStatusEnum(str, Enum):
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
DEAD = "DEAD"
ALIVE = "ALIVE"


@DeveloperAPI
class TrainWorkerInfo(BaseModel):
"""Metadata of a Ray Train worker."""
Expand All @@ -21,6 +35,9 @@ class TrainWorkerInfo(BaseModel):
gpu_ids: List[int] = Field(
description="A list of GPU ids allocated to that worker."
)
actor_status: Optional[ActorStatusEnum] = Field(
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
description="The status of the train worker actor. It can be ALIVE or DEAD."
)


@DeveloperAPI
Expand All @@ -46,9 +63,19 @@ class TrainRunInfo(BaseModel):
datasets: List[TrainDatasetInfo] = Field(
description="A List of dataset info for this Train run."
)
run_status: RunStatusEnum = Field(
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
description="The current status of the train run. It can be one of the "
"following: STARTED, FINISHED, or ERRORED."
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
)
start_time_ms: int = Field(
description="The UNIX timestamp of the start time of this Train run."
)
end_time_ms: Optional[int] = Field(
description="The UNIX timestamp of the end time of this Train run."
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
)
controller_actor_status: Optional[ActorStatusEnum] = Field(
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
description="The status of the controller actor. It can be ALIVE or DEAD."
)


@DeveloperAPI
Expand Down
13 changes: 11 additions & 2 deletions python/ray/train/_internal/state/state_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import Dict
from typing import Any, Dict

import ray
from ray.data import Dataset
Expand All @@ -23,12 +23,14 @@ class TrainRunStateManager:

def __init__(self, state_actor) -> None:
self.state_actor = state_actor
self.train_run_info_dict = {}

def register_train_run(
self,
run_id: str,
job_id: str,
run_name: str,
run_status: str,
controller_actor_id: str,
datasets: Dict[str, Dataset],
worker_group: WorkerGroup,
Expand Down Expand Up @@ -82,14 +84,21 @@ def collect_train_worker_info():
for ds_name, ds in datasets.items()
]

train_run_info = TrainRunInfo(
self.train_run_info_dict = dict(
id=run_id,
job_id=job_id,
name=run_name,
controller_actor_id=controller_actor_id,
workers=worker_info_list,
datasets=dataset_info_list,
start_time_ms=start_time_ms,
run_status=run_status,
)
train_run_info = TrainRunInfo(**self.train_run_info_dict)
ray.get(self.state_actor.register_train_run.remote(train_run_info))

def update_train_run_info(self, updates: Dict[str, Any]) -> None:
"""Update specific fields of a registered TrainRunInfo instance."""
self.train_run_info_dict.update(updates)
train_run_info = TrainRunInfo(**self.train_run_info_dict)
ray.get(self.state_actor.register_train_run.remote(train_run_info))
48 changes: 46 additions & 2 deletions python/ray/train/tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray.cluster_utils import Cluster
from ray.train import RunConfig, ScalingConfig
from ray.train._internal.state.schema import (
RunStatusEnum,
TrainDatasetInfo,
TrainRunInfo,
TrainWorkerInfo,
Expand Down Expand Up @@ -46,6 +47,9 @@ def ray_start_gpu_cluster():
"job_id": "0000000001",
"controller_actor_id": "3abd1972a19148d78acc78dd9414736e",
"start_time_ms": 1717448423000,
"run_status": "STARTED",
"controller_actor_status": null,
"end_time_ms": null,
"workers": [
{
"actor_id": "3d86c25634a71832dac32c8802000000",
Expand All @@ -55,7 +59,8 @@ def ray_start_gpu_cluster():
"node_id": "b1e6cbed8533ae2def4e7e7ced9d19858ceb1ed8ab9ba81ab9c07825",
"node_ip": "10.0.208.100",
"pid": 76071,
"gpu_ids": [0]
"gpu_ids": [0],
"actor_status": null
},
{
"actor_id": "8f162dd8365346d1b5c98ebd7338c4f9",
Expand All @@ -65,7 +70,8 @@ def ray_start_gpu_cluster():
"node_id": "b1e6cbed8533ae2def4e7e7ced9d19858ceb1ed8ab9ba81ab9c07825",
"node_ip": "10.0.208.100",
"pid": 76072,
"gpu_ids": [1]
"gpu_ids": [1],
"actor_status": null
}
],
"datasets": [
Expand Down Expand Up @@ -113,6 +119,7 @@ def _get_run_info_sample(run_id=None, run_name=None) -> TrainRunInfo:
workers=[worker_info_0, worker_info_1],
datasets=[dataset_info],
start_time_ms=1717448423000,
run_status=RunStatusEnum.STARTED,
)
return run_info

Expand Down Expand Up @@ -173,6 +180,7 @@ def test_state_manager(ray_start_gpu_cluster):
datasets={},
worker_group=worker_group,
start_time_ms=int(time.time() * 1000),
run_status=RunStatusEnum.STARTED,
)

# Register 100 runs with 10 TrainRunStateManagers
Expand All @@ -192,6 +200,7 @@ def test_state_manager(ray_start_gpu_cluster):
},
worker_group=worker_group,
start_time_ms=int(time.time() * 1000),
run_status=RunStatusEnum.STARTED,
)

runs = ray.get(state_actor.get_all_train_runs.remote())
Expand Down Expand Up @@ -275,6 +284,41 @@ def test_track_e2e_training(ray_start_gpu_cluster, gpus_per_worker):
assert dataset_info.dataset_uuid == dataset._plan._dataset_uuid


@pytest.mark.parametrize("raise_error", [True, False])
def test_train_run_status(ray_start_gpu_cluster, raise_error):
os.environ["RAY_TRAIN_ENABLE_STATE_TRACKING"] = "1"

def check_run_status(expected_status):
state_actor = ray.get_actor(
name=TRAIN_STATE_ACTOR_NAME, namespace=TRAIN_STATE_ACTOR_NAMESPACE
)
runs = ray.get(state_actor.get_all_train_runs.remote())
run = next(iter(runs.values()))
assert run.run_status == expected_status

def train_func():
check_run_status(expected_status=RunStatusEnum.STARTED)
if raise_error:
raise RuntimeError

trainer = DataParallelTrainer(
train_loop_per_worker=train_func,
scaling_config=ScalingConfig(num_workers=4, use_gpu=False),
)

try:
trainer.fit()
except Exception:
pass

if raise_error:
check_run_status(expected_status=RunStatusEnum.ERRORED)
else:
check_run_status(expected_status=RunStatusEnum.FINISHED)

ray.shutdown()


if __name__ == "__main__":
import sys

Expand Down
38 changes: 38 additions & 0 deletions python/ray/train/tests/test_train_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,44 @@ def train_func():
assert body["train_runs"][0]["name"] == "my_train_run"
assert len(body["train_runs"][0]["workers"]) == 4

ray.shutdown()


def test_update_actor_status(monkeypatch, shutdown_only):
monkeypatch.setenv("RAY_TRAIN_ENABLE_STATE_TRACKING", "1")

from ray.train._internal.state.schema import ActorStatusEnum

ray.init(num_cpus=8)

def check_actor_status(expected_actor_status):
url = ray._private.worker.get_dashboard_url()
resp = requests.get("http://" + url + "/api/train/runs")
assert resp.status_code == 200
body = resp.json()
train_run = body["train_runs"][0]

train_run["controller_actor_status"] == expected_actor_status

for worker_info in body["train_runs"][0]["workers"]:
assert worker_info["status"] == expected_actor_status

def train_func():
print("Training Starts")
time.sleep(0.5)
check_actor_status(expected_actor_status=ActorStatusEnum.ALIVE)

trainer = TorchTrainer(
train_func,
run_config=RunConfig(name="my_train_run", storage_path="/tmp/cluster_storage"),
scaling_config=ScalingConfig(num_workers=4, use_gpu=False),
)
trainer.fit()

check_actor_status(expected_actor_status=ActorStatusEnum.DEAD)

ray.shutdown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
Expand Down
3 changes: 3 additions & 0 deletions python/ray/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@ def _run_with_error_handling(self, func: Callable):

def __next__(self):
if self.is_finished():
self._backend_executor.report_final_run_status(errored=False)
raise StopIteration
try:
next_results = self._run_with_error_handling(self._fetch_next_result)
if next_results is None:
self._backend_executor.report_final_run_status(errored=False)
self._run_with_error_handling(self._finish_training)
self._finished_training = True
raise StopIteration
Expand All @@ -130,6 +132,7 @@ def __next__(self):
except StartTraceback:
# If this is a StartTraceback, then this is a user error.
# We raise it directly
self._backend_executor.report_final_run_status(errored=True)
try:
# Exception raised in at least one training worker. Immediately raise
# this error to the user and do not attempt to terminate gracefully.
Expand Down