Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the arrow flight interceptor to inject the auth header. #68

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 23 additions & 38 deletions sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
RetrievalMetadata,
)
from feast.infra.registry.base_registry import BaseRegistry
from feast.permissions.client.utils import create_flight_call_options
from feast.permissions.client.arrow_flight_auth_interceptor import (
build_arrow_flight_client,
)
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage

Expand All @@ -47,7 +49,6 @@ class RemoteRetrievalJob(RetrievalJob):
def __init__(
self,
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
api: str,
api_parameters: Dict[str, Any],
entity_df: Union[pd.DataFrame, str] = None,
Expand All @@ -56,7 +57,6 @@ def __init__(
):
# Initialize the client connection
self.client = client
self.options = options
self.api = api
self.api_parameters = api_parameters
self.entity_df = entity_df
Expand All @@ -77,7 +77,6 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
self.entity_df,
self.table,
self.client,
self.options,
)

@property
Expand Down Expand Up @@ -118,7 +117,6 @@ def persist(
api=RemoteRetrievalJob.persist.__name__,
api_parameters=api_parameters,
client=self.client,
options=self.options,
table=self.table,
entity_df=self.entity_df,
)
Expand All @@ -137,9 +135,9 @@ def get_historical_features(
) -> RemoteRetrievalJob:
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)
client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
)

