-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Fault tolerance for compiled DAGs #41943
Changes from 3 commits
0780ed4
13bd97f
13e8bd7
b2e493b
94c5ba4
a6762a4
a15f318
6ca9b43
f353c76
069dc45
082c6b2
5fab84d
482296d
04ad6b4
de7156f
9923a3d
dd5558e
1e19048
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,21 @@ | ||
import logging | ||
from typing import Any, Dict, List, Tuple, Union, Optional | ||
|
||
from collections import defaultdict | ||
from typing import Any, Dict, List, Tuple, Union, Optional | ||
import logging | ||
import threading | ||
|
||
import ray | ||
import ray.experimental.channel as ray_channel | ||
from ray.exceptions import RayTaskError, TaskCancelledError | ||
from ray.experimental.channel import Channel | ||
from ray.util.annotations import DeveloperAPI | ||
|
||
|
||
MAX_BUFFER_SIZE = int(100 * 1e6) # 100MB | ||
|
||
ChannelType = "ray.experimental.channel.Channel" | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@DeveloperAPI | ||
def do_allocate_channel( | ||
self, buffer_size_bytes: int, num_readers: int = 1 | ||
) -> ChannelType: | ||
def do_allocate_channel(self, buffer_size_bytes: int, num_readers: int = 1) -> Channel: | ||
"""Generic actor method to allocate an output channel. | ||
|
||
Args: | ||
|
@@ -28,14 +25,14 @@ def do_allocate_channel( | |
Returns: | ||
The allocated channel. | ||
""" | ||
self._output_channel = ray_channel.Channel(buffer_size_bytes, num_readers) | ||
self._output_channel = Channel(buffer_size_bytes, num_readers) | ||
return self._output_channel | ||
|
||
|
||
@DeveloperAPI | ||
def do_exec_compiled_task( | ||
self, | ||
inputs: List[Union[Any, ChannelType]], | ||
inputs: List[Union[Any, Channel]], | ||
actor_method_name: str, | ||
) -> None: | ||
"""Generic actor method to begin executing a compiled DAG. This runs an | ||
|
@@ -52,13 +49,14 @@ def do_exec_compiled_task( | |
the loop. | ||
""" | ||
try: | ||
self._input_channels = [i for i in inputs if isinstance(i, Channel)] | ||
method = getattr(self, actor_method_name) | ||
|
||
resolved_inputs = [] | ||
input_channel_idxs = [] | ||
# Add placeholders for input channels. | ||
for idx, inp in enumerate(inputs): | ||
if isinstance(inp, ray_channel.Channel): | ||
if isinstance(inp, Channel): | ||
input_channel_idxs.append((idx, inp)) | ||
resolved_inputs.append(None) | ||
else: | ||
|
@@ -75,10 +73,21 @@ def do_exec_compiled_task( | |
channel.end_read() | ||
|
||
except Exception as e: | ||
logging.warn(f"Compiled DAG task aborted with exception: {e}") | ||
logging.info(f"Compiled DAG task exited with exception: {e}") | ||
raise | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For non-Ray exceptions, I wonder if we should instead store the error and keep looping? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
|
||
|
||
@DeveloperAPI | ||
def do_cancel_compiled_task(self): | ||
e = RayTaskError( | ||
function_name="do_exec_compiled_task", | ||
traceback_str="", | ||
cause=TaskCancelledError(), | ||
) | ||
for channel in self._input_channels: | ||
channel.set_error(e) | ||
|
||
|
||
@DeveloperAPI | ||
class CompiledTask: | ||
"""Wraps the normal Ray DAGNode with some metadata.""" | ||
|
@@ -147,11 +156,13 @@ def __init__(self, buffer_size_bytes: Optional[int]): | |
self.actor_task_count: Dict["ray._raylet.ActorID", int] = defaultdict(int) | ||
|
||
# Cached attributes that are set during compilation. | ||
self.dag_input_channel: Optional[ChannelType] = None | ||
self.dag_output_channels: Optional[ChannelType] = None | ||
self.dag_input_channel: Optional[Channel] = None | ||
self.dag_output_channels: Optional[Channel] = None | ||
# ObjectRef for each worker's task. The task is an infinite loop that | ||
# repeatedly executes the method specified in the DAG. | ||
self.worker_task_refs: List["ray.ObjectRef"] = [] | ||
# Set of actors present in the DAG. | ||
self.actor_refs = set() | ||
|
||
def _add_node(self, node: "ray.dag.DAGNode") -> None: | ||
idx = self.counter | ||
|
@@ -254,7 +265,7 @@ def _preprocess(self) -> None: | |
|
||
def _get_or_compile( | ||
self, | ||
) -> Tuple[ChannelType, Union[ChannelType, List[ChannelType]]]: | ||
) -> Tuple[Channel, Union[Channel, List[Channel]]]: | ||
"""Compile an execution path. This allocates channels for adjacent | ||
tasks to send/receive values. An infinite task is submitted to each | ||
actor in the DAG that repeatedly receives from input channel(s) and | ||
|
@@ -303,8 +314,9 @@ def _get_or_compile( | |
num_readers=task.num_readers, | ||
) | ||
) | ||
self.actor_refs.add(task.dag_node._get_actor_handle()) | ||
elif isinstance(task.dag_node, InputNode): | ||
task.output_channel = ray_channel.Channel( | ||
task.output_channel = Channel( | ||
buffer_size_bytes=self._buffer_size_bytes, | ||
num_readers=task.num_readers, | ||
) | ||
|
@@ -345,7 +357,7 @@ def _get_or_compile( | |
# Assign the task with the correct input and output buffers. | ||
worker_fn = task.dag_node._get_remote_method("__ray_call__") | ||
self.worker_task_refs.append( | ||
worker_fn.remote( | ||
worker_fn.options(concurrency_group="_ray_system").remote( | ||
do_exec_compiled_task, | ||
resolved_args, | ||
task.dag_node.get_method_name(), | ||
|
@@ -373,13 +385,53 @@ def _get_or_compile( | |
self.dag_output_channels = self.dag_output_channels[0] | ||
|
||
# Driver should ray.put on input, ray.get/release on output | ||
return (self.dag_input_channel, self.dag_output_channels) | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._monitor = self._monitor_failures() | ||
return (self.dag_input_channel, self.dag_output_channels, self._monitor) | ||
|
||
def _monitor_failures(self): | ||
outer = self | ||
|
||
class Monitor(threading.Thread): | ||
def __init__(self): | ||
super().__init__(daemon=True) | ||
self.in_teardown = False | ||
|
||
def teardown(self): | ||
if self.in_teardown: | ||
return | ||
logger.info("Tearing down compiled DAG") | ||
self.in_teardown = True | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for actor in outer.actor_refs: | ||
logger.info(f"Cancelling compiled worker on actor: {actor}") | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
try: | ||
ray.get(actor.__ray_call__.remote(do_cancel_compiled_task)) | ||
except Exception as e: | ||
logger.info(f"Error cancelling worker task: {e}") | ||
pass | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.info("Teardown complete") | ||
|
||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def run(self): | ||
try: | ||
ray.get(outer.worker_task_refs) | ||
except Exception as e: | ||
if self.in_teardown: | ||
return | ||
if isinstance(outer.dag_output_channels, list): | ||
for output_channel in outer.dag_output_channels: | ||
output_channel.set_error(e) | ||
else: | ||
outer.dag_output_channels.set_error(e) | ||
self.teardown() | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
monitor = Monitor() | ||
monitor.start() | ||
return monitor | ||
|
||
def execute( | ||
self, | ||
*args, | ||
**kwargs, | ||
) -> Union[ChannelType, List[ChannelType]]: | ||
) -> Union[Channel, List[Channel]]: | ||
"""Execute this DAG using the compiled execution path. | ||
|
||
Args: | ||
|
@@ -400,6 +452,10 @@ def execute( | |
input_channel.write(args[0]) | ||
return output_channels | ||
|
||
def teardown(self): | ||
"""Teardown and cancel all worker tasks for this DAG.""" | ||
self._monitor.teardown() | ||
|
||
|
||
@DeveloperAPI | ||
def build_compiled_dag_from_ray_dag( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# coding: utf-8 | ||
import logging | ||
import os | ||
import random | ||
import sys | ||
|
||
import pytest | ||
|
@@ -9,7 +10,6 @@ | |
import ray.cluster_utils | ||
from ray.dag import InputNode, MultiOutputNode | ||
from ray.tests.conftest import * # noqa | ||
from ray._private.test_utils import wait_for_condition | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -20,12 +20,17 @@ | |
|
||
@ray.remote | ||
class Actor: | ||
def __init__(self, init_value): | ||
def __init__(self, init_value, fail_after=None): | ||
print("__init__ PID", os.getpid()) | ||
self.i = init_value | ||
self.fail_after = fail_after | ||
|
||
def inc(self, x): | ||
self.i += x | ||
if self.fail_after and self.i > self.fail_after: | ||
# Randomize the failures to better cover multi actor scenarios. | ||
if random.random() > 0.5: | ||
raise ValueError("injected fault") | ||
return self.i | ||
|
||
def append_to(self, lst): | ||
|
@@ -52,6 +57,10 @@ def test_basic(ray_start_regular): | |
assert result == i + 1 | ||
output_channel.end_read() | ||
|
||
# Note: must teardown before starting a new Ray session, otherwise you'll get | ||
# a segfault from the dangling monitor thread upon the new Ray init. | ||
compiled_dag.teardown() | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def test_regular_args(ray_start_regular): | ||
# Test passing regular args to .bind in addition to DAGNode args. | ||
|
@@ -68,6 +77,8 @@ def test_regular_args(ray_start_regular): | |
assert result == (i + 1) * 3 | ||
output_channel.end_read() | ||
|
||
compiled_dag.teardown() | ||
|
||
|
||
@pytest.mark.parametrize("num_actors", [1, 4]) | ||
def test_scatter_gather_dag(ray_start_regular, num_actors): | ||
|
@@ -86,6 +97,8 @@ def test_scatter_gather_dag(ray_start_regular, num_actors): | |
for chan in output_channels: | ||
chan.end_read() | ||
|
||
compiled_dag.teardown() | ||
|
||
|
||
@pytest.mark.parametrize("num_actors", [1, 4]) | ||
def test_chain_dag(ray_start_regular, num_actors): | ||
|
@@ -104,17 +117,20 @@ def test_chain_dag(ray_start_regular, num_actors): | |
assert result == list(range(num_actors)) | ||
output_channel.end_read() | ||
|
||
compiled_dag.teardown() | ||
|
||
|
||
def test_dag_exception(ray_start_regular, capsys): | ||
a = Actor.remote(0) | ||
with InputNode() as inp: | ||
dag = a.inc.bind(inp) | ||
|
||
compiled_dag = dag.experimental_compile() | ||
compiled_dag.execute("hello") | ||
wait_for_condition( | ||
lambda: "Compiled DAG task aborted with exception" in capsys.readouterr().err | ||
) | ||
output_channel = compiled_dag.execute("hello") | ||
with pytest.raises(TypeError): | ||
output_channel.begin_read() | ||
|
||
compiled_dag.teardown() | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def test_dag_errors(ray_start_regular): | ||
|
@@ -176,6 +192,32 @@ def f(x): | |
dag.experimental_compile() | ||
|
||
|
||
@pytest.mark.parametrize("num_actors", [1, 4]) | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def test_dag_fault_tolerance(ray_start_regular, num_actors): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also add a test for worker process dying? |
||
actors = [Actor.remote(0, fail_after=100) for _ in range(num_actors)] | ||
with InputNode() as i: | ||
out = [a.inc.bind(i) for a in actors] | ||
dag = MultiOutputNode(out) | ||
|
||
compiled_dag = dag.experimental_compile() | ||
|
||
for i in range(99): | ||
output_channels = compiled_dag.execute(1) | ||
# TODO(swang): Replace with fake ObjectRef. | ||
results = [chan.begin_read() for chan in output_channels] | ||
assert results == [i + 1] * num_actors | ||
for chan in output_channels: | ||
chan.end_read() | ||
|
||
with pytest.raises(ValueError): | ||
for i in range(99): | ||
rkooo567 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
output_channels = compiled_dag.execute(1) | ||
for chan in output_channels: | ||
chan.begin_read() | ||
for chan in output_channels: | ||
chan.end_read() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: is end_read necessary to call btw? Since now begin_read can raise an exception, it is easy to forget calling end_read. |
||
|
||
|
||
if __name__ == "__main__": | ||
if os.environ.get("PARALLEL_CI"): | ||
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
from typing import Any, Optional | ||
|
||
import ray | ||
from ray.exceptions import RaySystemError | ||
from ray.util.annotations import PublicAPI | ||
|
||
# Logger for this module. It should be configured at the entry point | ||
|
@@ -159,3 +160,32 @@ def end_read(self): | |
self._worker.core_worker.experimental_mutable_object_read_release( | ||
[self._base_ref] | ||
) | ||
|
||
def set_error(self, e: Exception) -> None: | ||
""" | ||
Shutdown the channel with the specified error object. New readers will see | ||
the error raised when they try to read from the channel. | ||
|
||
Does not block. | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
e: The exception object to write to the channel. | ||
""" | ||
logger.debug(f"Writing error to channel: {self._base_ref}: {e}") | ||
serialized_exc = self._worker.get_serialization_context().serialize(e) | ||
try: | ||
self._worker.core_worker.experimental_mutable_object_put_serialized( | ||
serialized_exc, | ||
self._base_ref, | ||
num_readers=1, | ||
try_wait=True, | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
except Exception as e: | ||
if not _is_write_acquire_failed_error(e): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I didn't quite understand this condition. It seems to fail silently if we fail to acquire? Also, could you comment on what cases we expect to fail to acquire? |
||
logger.exception("Error setting error on channel") | ||
raise | ||
|
||
|
||
def _is_write_acquire_failed_error(e: Exception) -> bool: | ||
# TODO(ekl): detect the exception type better | ||
return isinstance(e, RaySystemError) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied from non-experimental path.