diff --git a/sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py b/sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py index 724c7df5ca..7ef84fbeae 100644 --- a/sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py +++ b/sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py @@ -2,7 +2,7 @@ from feast.permissions.auth.auth_type import AuthType from feast.permissions.auth_model import AuthConfig -from feast.permissions.client.auth_client_manager_factory import get_auth_token +from feast.permissions.client.client_auth_token import get_auth_token class FlightBearerTokenInterceptor(fl.ClientMiddleware): diff --git a/sdk/python/feast/permissions/client/auth_client_manager.py b/sdk/python/feast/permissions/client/auth_client_manager.py index 82f9b7433e..2151cfb409 100644 --- a/sdk/python/feast/permissions/client/auth_client_manager.py +++ b/sdk/python/feast/permissions/client/auth_client_manager.py @@ -1,8 +1,49 @@ +import os from abc import ABC, abstractmethod +from feast.permissions.auth.auth_type import AuthType +from feast.permissions.auth_model import ( + AuthConfig, + KubernetesAuthConfig, + OidcClientAuthConfig, +) + class AuthenticationClientManager(ABC): @abstractmethod def get_token(self) -> str: """Retrieves the token based on the authentication type configuration""" pass + + +class AuthenticationClientManagerFactory(ABC): + def __init__(self, auth_config: AuthConfig): + self.auth_config = auth_config + + def get_auth_client_manager(self) -> AuthenticationClientManager: + from feast.permissions.client.intra_comm_authentication_client_manager import ( + IntraCommAuthClientManager, + ) + from feast.permissions.client.kubernetes_auth_client_manager import ( + KubernetesAuthClientManager, + ) + from feast.permissions.client.oidc_authentication_client_manager import ( + OidcAuthClientManager, + ) + + intra_communication_base64 = os.getenv("INTRA_COMMUNICATION_BASE64") + if intra_communication_base64: + return IntraCommAuthClientManager( + self.auth_config, intra_communication_base64 + ) + + if self.auth_config.type == AuthType.OIDC.value: + assert isinstance(self.auth_config, OidcClientAuthConfig) + return OidcAuthClientManager(self.auth_config) + elif self.auth_config.type == AuthType.KUBERNETES.value: + assert isinstance(self.auth_config, KubernetesAuthConfig) + return KubernetesAuthClientManager(self.auth_config) + else: + raise RuntimeError( + f"No Auth client manager implemented for the auth type:${self.auth_config.type}" + ) diff --git a/sdk/python/feast/permissions/client/auth_client_manager_factory.py b/sdk/python/feast/permissions/client/auth_client_manager_factory.py deleted file mode 100644 index 359072f38e..0000000000 --- a/sdk/python/feast/permissions/client/auth_client_manager_factory.py +++ /dev/null @@ -1,41 +0,0 @@ -import os -from typing import cast - -from feast.permissions.auth.auth_type import AuthType -from feast.permissions.auth_model import ( - AuthConfig, - KubernetesAuthConfig, - OidcAuthConfig, - OidcClientAuthConfig, -) -from feast.permissions.client.auth_client_manager import AuthenticationClientManager -from feast.permissions.client.kubernetes_auth_client_manager import ( - KubernetesAuthClientManager, -) -from feast.permissions.client.oidc_authentication_client_manager import ( - OidcAuthClientManager, -) - - -def get_auth_client_manager(auth_config: AuthConfig) -> AuthenticationClientManager: - if auth_config.type == AuthType.OIDC.value: - intra_communication_base64 = os.getenv("INTRA_COMMUNICATION_BASE64") - # If intra server communication call - if intra_communication_base64: - assert isinstance(auth_config, OidcAuthConfig) - client_auth_config = cast(OidcClientAuthConfig, auth_config) - else: - assert isinstance(auth_config, OidcClientAuthConfig) - client_auth_config = auth_config - return OidcAuthClientManager(client_auth_config) - elif auth_config.type == AuthType.KUBERNETES.value: - assert isinstance(auth_config, KubernetesAuthConfig) - return KubernetesAuthClientManager(auth_config) - else: - raise RuntimeError( - f"No Auth client manager implemented for the auth type:${auth_config.type}" - ) - - -def get_auth_token(auth_config: AuthConfig) -> str: - return get_auth_client_manager(auth_config).get_token() diff --git a/sdk/python/feast/permissions/client/client_auth_token.py b/sdk/python/feast/permissions/client/client_auth_token.py new file mode 100644 index 0000000000..68821e3f9c --- /dev/null +++ b/sdk/python/feast/permissions/client/client_auth_token.py @@ -0,0 +1,14 @@ +from feast.permissions.auth_model import ( + AuthConfig, +) +from feast.permissions.client.auth_client_manager import ( + AuthenticationClientManagerFactory, +) + + +def get_auth_token(auth_config: AuthConfig) -> str: + return ( + AuthenticationClientManagerFactory(auth_config) + .get_auth_client_manager() + .get_token() + ) diff --git a/sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py b/sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py index 5155b80cb5..121735e351 100644 --- a/sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py +++ b/sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py @@ -4,7 +4,7 @@ from feast.errors import FeastError from feast.permissions.auth_model import AuthConfig -from feast.permissions.client.auth_client_manager_factory import get_auth_token +from feast.permissions.client.client_auth_token import get_auth_token logger = logging.getLogger(__name__) diff --git a/sdk/python/feast/permissions/client/http_auth_requests_wrapper.py b/sdk/python/feast/permissions/client/http_auth_requests_wrapper.py index 3232e25025..ba02fab8d8 100644 --- a/sdk/python/feast/permissions/client/http_auth_requests_wrapper.py +++ b/sdk/python/feast/permissions/client/http_auth_requests_wrapper.py @@ -5,7 +5,7 @@ from feast.permissions.auth_model import ( AuthConfig, ) -from feast.permissions.client.auth_client_manager_factory import get_auth_token +from feast.permissions.client.client_auth_token import get_auth_token class AuthenticatedRequestsSession(Session): diff --git a/sdk/python/feast/permissions/client/intra_comm_authentication_client_manager.py b/sdk/python/feast/permissions/client/intra_comm_authentication_client_manager.py new file mode 100644 index 0000000000..678e1f39e5 --- /dev/null +++ b/sdk/python/feast/permissions/client/intra_comm_authentication_client_manager.py @@ -0,0 +1,31 @@ +import logging + +import jwt + +from feast.permissions.auth.auth_type import AuthType +from feast.permissions.auth_model import AuthConfig +from feast.permissions.client.auth_client_manager import AuthenticationClientManager + +logger = logging.getLogger(__name__) + + +class IntraCommAuthClientManager(AuthenticationClientManager): + def __init__(self, auth_config: AuthConfig, intra_communication_base64: str): + self.auth_config = auth_config + self.intra_communication_base64 = intra_communication_base64 + + def get_token(self): + if self.auth_config.type == AuthType.OIDC.value: + payload = { + "preferred_username": f"{self.intra_communication_base64}", # Subject claim + } + elif self.auth_config.type == AuthType.KUBERNETES.value: + payload = { + "sub": f":::{self.intra_communication_base64}", # Subject claim + } + else: + raise RuntimeError( + f"No Auth client manager implemented for the auth type:{self.auth_config.type}" + ) + + return jwt.encode(payload, "") diff --git a/sdk/python/tests/unit/permissions/auth/client/test_authentication_client_manager_factory.py b/sdk/python/tests/unit/permissions/auth/client/test_authentication_client_manager_factory.py new file mode 100644 index 0000000000..5a6a8d70fa --- /dev/null +++ b/sdk/python/tests/unit/permissions/auth/client/test_authentication_client_manager_factory.py @@ -0,0 +1,55 @@ +import os +from unittest import mock + +import assertpy +import jwt +import pytest +import yaml + +from feast.permissions.auth.auth_type import AuthType +from feast.permissions.auth_model import ( + AuthConfig, +) +from feast.permissions.client.auth_client_manager import ( + AuthenticationClientManagerFactory, +) +from feast.permissions.client.intra_comm_authentication_client_manager import ( + IntraCommAuthClientManager, +) + + +@mock.patch.dict(os.environ, {"INTRA_COMMUNICATION_BASE64": "server_intra_com_val"}) +def test_authentication_client_manager_factory(auth_config): + raw_config = yaml.safe_load(auth_config) + auth_config = AuthConfig(type=raw_config["auth"]["type"]) + + authentication_client_manager_factory = AuthenticationClientManagerFactory( + auth_config + ) + + authentication_client_manager = ( + authentication_client_manager_factory.get_auth_client_manager() + ) + + if auth_config.type not in [AuthType.KUBERNETES.value, AuthType.OIDC.value]: + with pytest.raises( + RuntimeError, + match=f"No Auth client manager implemented for the auth type:{auth_config.type}", + ): + authentication_client_manager.get_token() + else: + token = authentication_client_manager.get_token() + + decoded_token = jwt.decode(token, options={"verify_signature": False}) + assertpy.assert_that(authentication_client_manager).is_type_of( + IntraCommAuthClientManager + ) + + if AuthType.KUBERNETES.value == auth_config.type: + assertpy.assert_that(decoded_token["sub"]).is_equal_to( + ":::server_intra_com_val" + ) + elif AuthType.OIDC.value in auth_config.type: + assertpy.assert_that(decoded_token["preferred_username"]).is_equal_to( + "server_intra_com_val" + )