feature_view_names = [fv.name for fv in feature_views]
name_aliases = [fv.projection.name_alias for fv in feature_views]
Expand All @@ -154,7 +152,6 @@ def get_historical_features(

return RemoteRetrievalJob(
client=client,
options=options,
api=OfflineStore.get_historical_features.__name__,
api_parameters=api_parameters,
entity_df=entity_df,
Expand All @@ -174,8 +171,9 @@ def pull_all_from_table_or_query(
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)
client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
)

api_parameters = {
"data_source_name": data_source.name,
Expand All @@ -188,7 +186,6 @@ def pull_all_from_table_or_query(

return RemoteRetrievalJob(
client=client,
options=options,
api=OfflineStore.pull_all_from_table_or_query.__name__,
api_parameters=api_parameters,
)
Expand All @@ -207,8 +204,9 @@ def pull_latest_from_table_or_query(
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)
client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
)

api_parameters = {
"data_source_name": data_source.name,
Expand All @@ -222,7 +220,6 @@ def pull_latest_from_table_or_query(

return RemoteRetrievalJob(
client=client,
options=options,
api=OfflineStore.pull_latest_from_table_or_query.__name__,
api_parameters=api_parameters,
)
Expand All @@ -242,8 +239,9 @@ def write_logged_features(
data = pyarrow.parquet.read_table(data, use_threads=False, pre_buffer=False)

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)
client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
)

api_parameters = {
"feature_service_name": source._feature_service.name,
Expand All @@ -253,7 +251,6 @@ def write_logged_features(
api=OfflineStore.write_logged_features.__name__,
api_parameters=api_parameters,
client=client,
options=options,
table=data,
entity_df=None,
)
Expand All @@ -268,8 +265,9 @@ def offline_write_batch(
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)
client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
)

feature_view_names = [feature_view.name]
name_aliases = [feature_view.projection.name_alias]
Expand All @@ -284,18 +282,10 @@ def offline_write_batch(
api=OfflineStore.offline_write_batch.__name__,
api_parameters=api_parameters,
client=client,
options=options,
table=table,
entity_df=None,
)

@staticmethod
def init_client(config):
location = f"grpc://{config.offline_store.host}:{config.offline_store.port}"
client = fl.connect(location=location)
logger.info(f"Connecting FlightClient at {location}")
return client


def _create_retrieval_metadata(feature_refs: List[str], entity_df: pd.DataFrame):
entity_schema = _get_entity_schema(
Expand Down Expand Up @@ -349,35 +339,31 @@ def _send_retrieve_remote(
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
):
command_descriptor = _call_put(
api,
api_parameters,
client,
options,
entity_df,
table,
)
return _call_get(client, options, command_descriptor)
return _call_get(client, command_descriptor)


def _call_get(
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
command_descriptor: fl.FlightDescriptor,
):
flight = client.get_flight_info(command_descriptor, options)
flight = client.get_flight_info(command_descriptor)
ticket = flight.endpoints[0].ticket
reader = client.do_get(ticket, options)
reader = client.do_get(ticket)
return reader.read_all()


def _call_put(
api: str,
api_parameters: Dict[str, Any],
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
):
Expand All @@ -397,7 +383,7 @@ def _call_put(
)
)

_put_parameters(command_descriptor, entity_df, table, client, options)
_put_parameters(command_descriptor, entity_df, table, client)
return command_descriptor


Expand All @@ -406,7 +392,6 @@ def _put_parameters(
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
):
updatedTable: pa.Table

Expand All @@ -417,7 +402,7 @@ def _put_parameters(
else:
updatedTable = _create_empty_table()

writer, _ = client.do_put(command_descriptor, updatedTable.schema, options)
writer, _ = client.do_put(command_descriptor, updatedTable.schema)

writer.write_table(updatedTable)
writer.close()
Expand Down
10 changes: 5 additions & 5 deletions sdk/python/feast/infra/registry/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from feast.infra.infra_object import Infra
from feast.infra.registry.base_registry import BaseRegistry
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import (
AuthConfig,
NoAuthConfig,
Expand Down Expand Up @@ -48,13 +49,12 @@ def __init__(
repo_path: Optional[Path],
auth_config: AuthConfig = NoAuthConfig(),
):
auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config)
self.auth_config = auth_config
channel = grpc.insecure_channel(registry_config.path)
self.intercepted_channel = grpc.intercept_channel(
channel, auth_header_interceptor
)
self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.intercepted_channel)
if self.auth_config.type != AuthType.NONE.value:
auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config)
channel = grpc.intercept_channel(channel, auth_header_interceptor)
self.stub = RegistryServer_pb2_grpc.RegistryServerStub(channel)

def apply_entity(self, entity: Entity, project: str, commit: bool = True):
request = RegistryServer_pb2.ApplyEntityRequest(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pyarrow.flight as fl

from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import AuthConfig
from feast.permissions.client.auth_client_manager_factory import get_auth_token


class FlightBearerTokenInterceptor(fl.ClientMiddleware):
def __init__(self, auth_config: AuthConfig):
super().__init__()
self.auth_config = auth_config

def call_completed(self, exception):
pass

def received_headers(self, headers):
pass

def sending_headers(self):
access_token = get_auth_token(self.auth_config)
return {b"authorization": b"Bearer " + access_token.encode("utf-8")}


class FlightAuthInterceptorFactory(fl.ClientMiddlewareFactory):
def __init__(self, auth_config: AuthConfig):
super().__init__()
self.auth_config = auth_config

def start_call(self, info):
return FlightBearerTokenInterceptor(self.auth_config)


def build_arrow_flight_client(host: str, port, auth_config: AuthConfig):
if auth_config.type != AuthType.NONE.value:
middleware_factory = FlightAuthInterceptorFactory(auth_config)
return fl.FlightClient(f"grpc://{host}:{port}", middleware=[middleware_factory])
else:
return fl.FlightClient(f"grpc://{host}:{port}")
30 changes: 0 additions & 30 deletions sdk/python/feast/permissions/client/auth_client_manager.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,8 @@
from abc import ABC, abstractmethod

from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import (
AuthConfig,
KubernetesAuthConfig,
OidcAuthConfig,
)


class AuthenticationClientManager(ABC):
@abstractmethod
def get_token(self) -> str:
"""Retrieves the token based on the authentication type configuration"""
pass


def get_auth_client_manager(auth_config: AuthConfig) -> AuthenticationClientManager:
if auth_config.type == AuthType.OIDC.value:
assert isinstance(auth_config, OidcAuthConfig)

from feast.permissions.client.oidc_authentication_client_manager import (
OidcAuthClientManager,
)

return OidcAuthClientManager(auth_config)
elif auth_config.type == AuthType.KUBERNETES.value:
assert isinstance(auth_config, KubernetesAuthConfig)

from feast.permissions.client.kubernetes_auth_client_manager import (
KubernetesAuthClientManager,
)

return KubernetesAuthClientManager(auth_config)
else:
raise RuntimeError(
f"No Auth client manager implemented for the auth type:${auth_config.type}"
)
30 changes: 30 additions & 0 deletions sdk/python/feast/permissions/client/auth_client_manager_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import (
AuthConfig,
KubernetesAuthConfig,
OidcAuthConfig,
)
from feast.permissions.client.auth_client_manager import AuthenticationClientManager
from feast.permissions.client.kubernetes_auth_client_manager import (
KubernetesAuthClientManager,
)
from feast.permissions.client.oidc_authentication_client_manager import (
OidcAuthClientManager,
)


def get_auth_client_manager(auth_config: AuthConfig) -> AuthenticationClientManager:
if auth_config.type == AuthType.OIDC.value:
assert isinstance(auth_config, OidcAuthConfig)
return OidcAuthClientManager(auth_config)
elif auth_config.type == AuthType.KUBERNETES.value:
assert isinstance(auth_config, KubernetesAuthConfig)
return KubernetesAuthClientManager(auth_config)
else:
raise RuntimeError(
f"No Auth client manager implemented for the auth type:${auth_config.type}"
)


def get_auth_token(auth_config: AuthConfig) -> str:
return get_auth_client_manager(auth_config).get_token()
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import grpc

from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import AuthConfig
from feast.permissions.client.auth_client_manager import get_auth_client_manager
from feast.permissions.client.auth_client_manager_factory import get_auth_token

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,16 +42,11 @@ def intercept_stream_stream(
return continuation(client_call_details, request_iterator)

def _append_auth_header_metadata(self, client_call_details):
if self._auth_type.type is not AuthType.NONE.value:
logger.info(
f"Intercepted the grpc api method {client_call_details.method} call to inject Authorization header "
f"token. "
)
metadata = client_call_details.metadata or []
auth_client_manager = get_auth_client_manager(self._auth_type)
access_token = auth_client_manager.get_token()
metadata.append(
(b"authorization", b"Bearer " + access_token.encode("utf-8"))
)
client_call_details = client_call_details._replace(metadata=metadata)
logger.debug(
"Intercepted the grpc api method call to inject Authorization header "
)
metadata = client_call_details.metadata or []
access_token = get_auth_token(self._auth_type)
metadata.append((b"authorization", b"Bearer " + access_token.encode("utf-8")))
client_call_details = client_call_details._replace(metadata=metadata)
return client_call_details
Loading
Loading