Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Fault tolerance for compiled DAGs #41943

Merged
merged 18 commits into from
Dec 21, 2023
Merged
14 changes: 10 additions & 4 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,10 +860,16 @@ def get_objects(
debugger_breakpoint = metadata_fields[1][
len(ray_constants.OBJECT_METADATA_DEBUG_PREFIX) :
]
return (
self.deserialize_objects(data_metadata_pairs, object_refs),
debugger_breakpoint,
)
values = self.deserialize_objects(data_metadata_pairs, object_refs)
for i, value in enumerate(values):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied from non-experimental path.

if isinstance(value, RayError):
if isinstance(value, ray.exceptions.ObjectLostError):
global_worker.core_worker.dump_object_store_memory_usage()
if isinstance(value, RayTaskError):
raise value.as_instanceof_cause()
else:
raise value
return values, debugger_breakpoint

def main_loop(self):
"""The main loop a worker runs to receive and execute tasks."""
Expand Down
2 changes: 2 additions & 0 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3518,6 +3518,7 @@ cdef class CoreWorker:
def experimental_mutable_object_put_serialized(self, serialized_object,
ObjectRef object_ref,
num_readers,
try_wait=False
ericl marked this conversation as resolved.
Show resolved Hide resolved
):
cdef:
CObjectID c_object_id = object_ref.native()
Expand All @@ -3532,6 +3533,7 @@ cdef class CoreWorker:
metadata,
data_size,
num_readers,
try_wait,
&data,
))
if data_size > 0:
Expand Down
96 changes: 76 additions & 20 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import logging
from typing import Any, Dict, List, Tuple, Union, Optional

from collections import defaultdict
from typing import Any, Dict, List, Tuple, Union, Optional
import logging
import threading

import ray
import ray.experimental.channel as ray_channel
from ray.exceptions import RayTaskError, TaskCancelledError
from ray.experimental.channel import Channel
from ray.util.annotations import DeveloperAPI


MAX_BUFFER_SIZE = int(100 * 1e6) # 100MB

ChannelType = "ray.experimental.channel.Channel"

logger = logging.getLogger(__name__)


@DeveloperAPI
def do_allocate_channel(
self, buffer_size_bytes: int, num_readers: int = 1
) -> ChannelType:
def do_allocate_channel(self, buffer_size_bytes: int, num_readers: int = 1) -> Channel:
"""Generic actor method to allocate an output channel.

Args:
Expand All @@ -28,14 +25,14 @@ def do_allocate_channel(
Returns:
The allocated channel.
"""
self._output_channel = ray_channel.Channel(buffer_size_bytes, num_readers)
self._output_channel = Channel(buffer_size_bytes, num_readers)
return self._output_channel


@DeveloperAPI
def do_exec_compiled_task(
self,
inputs: List[Union[Any, ChannelType]],
inputs: List[Union[Any, Channel]],
actor_method_name: str,
) -> None:
"""Generic actor method to begin executing a compiled DAG. This runs an
Expand All @@ -52,13 +49,14 @@ def do_exec_compiled_task(
the loop.
"""
try:
self._input_channels = [i for i in inputs if isinstance(i, Channel)]
method = getattr(self, actor_method_name)

