Skip to content

Commit

Permalink
[core] Fault tolerance for compiled DAGs (ray-project#41943)
Browse files Browse the repository at this point in the history
This adds fault tolerance and a teardown method for compiled DAGs.
  • Loading branch information
ericl authored and vickytsang committed Jan 12, 2024
1 parent 170aabb commit 9969eaf
Show file tree
Hide file tree
Showing 15 changed files with 393 additions and 112 deletions.
14 changes: 10 additions & 4 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,10 +860,16 @@ def get_objects(
debugger_breakpoint = metadata_fields[1][
len(ray_constants.OBJECT_METADATA_DEBUG_PREFIX) :
]
return (
self.deserialize_objects(data_metadata_pairs, object_refs),
debugger_breakpoint,
)
values = self.deserialize_objects(data_metadata_pairs, object_refs)
for i, value in enumerate(values):
if isinstance(value, RayError):
if isinstance(value, ray.exceptions.ObjectLostError):
global_worker.core_worker.dump_object_store_memory_usage()
if isinstance(value, RayTaskError):
raise value.as_instanceof_cause()
else:
raise value
return values, debugger_breakpoint

def main_loop(self):
"""The main loop a worker runs to receive and execute tasks."""
Expand Down
9 changes: 8 additions & 1 deletion python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3517,7 +3517,7 @@ cdef class CoreWorker:

def experimental_mutable_object_put_serialized(self, serialized_object,
ObjectRef object_ref,
num_readers,
num_readers
):
cdef:
CObjectID c_object_id = object_ref.native()
Expand All @@ -3542,6 +3542,13 @@ cdef class CoreWorker:
c_object_id,
))

def experimental_mutable_object_set_error(self, ObjectRef object_ref):
cdef:
CObjectID c_object_id = object_ref.native()

check_status(CCoreWorkerProcess.GetCoreWorker()
.ExperimentalMutableObjectSetError(c_object_id))

