Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Fault tolerance for compiled DAGs #41943

Merged
merged 18 commits into from
Dec 21, 2023
Merged
14 changes: 10 additions & 4 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,10 +860,16 @@ def get_objects(
debugger_breakpoint = metadata_fields[1][
len(ray_constants.OBJECT_METADATA_DEBUG_PREFIX) :
]
return (
self.deserialize_objects(data_metadata_pairs, object_refs),
debugger_breakpoint,
)
values = self.deserialize_objects(data_metadata_pairs, object_refs)
for i, value in enumerate(values):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied from non-experimental path.

if isinstance(value, RayError):
if isinstance(value, ray.exceptions.ObjectLostError):
global_worker.core_worker.dump_object_store_memory_usage()
if isinstance(value, RayTaskError):
raise value.as_instanceof_cause()
else:
raise value
return values, debugger_breakpoint

def main_loop(self):
"""The main loop a worker runs to receive and execute tasks."""
Expand Down
9 changes: 8 additions & 1 deletion python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3517,7 +3517,7 @@ cdef class CoreWorker:

def experimental_mutable_object_put_serialized(self, serialized_object,
ObjectRef object_ref,
num_readers,
num_readers
):
cdef:
CObjectID c_object_id = object_ref.native()
Expand All @@ -3542,6 +3542,13 @@ cdef class CoreWorker:
c_object_id,
))

def experimental_mutable_object_set_error(self, ObjectRef object_ref):
cdef:
CObjectID c_object_id = object_ref.native()

check_status(CCoreWorkerProcess.GetCoreWorker()
.ExperimentalMutableObjectSetError(c_object_id))

