Skip to content

Commit

Permalink
[serve] wrap obj ref in result wrapper in deployment response (ray-pr…
Browse files Browse the repository at this point in the history
…oject#47655)

## Why are these changes needed?

Abstract `ray.ObjectRef` and `ray.ObjectRefGenerator` in a result
wrapper that the deployment response can directly call into.

---------

Signed-off-by: Cindy Zhang <cindyzyx9@gmail.com>
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
zcin authored and ujjawal-khare committed Oct 15, 2024
1 parent 0da5d0b commit 236f3d2
Show file tree
Hide file tree
Showing 9 changed files with 398 additions and 277 deletions.
6 changes: 0 additions & 6 deletions python/ray/serve/_private/default_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
ClusterNodeInfoCache,
DefaultClusterNodeInfoCache,
)
from ray.serve._private.common import RunningReplicaInfo
from ray.serve._private.deployment_scheduler import (
DefaultDeploymentScheduler,
DeploymentScheduler,
)
from ray.serve._private.replica_scheduler.replica_wrapper import ActorReplicaWrapper
from ray.serve._private.utils import get_head_node_id

# NOTE: Please read carefully before changing!
Expand All @@ -37,7 +35,3 @@ def create_deployment_scheduler(
create_placement_group_fn=create_placement_group_fn_override
or ray.util.placement_group,
)


def create_replica_wrapper(replica_info: RunningReplicaInfo):
return ActorReplicaWrapper(replica_info)
146 changes: 146 additions & 0 deletions python/ray/serve/_private/replica_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import threading
import time
from abc import ABC, abstractmethod
from typing import Callable, Optional, Union

import ray
from ray.serve._private.utils import calculate_remaining_timeout


class ReplicaResult(ABC):
@abstractmethod
def get(self, timeout_s: Optional[float]):
raise NotImplementedError

@abstractmethod
async def get_async(self):
raise NotImplementedError

@abstractmethod
def __next__(self):
raise NotImplementedError

@abstractmethod
async def __anext__(self):
raise NotImplementedError

@abstractmethod
def add_callback(self, callback: Callable):
raise NotImplementedError

@abstractmethod
def cancel(self):
raise NotImplementedError


class ActorReplicaResult(ReplicaResult):
def __init__(
self,
obj_ref_or_gen: Union[ray.ObjectRef, ray.ObjectRefGenerator],
is_streaming: bool,
):
self._obj_ref: Optional[ray.ObjectRef] = None
self._obj_ref_gen: Optional[ray.ObjectRefGenerator] = None
self._is_streaming: bool = is_streaming
self._object_ref_or_gen_sync_lock = threading.Lock()

if isinstance(obj_ref_or_gen, ray.ObjectRefGenerator):
self._obj_ref_gen = obj_ref_or_gen
else:
self._obj_ref = obj_ref_or_gen

@property
def obj_ref(self) -> Optional[ray.ObjectRef]:
return self._obj_ref

@property
def obj_ref_gen(self) -> Optional[ray.ObjectRefGenerator]:
return self._obj_ref_gen

def resolve_gen_to_ref_if_necessary_sync(
self, timeout_s: Optional[float] = None
) -> Optional[ray.ObjectRef]:
"""Returns the object ref pointing to the result."""

# NOTE(edoakes): this section needs to be guarded with a lock and the resulting
# object ref cached in order to avoid calling `__next__()` to
# resolve to the underlying object ref more than once.
# See: https://github.com/ray-project/ray/issues/43879.
with self._object_ref_or_gen_sync_lock:
if self._obj_ref is None and not self._is_streaming:
# Populate _obj_ref
obj_ref = self._obj_ref_gen._next_sync(timeout_s=timeout_s)

# Check for timeout
if obj_ref.is_nil():
raise TimeoutError("Timed out resolving to ObjectRef.")

self._obj_ref = obj_ref

return self._obj_ref

async def resolve_gen_to_ref_if_necessary_async(self) -> Optional[ray.ObjectRef]:
"""Returns the object ref pointing to the result."""

