Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable other distance metrics for Vector DB and Update docs #4170

Merged
merged 8 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions docs/reference/online-stores/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,16 @@ To compare this set of functionality against other online stores, please see the
## PGVector
The Postgres online store supports the use of [PGVector](https://github.com/pgvector/pgvector) for storing feature values.
To enable PGVector, set `pgvector_enabled: true` in the online store configuration.
To enable PGVector, set `pgvector_enabled: true` in the online store configuration.

The `vector_len` parameter can be used to specify the length of the vector. The default value is 512.

Then you can use `retrieve_online_documents` to retrieve the top k closest vectors to a query vector.
Please make sure to follow the instructions in the repository, which, as the time of this writing, requires you to
run `CREATE EXTENSION vector;` in the database.


Then you can use `retrieve_online_documents` to retrieve the top k closest vectors to a query vector.
For the Retrieval Augmented Generation (RAG) use-case, you have to embed the query prior to passing the query vector.

{% code title="python" %}
```python
Expand Down
7 changes: 7 additions & 0 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,6 +1740,7 @@ def retrieve_online_documents(
feature: str,
query: Union[str, List[float]],
top_k: int,
distance_metric: str,
) -> OnlineResponse:
"""
Retrieves the top k closest document features. Note, embeddings are a subset of features.
Expand All @@ -1750,18 +1751,21 @@ def retrieve_online_documents(
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
query: The query to retrieve the closest document features for.
top_k: The number of closest document features to retrieve.
distance_metric: The distance metric to use for retrieval.
"""
return self._retrieve_online_documents(
feature=feature,
query=query,
top_k=top_k,
distance_metric=distance_metric,
)

def _retrieve_online_documents(
self,
feature: str,
query: Union[str, List[float]],
top_k: int,
distance_metric: str = "L2",
):
if isinstance(query, str):
raise ValueError(
Expand All @@ -1783,6 +1787,7 @@ def _retrieve_online_documents(
requested_feature,
query,
top_k,
distance_metric,
)

# TODO Refactor to better way of populating result
Expand Down Expand Up @@ -2025,6 +2030,7 @@ def _retrieve_from_online_store(
requested_feature: str,
query: List[float],
top_k: int,
distance_metric: str,
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value, Value]]:
"""
Search and return document features from the online document store.
Expand All @@ -2035,6 +2041,7 @@ def _retrieve_from_online_store(
requested_feature=requested_feature,
query=query,
top_k=top_k,
distance_metric=distance_metric,
)

read_row_protos = []
Expand Down
18 changes: 17 additions & 1 deletion sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
from feast.repo_config import RepoConfig
from feast.usage import log_exceptions_and_usage

SUPPORTED_DISTANCE_METRICS_DICT = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's cool!

"cosine": "<=>",
"L1": "<+>",
"L2": "<->",
"inner_product": "<#>",
}


class PostgreSQLOnlineStoreConfig(PostgreSQLConfig):
type: Literal["postgres"] = "postgres"
Expand Down Expand Up @@ -276,6 +283,7 @@ def retrieve_online_documents(
requested_feature: str,
embedding: List[float],
top_k: int,
distance_metric: str = "L2",
) -> List[
Tuple[
Optional[datetime],
Expand All @@ -292,6 +300,7 @@ def retrieve_online_documents(
requested_feature: The requested feature as the column to search
embedding: The query embedding to search for
top_k: The number of items to return
distance_metric: The distance metric to use for the search.G
Returns:
List of tuples containing the event timestamp and the document feature

Expand All @@ -303,6 +312,12 @@ def retrieve_online_documents(
"pgvector is not enabled in the online store configuration"
)

if distance_metric not in SUPPORTED_DISTANCE_METRICS_DICT:
raise ValueError(
f"Distance metric {distance_metric} is not supported. Supported distance metrics are {SUPPORTED_DISTANCE_METRICS_DICT.keys()}"
)

distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric]
# Convert the embedding to a string to be used in postgres vector search
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"

Expand All @@ -327,13 +342,14 @@ def retrieve_online_documents(
feature_name,
value,
vector_value,
vector_value <-> %s as distance,
vector_value {distance_metric_sql} %s as distance,
event_ts FROM {table_name}
WHERE feature_name = {feature_name}
ORDER BY distance
LIMIT {top_k};
"""
).format(
distance_metric_sql=distance_metric_sql,
table_name=sql.Identifier(table_name),
feature_name=sql.Literal(requested_feature),
top_k=sql.Literal(top_k),
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def retrieve_online_documents(
table: The feature view whose feature values should be read.
requested_feature: The name of the feature whose embeddings should be used for retrieval.
embedding: The embeddings to use for retrieval.
top_k: The number of nearest neighbors to retrieve.
top_k: The number of documents to retrieve.

Returns:
object: A list of top k closest documents to the specified embedding. Each item in the list is a tuple
Expand Down
8 changes: 7 additions & 1 deletion sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,18 @@ def retrieve_online_documents(
requested_feature: str,
query: List[float],
top_k: int,
distance_metric: str,
) -> List:
set_usage_attribute("provider", self.__class__.__name__)
result = []
if self.online_store:
result = self.online_store.retrieve_online_documents(
config, table, requested_feature, query, top_k
config,
table,
requested_feature,
query,
top_k,
distance_metric,
)
return result

Expand Down
5 changes: 3 additions & 2 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def retrieve_online_documents(
requested_feature: str,
query: List[float],
top_k: int,
distance_metric: str = "L2",
) -> List[
Tuple[
Optional[datetime],
Expand All @@ -312,14 +313,14 @@ def retrieve_online_documents(
]
]:
"""
Searches for the top-k nearest neighbors of the given document in the online document store.
Searches for the top-k most similar documents in the online document store.

Args:
config: The config for the current feature store.
table: The feature view whose embeddings should be searched.
requested_feature: the requested document feature name.
query: The query embedding to search for.
top_k: The number of nearest neighbors to return.
top_k: The number of documents to return.

Returns:
A list of dictionaries, where each dictionary contains the document feature.
Expand Down
1 change: 1 addition & 0 deletions sdk/python/tests/foo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def retrieve_online_documents(
requested_feature: str,
query: List[float],
top_k: int,
distance_metric: str,
) -> List[
Tuple[
Optional[datetime],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,25 @@ def test_retrieve_online_documents(environment, fake_document_data):
fs.write_to_online_store("item_embeddings", df)

documents = fs.retrieve_online_documents(
feature="item_embeddings:embedding_float", query=[1.0, 2.0], top_k=2
feature="item_embeddings:embedding_float",
query=[1.0, 2.0],
top_k=2,
distance_metric="L2",
).to_dict()
assert len(documents["embedding_float"]) == 2

documents = fs.retrieve_online_documents(
feature="item_embeddings:embedding_float",
query=[1.0, 2.0],
top_k=2,
distance_metric="L1",
).to_dict()
assert len(documents["embedding_float"]) == 2

with pytest.raises(ValueError):
fs.retrieve_online_documents(
feature="item_embeddings:embedding_float",
query=[1.0, 2.0],
top_k=2,
distance_metric="wrong",
).to_dict()
Loading