Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for BM25Retriever in InMemoryDocumentStore #3561

Merged
merged 31 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
cf16c25
Merge remote-tracking branch 'origin/main' into imds_support_for_bm25
anakin87 Nov 10, 2022
ee89a34
very first draft
anakin87 Nov 10, 2022
3d7e8aa
implement query and query_batch
anakin87 Nov 11, 2022
5890742
add more bm25 parameters
anakin87 Nov 12, 2022
d94433f
add rank_bm25 dependency
anakin87 Nov 12, 2022
ce5efae
fix mypy
anakin87 Nov 12, 2022
432eff7
remove tokenizer callable parameter
anakin87 Nov 12, 2022
91d40ff
remove unused import
anakin87 Nov 12, 2022
2e39f05
only json serializable attributes
anakin87 Nov 12, 2022
343f1d4
try to fix: pylint too-many-public-methods / R0904
anakin87 Nov 12, 2022
514c248
bm25 attribute always present
anakin87 Nov 12, 2022
03a35a2
convert errors into warnings to make the tutorial 1 work
anakin87 Nov 12, 2022
25d6d42
add docstrings; tests
anakin87 Nov 13, 2022
707a81b
try to make tests run
anakin87 Nov 13, 2022
ac67603
better docstrings; revert not running tests
anakin87 Nov 13, 2022
2381901
some suggestions from review
anakin87 Nov 14, 2022
2d830c9
Merge remote-tracking branch 'upstream/main' into imds_support_for_bm25
anakin87 Nov 14, 2022
34fedfc
rename elasticsearch retriever as bm25 in tests; try to test memory_bm25
anakin87 Nov 14, 2022
bbd9faa
exclude tests with filters
anakin87 Nov 15, 2022
b3b2668
change elasticsearch to bm25 retriever in test_summarizer
anakin87 Nov 15, 2022
be5969e
merge; bm25_algorithm as a property
anakin87 Nov 15, 2022
47df196
Merge branch 'imds_support_for_bm25' of https://github.com/anakin87/h…
anakin87 Nov 15, 2022
d408128
add tests
anakin87 Nov 15, 2022
de970e8
Merge branch 'main' into imds_support_for_bm25
anakin87 Nov 15, 2022
2e06683
try to improve tests
anakin87 Nov 17, 2022
1ee1544
Merge branch 'main' into imds_support_for_bm25
anakin87 Nov 17, 2022
ba89540
better type hint
anakin87 Nov 17, 2022
832ef82
Merge branch 'main' into imds_support_for_bm25
anakin87 Nov 18, 2022
f64016e
adapt test_table_text_retriever_embedding
anakin87 Nov 18, 2022
99429f6
handle non-textual docs
anakin87 Nov 21, 2022
aad1970
query only textual documents
anakin87 Nov 21, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 177 additions & 3 deletions haystack/document_stores/memory.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from typing import Any, Dict, List, Optional, Union, Generator

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal # type: ignore

import time
import logging
from copy import deepcopy
from collections import defaultdict
import re

import numpy as np
import torch
from tqdm import tqdm
import rank_bm25

from haystack.schema import Document, Label
from haystack.errors import DuplicateDocumentError
from haystack.document_stores import BaseDocumentStore
from haystack.errors import DuplicateDocumentError, DocumentStoreError
from haystack.document_stores import KeywordDocumentStore
from haystack.document_stores.base import get_batches_from_generator
from haystack.modeling.utils import initialize_device_settings
from haystack.document_stores.filter_utils import LogicalFilterClause
Expand All @@ -20,7 +27,8 @@
logger = logging.getLogger(__name__)


class InMemoryDocumentStore(BaseDocumentStore):
class InMemoryDocumentStore(KeywordDocumentStore):
# pylint: disable=R0904
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
"""
In-memory document store
"""
Expand All @@ -38,6 +46,10 @@ def __init__(
use_gpu: bool = True,
scoring_batch_size: int = 500000,
devices: Optional[List[Union[str, torch.device]]] = None,
use_bm25: bool = False,
bm25_tokenization_regex: str = r"(?u)\b\w\w+\b",
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25Okapi",
bm25_parameters: dict = {},
):
"""
:param index: The documents are scoped to an index attribute that can be used when writing, querying,
Expand Down Expand Up @@ -67,6 +79,14 @@ def __init__(
A list containing torch device objects and/or strings is supported (For example
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
parameter is not used and a single cpu device is used for inference.
:param use_bm25: Whether to build a sparse representation of documents based on BM25.
`use_bm25=True` is required to connect `BM25Retriever` to this Document Store.
:param bm25_tokenization_regex: The regular expression to use for tokenization of the text.
:param bm25_algorithm: The specific BM25 implementation to adopt.
Parameter options : ( 'BM25Okapi', 'BM25L', 'BM25Plus')
:param bm25_parameters: Parameters for BM25 implementation in a dictionary format.
For example: {'k1':1.5, 'b':0.75, 'epsilon':0.25}
You can learn more about these parameters by visiting https://github.com/dorianbrown/rank_bm25
"""
super().__init__()

