Skip to content

Commit

Permalink
feat: Enable Vector database and retrieve_online_documents API (#4061)
Browse files Browse the repository at this point in the history
* feat: add document store

* feat: add document store

* feat: add document store

* feat: add document store

* remove DocumentStore

* format

* format

* format

* format

* format

* format

* remove unused vars

* add test

* add test

* format

* format

* format

* format

* format

* fix not implemented issue

* fix not implemented issue

* fix test

* format

* format

* format

* format

* format

* format

* update testcontainer

* format

* fix postgres integration test

* format

* fix postgres test

* fix postgres test

* fix postgres test

* fix postgres test

* fix postgres test

* format

* format

* format
  • Loading branch information
HaoXuAI authored Apr 15, 2024
1 parent 3c6ce86 commit ec19036
Show file tree
Hide file tree
Showing 15 changed files with 419 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ test-python-universal-postgres-offline:
test-python-universal-postgres-online:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.postgres_repo_configuration \
PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.postgres_offline_store.tests \
PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.postgres \
python -m pytest -n 8 --integration \
-k "not test_universal_cli and \
not test_go_feature_server and \
Expand Down
103 changes: 103 additions & 0 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,72 @@ def _get_online_features(
)
return OnlineResponse(online_features_response)

@log_exceptions_and_usage
def retrieve_online_documents(
self,
feature: str,
query: Union[str, List[float]],
top_k: int,
) -> OnlineResponse:
"""
Retrieves the top k closest document features. Note, embeddings are a subset of features.
Args:
feature: The list of document features that should be retrieved from the online document store. These features can be
specified either as a list of string document feature references or as a feature service. String feature
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
query: The query to retrieve the closest document features for.
top_k: The number of closest document features to retrieve.
"""
return self._retrieve_online_documents(
feature=feature,
query=query,
top_k=top_k,
)

def _retrieve_online_documents(
self,
feature: str,
query: Union[str, List[float]],
top_k: int,
):
if isinstance(query, str):
raise ValueError(
"Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents."
)
(
requested_feature_views,
_,
) = self._get_feature_views_to_use(
features=[feature], allow_cache=True, hide_dummy_entity=False
)
requested_feature = (
feature.split(":")[1] if isinstance(feature, str) else feature
)
provider = self._get_provider()
document_features = self._retrieve_from_online_store(
provider,
requested_feature_views[0],
requested_feature,
query,
top_k,
)
document_feature_vals = [feature[2] for feature in document_features]
document_feature_distance_vals = [feature[3] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])

# TODO Refactor to better way of populating result
# TODO populate entity in the response after returning entity in document_features is supported
self._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={requested_feature: document_feature_vals},
)
self._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={"distance": document_feature_distance_vals},
)
return OnlineResponse(online_features_response)

@staticmethod
def _get_columnar_entity_values(
rowise: Optional[List[Dict[str, Any]]], columnar: Optional[Dict[str, List[Any]]]
Expand Down Expand Up @@ -1906,6 +1972,43 @@ def _read_from_online_store(
read_row_protos.append((event_timestamps, statuses, values))
return read_row_protos

def _retrieve_from_online_store(
self,
provider: Provider,
table: FeatureView,
requested_feature: str,
query: List[float],
top_k: int,
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value]]:
"""
Search and return document features from the online document store.
"""
documents = provider.retrieve_online_documents(
config=self.config,
table=table,
requested_feature=requested_feature,
query=query,
top_k=top_k,
)

read_row_protos = []
row_ts_proto = Timestamp()

for row_ts, feature_val, distance_val in documents:
# Reset timestamp to default or update if row_ts is not None
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)

if feature_val is None or distance_val is None:
feature_val = Value()
distance_val = Value()
status = FieldStatus.NOT_FOUND
else:
status = FieldStatus.PRESENT

read_row_protos.append((row_ts_proto, status, feature_val, distance_val))
return read_row_protos

