Skip to content

Commit

Permalink
[ADAG] Support tasks with multiple return values in aDAG (ray-project…
Browse files Browse the repository at this point in the history
…#47024)

aDAG currently does not support multiple return values. We would like to
add general support for multiple return values.

This PR supports multiple returns by returning a separate
`ClassMethodNode` for each return value of the tuple. It is an
incremental change for `ClassMethodNode`, addign
`_is_class_method_output`, `_class_method_call`, `_output_idx`.
`_output_idx` is used to guide channel allocation and output writes.
User needs to specify `num_returns > 1` to hint multiple return values.
The upstream task allocates a separate output channel for each return
value. A downstream task reads from one of the output channels.

## What is done?

We modify `ClassMethodNode` to handle two logics, one is a class method
call which is the original semantics (`self.is_class_method_call ==
True`), another is a class method output which is responsible for one of
the multiple return values (`self.is_class_method_output == True`).

We modify `WriterInterface` to support writes to multiple
`output_channels` with `output_idxs`. If an output index is None, it
means the complete return value is written to the output channel.
Otherwise, the return value is a tuple and the index is used to extract
the value to be written to the output channel.

We allocate separate output channels to different readers. The
downstream tasks of a `ClassMethodNode` with
`self.is_class_method_output == True` are the readers of an output
channel of its upstream `ClassMethodNode`. The example below
demonstrates this.

```
upstream ClassMethodNode (self.is_class_method_call == True, self.output_channels = [c1, c2])
--> downstream ClassMethodNode (self.is_class_method_method == True, self.output_channels[c1])
--> ...
```

Closes ray-project#45569

---------

Signed-off-by: Weixin Deng <weixin@cs.washington.edu>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
dengwxn authored and ujjawal-khare committed Oct 15, 2024
1 parent a1f83d7 commit a1cbb6a
Show file tree
Hide file tree
Showing 10 changed files with 628 additions and 135 deletions.
21 changes: 18 additions & 3 deletions python/ray/actor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import logging
import weakref
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import ray._private.ray_constants as ray_constants
import ray._private.signature as signature
Expand Down Expand Up @@ -235,9 +235,10 @@ def _bind(
num_returns=None,
concurrency_group=None,
_generator_backpressure_num_objects=None,
):
) -> Union["ray.dag.ClassMethodNode", Tuple["ray.dag.ClassMethodNode", ...]]:
from ray.dag.class_node import (
BIND_INDEX_KEY,
IS_CLASS_METHOD_OUTPUT_KEY,
PARENT_CLASS_NODE_KEY,
PREV_CLASS_METHOD_CALL_KEY,
ClassMethodNode,
Expand Down Expand Up @@ -271,7 +272,21 @@ def _bind(
options,
other_args_to_resolve=other_args_to_resolve,
)
return node

if node.num_returns > 1:
output_nodes: List[ClassMethodNode] = []
for i in range(node.num_returns):
output_node = ClassMethodNode(
f"return_idx_{i}",
(node, i),
dict(),
dict(),
{IS_CLASS_METHOD_OUTPUT_KEY: True},
)
output_nodes.append(output_node)
return tuple(output_nodes)
else:
return node

@wrap_auto_init
@_tracing_actor_method_invocation
Expand Down
5 changes: 4 additions & 1 deletion python/ray/dag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from ray.dag.dag_node import DAGNode
from ray.dag.function_node import FunctionNode
from ray.dag.class_node import ClassNode, ClassMethodNode
from ray.dag.class_node import (
ClassNode,
ClassMethodNode,
)
from ray.dag.input_node import (
InputNode,
InputAttributeNode,
Expand Down
109 changes: 98 additions & 11 deletions python/ray/dag/class_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
PARENT_CLASS_NODE_KEY,
PREV_CLASS_METHOD_CALL_KEY,
BIND_INDEX_KEY,
IS_CLASS_METHOD_OUTPUT_KEY,
)
from ray.util.annotations import DeveloperAPI

Expand Down Expand Up @@ -134,6 +135,24 @@ def options(self, **options):
return self


class _ClassMethodOutput:
"""Represents a class method output in a Ray function DAG."""

def __init__(self, class_method_call: "ClassMethodNode", output_idx: int):
# The upstream class method call that returns multiple values.
self._class_method_call = class_method_call
# The output index of the return value from the upstream class method call.
self._output_idx = output_idx

@property
def class_method_call(self) -> "ClassMethodNode":
return self._class_method_call

@property
def output_idx(self) -> int:
return self._output_idx


@DeveloperAPI
class ClassMethodNode(DAGNode):
"""Represents an actor method invocation in a Ray function DAG."""
Expand Down Expand Up @@ -164,6 +183,21 @@ def __init__(
self._bind_index: Optional[int] = other_args_to_resolve.get(
BIND_INDEX_KEY, None
)
# Represent if the ClassMethodNode is a class method output. If True,
# the node is a placeholder for a return value from the ClassMethodNode
# that returns multiple values. If False, the node is a class method call.
self._is_class_method_output: bool = other_args_to_resolve.get(
IS_CLASS_METHOD_OUTPUT_KEY, False
)
# Represents the return value from the upstream ClassMethodNode that
# returns multiple values. If the node is a class method call, this is None.
self._class_method_output: Optional[_ClassMethodOutput] = None
if self._is_class_method_output:
# Set the upstream ClassMethodNode and the output index of the return
# value from `method_args`.
self._class_method_output = _ClassMethodOutput(
method_args[0], method_args[1]
)

# The actor creation task dependency is encoded as the first argument,
# and the ordering dependency as the second, which ensures they are
Expand Down Expand Up @@ -198,12 +232,16 @@ def _execute_impl(self, *args, **kwargs):
with value in bound_args and bound_kwargs via bottom-up recursion when
current node is executed.
"""
method_body = getattr(self._parent_class_node, self._method_name)
# Execute with bound args.
return method_body.options(**self._bound_options).remote(
*self._bound_args,
**self._bound_kwargs,
)
if self.is_class_method_call:
method_body = getattr(self._parent_class_node, self._method_name)
# Execute with bound args.
return method_body.options(**self._bound_options).remote(
*self._bound_args,
**self._bound_kwargs,
)
else:
assert self._class_method_output is not None
return self._bound_args[0][self._class_method_output.output_idx]

def __str__(self) -> str:
return get_dag_node_str(self, f"{self._method_name}()")
Expand All @@ -225,8 +263,57 @@ def _get_actor_handle(self) -> Optional["ray.actor.ActorHandle"]:

@property
def num_returns(self) -> int:
num_returns = self._bound_options.get("num_returns", None)
if num_returns is None:
method = self._get_remote_method(self._method_name)
num_returns = method.__getstate__()["num_returns"]
return num_returns
"""
Return the number of return values from the class method call. If the
node is a class method output, return the number of return values from
the upstream class method call.
"""

if self.is_class_method_call:
num_returns = self._bound_options.get("num_returns", None)
if num_returns is None:
method = self._get_remote_method(self._method_name)
num_returns = method.__getstate__()["num_returns"]
return num_returns
else:
assert self._class_method_output is not None
return self._class_method_output.class_method_call.num_returns

@property
def is_class_method_call(self) -> bool:
"""
Return True if the node is a class method call, False if the node is a
class method output.
"""
return not self._is_class_method_output

@property
def is_class_method_output(self) -> bool:
"""
Return True if the node is a class method output, False if the node is a
class method call.
"""
return self._is_class_method_output

@property
def class_method_call(self) -> Optional["ClassMethodNode"]:
"""
Return the upstream class method call that returns multiple values. If
the node is a class method output, return None.
"""

if self._class_method_output is None:
return None
return self._class_method_output.class_method_call

@property
def output_idx(self) -> Optional[int]:
"""
Return the output index of the return value from the upstream class
method call that returns multiple values. If the node is a class method
call, return None.
"""

if self._class_method_output is None:
return None
return self._class_method_output.output_idx
Loading

0 comments on commit a1cbb6a

Please sign in to comment.