From 8567c0f3733f327239f77d1ce65cf926c8de71c2 Mon Sep 17 00:00:00 2001 From: wa101 Date: Fri, 1 Nov 2024 11:25:25 +0100 Subject: [PATCH] fix token expiration for ray autoscaler Signed-off-by: wa101 --- .../_private/kuberay/autoscaling_config.py | 19 ++++---------- .../_private/kuberay/node_provider.py | 25 ++++++++++++++++--- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/python/ray/autoscaler/_private/kuberay/autoscaling_config.py b/python/ray/autoscaler/_private/kuberay/autoscaling_config.py index d659bf93da76..29d12af319b5 100644 --- a/python/ray/autoscaler/_private/kuberay/autoscaling_config.py +++ b/python/ray/autoscaler/_private/kuberay/autoscaling_config.py @@ -49,10 +49,10 @@ class AutoscalingConfigProducer: """ def __init__(self, ray_cluster_name, ray_cluster_namespace): - self._headers, self._verify = node_provider.load_k8s_secrets() - self._ray_cr_url = node_provider.url_from_resource( - namespace=ray_cluster_namespace, path=f"rayclusters/{ray_cluster_name}" + self.kubernetes_api_client = node_provider.KubernetesHttpApiClient( + namespace=ray_cluster_namespace ) + self._ray_cr_path = f"rayclusters/{ray_cluster_name}" def __call__(self): ray_cr = self._fetch_ray_cr_from_k8s_with_retries() @@ -67,7 +67,7 @@ def _fetch_ray_cr_from_k8s_with_retries(self) -> Dict[str, Any]: """ for i in range(1, MAX_RAYCLUSTER_FETCH_TRIES + 1): try: - return self._fetch_ray_cr_from_k8s() + return self.kubernetes_api_client.get(self._ray_cr_path) except requests.HTTPError as e: if i < MAX_RAYCLUSTER_FETCH_TRIES: logger.exception( @@ -80,15 +80,6 @@ def _fetch_ray_cr_from_k8s_with_retries(self) -> Dict[str, Any]: # This branch is inaccessible. Raise to satisfy mypy. raise AssertionError - def _fetch_ray_cr_from_k8s(self) -> Dict[str, Any]: - result = requests.get( - self._ray_cr_url, headers=self._headers, verify=self._verify - ) - if not result.status_code == 200: - result.raise_for_status() - ray_cr = result.json() - return ray_cr - def _derive_autoscaling_config_from_ray_cr(ray_cr: Dict[str, Any]) -> Dict[str, Any]: provider_config = _generate_provider_config(ray_cr["metadata"]["namespace"]) @@ -185,7 +176,7 @@ def _generate_legacy_autoscaling_config_fields() -> Dict[str, Any]: def _generate_available_node_types_from_ray_cr_spec( - ray_cr_spec: Dict[str, Any] + ray_cr_spec: Dict[str, Any], ) -> Dict[str, Any]: """Formats autoscaler "available_node_types" field based on the Ray CR's group specs. diff --git a/python/ray/autoscaler/_private/kuberay/node_provider.py b/python/ray/autoscaler/_private/kuberay/node_provider.py index 823afe34e548..13ad4a5ee332 100644 --- a/python/ray/autoscaler/_private/kuberay/node_provider.py +++ b/python/ray/autoscaler/_private/kuberay/node_provider.py @@ -1,3 +1,4 @@ +import datetime import json import logging import os @@ -48,6 +49,8 @@ # Key for GKE label that identifies which multi-host replica a pod belongs to REPLICA_INDEX_KEY = "replicaIndex" +TOKEN_REFRESH_PERIOD = datetime.timedelta(minutes=1) + # Design: # Each modification the autoscaler wants to make is posted to the API server goal state @@ -249,7 +252,18 @@ class KubernetesHttpApiClient(IKubernetesHttpApiClient): def __init__(self, namespace: str, kuberay_crd_version: str = KUBERAY_CRD_VER): self._kuberay_crd_version = kuberay_crd_version self._namespace = namespace - self._headers, self._verify = load_k8s_secrets() + self._token_expires_at = datetime.datetime.now() + TOKEN_REFRESH_PERIOD + self._headers, self._verify = None, None + + def _get_refreshed_headers_and_verify(self): + if (datetime.datetime.now() >= self._token_expires_at) or ( + self._headers is None or self._verify is None + ): + self._headers, self._verify = load_k8s_secrets() + self._token_expires_at = datetime.datetime.now() + TOKEN_REFRESH_PERIOD + return self._headers, self._verify + else: + return self._headers, self._verify def get(self, path: str) -> Dict[str, Any]: """Wrapper for REST GET of resource with proper headers. @@ -268,7 +282,9 @@ def get(self, path: str) -> Dict[str, Any]: path=path, kuberay_crd_version=self._kuberay_crd_version, ) - result = requests.get(url, headers=self._headers, verify=self._verify) + + headers, verify = self._get_refreshed_headers_and_verify() + result = requests.get(url, headers=headers, verify=verify) if not result.status_code == 200: result.raise_for_status() return result.json() @@ -291,11 +307,12 @@ def patch(self, path: str, payload: List[Dict[str, Any]]) -> Dict[str, Any]: path=path, kuberay_crd_version=self._kuberay_crd_version, ) + headers, verify = self._get_refreshed_headers_and_verify() result = requests.patch( url, json.dumps(payload), - headers={**self._headers, "Content-type": "application/json-patch+json"}, - verify=self._verify, + headers={**headers, "Content-type": "application/json-patch+json"}, + verify=verify, ) if not result.status_code == 200: result.raise_for_status()