def experimental_mutable_object_read_release(self, object_refs):
"""
For experimental.channel.Channel.
Expand Down
127 changes: 103 additions & 24 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
import logging
from typing import Any, Dict, List, Tuple, Union, Optional

from collections import defaultdict
from typing import Any, Dict, List, Tuple, Union, Optional
import logging
import threading
import traceback

import ray
import ray.experimental.channel as ray_channel
from ray.exceptions import RayTaskError
from ray.experimental.channel import Channel
from ray.util.annotations import DeveloperAPI


MAX_BUFFER_SIZE = int(100 * 1e6) # 100MB

ChannelType = "ray.experimental.channel.Channel"

logger = logging.getLogger(__name__)


@DeveloperAPI
def do_allocate_channel(
self, buffer_size_bytes: int, num_readers: int = 1
) -> ChannelType:
def do_allocate_channel(self, buffer_size_bytes: int, num_readers: int = 1) -> Channel:
"""Generic actor method to allocate an output channel.
Args:
Expand All @@ -28,14 +26,14 @@ def do_allocate_channel(
Returns:
The allocated channel.
"""
self._output_channel = ray_channel.Channel(buffer_size_bytes, num_readers)
self._output_channel = Channel(buffer_size_bytes, num_readers)
return self._output_channel


@DeveloperAPI
def do_exec_compiled_task(
self,
inputs: List[Union[Any, ChannelType]],
inputs: List[Union[Any, Channel]],
actor_method_name: str,
) -> None:
"""Generic actor method to begin executing a compiled DAG. This runs an
Expand All @@ -51,14 +49,17 @@ def do_exec_compiled_task(
actor_method_name: The name of the actual actor method to execute in
the loop.
"""
self._dag_cancelled = False

try:
self._input_channels = [i for i in inputs if isinstance(i, Channel)]
method = getattr(self, actor_method_name)

resolved_inputs = []
input_channel_idxs = []
# Add placeholders for input channels.
for idx, inp in enumerate(inputs):
if isinstance(inp, ray_channel.Channel):
if isinstance(inp, Channel):
input_channel_idxs.append((idx, inp))
resolved_inputs.append(None)
else:
Expand All @@ -68,17 +69,42 @@ def do_exec_compiled_task(
for idx, channel in input_channel_idxs:
resolved_inputs[idx] = channel.begin_read()

output_val = method(*resolved_inputs)
try:
output_val = method(*resolved_inputs)
except Exception as exc:
backtrace = ray._private.utils.format_error_message(
"".join(
traceback.format_exception(type(exc), exc, exc.__traceback__)
),
task_exception=True,
)
wrapped = RayTaskError(
function_name="do_exec_compiled_task",
traceback_str=backtrace,
cause=exc,
)
self._output_channel.write(wrapped)
else:
if self._dag_cancelled:
raise RuntimeError("DAG execution cancelled")
self._output_channel.write(output_val)

self._output_channel.write(output_val)
for _, channel in input_channel_idxs:
channel.end_read()

except Exception as e:
logging.warn(f"Compiled DAG task aborted with exception: {e}")
except Exception:
logging.exception("Compiled DAG task exited with exception")
raise


@DeveloperAPI
def do_cancel_compiled_task(self):
self._dag_cancelled = True
for channel in self._input_channels:
channel.close()
self._output_channel.close()


@DeveloperAPI
class CompiledTask:
"""Wraps the normal Ray DAGNode with some metadata."""
Expand Down Expand Up @@ -147,11 +173,13 @@ def __init__(self, buffer_size_bytes: Optional[int]):
self.actor_task_count: Dict["ray._raylet.ActorID", int] = defaultdict(int)

# Cached attributes that are set during compilation.
self.dag_input_channel: Optional[ChannelType] = None
self.dag_output_channels: Optional[ChannelType] = None
self.dag_input_channel: Optional[Channel] = None
self.dag_output_channels: Optional[Channel] = None
# ObjectRef for each worker's task. The task is an infinite loop that
# repeatedly executes the method specified in the DAG.
self.worker_task_refs: List["ray.ObjectRef"] = []
# Set of actors present in the DAG.
self.actor_refs = set()

def _add_node(self, node: "ray.dag.DAGNode") -> None:
idx = self.counter
Expand Down Expand Up @@ -254,7 +282,7 @@ def _preprocess(self) -> None:

def _get_or_compile(
self,
) -> Tuple[ChannelType, Union[ChannelType, List[ChannelType]]]:
) -> Tuple[Channel, Union[Channel, List[Channel]]]:
"""Compile an execution path. This allocates channels for adjacent
tasks to send/receive values. An infinite task is submitted to each
actor in the DAG that repeatedly receives from input channel(s) and
Expand All @@ -276,7 +304,6 @@ def _get_or_compile(

if self.dag_input_channel is not None:
assert self.dag_output_channels is not None
# Driver should ray.put on input, ray.get/release on output
return (
self.dag_input_channel,
self.dag_output_channels,
Expand All @@ -303,8 +330,9 @@ def _get_or_compile(
num_readers=task.num_readers,
)
)
self.actor_refs.add(task.dag_node._get_actor_handle())
elif isinstance(task.dag_node, InputNode):
task.output_channel = ray_channel.Channel(
task.output_channel = Channel(
buffer_size_bytes=self._buffer_size_bytes,
num_readers=task.num_readers,
)
Expand Down Expand Up @@ -345,7 +373,7 @@ def _get_or_compile(
# Assign the task with the correct input and output buffers.
worker_fn = task.dag_node._get_remote_method("__ray_call__")
self.worker_task_refs.append(
worker_fn.remote(
worker_fn.options(concurrency_group="_ray_system").remote(
do_exec_compiled_task,
resolved_args,
task.dag_node.get_method_name(),
Expand Down Expand Up @@ -373,13 +401,60 @@ def _get_or_compile(
self.dag_output_channels = self.dag_output_channels[0]

# Driver should ray.put on input, ray.get/release on output
return (self.dag_input_channel, self.dag_output_channels)
self._monitor = self._monitor_failures()
return (self.dag_input_channel, self.dag_output_channels, self._monitor)

def _monitor_failures(self):
outer = self

class Monitor(threading.Thread):
def __init__(self):
super().__init__(daemon=True)
self.in_teardown = False

def teardown(self):
if self.in_teardown:
return
logger.info("Tearing down compiled DAG")
self.in_teardown = True
for actor in outer.actor_refs:
logger.info(f"Cancelling compiled worker on actor: {actor}")
try:
ray.get(actor.__ray_call__.remote(do_cancel_compiled_task))
except Exception:
logger.exception("Error cancelling worker task")
pass
logger.info("Waiting for worker tasks to exit")
for ref in outer.worker_task_refs:
try:
ray.get(ref)
except Exception:
pass
logger.info("Teardown complete")

def run(self):
try:
ray.get(outer.worker_task_refs)
except Exception as e:
logger.debug(f"Handling exception from worker tasks: {e}")
if self.in_teardown:
return
if isinstance(outer.dag_output_channels, list):
for output_channel in outer.dag_output_channels:
output_channel.close()
else:
outer.dag_output_channels.close()
self.teardown()

monitor = Monitor()
monitor.start()
return monitor

def execute(
self,
*args,
**kwargs,
) -> Union[ChannelType, List[ChannelType]]:
) -> Union[Channel, List[Channel]]:
"""Execute this DAG using the compiled execution path.
Args:
Expand All @@ -400,6 +475,10 @@ def execute(
input_channel.write(args[0])
return output_channels

def teardown(self):
"""Teardown and cancel all worker tasks for this DAG."""
self._monitor.teardown()


@DeveloperAPI
def build_compiled_dag_from_ray_dag(
Expand Down
Loading

0 comments on commit 9969eaf

Please sign in to comment.