Skip to content

Commit

Permalink
Added the arrow flight interceptor to inject the auth header. (feast-…
Browse files Browse the repository at this point in the history
…dev#68)

* * Added the arrow flight interceptor to inject the auth header.
* Injecting grpc interceptor if it is needed when auth type is not NO_AUTH.

Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com>

* Fixing the failing integration test cases by setting the header in binary format.

Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com>

* Refactored method and moved to factory class to incorporate code review comment.
Fixed lint error by removing the type of port. and other minor changes.

Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com>

* Incorproating code review comments from Daniel.

Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com>

---------

Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com>
  • Loading branch information
lokeshrangineni authored Aug 11, 2024
1 parent c13f229 commit 0aad7a8
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 113 deletions.
61 changes: 23 additions & 38 deletions sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
RetrievalMetadata,
)
from feast.infra.registry.base_registry import BaseRegistry
from feast.permissions.client.utils import create_flight_call_options
from feast.permissions.client.arrow_flight_auth_interceptor import (
build_arrow_flight_client,
)
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage

Expand All @@ -47,7 +49,6 @@ class RemoteRetrievalJob(RetrievalJob):
def __init__(
self,
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
api: str,
api_parameters: Dict[str, Any],
entity_df: Union[pd.DataFrame, str] = None,
Expand All @@ -56,7 +57,6 @@ def __init__(
):
# Initialize the client connection
self.client = client
self.options = options
self.api = api
self.api_parameters = api_parameters
self.entity_df = entity_df
Expand All @@ -77,7 +77,6 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
self.entity_df,
self.table,
self.client,
self.options,
)

@property
Expand Down Expand Up @@ -118,7 +117,6 @@ def persist(
api=RemoteRetrievalJob.persist.__name__,
api_parameters=api_parameters,
client=self.client,
options=self.options,
table=self.table,
entity_df=self.entity_df,
)
Expand All @@ -137,9 +135,9 @@ def get_historical_features(
) -> RemoteRetrievalJob:
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)
client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
)