# NOTE(edoakes): this section needs to be guarded with a lock and the resulting
# object ref cached in order to avoid calling `__anext__()` to
# resolve to the underlying object ref more than once.
# See: https://github.com/ray-project/ray/issues/43879.
with self._object_ref_or_gen_sync_lock:
if self._obj_ref is None and not self._is_streaming:
self._obj_ref = await self._obj_ref_gen.__anext__()

return self._obj_ref

def get(self, timeout_s: Optional[float]):
assert (
self._obj_ref is not None or not self._is_streaming
), "get() can only be called on a non-streaming ActorReplicaResult"

start_time_s = time.time()
self.resolve_gen_to_ref_if_necessary_sync(timeout_s)

remaining_timeout_s = calculate_remaining_timeout(
timeout_s=timeout_s,
start_time_s=start_time_s,
curr_time_s=time.time(),
)
return ray.get(self._obj_ref, timeout=remaining_timeout_s)

async def get_async(self):
assert (
self._obj_ref is not None or not self._is_streaming
), "get_async() can only be called on a non-streaming ActorReplicaResult"

await self.resolve_gen_to_ref_if_necessary_async()
return await self._obj_ref

def __next__(self):
assert self._obj_ref_gen is not None, (
"next() can only be called on an ActorReplicaResult initialized with a "
"ray.ObjectRefGenerator"
)

next_obj_ref = self._obj_ref_gen.__next__()
return ray.get(next_obj_ref)

async def __anext__(self):
assert self._obj_ref_gen is not None, (
"anext() can only be called on an ActorReplicaResult initialized with a "
"ray.ObjectRefGenerator"
)

next_obj_ref = await self._obj_ref_gen.__anext__()
return await next_obj_ref

def add_callback(self, callback: Callable):
if self._obj_ref_gen is not None:
self._obj_ref_gen.completed()._on_completed(callback)
else:
self._obj_ref._on_completed(callback)

def cancel(self):
if self._obj_ref_gen is not None:
ray.cancel(self._obj_ref_gen)
else:
ray.cancel(self._obj_ref)
109 changes: 53 additions & 56 deletions python/ray/serve/_private/replica_scheduler/replica_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import pickle
from abc import ABC
from abc import ABC, abstractmethod
from typing import Optional, Set, Tuple, Union

import ray
Expand All @@ -11,6 +11,7 @@
ReplicaQueueLengthInfo,
RunningReplicaInfo,
)
from ray.serve._private.replica_result import ActorReplicaResult, ReplicaResult
from ray.serve._private.replica_scheduler.common import PendingRequest
from ray.serve._private.utils import JavaActorHandleProxy
from ray.serve.generated.serve_pb2 import RequestMetadata as RequestMetadataProto
Expand All @@ -22,54 +23,6 @@ class ReplicaWrapper(ABC):
This is used to abstract away details of Ray actor calls for testing.
"""

@property
def replica_id(self) -> ReplicaID:
"""ID of this replica."""
pass

@property
def multiplexed_model_ids(self) -> Set[str]:
"""Set of model IDs on this replica."""
pass

@property
def max_ongoing_requests(self) -> int:
"""Max concurrent requests that can be sent to this replica."""
pass

def push_proxy_handle(self, handle: ActorHandle):
"""When on proxy, push proxy's self handle to replica"""
pass

async def get_queue_len(self, *, deadline_s: float) -> int:
"""Returns current queue len for the replica.
`deadline_s` is passed to verify backoff for testing.
"""
pass

def send_request(self, pr: PendingRequest) -> Union[ObjectRef, ObjectRefGenerator]:
"""Send request to this replica."""
pass

async def send_request_with_rejection(
self,
pr: PendingRequest,
) -> Tuple[Optional[ObjectRefGenerator], ReplicaQueueLengthInfo]:
"""Send request to this replica.
The replica will yield a system message (ReplicaQueueLengthInfo) before
executing the actual request. This can cause it to reject the request.
The result will *always* be a generator, so for non-streaming requests it's up
to the caller to resolve it to its first (and only) ObjectRef.
Only supported for Python replicas.
"""
pass


