diff --git a/haystack/document_stores/memory.py b/haystack/document_stores/memory.py index dd61bd9286..5d893cd00c 100644 --- a/haystack/document_stores/memory.py +++ b/haystack/document_stores/memory.py @@ -1,17 +1,24 @@ from typing import Any, Dict, List, Optional, Union, Generator +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal # type: ignore + import time import logging from copy import deepcopy from collections import defaultdict +import re import numpy as np import torch from tqdm import tqdm +import rank_bm25 from haystack.schema import Document, Label -from haystack.errors import DuplicateDocumentError -from haystack.document_stores import BaseDocumentStore +from haystack.errors import DuplicateDocumentError, DocumentStoreError +from haystack.document_stores import KeywordDocumentStore from haystack.document_stores.base import get_batches_from_generator from haystack.modeling.utils import initialize_device_settings from haystack.document_stores.filter_utils import LogicalFilterClause @@ -20,7 +27,8 @@ logger = logging.getLogger(__name__) -class InMemoryDocumentStore(BaseDocumentStore): +class InMemoryDocumentStore(KeywordDocumentStore): + # pylint: disable=R0904 """ In-memory document store """ @@ -38,6 +46,10 @@ def __init__( use_gpu: bool = True, scoring_batch_size: int = 500000, devices: Optional[List[Union[str, torch.device]]] = None, + use_bm25: bool = False, + bm25_tokenization_regex: str = r"(?u)\b\w\w+\b", + bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25Okapi", + bm25_parameters: dict = {}, ): """ :param index: The documents are scoped to an index attribute that can be used when writing, querying, @@ -67,6 +79,14 @@ def __init__( A list containing torch device objects and/or strings is supported (For example [torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices parameter is not used and a single cpu device is used for inference. + :param use_bm25: Whether to build a sparse representation of documents based on BM25. + `use_bm25=True` is required to connect `BM25Retriever` to this Document Store. + :param bm25_tokenization_regex: The regular expression to use for tokenization of the text. + :param bm25_algorithm: The specific BM25 implementation to adopt. + Parameter options : ( 'BM25Okapi', 'BM25L', 'BM25Plus') + :param bm25_parameters: Parameters for BM25 implementation in a dictionary format. + For example: {'k1':1.5, 'b':0.75, 'epsilon':0.25} + You can learn more about these parameters by visiting https://github.com/dorianbrown/rank_bm25 """ super().__init__() @@ -81,6 +101,11 @@ def __init__( self.duplicate_documents = duplicate_documents self.use_gpu = use_gpu self.scoring_batch_size = scoring_batch_size + self.use_bm25 = use_bm25 + self.bm25_tokenization_regex = bm25_tokenization_regex + self.bm25_algorithm = bm25_algorithm + self.bm25_parameters = bm25_parameters + self.bm25: Dict[str, rank_bm25.BM25] = {} self.devices, _ = initialize_device_settings(devices=devices, use_cuda=self.use_gpu, multi_gpu=False) if len(self.devices) > 1: @@ -91,6 +116,22 @@ def __init__( self.main_device = self.devices[0] + @property + def bm25_tokenization_regex(self): + return self._tokenizer + + @bm25_tokenization_regex.setter + def bm25_tokenization_regex(self, regex_string: str): + self._tokenizer = re.compile(regex_string).findall + + @property + def bm25_algorithm(self): + return self._bm25_class + + @bm25_algorithm.setter + def bm25_algorithm(self, algorithm: str): + self._bm25_class = getattr(rank_bm25, algorithm) + def write_documents( self, documents: Union[List[dict], List[Document]], @@ -134,6 +175,7 @@ def write_documents( Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents ] documents_objects = self._drop_duplicate_documents(documents=documents_objects) + modified_documents = 0 for document in documents_objects: if document.id in self.indexes[index]: if duplicate_documents == "fail": @@ -146,6 +188,32 @@ def write_documents( ) continue self.indexes[index][document.id] = document + modified_documents += 1 + + if self.use_bm25 is True and modified_documents > 0: + self.update_bm25(index=index) + + def update_bm25(self, index: Optional[str] = None): + """ + Updates the BM25 sparse representation in the the document store. + + :param index: Index name for which the BM25 representation is to be updated. If set to None, the default self.index is used. + """ + index = index or self.index + + all_documents = self.get_all_documents(index=index) + textual_documents = [doc for doc in all_documents if doc.content_type == "text"] + if len(textual_documents) < len(all_documents): + logger.warning( + f"Some documents in {index} index are non-textual." + f" They will be written to the index, but the corresponding BM25 representations will not be generated." + ) + + tokenized_corpus = [ + self.bm25_tokenization_regex(doc.content.lower()) + for doc in tqdm(textual_documents, unit=" docs", desc="Updating BM25 representation...") + ] + self.bm25[index] = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters) def _create_document_field_map(self): return {self.embedding_field: "embedding"} @@ -759,12 +827,16 @@ def delete_documents( index = index or self.index if not filters and not ids: self.indexes[index] = {} + if index in self.bm25: + self.bm25[index] = {} return docs_to_delete = self.get_all_documents(index=index, filters=filters) if ids: docs_to_delete = [doc for doc in docs_to_delete if doc.id in ids] for doc in docs_to_delete: del self.indexes[index][doc.id] + if self.use_bm25 is True and len(docs_to_delete) > 0: + self.update_bm25(index=index) def delete_index(self, index: str): """ @@ -777,6 +849,9 @@ def delete_index(self, index: str): del self.indexes[index] logger.info("Index '%s' deleted.", index) + if index in self.bm25: + del self.bm25[index] + def delete_labels( self, index: Optional[str] = None, @@ -828,3 +903,111 @@ def delete_labels( labels_to_delete = [label for label in labels_to_delete if label.id in ids] for label in labels_to_delete: del self.indexes[index][label.id] + + def query( + self, + query: Optional[str], + filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, + top_k: int = 10, + custom_query: Optional[str] = None, + index: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + all_terms_must_match: bool = False, + scale_score: bool = False, + ) -> List[Document]: + """ + Scan through documents in DocumentStore and return a small number documents + that are most relevant to the query as defined by the BM25 algorithm. + :param query: The query. + :param top_k: How many documents to return per query. + :param index: The name of the index in the DocumentStore from which to retrieve documents. + """ + + if headers: + logger.warning("InMemoryDocumentStore does not support headers. This parameter is ignored.") + if custom_query: + logger.warning("InMemoryDocumentStore does not support custom_query. This parameter is ignored.") + if all_terms_must_match is True: + logger.warning("InMemoryDocumentStore does not support all_terms_must_match. This parameter is ignored.") + if filters: + logger.warning( + "InMemoryDocumentStore does not support filters for BM25 retrieval. This parameter is ignored." + ) + if scale_score is True: + logger.warning( + "InMemoryDocumentStore does not support scale_score for BM25 retrieval. This parameter is ignored." + ) + + index = index or self.index + if index not in self.bm25: + raise DocumentStoreError( + f"No BM25 representation found for the index: {index}. The Document store should be initialized with use_bm25=True" + ) + + if query is None: + return [] + + tokenized_query = self.bm25_tokenization_regex(query.lower()) + docs_scores = self.bm25[index].get_scores(tokenized_query) + top_docs_positions = np.argsort(docs_scores)[::-1][:top_k] + + textual_docs_list = [doc for doc in self.indexes[index].values() if doc.content_type == "text"] + top_docs = [] + for i in top_docs_positions: + doc = textual_docs_list[i] + doc.score = docs_scores[i] + top_docs.append(doc) + + return top_docs + + def query_batch( + self, + queries: List[str], + filters: Optional[ + Union[ + Dict[str, Union[Dict, List, str, int, float, bool]], + List[Dict[str, Union[Dict, List, str, int, float, bool]]], + ] + ] = None, + top_k: int = 10, + custom_query: Optional[str] = None, + index: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + all_terms_must_match: bool = False, + scale_score: bool = False, + ) -> List[List[Document]]: + """ + Scan through documents in DocumentStore and return a small number documents + that are most relevant to the provided queries as defined by keyword matching algorithms like BM25. + This method lets you find relevant documents for list of query strings (output: List of Lists of Documents). + :param query: The query. + :param top_k: How many documents to return per query. + :param index: The name of the index in the DocumentStore from which to retrieve documents. + """ + + if headers: + logger.warning("InMemoryDocumentStore does not support headers. This parameter is ignored.") + if custom_query: + logger.warning("InMemoryDocumentStore does not support custom_query. This parameter is ignored.") + if all_terms_must_match is True: + logger.warning("InMemoryDocumentStore does not support all_terms_must_match. This parameter is ignored.") + if filters: + logger.warning( + "InMemoryDocumentStore does not support filters for BM25 retrieval. This parameter is ignored." + ) + if scale_score is True: + logger.warning( + "InMemoryDocumentStore does not support scale_score for BM25 retrieval. This parameter is ignored." + ) + + index = index or self.index + if index not in self.bm25: + raise DocumentStoreError( + f"No BM25 representation found for the index: {index}. The Document store should be initialized with use_bm25=True" + ) + + result_documents = [] + for query in queries: + result_documents.append(self.query(query=query, top_k=top_k, index=index)) + + return result_documents diff --git a/pyproject.toml b/pyproject.toml index 6d78d098b5..4b90e689a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dependencies = [ "transformers==4.21.2", "nltk", "pandas", + "rank_bm25", # Utils "dill", # pickle extension for (de-)serialization diff --git a/test/conftest.py b/test/conftest.py index 4c8ea38198..2017b37b70 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -687,7 +687,7 @@ def indexing_document_classifier(): ) -@pytest.fixture(params=["es_filter_only", "elasticsearch", "dpr", "embedding", "tfidf", "table_text_retriever"]) +@pytest.fixture(params=["es_filter_only", "bm25", "dpr", "embedding", "tfidf", "table_text_retriever"]) def retriever(request, document_store): return get_retriever(request.param, document_store) @@ -753,7 +753,7 @@ def get_retriever(retriever_type, document_store): use_gpu=False, embed_title=True, ) - elif retriever_type == "elasticsearch": + elif retriever_type == "bm25": retriever = BM25Retriever(document_store=document_store) elif retriever_type == "es_filter_only": retriever = FilterRetriever(document_store=document_store) @@ -941,6 +941,7 @@ def get_document_store( embedding_field=embedding_field, index=index, similarity=similarity, + use_bm25=True, ) elif document_store_type == "elasticsearch": diff --git a/test/document_stores/test_document_store.py b/test/document_stores/test_document_store.py index 26a5cbacaa..a4acbbb312 100644 --- a/test/document_stores/test_document_store.py +++ b/test/document_stores/test_document_store.py @@ -5,12 +5,14 @@ import numpy as np import pandas as pd +from rank_bm25 import BM25 import pytest from unittest.mock import Mock from ..conftest import get_document_store, ensure_ids_are_correct_uuids from haystack.document_stores import ( + InMemoryDocumentStore, WeaviateDocumentStore, MilvusDocumentStore, FAISSDocumentStore, @@ -1444,3 +1446,31 @@ def test_normalize_embeddings_diff_shapes(): VEC_1 = np.array([0.1, 0.2, 0.3], dtype="float32").reshape(1, -1) BaseDocumentStore.normalize_embedding(VEC_1) assert np.linalg.norm(VEC_1) - 1 < 0.01 + + +def test_memory_update_bm25(): + ds = InMemoryDocumentStore(use_bm25=False) + ds.write_documents(DOCUMENTS) + ds.update_bm25() + bm25_representation = ds.bm25[ds.index] + assert isinstance(bm25_representation, BM25) + assert bm25_representation.corpus_size == ds.get_document_count() + + +@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True) +def test_memory_query(document_store_with_docs): + query_text = "Rome" + docs = document_store_with_docs.query(query=query_text, top_k=1) + assert len(docs) == 1 + assert docs[0].content == "My name is Matteo and I live in Rome" + + +@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True) +def test_memory_query_batch(document_store_with_docs): + query_texts = ["Paris", "Madrid"] + docs = document_store_with_docs.query_batch(queries=query_texts, top_k=5) + assert len(docs) == 2 + assert len(docs[0]) == 5 + assert docs[0][0].content == "My name is Christelle and I live in Paris" + assert len(docs[1]) == 5 + assert docs[1][0].content == "My name is Camila and I live in Madrid" diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index 94919482af..b779b1ee5d 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -45,17 +45,17 @@ ("embedding", "faiss"), ("embedding", "memory"), ("embedding", "milvus"), - ("elasticsearch", "elasticsearch"), + ("bm25", "elasticsearch"), + ("bm25", "memory"), ("es_filter_only", "elasticsearch"), ("tfidf", "memory"), ], indirect=True, ) -def test_retrieval(retriever_with_docs: BaseRetriever, document_store_with_docs: BaseDocumentStore): +def test_retrieval_without_filters(retriever_with_docs: BaseRetriever, document_store_with_docs: BaseDocumentStore): if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)): document_store_with_docs.update_embeddings(retriever_with_docs) - # test without filters # NOTE: FilterRetriever simply returns all documents matching a filter, # so without filters applied it does nothing if not isinstance(retriever_with_docs, FilterRetriever): @@ -64,29 +64,44 @@ def test_retrieval(retriever_with_docs: BaseRetriever, document_store_with_docs: assert len(res) == 5 assert res[0].meta["name"] == "filename1" - # test with filters - if not isinstance(document_store_with_docs, (FAISSDocumentStore, MilvusDocumentStore)) and not isinstance( - retriever_with_docs, TfidfRetriever - ): - # single filter - result = retriever_with_docs.retrieve(query="Christelle", filters={"name": ["filename3"]}, top_k=5) - assert len(result) == 1 - assert type(result[0]) == Document - assert result[0].content == "My name is Christelle and I live in Paris" - assert result[0].meta["name"] == "filename3" - - # multiple filters - result = retriever_with_docs.retrieve( - query="Paul", filters={"name": ["filename2"], "meta_field": ["test2", "test3"]}, top_k=5 - ) - assert len(result) == 1 - assert type(result[0]) == Document - assert result[0].meta["name"] == "filename2" - result = retriever_with_docs.retrieve( - query="Carla", filters={"name": ["filename1"], "meta_field": ["test2", "test3"]}, top_k=5 - ) - assert len(result) == 0 +@pytest.mark.parametrize( + "retriever_with_docs,document_store_with_docs", + [ + ("mdr", "elasticsearch"), + ("mdr", "memory"), + ("dpr", "elasticsearch"), + ("dpr", "memory"), + ("embedding", "elasticsearch"), + ("embedding", "memory"), + ("bm25", "elasticsearch"), + ("es_filter_only", "elasticsearch"), + ], + indirect=True, +) +def test_retrieval_with_filters(retriever_with_docs: BaseRetriever, document_store_with_docs: BaseDocumentStore): + if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever)): + document_store_with_docs.update_embeddings(retriever_with_docs) + + # single filter + result = retriever_with_docs.retrieve(query="Christelle", filters={"name": ["filename3"]}, top_k=5) + assert len(result) == 1 + assert type(result[0]) == Document + assert result[0].content == "My name is Christelle and I live in Paris" + assert result[0].meta["name"] == "filename3" + + # multiple filters + result = retriever_with_docs.retrieve( + query="Paul", filters={"name": ["filename2"], "meta_field": ["test2", "test3"]}, top_k=5 + ) + assert len(result) == 1 + assert type(result[0]) == Document + assert result[0].meta["name"] == "filename2" + + result = retriever_with_docs.retrieve( + query="Carla", filters={"name": ["filename1"], "meta_field": ["test2", "test3"]}, top_k=5 + ) + assert len(result) == 0 class MockBaseRetriever(MockRetriever): @@ -310,6 +325,9 @@ def test_retriever_basic_search(document_store, retriever, docs_with_ids): @pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True) @pytest.mark.embedding_dim(512) def test_table_text_retriever_embedding(document_store, retriever, docs): + # BM25 representation is incompatible with table retriever + if isinstance(document_store, InMemoryDocumentStore): + document_store.use_bm25 = False document_store.return_embedding = True document_store.write_documents(docs) diff --git a/test/nodes/test_summarizer.py b/test/nodes/test_summarizer.py index ae27ed52ca..3a41736709 100644 --- a/test/nodes/test_summarizer.py +++ b/test/nodes/test_summarizer.py @@ -70,7 +70,7 @@ def test_summarization_batch_multiple_doc_lists(summarizer): @pytest.mark.integration @pytest.mark.summarizer @pytest.mark.parametrize( - "retriever,document_store", [("embedding", "memory"), ("elasticsearch", "elasticsearch")], indirect=True + "retriever,document_store", [("embedding", "memory"), ("bm25", "elasticsearch")], indirect=True ) def test_summarization_pipeline(document_store, retriever, summarizer): document_store.write_documents(DOCS) @@ -118,7 +118,7 @@ def test_summarization_one_summary(summarizer): @pytest.mark.integration @pytest.mark.summarizer @pytest.mark.parametrize( - "retriever,document_store", [("embedding", "memory"), ("elasticsearch", "elasticsearch")], indirect=True + "retriever,document_store", [("embedding", "memory"), ("bm25", "elasticsearch")], indirect=True ) def test_summarization_pipeline_one_summary(document_store, retriever, summarizer): document_store.write_documents(SPLIT_DOCS) diff --git a/test/pipelines/test_eval.py b/test/pipelines/test_eval.py index 9ff14dc50c..0f44191880 100644 --- a/test/pipelines/test_eval.py +++ b/test/pipelines/test_eval.py @@ -164,7 +164,7 @@ def test_eval_reader(reader, document_store, use_confidence_scores): @pytest.mark.elasticsearch @pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True) @pytest.mark.parametrize("open_domain", [True, False]) -@pytest.mark.parametrize("retriever", ["elasticsearch"], indirect=True) +@pytest.mark.parametrize("retriever", ["bm25"], indirect=True) def test_eval_elastic_retriever(document_store, open_domain, retriever): # add eval data (SQUAD format) document_store.add_eval_data( @@ -188,7 +188,7 @@ def test_eval_elastic_retriever(document_store, open_domain, retriever): @pytest.mark.elasticsearch @pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True) @pytest.mark.parametrize("reader", ["farm"], indirect=True) -@pytest.mark.parametrize("retriever", ["elasticsearch"], indirect=True) +@pytest.mark.parametrize("retriever", ["bm25"], indirect=True) def test_eval_pipeline(document_store, reader, retriever): # add eval data (SQUAD format) document_store.add_eval_data( @@ -431,7 +431,7 @@ def test_extractive_qa_eval_multiple_queries(reader, retriever_with_docs, tmp_pa assert metrics["Retriever"]["ndcg"] == 0.5 -@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch"], indirect=True) +@pytest.mark.parametrize("retriever_with_docs", ["bm25"], indirect=True) @pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) @pytest.mark.parametrize("reader", ["farm"], indirect=True) def test_extractive_qa_labels_with_filters(reader, retriever_with_docs, tmp_path): diff --git a/test/pipelines/test_standard_pipelines.py b/test/pipelines/test_standard_pipelines.py index 880980afe2..847f241136 100644 --- a/test/pipelines/test_standard_pipelines.py +++ b/test/pipelines/test_standard_pipelines.py @@ -120,7 +120,7 @@ def test_document_search_pipeline_batch(retriever, document_store): @pytest.mark.integration -@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch", "dpr", "embedding"], indirect=True) +@pytest.mark.parametrize("retriever_with_docs", ["bm25", "dpr", "embedding"], indirect=True) @pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) def test_documentsearch_es_authentication(retriever_with_docs, document_store_with_docs: ElasticsearchDocumentStore): if isinstance(retriever_with_docs, (DensePassageRetriever, EmbeddingRetriever)):