-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Fault tolerance for compiled DAGs #41943
Changes from all commits
0780ed4
13bd97f
13e8bd7
b2e493b
94c5ba4
a6762a4
a15f318
6ca9b43
f353c76
069dc45
082c6b2
5fab84d
482296d
04ad6b4
de7156f
9923a3d
dd5558e
1e19048
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,22 @@ | ||
import logging | ||
from typing import Any, Dict, List, Tuple, Union, Optional | ||
|
||
from collections import defaultdict | ||
from typing import Any, Dict, List, Tuple, Union, Optional | ||
import logging | ||
import threading | ||
import traceback | ||
|
||
import ray | ||
import ray.experimental.channel as ray_channel | ||
from ray.exceptions import RayTaskError | ||
from ray.experimental.channel import Channel | ||
from ray.util.annotations import DeveloperAPI | ||
|
||
|
||
MAX_BUFFER_SIZE = int(100 * 1e6) # 100MB | ||
|
||
ChannelType = "ray.experimental.channel.Channel" | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@DeveloperAPI | ||
def do_allocate_channel( | ||
self, buffer_size_bytes: int, num_readers: int = 1 | ||
) -> ChannelType: | ||
def do_allocate_channel(self, buffer_size_bytes: int, num_readers: int = 1) -> Channel: | ||
"""Generic actor method to allocate an output channel. | ||
Args: | ||
|
@@ -28,14 +26,14 @@ def do_allocate_channel( | |
Returns: | ||
The allocated channel. | ||
""" | ||
self._output_channel = ray_channel.Channel(buffer_size_bytes, num_readers) | ||
self._output_channel = Channel(buffer_size_bytes, num_readers) | ||
return self._output_channel | ||
|
||
|
||
@DeveloperAPI | ||
def do_exec_compiled_task( | ||
self, | ||
inputs: List[Union[Any, ChannelType]], | ||
inputs: List[Union[Any, Channel]], | ||
actor_method_name: str, | ||
) -> None: | ||
"""Generic actor method to begin executing a compiled DAG. This runs an | ||
|
@@ -51,14 +49,17 @@ def do_exec_compiled_task( | |
actor_method_name: The name of the actual actor method to execute in | ||
the loop. | ||
""" | ||
self._dag_cancelled = False | ||
|
||
try: | ||
self._input_channels = [i for i in inputs if isinstance(i, Channel)] | ||
method = getattr(self, actor_method_name) | ||
|
||
resolved_inputs = [] | ||
input_channel_idxs = [] | ||
# Add placeholders for input channels. | ||
for idx, inp in enumerate(inputs): | ||
if isinstance(inp, ray_channel.Channel): | ||
if isinstance(inp, Channel): | ||
input_channel_idxs.append((idx, inp)) | ||
resolved_inputs.append(None) | ||
else: | ||
|
@@ -68,17 +69,42 @@ def do_exec_compiled_task( | |
for idx, channel in input_channel_idxs: | ||
resolved_inputs[idx] = channel.begin_read() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be good to explicitly try-catch the channel calls so that we can differentiate between expected errors (channel closed), application code errors, and anything else that might error in this loop (most likely system bugs). The try-catch at the end can be for system errors only. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm I played around with this and deciding the semantics is tricky, so I think we should tackle this later on for productionization. |
||
output_val = method(*resolved_inputs) | ||
try: | ||
output_val = method(*resolved_inputs) | ||
except Exception as exc: | ||
backtrace = ray._private.utils.format_error_message( | ||
"".join( | ||
traceback.format_exception(type(exc), exc, exc.__traceback__) | ||
), | ||
task_exception=True, | ||
) | ||
wrapped = RayTaskError( | ||
function_name="do_exec_compiled_task", | ||
traceback_str=backtrace, | ||
cause=exc, | ||
) | ||
self._output_channel.write(wrapped) | ||
else: | ||
if self._dag_cancelled: | ||
raise RuntimeError("DAG execution cancelled") | ||
self._output_channel.write(output_val) | ||
|
||
self._output_channel.write(output_val) | ||
for _, channel in input_channel_idxs: | ||
channel.end_read() | ||
|
||
except Exception as e: | ||
logging.warn(f"Compiled DAG task aborted with exception: {e}") | ||
except Exception: | ||
logging.exception("Compiled DAG task exited with exception") | ||
raise | ||
|
||
|
||
@DeveloperAPI | ||
def do_cancel_compiled_task(self): | ||
self._dag_cancelled = True | ||
for channel in self._input_channels: | ||
channel.close() | ||
self._output_channel.close() | ||
|
||
|
||
@DeveloperAPI | ||
class CompiledTask: | ||
"""Wraps the normal Ray DAGNode with some metadata.""" | ||
|
@@ -147,11 +173,13 @@ def __init__(self, buffer_size_bytes: Optional[int]): | |
self.actor_task_count: Dict["ray._raylet.ActorID", int] = defaultdict(int) | ||
|
||
# Cached attributes that are set during compilation. | ||
self.dag_input_channel: Optional[ChannelType] = None | ||
self.dag_output_channels: Optional[ChannelType] = None | ||
self.dag_input_channel: Optional[Channel] = None | ||
self.dag_output_channels: Optional[Channel] = None | ||
# ObjectRef for each worker's task. The task is an infinite loop that | ||
# repeatedly executes the method specified in the DAG. | ||
self.worker_task_refs: List["ray.ObjectRef"] = [] | ||
# Set of actors present in the DAG. | ||
self.actor_refs = set() | ||
|
||
def _add_node(self, node: "ray.dag.DAGNode") -> None: | ||
idx = self.counter | ||
|
@@ -254,7 +282,7 @@ def _preprocess(self) -> None: | |
|
||
def _get_or_compile( | ||
self, | ||
) -> Tuple[ChannelType, Union[ChannelType, List[ChannelType]]]: | ||
) -> Tuple[Channel, Union[Channel, List[Channel]]]: | ||
"""Compile an execution path. This allocates channels for adjacent | ||
tasks to send/receive values. An infinite task is submitted to each | ||
actor in the DAG that repeatedly receives from input channel(s) and | ||
|
@@ -276,7 +304,6 @@ def _get_or_compile( | |
|
||
if self.dag_input_channel is not None: | ||
assert self.dag_output_channels is not None | ||
# Driver should ray.put on input, ray.get/release on output | ||
return ( | ||
self.dag_input_channel, | ||
self.dag_output_channels, | ||
|
@@ -303,8 +330,9 @@ def _get_or_compile( | |
num_readers=task.num_readers, | ||
) | ||
) | ||
self.actor_refs.add(task.dag_node._get_actor_handle()) | ||
elif isinstance(task.dag_node, InputNode): | ||
task.output_channel = ray_channel.Channel( | ||
task.output_channel = Channel( | ||
buffer_size_bytes=self._buffer_size_bytes, | ||
num_readers=task.num_readers, | ||
) | ||
|
@@ -345,7 +373,7 @@ def _get_or_compile( | |
# Assign the task with the correct input and output buffers. | ||
worker_fn = task.dag_node._get_remote_method("__ray_call__") | ||
self.worker_task_refs.append( | ||
worker_fn.remote( | ||
worker_fn.options(concurrency_group="_ray_system").remote( | ||
do_exec_compiled_task, | ||
resolved_args, | ||
task.dag_node.get_method_name(), | ||
|
@@ -373,13 +401,60 @@ def _get_or_compile( | |
self.dag_output_channels = self.dag_output_channels[0] | ||
|
||
# Driver should ray.put on input, ray.get/release on output | ||
return (self.dag_input_channel, self.dag_output_channels) | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._monitor = self._monitor_failures() | ||
return (self.dag_input_channel, self.dag_output_channels, self._monitor) | ||
|
||
def _monitor_failures(self): | ||
outer = self | ||
|
||
class Monitor(threading.Thread): | ||
def __init__(self): | ||
super().__init__(daemon=True) | ||
self.in_teardown = False | ||
|
||
def teardown(self): | ||
if self.in_teardown: | ||
return | ||
logger.info("Tearing down compiled DAG") | ||
self.in_teardown = True | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for actor in outer.actor_refs: | ||
logger.info(f"Cancelling compiled worker on actor: {actor}") | ||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
try: | ||
ray.get(actor.__ray_call__.remote(do_cancel_compiled_task)) | ||
except Exception: | ||
logger.exception("Error cancelling worker task") | ||
pass | ||
logger.info("Waiting for worker tasks to exit") | ||
for ref in outer.worker_task_refs: | ||
try: | ||
ray.get(ref) | ||
except Exception: | ||
pass | ||
logger.info("Teardown complete") | ||
|
||
ericl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def run(self): | ||
try: | ||
ray.get(outer.worker_task_refs) | ||
except Exception as e: | ||
logger.debug(f"Handling exception from worker tasks: {e}") | ||
if self.in_teardown: | ||
return | ||
if isinstance(outer.dag_output_channels, list): | ||
for output_channel in outer.dag_output_channels: | ||
output_channel.close() | ||
else: | ||
outer.dag_output_channels.close() | ||
self.teardown() | ||
|
||
monitor = Monitor() | ||
monitor.start() | ||
return monitor | ||
|
||
def execute( | ||
self, | ||
*args, | ||
**kwargs, | ||
) -> Union[ChannelType, List[ChannelType]]: | ||
) -> Union[Channel, List[Channel]]: | ||
"""Execute this DAG using the compiled execution path. | ||
Args: | ||
|
@@ -400,6 +475,10 @@ def execute( | |
input_channel.write(args[0]) | ||
return output_channels | ||
|
||
def teardown(self): | ||
"""Teardown and cancel all worker tasks for this DAG.""" | ||
self._monitor.teardown() | ||
|
||
|
||
@DeveloperAPI | ||
def build_compiled_dag_from_ray_dag( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied from non-experimental path.