resolved_inputs = []
input_channel_idxs = []
# Add placeholders for input channels.
for idx, inp in enumerate(inputs):
if isinstance(inp, ray_channel.Channel):
if isinstance(inp, Channel):
input_channel_idxs.append((idx, inp))
resolved_inputs.append(None)
else:
Expand All @@ -75,10 +73,21 @@ def do_exec_compiled_task(
channel.end_read()

except Exception as e:
logging.warn(f"Compiled DAG task aborted with exception: {e}")
logging.info(f"Compiled DAG task exited with exception: {e}")
raise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-Ray exceptions, I wonder if we should instead store the error and keep looping?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1



@DeveloperAPI
def do_cancel_compiled_task(self):
e = RayTaskError(
function_name="do_exec_compiled_task",
traceback_str="",
cause=TaskCancelledError(),
)
for channel in self._input_channels:
channel.set_error(e)


@DeveloperAPI
class CompiledTask:
"""Wraps the normal Ray DAGNode with some metadata."""
Expand Down Expand Up @@ -147,11 +156,13 @@ def __init__(self, buffer_size_bytes: Optional[int]):
self.actor_task_count: Dict["ray._raylet.ActorID", int] = defaultdict(int)

# Cached attributes that are set during compilation.
self.dag_input_channel: Optional[ChannelType] = None
self.dag_output_channels: Optional[ChannelType] = None
self.dag_input_channel: Optional[Channel] = None
self.dag_output_channels: Optional[Channel] = None
# ObjectRef for each worker's task. The task is an infinite loop that
# repeatedly executes the method specified in the DAG.
self.worker_task_refs: List["ray.ObjectRef"] = []
# Set of actors present in the DAG.
self.actor_refs = set()

def _add_node(self, node: "ray.dag.DAGNode") -> None:
idx = self.counter
Expand Down Expand Up @@ -254,7 +265,7 @@ def _preprocess(self) -> None:

def _get_or_compile(
self,
) -> Tuple[ChannelType, Union[ChannelType, List[ChannelType]]]:
) -> Tuple[Channel, Union[Channel, List[Channel]]]:
"""Compile an execution path. This allocates channels for adjacent
tasks to send/receive values. An infinite task is submitted to each
actor in the DAG that repeatedly receives from input channel(s) and
Expand Down Expand Up @@ -303,8 +314,9 @@ def _get_or_compile(
num_readers=task.num_readers,
)
)
self.actor_refs.add(task.dag_node._get_actor_handle())
elif isinstance(task.dag_node, InputNode):
task.output_channel = ray_channel.Channel(
task.output_channel = Channel(
buffer_size_bytes=self._buffer_size_bytes,
num_readers=task.num_readers,
)
Expand Down Expand Up @@ -345,7 +357,7 @@ def _get_or_compile(
# Assign the task with the correct input and output buffers.
worker_fn = task.dag_node._get_remote_method("__ray_call__")
self.worker_task_refs.append(
worker_fn.remote(
worker_fn.options(concurrency_group="_ray_system").remote(
do_exec_compiled_task,
resolved_args,
task.dag_node.get_method_name(),
Expand Down Expand Up @@ -373,13 +385,53 @@ def _get_or_compile(
self.dag_output_channels = self.dag_output_channels[0]

# Driver should ray.put on input, ray.get/release on output
return (self.dag_input_channel, self.dag_output_channels)
ericl marked this conversation as resolved.
Show resolved Hide resolved
self._monitor = self._monitor_failures()
return (self.dag_input_channel, self.dag_output_channels, self._monitor)

def _monitor_failures(self):
outer = self

class Monitor(threading.Thread):
def __init__(self):
super().__init__(daemon=True)
self.in_teardown = False

def teardown(self):
if self.in_teardown:
return
logger.info("Tearing down compiled DAG")
self.in_teardown = True
ericl marked this conversation as resolved.
Show resolved Hide resolved
for actor in outer.actor_refs:
logger.info(f"Cancelling compiled worker on actor: {actor}")
ericl marked this conversation as resolved.
Show resolved Hide resolved
try:
ray.get(actor.__ray_call__.remote(do_cancel_compiled_task))
except Exception as e:
logger.info(f"Error cancelling worker task: {e}")
pass
ericl marked this conversation as resolved.
Show resolved Hide resolved
logger.info("Teardown complete")

ericl marked this conversation as resolved.
Show resolved Hide resolved
def run(self):
try:
ray.get(outer.worker_task_refs)
except Exception as e:
if self.in_teardown:
return
if isinstance(outer.dag_output_channels, list):
for output_channel in outer.dag_output_channels:
output_channel.set_error(e)
else:
outer.dag_output_channels.set_error(e)
self.teardown()
ericl marked this conversation as resolved.
Show resolved Hide resolved

monitor = Monitor()
monitor.start()
return monitor

def execute(
self,
*args,
**kwargs,
) -> Union[ChannelType, List[ChannelType]]:
) -> Union[Channel, List[Channel]]:
"""Execute this DAG using the compiled execution path.

