Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][experimental] Correct num_input_consumers for CachedChannel #47489

Merged
merged 4 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,22 +1214,20 @@ def _get_or_compile(
# or use InputAttributeNode, but not both.
num_input_consumers = 0

# Step 1: populate num_channel_reads and perform some validation.
# Step 1: populate `arg_to_consumers` and `num_input_consumers` and
# perform some validation.
for task in tasks:
has_at_least_one_channel_input = False
is_input_consumer = False
for arg in task.args:
if isinstance(arg, InputNode):
has_at_least_one_channel_input = True
arg_to_consumers[arg].add(task)
num_input_consumers = max(
num_input_consumers, len(arg_to_consumers[arg])
)
is_input_consumer = True
elif isinstance(arg, InputAttributeNode):
has_at_least_one_channel_input = True
arg_to_consumers[arg].add(task)
num_input_consumers = max(
num_input_consumers, len(arg_to_consumers[arg])
)
is_input_consumer = True
elif isinstance(arg, DAGNode): # Other DAGNodes
has_at_least_one_channel_input = True
arg_to_consumers[arg].add(task)
Expand All @@ -1238,6 +1236,8 @@ def _get_or_compile(
assert len(upstream_task.output_channels) == 1
arg_channel = upstream_task.output_channels[0]
assert arg_channel is not None
if is_input_consumer:
num_input_consumers += 1
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Support no-input DAGs (use an empty object to signal).
if not has_at_least_one_channel_input:
raise ValueError(
Expand Down
127 changes: 127 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,133 @@ def test_actor_method_bind_same_input_attr(ray_start_regular):
compiled_dag.teardown()


def test_actor_method_bind_diff_input_attr_1(ray_start_regular):
actor = Actor.remote(0)
c = Collector.remote()
with InputNode() as inp:
# Two class methods are bound to two different input
# attribute nodes.
branch1 = actor.inc.bind(inp[0])
branch2 = actor.inc.bind(inp[1])
dag = c.collect_two.bind(branch1, branch2)
compiled_dag = dag.experimental_compile()
ref = compiled_dag.execute(0, 1)
assert ray.get(ref) == [0, 1]

ref = compiled_dag.execute(1, 2)
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
assert ray.get(ref) == [0, 1, 2, 4]

ref = compiled_dag.execute(2, 3)
assert ray.get(ref) == [0, 1, 2, 4, 6, 9]

compiled_dag.teardown()


def test_actor_method_bind_diff_input_attr_2(ray_start_regular):
actor = Actor.remote(0)
c = Collector.remote()
with InputNode() as inp:
# Three class methods are bound to two different input
# attribute nodes. Two methods are bound to the same input
# attribute node.
branch1 = actor.inc.bind(inp[0])
branch2 = actor.inc.bind(inp[0])
branch3 = actor.inc.bind(inp[1])
dag = c.collect_three.bind(branch1, branch2, branch3)
compiled_dag = dag.experimental_compile()
ref = compiled_dag.execute(0, 1)
assert ray.get(ref) == [0, 0, 1]

ref = compiled_dag.execute(1, 2)
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
assert ray.get(ref) == [0, 0, 1, 2, 3, 5]

ref = compiled_dag.execute(2, 3)
assert ray.get(ref) == [0, 0, 1, 2, 3, 5, 7, 9, 12]

compiled_dag.teardown()


def test_actor_method_bind_diff_input_attr_3(ray_start_regular):
actor = Actor.remote(0)
with InputNode() as inp:
# A single class method is bound to two different input
# attribute nodes.
dag = actor.inc_two.bind(inp[0], inp[1])
compiled_dag = dag.experimental_compile()
ref = compiled_dag.execute(0, 1)
assert ray.get(ref) == 1

ref = compiled_dag.execute(1, 2)
assert ray.get(ref) == 4

ref = compiled_dag.execute(2, 3)
assert ray.get(ref) == 9

compiled_dag.teardown()


def test_actor_method_bind_diff_input_attr_4(ray_start_regular):
actor = Actor.remote(0)
c = Collector.remote()
with InputNode() as inp:
branch1 = actor.inc_two.bind(inp[0], inp[1])
branch2 = actor.inc.bind(inp[2])
dag = c.collect_two.bind(branch1, branch2)
compiled_dag = dag.experimental_compile()
ref = compiled_dag.execute(0, 1, 2)
assert ray.get(ref) == [1, 3]

ref = compiled_dag.execute(1, 2, 3)
assert ray.get(ref) == [1, 3, 6, 9]

ref = compiled_dag.execute(2, 3, 4)
assert ray.get(ref) == [1, 3, 6, 9, 14, 18]

compiled_dag.teardown()


def test_actor_method_bind_diff_input_attr_5(ray_start_regular):
actor = Actor.remote(0)
c = Collector.remote()
with InputNode() as inp:
branch1 = actor.inc_two.bind(inp[0], inp[1])
branch2 = actor.inc_two.bind(inp[2], inp[0])
dag = c.collect_two.bind(branch1, branch2)
compiled_dag = dag.experimental_compile()
ref = compiled_dag.execute(0, 1, 2)
assert ray.get(ref) == [1, 3]

ref = compiled_dag.execute(1, 2, 3)
assert ray.get(ref) == [1, 3, 6, 10]

ref = compiled_dag.execute(2, 3, 4)
assert ray.get(ref) == [1, 3, 6, 10, 15, 21]

compiled_dag.teardown()


def test_actor_method_bind_diff_kwargs_input_attr(ray_start_regular):
actor = Actor.remote(0)
c = Collector.remote()
with InputNode() as inp:
# Two class methods are bound to two different kwargs input
# attribute nodes.
branch1 = actor.inc.bind(inp.x)
branch2 = actor.inc.bind(inp.y)
dag = c.collect_two.bind(branch1, branch2)
compiled_dag = dag.experimental_compile()
ref = compiled_dag.execute(x=0, y=1)
assert ray.get(ref) == [0, 1]

ref = compiled_dag.execute(x=1, y=2)
assert ray.get(ref) == [0, 1, 2, 4]

ref = compiled_dag.execute(x=2, y=3)
assert ray.get(ref) == [0, 1, 2, 4, 6, 9]

compiled_dag.teardown()


def test_actor_method_bind_same_arg(ray_start_regular):
a1 = Actor.remote(0)
a2 = Actor.remote(0)
Expand Down
Loading