Skip to content

Commit

Permalink
[Serve] Let the controller look up the head node and fix flaky standa…
Browse files Browse the repository at this point in the history
…lone3 healthz test (ray-project#36878)

- Make sure we are using wait_for_condition in the test (could take
time to broadcast).
- Remove head_node_id from controller init args and instead fetch it
in the controller init. Also remove it from serve_start in
_private/api.py.
- Add an assertion to check that the controller actually runs on the
head node (use ray.nodes() and look for head node resource).
- Filter Nones from the active node set in deployment_state. Add a
unit test for this, it should never return None.

Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
GeneDer authored and arvind-chandra committed Aug 31, 2023
1 parent 8850bea commit 6caac99
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 32 deletions.
5 changes: 0 additions & 5 deletions python/ray/serve/_private/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,6 @@ def _start_controller(
else:
controller_name = format_actor_name(get_random_letters(), SERVE_CONTROLLER_NAME)

# Used for scheduling things to the head node explicitly.
# Assumes that `serve.start` runs on the head node.
head_node_id = ray.get_runtime_context().get_node_id()
controller_actor_options = {
"num_cpus": 1 if dedicated_cpu else 0,
"name": controller_name,
Expand All @@ -164,7 +161,6 @@ def _start_controller(
controller = ServeController.options(**controller_actor_options).remote(
controller_name,
http_config=http_options,
head_node_id=head_node_id,
detached=detached,
_disable_http_proxy=True,
)
Expand All @@ -186,7 +182,6 @@ def _start_controller(
controller = ServeController.options(**controller_actor_options).remote(
controller_name,
http_config=http_options,
head_node_id=head_node_id,
detached=detached,
)

Expand Down
6 changes: 5 additions & 1 deletion python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,11 @@ def get_active_node_ids(self) -> Set[str]:
ReplicaState.RECOVERING,
ReplicaState.RUNNING,
]
return {replica.actor_node_id for replica in self._replicas.get(active_states)}
return {
replica.actor_node_id
for replica in self._replicas.get(active_states)
if replica.actor_node_id is not None
}

def list_replica_details(self) -> List[ReplicaDetails]:
return [replica.actor_details for replica in self._replicas.get()]
Expand Down
17 changes: 17 additions & 0 deletions python/ray/serve/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ray._raylet import MessagePackSerializer
from ray._private.utils import import_attr
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
from ray._private.resource_spec import HEAD_NODE_RESOURCE_NAME

import __main__

Expand Down Expand Up @@ -686,3 +687,19 @@ def call_function_from_import_path(import_path: str) -> Any:
return callback_func()
except Exception as e:
raise RuntimeError(f"The function {import_path} raised an exception: {e}")


def get_head_node_id() -> str:
"""Get the head node id.
Iterate through all nodes in the ray cluster and return the node id of the first
alive node with head node resource.
"""
head_node_id = None
for node in ray.nodes():
if HEAD_NODE_RESOURCE_NAME in node["Resources"] and node["Alive"]:
head_node_id = node["NodeID"]
break
assert head_node_id is not None, "Cannot find alive head node."

return head_node_id
15 changes: 8 additions & 7 deletions python/ray/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
DEFAULT,
override_runtime_envs_except_env_vars,
call_function_from_import_path,
get_head_node_id,
)
from ray.serve._private.application_state import ApplicationStateManager

Expand Down Expand Up @@ -108,10 +109,14 @@ async def __init__(
controller_name: str,
*,
http_config: HTTPOptions,
head_node_id: str,
detached: bool = False,
_disable_http_proxy: bool = False,
):
self._controller_node_id = ray.get_runtime_context().get_node_id()
assert (
self._controller_node_id == get_head_node_id()
), "Controller must be on the head node."

configure_component_logger(
component_name="controller", component_id=str(os.getpid())
)
Expand Down Expand Up @@ -143,7 +148,7 @@ async def __init__(
controller_name,
detached,
http_config,
head_node_id,
self._controller_node_id,
gcs_client,
)

Expand Down Expand Up @@ -186,7 +191,6 @@ async def __init__(
run_background_task(self.run_control_loop())

self._recover_config_from_checkpoint()
self._head_node_id = head_node_id
self._active_nodes = set()
self._update_active_nodes()

Expand Down Expand Up @@ -289,7 +293,7 @@ def _update_active_nodes(self):
replicas). If the active nodes set changes, it will notify the long poll client.
"""
new_active_nodes = self.deployment_state_manager.get_active_node_ids()
new_active_nodes.add(self._head_node_id)
new_active_nodes.add(self._controller_node_id)
if self._active_nodes != new_active_nodes:
self._active_nodes = new_active_nodes
self.long_poll_host.notify_changed(
Expand Down Expand Up @@ -936,8 +940,6 @@ def __init__(
except ValueError:
self._controller = None
if self._controller is None:
# Used for scheduling things to the head node explicitly.
head_node_id = ray.get_runtime_context().get_node_id()
http_config = HTTPOptions()
http_config.port = http_proxy_port
self._controller = ServeController.options(
Expand All @@ -952,7 +954,6 @@ def __init__(
).remote(
controller_name,
http_config=http_config,
head_node_id=head_node_id,
detached=detached,
)

Expand Down
1 change: 0 additions & 1 deletion python/ray/serve/tests/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def test_callback_fail(ray_instance):
handle = actor_def.remote(
"controller",
http_config={},
head_node_id="123",
)
with pytest.raises(RayActorError, match="cannot be imported"):
ray.get(handle.check_alive.remote())
Expand Down
64 changes: 64 additions & 0 deletions python/ray/serve/tests/test_deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
deployment_name: str,
version: DeploymentVersion,
scheduling_strategy="SPREAD",
node_id=None,
):
self._actor_name = actor_name
self._replica_tag = replica_tag
Expand All @@ -99,6 +100,8 @@ def __init__(
self._is_cross_language = False
self._scheduling_strategy = scheduling_strategy
self._actor_handle = MockActorHandle()
self._node_id = node_id
self._node_id_is_set = False

@property
def is_cross_language(self) -> bool:
Expand Down Expand Up @@ -146,6 +149,8 @@ def worker_id(self) -> Optional[str]:

@property
def node_id(self) -> Optional[str]:
if self._node_id_is_set:
return self._node_id
if isinstance(self._scheduling_strategy, NodeAffinitySchedulingStrategy):
return self._scheduling_strategy.node_id
if self.ready == ReplicaStartupStatus.SUCCEEDED or self.started:
Expand Down Expand Up @@ -180,6 +185,10 @@ def set_starting_version(self, version: DeploymentVersion):
"""Mocked deployment_worker return version from reconfigure()"""
self.starting_version = version

def set_node_id(self, node_id: str):
self._node_id = node_id
self._node_id_is_set = True

def start(self, deployment_info: DeploymentInfo):
self.started = True

Expand Down Expand Up @@ -2755,5 +2764,60 @@ def test_get_active_node_ids(mock_get_all_node_ids, mock_deployment_state_manage
assert deployment_state_manager.get_active_node_ids() == set()


@patch.object(DriverDeploymentState, "_get_all_node_ids")
def test_get_active_node_ids_none(
mock_get_all_node_ids, mock_deployment_state_manager_full
):
"""Test get_active_node_ids() are not collecting none node ids.
When the running replicas has None as the node id, `get_active_node_ids()` should
not include it in the set.
"""
node_ids = ("node1", "node2", "node2")
mock_get_all_node_ids.return_value = [node_ids]

tag = "test_deployment"
create_deployment_state_manager, _ = mock_deployment_state_manager_full
deployment_state_manager = create_deployment_state_manager()

# Deploy deployment with version "1" and 3 replicas
info1, version1 = deployment_info(version="1", num_replicas=3)
updating = deployment_state_manager.deploy(tag, info1)
deployment_state = deployment_state_manager._deployment_states[tag]
assert updating

# When the replicas are in the STARTING state, `get_active_node_ids()` should
# return a set of node ids.
deployment_state_manager.update()
check_counts(
deployment_state,
total=3,
version=version1,
by_state=[(ReplicaState.STARTING, 3)],
)
mocked_replicas = deployment_state._replicas.get()
for idx, mocked_replica in enumerate(mocked_replicas):
mocked_replica._actor.set_scheduling_strategy(
NodeAffinitySchedulingStrategy(node_id=node_ids[idx], soft=True)
)
assert deployment_state.get_active_node_ids() == set(node_ids)
assert deployment_state_manager.get_active_node_ids() == set(node_ids)

# When the replicas are in the RUNNING state and are having None node id,
# `get_active_node_ids()` should return empty set.
for mocked_replica in mocked_replicas:
mocked_replica._actor.set_node_id(None)
mocked_replica._actor.set_ready()
deployment_state_manager.update()
check_counts(
deployment_state,
total=3,
version=version1,
by_state=[(ReplicaState.RUNNING, 3)],
)
assert None not in deployment_state.get_active_node_ids()
assert None not in deployment_state_manager.get_active_node_ids()


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))
1 change: 0 additions & 1 deletion python/ray/serve/tests/test_http_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def setup_controller():
).remote(
SERVE_CONTROLLER_NAME,
http_config=None,
head_node_id=HEAD_NODE_ID,
detached=True,
_disable_http_proxy=True,
)
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/tests/test_serve_ha.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def pid(self):
import os
return {{"pid": os.getpid()}}
serve.start(detached=True)
serve.start(detached=True, http_options={{"location": "EveryNode"}})
Counter.options(num_replicas={num_replicas}).deploy()
"""
Expand Down
61 changes: 45 additions & 16 deletions python/ray/serve/tests/test_standalone3.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,36 +431,60 @@ def test_healthz_and_routes_on_head_and_worker_nodes(

# Setup a cluster with 2 nodes
cluster = Cluster()
cluster.add_node(num_cpus=3)
cluster.add_node(num_cpus=3)
cluster.add_node(num_cpus=0)
cluster.add_node(num_cpus=2)
cluster.wait_for_nodes()
ray.init(address=cluster.address)
serve.start(http_options={"location": "EveryNode"})

# Deploy 2 replicas, one to each node
@serve.deployment(num_replicas=2, ray_actor_options={"num_cpus": 2})
# Deploy 2 replicas, both should be on the worker node.
@serve.deployment(num_replicas=2)
class HelloModel:
def __call__(self):
return "hello"

model = HelloModel.bind()
serve.run(target=model)

# Ensure worker node has both replicas.
def check_replicas_on_worker_nodes():
_actors = ray._private.state.actors().values()
replica_nodes = [
a["Address"]["NodeID"]
for a in _actors
if a["ActorClassName"].startswith("ServeReplica")
]
return len(set(replica_nodes)) == 1

wait_for_condition(check_replicas_on_worker_nodes)

# Ensure total actors of 2 proxies, 1 controller, and 2 replicas, and 2 nodes exist.
wait_for_condition(lambda: len(ray._private.state.actors()) == 5)
assert len(ray.nodes()) == 2

# Ensure `/-/healthz` and `/-/routes` return 200 and expected responses
# on both nodes.
assert requests.get("http://127.0.0.1:8000/-/healthz").status_code == 200
assert requests.get("http://127.0.0.1:8000/-/healthz").text == "success"
def check_request(url: str, expected_code: int, expected_text: str):
req = requests.get(url)
return req.status_code == expected_code and req.text == expected_text

wait_for_condition(
condition_predictor=check_request,
url="http://127.0.0.1:8000/-/healthz",
expected_code=200,
expected_text="success",
)
assert requests.get("http://127.0.0.1:8000/-/routes").status_code == 200
assert (
requests.get("http://127.0.0.1:8000/-/routes").text
== '{"/":"default_HelloModel"}'
)
assert requests.get("http://127.0.0.1:8001/-/healthz").status_code == 200
assert requests.get("http://127.0.0.1:8001/-/healthz").text == "success"
wait_for_condition(
condition_predictor=check_request,
url="http://127.0.0.1:8001/-/healthz",
expected_code=200,
expected_text="success",
)
assert requests.get("http://127.0.0.1:8001/-/routes").status_code == 200
assert (
requests.get("http://127.0.0.1:8001/-/routes").text
Expand Down Expand Up @@ -490,20 +514,25 @@ def _check():
# Ensure head node `/-/healthz` and `/-/routes` continue to return 200 and expected
# responses. Also, the worker node `/-/healthz` and `/-/routes` should return 503
# and unavailable responses.
assert requests.get("http://127.0.0.1:8000/-/healthz").text == "success"
assert requests.get("http://127.0.0.1:8000/-/healthz").status_code == 200
assert requests.get("http://127.0.0.1:8000/-/routes").text == "{}"
wait_for_condition(
condition_predictor=check_request,
url="http://127.0.0.1:8000/-/healthz",
expected_code=200,
expected_text="success",
)
assert requests.get("http://127.0.0.1:8000/-/routes").status_code == 200
assert (
requests.get("http://127.0.0.1:8001/-/healthz").text
== "This node is being drained."
assert requests.get("http://127.0.0.1:8000/-/routes").text == "{}"
wait_for_condition(
condition_predictor=check_request,
url="http://127.0.0.1:8001/-/healthz",
expected_code=503,
expected_text="This node is being drained.",
)
assert requests.get("http://127.0.0.1:8001/-/healthz").status_code == 503
assert requests.get("http://127.0.0.1:8001/-/routes").status_code == 503
assert (
requests.get("http://127.0.0.1:8001/-/routes").text
== "This node is being drained."
)
assert requests.get("http://127.0.0.1:8001/-/routes").status_code == 503

# Clean up serve.
serve.shutdown()
Expand Down
Loading

0 comments on commit 6caac99

Please sign in to comment.