feature_view_names = [fv.name for fv in feature_views]
name_aliases = [fv.projection.name_alias for fv in feature_views]
Expand All @@ -154,7 +152,6 @@ def get_historical_features(

return RemoteRetrievalJob(
client=client,
options=options,
api=OfflineStore.get_historical_features.__name__,
api_parameters=api_parameters,
entity_df=entity_df,
Expand All @@ -174,8 +171,9 @@ def pull_all_from_table_or_query(
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)
client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
)

api_parameters = {
"data_source_name": data_source.name,
Expand All @@ -188,7 +186,6 @@ def pull_all_from_table_or_query(

return RemoteRetrievalJob(
client=client,
options=options,
api=OfflineStore.pull_all_from_table_or_query.__name__,
api_parameters=api_parameters,
)
Expand All @@ -207,8 +204,9 @@ def pull_latest_from_table_or_query(
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)
client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
)

api_parameters = {
"data_source_name": data_source.name,
Expand All @@ -222,7 +220,6 @@ def pull_latest_from_table_or_query(

return RemoteRetrievalJob(
client=client,
options=options,
api=OfflineStore.pull_latest_from_table_or_query.__name__,
api_parameters=api_parameters,
)
Expand All @@ -242,8 +239,9 @@ def write_logged_features(
data = pyarrow.parquet.read_table(data, use_threads=False, pre_buffer=False)

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)
client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
)

api_parameters = {
"feature_service_name": source._feature_service.name,
Expand All @@ -253,7 +251,6 @@ def write_logged_features(
api=OfflineStore.write_logged_features.__name__,
api_parameters=api_parameters,
client=client,
options=options,
table=data,
entity_df=None,
)
Expand All @@ -268,8 +265,9 @@ def offline_write_batch(
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)

# Initialize the client connection
client = RemoteOfflineStore.init_client(config)
options = create_flight_call_options(config.auth_config)
client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
)

feature_view_names = [feature_view.name]
name_aliases = [feature_view.projection.name_alias]
Expand All @@ -284,18 +282,10 @@ def offline_write_batch(
api=OfflineStore.offline_write_batch.__name__,
api_parameters=api_parameters,
client=client,
options=options,
table=table,
entity_df=None,
)

@staticmethod
def init_client(config):
location = f"grpc://{config.offline_store.host}:{config.offline_store.port}"
client = fl.connect(location=location)
logger.info(f"Connecting FlightClient at {location}")
return client


def _create_retrieval_metadata(feature_refs: List[str], entity_df: pd.DataFrame):
entity_schema = _get_entity_schema(
Expand Down Expand Up @@ -349,35 +339,31 @@ def _send_retrieve_remote(
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
):
command_descriptor = _call_put(
api,
api_parameters,
client,
options,
entity_df,
table,
)
return _call_get(client, options, command_descriptor)
return _call_get(client, command_descriptor)


def _call_get(
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
command_descriptor: fl.FlightDescriptor,
):
flight = client.get_flight_info(command_descriptor, options)
flight = client.get_flight_info(command_descriptor)
ticket = flight.endpoints[0].ticket
reader = client.do_get(ticket, options)
reader = client.do_get(ticket)
return reader.read_all()


def _call_put(
api: str,
api_parameters: Dict[str, Any],
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
):
Expand All @@ -397,7 +383,7 @@ def _call_put(
)
)

_put_parameters(command_descriptor, entity_df, table, client, options)
_put_parameters(command_descriptor, entity_df, table, client)
return command_descriptor


Expand All @@ -406,7 +392,6 @@ def _put_parameters(
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
client: fl.FlightClient,
options: pa.flight.FlightCallOptions,
):
updatedTable: pa.Table

Expand All @@ -417,7 +402,7 @@ def _put_parameters(
else:
updatedTable = _create_empty_table()

writer, _ = client.do_put(command_descriptor, updatedTable.schema, options)
writer, _ = client.do_put(command_descriptor, updatedTable.schema)

writer.write_table(updatedTable)
writer.close()
Expand Down
10 changes: 5 additions & 5 deletions sdk/python/feast/infra/registry/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from feast.infra.infra_object import Infra
from feast.infra.registry.base_registry import BaseRegistry
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import (
AuthConfig,
NoAuthConfig,
Expand Down Expand Up @@ -48,13 +49,12 @@ def __init__(
repo_path: Optional[Path],
auth_config: AuthConfig = NoAuthConfig(),
):
auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config)
self.auth_config = auth_config
channel = grpc.insecure_channel(registry_config.path)
self.intercepted_channel = grpc.intercept_channel(
channel, auth_header_interceptor
)
self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.intercepted_channel)
if self.auth_config.type != AuthType.NONE.value:
auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config)
channel = grpc.intercept_channel(channel, auth_header_interceptor)
self.stub = RegistryServer_pb2_grpc.RegistryServerStub(channel)

def apply_entity(self, entity: Entity, project: str, commit: bool = True):
request = RegistryServer_pb2.ApplyEntityRequest(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pyarrow.flight as fl

from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import AuthConfig
from feast.permissions.client.auth_client_manager_factory import get_auth_token


class FlightBearerTokenInterceptor(fl.ClientMiddleware):
def __init__(self, auth_config: AuthConfig):
super().__init__()
self.auth_config = auth_config

def call_completed(self, exception):
pass

def received_headers(self, headers):
pass

def sending_headers(self):
access_token = get_auth_token(self.auth_config)
return {b"authorization": b"Bearer " + access_token.encode("utf-8")}


class FlightAuthInterceptorFactory(fl.ClientMiddlewareFactory):
def __init__(self, auth_config: AuthConfig):
super().__init__()
self.auth_config = auth_config

def start_call(self, info):
return FlightBearerTokenInterceptor(self.auth_config)


def build_arrow_flight_client(host: str, port, auth_config: AuthConfig):
if auth_config.type != AuthType.NONE.value:
middleware_factory = FlightAuthInterceptorFactory(auth_config)
return fl.FlightClient(f"grpc://{host}:{port}", middleware=[middleware_factory])
else:
return fl.FlightClient(f"grpc://{host}:{port}")
30 changes: 0 additions & 30 deletions sdk/python/feast/permissions/client/auth_client_manager.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,8 @@
from abc import ABC, abstractmethod

from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import (
AuthConfig,
KubernetesAuthConfig,
OidcAuthConfig,
)


class AuthenticationClientManager(ABC):
@abstractmethod
def get_token(self) -> str:
"""Retrieves the token based on the authentication type configuration"""
pass


def get_auth_client_manager(auth_config: AuthConfig) -> AuthenticationClientManager:
if auth_config.type == AuthType.OIDC.value:
assert isinstance(auth_config, OidcAuthConfig)

from feast.permissions.client.oidc_authentication_client_manager import (
OidcAuthClientManager,
)

return OidcAuthClientManager(auth_config)
elif auth_config.type == AuthType.KUBERNETES.value:
assert isinstance(auth_config, KubernetesAuthConfig)

from feast.permissions.client.kubernetes_auth_client_manager import (
KubernetesAuthClientManager,
)

return KubernetesAuthClientManager(auth_config)
else:
raise RuntimeError(
f"No Auth client manager implemented for the auth type:${auth_config.type}"
)
30 changes: 30 additions & 0 deletions sdk/python/feast/permissions/client/auth_client_manager_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import (
AuthConfig,
KubernetesAuthConfig,
OidcAuthConfig,
)
from feast.permissions.client.auth_client_manager import AuthenticationClientManager
from feast.permissions.client.kubernetes_auth_client_manager import (
KubernetesAuthClientManager,
)
from feast.permissions.client.oidc_authentication_client_manager import (
OidcAuthClientManager,
)


def get_auth_client_manager(auth_config: AuthConfig) -> AuthenticationClientManager:
if auth_config.type == AuthType.OIDC.value:
assert isinstance(auth_config, OidcAuthConfig)
return OidcAuthClientManager(auth_config)
elif auth_config.type == AuthType.KUBERNETES.value:
assert isinstance(auth_config, KubernetesAuthConfig)
return KubernetesAuthClientManager(auth_config)
else:
raise RuntimeError(
f"No Auth client manager implemented for the auth type:${auth_config.type}"
)


def get_auth_token(auth_config: AuthConfig) -> str:
return get_auth_client_manager(auth_config).get_token()
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import grpc

from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import AuthConfig
from feast.permissions.client.auth_client_manager import get_auth_client_manager
from feast.permissions.client.auth_client_manager_factory import get_auth_token

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,16 +42,11 @@ def intercept_stream_stream(
return continuation(client_call_details, request_iterator)

def _append_auth_header_metadata(self, client_call_details):
if self._auth_type.type is not AuthType.NONE.value:
logger.info(
f"Intercepted the grpc api method {client_call_details.method} call to inject Authorization header "
f"token. "
)
metadata = client_call_details.metadata or []
auth_client_manager = get_auth_client_manager(self._auth_type)
access_token = auth_client_manager.get_token()
metadata.append(
(b"authorization", b"Bearer " + access_token.encode("utf-8"))
)
client_call_details = client_call_details._replace(metadata=metadata)
logger.debug(
"Intercepted the grpc api method call to inject Authorization header "
)
metadata = client_call_details.metadata or []
access_token = get_auth_token(self._auth_type)
metadata.append((b"authorization", b"Bearer " + access_token.encode("utf-8")))
client_call_details = client_call_details._replace(metadata=metadata)
return client_call_details
Loading

0 comments on commit 0aad7a8

Please sign in to comment.