def experimental_mutable_object_read_release(self, object_refs):
"""
For experimental.channel.Channel.
Expand Down
127 changes: 103 additions & 24 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
import logging
from typing import Any, Dict, List, Tuple, Union, Optional

from collections import defaultdict
from typing import Any, Dict, List, Tuple, Union, Optional
import logging
import threading
import traceback

import ray
import ray.experimental.channel as ray_channel
from ray.exceptions import RayTaskError
from ray.experimental.channel import Channel
from ray.util.annotations import DeveloperAPI


MAX_BUFFER_SIZE = int(100 * 1e6) # 100MB

ChannelType = "ray.experimental.channel.Channel"

logger = logging.getLogger(__name__)


@DeveloperAPI
def do_allocate_channel(
self, buffer_size_bytes: int, num_readers: int = 1
) -> ChannelType:
def do_allocate_channel(self, buffer_size_bytes: int, num_readers: int = 1) -> Channel:
"""Generic actor method to allocate an output channel.
Args:
Expand All @@ -28,14 +26,14 @@ def do_allocate_channel(
Returns:
The allocated channel.
"""
self._output_channel = ray_channel.Channel(buffer_size_bytes, num_readers)
self._output_channel = Channel(buffer_size_bytes, num_readers)
return self._output_channel


@DeveloperAPI
def do_exec_compiled_task(
self,
inputs: List[Union[Any, ChannelType]],
inputs: List[Union[Any, Channel]],
actor_method_name: str,
) -> None:
"""Generic actor method to begin executing a compiled DAG. This runs an
Expand All @@ -51,14 +49,17 @@ def do_exec_compiled_task(
actor_method_name: The name of the actual actor method to execute in
the loop.
"""
self._dag_cancelled = False

try:
self._input_channels = [i for i in inputs if isinstance(i, Channel)]
method = getattr(self, actor_method_name)

resolved_inputs = []
input_channel_idxs = []
# Add placeholders for input channels.
for idx, inp in enumerate(inputs):
if isinstance(inp, ray_channel.Channel):
if isinstance(inp, Channel):
input_channel_idxs.append((idx, inp))
resolved_inputs.append(None)
else:
Expand All @@ -68,17 +69,42 @@ def do_exec_compiled_task(
for idx, channel in input_channel_idxs:
resolved_inputs[idx] = channel.begin_read()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be good to explicitly try-catch the channel calls so that we can differentiate between expected errors (channel closed), application code errors, and anything else that might error in this loop (most likely system bugs). The try-catch at the end can be for system errors only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I played around with this and deciding the semantics is tricky, so I think we should tackle this later on for productionization.

output_val = method(*resolved_inputs)
try:
output_val = method(*resolved_inputs)
except Exception as exc:
backtrace = ray._private.utils.format_error_message(
"".join(
traceback.format_exception(type(exc), exc, exc.__traceback__)
),
task_exception=True,
)
wrapped = RayTaskError(
function_name="do_exec_compiled_task",
traceback_str=backtrace,
cause=exc,
)
self._output_channel.write(wrapped)
else:
if self._dag_cancelled:
raise RuntimeError("DAG execution cancelled")
self._output_channel.write(output_val)

self._output_channel.write(output_val)
for _, channel in input_channel_idxs:
channel.end_read()

except Exception as e:
logging.warn(f"Compiled DAG task aborted with exception: {e}")
except Exception:
logging.exception("Compiled DAG task exited with exception")
raise


@DeveloperAPI
def do_cancel_compiled_task(self):
self._dag_cancelled = True
for channel in self._input_channels:
channel.close()
self._output_channel.close()


@DeveloperAPI
class CompiledTask:
"""Wraps the normal Ray DAGNode with some metadata."""
Expand Down Expand Up @@ -147,11 +173,13 @@ def __init__(self, buffer_size_bytes: Optional[int]):
self.actor_task_count: Dict["ray._raylet.ActorID", int] = defaultdict(int)

# Cached attributes that are set during compilation.
self.dag_input_channel: Optional[ChannelType] = None
self.dag_output_channels: Optional[ChannelType] = None
self.dag_input_channel: Optional[Channel] = None
self.dag_output_channels: Optional[Channel] = None
# ObjectRef for each worker's task. The task is an infinite loop that
# repeatedly executes the method specified in the DAG.
self.worker_task_refs: List["ray.ObjectRef"] = []
# Set of actors present in the DAG.
self.actor_refs = set()

def _add_node(self, node: "ray.dag.DAGNode") -> None:
idx = self.counter
Expand Down Expand Up @@ -254,7 +282,7 @@ def _preprocess(self) -> None:

def _get_or_compile(
self,
) -> Tuple[ChannelType, Union[ChannelType, List[ChannelType]]]:
) -> Tuple[Channel, Union[Channel, List[Channel]]]:
"""Compile an execution path. This allocates channels for adjacent
tasks to send/receive values. An infinite task is submitted to each
actor in the DAG that repeatedly receives from input channel(s) and
Expand All @@ -276,7 +304,6 @@ def _get_or_compile(

if self.dag_input_channel is not None:
assert self.dag_output_channels is not None
# Driver should ray.put on input, ray.get/release on output
return (
self.dag_input_channel,
self.dag_output_channels,
Expand All @@ -303,8 +330,9 @@ def _get_or_compile(
num_readers=task.num_readers,
)
)
self.actor_refs.add(task.dag_node._get_actor_handle())
elif isinstance(task.dag_node, InputNode):
task.output_channel = ray_channel.Channel(
task.output_channel = Channel(
buffer_size_bytes=self._buffer_size_bytes,
num_readers=task.num_readers,
)
Expand Down Expand Up @@ -345,7 +373,7 @@ def _get_or_compile(
# Assign the task with the correct input and output buffers.
worker_fn = task.dag_node._get_remote_method("__ray_call__")
self.worker_task_refs.append(
worker_fn.remote(
worker_fn.options(concurrency_group="_ray_system").remote(
do_exec_compiled_task,
resolved_args,
task.dag_node.get_method_name(),
Expand Down Expand Up @@ -373,13 +401,60 @@ def _get_or_compile(
self.dag_output_channels = self.dag_output_channels[0]

# Driver should ray.put on input, ray.get/release on output
return (self.dag_input_channel, self.dag_output_channels)
ericl marked this conversation as resolved.
Show resolved Hide resolved
self._monitor = self._monitor_failures()
return (self.dag_input_channel, self.dag_output_channels, self._monitor)

def _monitor_failures(self):
outer = self

class Monitor(threading.Thread):
def __init__(self):
super().__init__(daemon=True)
self.in_teardown = False

def teardown(self):
if self.in_teardown:
return
logger.info("Tearing down compiled DAG")
self.in_teardown = True
ericl marked this conversation as resolved.
Show resolved Hide resolved
for actor in outer.actor_refs:
logger.info(f"Cancelling compiled worker on actor: {actor}")
ericl marked this conversation as resolved.
Show resolved Hide resolved
try:
ray.get(actor.__ray_call__.remote(do_cancel_compiled_task))
except Exception:
logger.exception("Error cancelling worker task")
pass
logger.info("Waiting for worker tasks to exit")
for ref in outer.worker_task_refs:
try:
ray.get(ref)
except Exception:
pass
logger.info("Teardown complete")

ericl marked this conversation as resolved.
Show resolved Hide resolved
def run(self):
try:
ray.get(outer.worker_task_refs)
except Exception as e:
logger.debug(f"Handling exception from worker tasks: {e}")
if self.in_teardown:
return
if isinstance(outer.dag_output_channels, list):
for output_channel in outer.dag_output_channels:
output_channel.close()
else:
outer.dag_output_channels.close()
self.teardown()

monitor = Monitor()
monitor.start()
return monitor

def execute(
self,
*args,
**kwargs,
) -> Union[ChannelType, List[ChannelType]]:
) -> Union[Channel, List[Channel]]:
"""Execute this DAG using the compiled execution path.
Args:
Expand All @@ -400,6 +475,10 @@ def execute(
input_channel.write(args[0])
return output_channels

def teardown(self):
"""Teardown and cancel all worker tasks for this DAG."""
self._monitor.teardown()


@DeveloperAPI
def build_compiled_dag_from_ray_dag(
Expand Down
Loading
Loading