Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][experimental] Raise an exception if a leaf node is found during compilation #47757

Merged
merged 9 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,27 @@ def _preprocess(self) -> None:
# Add all readers to the NCCL actors of P2P.
nccl_actors_p2p.add(downstream_actor_handle)

# Collect all leaf nodes.
leaf_nodes: DAGNode = []
for idx, task in self.idx_to_task.items():
if not isinstance(task.dag_node, ClassMethodNode):
continue
if (
len(task.downstream_task_idxs) == 0
and not task.dag_node.is_adag_output_node
):
leaf_nodes.append(task.dag_node)
# Leaf nodes are not allowed because the exception thrown by the leaf
# node will not be propagated to the driver.
if len(leaf_nodes) != 0:
raise ValueError(
"Compiled DAG doesn't support leaf nodes, i.e., nodes that don't have "
"downstream nodes and are not output nodes. There are "
f"{len(leaf_nodes)} leaf nodes in the DAG. Please add the outputs of "
f"{[leaf_node.get_method_name() for leaf_node in leaf_nodes]} to the "
f"the MultiOutputNode."
)

nccl_actors_p2p = list(nccl_actors_p2p)
if None in nccl_actors_p2p:
raise ValueError("Driver cannot participate in the NCCL group.")
Expand Down
106 changes: 66 additions & 40 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,30 +174,6 @@ def test_basic(ray_start_regular):
del result


def test_two_returns_first(ray_start_regular):
a = Actor.remote(0)
with InputNode() as i:
o1, o2 = a.return_two.bind(i)
dag = o1

compiled_dag = dag.experimental_compile()
for _ in range(3):
res = ray.get(compiled_dag.execute(1))
assert res == 1


def test_two_returns_second(ray_start_regular):
a = Actor.remote(0)
with InputNode() as i:
o1, o2 = a.return_two.bind(i)
dag = o2

compiled_dag = dag.experimental_compile()
for _ in range(3):
res = ray.get(compiled_dag.execute(1))
assert res == 2


@pytest.mark.parametrize("single_fetch", [True, False])
def test_two_returns_one_reader(ray_start_regular, single_fetch):
a = Actor.remote(0)
Expand Down Expand Up @@ -1262,15 +1238,15 @@ def test_compile_twice_with_different_nodes(self, ray_start_regular):
with InputNode() as i:
branch1 = a.echo.bind(i)
branch2 = b.echo.bind(i)
dag = MultiOutputNode([branch1])
dag = MultiOutputNode([branch1, branch2])
compiled_dag = dag.experimental_compile()
compiled_dag.teardown()
with pytest.raises(
ValueError,
match="The DAG was compiled more than once. The following two "
"nodes call `experimental_compile`: ",
):
compiled_dag = branch2.experimental_compile()
branch2.experimental_compile()
ruisearch42 marked this conversation as resolved.
Show resolved Hide resolved


def test_exceed_max_buffered_results(ray_start_regular):
Expand Down Expand Up @@ -1782,15 +1758,22 @@ def test_intra_process_channel_with_multi_readers(


class TestLeafNode:
"""
Leaf nodes are not allowed right now because the exception thrown by the leaf
node will not be propagated to the driver and silently ignored, which is undesired.
"""

LEAF_NODE_EXCEPTION_TEMPLATE = (
"Compiled DAG doesn't support leaf nodes, i.e., nodes that don't have "
"downstream nodes and are not output nodes. There are {num_leaf_nodes} "
"leaf nodes in the DAG. Please add the outputs of"
)

def test_leaf_node_one_actor(self, ray_start_regular):
"""
driver -> a.inc
|
-> a.inc -> driver

The upper branch (branch 1) is a leaf node, and it will be executed
before the lower `a.inc` task because of the control dependency. Hence,
the result will be [20] because `a.inc` will be executed twice.
"""
a = Actor.remote(0)
with InputNode() as i:
Expand All @@ -1799,10 +1782,11 @@ def test_leaf_node_one_actor(self, ray_start_regular):
branch2 = a.inc.bind(input_data)
dag = MultiOutputNode([branch2])

compiled_dag = dag.experimental_compile()

ref = compiled_dag.execute(10)
assert ray.get(ref) == [20]
with pytest.raises(
ValueError,
match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=1),
):
dag.experimental_compile()

