diff --git a/python/ray/serve/_private/controller.py b/python/ray/serve/_private/controller.py index 8eff4c80315a..4aa6906b241f 100644 --- a/python/ray/serve/_private/controller.py +++ b/python/ray/serve/_private/controller.py @@ -226,8 +226,7 @@ def reconfigure_global_logging_config(self, global_logging_config: LoggingConfig self.global_logging_config = global_logging_config self.long_poll_host.notify_changed( - LongPollNamespace.GLOBAL_LOGGING_CONFIG, - global_logging_config, + {LongPollNamespace.GLOBAL_LOGGING_CONFIG: global_logging_config} ) configure_component_logger( component_name="controller", diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index ca0fb2d446c6..562fd62f62c6 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -1448,16 +1448,17 @@ def broadcast_running_replicas_if_changed(self) -> None: return self._long_poll_host.notify_changed( - (LongPollNamespace.RUNNING_REPLICAS, self._id), - running_replica_infos, - ) - # NOTE(zcin): notify changed for Java routers. Since Java only - # supports 1.x API, there is no concept of applications in Java, - # so the key should remain a string describing the deployment - # name. If there are no Java routers, this is a no-op. - self._long_poll_host.notify_changed( - (LongPollNamespace.RUNNING_REPLICAS, self._id.name), - running_replica_infos, + { + (LongPollNamespace.RUNNING_REPLICAS, self._id): running_replica_infos, + # NOTE(zcin): notify changed for Java routers. Since Java only + # supports 1.x API, there is no concept of applications in Java, + # so the key should remain a string describing the deployment + # name. If there are no Java routers, this is a no-op. + ( + LongPollNamespace.RUNNING_REPLICAS, + self._id.name, + ): running_replica_infos, + } ) self._last_broadcasted_running_replica_infos = running_replica_infos self._multiplexed_model_ids_updated = False @@ -1473,8 +1474,7 @@ def broadcast_deployment_config_if_changed(self) -> None: return self._long_poll_host.notify_changed( - (LongPollNamespace.DEPLOYMENT_CONFIG, self._id), - current_deployment_config, + {(LongPollNamespace.DEPLOYMENT_CONFIG, self._id): current_deployment_config} ) self._last_broadcasted_deployment_config = current_deployment_config diff --git a/python/ray/serve/_private/endpoint_state.py b/python/ray/serve/_private/endpoint_state.py index abc4c0615ad6..fd2074fd6669 100644 --- a/python/ray/serve/_private/endpoint_state.py +++ b/python/ray/serve/_private/endpoint_state.py @@ -46,7 +46,7 @@ def _checkpoint(self): def _notify_route_table_changed(self): self._long_poll_host.notify_changed( - LongPollNamespace.ROUTE_TABLE, self._endpoints + {LongPollNamespace.ROUTE_TABLE: self._endpoints} ) def _get_endpoint_for_route(self, route: str) -> Optional[DeploymentID]: diff --git a/python/ray/serve/_private/long_poll.py b/python/ray/serve/_private/long_poll.py index f3538913b76b..d6fb52e72310 100644 --- a/python/ray/serve/_private/long_poll.py +++ b/python/ray/serve/_private/long_poll.py @@ -4,6 +4,7 @@ import random from asyncio.events import AbstractEventLoop from collections import defaultdict +from collections.abc import Mapping from dataclasses import dataclass from enum import Enum, auto from typing import Any, Callable, DefaultDict, Dict, Optional, Set, Tuple, Union @@ -179,12 +180,12 @@ class LongPollHost: The desired use case is to embed this in an Ray actor. Client will be expected to call actor.listen_for_change.remote(...). On the host side, - you can call host.notify_changed(key, object) to update the state and + you can call host.notify_changed({key: object}) to update the state and potentially notify whoever is polling for these values. Internally, we use snapshot_ids for each object to identify client with outdated object and immediately return the result. If the client has the - up-to-date verison, then the listen_for_change call will only return when + up-to-date version, then the listen_for_change call will only return when the object is updated. """ @@ -306,15 +307,15 @@ async def listen_for_change( self._count_send(LongPollState.TIME_OUT) return LongPollState.TIME_OUT else: - updated_object_key: str = async_task_to_watched_keys[done.pop()] - updated_object = { - updated_object_key: UpdatedObject( + updated_objects = {} + for task in done: + updated_object_key = async_task_to_watched_keys[task] + updated_objects[updated_object_key] = UpdatedObject( self.object_snapshots[updated_object_key], self.snapshot_ids[updated_object_key], ) - } - self._count_send(updated_object) - return updated_object + self._count_send(updated_objects) + return updated_objects async def listen_for_change_java( self, @@ -403,21 +404,22 @@ def _listen_result_to_proto_bytes( proto = LongPollResult(**data) return proto.SerializeToString() - def notify_changed( - self, - object_key: KeyType, - updated_object: Any, - ): - try: - self.snapshot_ids[object_key] += 1 - except KeyError: - # Initial snapshot id must be >= 0, so that the long poll client - # can send a negative initial snapshot id to get a fast update. - # They should also be randomized; - # see https://github.com/ray-project/ray/pull/45881#discussion_r1645243485 - self.snapshot_ids[object_key] = random.randint(0, 1_000_000) - self.object_snapshots[object_key] = updated_object - logger.debug(f"LongPollHost: Notify change for key {object_key}.") - - for event in self.notifier_events.pop(object_key, set()): - event.set() + def notify_changed(self, updates: Mapping[KeyType, Any]) -> None: + """ + Update the current snapshot of some objects + and notify any long poll clients. + """ + for object_key, updated_object in updates.items(): + try: + self.snapshot_ids[object_key] += 1 + except KeyError: + # Initial snapshot id must be >= 0, so that the long poll client + # can send a negative initial snapshot id to get a fast update. + # They should also be randomized; see + # https://github.com/ray-project/ray/pull/45881#discussion_r1645243485 + self.snapshot_ids[object_key] = random.randint(0, 1_000_000) + self.object_snapshots[object_key] = updated_object + logger.debug(f"LongPollHost: Notify change for key {object_key}.") + + for event in self.notifier_events.pop(object_key, set()): + event.set() diff --git a/python/ray/serve/tests/test_long_poll.py b/python/ray/serve/tests/test_long_poll.py index 86bf03880e33..2ba31d414e05 100644 --- a/python/ray/serve/tests/test_long_poll.py +++ b/python/ray/serve/tests/test_long_poll.py @@ -38,7 +38,7 @@ def test_notifier_events_cleared_without_update(serve_instance): host = ray.remote(LongPollHost).remote( listen_for_change_request_timeout_s=(0.1, 0.1) ) - ray.get(host.notify_changed.remote("key_1", 999)) + ray.get(host.notify_changed.remote({"key_1": 999})) # Get an initial object snapshot for the key. object_ref = host.listen_for_change.remote({"key_1": -1}) @@ -60,8 +60,8 @@ def test_host_standalone(serve_instance): host = ray.remote(LongPollHost).remote() # Write two values - ray.get(host.notify_changed.remote("key_1", 999)) - ray.get(host.notify_changed.remote("key_2", 999)) + ray.get(host.notify_changed.remote({"key_1": 999})) + ray.get(host.notify_changed.remote({"key_2": 999})) object_ref = host.listen_for_change.remote({"key_1": -1, "key_2": -1}) # We should be able to get the result immediately @@ -77,7 +77,7 @@ def test_host_standalone(serve_instance): assert len(not_done) == 1 # Now update the value, we should immediately get updated value - ray.get(host.notify_changed.remote("key_2", 999)) + ray.get(host.notify_changed.remote({"key_2": 999})) result = ray.get(object_ref) assert len(result) == 1 assert "key_2" in result @@ -88,13 +88,13 @@ def test_long_poll_wait_for_keys(serve_instance): # are set. host = ray.remote(LongPollHost).remote() object_ref = host.listen_for_change.remote({"key_1": -1, "key_2": -1}) - ray.get(host.notify_changed.remote("key_1", 999)) - ray.get(host.notify_changed.remote("key_2", 999)) - # We should be able to get the one of the result immediately + ray.get(host.notify_changed.remote({"key_1": 123, "key_2": 456})) + + # We should be able to get the both results immediately result: Dict[str, UpdatedObject] = ray.get(object_ref) - assert set(result.keys()).issubset({"key_1", "key_2"}) - assert {v.object_snapshot for v in result.values()} == {999} + assert result.keys() == {"key_1", "key_2"} + assert {v.object_snapshot for v in result.values()} == {123, 456} def test_long_poll_restarts(serve_instance): @@ -106,7 +106,7 @@ class RestartableLongPollHost: def __init__(self) -> None: print("actor started") self.host = LongPollHost() - self.host.notify_changed("timer", time.time()) + self.host.notify_changed({"timer": time.time()}) self.should_exit = False async def listen_for_change(self, key_to_ids): @@ -142,8 +142,8 @@ async def test_client_callbacks(serve_instance): host = ray.remote(LongPollHost).remote() # Write two values - ray.get(host.notify_changed.remote("key_1", 100)) - ray.get(host.notify_changed.remote("key_2", 999)) + ray.get(host.notify_changed.remote({"key_1": 100})) + ray.get(host.notify_changed.remote({"key_2": 999})) callback_results = dict() @@ -167,7 +167,7 @@ def key_2_callback(result): timeout=1, ) - ray.get(host.notify_changed.remote("key_2", 1999)) + ray.get(host.notify_changed.remote({"key_2": 1999})) await async_wait_for_condition( lambda: callback_results == {"key_1": 100, "key_2": 999}, @@ -178,7 +178,7 @@ def key_2_callback(result): @pytest.mark.asyncio async def test_client_threadsafe(serve_instance): host = ray.remote(LongPollHost).remote() - ray.get(host.notify_changed.remote("key_1", 100)) + ray.get(host.notify_changed.remote({"key_1": 100})) e = asyncio.Event() @@ -198,7 +198,7 @@ def key_1_callback(_): def test_listen_for_change_java(serve_instance): host = ray.remote(LongPollHost).remote() - ray.get(host.notify_changed.remote("key_1", 999)) + ray.get(host.notify_changed.remote({"key_1": 999})) request_1 = {"keys_to_snapshot_ids": {"key_1": -1}} object_ref = host.listen_for_change_java.remote( LongPollRequest(**request_1).SerializeToString() @@ -211,7 +211,7 @@ def test_listen_for_change_java(serve_instance): endpoints: Dict[DeploymentID, EndpointInfo] = dict() endpoints["deployment_name"] = EndpointInfo(route="/test/xlang/poll") endpoints["deployment_name1"] = EndpointInfo(route="/test/xlang/poll1") - ray.get(host.notify_changed.remote(LongPollNamespace.ROUTE_TABLE, endpoints)) + ray.get(host.notify_changed.remote({LongPollNamespace.ROUTE_TABLE: endpoints})) object_ref_2 = host.listen_for_change_java.remote( LongPollRequest(**request_2).SerializeToString() ) @@ -240,7 +240,7 @@ def test_listen_for_change_java(serve_instance): ] ray.get( host.notify_changed.remote( - (LongPollNamespace.RUNNING_REPLICAS, "deployment_name"), replicas + {(LongPollNamespace.RUNNING_REPLICAS, "deployment_name"): replicas} ) ) object_ref_3 = host.listen_for_change_java.remote( diff --git a/python/ray/serve/tests/test_metrics.py b/python/ray/serve/tests/test_metrics.py index f93e37661394..6b3e674e7907 100644 --- a/python/ray/serve/tests/test_metrics.py +++ b/python/ray/serve/tests/test_metrics.py @@ -1581,7 +1581,7 @@ def test_long_poll_host_sends_counted(serve_instance): ) # Write a value. - ray.get(host.notify_changed.remote("key_1", 999)) + ray.get(host.notify_changed.remote({"key_1": 999})) object_ref = host.listen_for_change.remote({"key_1": -1}) # Check that the result's size is reported. @@ -1595,8 +1595,8 @@ def test_long_poll_host_sends_counted(serve_instance): ) # Write two new values. - ray.get(host.notify_changed.remote("key_1", 1000)) - ray.get(host.notify_changed.remote("key_2", 1000)) + ray.get(host.notify_changed.remote({"key_1": 1000})) + ray.get(host.notify_changed.remote({"key_2": 1000})) object_ref = host.listen_for_change.remote( {"key_1": result_1["key_1"].snapshot_id, "key_2": -1} )