Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ Added filtering option to FAISS vectorstore #5966

Merged
merged 3 commits into from
Jun 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions langchain/vectorstores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,19 @@ def add_embeddings(
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs)

def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
fetch_k: int = 20,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.

Args:
embedding: Embedding vector to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
fetch_k: Number of Documents to fetch after filtering
vempaliakhil96 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
List of documents most similar to the query text and L2 distance
Expand All @@ -196,7 +202,7 @@ def similarity_search_with_score_by_vector(
vector = np.array([embedding], dtype=np.float32)
if self._normalize_L2:
faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k)
scores, indices = self.index.search(vector, k if filter is None else fetch_k)
docs = []
for j, i in enumerate(indices[0]):
if i == -1:
Expand All @@ -206,11 +212,15 @@ def similarity_search_with_score_by_vector(
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append((doc, scores[0][j]))
return docs
if filter is not None:
if all(doc.metadata.get(key) == value for key, value in filter.items()):
docs.append((doc, scores[0][j]))
else:
docs.append((doc, scores[0][j]))
return docs[:k]

def similarity_search_with_score(
self, query: str, k: int = 4
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.

Expand All @@ -223,7 +233,7 @@ def similarity_search_with_score(
L2 distance in float. Lower score represents more similarity.
"""
embedding = self.embedding_function(query)
docs = self.similarity_search_with_score_by_vector(embedding, k)
docs = self.similarity_search_with_score_by_vector(embedding, k, **kwargs)
return docs

def similarity_search_by_vector(
Expand All @@ -238,7 +248,9 @@ def similarity_search_by_vector(
Returns:
List of Documents most similar to the embedding.
"""
docs_and_scores = self.similarity_search_with_score_by_vector(embedding, k)
docs_and_scores = self.similarity_search_with_score_by_vector(
embedding, k, **kwargs
)
return [doc for doc, _ in docs_and_scores]

def similarity_search(
Expand All @@ -253,7 +265,7 @@ def similarity_search(
Returns:
List of Documents most similar to the query.
"""
docs_and_scores = self.similarity_search_with_score(query, k)
docs_and_scores = self.similarity_search_with_score(query, k, **kwargs)
return [doc for doc, _ in docs_and_scores]

def max_marginal_relevance_search_by_vector(
Expand All @@ -262,6 +274,7 @@ def max_marginal_relevance_search_by_vector(
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Expand All @@ -272,15 +285,32 @@ def max_marginal_relevance_search_by_vector(
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
fetch_k: Number of Documents to fetch after filtering to
pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
_, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k)
_, indices = self.index.search(
np.array([embedding], dtype=np.float32),
fetch_k if filter is None else fetch_k * 2,
)
if filter is not None:
filtered_indices = []
for i in indices[0]:
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
if all(doc.metadata.get(key) == value for key, value in filter.items()):
filtered_indices.append(i)
indices = np.array([filtered_indices])
# -1 happens when not enough docs are returned.
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
mmr_selected = maximal_marginal_relevance(
Expand Down Expand Up @@ -328,7 +358,7 @@ def max_marginal_relevance_search(
"""
embedding = self.embedding_function(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mult=lambda_mult
embedding, k, fetch_k, lambda_mult=lambda_mult, **kwargs
)
return docs

Expand Down Expand Up @@ -530,5 +560,5 @@ def _similarity_search_with_relevance_scores(
"normalize_score_fn must be provided to"
" FAISS constructor to normalize scores"
)
docs_and_scores = self.similarity_search_with_score(query, k=k)
docs_and_scores = self.similarity_search_with_score(query, k=k, **kwargs)
return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores]
22 changes: 22 additions & 0 deletions tests/integration_tests/vectorstores/test_faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,28 @@ def test_faiss_with_metadatas() -> None:
assert output == [Document(page_content="foo", metadata={"page": 0})]


def test_faiss_with_metadatas_and_filter() -> None:
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = FAISS.from_texts(texts, FakeEmbeddings(), metadatas=metadatas)
expected_docstore = InMemoryDocstore(
{
docsearch.index_to_docstore_id[0]: Document(
page_content="foo", metadata={"page": 0}
),
docsearch.index_to_docstore_id[1]: Document(
page_content="bar", metadata={"page": 1}
),
docsearch.index_to_docstore_id[2]: Document(
page_content="baz", metadata={"page": 2}
),
}
)
assert docsearch.docstore.__dict__ == expected_docstore.__dict__
output = docsearch.similarity_search("foo", k=1, filter={"page": 1})
assert output == []


def test_faiss_search_not_found() -> None:
"""Test what happens when document is not found."""
texts = ["foo", "bar", "baz"]
Expand Down