Skip to content

Commit

Permalink
fix token expiration for ray autoscaler
Browse files Browse the repository at this point in the history
Signed-off-by: wa101 <wadhah.mahroug15@gmail.com>
  • Loading branch information
wadhah101 committed Nov 1, 2024
1 parent ba41ae9 commit 8567c0f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
19 changes: 5 additions & 14 deletions python/ray/autoscaler/_private/kuberay/autoscaling_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ class AutoscalingConfigProducer:
"""

def __init__(self, ray_cluster_name, ray_cluster_namespace):
self._headers, self._verify = node_provider.load_k8s_secrets()
self._ray_cr_url = node_provider.url_from_resource(
namespace=ray_cluster_namespace, path=f"rayclusters/{ray_cluster_name}"
self.kubernetes_api_client = node_provider.KubernetesHttpApiClient(
namespace=ray_cluster_namespace
)
self._ray_cr_path = f"rayclusters/{ray_cluster_name}"

def __call__(self):
ray_cr = self._fetch_ray_cr_from_k8s_with_retries()
Expand All @@ -67,7 +67,7 @@ def _fetch_ray_cr_from_k8s_with_retries(self) -> Dict[str, Any]:
"""
for i in range(1, MAX_RAYCLUSTER_FETCH_TRIES + 1):
try:
return self._fetch_ray_cr_from_k8s()
return self.kubernetes_api_client.get(self._ray_cr_path)
except requests.HTTPError as e:
if i < MAX_RAYCLUSTER_FETCH_TRIES:
logger.exception(
Expand All @@ -80,15 +80,6 @@ def _fetch_ray_cr_from_k8s_with_retries(self) -> Dict[str, Any]:
# This branch is inaccessible. Raise to satisfy mypy.
raise AssertionError

def _fetch_ray_cr_from_k8s(self) -> Dict[str, Any]:
result = requests.get(
self._ray_cr_url, headers=self._headers, verify=self._verify
)
if not result.status_code == 200:
result.raise_for_status()
ray_cr = result.json()
return ray_cr


def _derive_autoscaling_config_from_ray_cr(ray_cr: Dict[str, Any]) -> Dict[str, Any]:
provider_config = _generate_provider_config(ray_cr["metadata"]["namespace"])
Expand Down Expand Up @@ -185,7 +176,7 @@ def _generate_legacy_autoscaling_config_fields() -> Dict[str, Any]:


def _generate_available_node_types_from_ray_cr_spec(
ray_cr_spec: Dict[str, Any]
ray_cr_spec: Dict[str, Any],
) -> Dict[str, Any]:
"""Formats autoscaler "available_node_types" field based on the Ray CR's group
specs.
Expand Down
25 changes: 21 additions & 4 deletions python/ray/autoscaler/_private/kuberay/node_provider.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import json
import logging
import os
Expand Down Expand Up @@ -48,6 +49,8 @@
# Key for GKE label that identifies which multi-host replica a pod belongs to
REPLICA_INDEX_KEY = "replicaIndex"

TOKEN_REFRESH_PERIOD = datetime.timedelta(minutes=1)

# Design:

# Each modification the autoscaler wants to make is posted to the API server goal state
Expand Down Expand Up @@ -249,7 +252,18 @@ class KubernetesHttpApiClient(IKubernetesHttpApiClient):
def __init__(self, namespace: str, kuberay_crd_version: str = KUBERAY_CRD_VER):
self._kuberay_crd_version = kuberay_crd_version
self._namespace = namespace
self._headers, self._verify = load_k8s_secrets()
self._token_expires_at = datetime.datetime.now() + TOKEN_REFRESH_PERIOD
self._headers, self._verify = None, None

def _get_refreshed_headers_and_verify(self):
if (datetime.datetime.now() >= self._token_expires_at) or (
self._headers is None or self._verify is None
):
self._headers, self._verify = load_k8s_secrets()
self._token_expires_at = datetime.datetime.now() + TOKEN_REFRESH_PERIOD
return self._headers, self._verify
else:
return self._headers, self._verify

def get(self, path: str) -> Dict[str, Any]:
"""Wrapper for REST GET of resource with proper headers.
Expand All @@ -268,7 +282,9 @@ def get(self, path: str) -> Dict[str, Any]:
path=path,
kuberay_crd_version=self._kuberay_crd_version,
)
result = requests.get(url, headers=self._headers, verify=self._verify)

headers, verify = self._get_refreshed_headers_and_verify()
result = requests.get(url, headers=headers, verify=verify)
if not result.status_code == 200:
result.raise_for_status()
return result.json()
Expand All @@ -291,11 +307,12 @@ def patch(self, path: str, payload: List[Dict[str, Any]]) -> Dict[str, Any]:
path=path,
kuberay_crd_version=self._kuberay_crd_version,
)
headers, verify = self._get_refreshed_headers_and_verify()
result = requests.patch(
url,
json.dumps(payload),
headers={**self._headers, "Content-type": "application/json-patch+json"},
verify=self._verify,
headers={**headers, "Content-type": "application/json-patch+json"},
verify=verify,
)
if not result.status_code == 200:
result.raise_for_status()
Expand Down

0 comments on commit 8567c0f

Please sign in to comment.