From ea75e2aab512a384bbb45e63c0c41c3024be952b Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Fri, 18 Nov 2022 08:26:19 +0100 Subject: [PATCH] feat: store metadata using JSON in SQLDocumentStore (#3547) * add warnings * make the field cachable * review comment --- haystack/document_stores/sql.py | 39 ++++++++++++++++---------------- test/document_stores/test_sql.py | 36 ++++++++--------------------- 2 files changed, 29 insertions(+), 46 deletions(-) diff --git a/haystack/document_stores/sql.py b/haystack/document_stores/sql.py index 89a7ab220b..654bf20d9d 100644 --- a/haystack/document_stores/sql.py +++ b/haystack/document_stores/sql.py @@ -2,6 +2,7 @@ import logging import itertools +import json from uuid import uuid4 import numpy as np @@ -20,9 +21,10 @@ JSON, ForeignKeyConstraint, UniqueConstraint, + TypeDecorator, ) from sqlalchemy.ext.declarative import declarative_base - from sqlalchemy.orm import relationship, sessionmaker, validates + from sqlalchemy.orm import relationship, sessionmaker from sqlalchemy.sql import case, null except (ImportError, ModuleNotFoundError) as ie: from haystack.utils.import_utils import _optional_component_not_installed @@ -38,6 +40,20 @@ Base = declarative_base() # type: Any +class ArrayType(TypeDecorator): + + impl = String + cache_ok = True + + def process_bind_param(self, value, dialect): + return json.dumps(value) + + def process_result_value(self, value, dialect): + if value is not None: + return json.loads(value) + return value + + class ORMBase(Base): __abstract__ = True @@ -64,7 +80,7 @@ class MetaDocumentORM(ORMBase): __tablename__ = "meta_document" name = Column(String(100), index=True) - value = Column(String(1000), index=True) + value = Column(ArrayType(100), index=True) documents = relationship("DocumentORM", back_populates="meta") document_id = Column(String(100), nullable=False, index=True) @@ -76,17 +92,6 @@ class MetaDocumentORM(ORMBase): {}, ) # type: ignore - valid_metadata_types = (str, int, float, bool, bytes, bytearray, type(None)) - - @validates("value") - def validate_value(self, key, value): - if not isinstance(value, self.valid_metadata_types): - raise TypeError( - f"Discarded metadata '{self.name}', since it has invalid type: {type(value).__name__}.\n" - f"SQLDocumentStore can accept and cast to string only the following types: {', '.join([el.__name__ for el in self.valid_metadata_types])}" - ) - return value - class LabelORM(ORMBase): __tablename__ = "label" @@ -298,6 +303,7 @@ def _query( ).filter_by(index=index) if filters: + logger.warning("filters won't work on metadata fields containing compound data types") parsed_filter = LogicalFilterClause.parse(filters) select_ids = parsed_filter.convert_to_sql(MetaDocumentORM) documents_query = documents_query.filter(DocumentORM.id.in_(select_ids)) @@ -402,12 +408,7 @@ def write_documents( if "classification" in meta_fields: meta_fields = self._flatten_classification_meta_fields(meta_fields) vector_id = meta_fields.pop("vector_id", None) - meta_orms = [] - for key, value in meta_fields.items(): - try: - meta_orms.append(MetaDocumentORM(name=key, value=value)) - except TypeError as ex: - logger.error("Document %s - %s", doc.id, ex) + meta_orms = [MetaDocumentORM(name=key, value=value) for key, value in meta_fields.items()] doc_orm = DocumentORM( id=doc.id, content=doc.to_dict()["content"], diff --git a/test/document_stores/test_sql.py b/test/document_stores/test_sql.py index f26153b20b..ee777c61cb 100644 --- a/test/document_stores/test_sql.py +++ b/test/document_stores/test_sql.py @@ -1,3 +1,5 @@ +import logging + import pytest from haystack.document_stores.sql import SQLDocumentStore @@ -24,28 +26,6 @@ def test_delete_index(self, ds, documents): ds.delete_index(index="custom_index") assert ds.get_document_count(index="custom_index") == 0 - @pytest.mark.integration - def test_sql_write_document_invalid_meta(self, ds): - documents = [ - { - "content": "dict_with_invalid_meta", - "valid_meta_field": "test1", - "invalid_meta_field": [1, 2, 3], - "name": "filename1", - "id": "1", - }, - Document( - content="document_object_with_invalid_meta", - meta={"valid_meta_field": "test2", "invalid_meta_field": [1, 2, 3], "name": "filename2"}, - id="2", - ), - ] - ds.write_documents(documents) - documents_in_store = ds.get_all_documents() - assert len(documents_in_store) == 2 - assert ds.get_document_by_id("1").meta == {"name": "filename1", "valid_meta_field": "test1"} - assert ds.get_document_by_id("2").meta == {"name": "filename2", "valid_meta_field": "test2"} - @pytest.mark.integration def test_sql_write_different_documents_same_vector_id(self, ds): doc1 = {"content": "content 1", "name": "doc1", "id": "1", "vector_id": "vector_id"} @@ -98,13 +78,15 @@ def test_sql_get_documents_using_nested_filters_about_classification(self, ds): assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.95}})) == 0 assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_100"]})) == 0 - # NOTE: the SQLDocumentStore behaves differently to the others when filters are applied. - # While this should be considered a bug, the relative tests are skipped in the meantime + # NOTE: the SQLDocumentStore marshals metadata values with JSON so querying + # using filters doesn't always work. While this should be considered a bug, + # the relative tests are either customized or skipped while we work on a fix. - @pytest.mark.skip @pytest.mark.integration - def test_ne_filters(self, ds, documents): - pass + def test_ne_filters(self, ds, caplog): + with caplog.at_level(logging.WARNING): + ds.get_all_documents(filters={"year": {"$ne": "2020"}}) + assert "filters won't work on metadata fields" in caplog.text @pytest.mark.skip @pytest.mark.integration