Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Astra DB integration for API changes #9193

Merged
merged 4 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New Features

- Add new abstractions for `LlamaDataset`'s (#9165)
- Add metadata filtering and MMR mode support for `AstraDBVectorStore` (#9193)

### Breaking Changes / Deprecations

Expand All @@ -13,6 +14,7 @@
### Bug Fixes / Nits

- Use `azure_deployment` kwarg in `AzureOpenAILLM` (#9174)
- Fix similarity score return for `AstraDBVectorStore` Integration (#9193)

## [0.9.8] - 2023-11-26

Expand Down
109 changes: 97 additions & 12 deletions llama_index/vector_stores/astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@

"""
import logging
from typing import Any, List, Optional, cast
from typing import Any, Dict, List, Optional, cast

from llama_index.indices.query.embedding_utils import get_top_k_mmr_embeddings
from llama_index.schema import BaseNode, MetadataMode
from llama_index.vector_stores.types import (
ExactMatchFilter,
MetadataFilters,
VectorStore,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import (
Expand All @@ -20,6 +24,7 @@

_logger = logging.getLogger(__name__)

DEFAULT_MMR_PREFETCH_FACTOR = 4.0
MAX_INSERT_BATCH_SIZE = 20


Expand Down Expand Up @@ -58,7 +63,9 @@ def __init__(
namespace: Optional[str] = None,
ttl_seconds: Optional[int] = None,
) -> None:
import_err_msg = "`astrapy` package not found, please run `pip install astrapy`"
import_err_msg = (
"`astrapy` package not found, please run `pip install --upgrade astrapy`"
)

# Try to import astrapy for use
try:
Expand Down Expand Up @@ -153,25 +160,103 @@ def client(self) -> Any:
"""Return the underlying Astra vector table object."""
return self._astra_db_collection

@staticmethod
def _query_filters_to_dict(query_filters: MetadataFilters) -> Dict[str, Any]:
if any(not isinstance(f, ExactMatchFilter) for f in query_filters.filters):
raise NotImplementedError("Only `ExactMatchFilter` filters are supported")
return {f"metadata.{f.key}": f.value for f in query_filters.filters}

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes."""
# Get the currently available query modes
_available_query_modes = [
VectorStoreQueryMode.DEFAULT,
VectorStoreQueryMode.MMR,
]

# Reject query if not available
if query.mode not in _available_query_modes:
raise NotImplementedError(f"Query mode {query.mode} not available.")

# Get the query embedding
query_embedding = cast(List[float], query.query_embedding)

# Set the parameters accordingly
sort = {"$vector": query_embedding}
options = {"limit": query.similarity_top_k}
projection = {"$vector": 1, "$similarity": 1, "content": 1}
# Process the metadata filters as needed
if query.filters is not None:
query_metadata = self._query_filters_to_dict(query.filters)
else:
query_metadata = {}

# Get the scores depending on the query mode
if query.mode == VectorStoreQueryMode.DEFAULT:
# Call the vector_find method of AstraPy
matches = self._astra_db_collection.vector_find(
vector=query_embedding,
limit=query.similarity_top_k,
filter=query_metadata,
)

# Get the scores associated with each
top_k_scores = [match["$similarity"] for match in matches]
elif query.mode == VectorStoreQueryMode.MMR:
# Querying a larger number of vectors and then doing MMR on them.
if (
kwargs.get("mmr_prefetch_factor") is not None
and kwargs.get("mmr_prefetch_k") is not None
):
raise ValueError(
"'mmr_prefetch_factor' and 'mmr_prefetch_k' "
"cannot coexist in a call to query()"
)
else:
if kwargs.get("mmr_prefetch_k") is not None:
prefetch_k0 = int(kwargs["mmr_prefetch_k"])
else:
prefetch_k0 = int(
query.similarity_top_k
* kwargs.get("mmr_prefetch_factor", DEFAULT_MMR_PREFETCH_FACTOR)
)
# Get the most we can possibly need to fetch
prefetch_k = max(prefetch_k0, query.similarity_top_k)

# Call AstraPy to fetch them
prefetch_matches = self._astra_db_collection.vector_find(
vector=query_embedding,
limit=prefetch_k,
filter=query_metadata,
)

# Get the MMR threshold
mmr_threshold = query.mmr_threshold or kwargs.get("mmr_threshold")

# If we have found documents, we can proceed
if prefetch_matches:
pf_match_indices, pf_match_embeddings = zip(
*enumerate(match["$vector"] for match in prefetch_matches)
)
else:
pf_match_indices, pf_match_embeddings = [], []

# Create lists for the indices and embeddings
pf_match_indices = list(pf_match_indices)
pf_match_embeddings = list(pf_match_embeddings)

# Call the Llama utility function to get the top k
mmr_similarities, mmr_indices = get_top_k_mmr_embeddings(
query_embedding,
pf_match_embeddings,
similarity_top_k=query.similarity_top_k,
embedding_ids=pf_match_indices,
mmr_threshold=mmr_threshold,
)

# Call the find method of the Astra API
matches = self._astra_db_collection.find(
sort=sort, options=options, projection=projection
)["data"]["documents"]
# Finally, build the final results based on the mmr values
matches = [prefetch_matches[mmr_index] for mmr_index in mmr_indices]
top_k_scores = mmr_similarities

# We have three lists to return
top_k_nodes = []
top_k_ids = []
top_k_scores = []

# Get every match
for my_match in matches:
Expand All @@ -184,8 +269,8 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
# Append to the respective lists
top_k_nodes.append(node)
top_k_ids.append(my_match["_id"])
top_k_scores.append(my_match["$similarity"])

# return our final result
return VectorStoreQueryResult(
nodes=top_k_nodes,
similarities=top_k_scores,
Expand Down