@staticmethod
def _populate_response_from_feature_data(
feature_data: Iterable[
Expand Down
8 changes: 8 additions & 0 deletions sdk/python/feast/infra/key_encoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,11 @@ def serialize_entity_key(
output.append(val_bytes)

return b"".join(output)


def get_val_str(val):
accept_value_types = ["float_list_val", "double_list_val", "int_list_val"]
for accept_type in accept_value_types:
if val.HasField(accept_type):
return str(getattr(val, accept_type).val)
return None
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from testcontainers.core.waiting_utils import wait_for_logs

from feast.data_source import DataSource
from feast.feature_logging import LoggingDestination
from feast.infra.offline_stores.contrib.postgres_offline_store.postgres import (
PostgreSQLOfflineStoreConfig,
PostgreSQLSource,
Expand Down Expand Up @@ -57,6 +58,9 @@ def postgres_container():


class PostgreSQLDataSourceCreator(DataSourceCreator, OnlineStoreCreator):
def create_logged_features_destination(self) -> LoggingDestination:
return None # type: ignore

def __init__(
self, project_name: str, fixture_request: pytest.FixtureRequest, **kwargs
):
Expand Down
97 changes: 93 additions & 4 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import defaultdict
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union

import psycopg2
import pytz
Expand All @@ -12,7 +12,7 @@

from feast import Entity
from feast.feature_view import FeatureView
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.key_encoding_utils import get_val_str, serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool
from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig
Expand All @@ -25,6 +25,12 @@
class PostgreSQLOnlineStoreConfig(PostgreSQLConfig):
type: Literal["postgres"] = "postgres"

# Whether to enable the pgvector extension for vector similarity search
pgvector_enabled: Optional[bool] = False

# If pgvector is enabled, the length of the vector field
vector_len: Optional[int] = 512


class PostgreSQLOnlineStore(OnlineStore):
_conn: Optional[psycopg2._psycopg.connection] = None
Expand Down Expand Up @@ -68,11 +74,19 @@ def online_write_batch(
created_ts = _to_naive_utc(created_ts)

for feature_name, val in values.items():
val_str: Union[str, bytes]
if (
"pgvector_enabled" in config.online_config
and config.online_config["pgvector_enabled"]
):
val_str = get_val_str(val)
else:
val_str = val.SerializeToString()
insert_values.append(
(
entity_key_bin,
feature_name,
val.SerializeToString(),
val_str,
timestamp,
created_ts,
)
Expand Down Expand Up @@ -212,14 +226,20 @@ def update(

for table in tables_to_keep:
table_name = _table_id(project, table)
value_type = "BYTEA"
if (
"pgvector_enabled" in config.online_config
and config.online_config["pgvector_enabled"]
):
value_type = f'vector({config.online_config["vector_len"]})'
cur.execute(
sql.SQL(
"""
CREATE TABLE IF NOT EXISTS {}
(
entity_key BYTEA,
feature_name TEXT,
value BYTEA,
value {},
event_ts TIMESTAMPTZ,
created_ts TIMESTAMPTZ,
PRIMARY KEY(entity_key, feature_name)
Expand All @@ -228,6 +248,7 @@ def update(
"""
).format(
sql.Identifier(table_name),
sql.SQL(value_type),
sql.Identifier(f"{table_name}_ek"),
sql.Identifier(table_name),
)
Expand All @@ -251,6 +272,74 @@ def teardown(
logging.exception("Teardown failed")
raise

def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
"""
Args:
config: Feast configuration object
table: FeatureView object as the table to search
requested_feature: The requested feature as the column to search
embedding: The query embedding to search for
top_k: The number of items to return
Returns:
List of tuples containing the event timestamp and the document feature
"""
project = config.project

# Convert the embedding to a string to be used in postgres vector search
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"

result: List[
Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]
] = []
with self._get_conn(config) as conn, conn.cursor() as cur:
table_name = _table_id(project, table)

# Search query template to find the top k items that are closest to the given embedding
# SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5;
cur.execute(
sql.SQL(
"""
SELECT
entity_key,
feature_name,
value,
value <-> %s as distance,
event_ts FROM {table_name}
WHERE feature_name = {feature_name}
ORDER BY distance
LIMIT {top_k};
"""
).format(
table_name=sql.Identifier(table_name),
feature_name=sql.Literal(requested_feature),
top_k=sql.Literal(top_k),
),
(query_embedding_str,),
)
rows = cur.fetchall()

for entity_key, feature_name, value, distance, event_ts in rows:
# TODO Deserialize entity_key to return the entity in response
# entity_key_proto = EntityKeyProto()
# entity_key_proto_bin = bytes(entity_key)

# TODO Convert to List[float] for value type proto
feature_value_proto = ValueProto(string_val=value)

distance_value_proto = ValueProto(float_val=distance)
result.append((event_ts, feature_value_proto, distance_value_proto))

return result


def _table_id(project: str, table: FeatureView) -> str:
return f"{project}_{table.name}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from feast.infra.offline_stores.contrib.postgres_offline_store.tests.data_source import (
PostgreSQLDataSourceCreator,
)
from tests.integration.feature_repos.integration_test_repo_config import (
IntegrationTestRepoConfig,
)
from tests.integration.feature_repos.universal.online_store.postgres import (
PGVectorOnlineStoreCreator,
PostgresOnlineStoreCreator,
)

FULL_REPO_CONFIGS = [
IntegrationTestRepoConfig(online_store_creator=PostgreSQLDataSourceCreator),
IntegrationTestRepoConfig(
online_store="postgres", online_store_creator=PostgresOnlineStoreCreator
),
IntegrationTestRepoConfig(
online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator
),
]

AVAILABLE_ONLINE_STORES = {"pgvector": PGVectorOnlineStoreCreator}
27 changes: 27 additions & 0 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,30 @@ def teardown(
entities: Entities whose corresponding infrastructure should be deleted.
"""
pass

def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
"""
Retrieves online feature values for the specified embeddings.
Args:
config: The config for the current feature store.
table: The feature view whose feature values should be read.
requested_feature: The name of the feature whose embeddings should be used for retrieval.
embedding: The embeddings to use for retrieval.
top_k: The number of nearest neighbors to retrieve.
Returns:
object: A list of top k closest documents to the specified embedding. Each item in the list is a tuple
where the first item is the event timestamp for the row, and the second item is a dict of feature
name to embeddings.
"""
raise NotImplementedError(
f"Online store {self.__class__.__name__} does not support online retrieval"
)
17 changes: 17 additions & 0 deletions sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,23 @@ def online_read(
)
return result

@log_exceptions_and_usage(sampler=RatioSampler(ratio=0.001))
def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
query: List[float],
top_k: int,
) -> List:
set_usage_attribute("provider", self.__class__.__name__)
result = []
if self.online_store:
result = self.online_store.retrieve_online_documents(
config, table, requested_feature, query, top_k
)
return result

def ingest_df(
self,
feature_view: FeatureView,
Expand Down
Loading

0 comments on commit ec19036

Please sign in to comment.