Expand All @@ -81,6 +101,11 @@ def __init__(
self.duplicate_documents = duplicate_documents
self.use_gpu = use_gpu
self.scoring_batch_size = scoring_batch_size
self.use_bm25 = use_bm25
self.bm25_tokenization_regex = bm25_tokenization_regex
self.bm25_algorithm = bm25_algorithm
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
self.bm25_parameters = bm25_parameters
self.bm25: Dict[str, rank_bm25.BM25] = {}

self.devices, _ = initialize_device_settings(devices=devices, use_cuda=self.use_gpu, multi_gpu=False)
if len(self.devices) > 1:
Expand All @@ -91,6 +116,22 @@ def __init__(

self.main_device = self.devices[0]

@property
def bm25_tokenization_regex(self):
return self._tokenizer

@bm25_tokenization_regex.setter
def bm25_tokenization_regex(self, regex_string: str):
self._tokenizer = re.compile(regex_string).findall

@property
def bm25_algorithm(self):
return self._bm25_class

@bm25_algorithm.setter
def bm25_algorithm(self, algorithm: str):
self._bm25_class = getattr(rank_bm25, algorithm)

def write_documents(
self,
documents: Union[List[dict], List[Document]],
Expand Down Expand Up @@ -134,6 +175,7 @@ def write_documents(
Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents
]
documents_objects = self._drop_duplicate_documents(documents=documents_objects)
modified_documents = 0
for document in documents_objects:
if document.id in self.indexes[index]:
if duplicate_documents == "fail":
Expand All @@ -146,6 +188,23 @@ def write_documents(
)
continue
self.indexes[index][document.id] = document
modified_documents += 1

if self.use_bm25 is True and modified_documents > 0:
self.update_bm25(index=index)

def update_bm25(self, index: Optional[str] = None):
"""
Updates the BM25 sparse representation in the the document store.

:param index: Index name for which the BM25 representation is to be updated. If set to None, the default self.index is used.
"""
index = index or self.index
tokenized_corpus = [
self.bm25_tokenization_regex(doc.content.lower())
for doc in tqdm(self.get_all_documents(index=index), unit=" docs", desc="Updating BM25 representation...")
]
self.bm25[index] = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters)

def _create_document_field_map(self):
return {self.embedding_field: "embedding"}
Expand Down Expand Up @@ -759,12 +818,16 @@ def delete_documents(
index = index or self.index
if not filters and not ids:
self.indexes[index] = {}
if index in self.bm25:
self.bm25[index] = {}
return
docs_to_delete = self.get_all_documents(index=index, filters=filters)
if ids:
docs_to_delete = [doc for doc in docs_to_delete if doc.id in ids]
for doc in docs_to_delete:
del self.indexes[index][doc.id]
if self.use_bm25 is True and len(docs_to_delete) > 0:
self.update_bm25(index=index)

def delete_index(self, index: str):
"""
Expand All @@ -777,6 +840,9 @@ def delete_index(self, index: str):
del self.indexes[index]
logger.info("Index '%s' deleted.", index)

if index in self.bm25:
del self.bm25[index]

def delete_labels(
self,
index: Optional[str] = None,
Expand Down Expand Up @@ -828,3 +894,111 @@ def delete_labels(
labels_to_delete = [label for label in labels_to_delete if label.id in ids]
for label in labels_to_delete:
del self.indexes[index][label.id]

def query(
self,
query: Optional[str],
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
top_k: int = 10,
custom_query: Optional[str] = None,
index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
all_terms_must_match: bool = False,
scale_score: bool = False,
) -> List[Document]:
"""
Scan through documents in DocumentStore and return a small number documents
that are most relevant to the query as defined by the BM25 algorithm.
:param query: The query.
:param top_k: How many documents to return per query.
:param index: The name of the index in the DocumentStore from which to retrieve documents.
"""

if headers:
logger.warning("InMemoryDocumentStore does not support headers. This parameter is ignored.")
if custom_query:
logger.warning("InMemoryDocumentStore does not support custom_query. This parameter is ignored.")
if all_terms_must_match is True:
logger.warning("InMemoryDocumentStore does not support all_terms_must_match. This parameter is ignored.")
if filters:
logger.warning(
"InMemoryDocumentStore does not support filters for BM25 retrieval. This parameter is ignored."
)
if scale_score is True:
logger.warning(
"InMemoryDocumentStore does not support scale_score for BM25 retrieval. This parameter is ignored."
)

ZanSara marked this conversation as resolved.
Show resolved Hide resolved
index = index or self.index
if index not in self.bm25:
raise DocumentStoreError(
f"No BM25 representation found for the index: {index}. The Document store should be initialized with use_bm25=True"
)

if query is None:
return []

tokenized_query = self.bm25_tokenization_regex(query.lower())
docs_scores = self.bm25[index].get_scores(tokenized_query)
top_docs_positions = np.argsort(docs_scores)[::-1][:top_k]

docs_list = list(self.indexes[index].values())
top_docs = []
for i in top_docs_positions:
doc = docs_list[i]
doc.score = docs_scores[i]
top_docs.append(doc)

return top_docs

def query_batch(
self,
queries: List[str],
filters: Optional[
Union[
Dict[str, Union[Dict, List, str, int, float, bool]],
List[Dict[str, Union[Dict, List, str, int, float, bool]]],
]
] = None,
top_k: int = 10,
custom_query: Optional[str] = None,
index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
all_terms_must_match: bool = False,
scale_score: bool = False,
) -> List[List[Document]]:
"""
Scan through documents in DocumentStore and return a small number documents
that are most relevant to the provided queries as defined by keyword matching algorithms like BM25.
This method lets you find relevant documents for list of query strings (output: List of Lists of Documents).
:param query: The query.
:param top_k: How many documents to return per query.
:param index: The name of the index in the DocumentStore from which to retrieve documents.
"""

if headers:
logger.warning("InMemoryDocumentStore does not support headers. This parameter is ignored.")
if custom_query:
logger.warning("InMemoryDocumentStore does not support custom_query. This parameter is ignored.")
if all_terms_must_match is True:
logger.warning("InMemoryDocumentStore does not support all_terms_must_match. This parameter is ignored.")
if filters:
logger.warning(
"InMemoryDocumentStore does not support filters for BM25 retrieval. This parameter is ignored."
)
if scale_score is True:
logger.warning(
"InMemoryDocumentStore does not support scale_score for BM25 retrieval. This parameter is ignored."
)

index = index or self.index
if index not in self.bm25:
raise DocumentStoreError(
f"No BM25 representation found for the index: {index}. The Document store should be initialized with use_bm25=True"
)

result_documents = []
for query in queries:
result_documents.append(self.query(query=query, top_k=top_k, index=index))

return result_documents
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies = [
"transformers==4.21.2",
"nltk",
"pandas",
"rank_bm25",

# Utils
"dill", # pickle extension for (de-)serialization
Expand Down
5 changes: 3 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def indexing_document_classifier():
)


@pytest.fixture(params=["es_filter_only", "elasticsearch", "dpr", "embedding", "tfidf", "table_text_retriever"])
@pytest.fixture(params=["es_filter_only", "bm25", "dpr", "embedding", "tfidf", "table_text_retriever"])
def retriever(request, document_store):
return get_retriever(request.param, document_store)

Expand Down Expand Up @@ -753,7 +753,7 @@ def get_retriever(retriever_type, document_store):
use_gpu=False,
embed_title=True,
)
elif retriever_type == "elasticsearch":
elif retriever_type == "bm25":
retriever = BM25Retriever(document_store=document_store)
elif retriever_type == "es_filter_only":
retriever = FilterRetriever(document_store=document_store)
Expand Down Expand Up @@ -941,6 +941,7 @@ def get_document_store(
embedding_field=embedding_field,
index=index,
similarity=similarity,
use_bm25=True,
)

elif document_store_type == "elasticsearch":
Expand Down
30 changes: 30 additions & 0 deletions test/document_stores/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

import numpy as np
import pandas as pd
from rank_bm25 import BM25
import pytest
from unittest.mock import Mock


from ..conftest import get_document_store, ensure_ids_are_correct_uuids
from haystack.document_stores import (
InMemoryDocumentStore,
WeaviateDocumentStore,
MilvusDocumentStore,
FAISSDocumentStore,
Expand Down Expand Up @@ -1444,3 +1446,31 @@ def test_normalize_embeddings_diff_shapes():
VEC_1 = np.array([0.1, 0.2, 0.3], dtype="float32").reshape(1, -1)
BaseDocumentStore.normalize_embedding(VEC_1)
assert np.linalg.norm(VEC_1) - 1 < 0.01


def test_memory_update_bm25():
ds = InMemoryDocumentStore(use_bm25=False)
ds.write_documents(DOCUMENTS)
ds.update_bm25()
bm25_representation = ds.bm25[ds.index]
assert isinstance(bm25_representation, BM25)
assert bm25_representation.corpus_size == ds.get_document_count()


@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_memory_query(document_store_with_docs):
query_text = "Rome"
docs = document_store_with_docs.query(query=query_text, top_k=1)
assert len(docs) == 1
assert docs[0].content == "My name is Matteo and I live in Rome"


@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_memory_query_batch(document_store_with_docs):
query_texts = ["Paris", "Madrid"]
docs = document_store_with_docs.query_batch(queries=query_texts, top_k=5)
assert len(docs) == 2
assert len(docs[0]) == 5
assert docs[0][0].content == "My name is Christelle and I live in Paris"
assert len(docs[1]) == 5
assert docs[1][0].content == "My name is Camila and I live in Madrid"
Loading