From 94c5ba43a3b1d7b4bb15fadb880ee28b90ae219f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 15 Dec 2023 15:46:18 -0800 Subject: [PATCH] handle cancellation of ongoing dags Signed-off-by: Eric Liang --- python/ray/_raylet.pyx | 2 +- python/ray/dag/compiled_dag_node.py | 11 ++++ python/ray/dag/tests/test_accelerated_dag.py | 66 +++++++++++++++++++- src/ray/object_manager/plasma/client.cc | 5 +- 4 files changed, 76 insertions(+), 8 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 97a03dad732a..484d3b5a853b 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -3518,7 +3518,7 @@ cdef class CoreWorker: def experimental_mutable_object_put_serialized(self, serialized_object, ObjectRef object_ref, num_readers, - try_wait: bool=False + try_wait: bool = False ): cdef: CObjectID c_object_id = object_ref.native() diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 6a2320e3888e..7facc2189d8c 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -48,6 +48,8 @@ def do_exec_compiled_task( actor_method_name: The name of the actual actor method to execute in the loop. """ + self._dag_cancelled = False + try: self._input_channels = [i for i in inputs if isinstance(i, Channel)] method = getattr(self, actor_method_name) @@ -68,6 +70,8 @@ def do_exec_compiled_task( output_val = method(*resolved_inputs) + if self._dag_cancelled: + raise RuntimeError("DAG execution cancelled") self._output_channel.write(output_val) for _, channel in input_channel_idxs: channel.end_read() @@ -79,6 +83,7 @@ def do_exec_compiled_task( @DeveloperAPI def do_cancel_compiled_task(self): + self._dag_cancelled = True e = RayTaskError( function_name="do_exec_compiled_task", traceback_str="", @@ -408,6 +413,12 @@ def teardown(self): except Exception: logger.exception("Error cancelling worker task") pass + logger.info("Waiting for worker tasks to exit") + for ref in outer.worker_task_refs: + try: + ray.get(ref) + except Exception: + pass logger.info("Teardown complete") def run(self): diff --git a/python/ray/dag/tests/test_accelerated_dag.py b/python/ray/dag/tests/test_accelerated_dag.py index 077d5bb4ec7c..705a9b395c74 100644 --- a/python/ray/dag/tests/test_accelerated_dag.py +++ b/python/ray/dag/tests/test_accelerated_dag.py @@ -3,11 +3,13 @@ import os import random import sys +import time import pytest import ray import ray.cluster_utils +from ray.exceptions import RayActorError from ray.dag import InputNode, MultiOutputNode from ray.tests.conftest import * # noqa @@ -20,17 +22,21 @@ @ray.remote class Actor: - def __init__(self, init_value, fail_after=None): + def __init__(self, init_value, fail_after=None, sys_exit=False): print("__init__ PID", os.getpid()) self.i = init_value self.fail_after = fail_after + self.sys_exit = sys_exit def inc(self, x): self.i += x if self.fail_after and self.i > self.fail_after: # Randomize the failures to better cover multi actor scenarios. if random.random() > 0.5: - raise ValueError("injected fault") + if self.sys_exit: + os._exit(1) + else: + raise ValueError("injected fault") return self.i def append_to(self, lst): @@ -42,6 +48,10 @@ def inc_two(self, x, y): self.i += y return self.i + def sleep(self, x): + time.sleep(x) + return x + def test_basic(ray_start_regular): a = Actor.remote(0) @@ -194,7 +204,9 @@ def f(x): @pytest.mark.parametrize("num_actors", [1, 4]) def test_dag_fault_tolerance(ray_start_regular, num_actors): - actors = [Actor.remote(0, fail_after=100) for _ in range(num_actors)] + actors = [ + Actor.remote(0, fail_after=100, sys_exit=False) for _ in range(num_actors) + ] with InputNode() as i: out = [a.inc.bind(i) for a in actors] dag = MultiOutputNode(out) @@ -218,6 +230,54 @@ def test_dag_fault_tolerance(ray_start_regular, num_actors): chan.end_read() +def test_dag_fault_tolerance_sys_exit(ray_start_regular): + actors = [Actor.remote(0, fail_after=100, sys_exit=True) for _ in range(1)] + with InputNode() as i: + out = [a.inc.bind(i) for a in actors] + dag = MultiOutputNode(out) + + compiled_dag = dag.experimental_compile() + + for i in range(99): + output_channels = compiled_dag.execute(1) + # TODO(swang): Replace with fake ObjectRef. + results = [chan.begin_read() for chan in output_channels] + assert results == [i + 1] + for chan in output_channels: + chan.end_read() + + with pytest.raises(RayActorError): + for i in range(99): + output_channels = compiled_dag.execute(1) + for chan in output_channels: + chan.begin_read() + for chan in output_channels: + chan.end_read() + + +def test_dag_teardown_while_running(ray_start_regular): + a = Actor.remote(0) + + with InputNode() as inp: + dag = a.sleep.bind(inp) + + compiled_dag = dag.experimental_compile() + compiled_dag.execute(3) # 3-second slow task running async + compiled_dag.teardown() + + # Check we can still use the actor after first DAG teardown. + with InputNode() as inp: + dag = a.sleep.bind(inp) + + compiled_dag = dag.experimental_compile() + chan = compiled_dag.execute(0.1) + result = chan.begin_read() + assert result == 0.1 + chan.end_read() + + compiled_dag.teardown() + + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) diff --git a/src/ray/object_manager/plasma/client.cc b/src/ray/object_manager/plasma/client.cc index 88356aa88eff..3ffdcb3ac2e7 100644 --- a/src/ray/object_manager/plasma/client.cc +++ b/src/ray/object_manager/plasma/client.cc @@ -444,10 +444,7 @@ Status PlasmaClient::Impl::ExperimentalMutableObjectWriteAcquire( ") is larger than allocated buffer size " + std::to_string(entry->object.allocated_size)); } - if (!plasma_header->WriteAcquire(data_size, - metadata_size, - num_readers, - try_wait)) { + if (!plasma_header->WriteAcquire(data_size, metadata_size, num_readers, try_wait)) { return Status::IOError("write acquire failed"); };