diff --git a/python/ray/serve/_private/api.py b/python/ray/serve/_private/api.py index 5b8147f93893..db04df80e7d6 100644 --- a/python/ray/serve/_private/api.py +++ b/python/ray/serve/_private/api.py @@ -146,9 +146,6 @@ def _start_controller( else: controller_name = format_actor_name(get_random_letters(), SERVE_CONTROLLER_NAME) - # Used for scheduling things to the head node explicitly. - # Assumes that `serve.start` runs on the head node. - head_node_id = ray.get_runtime_context().get_node_id() controller_actor_options = { "num_cpus": 1 if dedicated_cpu else 0, "name": controller_name, @@ -164,7 +161,6 @@ def _start_controller( controller = ServeController.options(**controller_actor_options).remote( controller_name, http_config=http_options, - head_node_id=head_node_id, detached=detached, _disable_http_proxy=True, ) @@ -186,7 +182,6 @@ def _start_controller( controller = ServeController.options(**controller_actor_options).remote( controller_name, http_config=http_options, - head_node_id=head_node_id, detached=detached, ) diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 29e0c8811fe2..0f4b6c9cd27a 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -1239,7 +1239,11 @@ def get_active_node_ids(self) -> Set[str]: ReplicaState.RECOVERING, ReplicaState.RUNNING, ] - return {replica.actor_node_id for replica in self._replicas.get(active_states)} + return { + replica.actor_node_id + for replica in self._replicas.get(active_states) + if replica.actor_node_id is not None + } def list_replica_details(self) -> List[ReplicaDetails]: return [replica.actor_details for replica in self._replicas.get()] diff --git a/python/ray/serve/_private/utils.py b/python/ray/serve/_private/utils.py index fad624f9ba33..e4b58f53c596 100644 --- a/python/ray/serve/_private/utils.py +++ b/python/ray/serve/_private/utils.py @@ -42,6 +42,7 @@ from ray._raylet import MessagePackSerializer from ray._private.utils import import_attr from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag +from ray._private.resource_spec import HEAD_NODE_RESOURCE_NAME import __main__ @@ -672,3 +673,19 @@ def call_function_from_import_path(import_path: str) -> Any: return callback_func() except Exception as e: raise RuntimeError(f"The function {import_path} raised an exception: {e}") + + +def get_head_node_id() -> str: + """Get the head node id. + + Iterate through all nodes in the ray cluster and return the node id of the first + alive node with head node resource. + """ + head_node_id = None + for node in ray.nodes(): + if HEAD_NODE_RESOURCE_NAME in node["Resources"] and node["Alive"]: + head_node_id = node["NodeID"] + break + assert head_node_id is not None, "Cannot find alive head node." + + return head_node_id diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 726635131747..af9c7cf32213 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -64,6 +64,7 @@ DEFAULT, override_runtime_envs_except_env_vars, call_function_from_import_path, + get_head_node_id, ) from ray.serve._private.application_state import ApplicationStateManager @@ -108,10 +109,14 @@ async def __init__( controller_name: str, *, http_config: HTTPOptions, - head_node_id: str, detached: bool = False, _disable_http_proxy: bool = False, ): + self._controller_node_id = ray.get_runtime_context().get_node_id() + assert ( + self._controller_node_id == get_head_node_id() + ), "Controller must be on the head node." + configure_component_logger( component_name="controller", component_id=str(os.getpid()) ) @@ -143,7 +148,7 @@ async def __init__( controller_name, detached, http_config, - head_node_id, + self._controller_node_id, gcs_client, ) @@ -186,7 +191,6 @@ async def __init__( run_background_task(self.run_control_loop()) self._recover_config_from_checkpoint() - self._head_node_id = head_node_id self._active_nodes = set() self._update_active_nodes() @@ -289,7 +293,7 @@ def _update_active_nodes(self): replicas). If the active nodes set changes, it will notify the long poll client. """ new_active_nodes = self.deployment_state_manager.get_active_node_ids() - new_active_nodes.add(self._head_node_id) + new_active_nodes.add(self._controller_node_id) if self._active_nodes != new_active_nodes: self._active_nodes = new_active_nodes self.long_poll_host.notify_changed( @@ -936,8 +940,6 @@ def __init__( except ValueError: self._controller = None if self._controller is None: - # Used for scheduling things to the head node explicitly. - head_node_id = ray.get_runtime_context().get_node_id() http_config = HTTPOptions() http_config.port = http_proxy_port self._controller = ServeController.options( @@ -952,7 +954,6 @@ def __init__( ).remote( controller_name, http_config=http_config, - head_node_id=head_node_id, detached=detached, ) diff --git a/python/ray/serve/tests/test_callback.py b/python/ray/serve/tests/test_callback.py index 837b41a92bdd..12396a9486a9 100644 --- a/python/ray/serve/tests/test_callback.py +++ b/python/ray/serve/tests/test_callback.py @@ -171,7 +171,6 @@ def test_callback_fail(ray_instance): handle = actor_def.remote( "controller", http_config={}, - head_node_id="123", ) with pytest.raises(RayActorError, match="cannot be imported"): ray.get(handle.check_alive.remote()) diff --git a/python/ray/serve/tests/test_deployment_state.py b/python/ray/serve/tests/test_deployment_state.py index d5ba9538a7ba..187ebe5a5ca2 100644 --- a/python/ray/serve/tests/test_deployment_state.py +++ b/python/ray/serve/tests/test_deployment_state.py @@ -73,6 +73,7 @@ def __init__( deployment_name: str, version: DeploymentVersion, scheduling_strategy="SPREAD", + node_id=None, ): self._actor_name = actor_name self._replica_tag = replica_tag @@ -99,6 +100,8 @@ def __init__( self._is_cross_language = False self._scheduling_strategy = scheduling_strategy self._actor_handle = MockActorHandle() + self._node_id = node_id + self._node_id_is_set = False @property def is_cross_language(self) -> bool: @@ -146,6 +149,8 @@ def worker_id(self) -> Optional[str]: @property def node_id(self) -> Optional[str]: + if self._node_id_is_set: + return self._node_id if isinstance(self._scheduling_strategy, NodeAffinitySchedulingStrategy): return self._scheduling_strategy.node_id if self.ready == ReplicaStartupStatus.SUCCEEDED or self.started: @@ -180,6 +185,10 @@ def set_starting_version(self, version: DeploymentVersion): """Mocked deployment_worker return version from reconfigure()""" self.starting_version = version + def set_node_id(self, node_id: str): + self._node_id = node_id + self._node_id_is_set = True + def start(self, deployment_info: DeploymentInfo): self.started = True @@ -2755,5 +2764,60 @@ def test_get_active_node_ids(mock_get_all_node_ids, mock_deployment_state_manage assert deployment_state_manager.get_active_node_ids() == set() +@patch.object(DriverDeploymentState, "_get_all_node_ids") +def test_get_active_node_ids_none( + mock_get_all_node_ids, mock_deployment_state_manager_full +): + """Test get_active_node_ids() are not collecting none node ids. + + When the running replicas has None as the node id, `get_active_node_ids()` should + not include it in the set. + """ + node_ids = ("node1", "node2", "node2") + mock_get_all_node_ids.return_value = [node_ids] + + tag = "test_deployment" + create_deployment_state_manager, _ = mock_deployment_state_manager_full + deployment_state_manager = create_deployment_state_manager() + + # Deploy deployment with version "1" and 3 replicas + info1, version1 = deployment_info(version="1", num_replicas=3) + updating = deployment_state_manager.deploy(tag, info1) + deployment_state = deployment_state_manager._deployment_states[tag] + assert updating + + # When the replicas are in the STARTING state, `get_active_node_ids()` should + # return a set of node ids. + deployment_state_manager.update() + check_counts( + deployment_state, + total=3, + version=version1, + by_state=[(ReplicaState.STARTING, 3)], + ) + mocked_replicas = deployment_state._replicas.get() + for idx, mocked_replica in enumerate(mocked_replicas): + mocked_replica._actor.set_scheduling_strategy( + NodeAffinitySchedulingStrategy(node_id=node_ids[idx], soft=True) + ) + assert deployment_state.get_active_node_ids() == set(node_ids) + assert deployment_state_manager.get_active_node_ids() == set(node_ids) + + # When the replicas are in the RUNNING state and are having None node id, + # `get_active_node_ids()` should return empty set. + for mocked_replica in mocked_replicas: + mocked_replica._actor.set_node_id(None) + mocked_replica._actor.set_ready() + deployment_state_manager.update() + check_counts( + deployment_state, + total=3, + version=version1, + by_state=[(ReplicaState.RUNNING, 3)], + ) + assert None not in deployment_state.get_active_node_ids() + assert None not in deployment_state_manager.get_active_node_ids() + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_http_state.py b/python/ray/serve/tests/test_http_state.py index 9aadbd268ca8..3a07af1c9753 100644 --- a/python/ray/serve/tests/test_http_state.py +++ b/python/ray/serve/tests/test_http_state.py @@ -57,7 +57,6 @@ def setup_controller(): ).remote( SERVE_CONTROLLER_NAME, http_config=None, - head_node_id=HEAD_NODE_ID, detached=True, _disable_http_proxy=True, ) diff --git a/python/ray/serve/tests/test_serve_ha.py b/python/ray/serve/tests/test_serve_ha.py index 87fe6211a313..8cef448f52ad 100644 --- a/python/ray/serve/tests/test_serve_ha.py +++ b/python/ray/serve/tests/test_serve_ha.py @@ -37,7 +37,7 @@ def pid(self): import os return {{"pid": os.getpid()}} -serve.start(detached=True) +serve.start(detached=True, http_options={{"location": "EveryNode"}}) Counter.options(num_replicas={num_replicas}).deploy() """ diff --git a/python/ray/serve/tests/test_standalone3.py b/python/ray/serve/tests/test_standalone3.py index 3689501fec84..444d1ee9d020 100644 --- a/python/ray/serve/tests/test_standalone3.py +++ b/python/ray/serve/tests/test_standalone3.py @@ -431,14 +431,14 @@ def test_healthz_and_routes_on_head_and_worker_nodes( # Setup a cluster with 2 nodes cluster = Cluster() - cluster.add_node(num_cpus=3) - cluster.add_node(num_cpus=3) + cluster.add_node(num_cpus=0) + cluster.add_node(num_cpus=2) cluster.wait_for_nodes() ray.init(address=cluster.address) serve.start(http_options={"location": "EveryNode"}) - # Deploy 2 replicas, one to each node - @serve.deployment(num_replicas=2, ray_actor_options={"num_cpus": 2}) + # Deploy 2 replicas, both should be on the worker node. + @serve.deployment(num_replicas=2) class HelloModel: def __call__(self): return "hello" @@ -446,21 +446,45 @@ def __call__(self): model = HelloModel.bind() serve.run(target=model) + # Ensure worker node has both replicas. + def check_replicas_on_worker_nodes(): + _actors = ray._private.state.actors().values() + replica_nodes = [ + a["Address"]["NodeID"] + for a in _actors + if a["ActorClassName"].startswith("ServeReplica") + ] + return len(set(replica_nodes)) == 1 + + wait_for_condition(check_replicas_on_worker_nodes) + # Ensure total actors of 2 proxies, 1 controller, and 2 replicas, and 2 nodes exist. wait_for_condition(lambda: len(ray._private.state.actors()) == 5) assert len(ray.nodes()) == 2 # Ensure `/-/healthz` and `/-/routes` return 200 and expected responses # on both nodes. - assert requests.get("http://127.0.0.1:8000/-/healthz").status_code == 200 - assert requests.get("http://127.0.0.1:8000/-/healthz").text == "success" + def check_request(url: str, expected_code: int, expected_text: str): + req = requests.get(url) + return req.status_code == expected_code and req.text == expected_text + + wait_for_condition( + condition_predictor=check_request, + url="http://127.0.0.1:8000/-/healthz", + expected_code=200, + expected_text="success", + ) assert requests.get("http://127.0.0.1:8000/-/routes").status_code == 200 assert ( requests.get("http://127.0.0.1:8000/-/routes").text == '{"/":"default_HelloModel"}' ) - assert requests.get("http://127.0.0.1:8001/-/healthz").status_code == 200 - assert requests.get("http://127.0.0.1:8001/-/healthz").text == "success" + wait_for_condition( + condition_predictor=check_request, + url="http://127.0.0.1:8001/-/healthz", + expected_code=200, + expected_text="success", + ) assert requests.get("http://127.0.0.1:8001/-/routes").status_code == 200 assert ( requests.get("http://127.0.0.1:8001/-/routes").text @@ -490,20 +514,25 @@ def _check(): # Ensure head node `/-/healthz` and `/-/routes` continue to return 200 and expected # responses. Also, the worker node `/-/healthz` and `/-/routes` should return 503 # and unavailable responses. - assert requests.get("http://127.0.0.1:8000/-/healthz").text == "success" - assert requests.get("http://127.0.0.1:8000/-/healthz").status_code == 200 - assert requests.get("http://127.0.0.1:8000/-/routes").text == "{}" + wait_for_condition( + condition_predictor=check_request, + url="http://127.0.0.1:8000/-/healthz", + expected_code=200, + expected_text="success", + ) assert requests.get("http://127.0.0.1:8000/-/routes").status_code == 200 - assert ( - requests.get("http://127.0.0.1:8001/-/healthz").text - == "This node is being drained." + assert requests.get("http://127.0.0.1:8000/-/routes").text == "{}" + wait_for_condition( + condition_predictor=check_request, + url="http://127.0.0.1:8001/-/healthz", + expected_code=503, + expected_text="This node is being drained.", ) - assert requests.get("http://127.0.0.1:8001/-/healthz").status_code == 503 + assert requests.get("http://127.0.0.1:8001/-/routes").status_code == 503 assert ( requests.get("http://127.0.0.1:8001/-/routes").text == "This node is being drained." ) - assert requests.get("http://127.0.0.1:8001/-/routes").status_code == 503 # Clean up serve. serve.shutdown() diff --git a/python/ray/serve/tests/test_util.py b/python/ray/serve/tests/test_util.py index 5237ee7bf98b..a5e477476b96 100644 --- a/python/ray/serve/tests/test_util.py +++ b/python/ray/serve/tests/test_util.py @@ -4,6 +4,7 @@ import sys import tempfile from copy import deepcopy +from unittest.mock import patch import numpy as np import pytest @@ -20,7 +21,9 @@ msgpack_deserialize, snake_to_camel_case, dict_keys_snake_to_camel_case, + get_head_node_id, ) +from ray._private.resource_spec import HEAD_NODE_RESOURCE_NAME def test_serialize(): @@ -531,6 +534,39 @@ def test_shallow_copy(self): assert camel_dict["nested"]["list2"] is list2 +def test_get_head_node_id(): + """Test get_head_node_id() returning the correct head node id. + + When there are woker node, dead head node, and other alive head nodes, + get_head_node_id() should return the node id of the first alive head node. + When there are no alive head nodes, get_head_node_id() should raise assertion error. + """ + nodes = [ + {"NodeID": "worker_node1", "Alive": True, "Resources": {"CPU": 1}}, + { + "NodeID": "dead_head_node1", + "Alive": False, + "Resources": {"CPU": 1, HEAD_NODE_RESOURCE_NAME: 1.0}, + }, + { + "NodeID": "alive_head_node1", + "Alive": True, + "Resources": {"CPU": 1, HEAD_NODE_RESOURCE_NAME: 1.0}, + }, + { + "NodeID": "alive_head_node2", + "Alive": True, + "Resources": {"CPU": 1, HEAD_NODE_RESOURCE_NAME: 1.0}, + }, + ] + with patch("ray.nodes", return_value=nodes): + assert get_head_node_id() == "alive_head_node1" + + with patch("ray.nodes", return_value=[]): + with pytest.raises(AssertionError): + get_head_node_id() + + if __name__ == "__main__": import sys