Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] Remove JSON ser/de logic from DAG building #37198

Merged
merged 7 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 0 additions & 62 deletions python/ray/dag/class_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ray.dag.constants import (
PARENT_CLASS_NODE_KEY,
PREV_CLASS_METHOD_CALL_KEY,
DAGNODE_TYPE_KEY,
)
from ray.util.annotations import DeveloperAPI

Expand Down Expand Up @@ -92,36 +91,6 @@ def __getattr__(self, method_name: str):
def __str__(self) -> str:
return get_dag_node_str(self, str(self._body))

def get_import_path(self) -> str:
body = self._body.__ray_actor_class__
return f"{body.__module__}.{body.__qualname__}"

def to_json(self) -> Dict[str, Any]:
return {
DAGNODE_TYPE_KEY: ClassNode.__name__,
# Will be overriden by build()
"import_path": self.get_import_path(),
"args": self.get_args(),
"kwargs": self.get_kwargs(),
# .options() should not contain any DAGNode type
"options": self.get_options(),
"other_args_to_resolve": self.get_other_args_to_resolve(),
"uuid": self.get_stable_uuid(),
}

@classmethod
def from_json(cls, input_json, module, object_hook=None):
assert input_json[DAGNODE_TYPE_KEY] == ClassNode.__name__
node = cls(
module.__ray_metadata__.modified_class,
input_json["args"],
input_json["kwargs"],
input_json["options"],
other_args_to_resolve=input_json["other_args_to_resolve"],
)
node._stable_uuid = input_json["uuid"]
return node


class _UnboundClassMethodNode(object):
def __init__(self, actor: ClassNode, method_name: str):
Expand Down Expand Up @@ -230,34 +199,3 @@ def __str__(self) -> str:

def get_method_name(self) -> str:
return self._method_name

def get_import_path(self) -> str:
body = self._parent_class_node._body.__ray_actor_class__
return f"{body.__module__}.{body.__qualname__}"

def to_json(self) -> Dict[str, Any]:
return {
DAGNODE_TYPE_KEY: ClassMethodNode.__name__,
# Will be overriden by build()
"method_name": self.get_method_name(),
"import_path": self.get_import_path(),
"args": self.get_args(),
"kwargs": self.get_kwargs(),
# .options() should not contain any DAGNode type
"options": self.get_options(),
"other_args_to_resolve": self.get_other_args_to_resolve(),
"uuid": self.get_stable_uuid(),
}