Args:
Expand All @@ -400,6 +452,10 @@ def execute(
input_channel.write(args[0])
return output_channels

def teardown(self):
"""Teardown and cancel all worker tasks for this DAG."""
self._monitor.teardown()


@DeveloperAPI
def build_compiled_dag_from_ray_dag(
Expand Down
54 changes: 48 additions & 6 deletions python/ray/dag/tests/test_accelerated_dag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding: utf-8
import logging
import os
import random
import sys

import pytest
Expand All @@ -9,7 +10,6 @@
import ray.cluster_utils
from ray.dag import InputNode, MultiOutputNode
from ray.tests.conftest import * # noqa
from ray._private.test_utils import wait_for_condition


logger = logging.getLogger(__name__)
Expand All @@ -20,12 +20,17 @@

@ray.remote
class Actor:
def __init__(self, init_value):
def __init__(self, init_value, fail_after=None):
print("__init__ PID", os.getpid())
self.i = init_value
self.fail_after = fail_after

def inc(self, x):
self.i += x
if self.fail_after and self.i > self.fail_after:
# Randomize the failures to better cover multi actor scenarios.
if random.random() > 0.5:
raise ValueError("injected fault")
return self.i

def append_to(self, lst):
Expand All @@ -52,6 +57,10 @@ def test_basic(ray_start_regular):
assert result == i + 1
output_channel.end_read()

# Note: must teardown before starting a new Ray session, otherwise you'll get
# a segfault from the dangling monitor thread upon the new Ray init.
compiled_dag.teardown()
ericl marked this conversation as resolved.
Show resolved Hide resolved


def test_regular_args(ray_start_regular):
# Test passing regular args to .bind in addition to DAGNode args.
Expand All @@ -68,6 +77,8 @@ def test_regular_args(ray_start_regular):
assert result == (i + 1) * 3
output_channel.end_read()

compiled_dag.teardown()


@pytest.mark.parametrize("num_actors", [1, 4])
def test_scatter_gather_dag(ray_start_regular, num_actors):
Expand All @@ -86,6 +97,8 @@ def test_scatter_gather_dag(ray_start_regular, num_actors):
for chan in output_channels:
chan.end_read()

compiled_dag.teardown()


@pytest.mark.parametrize("num_actors", [1, 4])
def test_chain_dag(ray_start_regular, num_actors):
Expand All @@ -104,17 +117,20 @@ def test_chain_dag(ray_start_regular, num_actors):
assert result == list(range(num_actors))
output_channel.end_read()

compiled_dag.teardown()


def test_dag_exception(ray_start_regular, capsys):
a = Actor.remote(0)
with InputNode() as inp:
dag = a.inc.bind(inp)

compiled_dag = dag.experimental_compile()
compiled_dag.execute("hello")
wait_for_condition(
lambda: "Compiled DAG task aborted with exception" in capsys.readouterr().err
)
output_channel = compiled_dag.execute("hello")
with pytest.raises(TypeError):
output_channel.begin_read()

compiled_dag.teardown()
ericl marked this conversation as resolved.
Show resolved Hide resolved


def test_dag_errors(ray_start_regular):
Expand Down Expand Up @@ -176,6 +192,32 @@ def f(x):
dag.experimental_compile()


@pytest.mark.parametrize("num_actors", [1, 4])
ericl marked this conversation as resolved.
Show resolved Hide resolved
def test_dag_fault_tolerance(ray_start_regular, num_actors):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test for worker process dying?

actors = [Actor.remote(0, fail_after=100) for _ in range(num_actors)]
with InputNode() as i:
out = [a.inc.bind(i) for a in actors]
dag = MultiOutputNode(out)

compiled_dag = dag.experimental_compile()

for i in range(99):
output_channels = compiled_dag.execute(1)
# TODO(swang): Replace with fake ObjectRef.
results = [chan.begin_read() for chan in output_channels]
assert results == [i + 1] * num_actors
for chan in output_channels:
chan.end_read()

with pytest.raises(ValueError):
for i in range(99):
rkooo567 marked this conversation as resolved.
Show resolved Hide resolved
output_channels = compiled_dag.execute(1)
for chan in output_channels:
chan.begin_read()
for chan in output_channels:
chan.end_read()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: is end_read necessary to call btw?

Since now begin_read can raise an exception, it is easy to forget calling end_read.



if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
30 changes: 30 additions & 0 deletions python/ray/experimental/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional

import ray
from ray.exceptions import RaySystemError
from ray.util.annotations import PublicAPI

# Logger for this module. It should be configured at the entry point
Expand Down Expand Up @@ -159,3 +160,32 @@ def end_read(self):
self._worker.core_worker.experimental_mutable_object_read_release(
[self._base_ref]
)

def set_error(self, e: Exception) -> None:
"""
Shutdown the channel with the specified error object. New readers will see
the error raised when they try to read from the channel.

Does not block.
ericl marked this conversation as resolved.
Show resolved Hide resolved

Args:
e: The exception object to write to the channel.
"""
logger.debug(f"Writing error to channel: {self._base_ref}: {e}")
serialized_exc = self._worker.get_serialization_context().serialize(e)
try:
self._worker.core_worker.experimental_mutable_object_put_serialized(
serialized_exc,
self._base_ref,
num_readers=1,
try_wait=True,
ericl marked this conversation as resolved.
Show resolved Hide resolved
)
except Exception as e:
if not _is_write_acquire_failed_error(e):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I didn't quite understand this condition. It seems to fail silently if we fail to acquire?

Also, could you comment on what cases we expect to fail to acquire?

logger.exception("Error setting error on channel")
raise


def _is_write_acquire_failed_error(e: Exception) -> bool:
# TODO(ekl): detect the exception type better
return isinstance(e, RaySystemError)
Loading
Loading