Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Return entity key in the retrieval document api #4511

Merged
merged 27 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@
FieldStatus,
GetOnlineFeaturesResponse,
)
from feast.protos.feast.types.EntityKey_pb2 import EntityKey
from feast.protos.feast.types.Value_pb2 import RepeatedValue, Value
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import RepoConfig, load_repo_config
from feast.repo_contents import RepoContents
from feast.saved_dataset import SavedDataset, SavedDatasetStorage, ValidationReference
Expand Down Expand Up @@ -1666,20 +1668,29 @@ def retrieve_online_documents(
distance_metric,
)

# TODO Refactor to better way of populating result
# TODO populate entity in the response after returning entity in document_features is supported
# TODO currently not return the vector value since it is same as feature value, if embedding is supported,
# the feature value can be raw text before embedded
document_feature_vals = [feature[2] for feature in document_features]
document_feature_distance_vals = [feature[4] for feature in document_features]
entity_key_vals = [feature[1] for feature in document_features]
join_key_values: Dict[str, List[ValueProto]] = {}
for entity_key_val in entity_key_vals:
if entity_key_val is not None:
for join_key, entity_value in zip(
entity_key_val.join_keys, entity_key_val.entity_values
):
if join_key not in join_key_values:
join_key_values[join_key] = []
join_key_values[join_key].append(entity_value)

document_feature_vals = [feature[4] for feature in document_features]
document_feature_distance_vals = [feature[5] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])
utils._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={requested_feature: document_feature_vals},
)
utils._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={"distance": document_feature_distance_vals},
data={
**join_key_values,
requested_feature: document_feature_vals,
"distance": document_feature_distance_vals,
},
)
return OnlineResponse(online_features_response)

