Skip to content

Commit

Permalink
fix(sqlite): allow sqlite to work in mp/mt context (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Jan 26, 2022
1 parent ab788b2 commit d127742
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
8 changes: 7 additions & 1 deletion docarray/array/storage/sqlite/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class SqliteConfig:
table_name: Optional[str] = None
serialize_config: Dict = field(default_factory=dict)
conn_config: Dict = field(default_factory=dict)
journal_mode: str = 'DELETE'
synchronous: str = 'OFF'


class BackendMixin(BaseBackendMixin):
Expand Down Expand Up @@ -69,7 +71,9 @@ def _init_storage(
'Document', lambda x: Document.from_bytes(x, **config.serialize_config)
)

_conn_kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES)
_conn_kwargs = dict(
detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False
)
_conn_kwargs.update(config.conn_config)
if config.connection is None:
self._connection = sqlite3.connect(
Expand All @@ -83,6 +87,8 @@ def _init_storage(
raise TypeError(
f'connection argument must be None or a string or a sqlite3.Connection, not `{type(config.connection)}`'
)
self._connection.execute(f'PRAGMA synchronous={config.synchronous}')
self._connection.execute(f'PRAGMA journal_mode={config.journal_mode}')

self._table_name = (
_sanitize_table_name(self.__class__.__name__ + random_identity())
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/array/mixins/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,22 @@ def test_map_lambda(pytestconfig, da_cls):

for d in da.map(lambda x: x.load_uri_to_image_tensor()):
assert d.tensor is not None


@pytest.mark.parametrize('storage', ['memory', 'sqlite'])
@pytest.mark.parametrize('backend', ['thread', 'process'])
def test_apply_diff_backend_storage(storage, backend):
da = DocumentArray(
(Document(text='hello world she smiled too much') for _ in range(1000)),
storage=storage,
)
da.apply(lambda d: d.embed_feature_hashing(), backend=backend)

q = (
Document(text='she smiled too much')
.embed_feature_hashing()
.match(da, metric='jaccard', use_scipy=True)
)

assert len(q.matches[:5, ('text', 'scores__jaccard__value')]) == 2
assert len(q.matches[:5, ('text', 'scores__jaccard__value')][0]) == 5

0 comments on commit d127742

Please sign in to comment.