Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Pgvector patch #4108

Merged
merged 7 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,12 +1740,14 @@ def _retrieve_online_documents(
query,
top_k,
)
document_feature_vals = [feature[2] for feature in document_features]
document_feature_distance_vals = [feature[3] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])

# TODO Refactor to better way of populating result
# TODO populate entity in the response after returning entity in document_features is supported
# TODO currently not return the vector value since it is same as feature value, if embedding is supported,
# the feature value can be raw text before embedded
document_feature_vals = [feature[2] for feature in document_features]
document_feature_distance_vals = [feature[4] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])
self._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={requested_feature: document_feature_vals},
Expand Down Expand Up @@ -1979,7 +1981,7 @@ def _retrieve_from_online_store(
requested_feature: str,
query: List[float],
top_k: int,
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value]]:
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value, Value]]:
"""
Search and return document features from the online document store.
"""
Expand All @@ -1994,19 +1996,22 @@ def _retrieve_from_online_store(
read_row_protos = []
row_ts_proto = Timestamp()

for row_ts, feature_val, distance_val in documents:
for row_ts, feature_val, vector_value, distance_val in documents:
# Reset timestamp to default or update if row_ts is not None
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)

if feature_val is None or distance_val is None:
if feature_val is None or vector_value is None or distance_val is None:
feature_val = Value()
vector_value = Value()
distance_val = Value()
status = FieldStatus.NOT_FOUND
else:
status = FieldStatus.PRESENT

read_row_protos.append((row_ts_proto, status, feature_val, distance_val))
read_row_protos.append(
(row_ts_proto, status, feature_val, vector_value, distance_val)
)
return read_row_protos

@staticmethod
Expand Down
55 changes: 37 additions & 18 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ def online_write_batch(

for feature_name, val in values.items():
vector_val = None
if (
"pgvector_enabled" in config.online_store
and config.online_store.pgvector_enabled
):
if config.online_store.pgvector_enabled:
vector_val = get_list_val_str(val)
insert_values.append(
(
Expand Down Expand Up @@ -226,10 +223,7 @@ def update(

for table in tables_to_keep:
table_name = _table_id(project, table)
if (
"pgvector_enabled" in config.online_store
and config.online_store.pgvector_enabled
):
if config.online_store.pgvector_enabled:
vector_value_type = f"vector({config.online_store.vector_len})"
else:
# keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility
Expand Down Expand Up @@ -282,7 +276,14 @@ def retrieve_online_documents(
requested_feature: str,
embedding: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
) -> List[
Tuple[
Optional[datetime],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]
]:
"""

Args:
Expand All @@ -297,10 +298,7 @@ def retrieve_online_documents(
"""
project = config.project

if (
"pgvector_enabled" not in config.online_store
or not config.online_store.pgvector_enabled
):
if not config.online_store.pgvector_enabled:
raise ValueError(
"pgvector is not enabled in the online store configuration"
)
Expand All @@ -309,7 +307,12 @@ def retrieve_online_documents(
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"

result: List[
Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]
Tuple[
Optional[datetime],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]
] = []
with self._get_conn(config) as conn, conn.cursor() as cur:
table_name = _table_id(project, table)
Expand All @@ -322,6 +325,7 @@ def retrieve_online_documents(
SELECT
entity_key,
feature_name,
value,
vector_value,
vector_value <-> %s as distance,
event_ts FROM {table_name}
Expand All @@ -338,16 +342,31 @@ def retrieve_online_documents(
)
rows = cur.fetchall()

for entity_key, feature_name, vector_value, distance, event_ts in rows:
for (
entity_key,
feature_name,
value,
vector_value,
distance,
event_ts,
) in rows:
# TODO Deserialize entity_key to return the entity in response
# entity_key_proto = EntityKeyProto()
# entity_key_proto_bin = bytes(entity_key)

# TODO Convert to List[float] for value type proto
feature_value_proto = ValueProto(string_val=vector_value)
feature_value_proto = ValueProto()
feature_value_proto.ParseFromString(bytes(value))

vector_value_proto = ValueProto(string_val=vector_value)
distance_value_proto = ValueProto(float_val=distance)
result.append((event_ts, feature_value_proto, distance_value_proto))
result.append(
(
event_ts,
feature_value_proto,
vector_value_proto,
distance_value_proto,
)
)

return result

Expand Down
9 changes: 8 additions & 1 deletion sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,14 @@ def retrieve_online_documents(
requested_feature: str,
embedding: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
) -> List[
Tuple[
Optional[datetime],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]
]:
"""
Retrieves online feature values for the specified embeddings.

Expand Down
9 changes: 8 additions & 1 deletion sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,14 @@ def retrieve_online_documents(
requested_feature: str,
query: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
) -> List[
Tuple[
Optional[datetime],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]
]:
"""
Searches for the top-k nearest neighbors of the given document in the online document store.

Expand Down
9 changes: 8 additions & 1 deletion sdk/python/tests/foo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,12 @@ def retrieve_online_documents(
requested_feature: str,
query: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
) -> List[
Tuple[
Optional[datetime],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]
]:
return []
Loading