class ActorReplicaWrapper:
def __init__(self, replica_info: RunningReplicaInfo):
self._replica_info = replica_info
self._multiplexed_model_ids = set(replica_info.multiplexed_model_ids)
Expand All @@ -81,6 +34,7 @@ def __init__(self, replica_info: RunningReplicaInfo):

@property
def replica_id(self) -> ReplicaID:
"""ID of this replica."""
return self._replica_info.replica_id

@property
Expand All @@ -93,19 +47,53 @@ def availability_zone(self) -> Optional[str]:

@property
def multiplexed_model_ids(self) -> Set[str]:
"""Set of model IDs on this replica."""
return self._multiplexed_model_ids

@property
def max_ongoing_requests(self) -> int:
"""Max concurrent requests that can be sent to this replica."""
return self._replica_info.max_ongoing_requests

@property
def is_cross_language(self) -> bool:
return self._replica_info.is_cross_language

def push_proxy_handle(self, handle: ActorHandle):
"""When on proxy, push proxy's self handle to replica"""
self._actor_handle.push_proxy_handle.remote(handle)

@abstractmethod
async def get_queue_len(self, *, deadline_s: float) -> int:
"""Returns current queue len for the replica.
`deadline_s` is passed to verify backoff for testing.
"""
raise NotImplementedError

@abstractmethod
def send_request(self, pr: PendingRequest) -> ReplicaResult:
"""Send request to this replica."""
raise NotImplementedError

@abstractmethod
async def send_request_with_rejection(
self, pr: PendingRequest
) -> Tuple[Optional[ReplicaResult], ReplicaQueueLengthInfo]:
"""Send request to this replica.
The replica will yield a system message (ReplicaQueueLengthInfo) before
executing the actual request. This can cause it to reject the request.
The result will *always* be a generator, so for non-streaming requests it's up
to the caller to resolve it to its first (and only) ObjectRef.
Only supported for Python replicas.
"""
raise NotImplementedError


class ActorReplicaWrapper(ReplicaWrapper):
async def get_queue_len(self, *, deadline_s: float) -> int:
# NOTE(edoakes): the `get_num_ongoing_requests` method name is shared by
# the Python and Java replica implementations. If you change it, you need to
Expand Down Expand Up @@ -160,16 +148,20 @@ def _send_request_python(

return method.remote(pickle.dumps(pr.metadata), *pr.args, **pr.kwargs)

def send_request(self, pr: PendingRequest) -> Union[ObjectRef, ObjectRefGenerator]:
def send_request(self, pr: PendingRequest) -> ReplicaResult:
if self._replica_info.is_cross_language:
return self._send_request_java(pr)
return ActorReplicaResult(
self._send_request_java(pr), is_streaming=pr.metadata.is_streaming
)
else:
return self._send_request_python(pr, with_rejection=False)
return ActorReplicaResult(
self._send_request_python(pr, with_rejection=False),
is_streaming=pr.metadata.is_streaming,
)

async def send_request_with_rejection(
self,
pr: PendingRequest,
) -> Tuple[Optional[ObjectRefGenerator], ReplicaQueueLengthInfo]:
self, pr: PendingRequest
) -> Tuple[Optional[ReplicaResult], ReplicaQueueLengthInfo]:
assert (
not self._replica_info.is_cross_language
), "Request rejection not supported for Java."
Expand All @@ -182,7 +174,12 @@ async def send_request_with_rejection(
if not queue_len_info.accepted:
return None, queue_len_info
else:
return obj_ref_gen, queue_len_info
return (
ActorReplicaResult(
obj_ref_gen, is_streaming=pr.metadata.is_streaming
),
queue_len_info,
)
except asyncio.CancelledError as e:
# HTTP client disconnected or request was explicitly canceled.
ray.cancel(obj_ref_gen)
Expand Down
Loading

0 comments on commit 236f3d2

Please sign in to comment.