diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 4f96cfb0fc..ab2bc6cec2 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -78,7 +78,9 @@ FieldStatus, GetOnlineFeaturesResponse, ) +from feast.protos.feast.types.EntityKey_pb2 import EntityKey from feast.protos.feast.types.Value_pb2 import RepeatedValue, Value +from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RepoConfig, load_repo_config from feast.repo_contents import RepoContents from feast.saved_dataset import SavedDataset, SavedDatasetStorage, ValidationReference @@ -1666,20 +1668,29 @@ def retrieve_online_documents( distance_metric, ) - # TODO Refactor to better way of populating result - # TODO populate entity in the response after returning entity in document_features is supported # TODO currently not return the vector value since it is same as feature value, if embedding is supported, # the feature value can be raw text before embedded - document_feature_vals = [feature[2] for feature in document_features] - document_feature_distance_vals = [feature[4] for feature in document_features] + entity_key_vals = [feature[1] for feature in document_features] + join_key_values: Dict[str, List[ValueProto]] = {} + for entity_key_val in entity_key_vals: + if entity_key_val is not None: + for join_key, entity_value in zip( + entity_key_val.join_keys, entity_key_val.entity_values + ): + if join_key not in join_key_values: + join_key_values[join_key] = [] + join_key_values[join_key].append(entity_value) + + document_feature_vals = [feature[4] for feature in document_features] + document_feature_distance_vals = [feature[5] for feature in document_features] online_features_response = GetOnlineFeaturesResponse(results=[]) utils._populate_result_rows_from_columnar( online_features_response=online_features_response, - data={requested_feature: document_feature_vals}, - ) - utils._populate_result_rows_from_columnar( - online_features_response=online_features_response, - data={"distance": document_feature_distance_vals}, + data={ + **join_key_values, + requested_feature: document_feature_vals, + "distance": document_feature_distance_vals, + }, ) return OnlineResponse(online_features_response) @@ -1691,7 +1702,11 @@ def _retrieve_from_online_store( query: List[float], top_k: int, distance_metric: Optional[str], - ) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value, Value]]: + ) -> List[ + Tuple[ + Timestamp, Optional[EntityKey], "FieldStatus.ValueType", Value, Value, Value + ] + ]: """ Search and return document features from the online document store. """ @@ -1707,7 +1722,7 @@ def _retrieve_from_online_store( read_row_protos = [] row_ts_proto = Timestamp() - for row_ts, feature_val, vector_value, distance_val in documents: + for row_ts, entity_key, feature_val, vector_value, distance_val in documents: # Reset timestamp to default or update if row_ts is not None if row_ts is not None: row_ts_proto.FromDatetime(row_ts) @@ -1721,7 +1736,14 @@ def _retrieve_from_online_store( status = FieldStatus.PRESENT read_row_protos.append( - (row_ts_proto, status, feature_val, vector_value, distance_val) + ( + row_ts_proto, + entity_key, + status, + feature_val, + vector_value, + distance_val, + ) ) return read_row_protos diff --git a/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py b/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py index c26b4199ae..a0c25b931a 100644 --- a/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/contrib/elasticsearch.py @@ -9,12 +9,15 @@ from elasticsearch import Elasticsearch, helpers from feast import Entity, FeatureView, RepoConfig -from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key +from feast.infra.key_encoding_utils import ( + get_list_val_str, + serialize_entity_key, +) from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel -from feast.utils import to_naive_utc +from feast.utils import _build_retrieve_online_document_record, to_naive_utc class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel): @@ -224,6 +227,7 @@ def retrieve_online_documents( ) -> List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], @@ -232,6 +236,7 @@ def retrieve_online_documents( result: List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], @@ -247,23 +252,21 @@ def retrieve_online_documents( ) rows = response["hits"]["hits"][0:top_k] for row in rows: + entity_key = row["_source"]["entity_key"] feature_value = row["_source"]["feature_value"] vector_value = row["_source"]["vector_value"] timestamp = row["_source"]["timestamp"] distance = row["_score"] timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") - feature_value_proto = ValueProto() - feature_value_proto.ParseFromString(base64.b64decode(feature_value)) - - vector_value_proto = ValueProto(string_val=str(vector_value)) - distance_value_proto = ValueProto(float_val=distance) result.append( - ( + _build_retrieve_online_document_record( + entity_key, + base64.b64decode(feature_value), + str(vector_value), + distance, timestamp, - feature_value_proto, - vector_value_proto, - distance_value_proto, + config.entity_key_serialization_version, ) ) return result diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 8c6d3e0b99..8125da33be 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -37,6 +37,7 @@ from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RepoConfig +from feast.utils import _build_retrieve_online_document_record SUPPORTED_DISTANCE_METRICS_DICT = { "cosine": "<=>", @@ -360,6 +361,7 @@ def retrieve_online_documents( ) -> List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], @@ -391,12 +393,11 @@ def retrieve_online_documents( ) distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric] - # Convert the embedding to a string to be used in postgres vector search - query_embedding_str = f"[{','.join(str(el) for el in embedding)}]" result: List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], @@ -415,45 +416,37 @@ def retrieve_online_documents( feature_name, value, vector_value, - vector_value {distance_metric_sql} %s as distance, + vector_value {distance_metric_sql} %s::vector as distance, event_ts FROM {table_name} WHERE feature_name = {feature_name} ORDER BY distance LIMIT {top_k}; """ ).format( - distance_metric_sql=distance_metric_sql, + distance_metric_sql=sql.SQL(distance_metric_sql), table_name=sql.Identifier(table_name), feature_name=sql.Literal(requested_feature), top_k=sql.Literal(top_k), ), - (query_embedding_str,), + (embedding,), ) rows = cur.fetchall() - for ( entity_key, - feature_name, - value, + _, + feature_val, vector_value, - distance, + distance_val, event_ts, ) in rows: - # TODO Deserialize entity_key to return the entity in response - # entity_key_proto = EntityKeyProto() - # entity_key_proto_bin = bytes(entity_key) - - feature_value_proto = ValueProto() - feature_value_proto.ParseFromString(bytes(value)) - - vector_value_proto = ValueProto(string_val=vector_value) - distance_value_proto = ValueProto(float_val=distance) result.append( - ( + _build_retrieve_online_document_record( + entity_key, + feature_val, + vector_value, + distance_val, event_ts, - feature_value_proto, - vector_value_proto, - distance_value_proto, + config.entity_key_serialization_version, ) ) diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index 9cf2ef95f6..fdb5b055cf 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -349,6 +349,7 @@ def retrieve_online_documents( ) -> List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index 9896b766d4..061a766b8c 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -33,10 +33,9 @@ from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.protos.feast.core.SqliteTable_pb2 import SqliteTable as SqliteTableProto from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel, RepoConfig -from feast.utils import to_naive_utc +from feast.utils import _build_retrieve_online_document_record, to_naive_utc class SqliteOnlineStoreConfig(FeastConfigBaseModel): @@ -303,6 +302,7 @@ def retrieve_online_documents( ) -> List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], @@ -385,6 +385,7 @@ def retrieve_online_documents( result: List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], @@ -392,19 +393,14 @@ def retrieve_online_documents( ] = [] for entity_key, _, string_value, distance, event_ts in rows: - feature_value_proto = ValueProto() - feature_value_proto.ParseFromString(string_value if string_value else b"") - vector_value_proto = ValueProto( - float_list_val=FloatListProto(val=embedding) - ) - distance_value_proto = ValueProto(float_val=distance) - result.append( - ( + _build_retrieve_online_document_record( + entity_key, + string_value if string_value else b"", + embedding, + distance, event_ts, - feature_value_proto, - vector_value_proto, - distance_value_proto, + config.entity_key_serialization_version, ) ) diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 9940af1d02..c0062dde02 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -364,6 +364,7 @@ def retrieve_online_documents( ) -> List[ Tuple[ Optional[datetime], + Optional[EntityKeyProto], Optional[ValueProto], Optional[ValueProto], Optional[ValueProto], diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 992869557a..a6d7853e1b 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -33,11 +33,13 @@ FeatureViewNotFoundException, RequestDataNotFoundInEntityRowsException, ) +from feast.infra.key_encoding_utils import deserialize_entity_key from feast.protos.feast.serving.ServingService_pb2 import ( FieldStatus, GetOnlineFeaturesResponse, ) from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto from feast.protos.feast.types.Value_pb2 import RepeatedValue as RepeatedValueProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.type_map import python_values_to_proto_values @@ -49,7 +51,6 @@ from feast.feature_view import FeatureView from feast.on_demand_feature_view import OnDemandFeatureView - APPLICATION_NAME = "feast-dev/feast" USER_AGENT = "{}/{}".format(APPLICATION_NAME, get_version()) @@ -1050,3 +1051,51 @@ def tags_str_to_dict(tags: str = "") -> dict[str, str]: def _utc_now() -> datetime: return datetime.now(tz=timezone.utc) + + +def _build_retrieve_online_document_record( + entity_key: Union[str, bytes], + feature_value: Union[str, bytes], + vector_value: Union[str, List[float]], + distance_value: float, + event_timestamp: datetime, + entity_key_serialization_version: int, +) -> Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], +]: + if entity_key_serialization_version < 3: + entity_key_proto = None + else: + if isinstance(entity_key, str): + entity_key_proto_bin = entity_key.encode("utf-8") + else: + entity_key_proto_bin = entity_key + entity_key_proto = deserialize_entity_key( + entity_key_proto_bin, + entity_key_serialization_version=entity_key_serialization_version, + ) + + feature_value_proto = ValueProto() + + if isinstance(feature_value, str): + feature_value_proto.ParseFromString(feature_value.encode("utf-8")) + else: + feature_value_proto.ParseFromString(feature_value) + + if isinstance(vector_value, str): + vector_value_proto = ValueProto(string_val=vector_value) + else: + vector_value_proto = ValueProto(float_list_val=FloatListProto(val=vector_value)) + + distance_value_proto = ValueProto(float_val=distance_value) + return ( + event_timestamp, + entity_key_proto, + feature_value_proto, + vector_value_proto, + distance_value_proto, + ) diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index a9bb9ba9c4..08b8757b95 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -197,6 +197,26 @@ def environment(request, worker_id): e.teardown() +@pytest.fixture +def vectordb_environment(request, worker_id): + e = construct_test_environment( + request.param, + worker_id=worker_id, + fixture_request=request, + entity_key_serialization_version=3, + ) + + e.setup() + + if hasattr(e.data_source_creator, "mock_environ"): + with mock.patch.dict(os.environ, e.data_source_creator.mock_environ): + yield e + else: + yield e + + e.teardown() + + _config_cache: Any = {} diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 308201590d..1a0803acff 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -846,8 +846,8 @@ def assert_feature_service_entity_mapping_correctness( @pytest.mark.integration @pytest.mark.universal_online_stores(only=["pgvector", "elasticsearch"]) -def test_retrieve_online_documents(environment, fake_document_data): - fs = environment.feature_store +def test_retrieve_online_documents(vectordb_environment, fake_document_data): + fs = vectordb_environment.feature_store df, data_source = fake_document_data item_embeddings_feature_view = create_item_embeddings_feature_view(data_source) fs.apply([item_embeddings_feature_view, item()]) @@ -861,6 +861,9 @@ def test_retrieve_online_documents(environment, fake_document_data): ).to_dict() assert len(documents["embedding_float"]) == 2 + # assert returned the entity_id + assert len(documents["item_id"]) == 2 + documents = fs.retrieve_online_documents( feature="item_embeddings:embedding_float", query=[1.0, 2.0],