@classmethod
def from_json(cls, input_json):
assert input_json[DAGNODE_TYPE_KEY] == ClassMethodNode.__name__
node = cls(
input_json["method_name"],
input_json["args"],
input_json["kwargs"],
input_json["options"],
other_args_to_resolve=input_json["other_args_to_resolve"],
)
node._stable_uuid = input_json["uuid"]
return node
12 changes: 6 additions & 6 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,13 @@ def _copy(
instance._stable_uuid = self._stable_uuid
return instance

def __reduce__(self):
"""We disallow serialization to prevent inadvertent closure-capture.
def __getstate__(self):
"""Required due to overriding `__getattr__` else pickling fails."""
return self.__dict__

Use ``.to_json()`` and ``.from_json()`` to convert DAGNodes to a
serializable form.
"""
raise ValueError(f"DAGNode cannot be serialized. DAGNode: {str(self)}")
def __setstate__(self, d: Dict[str, Any]):
"""Required due to overriding `__getattr__` else pickling fails."""
self.__dict__.update(d)

def __getattr__(self, attr: str):
if attr == "bind":
Expand Down
30 changes: 0 additions & 30 deletions python/ray/dag/function_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import ray
from ray.dag.dag_node import DAGNode
from ray.dag.format_utils import get_dag_node_str
from ray.dag.constants import DAGNODE_TYPE_KEY
from ray.util.annotations import DeveloperAPI


Expand Down Expand Up @@ -59,32 +58,3 @@ def _execute_impl(self, *args, **kwargs):

def __str__(self) -> str:
return get_dag_node_str(self, str(self._body))

def get_import_path(self):
return f"{self._body.__module__}.{self._body.__qualname__}"

def to_json(self) -> Dict[str, Any]:
return {
DAGNODE_TYPE_KEY: FunctionNode.__name__,
# Will be overriden by build()
"import_path": self.get_import_path(),
"args": self.get_args(),
"kwargs": self.get_kwargs(),
# .options() should not contain any DAGNode type
"options": self.get_options(),
"other_args_to_resolve": self.get_other_args_to_resolve(),
"uuid": self.get_stable_uuid(),
}

@classmethod
def from_json(cls, input_json, module):
assert input_json[DAGNODE_TYPE_KEY] == FunctionNode.__name__
node = cls(
module._function,
input_json["args"],
input_json["kwargs"],
input_json["options"],
other_args_to_resolve=input_json["other_args_to_resolve"],
)
node._stable_uuid = input_json["uuid"]
return node
34 changes: 0 additions & 34 deletions python/ray/dag/input_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from ray.dag import DAGNode
from ray.dag.format_utils import get_dag_node_str
from ray.experimental.gradio_utils import type_to_string
from ray.dag.constants import DAGNODE_TYPE_KEY
from ray.util.annotations import DeveloperAPI

IN_CONTEXT_MANAGER = "__in_context_manager__"
Expand Down Expand Up @@ -173,20 +172,6 @@ def __enter__(self):
def __exit__(self, *args):
pass

def to_json(self) -> Dict[str, Any]:
return {
DAGNODE_TYPE_KEY: InputNode.__name__,
"other_args_to_resolve": self.get_other_args_to_resolve(),
"uuid": self.get_stable_uuid(),
}

@classmethod
def from_json(cls, input_json):
assert input_json[DAGNODE_TYPE_KEY] == InputNode.__name__
node = cls(_other_args_to_resolve=input_json["other_args_to_resolve"])
node._stable_uuid = input_json["uuid"]
return node

def get_result_type(self) -> str:
"""Get type of the output of this DAGNode.

Expand Down Expand Up @@ -298,25 +283,6 @@ def _execute_impl(self, *args, **kwargs):
def __str__(self) -> str:
return get_dag_node_str(self, f'["{self._key}"]')

def to_json(self) -> Dict[str, Any]:
return {
DAGNODE_TYPE_KEY: InputAttributeNode.__name__,
"other_args_to_resolve": self.get_other_args_to_resolve(),
"uuid": self.get_stable_uuid(),
}

@classmethod
def from_json(cls, input_json):
assert input_json[DAGNODE_TYPE_KEY] == InputAttributeNode.__name__
node = cls(
input_json["other_args_to_resolve"]["dag_input_node"],
input_json["other_args_to_resolve"]["key"],
input_json["other_args_to_resolve"]["accessor_method"],
input_json["other_args_to_resolve"]["result_type_string"],
)
node._stable_uuid = input_json["uuid"]
return node

def get_result_type(self) -> str:
"""Get type of the output of this DAGNode.

Expand Down
8 changes: 0 additions & 8 deletions python/ray/dag/tests/test_class_dag.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import pytest
import pickle

import ray
from ray.dag import (
DAGNode,
PARENT_CLASS_NODE_KEY,
PREV_CLASS_METHOD_CALL_KEY,
)
Expand Down Expand Up @@ -33,12 +31,6 @@ def get(self):
return self.i


def test_serialize_warning():
node = DAGNode([], {}, {}, {})
with pytest.raises(ValueError):
pickle.dumps(node)


def test_basic_actor_dag(shared_ray_instance):
@ray.remote
def combine(x, y):
Expand Down
8 changes: 0 additions & 8 deletions python/ray/serve/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -610,14 +610,6 @@ py_test(
deps = [":serve_lib"]
)

py_test(
name = "test_json_serde",
size = "medium",
srcs = serve_tests_srcs,
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
)

py_test(
name = "test_gradio",
size = "medium",
Expand Down
6 changes: 0 additions & 6 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from enum import Enum

#: Used for debugging to turn on DEBUG-level logs
DEBUG_LOG_ENV_VAR = "SERVE_DEBUG_LOG"
Expand Down Expand Up @@ -150,11 +149,6 @@
PUSH_MULTIPLEXED_MODEL_IDS_INTERVAL_S = 1.0


class ServeHandleType(str, Enum):
SYNC = "SYNC"
ASYNC = "ASYNC"


# Deprecation message for V1 migrations.
MIGRATION_MESSAGE = (
"See https://docs.ray.io/en/latest/serve/index.html for more information."
Expand Down
7 changes: 3 additions & 4 deletions python/ray/serve/_private/deployment_graph_build.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import inspect
import json
from typing import List
from collections import OrderedDict

from ray import cloudpickle

from ray.serve.deployment import Deployment, schema_to_deployment
from ray.serve.deployment_graph import RayServeDAGHandle
from ray.serve._private.constants import DEPLOYMENT_NAME_PREFIX_SEPARATOR
Expand All @@ -16,7 +17,6 @@
from ray.serve._private.deployment_function_executor_node import (
DeploymentFunctionExecutorNode,
)
from ray.serve._private.json_serde import DAGNodeEncoder
from ray.serve.handle import RayServeDeploymentHandle
from ray.serve.schema import DeploymentSchema

Expand Down Expand Up @@ -389,8 +389,7 @@ def replace_with_handle(node):
DeploymentFunctionExecutorNode,
),
):
serve_dag_root_json = json.dumps(node, cls=DAGNodeEncoder)
return RayServeDAGHandle(serve_dag_root_json)
return RayServeDAGHandle(cloudpickle.dumps(node))

(
replaced_deployment_init_args,
Expand Down
Loading