Skip to content

Commit

Permalink
handle cancellation of ongoing dags
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Liang <ekhliang@gmail.com>
  • Loading branch information
ericl committed Dec 15, 2023
1 parent b2e493b commit 94c5ba4
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 8 deletions.
2 changes: 1 addition & 1 deletion python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3518,7 +3518,7 @@ cdef class CoreWorker:
def experimental_mutable_object_put_serialized(self, serialized_object,
ObjectRef object_ref,
num_readers,
try_wait: bool=False
try_wait: bool = False
):
cdef:
CObjectID c_object_id = object_ref.native()
Expand Down
11 changes: 11 additions & 0 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def do_exec_compiled_task(
actor_method_name: The name of the actual actor method to execute in
the loop.
"""
self._dag_cancelled = False

try:
self._input_channels = [i for i in inputs if isinstance(i, Channel)]
method = getattr(self, actor_method_name)
Expand All @@ -68,6 +70,8 @@ def do_exec_compiled_task(

output_val = method(*resolved_inputs)

if self._dag_cancelled:
raise RuntimeError("DAG execution cancelled")
self._output_channel.write(output_val)
for _, channel in input_channel_idxs:
channel.end_read()
Expand All @@ -79,6 +83,7 @@ def do_exec_compiled_task(

@DeveloperAPI
def do_cancel_compiled_task(self):
self._dag_cancelled = True
e = RayTaskError(
function_name="do_exec_compiled_task",
traceback_str="",
Expand Down Expand Up @@ -408,6 +413,12 @@ def teardown(self):
except Exception:
logger.exception("Error cancelling worker task")
pass
logger.info("Waiting for worker tasks to exit")
for ref in outer.worker_task_refs:
try:
ray.get(ref)
except Exception:
pass
logger.info("Teardown complete")

def run(self):
Expand Down
66 changes: 63 additions & 3 deletions python/ray/dag/tests/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import os
import random
import sys
import time

import pytest

import ray
import ray.cluster_utils
from ray.exceptions import RayActorError
from ray.dag import InputNode, MultiOutputNode
from ray.tests.conftest import * # noqa

Expand All @@ -20,17 +22,21 @@

@ray.remote
class Actor:
def __init__(self, init_value, fail_after=None):
def __init__(self, init_value, fail_after=None, sys_exit=False):
print("__init__ PID", os.getpid())
self.i = init_value
self.fail_after = fail_after
self.sys_exit = sys_exit

def inc(self, x):
self.i += x
if self.fail_after and self.i > self.fail_after:
# Randomize the failures to better cover multi actor scenarios.
if random.random() > 0.5:
raise ValueError("injected fault")
if self.sys_exit:
os._exit(1)
else:
raise ValueError("injected fault")
return self.i

def append_to(self, lst):
Expand All @@ -42,6 +48,10 @@ def inc_two(self, x, y):
self.i += y
return self.i

def sleep(self, x):
time.sleep(x)
return x


def test_basic(ray_start_regular):
a = Actor.remote(0)
Expand Down Expand Up @@ -194,7 +204,9 @@ def f(x):

@pytest.mark.parametrize("num_actors", [1, 4])
def test_dag_fault_tolerance(ray_start_regular, num_actors):
actors = [Actor.remote(0, fail_after=100) for _ in range(num_actors)]
actors = [
Actor.remote(0, fail_after=100, sys_exit=False) for _ in range(num_actors)
]
with InputNode() as i:
out = [a.inc.bind(i) for a in actors]
dag = MultiOutputNode(out)
Expand All @@ -218,6 +230,54 @@ def test_dag_fault_tolerance(ray_start_regular, num_actors):
chan.end_read()


def test_dag_fault_tolerance_sys_exit(ray_start_regular):
actors = [Actor.remote(0, fail_after=100, sys_exit=True) for _ in range(1)]
with InputNode() as i:
out = [a.inc.bind(i) for a in actors]
dag = MultiOutputNode(out)

compiled_dag = dag.experimental_compile()

for i in range(99):
output_channels = compiled_dag.execute(1)
# TODO(swang): Replace with fake ObjectRef.
results = [chan.begin_read() for chan in output_channels]
assert results == [i + 1]
for chan in output_channels:
chan.end_read()

with pytest.raises(RayActorError):
for i in range(99):
output_channels = compiled_dag.execute(1)
for chan in output_channels:
chan.begin_read()
for chan in output_channels:
chan.end_read()


def test_dag_teardown_while_running(ray_start_regular):
a = Actor.remote(0)

with InputNode() as inp:
dag = a.sleep.bind(inp)

compiled_dag = dag.experimental_compile()
compiled_dag.execute(3) # 3-second slow task running async
compiled_dag.teardown()

# Check we can still use the actor after first DAG teardown.
with InputNode() as inp:
dag = a.sleep.bind(inp)

compiled_dag = dag.experimental_compile()
chan = compiled_dag.execute(0.1)
result = chan.begin_read()
assert result == 0.1
chan.end_read()

compiled_dag.teardown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
5 changes: 1 addition & 4 deletions src/ray/object_manager/plasma/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,7 @@ Status PlasmaClient::Impl::ExperimentalMutableObjectWriteAcquire(
") is larger than allocated buffer size " +
std::to_string(entry->object.allocated_size));
}
if (!plasma_header->WriteAcquire(data_size,
metadata_size,
num_readers,
try_wait)) {
if (!plasma_header->WriteAcquire(data_size, metadata_size, num_readers, try_wait)) {
return Status::IOError("write acquire failed");
};

Expand Down

0 comments on commit 94c5ba4

Please sign in to comment.