Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Update run status and actor status for train runs. #46395

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions python/ray/dashboard/modules/train/train_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import ray
import ray.dashboard.optional_utils as dashboard_optional_utils
import ray.dashboard.utils as dashboard_utils
from ray.core.generated import gcs_service_pb2, gcs_service_pb2_grpc
from ray.dashboard.modules.actor.actor_head import actor_table_data_to_dict
from ray.dashboard.modules.job.common import JobInfoStorageClient
from ray.dashboard.modules.job.utils import find_jobs_by_job_ids
from ray.util.annotations import DeveloperAPI
Expand All @@ -20,10 +22,9 @@ def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._train_stats_actor = None
self._job_info_client = None
self._gcs_actor_info_stub = None

# TODO(aguo): Update this to a "v2" path since I made a backwards-incompatible
# change. Will do so after the API is more stable.
@routes.get("/api/train/runs")
@routes.get("/api/train/v2/runs")
@dashboard_optional_utils.init_ray_and_catch_exceptions()
@DeveloperAPI
async def get_train_runs(self, req: Request) -> Response:
Expand Down Expand Up @@ -57,6 +58,7 @@ async def get_train_runs(self, req: Request) -> Response:
else:
try:
train_runs = await stats_actor.get_all_train_runs.remote()
await self._add_actor_status_and_update_run_status(train_runs)
# Sort train runs in reverse chronological order
train_runs = sorted(
train_runs.values(),
Expand Down Expand Up @@ -92,6 +94,44 @@ async def get_train_runs(self, req: Request) -> Response:
content_type="application/json",
)

async def _add_actor_status_and_update_run_status(self, train_runs):
from ray.train._internal.state.schema import ActorStatusEnum, RunStatusEnum

actor_status_table = {}
try:
logger.info("Getting all actor info from GCS.")
request = gcs_service_pb2.GetAllActorInfoRequest()
reply = await self._gcs_actor_info_stub.GetAllActorInfo(request, timeout=5)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
if reply.status.code == 0:
for message in reply.actor_table_data:
actor_table_data = actor_table_data_to_dict(message)
actor_status_table[actor_table_data["actorId"]] = actor_table_data[
"state"
]
except Exception:
logger.exception("Error Getting all actor info from GCS.")

for train_run in train_runs.values():
for worker_info in train_run.workers:
worker_info.status = actor_status_table.get(worker_info.actor_id, None)

# The train run can be unexpectedly terminated before the final run
# status was updated. This could be due to errors outside of the training
# function (e.g., system failure or user interruption) that crashed the
# train controller.
# We need to detect this case and mark the train run as ABORTED.
controller_actor_status = actor_status_table.get(
train_run.controller_actor_id, None
)
if (
controller_actor_status == ActorStatusEnum.DEAD
and train_run.run_status == RunStatusEnum.STARTED
):
train_run.run_status = RunStatusEnum.ABORTED
train_run.status_detail = (
"Unexpectedly terminated due to system errors."
)

@staticmethod
def is_minimal_module():
return False
Expand All @@ -102,6 +142,11 @@ async def run(self, server):
self._dashboard_head.gcs_aio_client
)

gcs_channel = self._dashboard_head.aiogrpc_gcs_channel
self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
gcs_channel
)

async def get_train_stats_actor(self):
"""
Gets the train stats actor and caches it as an instance variable.
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ py_test(

py_test(
name = "test_state",
size = "small",
size = "medium",
srcs = ["tests/test_state.py"],
tags = ["team:ml", "exclusive"],
deps = [":train_lib", ":conftest"]
Expand Down
22 changes: 22 additions & 0 deletions python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,8 @@ def initialize_session(

# Register Train Run before training starts
if self.state_tracking_enabled:
from ray.train._internal.state.schema import RunStatusEnum

core_context = ray.runtime_context.get_runtime_context()

self.state_manager.register_train_run(
Expand All @@ -553,6 +555,7 @@ def initialize_session(
datasets=datasets,
worker_group=self.worker_group,
start_time_ms=self._start_time_ms,
run_status=RunStatusEnum.STARTED,
)

# Run the training function asynchronously in its own thread.
Expand Down Expand Up @@ -650,6 +653,25 @@ def end_training():
results = self.get_with_failure_handling(futures)
return results

def report_final_run_status(self, errored=False):
"""Report the final train run status and end time to TrainStateActor."""
if self.state_tracking_enabled:
from ray.train._internal.state.schema import RunStatusEnum

if errored:
run_status = RunStatusEnum.ERRORED
status_detail = "Terminated due to an error in the training function."
else:
run_status = RunStatusEnum.FINISHED
status_detail = ""

self.state_manager.end_train_run(
run_id=self._trial_info.run_id,
run_status=run_status,
status_detail=status_detail,
end_time_ms=int(time.time() * 1000),
)

def get_with_failure_handling(self, remote_values):
"""Gets the remote values while handling for worker failures.

