Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2/X][Pipeline] Add python generation for ClassNode #22617

Merged
merged 6 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/ray/experimental/dag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
from ray.experimental.dag.function_node import FunctionNode
from ray.experimental.dag.class_node import ClassNode, ClassMethodNode
from ray.experimental.dag.input_node import InputNode

from ray.experimental.dag.constants import (
PARENT_CLASS_NODE_KEY,
PREV_CLASS_METHOD_CALL_KEY,
)

__all__ = [
"ClassNode",
"ClassMethodNode",
"DAGNode",
"FunctionNode",
"InputNode",
"PARENT_CLASS_NODE_KEY",
"PREV_CLASS_METHOD_CALL_KEY",
]
12 changes: 8 additions & 4 deletions python/ray/experimental/dag/class_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from ray.experimental.dag.dag_node import DAGNode
from ray.experimental.dag.input_node import InputNode
from ray.experimental.dag.format_utils import get_dag_node_str
from ray.experimental.dag.constants import (
PARENT_CLASS_NODE_KEY,
PREV_CLASS_METHOD_CALL_KEY,
)

from typing import Any, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -85,8 +89,8 @@ def __init__(self, actor: ClassNode, method_name: str):

def _bind(self, *args, **kwargs):
other_args_to_resolve = {
"parent_class_node": self._actor,
"prev_class_method_call": self._actor._last_call,
PARENT_CLASS_NODE_KEY: self._actor,
PREV_CLASS_METHOD_CALL_KEY: self._actor._last_call,
}

