diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index f22185fd1d72d..6e3739c1962f7 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -1214,22 +1214,20 @@ def _get_or_compile( # or use InputAttributeNode, but not both. num_input_consumers = 0 - # Step 1: populate num_channel_reads and perform some validation. + # Step 1: populate `arg_to_consumers` and `num_input_consumers` and + # perform some validation. for task in tasks: has_at_least_one_channel_input = False + is_input_consumer = False for arg in task.args: if isinstance(arg, InputNode): has_at_least_one_channel_input = True arg_to_consumers[arg].add(task) - num_input_consumers = max( - num_input_consumers, len(arg_to_consumers[arg]) - ) + is_input_consumer = True elif isinstance(arg, InputAttributeNode): has_at_least_one_channel_input = True arg_to_consumers[arg].add(task) - num_input_consumers = max( - num_input_consumers, len(arg_to_consumers[arg]) - ) + is_input_consumer = True elif isinstance(arg, DAGNode): # Other DAGNodes has_at_least_one_channel_input = True arg_to_consumers[arg].add(task) @@ -1238,6 +1236,8 @@ def _get_or_compile( assert len(upstream_task.output_channels) == 1 arg_channel = upstream_task.output_channels[0] assert arg_channel is not None + if is_input_consumer: + num_input_consumers += 1 # TODO: Support no-input DAGs (use an empty object to signal). if not has_at_least_one_channel_input: raise ValueError( diff --git a/python/ray/dag/tests/experimental/test_accelerated_dag.py b/python/ray/dag/tests/experimental/test_accelerated_dag.py index 3e7daffa86401..2dd658731b01f 100644 --- a/python/ray/dag/tests/experimental/test_accelerated_dag.py +++ b/python/ray/dag/tests/experimental/test_accelerated_dag.py @@ -418,6 +418,133 @@ def test_actor_method_bind_same_input_attr(ray_start_regular): compiled_dag.teardown() +def test_actor_method_bind_diff_input_attr_1(ray_start_regular): + actor = Actor.remote(0) + c = Collector.remote() + with InputNode() as inp: + # Two class methods are bound to two different input + # attribute nodes. + branch1 = actor.inc.bind(inp[0]) + branch2 = actor.inc.bind(inp[1]) + dag = c.collect_two.bind(branch1, branch2) + compiled_dag = dag.experimental_compile() + ref = compiled_dag.execute(0, 1) + assert ray.get(ref) == [0, 1] + + ref = compiled_dag.execute(1, 2) + assert ray.get(ref) == [0, 1, 2, 4] + + ref = compiled_dag.execute(2, 3) + assert ray.get(ref) == [0, 1, 2, 4, 6, 9] + + compiled_dag.teardown() + + +def test_actor_method_bind_diff_input_attr_2(ray_start_regular): + actor = Actor.remote(0) + c = Collector.remote() + with InputNode() as inp: + # Three class methods are bound to two different input + # attribute nodes. Two methods are bound to the same input + # attribute node. + branch1 = actor.inc.bind(inp[0]) + branch2 = actor.inc.bind(inp[0]) + branch3 = actor.inc.bind(inp[1]) + dag = c.collect_three.bind(branch1, branch2, branch3) + compiled_dag = dag.experimental_compile() + ref = compiled_dag.execute(0, 1) + assert ray.get(ref) == [0, 0, 1] + + ref = compiled_dag.execute(1, 2) + assert ray.get(ref) == [0, 0, 1, 2, 3, 5] + + ref = compiled_dag.execute(2, 3) + assert ray.get(ref) == [0, 0, 1, 2, 3, 5, 7, 9, 12] + + compiled_dag.teardown() + + +def test_actor_method_bind_diff_input_attr_3(ray_start_regular): + actor = Actor.remote(0) + with InputNode() as inp: + # A single class method is bound to two different input + # attribute nodes. + dag = actor.inc_two.bind(inp[0], inp[1]) + compiled_dag = dag.experimental_compile() + ref = compiled_dag.execute(0, 1) + assert ray.get(ref) == 1 + + ref = compiled_dag.execute(1, 2) + assert ray.get(ref) == 4 + + ref = compiled_dag.execute(2, 3) + assert ray.get(ref) == 9 + + compiled_dag.teardown() + + +def test_actor_method_bind_diff_input_attr_4(ray_start_regular): + actor = Actor.remote(0) + c = Collector.remote() + with InputNode() as inp: + branch1 = actor.inc_two.bind(inp[0], inp[1]) + branch2 = actor.inc.bind(inp[2]) + dag = c.collect_two.bind(branch1, branch2) + compiled_dag = dag.experimental_compile() + ref = compiled_dag.execute(0, 1, 2) + assert ray.get(ref) == [1, 3] + + ref = compiled_dag.execute(1, 2, 3) + assert ray.get(ref) == [1, 3, 6, 9] + + ref = compiled_dag.execute(2, 3, 4) + assert ray.get(ref) == [1, 3, 6, 9, 14, 18] + + compiled_dag.teardown() + + +def test_actor_method_bind_diff_input_attr_5(ray_start_regular): + actor = Actor.remote(0) + c = Collector.remote() + with InputNode() as inp: + branch1 = actor.inc_two.bind(inp[0], inp[1]) + branch2 = actor.inc_two.bind(inp[2], inp[0]) + dag = c.collect_two.bind(branch1, branch2) + compiled_dag = dag.experimental_compile() + ref = compiled_dag.execute(0, 1, 2) + assert ray.get(ref) == [1, 3] + + ref = compiled_dag.execute(1, 2, 3) + assert ray.get(ref) == [1, 3, 6, 10] + + ref = compiled_dag.execute(2, 3, 4) + assert ray.get(ref) == [1, 3, 6, 10, 15, 21] + + compiled_dag.teardown() + + +def test_actor_method_bind_diff_kwargs_input_attr(ray_start_regular): + actor = Actor.remote(0) + c = Collector.remote() + with InputNode() as inp: + # Two class methods are bound to two different kwargs input + # attribute nodes. + branch1 = actor.inc.bind(inp.x) + branch2 = actor.inc.bind(inp.y) + dag = c.collect_two.bind(branch1, branch2) + compiled_dag = dag.experimental_compile() + ref = compiled_dag.execute(x=0, y=1) + assert ray.get(ref) == [0, 1] + + ref = compiled_dag.execute(x=1, y=2) + assert ray.get(ref) == [0, 1, 2, 4] + + ref = compiled_dag.execute(x=2, y=3) + assert ray.get(ref) == [0, 1, 2, 4, 6, 9] + + compiled_dag.teardown() + + def test_actor_method_bind_same_arg(ray_start_regular): a1 = Actor.remote(0) a2 = Actor.remote(0)