diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index d382cb58..91bfc53b 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -22,11 +22,6 @@ ModelTypes = t.Union[RegisteredModel, ModelVersion, ModelArtifact] TModel = t.TypeVar("TModel", bound=ModelTypes) -DSC_CRD = "datasciencecluster.opendatahub.io/v1" -DEFAULT_NS = "kubeflow" -DSC_NS_CONFIG = "registriesNamespace" -EXTERNAL_ADDR_ANNOTATION = "routing.opendatahub.io/external-address-rest" - class ModelRegistry: """Model registry client.""" @@ -94,7 +89,14 @@ def __init__( @classmethod def from_service( - cls, name: str, author: str, *, ns: str | None = None, is_secure: bool = True + cls, + name: str, + author: str, + *, + ns: str | None = None, + is_secure: bool = True, + user_token: str | None = None, + custom_ca: str | None = None, ) -> ModelRegistry: """Create a client from a service name. @@ -105,61 +107,39 @@ def from_service( Keyword Args: ns: Namespace. Defaults to DSC registriesNamespace, or `kubeflow` if unavailable. is_secure: Whether to use a secure connection. Defaults to True. + user_token: The PEM-encoded user token as a string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH. + custom_ca: Path to the PEM-encoded root certificates as a string. Defaults to path on envvar CERT. """ - from kubernetes import client, config - - config.load_incluster_config() - if not ns: - kcustom = client.CustomObjectsApi() - g, v = DSC_CRD.split("/") - p = f"{g.split('.')[0]}s" - try: - dsc_raw = kcustom.list_cluster_custom_object( - group=g, - version=v, - plural=p, - ) - except client.ApiException as e: - msg = f"Failed to list {p}: {e}" - warn(msg, stacklevel=2) - ns = DEFAULT_NS - else: - ns = t.cast( - dict[str, t.Any], - dsc_raw["items"][0], - )["status"]["components"]["modelregistry"][DSC_NS_CONFIG] - - kcore = client.CoreV1Api() - serv = t.cast(client.V1Service, kcore.read_namespaced_service(name, ns)) - meta = t.cast(client.V1ObjectMeta, serv.metadata) - ext_addr = t.cast(dict[str, str], meta.annotations).get( - EXTERNAL_ADDR_ANNOTATION - ) - if ext_addr: - host, port = ext_addr.split(":") - host = f"https://{host}" - port = int(port) - elif not is_secure: - host = f"http://{meta.name}" - port = next( - ( - int(str(port.port)) - for port in t.cast( - list[client.V1ServicePort], - t.cast(client.V1ServiceSpec, serv.spec).ports, - ) - if port.app_protocol == "http" - ), - 8080, - ) - else: - msg = "No external address found for secure connection" - raise StoreError(msg) + from ._utils import Address, Kube + + with Kube(user_token) as kc: + if not ns: + res = kc.get_mr_ns() + if e := res.error: + warn(str(e), stacklevel=2) + ns = res.value + assert isinstance(ns, str) + + res = kc.get_service_addr(name, ns) + if e := res.error: + if not res.value: + raise e + warn(str(e), stacklevel=2) + addr = res.value + assert isinstance(addr, Address) + if addr.protocol != "https" and is_secure: + msg = "Service does not support secure connection. To proceed with insecure connection, set is_secure=False" + raise StoreError(msg) + host = f"{addr.protocol}://{addr.host}" + port = addr.port return cls( host, port, author=author, + is_secure=is_secure, + user_token=user_token, + custom_ca=custom_ca, ) def async_runner(self, coro: t.Any) -> t.Any: diff --git a/clients/python/src/model_registry/_utils.py b/clients/python/src/model_registry/_utils.py index b2a32cb8..a4da328d 100644 --- a/clients/python/src/model_registry/_utils.py +++ b/clients/python/src/model_registry/_utils.py @@ -2,10 +2,13 @@ import functools import inspect +import typing as t from collections.abc import Sequence -from typing import Any, Callable, TypeVar +from dataclasses import dataclass -CallableT = TypeVar("CallableT", bound=Callable[..., Any]) +from .exceptions import StoreError + +CallableT = t.TypeVar("CallableT", bound=t.Callable[..., t.Any]) # copied from https://github.com/Rapptz/RoboDanny @@ -29,7 +32,7 @@ def quote(string: str) -> str: # copied from https://github.com/openai/openai-python -def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: # noqa: C901 +def required_args(*variants: Sequence[str]) -> t.Callable[[CallableT], CallableT]: # noqa: C901 """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function. Useful for enforcing runtime validation of overloaded functions. @@ -107,3 +110,163 @@ def wrapper(*args: object, **kwargs: object) -> object: return wrapper # type: ignore return inner + + +T = t.TypeVar("T") + +E = t.TypeVar("E", bound=Exception) + + +@dataclass +class Result(t.Generic[T, E]): + value: T | None + error: E | None + + @property + def ok(self) -> bool: + return self.error is None + + @property + def has_value(self) -> bool: + return self.value is not None + + +class Address(t.NamedTuple): + protocol: str + host: str + port: int + + +@dataclass +class Kube: + user_token: str | None = None + # TODO: do we need to take care of custom CA config too? + from kubernetes import client, config + + DEFAULT_NS = "kubeflow" + DSC_CRD = "datasciencecluster.opendatahub.io/v1" + DSC_NS_CONFIG = "registriesNamespace" + EXTERNAL_ADDR_ANNOTATION = "routing.opendatahub.io/external-address-rest" + + def __post_init__(self): + self.config.load_incluster_config() + client = Kube.client.ApiClient() + self.sa_token = client.configuration.api_key["authorization"] + self.api_client = client + + def __enter__(self) -> Kube: + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.api_client.close() + + def try_get( + self, op: t.Callable[[], t.Any], as_user: bool = False + ) -> Result[t.Any, client.ApiException]: + if as_user and self.user_token is not None: + # NOTE: even though this config is consumed by the RESTClient, auth is refresh on every request: https://github.com/kubernetes-client/python/blob/b7ccf179f1b0194a0ed18e39fb063ef8a963fc6b/kubernetes/client/api_client.py#L166 + self.api_client.configuration.api_key["authorization"] = self.user_token + try: + return Result(op(), None) + except Kube.client.ApiException as e: + if e.status != 403: + raise e + return Result(None, e) + finally: + self.api_client.configuration.api_key["authorization"] = self.sa_token + + def try_get_with_any_token( + self, op: t.Callable[[], t.Any] + ) -> Result[t.Any, client.ApiException]: + res = self.try_get(op) + if res.error is not None and self.user_token: + res = self.try_get(op, as_user=True) + return res + + def get_default_dsc(self) -> Result[dict[str, t.Any], StoreError]: + kcustom = Kube.client.CustomObjectsApi(self.api_client) + + g, v = Kube.DSC_CRD.split("/") + p = f"{g.split('.')[0]}s" + + def list_dscs() -> t.Any: + return kcustom.list_cluster_custom_object( + group=g, + version=v, + plural=p, + ) + + res = self.try_get_with_any_token(list_dscs) + if dscs := res.value: + return Result( + t.cast( + dict[str, t.Any], + dscs["items"][0], + ), + None, + ) + return Result(None, StoreError(f"Failed to list {p}: {res.error}")) + + def get_mr_ns(self) -> Result[str, StoreError]: + res = self.get_default_dsc() + if dsc_raw := res.value: + return Result( + dsc_raw["status"]["components"]["modelregistry"][Kube.DSC_NS_CONFIG], + None, + ) + return Result(Kube.DEFAULT_NS, res.error) + + def get_namespaced_service( + self, name: str, ns: str + ) -> Result[client.V1Service, StoreError]: + kcore = self.client.CoreV1Api(self.api_client) + + def get_service() -> t.Any: + return kcore.read_namespaced_service(name, ns) + + res = self.try_get_with_any_token(get_service) + if serv := res.value: + return Result(t.cast(Kube.client.V1Service, serv), None) + return Result(None, StoreError(f"Failed to get service {name}: {res.error}")) + + def get_service_addr(self, name: str, ns: str) -> Result[Address, StoreError]: + res = self.get_namespaced_service(name, ns) + if res.error: + return Result(None, res.error) + + serv = res.value + assert serv is not None + meta = t.cast(Kube.client.V1ObjectMeta, serv.metadata) + ext_addr = t.cast(dict[str, str], meta.annotations).get( + Kube.EXTERNAL_ADDR_ANNOTATION + ) + err = None + if not ext_addr: + host = str(meta.name) + port_by_protocol = { + port.app_protocol: port + for port in t.cast( + list[Kube.client.V1ServicePort], + t.cast(Kube.client.V1ServiceSpec, serv.spec).ports, + ) + if port.app_protocol in ("http", "https") + } + if p := port_by_protocol.get("https"): + port = int(str(p.port)) + protocol = "https" + elif p := port_by_protocol.get("http"): + port = int(str(p.port)) + protocol = "http" + else: + err = StoreError(f"Service {name} has no http(s) ports") + port = 8080 + protocol = "http" + else: + from urllib.parse import urlparse + + parsed = urlparse(ext_addr) + protocol = parsed.scheme + host, port = parsed.netloc.split(":") + port = int(port) + + return Result(Address(protocol, host, port), err)