-
Notifications
You must be signed in to change notification settings - Fork 2k
/
sql.py
754 lines (641 loc) · 32 KB
/
sql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
from typing import Any, Dict, Union, List, Optional, Generator
import logging
import itertools
from uuid import uuid4
import numpy as np
try:
from sqlalchemy import (
and_,
func,
create_engine,
Column,
String,
DateTime,
Boolean,
Text,
text,
JSON,
ForeignKeyConstraint,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.sql import case, null
except (ImportError, ModuleNotFoundError) as ie:
from haystack.utils.import_utils import _optional_component_not_installed
_optional_component_not_installed(__name__, "sql", ie)
from haystack.schema import Document, Label, Answer
from haystack.document_stores.base import BaseDocumentStore
from haystack.document_stores.filter_utils import LogicalFilterClause
logger = logging.getLogger(__name__)
Base = declarative_base() # type: Any
class ORMBase(Base):
__abstract__ = True
id = Column(String(100), default=lambda: str(uuid4()), primary_key=True)
created_at = Column(DateTime, server_default=func.now())
updated_at = Column(DateTime, server_default=func.now(), server_onupdate=func.now())
class DocumentORM(ORMBase):
__tablename__ = "document"
content = Column(JSON, nullable=False)
content_type = Column(Text, nullable=True)
# primary key in combination with id to allow the same doc in different indices
index = Column(String(100), nullable=False, primary_key=True)
vector_id = Column(String(100), unique=True, nullable=True)
# speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata
meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined")
class MetaDocumentORM(ORMBase):
__tablename__ = "meta_document"
name = Column(String(100), index=True)
value = Column(String(1000), index=True)
documents = relationship("DocumentORM", back_populates="meta")
document_id = Column(String(100), nullable=False, index=True)
document_index = Column(String(100), nullable=False, index=True)
__table_args__ = (
ForeignKeyConstraint(
[document_id, document_index], [DocumentORM.id, DocumentORM.index], ondelete="CASCADE", onupdate="CASCADE"
),
{},
) # type: ignore
class LabelORM(ORMBase):
__tablename__ = "label"
index = Column(String(100), nullable=False, primary_key=True)
query = Column(Text, nullable=False)
answer = Column(JSON, nullable=True)
document = Column(JSON, nullable=False)
no_answer = Column(Boolean, nullable=False)
origin = Column(String(100), nullable=False)
is_correct_answer = Column(Boolean, nullable=False)
is_correct_document = Column(Boolean, nullable=False)
pipeline_id = Column(String(500), nullable=True)
meta = relationship("MetaLabelORM", back_populates="labels", lazy="joined")
class MetaLabelORM(ORMBase):
__tablename__ = "meta_label"
name = Column(String(100), index=True)
value = Column(String(1000), index=True)
labels = relationship("LabelORM", back_populates="meta")
label_id = Column(String(100), nullable=False, index=True)
label_index = Column(String(100), nullable=False, index=True)
__table_args__ = (
ForeignKeyConstraint(
[label_id, label_index], [LabelORM.id, LabelORM.index], ondelete="CASCADE", onupdate="CASCADE"
),
{},
) # type: ignore
class SQLDocumentStore(BaseDocumentStore):
def __init__(
self,
url: str = "sqlite://",
index: str = "document",
label_index: str = "label",
duplicate_documents: str = "overwrite",
check_same_thread: bool = False,
isolation_level: str = None,
):
"""
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
:param url: URL for SQL database as expected by SQLAlchemy. More info here: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
:param index: The documents are scoped to an index attribute that can be used when writing, querying, or deleting documents.
This parameter sets the default value for document index.
:param label_index: The default value of index attribute for the labels.
:param duplicate_documents: Handle duplicates document based on parameter options.
Parameter options : ( 'skip','overwrite','fail')
skip: Ignore the duplicates documents
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
:param check_same_thread: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
"""
super().__init__()
create_engine_params = {}
if isolation_level:
create_engine_params["isolation_level"] = isolation_level
if "sqlite" in url:
engine = create_engine(url, connect_args={"check_same_thread": check_same_thread}, **create_engine_params)
else:
engine = create_engine(url, **create_engine_params)
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
self.session = Session()
self.index: str = index
self.label_index = label_index
self.duplicate_documents = duplicate_documents
if getattr(self, "similarity", None) is None:
self.similarity = None
self.use_windowed_query = True
if "sqlite" in url:
import sqlite3
if sqlite3.sqlite_version < "3.25":
self.use_windowed_query = False
def get_document_by_id(
self, id: str, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None
) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""
if headers:
raise NotImplementedError("SQLDocumentStore does not support headers.")
documents = self.get_documents_by_id([id], index)
document = documents[0] if documents else None
return document
def get_documents_by_id(
self,
ids: List[str],
index: Optional[str] = None,
batch_size: int = 10_000,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""Fetch documents by specifying a list of text id strings"""
if headers:
raise NotImplementedError("SQLDocumentStore does not support headers.")
index = index or self.index
documents = []
for i in range(0, len(ids), batch_size):
query = self.session.query(DocumentORM).filter(
DocumentORM.id.in_(ids[i : i + batch_size]), DocumentORM.index == index
)
for row in query.all():
documents.append(self._convert_sql_row_to_document(row))
return documents
def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None, batch_size: int = 10_000):
"""Fetch documents by specifying a list of text vector id strings"""
index = index or self.index
documents = []
for i in range(0, len(vector_ids), batch_size):
query = self.session.query(DocumentORM).filter(
DocumentORM.vector_id.in_(vector_ids[i : i + batch_size]), DocumentORM.index == index
)
for row in query.all():
documents.append(self._convert_sql_row_to_document(row))
sorted_documents = sorted(documents, key=lambda doc: vector_ids.index(doc.meta["vector_id"]))
return sorted_documents
def get_all_documents(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in SQLDocStore
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
if headers:
raise NotImplementedError("SQLDocumentStore does not support headers.")
documents = list(
self.get_all_documents_generator(
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
)
)
return documents
def get_all_documents_generator(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in SQLDocStore
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
headers: Optional[Dict[str, str]] = None,
) -> Generator[Document, None, None]:
"""
Get documents from the document store. Under-the-hood, documents are fetched in batches from the
document store and yielded as individual documents. This method can be used to iteratively process
a large number of documents without having to load all documents in memory.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
if headers:
raise NotImplementedError("SQLDocumentStore does not support headers.")
if return_embedding is True:
raise Exception("return_embeddings is not supported by SQLDocumentStore.")
result = self._query(index=index, filters=filters, batch_size=batch_size)
yield from result
def _create_document_field_map(self) -> Dict:
"""
There is no field mapping required
"""
return {}
def _query(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in SQLDocStore
vector_ids: Optional[List[str]] = None,
only_documents_without_embedding: bool = False,
batch_size: int = 10_000,
):
"""
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param vector_ids: List of vector_id strings to filter the documents by.
:param only_documents_without_embedding: return only documents without an embedding.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
index = index or self.index
# Generally ORM objects kept in memory cause performance issue
# Hence using directly column name improve memory and performance.
# Refer https://stackoverflow.com/questions/23185319/why-is-loading-sqlalchemy-objects-via-the-orm-5-8x-slower-than-rows-via-a-raw-my
documents_query = self.session.query(
DocumentORM.id, DocumentORM.content, DocumentORM.content_type, DocumentORM.vector_id
).filter_by(index=index)
if filters:
parsed_filter = LogicalFilterClause.parse(filters)
select_ids = parsed_filter.convert_to_sql(MetaDocumentORM)
documents_query = documents_query.filter(DocumentORM.id.in_(select_ids))
if only_documents_without_embedding:
documents_query = documents_query.filter(DocumentORM.vector_id.is_(None))
if vector_ids:
documents_query = documents_query.filter(DocumentORM.vector_id.in_(vector_ids))
documents_map = {}
if self.use_windowed_query:
documents_query = self._windowed_query(documents_query, DocumentORM.id, batch_size)
for i, row in enumerate(documents_query, start=1):
documents_map[row.id] = Document.from_dict(
{
"id": row.id,
"content": row.content,
"content_type": row.content_type,
"meta": {} if row.vector_id is None else {"vector_id": row.vector_id},
}
)
if i % batch_size == 0:
documents_map = self._get_documents_meta(documents_map)
yield from documents_map.values()
documents_map = {}
if documents_map:
documents_map = self._get_documents_meta(documents_map)
yield from documents_map.values()
def _get_documents_meta(self, documents_map):
doc_ids = documents_map.keys()
meta_query = self.session.query(
MetaDocumentORM.document_id, MetaDocumentORM.name, MetaDocumentORM.value
).filter(MetaDocumentORM.document_id.in_(doc_ids))
for row in meta_query.all():
documents_map[row.document_id].meta[row.name] = row.value
return documents_map
def get_all_labels(self, index=None, filters: Optional[dict] = None, headers: Optional[Dict[str, str]] = None):
"""
Return all labels in the document store
"""
if headers:
raise NotImplementedError("SQLDocumentStore does not support headers.")
index = index or self.label_index
# TODO: Use batch_size
label_rows = self.session.query(LabelORM).filter_by(index=index).all()
labels = [self._convert_sql_row_to_label(row) for row in label_rows]
return labels
def write_documents(
self,
documents: Union[List[dict], List[Document]],
index: Optional[str] = None,
batch_size: int = 10_000,
duplicate_documents: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
) -> None:
"""
Indexes documents for later queries.
:param documents: a list of Python dictionaries or a list of Haystack Document objects.
For documents as dictionaries, the format is {"text": "<the-actual-text>"}.
Optionally: Include meta data via {"text": "<the-actual-text>",
"meta":{"name": "<some-document-name>, "author": "somebody", ...}}
It can be used for filtering and is accessible in the responses of the Finder.
:param index: add an optional index attribute to documents. It can be later used for filtering. For instance,
documents for evaluation can be indexed in a separate index than the documents for search.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:param duplicate_documents: Handle duplicates document based on parameter options.
Parameter options : ( 'skip','overwrite','fail')
skip: Ignore the duplicates documents
overwrite: Update any existing documents with the same ID when adding documents
but is considerably slower (default).
fail: an error is raised if the document ID of the document being added already
exists.
:return: None
"""
if headers:
raise NotImplementedError("SQLDocumentStore does not support headers.")
index = index or self.index
duplicate_documents = duplicate_documents or self.duplicate_documents
if len(documents) == 0:
return
# Make sure we comply to Document class format
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
document_objects = self._handle_duplicate_documents(
documents=document_objects, index=index, duplicate_documents=duplicate_documents
)
for i in range(0, len(document_objects), batch_size):
docs_orm = []
for doc in document_objects[i : i + batch_size]:
meta_fields = doc.meta or {}
vector_id = meta_fields.pop("vector_id", None)
meta_orms = [MetaDocumentORM(name=key, value=value) for key, value in meta_fields.items()]
doc_mapping = {
"id": doc.id,
"content": doc.to_dict()["content"],
"content_type": doc.content_type,
"vector_id": vector_id,
"meta": meta_orms,
"index": index,
}
if duplicate_documents == "overwrite":
doc_orm = DocumentORM(**doc_mapping)
# First old meta data cleaning is required
self.session.query(MetaDocumentORM).filter_by(document_id=doc.id).delete()
self.session.merge(doc_orm)
else:
docs_orm.append(doc_mapping)
if docs_orm:
self.session.bulk_insert_mappings(DocumentORM, docs_orm)
try:
self.session.commit()
except Exception as ex:
logger.error(f"Transaction rollback: {ex.__cause__}")
# Rollback is important here otherwise self.session will be in inconsistent state and next call will fail
self.session.rollback()
raise ex
def write_labels(self, labels, index=None, headers: Optional[Dict[str, str]] = None):
"""Write annotation labels into document store."""
if headers:
raise NotImplementedError("SQLDocumentStore does not support headers.")
labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
index = index or self.label_index
duplicate_ids: list = [label.id for label in self._get_duplicate_labels(labels, index=index)]
if len(duplicate_ids) > 0:
logger.warning(
f"Duplicate Label IDs: Inserting a Label whose id already exists in this document store."
f" This will overwrite the old Label. Please make sure Label.id is a unique identifier of"
f" the answer annotation and not the question."
f" Problematic ids: {','.join(duplicate_ids)}"
)
# TODO: Use batch_size
for label in labels:
# TODO As of now, we write documents as part of the Label table as this is consistent with the other
# document stores (e.g. elasticsearch) where "indices" are completely independent.
# We should eventually switch to an approach here that writes related documents to the document table if not already existing.
# See Issue XXX
# self.write_documents(documents=[label.document], index=index, duplicate_documents="skip")
# TODO: Handle label meta data
label_orm = LabelORM(
id=label.id,
no_answer=label.no_answer,
# document_id=label.document.id,
document=label.document.to_json(),
origin=label.origin,
query=label.query,
is_correct_answer=label.is_correct_answer,
is_correct_document=label.is_correct_document,
answer=label.answer.to_json(),
pipeline_id=label.pipeline_id,
index=index,
)
if label.id in duplicate_ids:
self.session.merge(label_orm)
else:
self.session.add(label_orm)
# TODO: investigate why test_multilabel() failed when not committing within the loop
# Seems that in some cases only the last label get than "committed"
self.session.commit()
def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str] = None, batch_size: int = 10_000):
"""
Update vector_ids for given document_ids.
:param vector_id_map: dict containing mapping of document_id -> vector_id.
:param index: filter documents by the optional index attribute for documents in database.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
index = index or self.index
for chunk_map in self.chunked_dict(vector_id_map, size=batch_size):
self.session.query(DocumentORM).filter(DocumentORM.id.in_(chunk_map), DocumentORM.index == index).update(
{DocumentORM.vector_id: case(chunk_map, value=DocumentORM.id)}, synchronize_session=False
)
try:
self.session.commit()
except Exception as ex:
logger.error(f"Transaction rollback: {ex.__cause__}")
self.session.rollback()
raise ex
def reset_vector_ids(self, index: Optional[str] = None):
"""
Set vector IDs for all documents as None
"""
index = index or self.index
self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()})
self.session.commit()
def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None):
"""
Update the metadata dictionary of a document by specifying its string id
"""
if not index:
index = self.index
self.session.query(MetaDocumentORM).filter_by(document_id=id, document_index=index).delete()
meta_orms = [
MetaDocumentORM(name=key, value=value, document_id=id, document_index=index) for key, value in meta.items()
]
for m in meta_orms:
self.session.add(m)
self.session.commit()
def get_document_count(
self,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in SQLDocStore
index: Optional[str] = None,
only_documents_without_embedding: bool = False,
headers: Optional[Dict[str, str]] = None,
) -> int:
"""
Return the number of documents in the document store.
"""
if headers:
raise NotImplementedError("SQLDocumentStore does not support headers.")
index = index or self.index
query = self.session.query(DocumentORM).filter_by(index=index)
if filters:
for key, values in filters.items():
query = query.join(MetaDocumentORM, aliased=True).filter(
MetaDocumentORM.name == key, MetaDocumentORM.value.in_(values)
)
if only_documents_without_embedding:
query = query.filter(DocumentORM.vector_id.is_(None))
count = query.count()
return count
def get_label_count(self, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> int:
"""
Return the number of labels in the document store
"""
if headers:
raise NotImplementedError("SQLDocumentStore does not support headers.")
index = index or self.label_index
return self.session.query(LabelORM).filter_by(index=index).count()
def _convert_sql_row_to_document(self, row) -> Document:
doc_dict = {
"id": row.id,
"content": row.content,
"content_type": row.content_type,
"meta": {meta.name: meta.value for meta in row.meta},
}
document = Document.from_dict(doc_dict)
if row.vector_id:
document.meta["vector_id"] = row.vector_id
return document
def _convert_sql_row_to_label(self, row) -> Label:
# doc = self._convert_sql_row_to_document(row.document)
label = Label(
query=row.query,
answer=Answer.from_json(row.answer), # type: ignore
document=Document.from_json(row.document),
is_correct_answer=row.is_correct_answer,
is_correct_document=row.is_correct_document,
origin=row.origin,
id=row.id,
no_answer=row.no_answer,
pipeline_id=row.pipeline_id,
created_at=str(row.created_at),
updated_at=str(row.updated_at),
meta=row.meta,
)
return label
def query_by_embedding(
self,
query_emb: np.ndarray,
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None,
return_embedding: Optional[bool] = None,
headers: Optional[Dict[str, str]] = None,
scale_score: bool = True,
) -> List[Document]:
raise NotImplementedError(
"SQLDocumentStore is currently not supporting embedding queries. "
"Change the query type (e.g. by choosing a different retriever) "
"or change the DocumentStore (e.g. to ElasticsearchDocumentStore)"
)
def delete_all_documents(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in SQLDocStore
headers: Optional[Dict[str, str]] = None,
):
"""
Delete documents in an index. All documents are deleted if no filters are passed.
:param index: Index name to delete the document from.
:param filters: Optional filters to narrow down the documents to be deleted.
:return: None
"""
if headers:
raise NotImplementedError("SQLDocumentStore does not support headers.")
logger.warning(
"""DEPRECATION WARNINGS:
1. delete_all_documents() method is deprecated, please use delete_documents method
For more details, please refer to the issue: https://github.com/deepset-ai/haystack/issues/1045
"""
)
self.delete_documents(index, None, filters)
def delete_documents(
self,
index: Optional[str] = None,
ids: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in SQLDocStore
headers: Optional[Dict[str, str]] = None,
):
"""
Delete documents in an index. All documents are deleted if no filters are passed.
:param index: Index name to delete the document from. If None, the
DocumentStore's default index (self.index) will be used.
:param ids: Optional list of IDs to narrow down the documents to be deleted.
:param filters: Optional filters to narrow down the documents to be deleted.
Example filters: {"name": ["some", "more"], "category": ["only_one"]}.
If filters are provided along with a list of IDs, this method deletes the
intersection of the two query results (documents that match the filters and
have their ID in the list).
:return: None
"""
index = index or self.index
if not filters and not ids:
self.session.query(DocumentORM).filter_by(index=index).delete(synchronize_session=False)
else:
document_ids_to_delete = self.session.query(DocumentORM.id).filter(DocumentORM.index == index)
if filters:
for key, values in filters.items():
document_ids_to_delete = document_ids_to_delete.join(MetaDocumentORM, aliased=True).filter(
MetaDocumentORM.name == key, MetaDocumentORM.value.in_(values)
)
if ids:
document_ids_to_delete = document_ids_to_delete.filter(DocumentORM.id.in_(ids))
self.session.query(DocumentORM).filter(DocumentORM.id.in_(document_ids_to_delete)).delete(
synchronize_session=False
)
self.session.commit()
def delete_index(self, index: str):
"""
Delete an existing index. The index including all data will be removed.
:param index: The name of the index to delete.
:return: None
"""
SQLDocumentStore.delete_documents(self, index)
def delete_labels(
self,
index: Optional[str] = None,
ids: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in SQLDocStore
headers: Optional[Dict[str, str]] = None,
):
"""
Delete labels from the document store. All labels are deleted if no filters are passed.
:param index: Index name to delete the labels from. If None, the
DocumentStore's default label index (self.label_index) will be used.
:param ids: Optional list of IDs to narrow down the labels to be deleted.
:param filters: Optional filters to narrow down the labels to be deleted.
Example filters: {"id": ["9a196e41-f7b5-45b4-bd19-5feb7501c159", "9a196e41-f7b5-45b4-bd19-5feb7501c159"]} or {"query": ["question2"]}
:return: None
"""
if headers:
raise NotImplementedError("SQLDocumentStore does not support headers.")
index = index or self.label_index
if not filters and not ids:
self.session.query(LabelORM).filter_by(index=index).delete(synchronize_session=False)
else:
label_ids_to_delete = self.session.query(LabelORM.id).filter_by(index=index)
if filters:
for key, values in filters.items():
label_attribute = getattr(LabelORM, key)
label_ids_to_delete = label_ids_to_delete.filter(label_attribute.in_(values))
if ids:
label_ids_to_delete = label_ids_to_delete.filter(LabelORM.id.in_(ids))
self.session.query(LabelORM).filter(LabelORM.id.in_(label_ids_to_delete)).delete(synchronize_session=False)
self.session.commit()
def _get_or_create(self, session, model, **kwargs):
instance = session.query(model).filter_by(**kwargs).first()
if instance:
return instance
else:
instance = model(**kwargs)
session.add(instance)
session.commit()
return instance
def chunked_dict(self, dictionary, size):
it = iter(dictionary)
for i in range(0, len(dictionary), size):
yield {k: dictionary[k] for k in itertools.islice(it, size)}
def _column_windows(self, session, column, windowsize):
"""Return a series of WHERE clauses against
a given column that break it into windows.
Result is an iterable of tuples, consisting of
((start, end), whereclause), where (start, end) are the ids.
The code is taken from: https://github.com/sqlalchemy/sqlalchemy/wiki/RangeQuery-and-WindowedRangeQuery
"""
def int_for_range(start_id, end_id):
if end_id:
return and_(column >= start_id, column < end_id)
else:
return column >= start_id
q = session.query(column, func.row_number().over(order_by=column).label("rownum")).from_self(column)
if windowsize > 1:
q = q.filter(text("rownum %% %d=1" % windowsize))
intervals = [id for id, in q]
while intervals:
start = intervals.pop(0)
if intervals:
end = intervals[0]
else:
end = None
yield int_for_range(start, end)
def _windowed_query(self, q, column, windowsize):
""" "Break a Query into windows on a given column."""
for whereclause in self._column_windows(q.session, column, windowsize):
for row in q.filter(whereclause).order_by(column):
yield row