Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add batching to Qdrant #5443

Merged
merged 3 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 48 additions & 30 deletions langchain/vectorstores/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
import warnings
from hashlib import md5
from itertools import islice
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -158,6 +159,7 @@ def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
batch_size: int = 64,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Expand All @@ -171,24 +173,30 @@ def add_texts(
"""
from qdrant_client.http import models as rest

texts = list(
texts
) # otherwise iterable might be exhausted after id calculation
ids = [md5(text.encode("utf-8")).hexdigest() for text in texts]

self.client.upsert(
collection_name=self.collection_name,
points=rest.Batch.construct(
ids=ids,
vectors=self._embed_texts(texts),
payloads=self._build_payloads(
texts,
metadatas,
self.content_payload_key,
self.metadata_payload_key,
ids = []
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None

batch_ids = [md5(text.encode("utf-8")).hexdigest() for text in batch_texts]

self.client.upsert(
collection_name=self.collection_name,
points=rest.Batch.construct(
ids=batch_ids,
vectors=self._embed_texts(batch_texts),
payloads=self._build_payloads(
batch_texts,
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
),
),
),
)
)

ids.extend(batch_ids)

return ids

Expand Down Expand Up @@ -309,6 +317,7 @@ def from_texts(
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
batch_size: int = 64,
**kwargs: Any,
) -> Qdrant:
"""Construct Qdrant wrapper from a list of texts.
Expand Down Expand Up @@ -361,7 +370,7 @@ def from_texts(
**kwargs:
Additional arguments passed directly into REST client initialization

This is a user friendly interface that:
This is a user-friendly interface that:
1. Creates embeddings, one for each text
2. Initializes the Qdrant database as an in-memory docstore by default
(and overridable to a remote docstore)
Expand Down Expand Up @@ -417,19 +426,28 @@ def from_texts(
),
)

# Now generate the embeddings for all the texts
embeddings = embedding.embed_documents(texts)

client.upsert(
collection_name=collection_name,
points=rest.Batch.construct(
ids=[md5(text.encode("utf-8")).hexdigest() for text in texts],
vectors=embeddings,
payloads=cls._build_payloads(
texts, metadatas, content_payload_key, metadata_payload_key
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None

# Generate the embeddings for all the texts in a batch
batch_embeddings = embedding.embed_documents(batch_texts)

client.upsert(
collection_name=collection_name,
points=rest.Batch.construct(
ids=[md5(text.encode("utf-8")).hexdigest() for text in batch_texts],
vectors=batch_embeddings,
payloads=cls._build_payloads(
batch_texts,
batch_metadatas,
content_payload_key,
metadata_payload_key,
),
),
),
)
)

return cls(
client=client,
Expand Down
25 changes: 25 additions & 0 deletions tests/integration_tests/vectorstores/fake_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,28 @@ def embed_query(self, text: str) -> List[float]:
Distance to each text will be that text's index,
as it was passed to embed_documents."""
return [float(1.0)] * 9 + [float(0.0)]


class ConsistentFakeEmbeddings(FakeEmbeddings):
"""Fake embeddings which remember all the texts seen so far to return consistent
vectors for the same texts."""

def __init__(self) -> None:
self.known_texts: List[str] = []

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [float(1.0)] * 9 + [float(self.known_texts.index(text))]
out_vectors.append(vector)
return out_vectors

def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
if text not in self.known_texts:
return [float(1.0)] * 9 + [float(0.0)]
return [float(1.0)] * 9 + [float(self.known_texts.index(text))]
69 changes: 50 additions & 19 deletions tests/integration_tests/vectorstores/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import Qdrant
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
)


@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize(
["content_payload_key", "metadata_payload_key"],
[
Expand All @@ -18,36 +21,59 @@
("foo", Qdrant.METADATA_KEY),
],
)
def test_qdrant(content_payload_key: str, metadata_payload_key: str) -> None:
def test_qdrant_similarity_search(
batch_size: int, content_payload_key: str, metadata_payload_key: str
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = Qdrant.from_texts(
texts,
FakeEmbeddings(),
ConsistentFakeEmbeddings(),
location=":memory:",
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]


def test_qdrant_add_documents() -> None:
@pytest.mark.parametrize("batch_size", [1, 64])
def test_qdrant_add_documents(batch_size: int) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch: Qdrant = Qdrant.from_texts(texts, FakeEmbeddings(), location=":memory:")
docsearch: Qdrant = Qdrant.from_texts(
texts, ConsistentFakeEmbeddings(), location=":memory:", batch_size=batch_size
)

new_texts = ["foobar", "foobaz"]
docsearch.add_documents([Document(page_content=content) for content in new_texts])
docsearch.add_documents(
[Document(page_content=content) for content in new_texts], batch_size=batch_size
)
output = docsearch.similarity_search("foobar", k=1)
# FakeEmbeddings return the same query embedding as the first document embedding
# computed in `embedding.embed_documents`. Since embed_documents is called twice,
# "foo" embedding is the same as "foobar" embedding
# StatefulFakeEmbeddings return the same query embedding as the first document
# embedding computed in `embedding.embed_documents`. Thus, "foo" embedding is the
# same as "foobar" embedding
assert output == [Document(page_content="foobar")] or output == [
Document(page_content="foo")
]


@pytest.mark.parametrize("batch_size", [1, 64])
def test_qdrant_add_texts_returns_all_ids(batch_size: int) -> None:
docsearch: Qdrant = Qdrant.from_texts(
["foobar"],
ConsistentFakeEmbeddings(),
location=":memory:",
batch_size=batch_size,
)

ids = docsearch.add_texts(["foo", "bar", "baz"])
assert 3 == len(ids)
assert 3 == len(set(ids))


@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize(
["content_payload_key", "metadata_payload_key"],
[
Expand All @@ -58,24 +84,26 @@ def test_qdrant_add_documents() -> None:
],
)
def test_qdrant_with_metadatas(
content_payload_key: str, metadata_payload_key: str
batch_size: int, content_payload_key: str, metadata_payload_key: str
) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = Qdrant.from_texts(
texts,
FakeEmbeddings(),
ConsistentFakeEmbeddings(),
metadatas=metadatas,
location=":memory:",
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": 0})]


def test_qdrant_similarity_search_filters() -> None:
@pytest.mark.parametrize("batch_size", [1, 64])
def test_qdrant_similarity_search_filters(batch_size: int) -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
metadatas = [
Expand All @@ -84,9 +112,10 @@ def test_qdrant_similarity_search_filters() -> None:
]
docsearch = Qdrant.from_texts(
texts,
FakeEmbeddings(),
ConsistentFakeEmbeddings(),
metadatas=metadatas,
location=":memory:",
batch_size=batch_size,
)

output = docsearch.similarity_search(
Expand All @@ -100,6 +129,7 @@ def test_qdrant_similarity_search_filters() -> None:
]


@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize(
["content_payload_key", "metadata_payload_key"],
[
Expand All @@ -110,18 +140,19 @@ def test_qdrant_similarity_search_filters() -> None:
],
)
def test_qdrant_max_marginal_relevance_search(
content_payload_key: str, metadata_payload_key: str
batch_size: int, content_payload_key: str, metadata_payload_key: str
) -> None:
"""Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = Qdrant.from_texts(
texts,
FakeEmbeddings(),
ConsistentFakeEmbeddings(),
metadatas=metadatas,
location=":memory:",
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
batch_size=batch_size,
)
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
assert output == [
Expand All @@ -133,9 +164,9 @@ def test_qdrant_max_marginal_relevance_search(
@pytest.mark.parametrize(
["embeddings", "embedding_function"],
[
(FakeEmbeddings(), None),
(FakeEmbeddings().embed_query, None),
(None, FakeEmbeddings().embed_query),
(ConsistentFakeEmbeddings(), None),
(ConsistentFakeEmbeddings().embed_query, None),
(None, ConsistentFakeEmbeddings().embed_query),
],
)
def test_qdrant_embedding_interface(
Expand All @@ -157,7 +188,7 @@ def test_qdrant_embedding_interface(
@pytest.mark.parametrize(
["embeddings", "embedding_function"],
[
(FakeEmbeddings(), FakeEmbeddings().embed_query),
(ConsistentFakeEmbeddings(), ConsistentFakeEmbeddings().embed_query),
(None, None),
],
)
Expand Down