Expand Down
36 changes: 36 additions & 0 deletions python/ray/train/_internal/state/schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,31 @@
from enum import Enum
from typing import List, Optional

from ray._private.pydantic_compat import BaseModel, Field
from ray.dashboard.modules.job.pydantic_models import JobDetails
from ray.util.annotations import DeveloperAPI


@DeveloperAPI
class RunStatusEnum(str, Enum):
"""Enumeration for the status of a train run."""

# The train run has started
STARTED = "STARTED"
# The train run was terminated as expected
FINISHED = "FINISHED"
# The train run was terminated early due to errors in the training function
ERRORED = "ERRORED"
# The train run was terminated early due to system errors or controller errors
ABORTED = "ABORTED"


@DeveloperAPI
class ActorStatusEnum(str, Enum):
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
DEAD = "DEAD"
ALIVE = "ALIVE"


@DeveloperAPI
class TrainWorkerInfo(BaseModel):
"""Metadata of a Ray Train worker."""
Expand All @@ -21,6 +42,9 @@ class TrainWorkerInfo(BaseModel):
gpu_ids: List[int] = Field(
description="A list of GPU ids allocated to that worker."
)
status: Optional[ActorStatusEnum] = Field(
description="The status of the train worker actor. It can be ALIVE or DEAD."
)


