Skip to content

Commit

Permalink
fix: Remote apply using offline store (feast-dev#4559)
Browse files Browse the repository at this point in the history
* remote apply using offline store

Signed-off-by: Daniele Martinoli <dmartino@redhat.com>

* passing data source proto to the offline server

Signed-off-by: Daniele Martinoli <dmartino@redhat.com>

* fixed linting, added permission asserts

Signed-off-by: Daniele Martinoli <dmartino@redhat.com>

---------

Signed-off-by: Daniele Martinoli <dmartino@redhat.com>
  • Loading branch information
dmartinol authored Oct 15, 2024
1 parent ba05893 commit ac62a32
Show file tree
Hide file tree
Showing 11 changed files with 282 additions and 64 deletions.
19 changes: 14 additions & 5 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,16 +602,23 @@ def _make_inferences(

# New feature views may reference previously applied entities.
entities = self._list_entities()
provider = self._get_provider()
update_feature_views_with_inferred_features_and_entities(
views_to_update, entities + entities_to_update, self.config
provider,
views_to_update,
entities + entities_to_update,
self.config,
)
update_feature_views_with_inferred_features_and_entities(
sfvs_to_update, entities + entities_to_update, self.config
provider,
sfvs_to_update,
entities + entities_to_update,
self.config,
)
# We need to attach the time stamp fields to the underlying data sources
# and cascade the dependencies
update_feature_views_with_inferred_features_and_entities(
odfvs_to_update, entities + entities_to_update, self.config
provider, odfvs_to_update, entities + entities_to_update, self.config
)
# TODO(kevjumba): Update schema inference
for sfv in sfvs_to_update:
Expand Down Expand Up @@ -1529,9 +1536,12 @@ def write_to_offline_store(
feature_view_name, allow_registry_cache=allow_registry_cache
)

provider = self._get_provider()
# Get columns of the batch source and the input dataframe.
column_names_and_types = (
feature_view.batch_source.get_table_column_names_and_types(self.config)
provider.get_table_column_names_and_types_from_data_source(
self.config, feature_view.batch_source
)
)
source_columns = [column for column, _ in column_names_and_types]
input_columns = df.columns.values.tolist()
Expand All @@ -1545,7 +1555,6 @@ def write_to_offline_store(
df = df.reindex(columns=source_columns)

table = pa.Table.from_pandas(df)
provider = self._get_provider()
provider.ingest_df_to_offline_store(feature_view, table)

def get_online_features(
Expand Down
10 changes: 8 additions & 2 deletions sdk/python/feast/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from feast.infra.offline_stores.file_source import FileSource
from feast.infra.offline_stores.redshift_source import RedshiftSource
from feast.infra.offline_stores.snowflake_source import SnowflakeSource
from feast.infra.provider import Provider
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.repo_config import RepoConfig
from feast.stream_feature_view import StreamFeatureView
Expand Down Expand Up @@ -95,6 +96,7 @@ def update_data_sources_with_inferred_event_timestamp_col(


def update_feature_views_with_inferred_features_and_entities(
provider: Provider,
fvs: Union[List[FeatureView], List[StreamFeatureView], List[OnDemandFeatureView]],
entities: List[Entity],
config: RepoConfig,
Expand Down Expand Up @@ -176,6 +178,7 @@ def update_feature_views_with_inferred_features_and_entities(

if run_inference_for_entities or run_inference_for_features:
_infer_features_and_entities(
provider,
fv,
join_keys,
run_inference_for_features,
Expand All @@ -193,6 +196,7 @@ def update_feature_views_with_inferred_features_and_entities(


def _infer_features_and_entities(
provider: Provider,
fv: Union[FeatureView, OnDemandFeatureView],
join_keys: Set[Optional[str]],
run_inference_for_features,
Expand Down Expand Up @@ -222,8 +226,10 @@ def _infer_features_and_entities(
columns_to_exclude.remove(mapped_col)
columns_to_exclude.add(original_col)

table_column_names_and_types = fv.batch_source.get_table_column_names_and_types(
config
table_column_names_and_types = (
provider.get_table_column_names_and_types_from_data_source(
config, fv.batch_source
)
)

for col_name, col_datatype in table_column_names_and_types:
Expand Down
27 changes: 25 additions & 2 deletions sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@
from abc import ABC
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
List,
Optional,
Tuple,
Union,
)

import pandas as pd
import pyarrow
Expand Down Expand Up @@ -352,8 +361,8 @@ def offline_write_batch(
"""
raise NotImplementedError

@staticmethod
def validate_data_source(
self,
config: RepoConfig,
data_source: DataSource,
):
Expand All @@ -365,3 +374,17 @@ def validate_data_source(
data_source: DataSource object that needs to be validated
"""
data_source.validate(config=config)

def get_table_column_names_and_types_from_data_source(
self,
config: RepoConfig,
data_source: DataSource,
) -> Iterable[Tuple[str, str]]:
"""
Returns the list of column names and raw column types for a DataSource.
Args:
config: Configuration object used to configure a feature store.
data_source: DataSource object
"""
return data_source.get_table_column_names_and_types(config=config)
53 changes: 52 additions & 1 deletion sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -328,6 +328,57 @@ def offline_write_batch(
entity_df=None,
)

def validate_data_source(
self,
config: RepoConfig,
data_source: DataSource,
):
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)

client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
)

api_parameters = {
"data_source_proto": str(data_source),
}
logger.debug(f"validating DataSource {data_source.name}")
_call_put(
api=OfflineStore.validate_data_source.__name__,
api_parameters=api_parameters,
client=client,
table=None,
entity_df=None,
)

def get_table_column_names_and_types_from_data_source(
self, config: RepoConfig, data_source: DataSource
) -> Iterable[Tuple[str, str]]:
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)

client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
)

api_parameters = {
"data_source_proto": str(data_source),
}
logger.debug(
f"Calling {OfflineStore.get_table_column_names_and_types_from_data_source.__name__} with {api_parameters}"
)
table = _send_retrieve_remote(
api=OfflineStore.get_table_column_names_and_types_from_data_source.__name__,
api_parameters=api_parameters,
client=client,
table=None,
entity_df=None,
)

logger.debug(
f"get_table_column_names_and_types_from_data_source for {data_source.name}: {table}"
)
return zip(table.column("name").to_pylist(), table.column("type").to_pylist())


def _create_retrieval_metadata(feature_refs: List[str], entity_df: pd.DataFrame):
entity_schema = _get_entity_schema(
Expand Down
20 changes: 19 additions & 1 deletion sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)

import pandas as pd
import pyarrow as pa
Expand Down Expand Up @@ -455,3 +466,10 @@ def validate_data_source(
data_source: DataSource,
):
self.offline_store.validate_data_source(config=config, data_source=data_source)

def get_table_column_names_and_types_from_data_source(
self, config: RepoConfig, data_source: DataSource
) -> Iterable[Tuple[str, str]]:
return self.offline_store.get_table_column_names_and_types_from_data_source(
config=config, data_source=data_source
)
26 changes: 25 additions & 1 deletion sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)

import pandas as pd
import pyarrow
Expand Down Expand Up @@ -405,6 +416,19 @@ def validate_data_source(
"""
pass

@abstractmethod
def get_table_column_names_and_types_from_data_source(
self, config: RepoConfig, data_source: DataSource
) -> Iterable[Tuple[str, str]]:
"""
Returns the list of column names and raw column types for a DataSource.
Args:
config: Configuration object used to configure a feature store.
data_source: DataSource object
"""
pass


def get_provider(config: RepoConfig) -> Provider:
if "." not in config.provider:
Expand Down
6 changes: 2 additions & 4 deletions sdk/python/feast/infra/registry/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from feast.infra.infra_object import Infra
from feast.infra.registry.base_registry import BaseRegistry
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import AuthConfig, NoAuthConfig
from feast.permissions.client.grpc_client_auth_interceptor import (
GrpcClientAuthHeaderInterceptor,
Expand Down Expand Up @@ -67,9 +66,8 @@ def __init__(
):
self.auth_config = auth_config
self.channel = grpc.insecure_channel(registry_config.path)
if self.auth_config.type != AuthType.NONE.value:
auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config)
self.channel = grpc.intercept_channel(self.channel, auth_header_interceptor)
auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config)
self.channel = grpc.intercept_channel(self.channel, auth_header_interceptor)
self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.channel)

def close(self):
Expand Down
46 changes: 46 additions & 0 deletions sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import pyarrow as pa
import pyarrow.flight as fl
from google.protobuf.json_format import Parse

from feast import FeatureStore, FeatureView, utils
from feast.arrow_error_handler import arrow_server_error_handling_decorator
from feast.data_source import DataSource
from feast.feature_logging import FeatureServiceLoggingSource
from feast.feature_view import DUMMY_ENTITY_NAME
from feast.infra.offline_stores.offline_utils import get_offline_store_from_config
Expand All @@ -26,6 +28,7 @@
init_security_manager,
str_to_auth_manager_type,
)
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.saved_dataset import SavedDatasetStorage

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -138,6 +141,9 @@ def _call_api(self, api: str, command: dict, key: str):
elif api == OfflineServer.persist.__name__:
self.persist(command, key)
remove_data = True
elif api == OfflineServer.validate_data_source.__name__:
self.validate_data_source(command)
remove_data = True
except Exception as e:
remove_data = True
logger.exception(e)
Expand Down Expand Up @@ -224,6 +230,11 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
table = self.pull_all_from_table_or_query(command).to_arrow()
elif api == OfflineServer.pull_latest_from_table_or_query.__name__:
table = self.pull_latest_from_table_or_query(command).to_arrow()
elif (
api
== OfflineServer.get_table_column_names_and_types_from_data_source.__name__
):
table = self.get_table_column_names_and_types_from_data_source(command)
else:
raise NotImplementedError
except Exception as e:
Expand Down Expand Up @@ -457,6 +468,41 @@ def persist(self, command: dict, key: str):
traceback.print_exc()
raise e

@staticmethod
def _extract_data_source_from_command(command) -> DataSource:
data_source_proto_str = command["data_source_proto"]
logger.debug(f"Extracted data_source_proto {data_source_proto_str}")
data_source_proto = DataSourceProto()
Parse(data_source_proto_str, data_source_proto)
data_source = DataSource.from_proto(data_source_proto)
logger.debug(f"Converted to DataSource {data_source}")
return data_source

def validate_data_source(self, command: dict):
data_source = OfflineServer._extract_data_source_from_command(command)
logger.debug(f"Validating data source {data_source.name}")
assert_permissions(data_source, actions=[AuthzedAction.READ_OFFLINE])

self.offline_store.validate_data_source(
config=self.store.config,
data_source=data_source,
)

def get_table_column_names_and_types_from_data_source(self, command: dict):
data_source = OfflineServer._extract_data_source_from_command(command)
logger.debug(f"Fetching table columns metadata from {data_source.name}")
assert_permissions(data_source, actions=[AuthzedAction.READ_OFFLINE])

column_names_and_types = data_source.get_table_column_names_and_types(
self.store.config
)

column_names, types = zip(*column_names_and_types)
logger.debug(
f"DataSource {data_source.name} has columns {column_names} with types {types}"
)
return pa.table({"name": column_names, "type": types})


def remove_dummies(fv: FeatureView) -> FeatureView:
"""
Expand Down
Loading

0 comments on commit ac62a32

Please sign in to comment.