Expand All @@ -1691,7 +1702,11 @@ def _retrieve_from_online_store(
query: List[float],
top_k: int,
distance_metric: Optional[str],
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value, Value]]:
) -> List[
Tuple[
Timestamp, Optional[EntityKey], "FieldStatus.ValueType", Value, Value, Value
]
]:
"""
Search and return document features from the online document store.
"""
Expand All @@ -1707,7 +1722,7 @@ def _retrieve_from_online_store(
read_row_protos = []
row_ts_proto = Timestamp()

for row_ts, feature_val, vector_value, distance_val in documents:
for row_ts, entity_key, feature_val, vector_value, distance_val in documents:
# Reset timestamp to default or update if row_ts is not None
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)
Expand All @@ -1721,7 +1736,14 @@ def _retrieve_from_online_store(
status = FieldStatus.PRESENT

read_row_protos.append(
(row_ts_proto, status, feature_val, vector_value, distance_val)
(
row_ts_proto,
entity_key,
status,
feature_val,
vector_value,
distance_val,
)
)
return read_row_protos

Expand Down
25 changes: 14 additions & 11 deletions sdk/python/feast/infra/online_stores/contrib/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
from elasticsearch import Elasticsearch, helpers

from feast import Entity, FeatureView, RepoConfig
from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key
from feast.infra.key_encoding_utils import (
get_list_val_str,
serialize_entity_key,
)
from feast.infra.online_stores.online_store import OnlineStore
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel
from feast.utils import to_naive_utc
from feast.utils import _build_retrieve_online_document_record, to_naive_utc


class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel):
Expand Down Expand Up @@ -224,6 +227,7 @@ def retrieve_online_documents(
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand All @@ -232,6 +236,7 @@ def retrieve_online_documents(
result: List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand All @@ -247,23 +252,21 @@ def retrieve_online_documents(
)
rows = response["hits"]["hits"][0:top_k]
for row in rows:
entity_key = row["_source"]["entity_key"]
feature_value = row["_source"]["feature_value"]
vector_value = row["_source"]["vector_value"]
timestamp = row["_source"]["timestamp"]
distance = row["_score"]
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f")

feature_value_proto = ValueProto()
feature_value_proto.ParseFromString(base64.b64decode(feature_value))

vector_value_proto = ValueProto(string_val=str(vector_value))
distance_value_proto = ValueProto(float_val=distance)
result.append(
(
_build_retrieve_online_document_record(
entity_key,
base64.b64decode(feature_value),
str(vector_value),
distance,
timestamp,
feature_value_proto,
vector_value_proto,
distance_value_proto,
config.entity_key_serialization_version,
)
)
return result
37 changes: 15 additions & 22 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import RepoConfig
from feast.utils import _build_retrieve_online_document_record

SUPPORTED_DISTANCE_METRICS_DICT = {
"cosine": "<=>",
Expand Down Expand Up @@ -360,6 +361,7 @@ def retrieve_online_documents(
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand Down Expand Up @@ -391,12 +393,11 @@ def retrieve_online_documents(
)

distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric]
# Convert the embedding to a string to be used in postgres vector search
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"

result: List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand All @@ -415,45 +416,37 @@ def retrieve_online_documents(
feature_name,
value,
vector_value,
vector_value {distance_metric_sql} %s as distance,
vector_value {distance_metric_sql} %s::vector as distance,
event_ts FROM {table_name}
WHERE feature_name = {feature_name}
ORDER BY distance
LIMIT {top_k};
"""
).format(
distance_metric_sql=distance_metric_sql,
distance_metric_sql=sql.SQL(distance_metric_sql),
table_name=sql.Identifier(table_name),
feature_name=sql.Literal(requested_feature),
top_k=sql.Literal(top_k),
),
(query_embedding_str,),
(embedding,),
)
rows = cur.fetchall()

for (
entity_key,
feature_name,
value,
_,
feature_val,
vector_value,
distance,
distance_val,
event_ts,
) in rows:
# TODO Deserialize entity_key to return the entity in response
# entity_key_proto = EntityKeyProto()
# entity_key_proto_bin = bytes(entity_key)

feature_value_proto = ValueProto()
feature_value_proto.ParseFromString(bytes(value))

vector_value_proto = ValueProto(string_val=vector_value)
distance_value_proto = ValueProto(float_val=distance)
result.append(
(
_build_retrieve_online_document_record(
entity_key,
feature_val,
vector_value,
distance_val,
event_ts,
feature_value_proto,
vector_value_proto,
distance_value_proto,
config.entity_key_serialization_version,
)
)

Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def retrieve_online_documents(
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand Down
22 changes: 9 additions & 13 deletions sdk/python/feast/infra/online_stores/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
from feast.protos.feast.core.SqliteTable_pb2 import SqliteTable as SqliteTableProto
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.utils import to_naive_utc
from feast.utils import _build_retrieve_online_document_record, to_naive_utc


class SqliteOnlineStoreConfig(FeastConfigBaseModel):
Expand Down Expand Up @@ -303,6 +302,7 @@ def retrieve_online_documents(
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand Down Expand Up @@ -385,26 +385,22 @@ def retrieve_online_documents(
result: List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]
] = []

for entity_key, _, string_value, distance, event_ts in rows:
feature_value_proto = ValueProto()
feature_value_proto.ParseFromString(string_value if string_value else b"")
vector_value_proto = ValueProto(
float_list_val=FloatListProto(val=embedding)
)
distance_value_proto = ValueProto(float_val=distance)

result.append(
(
_build_retrieve_online_document_record(
entity_key,
string_value if string_value else b"",
embedding,
distance,
event_ts,
feature_value_proto,
vector_value_proto,
distance_value_proto,
config.entity_key_serialization_version,
)
)

Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def retrieve_online_documents(
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand Down
51 changes: 50 additions & 1 deletion sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
FeatureViewNotFoundException,
RequestDataNotFoundInEntityRowsException,
)
from feast.infra.key_encoding_utils import deserialize_entity_key
from feast.protos.feast.serving.ServingService_pb2 import (
FieldStatus,
GetOnlineFeaturesResponse,
)
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto
from feast.protos.feast.types.Value_pb2 import RepeatedValue as RepeatedValueProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.type_map import python_values_to_proto_values
Expand All @@ -49,7 +51,6 @@
from feast.feature_view import FeatureView
from feast.on_demand_feature_view import OnDemandFeatureView


APPLICATION_NAME = "feast-dev/feast"
USER_AGENT = "{}/{}".format(APPLICATION_NAME, get_version())

Expand Down Expand Up @@ -1050,3 +1051,51 @@ def tags_str_to_dict(tags: str = "") -> dict[str, str]:

def _utc_now() -> datetime:
return datetime.now(tz=timezone.utc)


def _build_retrieve_online_document_record(
entity_key: Union[str, bytes],
feature_value: Union[str, bytes],
vector_value: Union[str, List[float]],
distance_value: float,
event_timestamp: datetime,
entity_key_serialization_version: int,
) -> Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]:
if entity_key_serialization_version < 3:
entity_key_proto = None
else:
if isinstance(entity_key, str):
entity_key_proto_bin = entity_key.encode("utf-8")
else:
entity_key_proto_bin = entity_key
entity_key_proto = deserialize_entity_key(
entity_key_proto_bin,
entity_key_serialization_version=entity_key_serialization_version,
)

feature_value_proto = ValueProto()

if isinstance(feature_value, str):
feature_value_proto.ParseFromString(feature_value.encode("utf-8"))
else:
feature_value_proto.ParseFromString(feature_value)

if isinstance(vector_value, str):
vector_value_proto = ValueProto(string_val=vector_value)
else:
vector_value_proto = ValueProto(float_list_val=FloatListProto(val=vector_value))

distance_value_proto = ValueProto(float_val=distance_value)
return (
event_timestamp,
entity_key_proto,
feature_value_proto,
vector_value_proto,
distance_value_proto,
)
Loading
Loading