node = ClassMethodNode(
Expand Down Expand Up @@ -122,13 +126,13 @@ def __init__(
self._method_name: str = method_name
# Parse other_args_to_resolve and assign to variables
self._parent_class_node: ClassNode = other_args_to_resolve.get(
"parent_class_node"
PARENT_CLASS_NODE_KEY
)
# Used to track lineage of ClassMethodCall to preserve deterministic
# submission and execution order.
self._prev_class_method_call: Optional[
ClassMethodNode
] = other_args_to_resolve.get("prev_class_method_call", None)
] = other_args_to_resolve.get(PREV_CLASS_METHOD_CALL_KEY, None)
# The actor creation task dependency is encoded as the first argument,
# and the ordering dependency as the second, which ensures they are
# executed prior to this node.
Expand Down
3 changes: 3 additions & 0 deletions python/ray/experimental/dag/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Reserved keys used to handle ClassMethodNode in Ray DAG building.
PARENT_CLASS_NODE_KEY = "parent_class_node"
PREV_CLASS_METHOD_CALL_KEY = "prev_class_method_call"
18 changes: 11 additions & 7 deletions python/ray/experimental/dag/tests/test_class_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import pickle

import ray
from ray.experimental.dag import DAGNode
from ray.experimental.dag import (
DAGNode,
PARENT_CLASS_NODE_KEY,
PREV_CLASS_METHOD_CALL_KEY,
)


@ray.remote
Expand Down Expand Up @@ -147,36 +151,36 @@ def combine(x, y):
assert test_a2.get_options() == {} # No .options() at outer call
# refer to a2 constructor .options() call
assert (
test_a2.get_other_args_to_resolve()["parent_class_node"]
test_a2.get_other_args_to_resolve()[PARENT_CLASS_NODE_KEY]
.get_options()
.get("name")
== "a2_v0"
)
# refer to actor method a2.inc.options() call
assert (
test_a2.get_other_args_to_resolve()["prev_class_method_call"]
test_a2.get_other_args_to_resolve()[PREV_CLASS_METHOD_CALL_KEY]
.get_options()
.get("name")
== "v3"
)
# refer to a1 constructor .options() call
assert (
test_a1.get_other_args_to_resolve()["parent_class_node"]
test_a1.get_other_args_to_resolve()[PARENT_CLASS_NODE_KEY]
.get_options()
.get("name")
== "a1_v1"
)
# refer to latest actor method a1.inc.options() call
assert (
test_a1.get_other_args_to_resolve()["prev_class_method_call"]
test_a1.get_other_args_to_resolve()[PREV_CLASS_METHOD_CALL_KEY]
.get_options()
.get("name")
== "v2"
)
# refer to first bound actor method a1.inc.options() call
assert (
test_a1.get_other_args_to_resolve()["prev_class_method_call"]
.get_other_args_to_resolve()["prev_class_method_call"]
test_a1.get_other_args_to_resolve()[PREV_CLASS_METHOD_CALL_KEY]
.get_other_args_to_resolve()[PREV_CLASS_METHOD_CALL_KEY]
.get_options()
.get("name")
== "v1"
Expand Down
5 changes: 5 additions & 0 deletions python/ray/serve/pipeline/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Reserved constant used as key in other_args_to_resolve to configure if we
# return sync or async handle of a deployment.
# True -> RayServeSyncHandle
# False -> RayServeHandle
USE_SYNC_HANDLE_KEY = "use_sync_handle"
49 changes: 38 additions & 11 deletions python/ray/serve/pipeline/deployment_method_node.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Dict, Optional, Tuple, List
from typing import Any, Dict, Optional, Tuple, List, Union

from ray.experimental.dag import DAGNode
from ray.experimental.dag.format_utils import get_dag_node_str
from ray.serve.api import Deployment
from ray.serve.handle import RayServeSyncHandle, RayServeHandle
from ray.serve.pipeline.constants import USE_SYNC_HANDLE_KEY


class DeploymentMethodNode(DAGNode):
Expand All @@ -26,16 +27,9 @@ def __init__(
method_options,
other_args_to_resolve=other_args_to_resolve,
)
# Serve handle is sync by default.
if (
"sync_handle" in self._bound_other_args_to_resolve
and self._bound_other_args_to_resolve.get("sync_handle") is True
):
self._deployment_handle: RayServeSyncHandle = deployment.get_handle(
sync=True
)
else:
self._deployment_handle: RayServeHandle = deployment.get_handle(sync=False)
self._deployment_handle: Union[
RayServeHandle, RayServeSyncHandle
] = self._get_serve_deployment_handle(deployment, other_args_to_resolve)

def _copy_impl(
self,
Expand Down Expand Up @@ -63,6 +57,39 @@ def _execute_impl(self, *args):
**self._bound_kwargs,
)

def _get_serve_deployment_handle(
self,
deployment: Deployment,
bound_other_args_to_resolve: Dict[str, Any],
) -> Union[RayServeHandle, RayServeSyncHandle]:
"""
Return a sync or async handle of the encapsulated Deployment based on
config.

Args:
deployment (Deployment): Deployment instance wrapped in the DAGNode.
bound_other_args_to_resolve (Dict[str, Any]): Contains args used
to configure DeploymentNode.

Returns:
RayServeHandle: Default and catch-all is to return sync handle.
return async handle only if user explicitly set
USE_SYNC_HANDLE_KEY with value of False.
"""
if USE_SYNC_HANDLE_KEY not in bound_other_args_to_resolve:
# Return sync RayServeSyncHandle
return deployment.get_handle(sync=True)
elif bound_other_args_to_resolve.get(USE_SYNC_HANDLE_KEY) is True:
# Return sync RayServeSyncHandle
return deployment.get_handle(sync=True)
elif bound_other_args_to_resolve.get(USE_SYNC_HANDLE_KEY) is False:
# Return async RayServeHandle
return deployment.get_handle(sync=False)
else:
raise ValueError(
f"{USE_SYNC_HANDLE_KEY} should only be set with a boolean value."
)

def __str__(self) -> str:
return get_dag_node_str(
self, str(self._method_name) + "() @ " + str(self._body)
Expand Down
49 changes: 38 additions & 11 deletions python/ray/serve/pipeline/deployment_node.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Dict, Optional, List, Tuple
from typing import Any, Dict, Optional, List, Tuple, Union

from ray.experimental.dag import DAGNode, InputNode
from ray.serve.api import Deployment
from ray.serve.handle import RayServeSyncHandle, RayServeHandle
from ray.serve.pipeline.deployment_method_node import DeploymentMethodNode
from ray.serve.pipeline.constants import USE_SYNC_HANDLE_KEY
from ray.experimental.dag.format_utils import get_dag_node_str


Expand All @@ -25,16 +26,9 @@ def __init__(
cls_options,
other_args_to_resolve=other_args_to_resolve,
)
# Serve handle is sync by default.
if (
"sync_handle" in self._bound_other_args_to_resolve
and self._bound_other_args_to_resolve.get("sync_handle") is True
):
self._deployment_handle: RayServeSyncHandle = deployment.get_handle(
sync=True
)
else:
self._deployment_handle: RayServeHandle = deployment.get_handle(sync=False)
self._deployment_handle: Union[
RayServeHandle, RayServeSyncHandle
] = self._get_serve_deployment_handle(deployment, other_args_to_resolve)

if self._contains_input_node():
raise ValueError(
Expand Down Expand Up @@ -65,6 +59,39 @@ def _execute_impl(self, *args):
*self._bound_args, **self._bound_kwargs
)

def _get_serve_deployment_handle(
self,
deployment: Deployment,
bound_other_args_to_resolve: Dict[str, Any],
) -> Union[RayServeHandle, RayServeSyncHandle]:
"""
Return a sync or async handle of the encapsulated Deployment based on
config.

Args:
deployment (Deployment): Deployment instance wrapped in the DAGNode.
bound_other_args_to_resolve (Dict[str, Any]): Contains args used
to configure DeploymentNode.

Returns:
RayServeHandle: Default and catch-all is to return sync handle.
return async handle only if user explicitly set
USE_SYNC_HANDLE_KEY with value of False.
"""
if USE_SYNC_HANDLE_KEY not in bound_other_args_to_resolve:
# Return sync RayServeSyncHandle
return deployment.get_handle(sync=True)
elif bound_other_args_to_resolve.get(USE_SYNC_HANDLE_KEY) is True:
# Return sync RayServeSyncHandle
return deployment.get_handle(sync=True)
elif bound_other_args_to_resolve.get(USE_SYNC_HANDLE_KEY) is False:
# Return async RayServeHandle
return deployment.get_handle(sync=False)
else:
raise ValueError(
f"{USE_SYNC_HANDLE_KEY} should only be set with a boolean value."
)

def _contains_input_node(self) -> bool:
"""Check if InputNode is used in children DAGNodes with current node
as the root.
Expand Down
Loading