Skip to content

Commit

Permalink
[core][aDAG] Fix a bug where multi arg + exception doesn't work (ray-…
Browse files Browse the repository at this point in the history
…project#47704)

Currently, when there's an exception, there's only 1 return value, but multi ref assumes that the return value has to match the # of output channels. It fixes the issue by duplicating exception to match the number of output channels.
  • Loading branch information
rkooo567 authored Sep 19, 2024
1 parent 5f69744 commit ab94e48
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
44 changes: 44 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ def return_two_as_one(self, x):
def return_two_from_three(self, x):
return x, x + 1, x + 2

@ray.method(num_returns=2)
def return_two_but_raise_exception(self, x):
raise RuntimeError
return 1, 2

def get_events(self):
return getattr(self, "__ray_adag_events", [])

Expand Down Expand Up @@ -2449,6 +2454,45 @@ def call(self, value):
assert torch.equal(ray.get(ref), torch.tensor([5, 5, 5, 5, 5]))


def test_multi_arg_exception(shutdown_only):
a = Actor.remote(0)
with InputNode() as i:
o1, o2 = a.return_two_but_raise_exception.bind(i)
dag = MultiOutputNode([o1, o2])

compiled_dag = dag.experimental_compile()
for _ in range(3):
x, y = compiled_dag.execute(1)
with pytest.raises(RuntimeError):
ray.get(x)
with pytest.raises(RuntimeError):
ray.get(y)

compiled_dag.teardown()


def test_multi_arg_exception_async(shutdown_only):
a = Actor.remote(0)
with InputNode() as i:
o1, o2 = a.return_two_but_raise_exception.bind(i)
dag = MultiOutputNode([o1, o2])

compiled_dag = dag.experimental_compile(enable_asyncio=True)

async def main():
for _ in range(3):
x, y = await compiled_dag.execute_async(1)
with pytest.raises(RuntimeError):
await x
with pytest.raises(RuntimeError):
await y

loop = get_or_create_event_loop()
loop.run_until_complete(main())

compiled_dag.teardown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
6 changes: 6 additions & 0 deletions python/ray/experimental/channel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,12 @@ def start(self):
channel.ensure_registered_as_writer()

def write(self, val: Any, timeout: Optional[float] = None) -> None:
# If it is an exception, there's only 1 return value.
# We have to send the same data to all channels.
if isinstance(val, Exception):
if len(self._output_channels) > 1:
val = tuple(val for _ in range(len(self._output_channels)))

if not self._is_input:
if len(self._output_channels) > 1:
if not isinstance(val, tuple):
Expand Down

0 comments on commit ab94e48

Please sign in to comment.