Skip to content

Commit

Permalink
fix: Flatten DocumentClassifier output in SQLDocumentStore; remov…
Browse files Browse the repository at this point in the history
…e `_sql_session_rollback` hack in tests (#3273)

* first draft

* fix

* fix

* move test to test_sql
  • Loading branch information
anakin87 authored Nov 16, 2022
1 parent af78f8b commit dc26e6d
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 19 deletions.
13 changes: 13 additions & 0 deletions haystack/document_stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ def write_documents(
docs_orm = []
for doc in document_objects[i : i + batch_size]:
meta_fields = doc.meta or {}
if "classification" in meta_fields:
meta_fields = self._flatten_classification_meta_fields(meta_fields)
vector_id = meta_fields.pop("vector_id", None)
meta_orms = []
for key, value in meta_fields.items():
Expand Down Expand Up @@ -785,3 +787,14 @@ def _windowed_query(self, q, column, windowsize):
for whereclause in self._column_windows(q.session, column, windowsize):
for row in q.filter(whereclause).order_by(column):
yield row

def _flatten_classification_meta_fields(self, meta_fields: dict) -> dict:
"""
Since SQLDocumentStore does not support dictionaries for metadata values,
the DocumentClassifier output is flattened
"""
meta_fields["classification.label"] = meta_fields["classification"]["label"]
meta_fields["classification.score"] = meta_fields["classification"]["score"]
meta_fields["classification.details"] = str(meta_fields["classification"]["details"])
del meta_fields["classification"]
return meta_fields
18 changes: 0 additions & 18 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,6 @@
requests_cache.install_cache(urls_expire_after={"huggingface.co": timedelta(hours=1), "*": requests_cache.DO_NOT_CACHE})


def _sql_session_rollback(self, attr):
"""
Inject SQLDocumentStore at runtime to do a session rollback each time it is called. This allows to catch
errors where an intended operation is still in a transaction, but not committed to the database.
"""
method = object.__getattribute__(self, attr)
if callable(method):
try:
self.session.rollback()
except AttributeError:
pass

return method


SQLDocumentStore.__getattribute__ = _sql_session_rollback


def pytest_collection_modifyitems(config, items):
# add pytest markers for tests that are not explicitly marked but include some keywords
name_to_markers = {
Expand Down
1 change: 0 additions & 1 deletion test/document_stores/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def test_write_with_duplicate_doc_ids_custom_index(document_store: BaseDocumentS


def test_get_all_documents_without_filters(document_store_with_docs):
print("hey!")
documents = document_store_with_docs.get_all_documents()
assert all(isinstance(d, Document) for d in documents)
assert len(documents) == 5
Expand Down
36 changes: 36 additions & 0 deletions test/document_stores/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,42 @@ def test_sql_write_different_documents_same_vector_id(self, ds):
with pytest.raises(Exception, match=r"(?i)unique"):
ds.write_documents([doc2], index="index3")

@pytest.mark.integration
def test_sql_get_documents_using_nested_filters_about_classification(self, ds):
documents = [
Document(
content="That's good. I like it.",
id="1",
meta={
"classification": {
"label": "LABEL_1",
"score": 0.694,
"details": {"LABEL_1": 0.694, "LABEL_0": 0.306},
}
},
),
Document(
content="That's bad. I don't like it.",
id="2",
meta={
"classification": {
"label": "LABEL_0",
"score": 0.898,
"details": {"LABEL_0": 0.898, "LABEL_1": 0.102},
}
},
),
]
ds.write_documents(documents)

assert ds.get_document_count() == 2
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.1}})) == 2
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_1", "LABEL_0"]})) == 2
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.8}})) == 1
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_1"]})) == 1
assert len(ds.get_all_documents(filters={"classification.score": {"$gt": 0.95}})) == 0
assert len(ds.get_all_documents(filters={"classification.label": ["LABEL_100"]})) == 0

# NOTE: the SQLDocumentStore behaves differently to the others when filters are applied.
# While this should be considered a bug, the relative tests are skipped in the meantime

Expand Down

0 comments on commit dc26e6d

Please sign in to comment.