@DeveloperAPI
Expand All @@ -46,9 +70,21 @@ class TrainRunInfo(BaseModel):
datasets: List[TrainDatasetInfo] = Field(
description="A List of dataset info for this Train run."
)
run_status: RunStatusEnum = Field(
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
description="The current status of the train run. It can be one of the "
"following: STARTED, FINISHED, ERRORED, or ABORTED."
)
status_detail: str = Field(
description="Detailed information about the current run status, "
"such as error messages."
)
Comment on lines +77 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this? We only ever have one "User Error" message right now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for tracking the error reason for now, and can be extend to track the details of current run status in the future(e.g. scaling up / down/ recovering when doing elastic training.

start_time_ms: int = Field(
description="The UNIX timestamp of the start time of this Train run."
)
end_time_ms: Optional[int] = Field(
description="The UNIX timestamp of the end time of this Train run. "
"If null, the Train run has not ended yet."
)


@DeveloperAPI
Expand Down
35 changes: 32 additions & 3 deletions python/ray/train/_internal/state/state_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import os
from typing import Dict
from collections import defaultdict
from typing import Any, Dict

import ray
from ray.data import Dataset
from ray.train._internal.state.schema import (
RunStatusEnum,
TrainDatasetInfo,
TrainRunInfo,
TrainWorkerInfo,
Expand All @@ -23,16 +25,19 @@ class TrainRunStateManager:

def __init__(self, state_actor) -> None:
self.state_actor = state_actor
self.train_run_info_dict = defaultdict(dict)

def register_train_run(
self,
run_id: str,
job_id: str,
run_name: str,
run_status: str,
controller_actor_id: str,
datasets: Dict[str, Dataset],
worker_group: WorkerGroup,
start_time_ms: float,
status_detail: str = "",
) -> None:
"""Collect Train Run Info and report to StateActor."""

Expand Down Expand Up @@ -82,14 +87,38 @@ def collect_train_worker_info():
for ds_name, ds in datasets.items()
]

train_run_info = TrainRunInfo(
updates = dict(
id=run_id,
job_id=job_id,
name=run_name,
controller_actor_id=controller_actor_id,
workers=worker_info_list,
datasets=dataset_info_list,
start_time_ms=start_time_ms,
run_status=run_status,
status_detail=status_detail,
)

ray.get(self.state_actor.register_train_run.remote(train_run_info))
# Clear the cached info to avoid registering the same run twice
self.train_run_info_dict[run_id] = {}
self._update_train_run_info(run_id, updates)

def end_train_run(
self,
run_id: str,
run_status: RunStatusEnum,
status_detail: str,
end_time_ms: int,
):
"""Update the train run status when the training is finished."""
updates = dict(
run_status=run_status, status_detail=status_detail, end_time_ms=end_time_ms
)
self._update_train_run_info(run_id, updates)

def _update_train_run_info(self, run_id: str, updates: Dict[str, Any]) -> None:
"""Update specific fields of a registered TrainRunInfo instance."""
if run_id in self.train_run_info_dict:
self.train_run_info_dict[run_id].update(updates)
train_run_info = TrainRunInfo(**self.train_run_info_dict[run_id])
ray.get(self.state_actor.register_train_run.remote(train_run_info))
51 changes: 48 additions & 3 deletions python/ray/train/tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray.cluster_utils import Cluster
from ray.train import RunConfig, ScalingConfig
from ray.train._internal.state.schema import (
RunStatusEnum,
TrainDatasetInfo,
TrainRunInfo,
TrainWorkerInfo,
Expand Down Expand Up @@ -46,6 +47,9 @@ def ray_start_gpu_cluster():
"job_id": "0000000001",
"controller_actor_id": "3abd1972a19148d78acc78dd9414736e",
"start_time_ms": 1717448423000,
"run_status": "STARTED",
"status_detail": "",
"end_time_ms": null,
"workers": [
{
"actor_id": "3d86c25634a71832dac32c8802000000",
Expand All @@ -55,7 +59,8 @@ def ray_start_gpu_cluster():
"node_id": "b1e6cbed8533ae2def4e7e7ced9d19858ceb1ed8ab9ba81ab9c07825",
"node_ip": "10.0.208.100",
"pid": 76071,
"gpu_ids": [0]
"gpu_ids": [0],
"status": null
},
{
"actor_id": "8f162dd8365346d1b5c98ebd7338c4f9",
Expand All @@ -65,7 +70,8 @@ def ray_start_gpu_cluster():
"node_id": "b1e6cbed8533ae2def4e7e7ced9d19858ceb1ed8ab9ba81ab9c07825",
"node_ip": "10.0.208.100",
"pid": 76072,
"gpu_ids": [1]
"gpu_ids": [1],
"status": null
}
],
"datasets": [
Expand Down Expand Up @@ -113,6 +119,8 @@ def _get_run_info_sample(run_id=None, run_name=None) -> TrainRunInfo:
workers=[worker_info_0, worker_info_1],
datasets=[dataset_info],
start_time_ms=1717448423000,
run_status=RunStatusEnum.STARTED,
status_detail="",
)
return run_info

Expand All @@ -137,7 +145,7 @@ def test_schema_equivalance():
assert _get_run_info_sample() == run_info_from_json


def test_state_actor_api():
def test_state_actor_api(ray_start_4_cpus):
state_actor = get_or_create_state_actor()
named_actors = ray.util.list_named_actors(all_namespaces=True)
assert {
Expand Down Expand Up @@ -173,6 +181,7 @@ def test_state_manager(ray_start_gpu_cluster):
datasets={},
worker_group=worker_group,
start_time_ms=int(time.time() * 1000),
run_status=RunStatusEnum.STARTED,
)

# Register 100 runs with 10 TrainRunStateManagers
Expand All @@ -192,6 +201,7 @@ def test_state_manager(ray_start_gpu_cluster):
},
worker_group=worker_group,
start_time_ms=int(time.time() * 1000),
run_status=RunStatusEnum.STARTED,
)

runs = ray.get(state_actor.get_all_train_runs.remote())
Expand Down Expand Up @@ -275,6 +285,41 @@ def test_track_e2e_training(ray_start_gpu_cluster, gpus_per_worker):
assert dataset_info.dataset_uuid == dataset._plan._dataset_uuid


@pytest.mark.parametrize("raise_error", [True, False])
def test_train_run_status(ray_start_gpu_cluster, raise_error):
os.environ["RAY_TRAIN_ENABLE_STATE_TRACKING"] = "1"

def check_run_status(expected_status):
state_actor = ray.get_actor(
name=TRAIN_STATE_ACTOR_NAME, namespace=TRAIN_STATE_ACTOR_NAMESPACE
)
runs = ray.get(state_actor.get_all_train_runs.remote())
run = next(iter(runs.values()))
assert run.run_status == expected_status

def train_func():
check_run_status(expected_status=RunStatusEnum.STARTED)
if raise_error:
raise RuntimeError

trainer = DataParallelTrainer(
train_loop_per_worker=train_func,
scaling_config=ScalingConfig(num_workers=4, use_gpu=False),
)

try:
trainer.fit()
except Exception:
pass

if raise_error:
check_run_status(expected_status=RunStatusEnum.ERRORED)
else:
check_run_status(expected_status=RunStatusEnum.FINISHED)

ray.shutdown()


if __name__ == "__main__":
import sys

Expand Down
Loading