Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Handle multiple changed objects per LongPollHost.listen_for_change RPC #48803

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/ray/serve/_private/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def reconfigure_global_logging_config(self, global_logging_config: LoggingConfig
self.global_logging_config = global_logging_config

self.long_poll_host.notify_changed(
LongPollNamespace.GLOBAL_LOGGING_CONFIG,
global_logging_config,
{LongPollNamespace.GLOBAL_LOGGING_CONFIG: global_logging_config}
)
configure_component_logger(
component_name="controller",
Expand Down
24 changes: 12 additions & 12 deletions python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,16 +1448,17 @@ def broadcast_running_replicas_if_changed(self) -> None:
return

self._long_poll_host.notify_changed(
(LongPollNamespace.RUNNING_REPLICAS, self._id),
running_replica_infos,
)
# NOTE(zcin): notify changed for Java routers. Since Java only
# supports 1.x API, there is no concept of applications in Java,
# so the key should remain a string describing the deployment
# name. If there are no Java routers, this is a no-op.
self._long_poll_host.notify_changed(
(LongPollNamespace.RUNNING_REPLICAS, self._id.name),
running_replica_infos,
{
(LongPollNamespace.RUNNING_REPLICAS, self._id): running_replica_infos,
# NOTE(zcin): notify changed for Java routers. Since Java only
# supports 1.x API, there is no concept of applications in Java,
# so the key should remain a string describing the deployment
# name. If there are no Java routers, this is a no-op.
(
LongPollNamespace.RUNNING_REPLICAS,
self._id.name,
): running_replica_infos,
}
)
self._last_broadcasted_running_replica_infos = running_replica_infos
self._multiplexed_model_ids_updated = False
Expand All @@ -1473,8 +1474,7 @@ def broadcast_deployment_config_if_changed(self) -> None:
return

self._long_poll_host.notify_changed(
(LongPollNamespace.DEPLOYMENT_CONFIG, self._id),
current_deployment_config,
{(LongPollNamespace.DEPLOYMENT_CONFIG, self._id): current_deployment_config}
)

self._last_broadcasted_deployment_config = current_deployment_config
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/_private/endpoint_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _checkpoint(self):

def _notify_route_table_changed(self):
self._long_poll_host.notify_changed(
LongPollNamespace.ROUTE_TABLE, self._endpoints
{LongPollNamespace.ROUTE_TABLE: self._endpoints}
)

def _get_endpoint_for_route(self, route: str) -> Optional[DeploymentID]:
Expand Down
54 changes: 28 additions & 26 deletions python/ray/serve/_private/long_poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
from asyncio.events import AbstractEventLoop
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Callable, DefaultDict, Dict, Optional, Set, Tuple, Union
Expand Down Expand Up @@ -179,12 +180,12 @@ class LongPollHost:

The desired use case is to embed this in an Ray actor. Client will be
expected to call actor.listen_for_change.remote(...). On the host side,
you can call host.notify_changed(key, object) to update the state and
you can call host.notify_changed({key: object}) to update the state and
potentially notify whoever is polling for these values.

Internally, we use snapshot_ids for each object to identify client with
outdated object and immediately return the result. If the client has the
up-to-date verison, then the listen_for_change call will only return when
up-to-date version, then the listen_for_change call will only return when
the object is updated.
"""

Expand Down Expand Up @@ -306,15 +307,15 @@ async def listen_for_change(
self._count_send(LongPollState.TIME_OUT)
return LongPollState.TIME_OUT
else:
updated_object_key: str = async_task_to_watched_keys[done.pop()]
updated_object = {
updated_object_key: UpdatedObject(
updated_objects = {}
for task in done:
updated_object_key = async_task_to_watched_keys[task]
updated_objects[updated_object_key] = UpdatedObject(
self.object_snapshots[updated_object_key],
self.snapshot_ids[updated_object_key],
)
}
self._count_send(updated_object)
return updated_object
self._count_send(updated_objects)
return updated_objects

async def listen_for_change_java(
self,
Expand Down Expand Up @@ -403,21 +404,22 @@ def _listen_result_to_proto_bytes(
proto = LongPollResult(**data)
return proto.SerializeToString()

def notify_changed(
self,
object_key: KeyType,
updated_object: Any,
):
try:
self.snapshot_ids[object_key] += 1
except KeyError:
# Initial snapshot id must be >= 0, so that the long poll client
# can send a negative initial snapshot id to get a fast update.
# They should also be randomized;
# see https://github.com/ray-project/ray/pull/45881#discussion_r1645243485
self.snapshot_ids[object_key] = random.randint(0, 1_000_000)
self.object_snapshots[object_key] = updated_object
logger.debug(f"LongPollHost: Notify change for key {object_key}.")

for event in self.notifier_events.pop(object_key, set()):
event.set()
def notify_changed(self, updates: Mapping[KeyType, Any]) -> None:
"""
Update the current snapshot of some objects
and notify any long poll clients.
"""
for object_key, updated_object in updates.items():
try:
self.snapshot_ids[object_key] += 1
except KeyError:
# Initial snapshot id must be >= 0, so that the long poll client
# can send a negative initial snapshot id to get a fast update.
# They should also be randomized; see
# https://github.com/ray-project/ray/pull/45881#discussion_r1645243485
self.snapshot_ids[object_key] = random.randint(0, 1_000_000)
self.object_snapshots[object_key] = updated_object
logger.debug(f"LongPollHost: Notify change for key {object_key}.")

for event in self.notifier_events.pop(object_key, set()):
event.set()
34 changes: 17 additions & 17 deletions python/ray/serve/tests/test_long_poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_notifier_events_cleared_without_update(serve_instance):
host = ray.remote(LongPollHost).remote(
listen_for_change_request_timeout_s=(0.1, 0.1)
)
ray.get(host.notify_changed.remote("key_1", 999))
ray.get(host.notify_changed.remote({"key_1": 999}))

# Get an initial object snapshot for the key.
object_ref = host.listen_for_change.remote({"key_1": -1})
Expand All @@ -60,8 +60,8 @@ def test_host_standalone(serve_instance):
host = ray.remote(LongPollHost).remote()

# Write two values
ray.get(host.notify_changed.remote("key_1", 999))
ray.get(host.notify_changed.remote("key_2", 999))
ray.get(host.notify_changed.remote({"key_1": 999}))
ray.get(host.notify_changed.remote({"key_2": 999}))
object_ref = host.listen_for_change.remote({"key_1": -1, "key_2": -1})

# We should be able to get the result immediately
Expand All @@ -77,7 +77,7 @@ def test_host_standalone(serve_instance):
assert len(not_done) == 1

# Now update the value, we should immediately get updated value
ray.get(host.notify_changed.remote("key_2", 999))
ray.get(host.notify_changed.remote({"key_2": 999}))
result = ray.get(object_ref)
assert len(result) == 1
assert "key_2" in result
Expand All @@ -88,13 +88,13 @@ def test_long_poll_wait_for_keys(serve_instance):
# are set.
host = ray.remote(LongPollHost).remote()
object_ref = host.listen_for_change.remote({"key_1": -1, "key_2": -1})
ray.get(host.notify_changed.remote("key_1", 999))
ray.get(host.notify_changed.remote("key_2", 999))

# We should be able to get the one of the result immediately
ray.get(host.notify_changed.remote({"key_1": 123, "key_2": 456}))

# We should be able to get the both results immediately
result: Dict[str, UpdatedObject] = ray.get(object_ref)
assert set(result.keys()).issubset({"key_1", "key_2"})
assert {v.object_snapshot for v in result.values()} == {999}
assert result.keys() == {"key_1", "key_2"}
assert {v.object_snapshot for v in result.values()} == {123, 456}


def test_long_poll_restarts(serve_instance):
Expand All @@ -106,7 +106,7 @@ class RestartableLongPollHost:
def __init__(self) -> None:
print("actor started")
self.host = LongPollHost()
self.host.notify_changed("timer", time.time())
self.host.notify_changed({"timer": time.time()})
self.should_exit = False

async def listen_for_change(self, key_to_ids):
Expand Down Expand Up @@ -142,8 +142,8 @@ async def test_client_callbacks(serve_instance):
host = ray.remote(LongPollHost).remote()

# Write two values
ray.get(host.notify_changed.remote("key_1", 100))
ray.get(host.notify_changed.remote("key_2", 999))
ray.get(host.notify_changed.remote({"key_1": 100}))
ray.get(host.notify_changed.remote({"key_2": 999}))

callback_results = dict()

Expand All @@ -167,7 +167,7 @@ def key_2_callback(result):
timeout=1,
)

ray.get(host.notify_changed.remote("key_2", 1999))
ray.get(host.notify_changed.remote({"key_2": 1999}))

await async_wait_for_condition(
lambda: callback_results == {"key_1": 100, "key_2": 999},
Expand All @@ -178,7 +178,7 @@ def key_2_callback(result):
@pytest.mark.asyncio
async def test_client_threadsafe(serve_instance):
host = ray.remote(LongPollHost).remote()
ray.get(host.notify_changed.remote("key_1", 100))
ray.get(host.notify_changed.remote({"key_1": 100}))

e = asyncio.Event()

Expand All @@ -198,7 +198,7 @@ def key_1_callback(_):

def test_listen_for_change_java(serve_instance):
host = ray.remote(LongPollHost).remote()
ray.get(host.notify_changed.remote("key_1", 999))
ray.get(host.notify_changed.remote({"key_1": 999}))
request_1 = {"keys_to_snapshot_ids": {"key_1": -1}}
object_ref = host.listen_for_change_java.remote(
LongPollRequest(**request_1).SerializeToString()
Expand All @@ -211,7 +211,7 @@ def test_listen_for_change_java(serve_instance):
endpoints: Dict[DeploymentID, EndpointInfo] = dict()
endpoints["deployment_name"] = EndpointInfo(route="/test/xlang/poll")
endpoints["deployment_name1"] = EndpointInfo(route="/test/xlang/poll1")
ray.get(host.notify_changed.remote(LongPollNamespace.ROUTE_TABLE, endpoints))
ray.get(host.notify_changed.remote({LongPollNamespace.ROUTE_TABLE: endpoints}))
object_ref_2 = host.listen_for_change_java.remote(
LongPollRequest(**request_2).SerializeToString()
)
Expand Down Expand Up @@ -240,7 +240,7 @@ def test_listen_for_change_java(serve_instance):
]
ray.get(
host.notify_changed.remote(
(LongPollNamespace.RUNNING_REPLICAS, "deployment_name"), replicas
{(LongPollNamespace.RUNNING_REPLICAS, "deployment_name"): replicas}
)
)
object_ref_3 = host.listen_for_change_java.remote(
Expand Down
6 changes: 3 additions & 3 deletions python/ray/serve/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,7 +1581,7 @@ def test_long_poll_host_sends_counted(serve_instance):
)

# Write a value.
ray.get(host.notify_changed.remote("key_1", 999))
ray.get(host.notify_changed.remote({"key_1": 999}))
object_ref = host.listen_for_change.remote({"key_1": -1})

# Check that the result's size is reported.
Expand All @@ -1595,8 +1595,8 @@ def test_long_poll_host_sends_counted(serve_instance):
)

# Write two new values.
ray.get(host.notify_changed.remote("key_1", 1000))
ray.get(host.notify_changed.remote("key_2", 1000))
ray.get(host.notify_changed.remote({"key_1": 1000}))
ray.get(host.notify_changed.remote({"key_2": 1000}))
object_ref = host.listen_for_change.remote(
{"key_1": result_1["key_1"].snapshot_id, "key_2": -1}
)
Expand Down
Loading