Skip to content

Commit

Permalink
Update Astra DB integration for API changes (#9193)
Browse files Browse the repository at this point in the history
* Update Astra DB integration for API changes

* Update astra.py

* Fix issue in specification of metadata filter
  • Loading branch information
erichare authored Nov 29, 2023
1 parent 6efaaeb commit e1d513d
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New Features

- Add new abstractions for `LlamaDataset`'s (#9165)
- Add metadata filtering and MMR mode support for `AstraDBVectorStore` (#9193)

### Breaking Changes / Deprecations

Expand All @@ -13,6 +14,7 @@
### Bug Fixes / Nits

- Use `azure_deployment` kwarg in `AzureOpenAILLM` (#9174)
- Fix similarity score return for `AstraDBVectorStore` Integration (#9193)

## [0.9.8] - 2023-11-26

Expand Down
109 changes: 97 additions & 12 deletions llama_index/vector_stores/astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
"""
import logging
from typing import Any, List, Optional, cast
from typing import Any, Dict, List, Optional, cast

from llama_index.indices.query.embedding_utils import get_top_k_mmr_embeddings
from llama_index.schema import BaseNode, MetadataMode
from llama_index.vector_stores.types import (
ExactMatchFilter,
MetadataFilters,
VectorStore,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import (
Expand All @@ -20,6 +24,7 @@

_logger = logging.getLogger(__name__)

DEFAULT_MMR_PREFETCH_FACTOR = 4.0
MAX_INSERT_BATCH_SIZE = 20


Expand Down Expand Up @@ -58,7 +63,9 @@ def __init__(
namespace: Optional[str] = None,
ttl_seconds: Optional[int] = None,
) -> None:
import_err_msg = "`astrapy` package not found, please run `pip install astrapy`"
import_err_msg = (
"`astrapy` package not found, please run `pip install --upgrade astrapy`"
)

# Try to import astrapy for use
try:
Expand Down Expand Up @@ -153,25 +160,103 @@ def client(self) -> Any:
"""Return the underlying Astra vector table object."""
return self._astra_db_collection

@staticmethod
def _query_filters_to_dict(query_filters: MetadataFilters) -> Dict[str, Any]:
if any(not isinstance(f, ExactMatchFilter) for f in query_filters.filters):
raise NotImplementedError("Only `ExactMatchFilter` filters are supported")
return {f"metadata.{f.key}": f.value for f in query_filters.filters}

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes."""
# Get the currently available query modes
_available_query_modes = [
VectorStoreQueryMode.DEFAULT,
VectorStoreQueryMode.MMR,
]

# Reject query if not available
if query.mode not in _available_query_modes:
raise NotImplementedError(f"Query mode {query.mode} not available.")

# Get the query embedding
query_embedding = cast(List[float], query.query_embedding)

# Set the parameters accordingly
sort = {"$vector": query_embedding}
options = {"limit": query.similarity_top_k}
projection = {"$vector": 1, "$similarity": 1, "content": 1}
# Process the metadata filters as needed
if query.filters is not None:
query_metadata = self._query_filters_to_dict(query.filters)
else:
query_metadata = {}

# Get the scores depending on the query mode
if query.mode == VectorStoreQueryMode.DEFAULT:
# Call the vector_find method of AstraPy
matches = self._astra_db_collection.vector_find(
vector=query_embedding,
limit=query.similarity_top_k,
filter=query_metadata,
)

# Get the scores associated with each
top_k_scores = [match["$similarity"] for match in matches]
elif query.mode == VectorStoreQueryMode.MMR:
# Querying a larger number of vectors and then doing MMR on them.
if (
kwargs.get("mmr_prefetch_factor") is not None
and kwargs.get("mmr_prefetch_k") is not None
):
raise ValueError(
"'mmr_prefetch_factor' and 'mmr_prefetch_k' "
"cannot coexist in a call to query()"
)
else:
if kwargs.get("mmr_prefetch_k") is not None:
prefetch_k0 = int(kwargs["mmr_prefetch_k"])
else:
prefetch_k0 = int(
query.similarity_top_k
* kwargs.get("mmr_prefetch_factor", DEFAULT_MMR_PREFETCH_FACTOR)
)
# Get the most we can possibly need to fetch
prefetch_k = max(prefetch_k0, query.similarity_top_k)

# Call AstraPy to fetch them
prefetch_matches = self._astra_db_collection.vector_find(
vector=query_embedding,
limit=prefetch_k,
filter=query_metadata,
)

# Get the MMR threshold
mmr_threshold = query.mmr_threshold or kwargs.get("mmr_threshold")

# If we have found documents, we can proceed
if prefetch_matches:
pf_match_indices, pf_match_embeddings = zip(
*enumerate(match["$vector"] for match in prefetch_matches)
)
else:
pf_match_indices, pf_match_embeddings = [], []

# Create lists for the indices and embeddings
pf_match_indices = list(pf_match_indices)
pf_match_embeddings = list(pf_match_embeddings)

# Call the Llama utility function to get the top k
mmr_similarities, mmr_indices = get_top_k_mmr_embeddings(
query_embedding,
pf_match_embeddings,
similarity_top_k=query.similarity_top_k,
embedding_ids=pf_match_indices,
mmr_threshold=mmr_threshold,
)

# Call the find method of the Astra API
matches = self._astra_db_collection.find(
sort=sort, options=options, projection=projection
)["data"]["documents"]
# Finally, build the final results based on the mmr values
matches = [prefetch_matches[mmr_index] for mmr_index in mmr_indices]
top_k_scores = mmr_similarities

# We have three lists to return
top_k_nodes = []
top_k_ids = []
top_k_scores = []

# Get every match
for my_match in matches:
Expand All @@ -184,8 +269,8 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
# Append to the respective lists
top_k_nodes.append(node)
top_k_ids.append(my_match["_id"])
top_k_scores.append(my_match["$similarity"])

# return our final result
return VectorStoreQueryResult(
nodes=top_k_nodes,
similarities=top_k_scores,
Expand Down

0 comments on commit e1d513d

Please sign in to comment.