From ba9f4efd5eccd0548a39521a145c6573ac90c221 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Tue, 7 May 2024 08:06:42 -0400 Subject: [PATCH] feat: Enable other distance metrics for Vector DB and Update docs (#4170) * updated PGVector docs Signed-off-by: Francisco Javier Arceo * adding distance metric to arguments and defaulting to L2 Signed-off-by: Francisco Javier Arceo * linter Signed-off-by: Francisco Javier Arceo * testing other distance metric Signed-off-by: Francisco Javier Arceo * updated default Signed-off-by: Francisco Javier Arceo * linter Signed-off-by: Francisco Javier Arceo * fixed some copy Signed-off-by: Francisco Javier Arceo * updated Signed-off-by: Francisco Javier Arceo --------- Signed-off-by: Francisco Javier Arceo --- docs/reference/online-stores/postgres.md | 10 +++++++-- sdk/python/feast/feature_store.py | 7 +++++++ .../infra/online_stores/contrib/postgres.py | 18 +++++++++++++++- .../feast/infra/online_stores/online_store.py | 2 +- .../feast/infra/passthrough_provider.py | 8 ++++++- sdk/python/feast/infra/provider.py | 5 +++-- sdk/python/tests/foo_provider.py | 1 + .../online_store/test_universal_online.py | 21 ++++++++++++++++++- 8 files changed, 64 insertions(+), 8 deletions(-) diff --git a/docs/reference/online-stores/postgres.md b/docs/reference/online-stores/postgres.md index 34d4de3488..77a9408d2b 100644 --- a/docs/reference/online-stores/postgres.md +++ b/docs/reference/online-stores/postgres.md @@ -65,10 +65,16 @@ To compare this set of functionality against other online stores, please see the ## PGVector The Postgres online store supports the use of [PGVector](https://github.com/pgvector/pgvector) for storing feature values. -To enable PGVector, set `pgvector_enabled: true` in the online store configuration. +To enable PGVector, set `pgvector_enabled: true` in the online store configuration. + The `vector_len` parameter can be used to specify the length of the vector. The default value is 512. -Then you can use `retrieve_online_documents` to retrieve the top k closest vectors to a query vector. +Please make sure to follow the instructions in the repository, which, as the time of this writing, requires you to +run `CREATE EXTENSION vector;` in the database. + + +Then you can use `retrieve_online_documents` to retrieve the top k closest vectors to a query vector. +For the Retrieval Augmented Generation (RAG) use-case, you have to embed the query prior to passing the query vector. {% code title="python" %} ```python diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index bc492e4208..f45dbb1bc8 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1740,6 +1740,7 @@ def retrieve_online_documents( feature: str, query: Union[str, List[float]], top_k: int, + distance_metric: str, ) -> OnlineResponse: """ Retrieves the top k closest document features. Note, embeddings are a subset of features. @@ -1750,11 +1751,13 @@ def retrieve_online_documents( references must have format "feature_view:feature", e.g, "document_fv:document_embeddings". query: The query to retrieve the closest document features for. top_k: The number of closest document features to retrieve. + distance_metric: The distance metric to use for retrieval. """ return self._retrieve_online_documents( feature=feature, query=query, top_k=top_k, + distance_metric=distance_metric, ) def _retrieve_online_documents( @@ -1762,6 +1765,7 @@ def _retrieve_online_documents( feature: str, query: Union[str, List[float]], top_k: int, + distance_metric: str = "L2", ): if isinstance(query, str): raise ValueError( @@ -1783,6 +1787,7 @@ def _retrieve_online_documents( requested_feature, query, top_k, + distance_metric, ) # TODO Refactor to better way of populating result @@ -2025,6 +2030,7 @@ def _retrieve_from_online_store( requested_feature: str, query: List[float], top_k: int, + distance_metric: str, ) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value, Value]]: """ Search and return document features from the online document store. @@ -2035,6 +2041,7 @@ def _retrieve_from_online_store( requested_feature=requested_feature, query=query, top_k=top_k, + distance_metric=distance_metric, ) read_row_protos = [] diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 6ed0885d13..f2c32fdafd 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -21,6 +21,13 @@ from feast.repo_config import RepoConfig from feast.usage import log_exceptions_and_usage +SUPPORTED_DISTANCE_METRICS_DICT = { + "cosine": "<=>", + "L1": "<+>", + "L2": "<->", + "inner_product": "<#>", +} + class PostgreSQLOnlineStoreConfig(PostgreSQLConfig): type: Literal["postgres"] = "postgres" @@ -276,6 +283,7 @@ def retrieve_online_documents( requested_feature: str, embedding: List[float], top_k: int, + distance_metric: str = "L2", ) -> List[ Tuple[ Optional[datetime], @@ -292,6 +300,7 @@ def retrieve_online_documents( requested_feature: The requested feature as the column to search embedding: The query embedding to search for top_k: The number of items to return + distance_metric: The distance metric to use for the search.G Returns: List of tuples containing the event timestamp and the document feature @@ -303,6 +312,12 @@ def retrieve_online_documents( "pgvector is not enabled in the online store configuration" ) + if distance_metric not in SUPPORTED_DISTANCE_METRICS_DICT: + raise ValueError( + f"Distance metric {distance_metric} is not supported. Supported distance metrics are {SUPPORTED_DISTANCE_METRICS_DICT.keys()}" + ) + + distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric] # Convert the embedding to a string to be used in postgres vector search query_embedding_str = f"[{','.join(str(el) for el in embedding)}]" @@ -327,13 +342,14 @@ def retrieve_online_documents( feature_name, value, vector_value, - vector_value <-> %s as distance, + vector_value {distance_metric_sql} %s as distance, event_ts FROM {table_name} WHERE feature_name = {feature_name} ORDER BY distance LIMIT {top_k}; """ ).format( + distance_metric_sql=distance_metric_sql, table_name=sql.Identifier(table_name), feature_name=sql.Literal(requested_feature), top_k=sql.Literal(top_k), diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index 67c5a931dd..2a81e37042 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -158,7 +158,7 @@ def retrieve_online_documents( table: The feature view whose feature values should be read. requested_feature: The name of the feature whose embeddings should be used for retrieval. embedding: The embeddings to use for retrieval. - top_k: The number of nearest neighbors to retrieve. + top_k: The number of documents to retrieve. Returns: object: A list of top k closest documents to the specified embedding. Each item in the list is a tuple diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 6476acbcb9..2f3e30018a 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -196,12 +196,18 @@ def retrieve_online_documents( requested_feature: str, query: List[float], top_k: int, + distance_metric: str, ) -> List: set_usage_attribute("provider", self.__class__.__name__) result = [] if self.online_store: result = self.online_store.retrieve_online_documents( - config, table, requested_feature, query, top_k + config, + table, + requested_feature, + query, + top_k, + distance_metric, ) return result diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index a45051a1b6..02fba0c1f6 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -303,6 +303,7 @@ def retrieve_online_documents( requested_feature: str, query: List[float], top_k: int, + distance_metric: str = "L2", ) -> List[ Tuple[ Optional[datetime], @@ -312,14 +313,14 @@ def retrieve_online_documents( ] ]: """ - Searches for the top-k nearest neighbors of the given document in the online document store. + Searches for the top-k most similar documents in the online document store. Args: config: The config for the current feature store. table: The feature view whose embeddings should be searched. requested_feature: the requested document feature name. query: The query embedding to search for. - top_k: The number of nearest neighbors to return. + top_k: The number of documents to return. Returns: A list of dictionaries, where each dictionary contains the document feature. diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 2a830d424c..f869d82e11 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -111,6 +111,7 @@ def retrieve_online_documents( requested_feature: str, query: List[float], top_k: int, + distance_metric: str, ) -> List[ Tuple[ Optional[datetime], diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 3ae7be9e1e..5d6462e5e3 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -798,6 +798,25 @@ def test_retrieve_online_documents(environment, fake_document_data): fs.write_to_online_store("item_embeddings", df) documents = fs.retrieve_online_documents( - feature="item_embeddings:embedding_float", query=[1.0, 2.0], top_k=2 + feature="item_embeddings:embedding_float", + query=[1.0, 2.0], + top_k=2, + distance_metric="L2", ).to_dict() assert len(documents["embedding_float"]) == 2 + + documents = fs.retrieve_online_documents( + feature="item_embeddings:embedding_float", + query=[1.0, 2.0], + top_k=2, + distance_metric="L1", + ).to_dict() + assert len(documents["embedding_float"]) == 2 + + with pytest.raises(ValueError): + fs.retrieve_online_documents( + feature="item_embeddings:embedding_float", + query=[1.0, 2.0], + top_k=2, + distance_metric="wrong", + ).to_dict()