From 6d710e1bdef8d221f26ddd66563da47c91750ad2 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 7 Jul 2023 10:42:04 -0500 Subject: [PATCH 1/5] rm Signed-off-by: Edward Oakes --- python/ray/dag/class_node.py | 62 ---- python/ray/dag/dag_node.py | 10 +- python/ray/dag/function_node.py | 30 -- python/ray/dag/input_node.py | 34 --- python/ray/serve/_private/constants.py | 6 - .../serve/_private/deployment_graph_build.py | 7 +- python/ray/serve/_private/json_serde.py | 169 ----------- python/ray/serve/deployment_graph.py | 19 +- python/ray/serve/drivers.py | 6 +- .../experimental/gradio_visualize_graph.py | 31 +- python/ray/serve/handle.py | 36 +-- python/ray/serve/tests/test_json_serde.py | 268 ------------------ 12 files changed, 35 insertions(+), 643 deletions(-) delete mode 100644 python/ray/serve/_private/json_serde.py delete mode 100644 python/ray/serve/tests/test_json_serde.py diff --git a/python/ray/dag/class_node.py b/python/ray/dag/class_node.py index d5842c1d402f..66eb83084d21 100644 --- a/python/ray/dag/class_node.py +++ b/python/ray/dag/class_node.py @@ -5,7 +5,6 @@ from ray.dag.constants import ( PARENT_CLASS_NODE_KEY, PREV_CLASS_METHOD_CALL_KEY, - DAGNODE_TYPE_KEY, ) from ray.util.annotations import DeveloperAPI @@ -92,36 +91,6 @@ def __getattr__(self, method_name: str): def __str__(self) -> str: return get_dag_node_str(self, str(self._body)) - def get_import_path(self) -> str: - body = self._body.__ray_actor_class__ - return f"{body.__module__}.{body.__qualname__}" - - def to_json(self) -> Dict[str, Any]: - return { - DAGNODE_TYPE_KEY: ClassNode.__name__, - # Will be overriden by build() - "import_path": self.get_import_path(), - "args": self.get_args(), - "kwargs": self.get_kwargs(), - # .options() should not contain any DAGNode type - "options": self.get_options(), - "other_args_to_resolve": self.get_other_args_to_resolve(), - "uuid": self.get_stable_uuid(), - } - - @classmethod - def from_json(cls, input_json, module, object_hook=None): - assert input_json[DAGNODE_TYPE_KEY] == ClassNode.__name__ - node = cls( - module.__ray_metadata__.modified_class, - input_json["args"], - input_json["kwargs"], - input_json["options"], - other_args_to_resolve=input_json["other_args_to_resolve"], - ) - node._stable_uuid = input_json["uuid"] - return node - class _UnboundClassMethodNode(object): def __init__(self, actor: ClassNode, method_name: str): @@ -230,34 +199,3 @@ def __str__(self) -> str: def get_method_name(self) -> str: return self._method_name - - def get_import_path(self) -> str: - body = self._parent_class_node._body.__ray_actor_class__ - return f"{body.__module__}.{body.__qualname__}" - - def to_json(self) -> Dict[str, Any]: - return { - DAGNODE_TYPE_KEY: ClassMethodNode.__name__, - # Will be overriden by build() - "method_name": self.get_method_name(), - "import_path": self.get_import_path(), - "args": self.get_args(), - "kwargs": self.get_kwargs(), - # .options() should not contain any DAGNode type - "options": self.get_options(), - "other_args_to_resolve": self.get_other_args_to_resolve(), - "uuid": self.get_stable_uuid(), - } - - @classmethod - def from_json(cls, input_json): - assert input_json[DAGNODE_TYPE_KEY] == ClassMethodNode.__name__ - node = cls( - input_json["method_name"], - input_json["args"], - input_json["kwargs"], - input_json["options"], - other_args_to_resolve=input_json["other_args_to_resolve"], - ) - node._stable_uuid = input_json["uuid"] - return node diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index f2f615e85224..387fea1a99c5 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -322,13 +322,11 @@ def _copy( instance._stable_uuid = self._stable_uuid return instance - def __reduce__(self): - """We disallow serialization to prevent inadvertent closure-capture. + def __getstate__(self): + return self.__dict__ - Use ``.to_json()`` and ``.from_json()`` to convert DAGNodes to a - serializable form. - """ - raise ValueError(f"DAGNode cannot be serialized. DAGNode: {str(self)}") + def __setstate__(self, d: Dict[str, Any]): + self.__dict__.update(d) def __getattr__(self, attr: str): if attr == "bind": diff --git a/python/ray/dag/function_node.py b/python/ray/dag/function_node.py index 42dc01bdba11..4565fcffe8ff 100644 --- a/python/ray/dag/function_node.py +++ b/python/ray/dag/function_node.py @@ -4,7 +4,6 @@ import ray from ray.dag.dag_node import DAGNode from ray.dag.format_utils import get_dag_node_str -from ray.dag.constants import DAGNODE_TYPE_KEY from ray.util.annotations import DeveloperAPI @@ -59,32 +58,3 @@ def _execute_impl(self, *args, **kwargs): def __str__(self) -> str: return get_dag_node_str(self, str(self._body)) - - def get_import_path(self): - return f"{self._body.__module__}.{self._body.__qualname__}" - - def to_json(self) -> Dict[str, Any]: - return { - DAGNODE_TYPE_KEY: FunctionNode.__name__, - # Will be overriden by build() - "import_path": self.get_import_path(), - "args": self.get_args(), - "kwargs": self.get_kwargs(), - # .options() should not contain any DAGNode type - "options": self.get_options(), - "other_args_to_resolve": self.get_other_args_to_resolve(), - "uuid": self.get_stable_uuid(), - } - - @classmethod - def from_json(cls, input_json, module): - assert input_json[DAGNODE_TYPE_KEY] == FunctionNode.__name__ - node = cls( - module._function, - input_json["args"], - input_json["kwargs"], - input_json["options"], - other_args_to_resolve=input_json["other_args_to_resolve"], - ) - node._stable_uuid = input_json["uuid"] - return node diff --git a/python/ray/dag/input_node.py b/python/ray/dag/input_node.py index b64a9c16f1b4..aeada6da7e9b 100644 --- a/python/ray/dag/input_node.py +++ b/python/ray/dag/input_node.py @@ -3,7 +3,6 @@ from ray.dag import DAGNode from ray.dag.format_utils import get_dag_node_str from ray.experimental.gradio_utils import type_to_string -from ray.dag.constants import DAGNODE_TYPE_KEY from ray.util.annotations import DeveloperAPI IN_CONTEXT_MANAGER = "__in_context_manager__" @@ -173,20 +172,6 @@ def __enter__(self): def __exit__(self, *args): pass - def to_json(self) -> Dict[str, Any]: - return { - DAGNODE_TYPE_KEY: InputNode.__name__, - "other_args_to_resolve": self.get_other_args_to_resolve(), - "uuid": self.get_stable_uuid(), - } - - @classmethod - def from_json(cls, input_json): - assert input_json[DAGNODE_TYPE_KEY] == InputNode.__name__ - node = cls(_other_args_to_resolve=input_json["other_args_to_resolve"]) - node._stable_uuid = input_json["uuid"] - return node - def get_result_type(self) -> str: """Get type of the output of this DAGNode. @@ -298,25 +283,6 @@ def _execute_impl(self, *args, **kwargs): def __str__(self) -> str: return get_dag_node_str(self, f'["{self._key}"]') - def to_json(self) -> Dict[str, Any]: - return { - DAGNODE_TYPE_KEY: InputAttributeNode.__name__, - "other_args_to_resolve": self.get_other_args_to_resolve(), - "uuid": self.get_stable_uuid(), - } - - @classmethod - def from_json(cls, input_json): - assert input_json[DAGNODE_TYPE_KEY] == InputAttributeNode.__name__ - node = cls( - input_json["other_args_to_resolve"]["dag_input_node"], - input_json["other_args_to_resolve"]["key"], - input_json["other_args_to_resolve"]["accessor_method"], - input_json["other_args_to_resolve"]["result_type_string"], - ) - node._stable_uuid = input_json["uuid"] - return node - def get_result_type(self) -> str: """Get type of the output of this DAGNode. diff --git a/python/ray/serve/_private/constants.py b/python/ray/serve/_private/constants.py index 4feff1f32b98..2f495199ae75 100644 --- a/python/ray/serve/_private/constants.py +++ b/python/ray/serve/_private/constants.py @@ -1,5 +1,4 @@ import os -from enum import Enum #: Used for debugging to turn on DEBUG-level logs DEBUG_LOG_ENV_VAR = "SERVE_DEBUG_LOG" @@ -153,11 +152,6 @@ PUSH_MULTIPLEXED_MODEL_IDS_INTERVAL_S = 1.0 -class ServeHandleType(str, Enum): - SYNC = "SYNC" - ASYNC = "ASYNC" - - # Deprecation message for V1 migrations. MIGRATION_MESSAGE = ( "See https://docs.ray.io/en/latest/serve/index.html for more information." diff --git a/python/ray/serve/_private/deployment_graph_build.py b/python/ray/serve/_private/deployment_graph_build.py index a9d4b6977203..2186e168adf0 100644 --- a/python/ray/serve/_private/deployment_graph_build.py +++ b/python/ray/serve/_private/deployment_graph_build.py @@ -1,8 +1,9 @@ import inspect -import json from typing import List from collections import OrderedDict +from ray import cloudpickle + from ray.serve.deployment import Deployment, schema_to_deployment from ray.serve.deployment_graph import RayServeDAGHandle from ray.serve._private.constants import DEPLOYMENT_NAME_PREFIX_SEPARATOR @@ -16,7 +17,6 @@ from ray.serve._private.deployment_function_executor_node import ( DeploymentFunctionExecutorNode, ) -from ray.serve._private.json_serde import DAGNodeEncoder from ray.serve.handle import RayServeDeploymentHandle from ray.serve.schema import DeploymentSchema @@ -389,8 +389,7 @@ def replace_with_handle(node): DeploymentFunctionExecutorNode, ), ): - serve_dag_root_json = json.dumps(node, cls=DAGNodeEncoder) - return RayServeDAGHandle(serve_dag_root_json) + return RayServeDAGHandle(cloudpickle.dumps(node)) ( replaced_deployment_init_args, diff --git a/python/ray/serve/_private/json_serde.py b/python/ray/serve/_private/json_serde.py deleted file mode 100644 index 72d3f4077193..000000000000 --- a/python/ray/serve/_private/json_serde.py +++ /dev/null @@ -1,169 +0,0 @@ -from typing import Any, Union -from importlib import import_module - -import json - -from ray.dag import ( - DAGNode, - ClassNode, - FunctionNode, - InputNode, - InputAttributeNode, - DAGNODE_TYPE_KEY, -) -from ray.serve._private.deployment_executor_node import DeploymentExecutorNode -from ray.serve._private.deployment_method_executor_node import ( - DeploymentMethodExecutorNode, -) -from ray.serve._private.deployment_function_executor_node import ( - DeploymentFunctionExecutorNode, -) - -from ray.serve.schema import ( - DeploymentSchema, -) -from ray.serve._private.utils import parse_import_path -from ray.serve.handle import ( - HandleOptions, - RayServeHandle, - RayServeDeploymentHandle, - _serve_handle_to_json_dict, - _serve_handle_from_json_dict, -) -from ray.serve._private.constants import SERVE_HANDLE_JSON_KEY -from ray.serve.deployment_graph import RayServeDAGHandle - - -def convert_to_json_safe_obj(obj: Any, *, err_key: str) -> Any: - """Converts the provided object into a JSON-safe version of it. - - The returned object can safely be `json.dumps`'d to a string. - - Uses the Ray Serve encoder to serialize special objects such as - ServeHandles and DAGHandles. - - Raises: TypeError if the object contains fields that cannot be - JSON-serialized. - """ - try: - return json.loads(json.dumps(obj, cls=DAGNodeEncoder)) - except Exception as e: - raise TypeError( - "All provided fields must be JSON-serializable to build the " - f"Serve app. Failed while serializing {err_key}:\n{e}" - ) - - -def convert_from_json_safe_obj(obj: Any, *, err_key: str) -> Any: - """Converts a JSON-safe object to one that contains Serve special types. - - The provided object should have been serialized using - convert_to_json_safe_obj. Any special-cased objects such as ServeHandles - will be recovered on this pass. - """ - try: - return json.loads(json.dumps(obj), object_hook=dagnode_from_json) - except Exception as e: - raise ValueError(f"Failed to convert {err_key} from JSON:\n{e}") - - -class DAGNodeEncoder(json.JSONEncoder): - """ - Custom JSON serializer for DAGNode type that takes care of RayServeHandle - used in deployment init args or kwargs, as well as all other DAGNode types - with potentially deeply nested structure with other DAGNode instances. - - Enforcements: - - All args, kwargs and other_args_to_resolve used in Ray DAG needs to - be JSON serializable in order to be converted and deployed using - Ray Serve. - - All modules such as class or functions need to be visible and - importable on top of its file, and can be resolved via a fully - qualified import_path. - - No DAGNode instance should appear in bound .options(), which should be - JSON serializable with default encoder. - """ - - def default(self, obj): - if isinstance(obj, DeploymentSchema): - return { - DAGNODE_TYPE_KEY: "DeploymentSchema", - # The schema's default values are Python enums that aren't - # JSON-serializable by design. exclude_defaults omits these, - # so the return value can be JSON-serialized. - "schema": obj.dict(exclude_defaults=True), - } - elif isinstance(obj, RayServeHandle): - return _serve_handle_to_json_dict(obj) - elif isinstance(obj, RayServeDAGHandle): - # TODO(simon) Do a proper encoder - return { - DAGNODE_TYPE_KEY: RayServeDAGHandle.__name__, - "dag_node_json": obj.dag_node_json, - } - elif isinstance(obj, RayServeDeploymentHandle): - return { - DAGNODE_TYPE_KEY: RayServeDeploymentHandle.__name__, - "deployment_name": obj.deployment_name, - "handle_options_method_name": obj.handle_options.method_name, - } - # For all other DAGNode types. - elif isinstance(obj, DAGNode): - return obj.to_json() - else: - return json.JSONEncoder.default(self, obj) - - -def dagnode_from_json(input_json: Any) -> Union[DAGNode, RayServeHandle, Any]: - """ - Decode a DAGNode from given input json dictionary. JSON serialization is - only used and enforced in ray serve from ray core API authored DAGNode(s). - - Covers both RayServeHandle and DAGNode types. - - Assumptions: - - User object's JSON dict does not have keys that collide with our - reserved DAGNODE_TYPE_KEY - - RayServeHandle and Deployment can be re-constructed without losing - states needed for their functionality or correctness. - - DAGNode type can be re-constructed with new stable_uuid upon each - deserialization without effective correctness of execution. - - Only exception is ClassNode used as parent of ClassMethodNode - that we perserve the same parent node. - - .options() does not contain any DAGNode type - """ - node_type_to_cls = { - # Ray DAG Inputs - InputNode.__name__: InputNode, - InputAttributeNode.__name__: InputAttributeNode, - # Deployment graph execution nodes - DeploymentExecutorNode.__name__: DeploymentExecutorNode, - DeploymentMethodExecutorNode.__name__: DeploymentMethodExecutorNode, - DeploymentFunctionExecutorNode.__name__: DeploymentFunctionExecutorNode, - } - # Deserialize RayServeHandle type - if SERVE_HANDLE_JSON_KEY in input_json: - return _serve_handle_from_json_dict(input_json) - # Base case for plain objects - elif DAGNODE_TYPE_KEY not in input_json: - return input_json - elif input_json[DAGNODE_TYPE_KEY] == RayServeDAGHandle.__name__: - return RayServeDAGHandle(input_json["dag_node_json"]) - elif input_json[DAGNODE_TYPE_KEY] == "DeploymentSchema": - return DeploymentSchema.parse_obj(input_json["schema"]) - elif input_json[DAGNODE_TYPE_KEY] == RayServeDeploymentHandle.__name__: - return RayServeDeploymentHandle( - input_json["deployment_name"], - HandleOptions(input_json["handle_options_method_name"]), - ) - # Deserialize DAGNode type - elif input_json[DAGNODE_TYPE_KEY] in node_type_to_cls: - return node_type_to_cls[input_json[DAGNODE_TYPE_KEY]].from_json(input_json) - else: - # Class and Function nodes require original module as body. - module_name, attr_name = parse_import_path(input_json["import_path"]) - module = getattr(import_module(module_name), attr_name) - if input_json[DAGNODE_TYPE_KEY] == FunctionNode.__name__: - return FunctionNode.from_json(input_json, module) - elif input_json[DAGNODE_TYPE_KEY] == ClassNode.__name__: - return ClassNode.from_json(input_json, module) diff --git a/python/ray/serve/deployment_graph.py b/python/ray/serve/deployment_graph.py index d1325ce8de61..4a2da419099a 100644 --- a/python/ray/serve/deployment_graph.py +++ b/python/ray/serve/deployment_graph.py @@ -1,6 +1,8 @@ -import json import os + import ray +from ray import cloudpickle +from ray.util.annotations import PublicAPI from ray.dag.class_node import ClassNode # noqa: F401 from ray.dag.function_node import FunctionNode # noqa: F401 @@ -9,7 +11,6 @@ from ray.serve._private.constants import ( SYNC_HANDLE_IN_DAG_FEATURE_FLAG_ENV_KEY, ) -from ray.util.annotations import PublicAPI FLAG_SERVE_DEPLOYMENT_HANDLE_IS_SYNC = ( os.environ.get(SYNC_HANDLE_IN_DAG_FEATURE_FLAG_ENV_KEY, "0") == "1" @@ -24,8 +25,8 @@ class RayServeDAGHandle: orchestrate a deployment graph. """ - def __init__(self, dag_node_json: str) -> None: - self.dag_node_json = dag_node_json + def __init__(self, pickled_dag_node: bytes) -> None: + self.pickled_dag_node = pickled_dag_node # NOTE(simon): Making this lazy to avoid deserialization in controller for now # This would otherwise hang because it's trying to get handles from within @@ -34,22 +35,18 @@ def __init__(self, dag_node_json: str) -> None: @classmethod def _deserialize(cls, *args): - """Required for this class's __reduce__ method to be picklable.""" + """Required for this class's __reduce__ method to be pickleable.""" return cls(*args) def __reduce__(self): - return RayServeDAGHandle._deserialize, (self.dag_node_json,) + return RayServeDAGHandle._deserialize, (self.pickled_dag_node,) async def remote( self, *args, _ray_cache_refs: bool = False, **kwargs ) -> ray.ObjectRef: """Execute the request, returns a ObjectRef representing final result.""" if self.dag_node is None: - from ray.serve._private.json_serde import dagnode_from_json - - self.dag_node = json.loads( - self.dag_node_json, object_hook=dagnode_from_json - ) + self.dag_node = cloudpickle.loads(self.pickled_dag_node) if FLAG_SERVE_DEPLOYMENT_HANDLE_IS_SYNC: return self.dag_node.execute( diff --git a/python/ray/serve/drivers.py b/python/ray/serve/drivers.py index 5c964381031a..71c8781a247c 100644 --- a/python/ray/serve/drivers.py +++ b/python/ray/serve/drivers.py @@ -114,9 +114,9 @@ async def get_intermediate_object_refs(self) -> Dict[str, Any]: return await root_dag_node.get_object_refs_from_last_execute() - async def get_dag_node_json(self) -> str: - """Returns the json serialized root dag node""" - return self.dags[self.MATCH_ALL_ROUTE_PREFIX].dag_node_json + async def get_pickled_dag_node(self) -> bytes: + """Returns the serialized root dag node.""" + return self.dags[self.MATCH_ALL_ROUTE_PREFIX].pickled_dag_node @PublicAPI(stability="alpha") diff --git a/python/ray/serve/experimental/gradio_visualize_graph.py b/python/ray/serve/experimental/gradio_visualize_graph.py index d6bf16b230c7..4b8aac556865 100644 --- a/python/ray/serve/experimental/gradio_visualize_graph.py +++ b/python/ray/serve/experimental/gradio_visualize_graph.py @@ -1,27 +1,28 @@ +from collections import defaultdict +from io import BytesIO +import logging +from pydoc import locate +from typing import Any, Dict, Optional + import ray +from ray import cloudpickle +from ray.experimental.gradio_utils import type_to_string + from ray.dag import ( DAGNode, InputNode, InputAttributeNode, ) +from ray.dag.utils import _DAGNodeNameGenerator +from ray.dag.vis_utils import _dag_to_dot + +from ray.serve.handle import RayServeHandle from ray.serve._private.deployment_function_executor_node import ( DeploymentFunctionExecutorNode, ) from ray.serve._private.deployment_method_executor_node import ( DeploymentMethodExecutorNode, ) -from ray.serve._private.json_serde import dagnode_from_json -from ray.dag.utils import _DAGNodeNameGenerator -from ray.dag.vis_utils import _dag_to_dot -from ray.serve.handle import RayServeHandle -from ray.experimental.gradio_utils import type_to_string - -from typing import Any, Dict, Optional -from collections import defaultdict -import json -import logging -from io import BytesIO -from pydoc import locate logger = logging.getLogger(__name__) @@ -311,9 +312,9 @@ def visualize_with_gradio( self._reset_state() self.handle = driver_handle - # Load the root DAG node from handle - dag_node_json = ray.get(self.handle.get_dag_node_json.remote()) - self.dag = json.loads(dag_node_json, object_hook=dagnode_from_json) + # Load the root DAG node from handle. + pickled_dag_node = ray.get(self.handle.get_pickled_dag_node.remote()) + self.dag = cloudpickle.loads(pickled_dag_node) # Get level for each node in dag uuid_to_depths = defaultdict(lambda: 0) diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index e2c45194b4f3..154821fed85d 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -4,8 +4,8 @@ from functools import wraps import inspect import os -from typing import Coroutine, Dict, Optional, Union import threading +from typing import Coroutine, Optional, Union import ray from ray._private.utils import get_or_create_event_loop @@ -15,9 +15,7 @@ from ray.serve._private.common import EndpointTag from ray.serve._private.constants import ( RAY_SERVE_ENABLE_NEW_ROUTING, - SERVE_HANDLE_JSON_KEY, SYNC_HANDLE_IN_DAG_FEATURE_FLAG_ENV_KEY, - ServeHandleType, ) from ray.serve._private.utils import ( get_random_letters, @@ -418,35 +416,3 @@ def __getattr__(self, name): def __repr__(self): return f"{self.__class__.__name__}" f"(deployment='{self.deployment_name}')" - - -def _serve_handle_to_json_dict(handle: RayServeHandle) -> Dict[str, str]: - """Converts a Serve handle to a JSON-serializable dictionary. - - The dictionary can be converted back to a ServeHandle using - _serve_handle_from_json_dict. - """ - if isinstance(handle, RayServeSyncHandle): - handle_type = ServeHandleType.SYNC - else: - handle_type = ServeHandleType.ASYNC - - return { - SERVE_HANDLE_JSON_KEY: handle_type, - "deployment_name": handle.deployment_name, - } - - -def _serve_handle_from_json_dict(d: Dict[str, str]) -> RayServeHandle: - """Converts a JSON-serializable dictionary back to a ServeHandle. - - The dictionary should be constructed using _serve_handle_to_json_dict. - """ - if SERVE_HANDLE_JSON_KEY not in d: - raise ValueError(f"dict must contain {SERVE_HANDLE_JSON_KEY} key.") - - return serve.context.get_global_client().get_handle( - d["deployment_name"], - sync=d[SERVE_HANDLE_JSON_KEY] == ServeHandleType.SYNC, - missing_ok=True, - ) diff --git a/python/ray/serve/tests/test_json_serde.py b/python/ray/serve/tests/test_json_serde.py deleted file mode 100644 index 16f45dc5bc57..000000000000 --- a/python/ray/serve/tests/test_json_serde.py +++ /dev/null @@ -1,268 +0,0 @@ -import pytest -import json -from typing import TypeVar - -import ray -from ray.dag.dag_node import DAGNode -from ray.dag.input_node import InputNode -from ray import serve -from ray.dag.utils import _DAGNodeNameGenerator -from ray.serve.handle import ( - RayServeSyncHandle, - _serve_handle_to_json_dict, - _serve_handle_from_json_dict, -) -from ray.serve._private.json_serde import ( - DAGNodeEncoder, - dagnode_from_json, -) -from ray.serve.tests.resources.test_modules import ( - Model, - combine, - Counter, - ClassHello, - fn_hello, - Combine, - NESTED_HANDLE_KEY, -) -from ray.serve._private.deployment_graph_build import ( - transform_ray_dag_to_serve_dag, - extract_deployments_from_serve_dag, - transform_serve_dag_to_serve_executor_dag, -) - -RayHandleLike = TypeVar("RayHandleLike") -pytestmark = pytest.mark.asyncio - - -async def test_non_json_serializable_args(): - """Use non-JSON serializable object in Ray DAG and ensure we throw exception - with reasonable error messages. - """ - - class MyNonJSONClass: - def __init__(self, val): - self.val = val - - ray_dag = combine.bind(MyNonJSONClass(1), MyNonJSONClass(2)) - # General context - with pytest.raises( - TypeError, - match=r"Object of type .* is not JSON serializable", - ): - _ = json.dumps(ray_dag, cls=DAGNodeEncoder) - - -async def test_simple_function_node_json_serde(serve_instance): - """ - Test the following behavior - 1) Ray DAG node can go through full JSON serde cycle - 2) Ray DAG node and deserialized DAG node produces same output - 3) Ray DAG node can go through multiple rounds of JSON serde and still - provides the same value as if it's only JSON serde once - Against following test cases - - Simple function with no args - - Simple function with only args, all primitive types - - Simple function with args + kwargs, all primitive types - """ - original_dag_node = combine.bind(1, 2) - await _test_deployment_json_serde_helper( - original_dag_node, - expected_num_deployments=1, - ) - - original_dag_node = combine.bind(1, 2, kwargs_output=3) - await _test_deployment_json_serde_helper( - original_dag_node, - expected_num_deployments=1, - ) - - original_dag_node = fn_hello.bind() - await _test_deployment_json_serde_helper( - original_dag_node, - expected_num_deployments=1, - ) - - -async def test_simple_class_node_json_serde(serve_instance): - """ - Test the following behavior - 1) Ray DAG node can go through full JSON serde cycle - 2) Ray DAG node and deserialized DAG node produces same actor instances - with same method call output - 3) Ray DAG node can go through multiple rounds of JSON serde and still - provides the same value as if it's only JSON serde once - Against following test cases - - Simple class with no args - - Simple class with only args, all primitive types - - Simple class with args + kwargs, all primitive types - - Simple chain of class method calls, all primitive types - """ - hello_actor = ClassHello.bind() - original_dag_node = hello_actor.hello.bind() - await _test_deployment_json_serde_helper( - original_dag_node, - expected_num_deployments=1, - ) - - model_actor = Model.bind(1) - original_dag_node = model_actor.forward.bind(1) - await _test_deployment_json_serde_helper( - original_dag_node, - expected_num_deployments=1, - ) - - model_actor = Model.bind(1, ratio=0.5) - original_dag_node = model_actor.forward.bind(1) - await _test_deployment_json_serde_helper( - original_dag_node, - expected_num_deployments=1, - ) - - -async def _test_deployment_json_serde_helper( - ray_dag: DAGNode, input=None, expected_num_deployments=None -): - """Helper function for DeploymentNode and DeploymentMethodNode calls, checks - the following: - 1) Transform ray dag to serve dag, and ensure serve dag is JSON - serializable. - 2) Serve dag JSON and be deserialized back to serve dag. - 3) Deserialized serve dag can extract correct number and async definition of - serve deployments. - """ - with _DAGNodeNameGenerator() as node_name_generator: - serve_root_dag = ray_dag.apply_recursive( - lambda node: transform_ray_dag_to_serve_dag(node, node_name_generator) - ) - deserialized_deployments = extract_deployments_from_serve_dag(serve_root_dag) - serve_executor_root_dag = serve_root_dag.apply_recursive( - transform_serve_dag_to_serve_executor_dag - ) - json_serialized = json.dumps(serve_executor_root_dag, cls=DAGNodeEncoder) - deserialized_serve_executor_root_dag_node = json.loads( - json_serialized, object_hook=dagnode_from_json - ) - assert len(deserialized_deployments) == expected_num_deployments - # Deploy deserilized version to ensure JSON serde correctness - for model in deserialized_deployments: - model.deploy() - if input is None: - assert ray.get(ray_dag.execute()) == ray.get( - await serve_executor_root_dag.execute() - ) - else: - assert ray.get(ray_dag.execute(input)) == ray.get( - await serve_executor_root_dag.execute(input) - ) - return serve_executor_root_dag, deserialized_serve_executor_root_dag_node - - -async def test_simple_deployment_method_call_chain(serve_instance): - """In Ray Core DAG, we maintain a simple linked list to keep track of - method call lineage on the SAME parent class node with same uuid. However - JSON serialization is only applicable in serve and we convert all - ClassMethodNode to DeploymentMethodNode that acts on deployment handle - that is uniquely identified by its name without dependency of uuid. - """ - counter = Counter.bind(0) - counter.inc.bind(1) - counter.inc.bind(2) - ray_dag = counter.get.bind() - assert ray.get(ray_dag.execute()) == 3 - - # note(simon): Equivalence is not guaranteed here and - # nor should it be a supported workflow. - - # ( - # serve_root_dag, - # deserialized_serve_root_dag_node, - # ) = await _test_deployment_json_serde_helper(ray_dag, expected_num_deployments=1) - # # Deployment to Deployment, possible DeploymentMethodNode call chain - # # Both serve dags uses the same underlying deployments, thus the rhs value - # # went through two execute() - # assert ray.get(serve_root_dag.execute()) + ray.get(ray_dag.execute()) == ray.get( - # deserialized_serve_root_dag_node.execute() - # ) - - -async def test_multi_instantiation_class_nested_deployment_arg(serve_instance): - with InputNode() as dag_input: - m1 = Model.bind(2) - m2 = Model.bind(3) - combine = Combine.bind(m1, m2={NESTED_HANDLE_KEY: m2}, m2_nested=True) - ray_dag = combine.__call__.bind(dag_input) - - ( - serve_root_dag, - deserialized_serve_root_dag_node, - ) = await _test_deployment_json_serde_helper( - ray_dag, input=1, expected_num_deployments=3 - ) - assert ray.get(await serve_root_dag.execute(1)) == ray.get( - await deserialized_serve_root_dag_node.execute(1) - ) - - -async def test_nested_deployment_node_json_serde(serve_instance): - with InputNode() as dag_input: - m1 = Model.bind(2) - m2 = Model.bind(3) - - m1_output = m1.forward.bind(dag_input) - m2_output = m2.forward.bind(dag_input) - - ray_dag = combine.bind(m1_output, m2_output) - ( - serve_root_dag, - deserialized_serve_root_dag_node, - ) = await _test_deployment_json_serde_helper( - ray_dag, input=1, expected_num_deployments=3 - ) - assert ray.get(await serve_root_dag.execute(1)) == ray.get( - await deserialized_serve_root_dag_node.execute(1) - ) - - -def get_handle(sync: bool = True): - @serve.deployment - def echo(inp: str): - return inp - - echo.deploy() - return echo.get_handle(sync=sync) - - -async def call(handle, inp): - if isinstance(handle, RayServeSyncHandle): - ref = handle.remote(inp) - else: - ref = await handle.remote(inp) - - return ray.get(ref) - - -@pytest.mark.asyncio -class TestHandleJSON: - def test_invalid(self, serve_instance): - with pytest.raises(ValueError): - _serve_handle_from_json_dict({"blah": 123}) - - @pytest.mark.parametrize("sync", [False, True]) - async def test_basic(self, serve_instance, sync): - handle = get_handle(sync) - assert await call(handle, "hi") == "hi" - - serialized = json.dumps(_serve_handle_to_json_dict(handle)) - # Check we can go through multiple rounds of serde. - serialized = json.dumps(json.loads(serialized)) - - # Load the handle back from the dict. - handle = _serve_handle_from_json_dict(json.loads(serialized)) - assert await call(handle, "hi") == "hi" - - -if __name__ == "__main__": - import sys - - sys.exit(pytest.main(["-v", __file__])) From 4b3c5e20ad874c74c7b43e53606d288940a6cc83 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 7 Jul 2023 10:44:52 -0500 Subject: [PATCH 2/5] fix Signed-off-by: Edward Oakes --- python/ray/dag/dag_node.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index 387fea1a99c5..6408f92de15f 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -323,9 +323,11 @@ def _copy( return instance def __getstate__(self): + """Required due to overriding `__getattr__` else pickling fails.""" return self.__dict__ def __setstate__(self, d: Dict[str, Any]): + """Required due to overriding `__getattr__` else pickling fails.""" self.__dict__.update(d) def __getattr__(self, attr: str): From 47ca6616ae409a43b06f7cf635bf226ef6576072 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 7 Jul 2023 11:45:40 -0500 Subject: [PATCH 3/5] fix Signed-off-by: Edward Oakes --- python/ray/serve/BUILD | 8 -------- 1 file changed, 8 deletions(-) diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 5c58e6df244e..f80c8b213cd3 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -601,14 +601,6 @@ py_test( deps = [":serve_lib"] ) -py_test( - name = "test_json_serde", - size = "medium", - srcs = serve_tests_srcs, - tags = ["exclusive", "team:serve"], - deps = [":serve_lib"], -) - py_test( name = "test_gradio", size = "medium", From 855066ea45a61d114e599b0316bc496c7b1301ff Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 7 Jul 2023 11:47:24 -0500 Subject: [PATCH 4/5] fix Signed-off-by: Edward Oakes --- python/ray/dag/tests/test_class_dag.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/python/ray/dag/tests/test_class_dag.py b/python/ray/dag/tests/test_class_dag.py index 55cff1e540c4..8bef8c792f9f 100644 --- a/python/ray/dag/tests/test_class_dag.py +++ b/python/ray/dag/tests/test_class_dag.py @@ -1,9 +1,7 @@ import pytest -import pickle import ray from ray.dag import ( - DAGNode, PARENT_CLASS_NODE_KEY, PREV_CLASS_METHOD_CALL_KEY, ) @@ -33,12 +31,6 @@ def get(self): return self.i -def test_serialize_warning(): - node = DAGNode([], {}, {}, {}) - with pytest.raises(ValueError): - pickle.dumps(node) - - def test_basic_actor_dag(shared_ray_instance): @ray.remote def combine(x, y): From 6bbd1de4d51900ebacb3a971830b78a7f34d0543 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 7 Jul 2023 13:24:36 -0500 Subject: [PATCH 5/5] fix Signed-off-by: Edward Oakes --- python/ray/serve/deployment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/serve/deployment.py b/python/ray/serve/deployment.py index bd776e28afd3..1f93a843bcbd 100644 --- a/python/ray/serve/deployment.py +++ b/python/ray/serve/deployment.py @@ -266,7 +266,7 @@ def bind(self, *args, **kwargs) -> Application: """ copied_self = copy(self) - copied_self._func_or_class = "dummpy.module" + copied_self._func_or_class = "dummy.module" schema_shell = deployment_to_schema(copied_self) if inspect.isfunction(self._func_or_class):