def test_leaf_node_two_actors(self, ray_start_regular):
"""
Expand All @@ -1811,20 +1795,62 @@ def test_leaf_node_two_actors(self, ray_start_regular):
| -> b.inc ----> driver
|
-> a.inc (branch 1)

The lower branch (branch 1) is a leaf node, and it will be executed
before the upper `a.inc` task because of the control dependency.
"""
a = Actor.remote(0)
b = Actor.remote(100)
with InputNode() as i:
a.inc.bind(i) # branch1: leaf node
branch2 = b.inc.bind(i)
dag = MultiOutputNode([a.inc.bind(branch2), b.inc.bind(branch2)])
compiled_dag = dag.experimental_compile()
with pytest.raises(
ValueError,
match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=1),
):
dag.experimental_compile()

def test_multi_leaf_nodes(self, ray_start_regular):
"""
driver -> a.inc -> a.inc (branch 1, leaf node)
| |
| -> a.inc -> driver
|
-> a.inc (branch 2, leaf node)
"""
a = Actor.remote(0)
with InputNode() as i:
dag = a.inc.bind(i)
a.inc.bind(dag) # branch1: leaf node
a.inc.bind(i) # branch2: leaf node
dag = MultiOutputNode([a.inc.bind(dag)])

with pytest.raises(
ValueError,
match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=2),
):
dag.experimental_compile()

ref = compiled_dag.execute(10)
assert ray.get(ref) == [120, 220]
def test_two_returns_first(self, ray_start_regular):
a = Actor.remote(0)
with InputNode() as i:
o1, o2 = a.return_two.bind(i)
dag = o1

with pytest.raises(
ValueError,
match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=1),
):
dag.experimental_compile()

def test_two_returns_second(self, ray_start_regular):
a = Actor.remote(0)
with InputNode() as i:
o1, o2 = a.return_two.bind(i)
dag = o2
with pytest.raises(
ValueError,
match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=1),
):
dag.experimental_compile()


def test_output_node(ray_start_regular):
Expand Down
5 changes: 5 additions & 0 deletions python/ray/dag/tests/experimental/test_collective_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def test_comm_deduplicate_p2p_and_collective(ray_start_regular, monkeypatch):
dag = workers[1].recv.bind(
collectives[0].with_type_hint(TorchTensorType(transport="nccl"))
)
dag = MultiOutputNode([dag, collectives[1]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
monkeypatch,
Expand Down Expand Up @@ -435,6 +436,7 @@ def test_custom_comm_deduplicate(ray_start_regular, monkeypatch):
dag = workers[0].recv.bind(
collectives[1].with_type_hint(TorchTensorType(transport="nccl"))
)
dag = MultiOutputNode([dag, collectives[0]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
monkeypatch,
Expand All @@ -453,6 +455,7 @@ def test_custom_comm_deduplicate(ray_start_regular, monkeypatch):
dag = workers[0].recv.bind(
collectives[1].with_type_hint(TorchTensorType(transport=comm))
)
dag = MultiOutputNode([dag, collectives[0]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
monkeypatch,
Expand Down Expand Up @@ -487,6 +490,7 @@ def test_custom_comm_init_teardown(ray_start_regular, monkeypatch):
dag = workers[0].recv.bind(
allreduce[1].with_type_hint(TorchTensorType(transport=comm))
)
dag = MultiOutputNode([dag, allreduce[0]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
monkeypatch,
Expand All @@ -508,6 +512,7 @@ def test_custom_comm_init_teardown(ray_start_regular, monkeypatch):
dag = workers[0].recv.bind(
allreduce2[1].with_type_hint(TorchTensorType(transport=comm_3))
)
dag = MultiOutputNode([dag, allreduce2[0]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
monkeypatch,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def test_torch_tensor_nccl_all_reduce_get_partial(ray_start_regular):
collectives = collective.allreduce.bind(computes, ReduceOp.SUM)
recv = workers[0].recv.bind(collectives[0])
tensor = workers[1].recv_tensor.bind(collectives[0])
dag = MultiOutputNode([recv, tensor])
dag = MultiOutputNode([recv, tensor, collectives[1]])

compiled_dag = dag.experimental_compile()

Expand All @@ -955,7 +955,7 @@ def test_torch_tensor_nccl_all_reduce_get_partial(ray_start_regular):
[(shape, dtype, i + idx + 1) for idx in range(num_workers)]
)
result = ray.get(ref)
metadata, tensor = result
metadata, tensor, _ = result
reduced_val = sum(i + idx + 1 for idx in range(num_workers))
assert metadata == (reduced_val, shape, dtype)
tensor = tensor.to("cpu")
Expand Down