Skip to content

Commit

Permalink
fix: Added Offline Store Arrow client errors handler (feast-dev#4524)
Browse files Browse the repository at this point in the history
* fix: Added Offline Store Arrow client errors handler

Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>

* Added more tests

Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>

---------

Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>
  • Loading branch information
tmihalac authored Sep 17, 2024
1 parent c5a4d90 commit 7535b40
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 61 deletions.
49 changes: 49 additions & 0 deletions sdk/python/feast/arrow_error_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging
from functools import wraps

import pyarrow.flight as fl

from feast.errors import FeastError

logger = logging.getLogger(__name__)


def arrow_client_error_handling_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
mapped_error = FeastError.from_error_detail(_get_exception_data(e.args[0]))
if mapped_error is not None:
raise mapped_error
raise e

return wrapper


def arrow_server_error_handling_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if isinstance(e, FeastError):
raise fl.FlightError(e.to_error_detail())

return wrapper


def _get_exception_data(except_str) -> str:
substring = "Flight error: "

# Find the starting index of the substring
position = except_str.find(substring)
end_json_index = except_str.find("}")

if position != -1 and end_json_index != -1:
# Extract the part of the string after the substring
result = except_str[position + len(substring) : end_json_index + 1]
return result

return ""
68 changes: 60 additions & 8 deletions sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
import pyarrow as pa
import pyarrow.flight as fl
import pyarrow.parquet
from pyarrow import Schema
from pyarrow._flight import FlightCallOptions, FlightDescriptor, Ticket
from pydantic import StrictInt, StrictStr

from feast import OnDemandFeatureView
from feast.arrow_error_handler import arrow_client_error_handling_decorator
from feast.data_source import DataSource
from feast.feature_logging import (
FeatureServiceLoggingSource,
Expand All @@ -27,15 +30,54 @@
RetrievalMetadata,
)
from feast.infra.registry.base_registry import BaseRegistry
from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import AuthConfig
from feast.permissions.client.arrow_flight_auth_interceptor import (
build_arrow_flight_client,
FlightAuthInterceptorFactory,
)
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage

logger = logging.getLogger(__name__)


class FeastFlightClient(fl.FlightClient):
@arrow_client_error_handling_decorator
def get_flight_info(
self, descriptor: FlightDescriptor, options: FlightCallOptions = None
):
return super().get_flight_info(descriptor, options)

@arrow_client_error_handling_decorator
def do_get(self, ticket: Ticket, options: FlightCallOptions = None):
return super().do_get(ticket, options)

@arrow_client_error_handling_decorator
def do_put(
self,
descriptor: FlightDescriptor,
schema: Schema,
options: FlightCallOptions = None,
):
return super().do_put(descriptor, schema, options)

@arrow_client_error_handling_decorator
def list_flights(self, criteria: bytes = b"", options: FlightCallOptions = None):
return super().list_flights(criteria, options)

@arrow_client_error_handling_decorator
def list_actions(self, options: FlightCallOptions = None):
return super().list_actions(options)


def build_arrow_flight_client(host: str, port, auth_config: AuthConfig):
if auth_config.type != AuthType.NONE.value:
middlewares = [FlightAuthInterceptorFactory(auth_config)]
return FeastFlightClient(f"grpc://{host}:{port}", middleware=middlewares)

return FeastFlightClient(f"grpc://{host}:{port}")


class RemoteOfflineStoreConfig(FeastConfigBaseModel):
type: Literal["remote"] = "remote"
host: StrictStr
Expand All @@ -48,7 +90,7 @@ class RemoteOfflineStoreConfig(FeastConfigBaseModel):
class RemoteRetrievalJob(RetrievalJob):
def __init__(
self,
client: fl.FlightClient,
client: FeastFlightClient,
api: str,
api_parameters: Dict[str, Any],
entity_df: Union[pd.DataFrame, str] = None,
Expand Down Expand Up @@ -338,7 +380,7 @@ def _send_retrieve_remote(
api_parameters: Dict[str, Any],
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
client: fl.FlightClient,
client: FeastFlightClient,
):
command_descriptor = _call_put(
api,
Expand All @@ -351,19 +393,19 @@ def _send_retrieve_remote(


def _call_get(
client: fl.FlightClient,
client: FeastFlightClient,
command_descriptor: fl.FlightDescriptor,
):
flight = client.get_flight_info(command_descriptor)
ticket = flight.endpoints[0].ticket
reader = client.do_get(ticket)
return reader.read_all()
return read_all(reader)


def _call_put(
api: str,
api_parameters: Dict[str, Any],
client: fl.FlightClient,
client: FeastFlightClient,
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
):
Expand Down Expand Up @@ -391,7 +433,7 @@ def _put_parameters(
command_descriptor: fl.FlightDescriptor,
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
client: fl.FlightClient,
client: FeastFlightClient,
):
updatedTable: pa.Table

Expand All @@ -404,10 +446,20 @@ def _put_parameters(

writer, _ = client.do_put(command_descriptor, updatedTable.schema)

writer.write_table(updatedTable)
write_table(writer, updatedTable)


@arrow_client_error_handling_decorator
def write_table(writer, updated_table: pa.Table):
writer.write_table(updated_table)
writer.close()


@arrow_client_error_handling_decorator
def read_all(reader):
return reader.read_all()


def _create_empty_table():
schema = pa.schema(
{
Expand Down
54 changes: 37 additions & 17 deletions sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
import pyarrow.flight as fl

from feast import FeatureStore, FeatureView, utils
from feast.arrow_error_handler import arrow_server_error_handling_decorator
from feast.feature_logging import FeatureServiceLoggingSource
from feast.feature_view import DUMMY_ENTITY_NAME
from feast.infra.offline_stores.offline_utils import get_offline_store_from_config
from feast.permissions.action import AuthzedAction
from feast.permissions.security_manager import assert_permissions
from feast.permissions.server.arrow import (
arrowflight_middleware,
AuthorizationMiddlewareFactory,
inject_user_details_decorator,
)
from feast.permissions.server.utils import (
AuthManagerType,
ServerType,
init_auth_manager,
init_security_manager,
Expand All @@ -34,7 +36,7 @@ class OfflineServer(fl.FlightServerBase):
def __init__(self, store: FeatureStore, location: str, **kwargs):
super(OfflineServer, self).__init__(
location,
middleware=arrowflight_middleware(
middleware=self.arrow_flight_auth_middleware(
str_to_auth_manager_type(store.config.auth_config.type)
),
**kwargs,
Expand All @@ -45,6 +47,25 @@ def __init__(self, store: FeatureStore, location: str, **kwargs):
self.store = store
self.offline_store = get_offline_store_from_config(store.config.offline_store)

def arrow_flight_auth_middleware(
self,
auth_type: AuthManagerType,
) -> dict[str, fl.ServerMiddlewareFactory]:
"""
A dictionary with the configured middlewares to support extracting the user details when the authorization manager is defined.
The authorization middleware key is `auth`.
Returns:
dict[str, fl.ServerMiddlewareFactory]: Optional dictionary of middlewares. If the authorization type is set to `NONE`, it returns an empty dict.
"""

if auth_type == AuthManagerType.NONE:
return {}

return {
"auth": AuthorizationMiddlewareFactory(),
}

@classmethod
def descriptor_to_key(self, descriptor: fl.FlightDescriptor):
return (
Expand All @@ -61,15 +82,7 @@ def _make_flight_info(self, key: Any, descriptor: fl.FlightDescriptor):
return fl.FlightInfo(schema, descriptor, endpoints, -1, -1)

@inject_user_details_decorator
def get_flight_info(
self, context: fl.ServerCallContext, descriptor: fl.FlightDescriptor
):
key = OfflineServer.descriptor_to_key(descriptor)
if key in self.flights:
return self._make_flight_info(key, descriptor)
raise KeyError("Flight not found.")

@inject_user_details_decorator
@arrow_server_error_handling_decorator
def list_flights(self, context: fl.ServerCallContext, criteria: bytes):
for key, table in self.flights.items():
if key[1] is not None:
Expand All @@ -79,9 +92,20 @@ def list_flights(self, context: fl.ServerCallContext, criteria: bytes):

yield self._make_flight_info(key, descriptor)

@inject_user_details_decorator
@arrow_server_error_handling_decorator
def get_flight_info(
self, context: fl.ServerCallContext, descriptor: fl.FlightDescriptor
):
key = OfflineServer.descriptor_to_key(descriptor)
if key in self.flights:
return self._make_flight_info(key, descriptor)
raise KeyError("Flight not found.")

# Expects to receive request parameters and stores them in the flights dictionary
# Indexed by the unique command
@inject_user_details_decorator
@arrow_server_error_handling_decorator
def do_put(
self,
context: fl.ServerCallContext,
Expand Down Expand Up @@ -179,6 +203,7 @@ def _validate_do_get_parameters(self, command: dict):
# Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
# and returns the stream of data
@inject_user_details_decorator
@arrow_server_error_handling_decorator
def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
key = ast.literal_eval(ticket.ticket.decode())
if key not in self.flights:
Expand Down Expand Up @@ -337,6 +362,7 @@ def pull_latest_from_table_or_query(self, command: dict):
utils.make_tzaware(datetime.fromisoformat(command["end_date"])),
)

@arrow_server_error_handling_decorator
def list_actions(self, context):
return [
(
Expand Down Expand Up @@ -431,12 +457,6 @@ def persist(self, command: dict, key: str):
traceback.print_exc()
raise e

def do_action(self, context: fl.ServerCallContext, action: fl.Action):
pass

def do_drop_dataset(self, dataset):
pass


def remove_dummies(fv: FeatureView) -> FeatureView:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pyarrow.flight as fl

from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import AuthConfig
from feast.permissions.client.client_auth_token import get_auth_token

Expand Down Expand Up @@ -28,11 +27,3 @@ def __init__(self, auth_config: AuthConfig):

def start_call(self, info):
return FlightBearerTokenInterceptor(self.auth_config)


def build_arrow_flight_client(host: str, port, auth_config: AuthConfig):
if auth_config.type != AuthType.NONE.value:
middleware_factory = FlightAuthInterceptorFactory(auth_config)
return fl.FlightClient(f"grpc://{host}:{port}", middleware=[middleware_factory])
else:
return fl.FlightClient(f"grpc://{host}:{port}")
31 changes: 5 additions & 26 deletions sdk/python/feast/permissions/server/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
import functools
import logging
from typing import Optional, cast
from typing import cast

import pyarrow.flight as fl
from pyarrow.flight import ServerCallContext
Expand All @@ -14,41 +14,19 @@
get_auth_manager,
)
from feast.permissions.security_manager import get_security_manager
from feast.permissions.server.utils import (
AuthManagerType,
)
from feast.permissions.user import User

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def arrowflight_middleware(
auth_type: AuthManagerType,
) -> Optional[dict[str, fl.ServerMiddlewareFactory]]:
"""
A dictionary with the configured middlewares to support extracting the user details when the authorization manager is defined.
The authorization middleware key is `auth`.
Returns:
dict[str, fl.ServerMiddlewareFactory]: Optional dictionary of middlewares. If the authorization type is set to `NONE`, it returns `None`.
"""

if auth_type == AuthManagerType.NONE:
return None

return {
"auth": AuthorizationMiddlewareFactory(),
}


class AuthorizationMiddlewareFactory(fl.ServerMiddlewareFactory):
"""
A middleware factory to intercept the authorization header and propagate it to the authorization middleware.
"""

def __init__(self):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def start_call(self, info, headers):
"""
Expand All @@ -65,7 +43,8 @@ class AuthorizationMiddleware(fl.ServerMiddleware):
A server middleware holding the authorization header and offering a method to extract the user credentials.
"""

def __init__(self, access_token: str):
def __init__(self, access_token: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.access_token = access_token

def call_completed(self, exception):
Expand Down
Loading

0 comments on commit 7535b40

Please sign in to comment.