From dcb27de9a5d7e8bc831c0321a344e7b1d5424758 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 01:24:47 -0700 Subject: [PATCH 01/26] update entity retrieval and add duckdb Signed-off-by: cmuhao --- sdk/python/feast/feature_store.py | 20 +- .../contrib/duckdb_online_store/__init__.py | 0 .../contrib/duckdb_online_store/duckdb.py | 166 ++++++++++++++ .../online_stores/contrib/elasticsearch.py | 103 +++++---- .../infra/online_stores/contrib/postgres.py | 131 +++++------ .../feast/infra/online_stores/online_store.py | 115 +++++----- .../feast/infra/online_stores/sqlite.py | 30 +-- sdk/python/feast/utils.py | 216 +++++++++++------- 8 files changed, 494 insertions(+), 287 deletions(-) create mode 100644 sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/__init__.py create mode 100644 sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 4f96cfb0fc..9bbd271c6a 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -85,6 +85,7 @@ from feast.stream_feature_view import StreamFeatureView from feast.utils import _utc_now from feast.version import get_version +from protos.feast.types.EntityKey_pb2 import EntityKey warnings.simplefilter("once", DeprecationWarning) @@ -1666,20 +1667,19 @@ def retrieve_online_documents( distance_metric, ) - # TODO Refactor to better way of populating result - # TODO populate entity in the response after returning entity in document_features is supported # TODO currently not return the vector value since it is same as feature value, if embedding is supported, # the feature value can be raw text before embedded + entity_key_vals = [feature[1] for feature in document_features] document_feature_vals = [feature[2] for feature in document_features] document_feature_distance_vals = [feature[4] for feature in document_features] online_features_response = GetOnlineFeaturesResponse(results=[]) utils._populate_result_rows_from_columnar( online_features_response=online_features_response, - data={requested_feature: document_feature_vals}, - ) - utils._populate_result_rows_from_columnar( - online_features_response=online_features_response, - data={"distance": document_feature_distance_vals}, + data={ + "entity_key": entity_key_vals, + requested_feature: document_feature_vals, + "distance": document_feature_distance_vals + }, ) return OnlineResponse(online_features_response) @@ -1691,7 +1691,7 @@ def _retrieve_from_online_store( query: List[float], top_k: int, distance_metric: Optional[str], - ) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value, Value]]: + ) -> List[Tuple[Timestamp, EntityKey, "FieldStatus.ValueType", Value, Value, Value]]: """ Search and return document features from the online document store. """ @@ -1707,7 +1707,7 @@ def _retrieve_from_online_store( read_row_protos = [] row_ts_proto = Timestamp() - for row_ts, feature_val, vector_value, distance_val in documents: + for row_ts, entity_key, feature_val, vector_value, distance_val in documents: # Reset timestamp to default or update if row_ts is not None if row_ts is not None: row_ts_proto.FromDatetime(row_ts) @@ -1721,7 +1721,7 @@ def _retrieve_from_online_store( status = FieldStatus.PRESENT read_row_protos.append( - (row_ts_proto, status, feature_val, vector_value, distance_val) + (row_ts_proto, entity_key, status, feature_val, vector_value, distance_val) ) return read_row_protos diff --git a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/__init__.py b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py new file mode 100644 index 0000000000..4fd83724d4 --- /dev/null +++ b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py @@ -0,0 +1,166 @@ +import contextlib +from datetime import datetime + +import duckdb +from typing import Optional, Dict, Any, List, Tuple, Union +from feast import Entity +from feast.feature_view import FeatureView +from feast.infra.online_stores.online_store import OnlineStore +from feast.repo_config import RepoConfig +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from infra.key_encoding_utils import serialize_entity_key +from utils import _build_retrieve_online_document_results + + +class DuckDBOnlineStoreConfig: + type: str = "duckdb" + path: str + enable_vector_search: bool = False # New option for enabling vector search + dimension: Optional[int] = 512 + distance_metric: Optional[str] = "L2" + + +class DuckDBOnlineStore(OnlineStore): + async def online_read_async(self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None) -> List[ + Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + pass + + def __init__(self, + config: DuckDBOnlineStoreConfig): + self.config = config + self.connection = None + + @contextlib.contextmanager + def _get_conn(self, + config: RepoConfig) -> Any: + if self.connection is None: + self.connection = duckdb.connect(database=self.config.path, read_only=False) + yield self.connection + + def create_vector_index(self, + config: RepoConfig, + table_name: str, + vector_column: str) -> None: + """Create an HNSW index for vector similarity search.""" + if not config.enable_vector_search: + raise ValueError("Vector search is not enabled in the configuration.") + distance_metric = config.distance_metric + + with self._get_conn(None) as conn: + conn.execute( + f"CREATE INDEX idx ON {table_name} USING HNSW ({vector_column}) WITH (metric = '{distance_metric}');" + ) + + def online_write_batch( + self, + config: RepoConfig, + table: FeatureView, + data: List[Tuple[EntityKeyProto, Dict[str, ValueProto]]], + ) -> None: + insert_values = [] + for entity_key, values in data: + entity_key_bin = serialize_entity_key(entity_key).hex() + for feature_name, val in values.items(): + insert_values.append((entity_key_bin, feature_name, val.SerializeToString())) + + with self._get_conn(config) as conn: + conn.execute(f"CREATE TABLE IF NOT EXISTS {table.name} (entity_key BLOB, feature_name TEXT, value BLOB)") + conn.executemany( + f"INSERT INTO {table.name} (entity_key, feature_name, value) VALUES (?, ?, ?)", + insert_values + ) + + def online_read( + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, + ) -> List[Tuple[Optional[Dict[str, ValueProto]]]]: + keys = [serialize_entity_key(key).hex() for key in entity_keys] + query = f"SELECT feature_name, value FROM {table.name} WHERE entity_key IN ({','.join(['?'] * len(keys))})" + + with self._get_conn(config) as conn: + results = conn.execute(query, keys).fetchall() + + return [{feature_name: ValueProto().ParseFromString(value) for feature_name, value in results}] + + def retrieve_online_documents( + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + distance_metric: Optional[str] = "L2", + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] + ]: + """Perform a vector similarity search using the HNSW index.""" + if not self.config.enable_vector_search: + raise ValueError("Vector search is not enabled in the configuration.") + if config.entity_key_serialization_version < 3: + raise ValueError("Entity key serialization version must be at least 3 for vector search.") + + result: List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] + ] = [] + + with self._get_conn(config) as conn: + query = f""" + SELECT + entity_key, + feature_name, + value, + vector_value, + event_ts + FROM {table.name} + WHERE feature_name = '{requested_feature}' + ORDER BY array_distance(vec, ?::FLOAT[]) LIMIT ?; + """ + rows = conn.execute(query, (embedding, top_k)).fetchall() + result = _build_retrieve_online_document_results( + rows, + entity_key_serialization_version=config.entity_key_serialization_version + ) + + return result + + def update( + self, + config: RepoConfig, + tables_to_delete: List[FeatureView], + tables_to_keep: List[FeatureView], + ) -> None: + with self._get_conn(config) as conn: + for table in tables_to_delete: + conn.execute(f"DROP TABLE IF EXISTS {table.name}") + for table in tables_to_keep: + conn.execute( + f"CREATE TABLE IF NOT EXISTS {table.name} (entity_key BLOB, feature_name TEXT, value BLOB)") + + def teardown( + self, + config: RepoConfig, + tables: List[FeatureView], + ) -> None: + with self._get_conn(config) as conn: + for table in tables: + conn.execute(f"DROP TABLE IF EXISTS {table.name}") diff --git a/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py b/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py index c26b4199ae..1a3d2fbfc2 100644 --- a/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py @@ -9,12 +9,12 @@ from elasticsearch import Elasticsearch, helpers from feast import Entity, FeatureView, RepoConfig -from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key +from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key, deserialize_entity_key from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel -from feast.utils import to_naive_utc +from feast.utils import to_naive_utc, _build_retrieve_online_document_record class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel): @@ -46,7 +46,8 @@ class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel): class ElasticSearchOnlineStore(OnlineStore): _client: Optional[Elasticsearch] = None - def _get_client(self, config: RepoConfig) -> Elasticsearch: + def _get_client(self, + config: RepoConfig) -> Elasticsearch: online_store_config = config.online_store assert isinstance(online_store_config, ElasticSearchOnlineStoreConfig) @@ -72,7 +73,9 @@ def _get_client(self, config: RepoConfig) -> Elasticsearch: ) return self._client - def _bulk_batch_actions(self, table: FeatureView, batch: List[Dict[str, Any]]): + def _bulk_batch_actions(self, + table: FeatureView, + batch: List[Dict[str, Any]]): for row in batch: yield { "_index": table.name, @@ -81,13 +84,13 @@ def _bulk_batch_actions(self, table: FeatureView, batch: List[Dict[str, Any]]): } def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], ) -> None: insert_values = [] for entity_key, values, timestamp, created_ts in data: @@ -117,16 +120,16 @@ def online_write_batch( batch_size = config.online_store.write_batch_size for i in range(0, len(insert_values), batch_size): - batch = insert_values[i : i + batch_size] + batch = insert_values[i: i + batch_size] actions = self._bulk_batch_actions(table, batch) helpers.bulk(self._get_client(config), actions) def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: if not requested_features: body = { @@ -156,7 +159,9 @@ def online_read( ) return results - def create_index(self, config: RepoConfig, table: FeatureView): + def create_index(self, + config: RepoConfig, + table: FeatureView): """ Create an index in ElasticSearch for the given table. TODO: This method can be exposed to users to customize the indexing functionality. @@ -184,13 +189,13 @@ def create_index(self, config: RepoConfig, table: FeatureView): ) def update( - self, - config: RepoConfig, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, ): # implement the update method for table in tables_to_delete: @@ -199,10 +204,10 @@ def update( self.create_index(config, table) def teardown( - self, - config: RepoConfig, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ): project = config.project try: @@ -213,17 +218,18 @@ def teardown( raise def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_feature: str, - embedding: List[float], - top_k: int, - *args, - **kwargs, + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + *args, + **kwargs, ) -> List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], @@ -232,6 +238,7 @@ def retrieve_online_documents( result: List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], @@ -247,23 +254,19 @@ def retrieve_online_documents( ) rows = response["hits"]["hits"][0:top_k] for row in rows: + entity_key = row["_source"]["entity_key"] feature_value = row["_source"]["feature_value"] vector_value = row["_source"]["vector_value"] timestamp = row["_source"]["timestamp"] distance = row["_score"] timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") - feature_value_proto = ValueProto() - feature_value_proto.ParseFromString(base64.b64decode(feature_value)) - - vector_value_proto = ValueProto(string_val=str(vector_value)) - distance_value_proto = ValueProto(float_val=distance) - result.append( - ( - timestamp, - feature_value_proto, - vector_value_proto, - distance_value_proto, - ) - ) + result.append(_build_retrieve_online_document_record( + timestamp, + entity_key, + base64.b64decode(feature_value), + str(vector_value), + distance, + config.entity_key_serialization_version + )) return result diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 8c6d3e0b99..487bb96b90 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -37,6 +37,7 @@ from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RepoConfig +from utils import _build_retrieve_online_document_record SUPPORTED_DISTANCE_METRICS_DICT = { "cosine": "<=>", @@ -64,7 +65,8 @@ class PostgreSQLOnlineStore(OnlineStore): _conn_pool_async: Optional[AsyncConnectionPool] = None @contextlib.contextmanager - def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]: + def _get_conn(self, + config: RepoConfig) -> Generator[Connection, Any, Any]: assert config.online_store.type == "postgres" if config.online_store.conn_type == ConnectionType.pool: @@ -81,7 +83,8 @@ def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]: @contextlib.asynccontextmanager async def _get_conn_async( - self, config: RepoConfig + self, + config: RepoConfig ) -> AsyncGenerator[AsyncConnection, Any]: if config.online_store.conn_type == ConnectionType.pool: if not self._conn_pool_async: @@ -98,13 +101,13 @@ async def _get_conn_async( yield self._conn_async def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], ) -> None: # Format insert values insert_values = [] @@ -156,11 +159,11 @@ def online_write_batch( progress(len(data)) def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version) query, params = self._construct_query_and_params( @@ -174,11 +177,11 @@ def online_read( return self._process_rows(keys, rows) async def online_read_async( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version) query, params = self._construct_query_and_params( @@ -194,10 +197,10 @@ async def online_read_async( @staticmethod def _construct_query_and_params( - config: RepoConfig, - table: FeatureView, - keys: List[bytes], - requested_features: Optional[List[str]] = None, + config: RepoConfig, + table: FeatureView, + keys: List[bytes], + requested_features: Optional[List[str]] = None, ) -> Tuple[sql.Composed, Union[Tuple[List[bytes], List[str]], Tuple[List[bytes]]]]: """Construct the SQL query based on the given parameters.""" if requested_features: @@ -224,7 +227,8 @@ def _construct_query_and_params( @staticmethod def _prepare_keys( - entity_keys: List[EntityKeyProto], entity_key_serialization_version: int + entity_keys: List[EntityKeyProto], + entity_key_serialization_version: int ) -> List[bytes]: """Prepare all keys in a list to make fewer round trips to the database.""" return [ @@ -237,7 +241,8 @@ def _prepare_keys( @staticmethod def _process_rows( - keys: List[bytes], rows: List[Tuple] + keys: List[bytes], + rows: List[Tuple] ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: """Transform the retrieved rows in the desired output. @@ -266,13 +271,13 @@ def _process_rows( return result def update( - self, - config: RepoConfig, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, ): project = config.project schema_name = config.online_store.db_schema or config.online_store.user @@ -334,10 +339,10 @@ def update( conn.commit() def teardown( - self, - config: RepoConfig, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ): project = config.project try: @@ -350,16 +355,17 @@ def teardown( raise def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_feature: str, - embedding: List[float], - top_k: int, - distance_metric: Optional[str] = "L2", + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + distance_metric: Optional[str] = "L2", ) -> List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], @@ -397,6 +403,7 @@ def retrieve_online_documents( result: List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], @@ -430,37 +437,21 @@ def retrieve_online_documents( (query_embedding_str,), ) rows = cur.fetchall() - - for ( - entity_key, - feature_name, - value, - vector_value, - distance, - event_ts, - ) in rows: - # TODO Deserialize entity_key to return the entity in response - # entity_key_proto = EntityKeyProto() - # entity_key_proto_bin = bytes(entity_key) - - feature_value_proto = ValueProto() - feature_value_proto.ParseFromString(bytes(value)) - - vector_value_proto = ValueProto(string_val=vector_value) - distance_value_proto = ValueProto(float_val=distance) - result.append( - ( - event_ts, - feature_value_proto, - vector_value_proto, - distance_value_proto, - ) - ) + for entity_key, _, feature_val, vector_value, distance_val, event_ts in rows: + result.append(_build_retrieve_online_document_record( + event_ts, + entity_key, + feature_val, + vector_value, + distance_val, + config.entity_key_serialization_version + )) return result -def _table_id(project: str, table: FeatureView) -> str: +def _table_id(project: str, + table: FeatureView) -> str: return f"{project}_{table.name}" diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index 9cf2ef95f6..19a03c04f2 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -36,13 +36,13 @@ class OnlineStore(ABC): @abstractmethod def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], ) -> None: """ Writes a batch of feature rows to the online store. @@ -62,11 +62,11 @@ def online_write_batch( @abstractmethod def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: """ Reads features values for the given entity keys. @@ -85,11 +85,11 @@ def online_read( pass async def online_read_async( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: """ Reads features values for the given entity keys asynchronously. @@ -110,16 +110,16 @@ async def online_read_async( ) def get_online_features( - self, - config: RepoConfig, - features: Union[List[str], FeatureService], - entity_rows: Union[ - List[Dict[str, Any]], - Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], - ], - registry: BaseRegistry, - project: str, - full_feature_names: bool = False, + self, + config: RepoConfig, + features: Union[List[str], FeatureService], + entity_rows: Union[ + List[Dict[str, Any]], + Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], + ], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, ) -> OnlineResponse: if isinstance(entity_rows, list): columnar: Dict[str, List[Any]] = {k: [] for k in entity_rows[0].keys()} @@ -197,16 +197,16 @@ def get_online_features( return OnlineResponse(online_features_response) async def get_online_features_async( - self, - config: RepoConfig, - features: Union[List[str], FeatureService], - entity_rows: Union[ - List[Dict[str, Any]], - Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], - ], - registry: BaseRegistry, - project: str, - full_feature_names: bool = False, + self, + config: RepoConfig, + features: Union[List[str], FeatureService], + entity_rows: Union[ + List[Dict[str, Any]], + Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], + ], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, ) -> OnlineResponse: if isinstance(entity_rows, list): columnar: Dict[str, List[Any]] = {k: [] for k in entity_rows[0].keys()} @@ -285,13 +285,13 @@ async def get_online_features_async( @abstractmethod def update( - self, - config: RepoConfig, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, ): """ Reconciles cloud resources with the specified set of Feast objects. @@ -310,7 +310,9 @@ def update( pass def plan( - self, config: RepoConfig, desired_registry_proto: RegistryProto + self, + config: RepoConfig, + desired_registry_proto: RegistryProto ) -> List[InfraObject]: """ Returns the set of InfraObjects required to support the desired registry. @@ -323,10 +325,10 @@ def plan( @abstractmethod def teardown( - self, - config: RepoConfig, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ): """ Tears down all cloud resources for the specified set of Feast objects. @@ -339,16 +341,17 @@ def teardown( pass def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_feature: str, - embedding: List[float], - top_k: int, - distance_metric: Optional[str] = None, + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + distance_metric: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index 9896b766d4..bfb566adc3 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -27,16 +27,15 @@ from feast import Entity from feast.feature_view import FeatureView from feast.infra.infra_object import SQLITE_INFRA_OBJECT_CLASS_TYPE, InfraObject -from feast.infra.key_encoding_utils import serialize_entity_key +from feast.infra.key_encoding_utils import serialize_entity_key, deserialize_entity_key from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.core.InfraObject_pb2 import InfraObject as InfraObjectProto from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.protos.feast.core.SqliteTable_pb2 import SqliteTable as SqliteTableProto from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel, RepoConfig -from feast.utils import to_naive_utc +from feast.utils import to_naive_utc, _build_retrieve_online_document_record class SqliteOnlineStoreConfig(FeastConfigBaseModel): @@ -303,6 +302,7 @@ def retrieve_online_documents( ) -> List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], @@ -385,6 +385,7 @@ def retrieve_online_documents( result: List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], @@ -392,21 +393,14 @@ def retrieve_online_documents( ] = [] for entity_key, _, string_value, distance, event_ts in rows: - feature_value_proto = ValueProto() - feature_value_proto.ParseFromString(string_value if string_value else b"") - vector_value_proto = ValueProto( - float_list_val=FloatListProto(val=embedding) - ) - distance_value_proto = ValueProto(float_val=distance) - - result.append( - ( - event_ts, - feature_value_proto, - vector_value_proto, - distance_value_proto, - ) - ) + result.append(_build_retrieve_online_document_record( + event_ts, + entity_key, + string_value if string_value else b"", + embedding, + distance, + config.entity_key_serialization_version + )) return result diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 992869557a..58ed1252a0 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -40,16 +40,17 @@ from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import RepeatedValue as RepeatedValueProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto from feast.type_map import python_values_to_proto_values from feast.value_type import ValueType from feast.version import get_version +from infra.key_encoding_utils import deserialize_entity_key if typing.TYPE_CHECKING: from feast.feature_service import FeatureService from feast.feature_view import FeatureView from feast.on_demand_feature_view import OnDemandFeatureView - APPLICATION_NAME = "feast-dev/feast" USER_AGENT = "{}/{}".format(APPLICATION_NAME, get_version()) @@ -98,9 +99,9 @@ def get_default_yaml_file_path(repo_path: Path) -> Path: def _get_requested_feature_views_to_features_dict( - feature_refs: List[str], - feature_views: List["FeatureView"], - on_demand_feature_views: List["OnDemandFeatureView"], + feature_refs: List[str], + feature_views: List["FeatureView"], + on_demand_feature_views: List["OnDemandFeatureView"], ) -> Tuple[Dict["FeatureView", List[str]], Dict["OnDemandFeatureView", List[str]]]: """Create a dict of FeatureView -> List[Feature] for all requested features. Set full_feature_names to True to have feature names prefixed by their feature view name.""" @@ -132,7 +133,8 @@ def _get_requested_feature_views_to_features_dict( def _get_column_names( - feature_view: "FeatureView", entities: List[Entity] + feature_view: "FeatureView", + entities: List[Entity] ) -> Tuple[List[str], List[str], str, Optional[str]]: """ If a field mapping exists, run it in reverse on the join keys, @@ -167,7 +169,7 @@ def _get_column_names( created_timestamp_column = ( reverse_field_mapping[created_timestamp_column] if created_timestamp_column - and created_timestamp_column in reverse_field_mapping.keys() + and created_timestamp_column in reverse_field_mapping.keys() else created_timestamp_column ) join_keys = [ @@ -185,8 +187,8 @@ def _get_column_names( name for name in feature_names if name not in join_keys - and name != timestamp_field - and name != created_timestamp_column + and name != timestamp_field + and name != created_timestamp_column ] return ( join_keys, @@ -197,8 +199,8 @@ def _get_column_names( def _run_pyarrow_field_mapping( - table: pyarrow.Table, - field_mapping: Dict[str, str], + table: pyarrow.Table, + field_mapping: Dict[str, str], ) -> pyarrow.Table: # run field mapping in the forward direction cols = table.column_names @@ -225,17 +227,17 @@ def _coerce_datetime(ts): def _convert_arrow_to_proto( - table: Union[pyarrow.Table, pyarrow.RecordBatch], - feature_view: "FeatureView", - join_keys: Dict[str, ValueType], + table: Union[pyarrow.Table, pyarrow.RecordBatch], + feature_view: "FeatureView", + join_keys: Dict[str, ValueType], ) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: # Avoid ChunkedArrays which guarantees `zero_copy_only` available. if isinstance(table, pyarrow.Table): table = table.to_batches()[0] columns = [ - (field.name, field.dtype.to_value_type()) for field in feature_view.features - ] + list(join_keys.items()) + (field.name, field.dtype.to_value_type()) for field in feature_view.features + ] + list(join_keys.items()) proto_values_by_column = { column: python_values_to_proto_values( @@ -292,7 +294,8 @@ def _validate_entity_values(join_key_values: Dict[str, List[ValueProto]]): return set_of_row_lengths.pop() -def _validate_feature_refs(feature_refs: List[str], full_feature_names: bool = False): +def _validate_feature_refs(feature_refs: List[str], + full_feature_names: bool = False): """ Validates that there are no collisions among the feature references. @@ -329,9 +332,9 @@ def _validate_feature_refs(feature_refs: List[str], full_feature_names: bool = F def _group_feature_refs( - features: List[str], - all_feature_views: List["FeatureView"], - all_on_demand_feature_views: List["OnDemandFeatureView"], + features: List[str], + all_feature_views: List["FeatureView"], + all_on_demand_feature_views: List["OnDemandFeatureView"], ) -> Tuple[ List[Tuple["FeatureView", List[str]]], List[Tuple["OnDemandFeatureView", List[str]]] ]: @@ -381,7 +384,8 @@ def _group_feature_refs( def apply_list_mapping( - lst: Iterable[Any], mapping_indexes: Iterable[List[int]] + lst: Iterable[Any], + mapping_indexes: Iterable[List[int]] ) -> Iterable[Any]: output_len = sum(len(item) for item in mapping_indexes) output = [None] * output_len @@ -393,10 +397,10 @@ def apply_list_mapping( def _augment_response_with_on_demand_transforms( - online_features_response: GetOnlineFeaturesResponse, - feature_refs: List[str], - requested_on_demand_feature_views: List["OnDemandFeatureView"], - full_feature_names: bool, + online_features_response: GetOnlineFeaturesResponse, + feature_refs: List[str], + requested_on_demand_feature_views: List["OnDemandFeatureView"], + full_feature_names: bool, ): """Computes on demand feature values and adds them to the result rows. @@ -489,9 +493,9 @@ def _augment_response_with_on_demand_transforms( def _get_entity_maps( - registry, - project, - feature_views, + registry, + project, + feature_views, ) -> Tuple[Dict[str, str], Dict[str, ValueType], Set[str]]: # TODO(felixwang9817): Support entities that have different types for different feature views. entities = registry.list_entities(project, allow_cache=True) @@ -522,9 +526,9 @@ def _get_entity_maps( def _get_table_entity_values( - table: "FeatureView", - entity_name_to_join_key_map: Dict[str, str], - join_key_proto_values: Dict[str, List[ValueProto]], + table: "FeatureView", + entity_name_to_join_key_map: Dict[str, str], + join_key_proto_values: Dict[str, List[ValueProto]], ) -> Dict[str, List[ValueProto]]: # The correct join_keys expected by the OnlineStore for this Feature View. table_join_keys = [ @@ -545,9 +549,9 @@ def _get_table_entity_values( def _get_unique_entities( - table: "FeatureView", - join_key_values: Dict[str, List[ValueProto]], - entity_name_to_join_key_map: Dict[str, str], + table: "FeatureView", + join_key_values: Dict[str, List[ValueProto]], + entity_name_to_join_key_map: Dict[str, str], ) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...]]: """Return the set of unique composite Entities for a Feature View and the indexes at which they appear. @@ -584,8 +588,8 @@ def _get_unique_entities( def _drop_unneeded_columns( - online_features_response: GetOnlineFeaturesResponse, - requested_result_row_names: Set[str], + online_features_response: GetOnlineFeaturesResponse, + requested_result_row_names: Set[str], ): """ Unneeded feature values such as request data and unrequested input feature views will @@ -609,8 +613,8 @@ def _drop_unneeded_columns( def _populate_result_rows_from_columnar( - online_features_response: GetOnlineFeaturesResponse, - data: Dict[str, List[ValueProto]], + online_features_response: GetOnlineFeaturesResponse, + data: Dict[str, List[ValueProto]], ): timestamp = Timestamp() # Only initialize this timestamp once. # Add more values to the existing result rows @@ -626,7 +630,7 @@ def _populate_result_rows_from_columnar( def get_needed_request_data( - grouped_odfv_refs: List[Tuple["OnDemandFeatureView", List[str]]], + grouped_odfv_refs: List[Tuple["OnDemandFeatureView", List[str]]], ) -> Set[str]: needed_request_data: Set[str] = set() for odfv, _ in grouped_odfv_refs: @@ -636,8 +640,8 @@ def get_needed_request_data( def ensure_request_data_values_exist( - needed_request_data: Set[str], - request_data_features: Dict[str, List[Any]], + needed_request_data: Set[str], + request_data_features: Dict[str, List[Any]], ): if len(needed_request_data) != len(request_data_features.keys()): missing_features = [ @@ -647,16 +651,16 @@ def ensure_request_data_values_exist( def _populate_response_from_feature_data( - feature_data: Iterable[ - Tuple[ - Iterable[Timestamp], Iterable["FieldStatus.ValueType"], Iterable[ValueProto] - ] - ], - indexes: Iterable[List[int]], - online_features_response: GetOnlineFeaturesResponse, - full_feature_names: bool, - requested_features: Iterable[str], - table: "FeatureView", + feature_data: Iterable[ + Tuple[ + Iterable[Timestamp], Iterable["FieldStatus.ValueType"], Iterable[ValueProto] + ] + ], + indexes: Iterable[List[int]], + online_features_response: GetOnlineFeaturesResponse, + full_feature_names: bool, + requested_features: Iterable[str], + table: "FeatureView", ): """Populate the GetOnlineFeaturesResponse with feature data. @@ -690,8 +694,8 @@ def _populate_response_from_feature_data( # Populate the result with data fetched from the OnlineStore # which is guaranteed to be aligned with `requested_features`. for ( - feature_idx, - (timestamp_vector, statuses_vector, values_vector), + feature_idx, + (timestamp_vector, statuses_vector, values_vector), ) in enumerate(zip(zip(*timestamps), zip(*statuses), zip(*values))): online_features_response.results.append( GetOnlineFeaturesResponse.FeatureVector( @@ -703,10 +707,10 @@ def _populate_response_from_feature_data( def _get_features( - registry, - project, - features: Union[List[str], "FeatureService"], - allow_cache: bool = False, + registry, + project, + features: Union[List[str], "FeatureService"], + allow_cache: bool = False, ) -> List[str]: from feast.feature_service import FeatureService @@ -737,11 +741,11 @@ def _get_features( def _list_feature_views( - registry, - project, - allow_cache: bool = False, - hide_dummy_entity: bool = True, - tags: Optional[dict[str, str]] = None, + registry, + project, + allow_cache: bool = False, + hide_dummy_entity: bool = True, + tags: Optional[dict[str, str]] = None, ) -> List["FeatureView"]: from feast.feature_view import DUMMY_ENTITY_NAME @@ -755,11 +759,11 @@ def _list_feature_views( def _get_feature_views_to_use( - registry, - project, - features: Optional[Union[List[str], "FeatureService"]], - allow_cache=False, - hide_dummy_entity: bool = True, + registry, + project, + features: Optional[Union[List[str], "FeatureService"]], + allow_cache=False, + hide_dummy_entity: bool = True, ) -> Tuple[List["FeatureView"], List["OnDemandFeatureView"]]: from feast.feature_service import FeatureService @@ -813,10 +817,10 @@ def _get_feature_views_to_use( def _get_online_request_context( - registry, - project, - features: Union[List[str], "FeatureService"], - full_feature_names: bool, + registry, + project, + features: Union[List[str], "FeatureService"], + full_feature_names: bool, ): from feast.feature_view import DUMMY_ENTITY_NAME @@ -881,14 +885,14 @@ def _get_online_request_context( def _prepare_entities_to_read_from_online_store( - registry, - project, - features: Union[List[str], "FeatureService"], - entity_values: Mapping[ - str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValueProto] - ], - full_feature_names: bool = False, - native_entity_values: bool = True, + registry, + project, + features: Union[List[str], "FeatureService"], + entity_values: Mapping[ + str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValueProto] + ], + full_feature_names: bool = False, + native_entity_values: bool = True, ): from feast.feature_view import DUMMY_ENTITY, DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL @@ -976,7 +980,7 @@ def _prepare_entities_to_read_from_online_store( def _get_entity_key_protos( - entity_rows: Iterable[Mapping[str, ValueProto]], + entity_rows: Iterable[Mapping[str, ValueProto]], ) -> List[EntityKeyProto]: # Instantiate one EntityKeyProto per Entity. entity_key_protos = [ @@ -987,8 +991,8 @@ def _get_entity_key_protos( def _convert_rows_to_protobuf( - requested_features: List[str], - read_rows: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]], + requested_features: List[str], + read_rows: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]], ) -> List[Tuple[List[Timestamp], List["FieldStatus.ValueType"], List[ValueProto]]]: # Each row is a set of features for a given entity key. # We only need to convert the data to Protobuf once. @@ -1020,7 +1024,8 @@ def _convert_rows_to_protobuf( def has_all_tags( - object_tags: dict[str, str], requested_tags: Optional[dict[str, str]] = None + object_tags: dict[str, str], + requested_tags: Optional[dict[str, str]] = None ) -> bool: if requested_tags is None: return True @@ -1028,7 +1033,7 @@ def has_all_tags( def tags_list_to_dict( - tags_list: Optional[list[str]] = None, + tags_list: Optional[list[str]] = None, ) -> Optional[dict[str, str]]: if not tags_list: return None @@ -1050,3 +1055,48 @@ def tags_str_to_dict(tags: str = "") -> dict[str, str]: def _utc_now() -> datetime: return datetime.now(tz=timezone.utc) + + +def _build_retrieve_online_document_record( + event_ts: datetime, + entity_key: str, + value: Union[str, bytes], + vector_value: Union[str, List[float]], + distance: float, + entity_key_serialization_version: int +) -> Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], +]: + if entity_key_serialization_version < 3: + entity_key_proto = None + else: + entity_key_proto_bin = bytes(entity_key) + entity_key_proto = deserialize_entity_key( + entity_key_proto_bin, + entity_key_serialization_version=entity_key_serialization_version + ) + + if isinstance(value, str): + value = bytes(value) + feature_value_proto = ValueProto() + feature_value_proto.ParseFromString(value) + + if isinstance(vector_value, str): + vector_value_proto = ValueProto(string_val=vector_value) + else: + vector_value_proto = ValueProto( + float_list_val=FloatListProto(val=vector_value) + ) + + distance_value_proto = ValueProto(float_val=distance) + return ( + event_ts, + entity_key_proto, + feature_value_proto, + vector_value_proto, + distance_value_proto, + ) From c24dc7535ecfd36f1342ac01603fde52042871d9 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 01:25:17 -0700 Subject: [PATCH 02/26] lint Signed-off-by: cmuhao --- sdk/python/feast/feature_store.py | 17 +- .../contrib/duckdb_online_store/duckdb.py | 108 +++++----- .../online_stores/contrib/elasticsearch.py | 100 ++++----- .../infra/online_stores/contrib/postgres.py | 122 +++++------ .../feast/infra/online_stores/online_store.py | 114 ++++++----- .../feast/infra/online_stores/sqlite.py | 22 +- sdk/python/feast/utils.py | 190 +++++++++--------- 7 files changed, 345 insertions(+), 328 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 9bbd271c6a..44713c9d84 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -34,6 +34,7 @@ import pyarrow as pa from colorama import Fore, Style from google.protobuf.timestamp_pb2 import Timestamp +from protos.feast.types.EntityKey_pb2 import EntityKey from tqdm import tqdm from feast import feature_server, flags_helper, ui_server, utils @@ -85,7 +86,6 @@ from feast.stream_feature_view import StreamFeatureView from feast.utils import _utc_now from feast.version import get_version -from protos.feast.types.EntityKey_pb2 import EntityKey warnings.simplefilter("once", DeprecationWarning) @@ -1678,7 +1678,7 @@ def retrieve_online_documents( data={ "entity_key": entity_key_vals, requested_feature: document_feature_vals, - "distance": document_feature_distance_vals + "distance": document_feature_distance_vals, }, ) return OnlineResponse(online_features_response) @@ -1691,7 +1691,9 @@ def _retrieve_from_online_store( query: List[float], top_k: int, distance_metric: Optional[str], - ) -> List[Tuple[Timestamp, EntityKey, "FieldStatus.ValueType", Value, Value, Value]]: + ) -> List[ + Tuple[Timestamp, EntityKey, "FieldStatus.ValueType", Value, Value, Value] + ]: """ Search and return document features from the online document store. """ @@ -1721,7 +1723,14 @@ def _retrieve_from_online_store( status = FieldStatus.PRESENT read_row_protos.append( - (row_ts_proto, entity_key, status, feature_val, vector_value, distance_val) + ( + row_ts_proto, + entity_key, + status, + feature_val, + vector_value, + distance_val, + ) ) return read_row_protos diff --git a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py index 4fd83724d4..b715c68204 100644 --- a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py +++ b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py @@ -1,16 +1,16 @@ import contextlib from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple import duckdb -from typing import Optional, Dict, Any, List, Tuple, Union -from feast import Entity +from infra.key_encoding_utils import serialize_entity_key +from utils import _build_retrieve_online_document_results + from feast.feature_view import FeatureView from feast.infra.online_stores.online_store import OnlineStore -from feast.repo_config import RepoConfig from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto -from infra.key_encoding_utils import serialize_entity_key -from utils import _build_retrieve_online_document_results +from feast.repo_config import RepoConfig class DuckDBOnlineStoreConfig: @@ -22,30 +22,28 @@ class DuckDBOnlineStoreConfig: class DuckDBOnlineStore(OnlineStore): - async def online_read_async(self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None) -> List[ - Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + async def online_read_async( + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: pass - def __init__(self, - config: DuckDBOnlineStoreConfig): + def __init__(self, config: DuckDBOnlineStoreConfig): self.config = config self.connection = None @contextlib.contextmanager - def _get_conn(self, - config: RepoConfig) -> Any: + def _get_conn(self, config: RepoConfig) -> Any: if self.connection is None: self.connection = duckdb.connect(database=self.config.path, read_only=False) yield self.connection - def create_vector_index(self, - config: RepoConfig, - table_name: str, - vector_column: str) -> None: + def create_vector_index( + self, config: RepoConfig, table_name: str, vector_column: str + ) -> None: """Create an HNSW index for vector similarity search.""" if not config.enable_vector_search: raise ValueError("Vector search is not enabled in the configuration.") @@ -57,30 +55,34 @@ def create_vector_index(self, ) def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[Tuple[EntityKeyProto, Dict[str, ValueProto]]], + self, + config: RepoConfig, + table: FeatureView, + data: List[Tuple[EntityKeyProto, Dict[str, ValueProto]]], ) -> None: insert_values = [] for entity_key, values in data: entity_key_bin = serialize_entity_key(entity_key).hex() for feature_name, val in values.items(): - insert_values.append((entity_key_bin, feature_name, val.SerializeToString())) + insert_values.append( + (entity_key_bin, feature_name, val.SerializeToString()) + ) with self._get_conn(config) as conn: - conn.execute(f"CREATE TABLE IF NOT EXISTS {table.name} (entity_key BLOB, feature_name TEXT, value BLOB)") + conn.execute( + f"CREATE TABLE IF NOT EXISTS {table.name} (entity_key BLOB, feature_name TEXT, value BLOB)" + ) conn.executemany( f"INSERT INTO {table.name} (entity_key, feature_name, value) VALUES (?, ?, ?)", - insert_values + insert_values, ) def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[Dict[str, ValueProto]]]]: keys = [serialize_entity_key(key).hex() for key in entity_keys] query = f"SELECT feature_name, value FROM {table.name} WHERE entity_key IN ({','.join(['?'] * len(keys))})" @@ -88,16 +90,21 @@ def online_read( with self._get_conn(config) as conn: results = conn.execute(query, keys).fetchall() - return [{feature_name: ValueProto().ParseFromString(value) for feature_name, value in results}] + return [ + { + feature_name: ValueProto().ParseFromString(value) + for feature_name, value in results + } + ] def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_feature: str, - embedding: List[float], - top_k: int, - distance_metric: Optional[str] = "L2", + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + distance_metric: Optional[str] = "L2", ) -> List[ Tuple[ Optional[datetime], @@ -111,7 +118,9 @@ def retrieve_online_documents( if not self.config.enable_vector_search: raise ValueError("Vector search is not enabled in the configuration.") if config.entity_key_serialization_version < 3: - raise ValueError("Entity key serialization version must be at least 3 for vector search.") + raise ValueError( + "Entity key serialization version must be at least 3 for vector search." + ) result: List[ Tuple[ @@ -138,28 +147,29 @@ def retrieve_online_documents( rows = conn.execute(query, (embedding, top_k)).fetchall() result = _build_retrieve_online_document_results( rows, - entity_key_serialization_version=config.entity_key_serialization_version + entity_key_serialization_version=config.entity_key_serialization_version, ) return result def update( - self, - config: RepoConfig, - tables_to_delete: List[FeatureView], - tables_to_keep: List[FeatureView], + self, + config: RepoConfig, + tables_to_delete: List[FeatureView], + tables_to_keep: List[FeatureView], ) -> None: with self._get_conn(config) as conn: for table in tables_to_delete: conn.execute(f"DROP TABLE IF EXISTS {table.name}") for table in tables_to_keep: conn.execute( - f"CREATE TABLE IF NOT EXISTS {table.name} (entity_key BLOB, feature_name TEXT, value BLOB)") + f"CREATE TABLE IF NOT EXISTS {table.name} (entity_key BLOB, feature_name TEXT, value BLOB)" + ) def teardown( - self, - config: RepoConfig, - tables: List[FeatureView], + self, + config: RepoConfig, + tables: List[FeatureView], ) -> None: with self._get_conn(config) as conn: for table in tables: diff --git a/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py b/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py index 1a3d2fbfc2..b4114410c6 100644 --- a/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py @@ -9,12 +9,15 @@ from elasticsearch import Elasticsearch, helpers from feast import Entity, FeatureView, RepoConfig -from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key, deserialize_entity_key +from feast.infra.key_encoding_utils import ( + get_list_val_str, + serialize_entity_key, +) from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel -from feast.utils import to_naive_utc, _build_retrieve_online_document_record +from feast.utils import _build_retrieve_online_document_record, to_naive_utc class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel): @@ -46,8 +49,7 @@ class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel): class ElasticSearchOnlineStore(OnlineStore): _client: Optional[Elasticsearch] = None - def _get_client(self, - config: RepoConfig) -> Elasticsearch: + def _get_client(self, config: RepoConfig) -> Elasticsearch: online_store_config = config.online_store assert isinstance(online_store_config, ElasticSearchOnlineStoreConfig) @@ -73,9 +75,7 @@ def _get_client(self, ) return self._client - def _bulk_batch_actions(self, - table: FeatureView, - batch: List[Dict[str, Any]]): + def _bulk_batch_actions(self, table: FeatureView, batch: List[Dict[str, Any]]): for row in batch: yield { "_index": table.name, @@ -84,13 +84,13 @@ def _bulk_batch_actions(self, } def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], ) -> None: insert_values = [] for entity_key, values, timestamp, created_ts in data: @@ -120,16 +120,16 @@ def online_write_batch( batch_size = config.online_store.write_batch_size for i in range(0, len(insert_values), batch_size): - batch = insert_values[i: i + batch_size] + batch = insert_values[i : i + batch_size] actions = self._bulk_batch_actions(table, batch) helpers.bulk(self._get_client(config), actions) def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: if not requested_features: body = { @@ -159,9 +159,7 @@ def online_read( ) return results - def create_index(self, - config: RepoConfig, - table: FeatureView): + def create_index(self, config: RepoConfig, table: FeatureView): """ Create an index in ElasticSearch for the given table. TODO: This method can be exposed to users to customize the indexing functionality. @@ -189,13 +187,13 @@ def create_index(self, ) def update( - self, - config: RepoConfig, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, ): # implement the update method for table in tables_to_delete: @@ -204,10 +202,10 @@ def update( self.create_index(config, table) def teardown( - self, - config: RepoConfig, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ): project = config.project try: @@ -218,14 +216,14 @@ def teardown( raise def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_feature: str, - embedding: List[float], - top_k: int, - *args, - **kwargs, + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + *args, + **kwargs, ) -> List[ Tuple[ Optional[datetime], @@ -261,12 +259,14 @@ def retrieve_online_documents( distance = row["_score"] timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") - result.append(_build_retrieve_online_document_record( - timestamp, - entity_key, - base64.b64decode(feature_value), - str(vector_value), - distance, - config.entity_key_serialization_version - )) + result.append( + _build_retrieve_online_document_record( + timestamp, + entity_key, + base64.b64decode(feature_value), + str(vector_value), + distance, + config.entity_key_serialization_version, + ) + ) return result diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 487bb96b90..3b039079f6 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -19,6 +19,7 @@ from psycopg import AsyncConnection, sql from psycopg.connection import Connection from psycopg_pool import AsyncConnectionPool, ConnectionPool +from utils import _build_retrieve_online_document_record from feast import Entity from feast.feature_view import FeatureView @@ -37,7 +38,6 @@ from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RepoConfig -from utils import _build_retrieve_online_document_record SUPPORTED_DISTANCE_METRICS_DICT = { "cosine": "<=>", @@ -65,8 +65,7 @@ class PostgreSQLOnlineStore(OnlineStore): _conn_pool_async: Optional[AsyncConnectionPool] = None @contextlib.contextmanager - def _get_conn(self, - config: RepoConfig) -> Generator[Connection, Any, Any]: + def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]: assert config.online_store.type == "postgres" if config.online_store.conn_type == ConnectionType.pool: @@ -83,8 +82,7 @@ def _get_conn(self, @contextlib.asynccontextmanager async def _get_conn_async( - self, - config: RepoConfig + self, config: RepoConfig ) -> AsyncGenerator[AsyncConnection, Any]: if config.online_store.conn_type == ConnectionType.pool: if not self._conn_pool_async: @@ -101,13 +99,13 @@ async def _get_conn_async( yield self._conn_async def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], ) -> None: # Format insert values insert_values = [] @@ -159,11 +157,11 @@ def online_write_batch( progress(len(data)) def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version) query, params = self._construct_query_and_params( @@ -177,11 +175,11 @@ def online_read( return self._process_rows(keys, rows) async def online_read_async( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version) query, params = self._construct_query_and_params( @@ -197,10 +195,10 @@ async def online_read_async( @staticmethod def _construct_query_and_params( - config: RepoConfig, - table: FeatureView, - keys: List[bytes], - requested_features: Optional[List[str]] = None, + config: RepoConfig, + table: FeatureView, + keys: List[bytes], + requested_features: Optional[List[str]] = None, ) -> Tuple[sql.Composed, Union[Tuple[List[bytes], List[str]], Tuple[List[bytes]]]]: """Construct the SQL query based on the given parameters.""" if requested_features: @@ -227,8 +225,7 @@ def _construct_query_and_params( @staticmethod def _prepare_keys( - entity_keys: List[EntityKeyProto], - entity_key_serialization_version: int + entity_keys: List[EntityKeyProto], entity_key_serialization_version: int ) -> List[bytes]: """Prepare all keys in a list to make fewer round trips to the database.""" return [ @@ -241,8 +238,7 @@ def _prepare_keys( @staticmethod def _process_rows( - keys: List[bytes], - rows: List[Tuple] + keys: List[bytes], rows: List[Tuple] ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: """Transform the retrieved rows in the desired output. @@ -271,13 +267,13 @@ def _process_rows( return result def update( - self, - config: RepoConfig, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, ): project = config.project schema_name = config.online_store.db_schema or config.online_store.user @@ -339,10 +335,10 @@ def update( conn.commit() def teardown( - self, - config: RepoConfig, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ): project = config.project try: @@ -355,13 +351,13 @@ def teardown( raise def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_feature: str, - embedding: List[float], - top_k: int, - distance_metric: Optional[str] = "L2", + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + distance_metric: Optional[str] = "L2", ) -> List[ Tuple[ Optional[datetime], @@ -437,21 +433,29 @@ def retrieve_online_documents( (query_embedding_str,), ) rows = cur.fetchall() - for entity_key, _, feature_val, vector_value, distance_val, event_ts in rows: - result.append(_build_retrieve_online_document_record( - event_ts, - entity_key, - feature_val, - vector_value, - distance_val, - config.entity_key_serialization_version - )) + for ( + entity_key, + _, + feature_val, + vector_value, + distance_val, + event_ts, + ) in rows: + result.append( + _build_retrieve_online_document_record( + event_ts, + entity_key, + feature_val, + vector_value, + distance_val, + config.entity_key_serialization_version, + ) + ) return result -def _table_id(project: str, - table: FeatureView) -> str: +def _table_id(project: str, table: FeatureView) -> str: return f"{project}_{table.name}" diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index 19a03c04f2..fdb5b055cf 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -36,13 +36,13 @@ class OnlineStore(ABC): @abstractmethod def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], ) -> None: """ Writes a batch of feature rows to the online store. @@ -62,11 +62,11 @@ def online_write_batch( @abstractmethod def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: """ Reads features values for the given entity keys. @@ -85,11 +85,11 @@ def online_read( pass async def online_read_async( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: """ Reads features values for the given entity keys asynchronously. @@ -110,16 +110,16 @@ async def online_read_async( ) def get_online_features( - self, - config: RepoConfig, - features: Union[List[str], FeatureService], - entity_rows: Union[ - List[Dict[str, Any]], - Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], - ], - registry: BaseRegistry, - project: str, - full_feature_names: bool = False, + self, + config: RepoConfig, + features: Union[List[str], FeatureService], + entity_rows: Union[ + List[Dict[str, Any]], + Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], + ], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, ) -> OnlineResponse: if isinstance(entity_rows, list): columnar: Dict[str, List[Any]] = {k: [] for k in entity_rows[0].keys()} @@ -197,16 +197,16 @@ def get_online_features( return OnlineResponse(online_features_response) async def get_online_features_async( - self, - config: RepoConfig, - features: Union[List[str], FeatureService], - entity_rows: Union[ - List[Dict[str, Any]], - Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], - ], - registry: BaseRegistry, - project: str, - full_feature_names: bool = False, + self, + config: RepoConfig, + features: Union[List[str], FeatureService], + entity_rows: Union[ + List[Dict[str, Any]], + Mapping[str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValue]], + ], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, ) -> OnlineResponse: if isinstance(entity_rows, list): columnar: Dict[str, List[Any]] = {k: [] for k in entity_rows[0].keys()} @@ -285,13 +285,13 @@ async def get_online_features_async( @abstractmethod def update( - self, - config: RepoConfig, - tables_to_delete: Sequence[FeatureView], - tables_to_keep: Sequence[FeatureView], - entities_to_delete: Sequence[Entity], - entities_to_keep: Sequence[Entity], - partial: bool, + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, ): """ Reconciles cloud resources with the specified set of Feast objects. @@ -310,9 +310,7 @@ def update( pass def plan( - self, - config: RepoConfig, - desired_registry_proto: RegistryProto + self, config: RepoConfig, desired_registry_proto: RegistryProto ) -> List[InfraObject]: """ Returns the set of InfraObjects required to support the desired registry. @@ -325,10 +323,10 @@ def plan( @abstractmethod def teardown( - self, - config: RepoConfig, - tables: Sequence[FeatureView], - entities: Sequence[Entity], + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], ): """ Tears down all cloud resources for the specified set of Feast objects. @@ -341,13 +339,13 @@ def teardown( pass def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_feature: str, - embedding: List[float], - top_k: int, - distance_metric: Optional[str] = None, + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + distance_metric: Optional[str] = None, ) -> List[ Tuple[ Optional[datetime], diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index bfb566adc3..baf1668cb8 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -27,7 +27,7 @@ from feast import Entity from feast.feature_view import FeatureView from feast.infra.infra_object import SQLITE_INFRA_OBJECT_CLASS_TYPE, InfraObject -from feast.infra.key_encoding_utils import serialize_entity_key, deserialize_entity_key +from feast.infra.key_encoding_utils import serialize_entity_key from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.core.InfraObject_pb2 import InfraObject as InfraObjectProto from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto @@ -35,7 +35,7 @@ from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel, RepoConfig -from feast.utils import to_naive_utc, _build_retrieve_online_document_record +from feast.utils import _build_retrieve_online_document_record, to_naive_utc class SqliteOnlineStoreConfig(FeastConfigBaseModel): @@ -393,14 +393,16 @@ def retrieve_online_documents( ] = [] for entity_key, _, string_value, distance, event_ts in rows: - result.append(_build_retrieve_online_document_record( - event_ts, - entity_key, - string_value if string_value else b"", - embedding, - distance, - config.entity_key_serialization_version - )) + result.append( + _build_retrieve_online_document_record( + event_ts, + entity_key, + string_value if string_value else b"", + embedding, + distance, + config.entity_key_serialization_version, + ) + ) return result diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 58ed1252a0..c2bc77813a 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -24,6 +24,7 @@ import pyarrow from dateutil.tz import tzlocal from google.protobuf.timestamp_pb2 import Timestamp +from infra.key_encoding_utils import deserialize_entity_key from feast.constants import FEAST_FS_YAML_FILE_PATH_ENV_NAME from feast.entity import Entity @@ -38,13 +39,12 @@ GetOnlineFeaturesResponse, ) from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto from feast.protos.feast.types.Value_pb2 import RepeatedValue as RepeatedValueProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto -from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto from feast.type_map import python_values_to_proto_values from feast.value_type import ValueType from feast.version import get_version -from infra.key_encoding_utils import deserialize_entity_key if typing.TYPE_CHECKING: from feast.feature_service import FeatureService @@ -99,9 +99,9 @@ def get_default_yaml_file_path(repo_path: Path) -> Path: def _get_requested_feature_views_to_features_dict( - feature_refs: List[str], - feature_views: List["FeatureView"], - on_demand_feature_views: List["OnDemandFeatureView"], + feature_refs: List[str], + feature_views: List["FeatureView"], + on_demand_feature_views: List["OnDemandFeatureView"], ) -> Tuple[Dict["FeatureView", List[str]], Dict["OnDemandFeatureView", List[str]]]: """Create a dict of FeatureView -> List[Feature] for all requested features. Set full_feature_names to True to have feature names prefixed by their feature view name.""" @@ -133,8 +133,7 @@ def _get_requested_feature_views_to_features_dict( def _get_column_names( - feature_view: "FeatureView", - entities: List[Entity] + feature_view: "FeatureView", entities: List[Entity] ) -> Tuple[List[str], List[str], str, Optional[str]]: """ If a field mapping exists, run it in reverse on the join keys, @@ -169,7 +168,7 @@ def _get_column_names( created_timestamp_column = ( reverse_field_mapping[created_timestamp_column] if created_timestamp_column - and created_timestamp_column in reverse_field_mapping.keys() + and created_timestamp_column in reverse_field_mapping.keys() else created_timestamp_column ) join_keys = [ @@ -187,8 +186,8 @@ def _get_column_names( name for name in feature_names if name not in join_keys - and name != timestamp_field - and name != created_timestamp_column + and name != timestamp_field + and name != created_timestamp_column ] return ( join_keys, @@ -199,8 +198,8 @@ def _get_column_names( def _run_pyarrow_field_mapping( - table: pyarrow.Table, - field_mapping: Dict[str, str], + table: pyarrow.Table, + field_mapping: Dict[str, str], ) -> pyarrow.Table: # run field mapping in the forward direction cols = table.column_names @@ -227,17 +226,17 @@ def _coerce_datetime(ts): def _convert_arrow_to_proto( - table: Union[pyarrow.Table, pyarrow.RecordBatch], - feature_view: "FeatureView", - join_keys: Dict[str, ValueType], + table: Union[pyarrow.Table, pyarrow.RecordBatch], + feature_view: "FeatureView", + join_keys: Dict[str, ValueType], ) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: # Avoid ChunkedArrays which guarantees `zero_copy_only` available. if isinstance(table, pyarrow.Table): table = table.to_batches()[0] columns = [ - (field.name, field.dtype.to_value_type()) for field in feature_view.features - ] + list(join_keys.items()) + (field.name, field.dtype.to_value_type()) for field in feature_view.features + ] + list(join_keys.items()) proto_values_by_column = { column: python_values_to_proto_values( @@ -294,8 +293,7 @@ def _validate_entity_values(join_key_values: Dict[str, List[ValueProto]]): return set_of_row_lengths.pop() -def _validate_feature_refs(feature_refs: List[str], - full_feature_names: bool = False): +def _validate_feature_refs(feature_refs: List[str], full_feature_names: bool = False): """ Validates that there are no collisions among the feature references. @@ -332,9 +330,9 @@ def _validate_feature_refs(feature_refs: List[str], def _group_feature_refs( - features: List[str], - all_feature_views: List["FeatureView"], - all_on_demand_feature_views: List["OnDemandFeatureView"], + features: List[str], + all_feature_views: List["FeatureView"], + all_on_demand_feature_views: List["OnDemandFeatureView"], ) -> Tuple[ List[Tuple["FeatureView", List[str]]], List[Tuple["OnDemandFeatureView", List[str]]] ]: @@ -384,8 +382,7 @@ def _group_feature_refs( def apply_list_mapping( - lst: Iterable[Any], - mapping_indexes: Iterable[List[int]] + lst: Iterable[Any], mapping_indexes: Iterable[List[int]] ) -> Iterable[Any]: output_len = sum(len(item) for item in mapping_indexes) output = [None] * output_len @@ -397,10 +394,10 @@ def apply_list_mapping( def _augment_response_with_on_demand_transforms( - online_features_response: GetOnlineFeaturesResponse, - feature_refs: List[str], - requested_on_demand_feature_views: List["OnDemandFeatureView"], - full_feature_names: bool, + online_features_response: GetOnlineFeaturesResponse, + feature_refs: List[str], + requested_on_demand_feature_views: List["OnDemandFeatureView"], + full_feature_names: bool, ): """Computes on demand feature values and adds them to the result rows. @@ -493,9 +490,9 @@ def _augment_response_with_on_demand_transforms( def _get_entity_maps( - registry, - project, - feature_views, + registry, + project, + feature_views, ) -> Tuple[Dict[str, str], Dict[str, ValueType], Set[str]]: # TODO(felixwang9817): Support entities that have different types for different feature views. entities = registry.list_entities(project, allow_cache=True) @@ -526,9 +523,9 @@ def _get_entity_maps( def _get_table_entity_values( - table: "FeatureView", - entity_name_to_join_key_map: Dict[str, str], - join_key_proto_values: Dict[str, List[ValueProto]], + table: "FeatureView", + entity_name_to_join_key_map: Dict[str, str], + join_key_proto_values: Dict[str, List[ValueProto]], ) -> Dict[str, List[ValueProto]]: # The correct join_keys expected by the OnlineStore for this Feature View. table_join_keys = [ @@ -549,9 +546,9 @@ def _get_table_entity_values( def _get_unique_entities( - table: "FeatureView", - join_key_values: Dict[str, List[ValueProto]], - entity_name_to_join_key_map: Dict[str, str], + table: "FeatureView", + join_key_values: Dict[str, List[ValueProto]], + entity_name_to_join_key_map: Dict[str, str], ) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...]]: """Return the set of unique composite Entities for a Feature View and the indexes at which they appear. @@ -588,8 +585,8 @@ def _get_unique_entities( def _drop_unneeded_columns( - online_features_response: GetOnlineFeaturesResponse, - requested_result_row_names: Set[str], + online_features_response: GetOnlineFeaturesResponse, + requested_result_row_names: Set[str], ): """ Unneeded feature values such as request data and unrequested input feature views will @@ -613,8 +610,8 @@ def _drop_unneeded_columns( def _populate_result_rows_from_columnar( - online_features_response: GetOnlineFeaturesResponse, - data: Dict[str, List[ValueProto]], + online_features_response: GetOnlineFeaturesResponse, + data: Dict[str, List[ValueProto]], ): timestamp = Timestamp() # Only initialize this timestamp once. # Add more values to the existing result rows @@ -630,7 +627,7 @@ def _populate_result_rows_from_columnar( def get_needed_request_data( - grouped_odfv_refs: List[Tuple["OnDemandFeatureView", List[str]]], + grouped_odfv_refs: List[Tuple["OnDemandFeatureView", List[str]]], ) -> Set[str]: needed_request_data: Set[str] = set() for odfv, _ in grouped_odfv_refs: @@ -640,8 +637,8 @@ def get_needed_request_data( def ensure_request_data_values_exist( - needed_request_data: Set[str], - request_data_features: Dict[str, List[Any]], + needed_request_data: Set[str], + request_data_features: Dict[str, List[Any]], ): if len(needed_request_data) != len(request_data_features.keys()): missing_features = [ @@ -651,16 +648,16 @@ def ensure_request_data_values_exist( def _populate_response_from_feature_data( - feature_data: Iterable[ - Tuple[ - Iterable[Timestamp], Iterable["FieldStatus.ValueType"], Iterable[ValueProto] - ] - ], - indexes: Iterable[List[int]], - online_features_response: GetOnlineFeaturesResponse, - full_feature_names: bool, - requested_features: Iterable[str], - table: "FeatureView", + feature_data: Iterable[ + Tuple[ + Iterable[Timestamp], Iterable["FieldStatus.ValueType"], Iterable[ValueProto] + ] + ], + indexes: Iterable[List[int]], + online_features_response: GetOnlineFeaturesResponse, + full_feature_names: bool, + requested_features: Iterable[str], + table: "FeatureView", ): """Populate the GetOnlineFeaturesResponse with feature data. @@ -694,8 +691,8 @@ def _populate_response_from_feature_data( # Populate the result with data fetched from the OnlineStore # which is guaranteed to be aligned with `requested_features`. for ( - feature_idx, - (timestamp_vector, statuses_vector, values_vector), + feature_idx, + (timestamp_vector, statuses_vector, values_vector), ) in enumerate(zip(zip(*timestamps), zip(*statuses), zip(*values))): online_features_response.results.append( GetOnlineFeaturesResponse.FeatureVector( @@ -707,10 +704,10 @@ def _populate_response_from_feature_data( def _get_features( - registry, - project, - features: Union[List[str], "FeatureService"], - allow_cache: bool = False, + registry, + project, + features: Union[List[str], "FeatureService"], + allow_cache: bool = False, ) -> List[str]: from feast.feature_service import FeatureService @@ -741,11 +738,11 @@ def _get_features( def _list_feature_views( - registry, - project, - allow_cache: bool = False, - hide_dummy_entity: bool = True, - tags: Optional[dict[str, str]] = None, + registry, + project, + allow_cache: bool = False, + hide_dummy_entity: bool = True, + tags: Optional[dict[str, str]] = None, ) -> List["FeatureView"]: from feast.feature_view import DUMMY_ENTITY_NAME @@ -759,11 +756,11 @@ def _list_feature_views( def _get_feature_views_to_use( - registry, - project, - features: Optional[Union[List[str], "FeatureService"]], - allow_cache=False, - hide_dummy_entity: bool = True, + registry, + project, + features: Optional[Union[List[str], "FeatureService"]], + allow_cache=False, + hide_dummy_entity: bool = True, ) -> Tuple[List["FeatureView"], List["OnDemandFeatureView"]]: from feast.feature_service import FeatureService @@ -817,10 +814,10 @@ def _get_feature_views_to_use( def _get_online_request_context( - registry, - project, - features: Union[List[str], "FeatureService"], - full_feature_names: bool, + registry, + project, + features: Union[List[str], "FeatureService"], + full_feature_names: bool, ): from feast.feature_view import DUMMY_ENTITY_NAME @@ -885,14 +882,14 @@ def _get_online_request_context( def _prepare_entities_to_read_from_online_store( - registry, - project, - features: Union[List[str], "FeatureService"], - entity_values: Mapping[ - str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValueProto] - ], - full_feature_names: bool = False, - native_entity_values: bool = True, + registry, + project, + features: Union[List[str], "FeatureService"], + entity_values: Mapping[ + str, Union[Sequence[Any], Sequence[ValueProto], RepeatedValueProto] + ], + full_feature_names: bool = False, + native_entity_values: bool = True, ): from feast.feature_view import DUMMY_ENTITY, DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL @@ -980,7 +977,7 @@ def _prepare_entities_to_read_from_online_store( def _get_entity_key_protos( - entity_rows: Iterable[Mapping[str, ValueProto]], + entity_rows: Iterable[Mapping[str, ValueProto]], ) -> List[EntityKeyProto]: # Instantiate one EntityKeyProto per Entity. entity_key_protos = [ @@ -991,8 +988,8 @@ def _get_entity_key_protos( def _convert_rows_to_protobuf( - requested_features: List[str], - read_rows: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]], + requested_features: List[str], + read_rows: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]], ) -> List[Tuple[List[Timestamp], List["FieldStatus.ValueType"], List[ValueProto]]]: # Each row is a set of features for a given entity key. # We only need to convert the data to Protobuf once. @@ -1024,8 +1021,7 @@ def _convert_rows_to_protobuf( def has_all_tags( - object_tags: dict[str, str], - requested_tags: Optional[dict[str, str]] = None + object_tags: dict[str, str], requested_tags: Optional[dict[str, str]] = None ) -> bool: if requested_tags is None: return True @@ -1033,7 +1029,7 @@ def has_all_tags( def tags_list_to_dict( - tags_list: Optional[list[str]] = None, + tags_list: Optional[list[str]] = None, ) -> Optional[dict[str, str]]: if not tags_list: return None @@ -1058,12 +1054,12 @@ def _utc_now() -> datetime: def _build_retrieve_online_document_record( - event_ts: datetime, - entity_key: str, - value: Union[str, bytes], - vector_value: Union[str, List[float]], - distance: float, - entity_key_serialization_version: int + event_ts: datetime, + entity_key: str, + value: Union[str, bytes], + vector_value: Union[str, List[float]], + distance: float, + entity_key_serialization_version: int, ) -> Tuple[ Optional[datetime], Optional[EntityKeyProto], @@ -1077,7 +1073,7 @@ def _build_retrieve_online_document_record( entity_key_proto_bin = bytes(entity_key) entity_key_proto = deserialize_entity_key( entity_key_proto_bin, - entity_key_serialization_version=entity_key_serialization_version + entity_key_serialization_version=entity_key_serialization_version, ) if isinstance(value, str): @@ -1088,9 +1084,7 @@ def _build_retrieve_online_document_record( if isinstance(vector_value, str): vector_value_proto = ValueProto(string_val=vector_value) else: - vector_value_proto = ValueProto( - float_list_val=FloatListProto(val=vector_value) - ) + vector_value_proto = ValueProto(float_list_val=FloatListProto(val=vector_value)) distance_value_proto = ValueProto(float_val=distance) return ( From 6224ee927f4c361d084e181cbfc825662c522d79 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 01:28:22 -0700 Subject: [PATCH 03/26] fix lint Signed-off-by: cmuhao --- sdk/python/feast/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index c2bc77813a..199b49dc1d 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -1076,10 +1076,8 @@ def _build_retrieve_online_document_record( entity_key_serialization_version=entity_key_serialization_version, ) - if isinstance(value, str): - value = bytes(value) feature_value_proto = ValueProto() - feature_value_proto.ParseFromString(value) + feature_value_proto.ParseFromString(bytes(value)) if isinstance(vector_value, str): vector_value_proto = ValueProto(string_val=vector_value) From ad8e4adebf640c18abc7ae3073a59c42c1b86282 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 01:32:28 -0700 Subject: [PATCH 04/26] fix lint Signed-off-by: cmuhao --- sdk/python/feast/feature_store.py | 4 ++-- sdk/python/feast/infra/provider.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 44713c9d84..320bd81fae 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1670,8 +1670,8 @@ def retrieve_online_documents( # TODO currently not return the vector value since it is same as feature value, if embedding is supported, # the feature value can be raw text before embedded entity_key_vals = [feature[1] for feature in document_features] - document_feature_vals = [feature[2] for feature in document_features] - document_feature_distance_vals = [feature[4] for feature in document_features] + document_feature_vals = [feature[4] for feature in document_features] + document_feature_distance_vals = [feature[5] for feature in document_features] online_features_response = GetOnlineFeaturesResponse(results=[]) utils._populate_result_rows_from_columnar( online_features_response=online_features_response, diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 9940af1d02..c0062dde02 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -364,6 +364,7 @@ def retrieve_online_documents( ) -> List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], From d5df05629e1726f6fce6f3860630def1fb4def09 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 08:49:42 -0700 Subject: [PATCH 05/26] fix lint Signed-off-by: cmuhao --- .../contrib/duckdb_online_store/duckdb.py | 29 ++++++++++++++----- .../online_stores/contrib/elasticsearch.py | 2 +- .../infra/online_stores/contrib/postgres.py | 2 +- .../feast/infra/online_stores/sqlite.py | 2 +- sdk/python/feast/utils.py | 12 ++++---- 5 files changed, 31 insertions(+), 16 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py index b715c68204..68628de7c9 100644 --- a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py +++ b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py @@ -1,10 +1,11 @@ +import abc import contextlib from datetime import datetime from typing import Any, Dict, List, Optional, Tuple import duckdb from infra.key_encoding_utils import serialize_entity_key -from utils import _build_retrieve_online_document_results +from utils import _build_retrieve_online_document_results, _build_retrieve_online_document_record from feast.feature_view import FeatureView from feast.infra.online_stores.online_store import OnlineStore @@ -22,6 +23,8 @@ class DuckDBOnlineStoreConfig: class DuckDBOnlineStore(OnlineStore): + + @abc.abstractmethod async def online_read_async( self, config: RepoConfig, @@ -45,9 +48,9 @@ def create_vector_index( self, config: RepoConfig, table_name: str, vector_column: str ) -> None: """Create an HNSW index for vector similarity search.""" - if not config.enable_vector_search: + if not config.online_store.enable_vector_search: raise ValueError("Vector search is not enabled in the configuration.") - distance_metric = config.distance_metric + distance_metric = config.online_store.distance_metric with self._get_conn(None) as conn: conn.execute( @@ -145,10 +148,22 @@ def retrieve_online_documents( ORDER BY array_distance(vec, ?::FLOAT[]) LIMIT ?; """ rows = conn.execute(query, (embedding, top_k)).fetchall() - result = _build_retrieve_online_document_results( - rows, - entity_key_serialization_version=config.entity_key_serialization_version, - ) + for ( + entity_key, + _, + feature_val, + vector_value, + distance_val, + event_ts, + ) in rows: + result.append(_build_retrieve_online_document_record( + entity_key=entity_key, + feature_value=feature_val, + vector_value=vector_value, + distance_value=distance_val, + event_timestamp=event_ts, + entity_key_serialization_version=config.entity_key_serialization_version, + )) return result diff --git a/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py b/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py index b4114410c6..a0c25b931a 100644 --- a/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py @@ -261,11 +261,11 @@ def retrieve_online_documents( result.append( _build_retrieve_online_document_record( - timestamp, entity_key, base64.b64decode(feature_value), str(vector_value), distance, + timestamp, config.entity_key_serialization_version, ) ) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 3b039079f6..7652f77b8a 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -443,11 +443,11 @@ def retrieve_online_documents( ) in rows: result.append( _build_retrieve_online_document_record( - event_ts, entity_key, feature_val, vector_value, distance_val, + event_ts, config.entity_key_serialization_version, ) ) diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index baf1668cb8..061a766b8c 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -395,11 +395,11 @@ def retrieve_online_documents( for entity_key, _, string_value, distance, event_ts in rows: result.append( _build_retrieve_online_document_record( - event_ts, entity_key, string_value if string_value else b"", embedding, distance, + event_ts, config.entity_key_serialization_version, ) ) diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 199b49dc1d..88025259e9 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -1054,11 +1054,11 @@ def _utc_now() -> datetime: def _build_retrieve_online_document_record( - event_ts: datetime, entity_key: str, - value: Union[str, bytes], + feature_value: Union[str, bytes], vector_value: Union[str, List[float]], - distance: float, + distance_value: float, + event_timestamp: datetime, entity_key_serialization_version: int, ) -> Tuple[ Optional[datetime], @@ -1077,16 +1077,16 @@ def _build_retrieve_online_document_record( ) feature_value_proto = ValueProto() - feature_value_proto.ParseFromString(bytes(value)) + feature_value_proto.ParseFromString(bytes(feature_value)) if isinstance(vector_value, str): vector_value_proto = ValueProto(string_val=vector_value) else: vector_value_proto = ValueProto(float_list_val=FloatListProto(val=vector_value)) - distance_value_proto = ValueProto(float_val=distance) + distance_value_proto = ValueProto(float_val=distance_value) return ( - event_ts, + event_timestamp, entity_key_proto, feature_value_proto, vector_value_proto, From 6e841c2371736f3972056e4c6d2b99a76c4d0811 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 08:54:32 -0700 Subject: [PATCH 06/26] fix lint Signed-off-by: cmuhao --- .../contrib/duckdb_online_store/duckdb.py | 113 ++++++++++-------- 1 file changed, 61 insertions(+), 52 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py index 68628de7c9..d3493deaea 100644 --- a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py +++ b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py @@ -5,7 +5,9 @@ import duckdb from infra.key_encoding_utils import serialize_entity_key -from utils import _build_retrieve_online_document_results, _build_retrieve_online_document_record +from utils import ( + _build_retrieve_online_document_record, +) from feast.feature_view import FeatureView from feast.infra.online_stores.online_store import OnlineStore @@ -17,35 +19,40 @@ class DuckDBOnlineStoreConfig: type: str = "duckdb" path: str + read_only: bool = False enable_vector_search: bool = False # New option for enabling vector search dimension: Optional[int] = 512 distance_metric: Optional[str] = "L2" class DuckDBOnlineStore(OnlineStore): + __conn: Optional[duckdb.Connection] = None @abc.abstractmethod async def online_read_async( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: pass - def __init__(self, config: DuckDBOnlineStoreConfig): - self.config = config - self.connection = None - @contextlib.contextmanager - def _get_conn(self, config: RepoConfig) -> Any: - if self.connection is None: - self.connection = duckdb.connect(database=self.config.path, read_only=False) - yield self.connection + def _get_conn(self, + config: RepoConfig) -> Any: + assert config.online_store.type == "duckdb" + online_store_config = config.online_store + + if self.__conn is None: + self.__conn = duckdb.connect(database=online_store_config.path, read_only=online_store_config.read_only) + yield self.__conn def create_vector_index( - self, config: RepoConfig, table_name: str, vector_column: str + self, + config: RepoConfig, + table_name: str, + vector_column: str ) -> None: """Create an HNSW index for vector similarity search.""" if not config.online_store.enable_vector_search: @@ -58,10 +65,10 @@ def create_vector_index( ) def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[Tuple[EntityKeyProto, Dict[str, ValueProto]]], + self, + config: RepoConfig, + table: FeatureView, + data: List[Tuple[EntityKeyProto, Dict[str, ValueProto]]], ) -> None: insert_values = [] for entity_key, values in data: @@ -81,11 +88,11 @@ def online_write_batch( ) def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[Dict[str, ValueProto]]]]: keys = [serialize_entity_key(key).hex() for key in entity_keys] query = f"SELECT feature_name, value FROM {table.name} WHERE entity_key IN ({','.join(['?'] * len(keys))})" @@ -101,13 +108,13 @@ def online_read( ] def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_feature: str, - embedding: List[float], - top_k: int, - distance_metric: Optional[str] = "L2", + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + distance_metric: Optional[str] = "L2", ) -> List[ Tuple[ Optional[datetime], @@ -149,29 +156,31 @@ def retrieve_online_documents( """ rows = conn.execute(query, (embedding, top_k)).fetchall() for ( - entity_key, - _, - feature_val, - vector_value, - distance_val, - event_ts, + entity_key, + _, + feature_val, + vector_value, + distance_val, + event_ts, ) in rows: - result.append(_build_retrieve_online_document_record( - entity_key=entity_key, - feature_value=feature_val, - vector_value=vector_value, - distance_value=distance_val, - event_timestamp=event_ts, - entity_key_serialization_version=config.entity_key_serialization_version, - )) + result.append( + _build_retrieve_online_document_record( + entity_key=entity_key, + feature_value=feature_val, + vector_value=vector_value, + distance_value=distance_val, + event_timestamp=event_ts, + entity_key_serialization_version=config.entity_key_serialization_version, + ) + ) return result def update( - self, - config: RepoConfig, - tables_to_delete: List[FeatureView], - tables_to_keep: List[FeatureView], + self, + config: RepoConfig, + tables_to_delete: List[FeatureView], + tables_to_keep: List[FeatureView], ) -> None: with self._get_conn(config) as conn: for table in tables_to_delete: @@ -182,9 +191,9 @@ def update( ) def teardown( - self, - config: RepoConfig, - tables: List[FeatureView], + self, + config: RepoConfig, + tables: List[FeatureView], ) -> None: with self._get_conn(config) as conn: for table in tables: From 2432b7ddaccc9a2ed2493ca6578ee2c6c9c9cbac Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 08:58:05 -0700 Subject: [PATCH 07/26] fix lint Signed-off-by: cmuhao --- .../contrib/duckdb_online_store/duckdb.py | 95 ++++++++++--------- 1 file changed, 50 insertions(+), 45 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py index d3493deaea..e77f2a1f4a 100644 --- a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py +++ b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py @@ -1,7 +1,7 @@ import abc import contextlib from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Callable import duckdb from infra.key_encoding_utils import serialize_entity_key @@ -26,49 +26,51 @@ class DuckDBOnlineStoreConfig: class DuckDBOnlineStore(OnlineStore): - __conn: Optional[duckdb.Connection] = None + __conn: Optional[duckdb.DuckDBPyConnection] = None @abc.abstractmethod async def online_read_async( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: pass @contextlib.contextmanager - def _get_conn(self, - config: RepoConfig) -> Any: + def _get_conn(self, config: RepoConfig) -> Any: assert config.online_store.type == "duckdb" online_store_config = config.online_store if self.__conn is None: - self.__conn = duckdb.connect(database=online_store_config.path, read_only=online_store_config.read_only) + self.__conn = duckdb.connect( + database=online_store_config.path, + read_only=online_store_config.read_only, + ) yield self.__conn def create_vector_index( - self, - config: RepoConfig, - table_name: str, - vector_column: str + self, config: RepoConfig, table_name: str, vector_column: str ) -> None: """Create an HNSW index for vector similarity search.""" if not config.online_store.enable_vector_search: raise ValueError("Vector search is not enabled in the configuration.") distance_metric = config.online_store.distance_metric - with self._get_conn(None) as conn: + with self._get_conn(config) as conn: conn.execute( f"CREATE INDEX idx ON {table_name} USING HNSW ({vector_column}) WITH (metric = '{distance_metric}');" ) def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[Tuple[EntityKeyProto, Dict[str, ValueProto]]], + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], ) -> None: insert_values = [] for entity_key, values in data: @@ -88,11 +90,11 @@ def online_write_batch( ) def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[Dict[str, ValueProto]]]]: keys = [serialize_entity_key(key).hex() for key in entity_keys] query = f"SELECT feature_name, value FROM {table.name} WHERE entity_key IN ({','.join(['?'] * len(keys))})" @@ -108,13 +110,13 @@ def online_read( ] def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_feature: str, - embedding: List[float], - top_k: int, - distance_metric: Optional[str] = "L2", + self, + config: RepoConfig, + table: FeatureView, + requested_feature: str, + embedding: List[float], + top_k: int, + distance_metric: Optional[str] = "L2", ) -> List[ Tuple[ Optional[datetime], @@ -124,9 +126,12 @@ def retrieve_online_documents( Optional[ValueProto], ] ]: + online_store_config = config.online_store """Perform a vector similarity search using the HNSW index.""" - if not self.config.enable_vector_search: + + if not online_store_config.enable_vector_search: raise ValueError("Vector search is not enabled in the configuration.") + if config.entity_key_serialization_version < 3: raise ValueError( "Entity key serialization version must be at least 3 for vector search." @@ -156,12 +161,12 @@ def retrieve_online_documents( """ rows = conn.execute(query, (embedding, top_k)).fetchall() for ( - entity_key, - _, - feature_val, - vector_value, - distance_val, - event_ts, + entity_key, + _, + feature_val, + vector_value, + distance_val, + event_ts, ) in rows: result.append( _build_retrieve_online_document_record( @@ -177,10 +182,10 @@ def retrieve_online_documents( return result def update( - self, - config: RepoConfig, - tables_to_delete: List[FeatureView], - tables_to_keep: List[FeatureView], + self, + config: RepoConfig, + tables_to_delete: List[FeatureView], + tables_to_keep: List[FeatureView], ) -> None: with self._get_conn(config) as conn: for table in tables_to_delete: @@ -191,9 +196,9 @@ def update( ) def teardown( - self, - config: RepoConfig, - tables: List[FeatureView], + self, + config: RepoConfig, + tables: List[FeatureView], ) -> None: with self._get_conn(config) as conn: for table in tables: From 4f5603b06224377fd81803a3f909b54283b7caf9 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 09:01:48 -0700 Subject: [PATCH 08/26] fix lint Signed-off-by: cmuhao --- .../contrib/duckdb_online_store/__init__.py | 0 .../contrib/duckdb_online_store/duckdb.py | 205 ------------------ 2 files changed, 205 deletions(-) delete mode 100644 sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/__init__.py delete mode 100644 sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py diff --git a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/__init__.py b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py b/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py deleted file mode 100644 index e77f2a1f4a..0000000000 --- a/sdk/python/feast/infra/online_stores/contrib/duckdb_online_store/duckdb.py +++ /dev/null @@ -1,205 +0,0 @@ -import abc -import contextlib -from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, Callable - -import duckdb -from infra.key_encoding_utils import serialize_entity_key -from utils import ( - _build_retrieve_online_document_record, -) - -from feast.feature_view import FeatureView -from feast.infra.online_stores.online_store import OnlineStore -from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -from feast.protos.feast.types.Value_pb2 import Value as ValueProto -from feast.repo_config import RepoConfig - - -class DuckDBOnlineStoreConfig: - type: str = "duckdb" - path: str - read_only: bool = False - enable_vector_search: bool = False # New option for enabling vector search - dimension: Optional[int] = 512 - distance_metric: Optional[str] = "L2" - - -class DuckDBOnlineStore(OnlineStore): - __conn: Optional[duckdb.DuckDBPyConnection] = None - - @abc.abstractmethod - async def online_read_async( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, - ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - pass - - @contextlib.contextmanager - def _get_conn(self, config: RepoConfig) -> Any: - assert config.online_store.type == "duckdb" - online_store_config = config.online_store - - if self.__conn is None: - self.__conn = duckdb.connect( - database=online_store_config.path, - read_only=online_store_config.read_only, - ) - yield self.__conn - - def create_vector_index( - self, config: RepoConfig, table_name: str, vector_column: str - ) -> None: - """Create an HNSW index for vector similarity search.""" - if not config.online_store.enable_vector_search: - raise ValueError("Vector search is not enabled in the configuration.") - distance_metric = config.online_store.distance_metric - - with self._get_conn(config) as conn: - conn.execute( - f"CREATE INDEX idx ON {table_name} USING HNSW ({vector_column}) WITH (metric = '{distance_metric}');" - ) - - def online_write_batch( - self, - config: RepoConfig, - table: FeatureView, - data: List[ - Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] - ], - progress: Optional[Callable[[int], Any]], - ) -> None: - insert_values = [] - for entity_key, values in data: - entity_key_bin = serialize_entity_key(entity_key).hex() - for feature_name, val in values.items(): - insert_values.append( - (entity_key_bin, feature_name, val.SerializeToString()) - ) - - with self._get_conn(config) as conn: - conn.execute( - f"CREATE TABLE IF NOT EXISTS {table.name} (entity_key BLOB, feature_name TEXT, value BLOB)" - ) - conn.executemany( - f"INSERT INTO {table.name} (entity_key, feature_name, value) VALUES (?, ?, ?)", - insert_values, - ) - - def online_read( - self, - config: RepoConfig, - table: FeatureView, - entity_keys: List[EntityKeyProto], - requested_features: Optional[List[str]] = None, - ) -> List[Tuple[Optional[Dict[str, ValueProto]]]]: - keys = [serialize_entity_key(key).hex() for key in entity_keys] - query = f"SELECT feature_name, value FROM {table.name} WHERE entity_key IN ({','.join(['?'] * len(keys))})" - - with self._get_conn(config) as conn: - results = conn.execute(query, keys).fetchall() - - return [ - { - feature_name: ValueProto().ParseFromString(value) - for feature_name, value in results - } - ] - - def retrieve_online_documents( - self, - config: RepoConfig, - table: FeatureView, - requested_feature: str, - embedding: List[float], - top_k: int, - distance_metric: Optional[str] = "L2", - ) -> List[ - Tuple[ - Optional[datetime], - Optional[EntityKeyProto], - Optional[ValueProto], - Optional[ValueProto], - Optional[ValueProto], - ] - ]: - online_store_config = config.online_store - """Perform a vector similarity search using the HNSW index.""" - - if not online_store_config.enable_vector_search: - raise ValueError("Vector search is not enabled in the configuration.") - - if config.entity_key_serialization_version < 3: - raise ValueError( - "Entity key serialization version must be at least 3 for vector search." - ) - - result: List[ - Tuple[ - Optional[datetime], - Optional[EntityKeyProto], - Optional[ValueProto], - Optional[ValueProto], - Optional[ValueProto], - ] - ] = [] - - with self._get_conn(config) as conn: - query = f""" - SELECT - entity_key, - feature_name, - value, - vector_value, - event_ts - FROM {table.name} - WHERE feature_name = '{requested_feature}' - ORDER BY array_distance(vec, ?::FLOAT[]) LIMIT ?; - """ - rows = conn.execute(query, (embedding, top_k)).fetchall() - for ( - entity_key, - _, - feature_val, - vector_value, - distance_val, - event_ts, - ) in rows: - result.append( - _build_retrieve_online_document_record( - entity_key=entity_key, - feature_value=feature_val, - vector_value=vector_value, - distance_value=distance_val, - event_timestamp=event_ts, - entity_key_serialization_version=config.entity_key_serialization_version, - ) - ) - - return result - - def update( - self, - config: RepoConfig, - tables_to_delete: List[FeatureView], - tables_to_keep: List[FeatureView], - ) -> None: - with self._get_conn(config) as conn: - for table in tables_to_delete: - conn.execute(f"DROP TABLE IF EXISTS {table.name}") - for table in tables_to_keep: - conn.execute( - f"CREATE TABLE IF NOT EXISTS {table.name} (entity_key BLOB, feature_name TEXT, value BLOB)" - ) - - def teardown( - self, - config: RepoConfig, - tables: List[FeatureView], - ) -> None: - with self._get_conn(config) as conn: - for table in tables: - conn.execute(f"DROP TABLE IF EXISTS {table.name}") From 08039e43e1fb035119339f5124ccb60dce3ad875 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 11:48:01 -0700 Subject: [PATCH 09/26] fix lint Signed-off-by: cmuhao --- sdk/python/feast/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 88025259e9..b6be8e6bb8 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -1070,14 +1070,18 @@ def _build_retrieve_online_document_record( if entity_key_serialization_version < 3: entity_key_proto = None else: - entity_key_proto_bin = bytes(entity_key) + entity_key_proto_bin = entity_key.encode('utf-8') entity_key_proto = deserialize_entity_key( entity_key_proto_bin, entity_key_serialization_version=entity_key_serialization_version, ) feature_value_proto = ValueProto() - feature_value_proto.ParseFromString(bytes(feature_value)) + + if isinstance(feature_value, str): + feature_value_proto.ParseFromString(feature_value.encode('utf-8')) + else: + feature_value_proto.ParseFromString(feature_value) if isinstance(vector_value, str): vector_value_proto = ValueProto(string_val=vector_value) From 61625174aace108ba4ec56b1705065fbeedaf64e Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 11:48:30 -0700 Subject: [PATCH 10/26] fix lint Signed-off-by: cmuhao --- sdk/python/feast/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index b6be8e6bb8..73d2cfefaf 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -1070,7 +1070,7 @@ def _build_retrieve_online_document_record( if entity_key_serialization_version < 3: entity_key_proto = None else: - entity_key_proto_bin = entity_key.encode('utf-8') + entity_key_proto_bin = entity_key.encode("utf-8") entity_key_proto = deserialize_entity_key( entity_key_proto_bin, entity_key_serialization_version=entity_key_serialization_version, @@ -1079,7 +1079,7 @@ def _build_retrieve_online_document_record( feature_value_proto = ValueProto() if isinstance(feature_value, str): - feature_value_proto.ParseFromString(feature_value.encode('utf-8')) + feature_value_proto.ParseFromString(feature_value.encode("utf-8")) else: feature_value_proto.ParseFromString(feature_value) From a341071a84ea9618db31c45047945a35e676c052 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 11:51:53 -0700 Subject: [PATCH 11/26] fix typo Signed-off-by: cmuhao --- sdk/python/feast/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 73d2cfefaf..81b41fa695 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -24,7 +24,7 @@ import pyarrow from dateutil.tz import tzlocal from google.protobuf.timestamp_pb2 import Timestamp -from infra.key_encoding_utils import deserialize_entity_key +from feast.infra.key_encoding_utils import deserialize_entity_key from feast.constants import FEAST_FS_YAML_FILE_PATH_ENV_NAME from feast.entity import Entity From 4438d050e4b6075fabdcf913e261f1ae455a85f1 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 11:52:54 -0700 Subject: [PATCH 12/26] fix typo Signed-off-by: cmuhao --- sdk/python/feast/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 81b41fa695..71db2c1478 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -24,7 +24,6 @@ import pyarrow from dateutil.tz import tzlocal from google.protobuf.timestamp_pb2 import Timestamp -from feast.infra.key_encoding_utils import deserialize_entity_key from feast.constants import FEAST_FS_YAML_FILE_PATH_ENV_NAME from feast.entity import Entity @@ -34,6 +33,7 @@ FeatureViewNotFoundException, RequestDataNotFoundInEntityRowsException, ) +from feast.infra.key_encoding_utils import deserialize_entity_key from feast.protos.feast.serving.ServingService_pb2 import ( FieldStatus, GetOnlineFeaturesResponse, From 36e094b2efc73e37ba44c0f2fd22069b1150e44d Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 11:55:29 -0700 Subject: [PATCH 13/26] fix typo Signed-off-by: cmuhao --- sdk/python/feast/feature_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 320bd81fae..59a4eeb9e1 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -34,7 +34,6 @@ import pyarrow as pa from colorama import Fore, Style from google.protobuf.timestamp_pb2 import Timestamp -from protos.feast.types.EntityKey_pb2 import EntityKey from tqdm import tqdm from feast import feature_server, flags_helper, ui_server, utils @@ -79,6 +78,7 @@ FieldStatus, GetOnlineFeaturesResponse, ) +from feast.protos.feast.types.EntityKey_pb2 import EntityKey from feast.protos.feast.types.Value_pb2 import RepeatedValue, Value from feast.repo_config import RepoConfig, load_repo_config from feast.repo_contents import RepoContents From 43286373d01aa17a106b3930d1b6e2467394c03c Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 12:33:12 -0700 Subject: [PATCH 14/26] fix typo Signed-off-by: cmuhao --- sdk/python/feast/feature_store.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 59a4eeb9e1..c004b2e4b4 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -78,6 +78,7 @@ FieldStatus, GetOnlineFeaturesResponse, ) +from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.protos.feast.types.EntityKey_pb2 import EntityKey from feast.protos.feast.types.Value_pb2 import RepeatedValue, Value from feast.repo_config import RepoConfig, load_repo_config @@ -1670,13 +1671,20 @@ def retrieve_online_documents( # TODO currently not return the vector value since it is same as feature value, if embedding is supported, # the feature value can be raw text before embedded entity_key_vals = [feature[1] for feature in document_features] + join_key_values: Dict[str, List[ValueProto]] = {} + for entity_key_val in entity_key_vals: + for entity_key in entity_key_val.join_keys(): + if entity_key not in join_key_values: + join_key_values[entity_key] = [] + join_key_values[entity_key].append(entity_key_val[entity_key]) + document_feature_vals = [feature[4] for feature in document_features] document_feature_distance_vals = [feature[5] for feature in document_features] online_features_response = GetOnlineFeaturesResponse(results=[]) utils._populate_result_rows_from_columnar( online_features_response=online_features_response, data={ - "entity_key": entity_key_vals, + **join_key_values, requested_feature: document_feature_vals, "distance": document_feature_distance_vals, }, @@ -1715,6 +1723,7 @@ def _retrieve_from_online_store( row_ts_proto.FromDatetime(row_ts) if feature_val is None or vector_value is None or distance_val is None: + entity_key = EntityKey() feature_val = Value() vector_value = Value() distance_val = Value() From 2d7162c829fd4c1276ad63f5eed8c0b187c0a18d Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 15:05:08 -0700 Subject: [PATCH 15/26] fix typo Signed-off-by: cmuhao --- sdk/python/feast/feature_store.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index c004b2e4b4..05d6a4d574 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -78,9 +78,9 @@ FieldStatus, GetOnlineFeaturesResponse, ) -from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.protos.feast.types.EntityKey_pb2 import EntityKey from feast.protos.feast.types.Value_pb2 import RepeatedValue, Value +from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RepoConfig, load_repo_config from feast.repo_contents import RepoContents from feast.saved_dataset import SavedDataset, SavedDatasetStorage, ValidationReference @@ -1673,10 +1673,10 @@ def retrieve_online_documents( entity_key_vals = [feature[1] for feature in document_features] join_key_values: Dict[str, List[ValueProto]] = {} for entity_key_val in entity_key_vals: - for entity_key in entity_key_val.join_keys(): - if entity_key not in join_key_values: - join_key_values[entity_key] = [] - join_key_values[entity_key].append(entity_key_val[entity_key]) + for join_key, entity_value in zip(entity_key_val.join_keys, entity_key_val.entity_values): + if join_key not in join_key_values: + join_key_values[join_key] = [] + join_key_values[join_key].append(entity_value) document_feature_vals = [feature[4] for feature in document_features] document_feature_distance_vals = [feature[5] for feature in document_features] From a5520f52dfc335ad45205fba69b844ffc0eac8af Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 15:06:05 -0700 Subject: [PATCH 16/26] fix typo Signed-off-by: cmuhao --- sdk/python/feast/feature_store.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 05d6a4d574..1dff2d2255 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1673,7 +1673,9 @@ def retrieve_online_documents( entity_key_vals = [feature[1] for feature in document_features] join_key_values: Dict[str, List[ValueProto]] = {} for entity_key_val in entity_key_vals: - for join_key, entity_value in zip(entity_key_val.join_keys, entity_key_val.entity_values): + for join_key, entity_value in zip( + entity_key_val.join_keys, entity_key_val.entity_values + ): if join_key not in join_key_values: join_key_values[join_key] = [] join_key_values[join_key].append(entity_value) @@ -1722,7 +1724,7 @@ def _retrieve_from_online_store( if row_ts is not None: row_ts_proto.FromDatetime(row_ts) - if feature_val is None or vector_value is None or distance_val is None: + if entity_key is None or feature_val is None or vector_value is None or distance_val is None: entity_key = EntityKey() feature_val = Value() vector_value = Value() From 85e329a17f092a53ce128ac1d4c7710b9e1976d9 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 15:06:22 -0700 Subject: [PATCH 17/26] fix typo Signed-off-by: cmuhao --- sdk/python/feast/feature_store.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 1dff2d2255..7ed1de4b32 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1702,7 +1702,7 @@ def _retrieve_from_online_store( top_k: int, distance_metric: Optional[str], ) -> List[ - Tuple[Timestamp, EntityKey, "FieldStatus.ValueType", Value, Value, Value] + Tuple[Timestamp, Optional[EntityKey], "FieldStatus.ValueType", Value, Value, Value] ]: """ Search and return document features from the online document store. @@ -1724,8 +1724,7 @@ def _retrieve_from_online_store( if row_ts is not None: row_ts_proto.FromDatetime(row_ts) - if entity_key is None or feature_val is None or vector_value is None or distance_val is None: - entity_key = EntityKey() + if feature_val is None or vector_value is None or distance_val is None: feature_val = Value() vector_value = Value() distance_val = Value() From f67d25a42b5ac4e4bafef337ad3a27055bcfe373 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 15:07:16 -0700 Subject: [PATCH 18/26] fix typo Signed-off-by: cmuhao --- sdk/python/feast/feature_store.py | 446 +++++++++++++++++------------- 1 file changed, 252 insertions(+), 194 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 7ed1de4b32..857be25ef2 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -108,10 +108,10 @@ class FeatureStore: _provider: Provider def __init__( - self, - repo_path: Optional[str] = None, - config: Optional[RepoConfig] = None, - fs_yaml_file: Optional[Path] = None, + self, + repo_path: Optional[str] = None, + config: Optional[RepoConfig] = None, + fs_yaml_file: Optional[Path] = None, ): """ Creates a FeatureStore object. @@ -205,7 +205,9 @@ def refresh_registry(self): self._registry.refresh(self.project) def list_entities( - self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None ) -> List[Entity]: """ Retrieves the list of entities from the registry. @@ -220,10 +222,10 @@ def list_entities( return self._list_entities(allow_cache, tags=tags) def _list_entities( - self, - allow_cache: bool = False, - hide_dummy_entity: bool = True, - tags: Optional[dict[str, str]] = None, + self, + allow_cache: bool = False, + hide_dummy_entity: bool = True, + tags: Optional[dict[str, str]] = None, ) -> List[Entity]: all_entities = self._registry.list_entities( self.project, allow_cache=allow_cache, tags=tags @@ -235,7 +237,8 @@ def _list_entities( ] def list_feature_services( - self, tags: Optional[dict[str, str]] = None + self, + tags: Optional[dict[str, str]] = None ) -> List[FeatureService]: """ Retrieves the list of feature services from the registry. @@ -249,16 +252,18 @@ def list_feature_services( return self._registry.list_feature_services(self.project, tags=tags) def _list_all_feature_views( - self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None ) -> List[BaseFeatureView]: feature_views = [] for fv in self.registry.list_all_feature_views( - self.project, allow_cache=allow_cache, tags=tags + self.project, allow_cache=allow_cache, tags=tags ): if ( - isinstance(fv, FeatureView) - and fv.entities - and fv.entities[0] == DUMMY_ENTITY_NAME + isinstance(fv, FeatureView) + and fv.entities + and fv.entities[0] == DUMMY_ENTITY_NAME ): fv.entities = [] fv.entity_columns = [] @@ -266,7 +271,9 @@ def _list_all_feature_views( return feature_views def list_all_feature_views( - self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None ) -> List[BaseFeatureView]: """ Retrieves the list of feature views from the registry. @@ -280,7 +287,9 @@ def list_all_feature_views( return self._list_all_feature_views(allow_cache, tags=tags) def list_feature_views( - self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None ) -> List[FeatureView]: """ Retrieves the list of feature views from the registry. @@ -297,7 +306,9 @@ def list_feature_views( ) def list_batch_feature_views( - self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None ) -> List[FeatureView]: """ Retrieves the list of feature views from the registry. @@ -312,19 +323,19 @@ def list_batch_feature_views( return self._list_batch_feature_views(allow_cache=allow_cache, tags=tags) def _list_batch_feature_views( - self, - allow_cache: bool = False, - hide_dummy_entity: bool = True, - tags: Optional[dict[str, str]] = None, + self, + allow_cache: bool = False, + hide_dummy_entity: bool = True, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureView]: feature_views = [] for fv in self._registry.list_feature_views( - self.project, allow_cache=allow_cache, tags=tags + self.project, allow_cache=allow_cache, tags=tags ): if ( - hide_dummy_entity - and fv.entities - and fv.entities[0] == DUMMY_ENTITY_NAME + hide_dummy_entity + and fv.entities + and fv.entities[0] == DUMMY_ENTITY_NAME ): fv.entities = [] fv.entity_columns = [] @@ -332,14 +343,14 @@ def _list_batch_feature_views( return feature_views def _list_stream_feature_views( - self, - allow_cache: bool = False, - hide_dummy_entity: bool = True, - tags: Optional[dict[str, str]] = None, + self, + allow_cache: bool = False, + hide_dummy_entity: bool = True, + tags: Optional[dict[str, str]] = None, ) -> List[StreamFeatureView]: stream_feature_views = [] for sfv in self._registry.list_stream_feature_views( - self.project, allow_cache=allow_cache, tags=tags + self.project, allow_cache=allow_cache, tags=tags ): if hide_dummy_entity and sfv.entities[0] == DUMMY_ENTITY_NAME: sfv.entities = [] @@ -348,7 +359,9 @@ def _list_stream_feature_views( return stream_feature_views def list_on_demand_feature_views( - self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None ) -> List[OnDemandFeatureView]: """ Retrieves the list of on demand feature views from the registry. @@ -365,7 +378,9 @@ def list_on_demand_feature_views( ) def list_stream_feature_views( - self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None ) -> List[StreamFeatureView]: """ Retrieves the list of stream feature views from the registry. @@ -376,7 +391,9 @@ def list_stream_feature_views( return self._list_stream_feature_views(allow_cache, tags=tags) def list_data_sources( - self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None ) -> List[DataSource]: """ Retrieves the list of data sources from the registry. @@ -392,7 +409,9 @@ def list_data_sources( self.project, allow_cache=allow_cache, tags=tags ) - def get_entity(self, name: str, allow_registry_cache: bool = False) -> Entity: + def get_entity(self, + name: str, + allow_registry_cache: bool = False) -> Entity: """ Retrieves an entity. @@ -411,7 +430,9 @@ def get_entity(self, name: str, allow_registry_cache: bool = False) -> Entity: ) def get_feature_service( - self, name: str, allow_cache: bool = False + self, + name: str, + allow_cache: bool = False ) -> FeatureService: """ Retrieves a feature service. @@ -429,7 +450,9 @@ def get_feature_service( return self._registry.get_feature_service(name, self.project, allow_cache) def get_feature_view( - self, name: str, allow_registry_cache: bool = False + self, + name: str, + allow_registry_cache: bool = False ) -> FeatureView: """ Retrieves a feature view. @@ -447,10 +470,10 @@ def get_feature_view( return self._get_feature_view(name, allow_registry_cache=allow_registry_cache) def _get_feature_view( - self, - name: str, - hide_dummy_entity: bool = True, - allow_registry_cache: bool = False, + self, + name: str, + hide_dummy_entity: bool = True, + allow_registry_cache: bool = False, ) -> FeatureView: feature_view = self._registry.get_feature_view( name, self.project, allow_cache=allow_registry_cache @@ -460,7 +483,9 @@ def _get_feature_view( return feature_view def get_stream_feature_view( - self, name: str, allow_registry_cache: bool = False + self, + name: str, + allow_registry_cache: bool = False ) -> StreamFeatureView: """ Retrieves a stream feature view. @@ -480,10 +505,10 @@ def get_stream_feature_view( ) def _get_stream_feature_view( - self, - name: str, - hide_dummy_entity: bool = True, - allow_registry_cache: bool = False, + self, + name: str, + hide_dummy_entity: bool = True, + allow_registry_cache: bool = False, ) -> StreamFeatureView: stream_feature_view = self._registry.get_stream_feature_view( name, self.project, allow_cache=allow_registry_cache @@ -492,7 +517,8 @@ def _get_stream_feature_view( stream_feature_view.entities = [] return stream_feature_view - def get_on_demand_feature_view(self, name: str) -> OnDemandFeatureView: + def get_on_demand_feature_view(self, + name: str) -> OnDemandFeatureView: """ Retrieves a feature view. @@ -507,7 +533,8 @@ def get_on_demand_feature_view(self, name: str) -> OnDemandFeatureView: """ return self._registry.get_on_demand_feature_view(name, self.project) - def get_data_source(self, name: str) -> DataSource: + def get_data_source(self, + name: str) -> DataSource: """ Retrieves the list of data sources from the registry. @@ -522,7 +549,8 @@ def get_data_source(self, name: str) -> DataSource: """ return self._registry.get_data_source(name, self.project) - def delete_feature_view(self, name: str): + def delete_feature_view(self, + name: str): """ Deletes a feature view. @@ -534,7 +562,8 @@ def delete_feature_view(self, name: str): """ return self._registry.delete_feature_view(name, self.project) - def delete_feature_service(self, name: str): + def delete_feature_service(self, + name: str): """ Deletes a feature service. @@ -550,14 +579,14 @@ def _should_use_plan(self): """Returns True if plan and _apply_diffs should be used, False otherwise.""" # Currently only the local provider with sqlite online store supports plan and _apply_diffs. return self.config.provider == "local" and ( - self.config.online_store and self.config.online_store.type == "sqlite" + self.config.online_store and self.config.online_store.type == "sqlite" ) def _validate_all_feature_views( - self, - views_to_update: List[FeatureView], - odfvs_to_update: List[OnDemandFeatureView], - sfvs_to_update: List[StreamFeatureView], + self, + views_to_update: List[FeatureView], + odfvs_to_update: List[OnDemandFeatureView], + sfvs_to_update: List[StreamFeatureView], ): """Validates all feature views.""" if len(odfvs_to_update) > 0 and not flags_helper.is_test(): @@ -575,13 +604,13 @@ def _validate_all_feature_views( ) def _make_inferences( - self, - data_sources_to_update: List[DataSource], - entities_to_update: List[Entity], - views_to_update: List[FeatureView], - odfvs_to_update: List[OnDemandFeatureView], - sfvs_to_update: List[StreamFeatureView], - feature_services_to_update: List[FeatureService], + self, + data_sources_to_update: List[DataSource], + entities_to_update: List[Entity], + views_to_update: List[FeatureView], + odfvs_to_update: List[OnDemandFeatureView], + sfvs_to_update: List[StreamFeatureView], + feature_services_to_update: List[FeatureService], ): """Makes inferences for entities, feature views, odfvs, and feature services.""" update_data_sources_with_inferred_event_timestamp_col( @@ -621,8 +650,8 @@ def _make_inferences( feature_service.infer_features(fvs_to_update=fvs_to_update_map) def _get_feature_views_to_materialize( - self, - feature_views: Optional[List[str]], + self, + feature_views: Optional[List[str]], ) -> List[FeatureView]: """ Returns the list of feature views that should be materialized. @@ -669,7 +698,8 @@ def _get_feature_views_to_materialize( return feature_views_to_materialize def plan( - self, desired_repo_contents: RepoContents + self, + desired_repo_contents: RepoContents ) -> Tuple[RegistryDiff, InfraDiff, Infra]: """Dry-run registering objects to metadata store. @@ -747,7 +777,10 @@ def plan( return registry_diff, infra_diff, new_infra def _apply_diffs( - self, registry_diff: RegistryDiff, infra_diff: InfraDiff, new_infra: Infra + self, + registry_diff: RegistryDiff, + infra_diff: InfraDiff, + new_infra: Infra ): """Applies the given diffs to the metadata store and infrastructure. @@ -764,22 +797,22 @@ def _apply_diffs( self._registry.update_infra(new_infra, self.project, commit=True) def apply( - self, - objects: Union[ - Project, - DataSource, - Entity, - FeatureView, - OnDemandFeatureView, - BatchFeatureView, - StreamFeatureView, - FeatureService, - ValidationReference, - Permission, - List[FeastObject], - ], - objects_to_delete: Optional[List[FeastObject]] = None, - partial: bool = True, + self, + objects: Union[ + Project, + DataSource, + Entity, + FeatureView, + OnDemandFeatureView, + BatchFeatureView, + StreamFeatureView, + FeatureService, + ValidationReference, + Permission, + List[FeastObject], + ], + objects_to_delete: Optional[List[FeastObject]] = None, + partial: bool = True, ): """Register objects to metadata store and update related infrastructure. @@ -837,8 +870,8 @@ def apply( if ( # BFVs are not handled separately from FVs right now. - (isinstance(ob, FeatureView) or isinstance(ob, BatchFeatureView)) - and not isinstance(ob, StreamFeatureView) + (isinstance(ob, FeatureView) or isinstance(ob, BatchFeatureView)) + and not isinstance(ob, StreamFeatureView) ) ] sfvs_to_update = [ob for ob in objects if isinstance(ob, StreamFeatureView)] @@ -855,9 +888,9 @@ def apply( batch_sources_to_add: List[DataSource] = [] for data_source in data_sources_set_to_update: if ( - isinstance(data_source, PushSource) - or isinstance(data_source, KafkaSource) - or isinstance(data_source, KinesisSource) + isinstance(data_source, PushSource) + or isinstance(data_source, KafkaSource) + or isinstance(data_source, KinesisSource) ): assert data_source.batch_source batch_sources_to_add.append(data_source.batch_source) @@ -926,8 +959,8 @@ def apply( ob for ob in objects_to_delete if ( - (isinstance(ob, FeatureView) or isinstance(ob, BatchFeatureView)) - and not isinstance(ob, StreamFeatureView) + (isinstance(ob, FeatureView) or isinstance(ob, BatchFeatureView)) + and not isinstance(ob, StreamFeatureView) ) ] odfvs_to_delete = [ @@ -1011,10 +1044,10 @@ def teardown(self): self._registry.teardown() def get_historical_features( - self, - entity_df: Union[pd.DataFrame, str], - features: Union[List[str], FeatureService], - full_feature_names: bool = False, + self, + entity_df: Union[pd.DataFrame, str], + features: Union[List[str], FeatureService], + full_feature_names: bool = False, ) -> RetrievalJob: """Enrich an entity dataframe with historical feature values for either training or batch scoring. @@ -1118,13 +1151,13 @@ def get_historical_features( return job def create_saved_dataset( - self, - from_: RetrievalJob, - name: str, - storage: SavedDatasetStorage, - tags: Optional[Dict[str, str]] = None, - feature_service: Optional[FeatureService] = None, - allow_overwrite: bool = False, + self, + from_: RetrievalJob, + name: str, + storage: SavedDatasetStorage, + tags: Optional[Dict[str, str]] = None, + feature_service: Optional[FeatureService] = None, + allow_overwrite: bool = False, ) -> SavedDataset: """ Execute provided retrieval job and persist its outcome in given storage. @@ -1184,7 +1217,8 @@ def create_saved_dataset( self._registry.apply_saved_dataset(dataset, self.project, commit=True) return dataset - def get_saved_dataset(self, name: str) -> SavedDataset: + def get_saved_dataset(self, + name: str) -> SavedDataset: """ Find a saved dataset in the registry by provided name and create a retrieval job to pull whole dataset from storage (offline store). @@ -1216,9 +1250,9 @@ def get_saved_dataset(self, name: str) -> SavedDataset: return dataset.with_retrieval_job(retrieval_job) def materialize_incremental( - self, - end_date: datetime, - feature_views: Optional[List[str]] = None, + self, + end_date: datetime, + feature_views: Optional[List[str]] = None, ) -> None: """ Materialize incremental new data from the offline store into the online store. @@ -1307,10 +1341,10 @@ def tqdm_builder(length): ) def materialize( - self, - start_date: datetime, - end_date: datetime, - feature_views: Optional[List[str]] = None, + self, + start_date: datetime, + end_date: datetime, + feature_views: Optional[List[str]] = None, ) -> None: """ Materialize data from the offline store into the online store. @@ -1381,11 +1415,11 @@ def tqdm_builder(length): ) def push( - self, - push_source_name: str, - df: pd.DataFrame, - allow_registry_cache: bool = True, - to: PushMode = PushMode.ONLINE, + self, + push_source_name: str, + df: pd.DataFrame, + allow_registry_cache: bool = True, + to: PushMode = PushMode.ONLINE, ): """ Push features to a push source. This updates all the feature views that have the push source as stream source. @@ -1405,9 +1439,9 @@ def push( fv for fv in all_fvs if ( - fv.stream_source is not None - and isinstance(fv.stream_source, PushSource) - and fv.stream_source.name == push_source_name + fv.stream_source is not None + and isinstance(fv.stream_source, PushSource) + and fv.stream_source.name == push_source_name ) } @@ -1425,11 +1459,11 @@ def push( ) def write_to_online_store( - self, - feature_view_name: str, - df: Optional[pd.DataFrame] = None, - inputs: Optional[Union[Dict[str, List[Any]], pd.DataFrame]] = None, - allow_registry_cache: bool = True, + self, + feature_view_name: str, + df: Optional[pd.DataFrame] = None, + inputs: Optional[Union[Dict[str, List[Any]], pd.DataFrame]] = None, + allow_registry_cache: bool = True, ): """ Persists a dataframe to the online store. @@ -1465,11 +1499,11 @@ def write_to_online_store( provider.ingest_df(feature_view, df) def write_to_offline_store( - self, - feature_view_name: str, - df: pd.DataFrame, - allow_registry_cache: bool = True, - reorder_columns: bool = True, + self, + feature_view_name: str, + df: pd.DataFrame, + allow_registry_cache: bool = True, + reorder_columns: bool = True, ): """ Persists the dataframe directly into the batch data source for the given feature view. @@ -1507,13 +1541,13 @@ def write_to_offline_store( provider.ingest_df_to_offline_store(feature_view, table) def get_online_features( - self, - features: Union[List[str], FeatureService], - entity_rows: Union[ - List[Dict[str, Any]], - Mapping[str, Union[Sequence[Any], Sequence[Value], RepeatedValue]], - ], - full_feature_names: bool = False, + self, + features: Union[List[str], FeatureService], + entity_rows: Union[ + List[Dict[str, Any]], + Mapping[str, Union[Sequence[Any], Sequence[Value], RepeatedValue]], + ], + full_feature_names: bool = False, ) -> OnlineResponse: """ Retrieves the latest online feature data. @@ -1568,13 +1602,13 @@ def get_online_features( ) async def get_online_features_async( - self, - features: Union[List[str], FeatureService], - entity_rows: Union[ - List[Dict[str, Any]], - Mapping[str, Union[Sequence[Any], Sequence[Value], RepeatedValue]], - ], - full_feature_names: bool = False, + self, + features: Union[List[str], FeatureService], + entity_rows: Union[ + List[Dict[str, Any]], + Mapping[str, Union[Sequence[Any], Sequence[Value], RepeatedValue]], + ], + full_feature_names: bool = False, ) -> OnlineResponse: """ [Alpha] Retrieves the latest online feature data asynchronously. @@ -1614,11 +1648,11 @@ async def get_online_features_async( ) def retrieve_online_documents( - self, - feature: str, - query: Union[str, List[float]], - top_k: int, - distance_metric: Optional[str] = None, + self, + feature: str, + query: Union[str, List[float]], + top_k: int, + distance_metric: Optional[str] = None, ) -> OnlineResponse: """ Retrieves the top k closest document features. Note, embeddings are a subset of features. @@ -1673,12 +1707,13 @@ def retrieve_online_documents( entity_key_vals = [feature[1] for feature in document_features] join_key_values: Dict[str, List[ValueProto]] = {} for entity_key_val in entity_key_vals: - for join_key, entity_value in zip( - entity_key_val.join_keys, entity_key_val.entity_values - ): - if join_key not in join_key_values: - join_key_values[join_key] = [] - join_key_values[join_key].append(entity_value) + if entity_key_val is not None: + for join_key, entity_value in zip( + entity_key_val.join_keys, entity_key_val.entity_values + ): + if join_key not in join_key_values: + join_key_values[join_key] = [] + join_key_values[join_key].append(entity_value) document_feature_vals = [feature[4] for feature in document_features] document_feature_distance_vals = [feature[5] for feature in document_features] @@ -1694,15 +1729,17 @@ def retrieve_online_documents( return OnlineResponse(online_features_response) def _retrieve_from_online_store( - self, - provider: Provider, - table: FeatureView, - requested_feature: str, - query: List[float], - top_k: int, - distance_metric: Optional[str], + self, + provider: Provider, + table: FeatureView, + requested_feature: str, + query: List[float], + top_k: int, + distance_metric: Optional[str], ) -> List[ - Tuple[Timestamp, Optional[EntityKey], "FieldStatus.ValueType", Value, Value, Value] + Tuple[ + Timestamp, Optional[EntityKey], "FieldStatus.ValueType", Value, Value, Value + ] ]: """ Search and return document features from the online document store. @@ -1745,15 +1782,15 @@ def _retrieve_from_online_store( return read_row_protos def serve( - self, - host: str, - port: int, - type_: str = "http", - no_access_log: bool = True, - workers: int = 1, - metrics: bool = False, - keep_alive_timeout: int = 30, - registry_ttl_sec: int = 2, + self, + host: str, + port: int, + type_: str = "http", + no_access_log: bool = True, + workers: int = 1, + metrics: bool = False, + keep_alive_timeout: int = 30, + registry_ttl_sec: int = 2, ) -> None: """Start the feature consumption server locally on a given port.""" type_ = type_.lower() @@ -1778,12 +1815,12 @@ def get_feature_server_endpoint(self) -> Optional[str]: return self._provider.get_feature_server_endpoint() def serve_ui( - self, - host: str, - port: int, - get_registry_dump: Callable, - registry_ttl_sec: int, - root_path: str = "", + self, + host: str, + port: int, + get_registry_dump: Callable, + registry_ttl_sec: int, + root_path: str = "", ) -> None: """Start the UI server locally""" if flags_helper.is_test(): @@ -1802,19 +1839,23 @@ def serve_ui( root_path=root_path, ) - def serve_registry(self, port: int) -> None: + def serve_registry(self, + port: int) -> None: """Start registry server locally on a given port.""" from feast import registry_server registry_server.start_server(self, port) - def serve_offline(self, host: str, port: int) -> None: + def serve_offline(self, + host: str, + port: int) -> None: """Start offline server locally on a given port.""" from feast import offline_server offline_server.start_server(self, host, port) - def serve_transformations(self, port: int) -> None: + def serve_transformations(self, + port: int) -> None: """Start the feature transformation server locally on a given port.""" warnings.warn( "On demand feature view is an experimental feature. " @@ -1827,7 +1868,9 @@ def serve_transformations(self, port: int) -> None: transformation_server.start_server(self, port) def write_logged_features( - self, logs: Union[pa.Table, Path], source: FeatureService + self, + logs: Union[pa.Table, Path], + source: FeatureService ): """ Write logs produced by a source (currently only feature service is supported as a source) @@ -1841,7 +1884,7 @@ def write_logged_features( raise ValueError("Only feature service is currently supported as a source") assert ( - source.logging_config is not None + source.logging_config is not None ), "Feature service must be configured with logging config in order to use this functionality" assert isinstance(logs, (pa.Table, Path)) @@ -1854,13 +1897,13 @@ def write_logged_features( ) def validate_logged_features( - self, - source: FeatureService, - start: datetime, - end: datetime, - reference: ValidationReference, - throw_exception: bool = True, - cache_profile: bool = True, + self, + source: FeatureService, + start: datetime, + end: datetime, + reference: ValidationReference, + throw_exception: bool = True, + cache_profile: bool = True, ) -> Optional[ValidationFailed]: """ Load logged features from an offline store and validate them against provided validation reference. @@ -1914,7 +1957,9 @@ def validate_logged_features( return None def get_validation_reference( - self, name: str, allow_cache: bool = False + self, + name: str, + allow_cache: bool = False ) -> ValidationReference: """ Retrieves a validation reference. @@ -1929,7 +1974,9 @@ def get_validation_reference( return ref def list_validation_references( - self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None ) -> List[ValidationReference]: """ Retrieves the list of validation references from the registry. @@ -1946,7 +1993,9 @@ def list_validation_references( ) def list_permissions( - self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None ) -> List[Permission]: """ Retrieves the list of permissions from the registry. @@ -1962,7 +2011,8 @@ def list_permissions( self.project, allow_cache=allow_cache, tags=tags ) - def get_permission(self, name: str) -> Permission: + def get_permission(self, + name: str) -> Permission: """ Retrieves a permission from the registry. @@ -1978,7 +2028,9 @@ def get_permission(self, name: str) -> Permission: return self._registry.get_permission(name, self.project) def list_projects( - self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None ) -> List[Project]: """ Retrieves the list of projects from the registry. @@ -1992,7 +2044,8 @@ def list_projects( """ return self._registry.list_projects(allow_cache=allow_cache, tags=tags) - def get_project(self, name: Optional[str]) -> Project: + def get_project(self, + name: Optional[str]) -> Project: """ Retrieves a project from the registry. @@ -2008,7 +2061,9 @@ def get_project(self, name: Optional[str]) -> Project: return self._registry.get_project(name or self.project) def list_saved_datasets( - self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None ) -> List[SavedDataset]: """ Retrieves the list of saved datasets from the registry. @@ -2026,7 +2081,10 @@ def list_saved_datasets( def _print_materialization_log( - start_date, end_date, num_feature_views: int, online_store: str + start_date, + end_date, + num_feature_views: int, + online_store: str ): if start_date: print( From 6b9255d62fda213c9d80267b4d0a5e40e3ba1f26 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 15:07:33 -0700 Subject: [PATCH 19/26] fix lint Signed-off-by: cmuhao --- sdk/python/feast/feature_store.py | 431 +++++++++++++----------------- 1 file changed, 188 insertions(+), 243 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 857be25ef2..ab2bc6cec2 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -108,10 +108,10 @@ class FeatureStore: _provider: Provider def __init__( - self, - repo_path: Optional[str] = None, - config: Optional[RepoConfig] = None, - fs_yaml_file: Optional[Path] = None, + self, + repo_path: Optional[str] = None, + config: Optional[RepoConfig] = None, + fs_yaml_file: Optional[Path] = None, ): """ Creates a FeatureStore object. @@ -205,9 +205,7 @@ def refresh_registry(self): self._registry.refresh(self.project) def list_entities( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[Entity]: """ Retrieves the list of entities from the registry. @@ -222,10 +220,10 @@ def list_entities( return self._list_entities(allow_cache, tags=tags) def _list_entities( - self, - allow_cache: bool = False, - hide_dummy_entity: bool = True, - tags: Optional[dict[str, str]] = None, + self, + allow_cache: bool = False, + hide_dummy_entity: bool = True, + tags: Optional[dict[str, str]] = None, ) -> List[Entity]: all_entities = self._registry.list_entities( self.project, allow_cache=allow_cache, tags=tags @@ -237,8 +235,7 @@ def _list_entities( ] def list_feature_services( - self, - tags: Optional[dict[str, str]] = None + self, tags: Optional[dict[str, str]] = None ) -> List[FeatureService]: """ Retrieves the list of feature services from the registry. @@ -252,18 +249,16 @@ def list_feature_services( return self._registry.list_feature_services(self.project, tags=tags) def _list_all_feature_views( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[BaseFeatureView]: feature_views = [] for fv in self.registry.list_all_feature_views( - self.project, allow_cache=allow_cache, tags=tags + self.project, allow_cache=allow_cache, tags=tags ): if ( - isinstance(fv, FeatureView) - and fv.entities - and fv.entities[0] == DUMMY_ENTITY_NAME + isinstance(fv, FeatureView) + and fv.entities + and fv.entities[0] == DUMMY_ENTITY_NAME ): fv.entities = [] fv.entity_columns = [] @@ -271,9 +266,7 @@ def _list_all_feature_views( return feature_views def list_all_feature_views( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[BaseFeatureView]: """ Retrieves the list of feature views from the registry. @@ -287,9 +280,7 @@ def list_all_feature_views( return self._list_all_feature_views(allow_cache, tags=tags) def list_feature_views( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[FeatureView]: """ Retrieves the list of feature views from the registry. @@ -306,9 +297,7 @@ def list_feature_views( ) def list_batch_feature_views( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[FeatureView]: """ Retrieves the list of feature views from the registry. @@ -323,19 +312,19 @@ def list_batch_feature_views( return self._list_batch_feature_views(allow_cache=allow_cache, tags=tags) def _list_batch_feature_views( - self, - allow_cache: bool = False, - hide_dummy_entity: bool = True, - tags: Optional[dict[str, str]] = None, + self, + allow_cache: bool = False, + hide_dummy_entity: bool = True, + tags: Optional[dict[str, str]] = None, ) -> List[FeatureView]: feature_views = [] for fv in self._registry.list_feature_views( - self.project, allow_cache=allow_cache, tags=tags + self.project, allow_cache=allow_cache, tags=tags ): if ( - hide_dummy_entity - and fv.entities - and fv.entities[0] == DUMMY_ENTITY_NAME + hide_dummy_entity + and fv.entities + and fv.entities[0] == DUMMY_ENTITY_NAME ): fv.entities = [] fv.entity_columns = [] @@ -343,14 +332,14 @@ def _list_batch_feature_views( return feature_views def _list_stream_feature_views( - self, - allow_cache: bool = False, - hide_dummy_entity: bool = True, - tags: Optional[dict[str, str]] = None, + self, + allow_cache: bool = False, + hide_dummy_entity: bool = True, + tags: Optional[dict[str, str]] = None, ) -> List[StreamFeatureView]: stream_feature_views = [] for sfv in self._registry.list_stream_feature_views( - self.project, allow_cache=allow_cache, tags=tags + self.project, allow_cache=allow_cache, tags=tags ): if hide_dummy_entity and sfv.entities[0] == DUMMY_ENTITY_NAME: sfv.entities = [] @@ -359,9 +348,7 @@ def _list_stream_feature_views( return stream_feature_views def list_on_demand_feature_views( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[OnDemandFeatureView]: """ Retrieves the list of on demand feature views from the registry. @@ -378,9 +365,7 @@ def list_on_demand_feature_views( ) def list_stream_feature_views( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[StreamFeatureView]: """ Retrieves the list of stream feature views from the registry. @@ -391,9 +376,7 @@ def list_stream_feature_views( return self._list_stream_feature_views(allow_cache, tags=tags) def list_data_sources( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[DataSource]: """ Retrieves the list of data sources from the registry. @@ -409,9 +392,7 @@ def list_data_sources( self.project, allow_cache=allow_cache, tags=tags ) - def get_entity(self, - name: str, - allow_registry_cache: bool = False) -> Entity: + def get_entity(self, name: str, allow_registry_cache: bool = False) -> Entity: """ Retrieves an entity. @@ -430,9 +411,7 @@ def get_entity(self, ) def get_feature_service( - self, - name: str, - allow_cache: bool = False + self, name: str, allow_cache: bool = False ) -> FeatureService: """ Retrieves a feature service. @@ -450,9 +429,7 @@ def get_feature_service( return self._registry.get_feature_service(name, self.project, allow_cache) def get_feature_view( - self, - name: str, - allow_registry_cache: bool = False + self, name: str, allow_registry_cache: bool = False ) -> FeatureView: """ Retrieves a feature view. @@ -470,10 +447,10 @@ def get_feature_view( return self._get_feature_view(name, allow_registry_cache=allow_registry_cache) def _get_feature_view( - self, - name: str, - hide_dummy_entity: bool = True, - allow_registry_cache: bool = False, + self, + name: str, + hide_dummy_entity: bool = True, + allow_registry_cache: bool = False, ) -> FeatureView: feature_view = self._registry.get_feature_view( name, self.project, allow_cache=allow_registry_cache @@ -483,9 +460,7 @@ def _get_feature_view( return feature_view def get_stream_feature_view( - self, - name: str, - allow_registry_cache: bool = False + self, name: str, allow_registry_cache: bool = False ) -> StreamFeatureView: """ Retrieves a stream feature view. @@ -505,10 +480,10 @@ def get_stream_feature_view( ) def _get_stream_feature_view( - self, - name: str, - hide_dummy_entity: bool = True, - allow_registry_cache: bool = False, + self, + name: str, + hide_dummy_entity: bool = True, + allow_registry_cache: bool = False, ) -> StreamFeatureView: stream_feature_view = self._registry.get_stream_feature_view( name, self.project, allow_cache=allow_registry_cache @@ -517,8 +492,7 @@ def _get_stream_feature_view( stream_feature_view.entities = [] return stream_feature_view - def get_on_demand_feature_view(self, - name: str) -> OnDemandFeatureView: + def get_on_demand_feature_view(self, name: str) -> OnDemandFeatureView: """ Retrieves a feature view. @@ -533,8 +507,7 @@ def get_on_demand_feature_view(self, """ return self._registry.get_on_demand_feature_view(name, self.project) - def get_data_source(self, - name: str) -> DataSource: + def get_data_source(self, name: str) -> DataSource: """ Retrieves the list of data sources from the registry. @@ -549,8 +522,7 @@ def get_data_source(self, """ return self._registry.get_data_source(name, self.project) - def delete_feature_view(self, - name: str): + def delete_feature_view(self, name: str): """ Deletes a feature view. @@ -562,8 +534,7 @@ def delete_feature_view(self, """ return self._registry.delete_feature_view(name, self.project) - def delete_feature_service(self, - name: str): + def delete_feature_service(self, name: str): """ Deletes a feature service. @@ -579,14 +550,14 @@ def _should_use_plan(self): """Returns True if plan and _apply_diffs should be used, False otherwise.""" # Currently only the local provider with sqlite online store supports plan and _apply_diffs. return self.config.provider == "local" and ( - self.config.online_store and self.config.online_store.type == "sqlite" + self.config.online_store and self.config.online_store.type == "sqlite" ) def _validate_all_feature_views( - self, - views_to_update: List[FeatureView], - odfvs_to_update: List[OnDemandFeatureView], - sfvs_to_update: List[StreamFeatureView], + self, + views_to_update: List[FeatureView], + odfvs_to_update: List[OnDemandFeatureView], + sfvs_to_update: List[StreamFeatureView], ): """Validates all feature views.""" if len(odfvs_to_update) > 0 and not flags_helper.is_test(): @@ -604,13 +575,13 @@ def _validate_all_feature_views( ) def _make_inferences( - self, - data_sources_to_update: List[DataSource], - entities_to_update: List[Entity], - views_to_update: List[FeatureView], - odfvs_to_update: List[OnDemandFeatureView], - sfvs_to_update: List[StreamFeatureView], - feature_services_to_update: List[FeatureService], + self, + data_sources_to_update: List[DataSource], + entities_to_update: List[Entity], + views_to_update: List[FeatureView], + odfvs_to_update: List[OnDemandFeatureView], + sfvs_to_update: List[StreamFeatureView], + feature_services_to_update: List[FeatureService], ): """Makes inferences for entities, feature views, odfvs, and feature services.""" update_data_sources_with_inferred_event_timestamp_col( @@ -650,8 +621,8 @@ def _make_inferences( feature_service.infer_features(fvs_to_update=fvs_to_update_map) def _get_feature_views_to_materialize( - self, - feature_views: Optional[List[str]], + self, + feature_views: Optional[List[str]], ) -> List[FeatureView]: """ Returns the list of feature views that should be materialized. @@ -698,8 +669,7 @@ def _get_feature_views_to_materialize( return feature_views_to_materialize def plan( - self, - desired_repo_contents: RepoContents + self, desired_repo_contents: RepoContents ) -> Tuple[RegistryDiff, InfraDiff, Infra]: """Dry-run registering objects to metadata store. @@ -777,10 +747,7 @@ def plan( return registry_diff, infra_diff, new_infra def _apply_diffs( - self, - registry_diff: RegistryDiff, - infra_diff: InfraDiff, - new_infra: Infra + self, registry_diff: RegistryDiff, infra_diff: InfraDiff, new_infra: Infra ): """Applies the given diffs to the metadata store and infrastructure. @@ -797,22 +764,22 @@ def _apply_diffs( self._registry.update_infra(new_infra, self.project, commit=True) def apply( - self, - objects: Union[ - Project, - DataSource, - Entity, - FeatureView, - OnDemandFeatureView, - BatchFeatureView, - StreamFeatureView, - FeatureService, - ValidationReference, - Permission, - List[FeastObject], - ], - objects_to_delete: Optional[List[FeastObject]] = None, - partial: bool = True, + self, + objects: Union[ + Project, + DataSource, + Entity, + FeatureView, + OnDemandFeatureView, + BatchFeatureView, + StreamFeatureView, + FeatureService, + ValidationReference, + Permission, + List[FeastObject], + ], + objects_to_delete: Optional[List[FeastObject]] = None, + partial: bool = True, ): """Register objects to metadata store and update related infrastructure. @@ -870,8 +837,8 @@ def apply( if ( # BFVs are not handled separately from FVs right now. - (isinstance(ob, FeatureView) or isinstance(ob, BatchFeatureView)) - and not isinstance(ob, StreamFeatureView) + (isinstance(ob, FeatureView) or isinstance(ob, BatchFeatureView)) + and not isinstance(ob, StreamFeatureView) ) ] sfvs_to_update = [ob for ob in objects if isinstance(ob, StreamFeatureView)] @@ -888,9 +855,9 @@ def apply( batch_sources_to_add: List[DataSource] = [] for data_source in data_sources_set_to_update: if ( - isinstance(data_source, PushSource) - or isinstance(data_source, KafkaSource) - or isinstance(data_source, KinesisSource) + isinstance(data_source, PushSource) + or isinstance(data_source, KafkaSource) + or isinstance(data_source, KinesisSource) ): assert data_source.batch_source batch_sources_to_add.append(data_source.batch_source) @@ -959,8 +926,8 @@ def apply( ob for ob in objects_to_delete if ( - (isinstance(ob, FeatureView) or isinstance(ob, BatchFeatureView)) - and not isinstance(ob, StreamFeatureView) + (isinstance(ob, FeatureView) or isinstance(ob, BatchFeatureView)) + and not isinstance(ob, StreamFeatureView) ) ] odfvs_to_delete = [ @@ -1044,10 +1011,10 @@ def teardown(self): self._registry.teardown() def get_historical_features( - self, - entity_df: Union[pd.DataFrame, str], - features: Union[List[str], FeatureService], - full_feature_names: bool = False, + self, + entity_df: Union[pd.DataFrame, str], + features: Union[List[str], FeatureService], + full_feature_names: bool = False, ) -> RetrievalJob: """Enrich an entity dataframe with historical feature values for either training or batch scoring. @@ -1151,13 +1118,13 @@ def get_historical_features( return job def create_saved_dataset( - self, - from_: RetrievalJob, - name: str, - storage: SavedDatasetStorage, - tags: Optional[Dict[str, str]] = None, - feature_service: Optional[FeatureService] = None, - allow_overwrite: bool = False, + self, + from_: RetrievalJob, + name: str, + storage: SavedDatasetStorage, + tags: Optional[Dict[str, str]] = None, + feature_service: Optional[FeatureService] = None, + allow_overwrite: bool = False, ) -> SavedDataset: """ Execute provided retrieval job and persist its outcome in given storage. @@ -1217,8 +1184,7 @@ def create_saved_dataset( self._registry.apply_saved_dataset(dataset, self.project, commit=True) return dataset - def get_saved_dataset(self, - name: str) -> SavedDataset: + def get_saved_dataset(self, name: str) -> SavedDataset: """ Find a saved dataset in the registry by provided name and create a retrieval job to pull whole dataset from storage (offline store). @@ -1250,9 +1216,9 @@ def get_saved_dataset(self, return dataset.with_retrieval_job(retrieval_job) def materialize_incremental( - self, - end_date: datetime, - feature_views: Optional[List[str]] = None, + self, + end_date: datetime, + feature_views: Optional[List[str]] = None, ) -> None: """ Materialize incremental new data from the offline store into the online store. @@ -1341,10 +1307,10 @@ def tqdm_builder(length): ) def materialize( - self, - start_date: datetime, - end_date: datetime, - feature_views: Optional[List[str]] = None, + self, + start_date: datetime, + end_date: datetime, + feature_views: Optional[List[str]] = None, ) -> None: """ Materialize data from the offline store into the online store. @@ -1415,11 +1381,11 @@ def tqdm_builder(length): ) def push( - self, - push_source_name: str, - df: pd.DataFrame, - allow_registry_cache: bool = True, - to: PushMode = PushMode.ONLINE, + self, + push_source_name: str, + df: pd.DataFrame, + allow_registry_cache: bool = True, + to: PushMode = PushMode.ONLINE, ): """ Push features to a push source. This updates all the feature views that have the push source as stream source. @@ -1439,9 +1405,9 @@ def push( fv for fv in all_fvs if ( - fv.stream_source is not None - and isinstance(fv.stream_source, PushSource) - and fv.stream_source.name == push_source_name + fv.stream_source is not None + and isinstance(fv.stream_source, PushSource) + and fv.stream_source.name == push_source_name ) } @@ -1459,11 +1425,11 @@ def push( ) def write_to_online_store( - self, - feature_view_name: str, - df: Optional[pd.DataFrame] = None, - inputs: Optional[Union[Dict[str, List[Any]], pd.DataFrame]] = None, - allow_registry_cache: bool = True, + self, + feature_view_name: str, + df: Optional[pd.DataFrame] = None, + inputs: Optional[Union[Dict[str, List[Any]], pd.DataFrame]] = None, + allow_registry_cache: bool = True, ): """ Persists a dataframe to the online store. @@ -1499,11 +1465,11 @@ def write_to_online_store( provider.ingest_df(feature_view, df) def write_to_offline_store( - self, - feature_view_name: str, - df: pd.DataFrame, - allow_registry_cache: bool = True, - reorder_columns: bool = True, + self, + feature_view_name: str, + df: pd.DataFrame, + allow_registry_cache: bool = True, + reorder_columns: bool = True, ): """ Persists the dataframe directly into the batch data source for the given feature view. @@ -1541,13 +1507,13 @@ def write_to_offline_store( provider.ingest_df_to_offline_store(feature_view, table) def get_online_features( - self, - features: Union[List[str], FeatureService], - entity_rows: Union[ - List[Dict[str, Any]], - Mapping[str, Union[Sequence[Any], Sequence[Value], RepeatedValue]], - ], - full_feature_names: bool = False, + self, + features: Union[List[str], FeatureService], + entity_rows: Union[ + List[Dict[str, Any]], + Mapping[str, Union[Sequence[Any], Sequence[Value], RepeatedValue]], + ], + full_feature_names: bool = False, ) -> OnlineResponse: """ Retrieves the latest online feature data. @@ -1602,13 +1568,13 @@ def get_online_features( ) async def get_online_features_async( - self, - features: Union[List[str], FeatureService], - entity_rows: Union[ - List[Dict[str, Any]], - Mapping[str, Union[Sequence[Any], Sequence[Value], RepeatedValue]], - ], - full_feature_names: bool = False, + self, + features: Union[List[str], FeatureService], + entity_rows: Union[ + List[Dict[str, Any]], + Mapping[str, Union[Sequence[Any], Sequence[Value], RepeatedValue]], + ], + full_feature_names: bool = False, ) -> OnlineResponse: """ [Alpha] Retrieves the latest online feature data asynchronously. @@ -1648,11 +1614,11 @@ async def get_online_features_async( ) def retrieve_online_documents( - self, - feature: str, - query: Union[str, List[float]], - top_k: int, - distance_metric: Optional[str] = None, + self, + feature: str, + query: Union[str, List[float]], + top_k: int, + distance_metric: Optional[str] = None, ) -> OnlineResponse: """ Retrieves the top k closest document features. Note, embeddings are a subset of features. @@ -1709,7 +1675,7 @@ def retrieve_online_documents( for entity_key_val in entity_key_vals: if entity_key_val is not None: for join_key, entity_value in zip( - entity_key_val.join_keys, entity_key_val.entity_values + entity_key_val.join_keys, entity_key_val.entity_values ): if join_key not in join_key_values: join_key_values[join_key] = [] @@ -1729,13 +1695,13 @@ def retrieve_online_documents( return OnlineResponse(online_features_response) def _retrieve_from_online_store( - self, - provider: Provider, - table: FeatureView, - requested_feature: str, - query: List[float], - top_k: int, - distance_metric: Optional[str], + self, + provider: Provider, + table: FeatureView, + requested_feature: str, + query: List[float], + top_k: int, + distance_metric: Optional[str], ) -> List[ Tuple[ Timestamp, Optional[EntityKey], "FieldStatus.ValueType", Value, Value, Value @@ -1782,15 +1748,15 @@ def _retrieve_from_online_store( return read_row_protos def serve( - self, - host: str, - port: int, - type_: str = "http", - no_access_log: bool = True, - workers: int = 1, - metrics: bool = False, - keep_alive_timeout: int = 30, - registry_ttl_sec: int = 2, + self, + host: str, + port: int, + type_: str = "http", + no_access_log: bool = True, + workers: int = 1, + metrics: bool = False, + keep_alive_timeout: int = 30, + registry_ttl_sec: int = 2, ) -> None: """Start the feature consumption server locally on a given port.""" type_ = type_.lower() @@ -1815,12 +1781,12 @@ def get_feature_server_endpoint(self) -> Optional[str]: return self._provider.get_feature_server_endpoint() def serve_ui( - self, - host: str, - port: int, - get_registry_dump: Callable, - registry_ttl_sec: int, - root_path: str = "", + self, + host: str, + port: int, + get_registry_dump: Callable, + registry_ttl_sec: int, + root_path: str = "", ) -> None: """Start the UI server locally""" if flags_helper.is_test(): @@ -1839,23 +1805,19 @@ def serve_ui( root_path=root_path, ) - def serve_registry(self, - port: int) -> None: + def serve_registry(self, port: int) -> None: """Start registry server locally on a given port.""" from feast import registry_server registry_server.start_server(self, port) - def serve_offline(self, - host: str, - port: int) -> None: + def serve_offline(self, host: str, port: int) -> None: """Start offline server locally on a given port.""" from feast import offline_server offline_server.start_server(self, host, port) - def serve_transformations(self, - port: int) -> None: + def serve_transformations(self, port: int) -> None: """Start the feature transformation server locally on a given port.""" warnings.warn( "On demand feature view is an experimental feature. " @@ -1868,9 +1830,7 @@ def serve_transformations(self, transformation_server.start_server(self, port) def write_logged_features( - self, - logs: Union[pa.Table, Path], - source: FeatureService + self, logs: Union[pa.Table, Path], source: FeatureService ): """ Write logs produced by a source (currently only feature service is supported as a source) @@ -1884,7 +1844,7 @@ def write_logged_features( raise ValueError("Only feature service is currently supported as a source") assert ( - source.logging_config is not None + source.logging_config is not None ), "Feature service must be configured with logging config in order to use this functionality" assert isinstance(logs, (pa.Table, Path)) @@ -1897,13 +1857,13 @@ def write_logged_features( ) def validate_logged_features( - self, - source: FeatureService, - start: datetime, - end: datetime, - reference: ValidationReference, - throw_exception: bool = True, - cache_profile: bool = True, + self, + source: FeatureService, + start: datetime, + end: datetime, + reference: ValidationReference, + throw_exception: bool = True, + cache_profile: bool = True, ) -> Optional[ValidationFailed]: """ Load logged features from an offline store and validate them against provided validation reference. @@ -1957,9 +1917,7 @@ def validate_logged_features( return None def get_validation_reference( - self, - name: str, - allow_cache: bool = False + self, name: str, allow_cache: bool = False ) -> ValidationReference: """ Retrieves a validation reference. @@ -1974,9 +1932,7 @@ def get_validation_reference( return ref def list_validation_references( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[ValidationReference]: """ Retrieves the list of validation references from the registry. @@ -1993,9 +1949,7 @@ def list_validation_references( ) def list_permissions( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[Permission]: """ Retrieves the list of permissions from the registry. @@ -2011,8 +1965,7 @@ def list_permissions( self.project, allow_cache=allow_cache, tags=tags ) - def get_permission(self, - name: str) -> Permission: + def get_permission(self, name: str) -> Permission: """ Retrieves a permission from the registry. @@ -2028,9 +1981,7 @@ def get_permission(self, return self._registry.get_permission(name, self.project) def list_projects( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[Project]: """ Retrieves the list of projects from the registry. @@ -2044,8 +1995,7 @@ def list_projects( """ return self._registry.list_projects(allow_cache=allow_cache, tags=tags) - def get_project(self, - name: Optional[str]) -> Project: + def get_project(self, name: Optional[str]) -> Project: """ Retrieves a project from the registry. @@ -2061,9 +2011,7 @@ def get_project(self, return self._registry.get_project(name or self.project) def list_saved_datasets( - self, - allow_cache: bool = False, - tags: Optional[dict[str, str]] = None + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[SavedDataset]: """ Retrieves the list of saved datasets from the registry. @@ -2081,10 +2029,7 @@ def list_saved_datasets( def _print_materialization_log( - start_date, - end_date, - num_feature_views: int, - online_store: str + start_date, end_date, num_feature_views: int, online_store: str ): if start_date: print( From 23e54382d53b23037e89eb242d610519bb3b255a Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 23:17:34 -0700 Subject: [PATCH 20/26] fix test Signed-off-by: cmuhao --- .../feast/infra/online_stores/contrib/postgres.py | 11 ++++++----- sdk/python/feast/utils.py | 7 +++++-- .../integration/feature_repos/repo_configuration.py | 2 +- .../integration/online_store/test_universal_online.py | 3 +++ 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 7652f77b8a..af2913b539 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -16,10 +16,11 @@ Union, ) +import numpy as np from psycopg import AsyncConnection, sql from psycopg.connection import Connection from psycopg_pool import AsyncConnectionPool, ConnectionPool -from utils import _build_retrieve_online_document_record +from feast.utils import _build_retrieve_online_document_record from feast import Entity from feast.feature_view import FeatureView @@ -394,7 +395,7 @@ def retrieve_online_documents( distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric] # Convert the embedding to a string to be used in postgres vector search - query_embedding_str = f"[{','.join(str(el) for el in embedding)}]" + query_embedding_array = np.array(embedding) result: List[ Tuple[ @@ -418,19 +419,19 @@ def retrieve_online_documents( feature_name, value, vector_value, - vector_value {distance_metric_sql} %s as distance, + vector_value {distance_metric_sql} %s::vector as distance, event_ts FROM {table_name} WHERE feature_name = {feature_name} ORDER BY distance LIMIT {top_k}; """ ).format( - distance_metric_sql=distance_metric_sql, + distance_metric_sql=sql.SQL(distance_metric_sql), table_name=sql.Identifier(table_name), feature_name=sql.Literal(requested_feature), top_k=sql.Literal(top_k), ), - (query_embedding_str,), + (embedding,), ) rows = cur.fetchall() for ( diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 71db2c1478..a6d7853e1b 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -1054,7 +1054,7 @@ def _utc_now() -> datetime: def _build_retrieve_online_document_record( - entity_key: str, + entity_key: Union[str, bytes], feature_value: Union[str, bytes], vector_value: Union[str, List[float]], distance_value: float, @@ -1070,7 +1070,10 @@ def _build_retrieve_online_document_record( if entity_key_serialization_version < 3: entity_key_proto = None else: - entity_key_proto_bin = entity_key.encode("utf-8") + if isinstance(entity_key, str): + entity_key_proto_bin = entity_key.encode("utf-8") + else: + entity_key_proto_bin = entity_key entity_key_proto = deserialize_entity_key( entity_key_proto_bin, entity_key_serialization_version=entity_key_serialization_version, diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 73f99fb7c2..3b2dad2d6a 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -520,7 +520,7 @@ def construct_test_environment( fixture_request: Optional[pytest.FixtureRequest], test_suite_name: str = "integration_test", worker_id: str = "worker_id", - entity_key_serialization_version: int = 2, + entity_key_serialization_version: int = 3, ) -> Environment: _uuid = str(uuid.uuid4()).replace("-", "")[:6] diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 308201590d..55c73a49db 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -861,6 +861,9 @@ def test_retrieve_online_documents(environment, fake_document_data): ).to_dict() assert len(documents["embedding_float"]) == 2 + # assert returned the entity_id + assert len(documents["item_id"]) == 2 + documents = fs.retrieve_online_documents( feature="item_embeddings:embedding_float", query=[1.0, 2.0], From a649ad11651d75fe1862beaa9a82a601d7f308bd Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 23:17:50 -0700 Subject: [PATCH 21/26] fix test Signed-off-by: cmuhao --- sdk/python/feast/infra/online_stores/contrib/postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index af2913b539..1f579f155f 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -20,7 +20,6 @@ from psycopg import AsyncConnection, sql from psycopg.connection import Connection from psycopg_pool import AsyncConnectionPool, ConnectionPool -from feast.utils import _build_retrieve_online_document_record from feast import Entity from feast.feature_view import FeatureView @@ -39,6 +38,7 @@ from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RepoConfig +from feast.utils import _build_retrieve_online_document_record SUPPORTED_DISTANCE_METRICS_DICT = { "cosine": "<=>", From ca89dcb114c0148d5255b1bf35d8a011f323dffc Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 23:18:55 -0700 Subject: [PATCH 22/26] fix test Signed-off-by: cmuhao --- sdk/python/feast/infra/online_stores/contrib/postgres.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 1f579f155f..1c929a4bb7 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -394,8 +394,6 @@ def retrieve_online_documents( ) distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric] - # Convert the embedding to a string to be used in postgres vector search - query_embedding_array = np.array(embedding) result: List[ Tuple[ From 90bceff455a4ca6e846c635e5de43c67ada0dc0b Mon Sep 17 00:00:00 2001 From: cmuhao Date: Tue, 10 Sep 2024 23:19:11 -0700 Subject: [PATCH 23/26] fix test Signed-off-by: cmuhao --- sdk/python/feast/infra/online_stores/contrib/postgres.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 1c929a4bb7..8125da33be 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -16,7 +16,6 @@ Union, ) -import numpy as np from psycopg import AsyncConnection, sql from psycopg.connection import Connection from psycopg_pool import AsyncConnectionPool, ConnectionPool From fe9cdf0914fb5cc86175c044cceb443c7d8dbd93 Mon Sep 17 00:00:00 2001 From: cmuhao Date: Thu, 19 Sep 2024 10:39:44 -0700 Subject: [PATCH 24/26] fix test Signed-off-by: cmuhao --- sdk/python/tests/conftest.py | 19 +++++++++++++++++++ .../online_store/test_universal_online.py | 4 ++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index a9bb9ba9c4..50763a73a3 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -196,6 +196,25 @@ def environment(request, worker_id): e.teardown() +@pytest.fixture +def vectordb_environment(request, worker_id): + e = construct_test_environment( + request.param, + worker_id=worker_id, + fixture_request=request, + entity_key_serialization_version=3 + ) + + e.setup() + + if hasattr(e.data_source_creator, "mock_environ"): + with mock.patch.dict(os.environ, e.data_source_creator.mock_environ): + yield e + else: + yield e + + e.teardown() + _config_cache: Any = {} diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 55c73a49db..1a0803acff 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -846,8 +846,8 @@ def assert_feature_service_entity_mapping_correctness( @pytest.mark.integration @pytest.mark.universal_online_stores(only=["pgvector", "elasticsearch"]) -def test_retrieve_online_documents(environment, fake_document_data): - fs = environment.feature_store +def test_retrieve_online_documents(vectordb_environment, fake_document_data): + fs = vectordb_environment.feature_store df, data_source = fake_document_data item_embeddings_feature_view = create_item_embeddings_feature_view(data_source) fs.apply([item_embeddings_feature_view, item()]) From f9ffb8504d99da33b20a5fe427a296dd8eb4c7cf Mon Sep 17 00:00:00 2001 From: cmuhao Date: Thu, 19 Sep 2024 10:40:34 -0700 Subject: [PATCH 25/26] fix test Signed-off-by: cmuhao --- sdk/python/tests/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index 50763a73a3..08b8757b95 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -196,13 +196,14 @@ def environment(request, worker_id): e.teardown() + @pytest.fixture def vectordb_environment(request, worker_id): e = construct_test_environment( request.param, worker_id=worker_id, fixture_request=request, - entity_key_serialization_version=3 + entity_key_serialization_version=3, ) e.setup() From 3b3fa0eab3dac7bd9a590f5892fcf9f4b4d124bb Mon Sep 17 00:00:00 2001 From: cmuhao Date: Thu, 19 Sep 2024 23:02:33 -0700 Subject: [PATCH 26/26] fix test Signed-off-by: cmuhao --- .../tests/integration/feature_repos/repo_configuration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 3b2dad2d6a..73f99fb7c2 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -520,7 +520,7 @@ def construct_test_environment( fixture_request: Optional[pytest.FixtureRequest], test_suite_name: str = "integration_test", worker_id: str = "worker_id", - entity_key_serialization_version: int = 3, + entity_key_serialization_version: int = 2, ) -> Environment: _uuid = str(uuid.uuid4()).replace("-", "")[:6]