-
Notifications
You must be signed in to change notification settings - Fork 2
/
itemdb.py
715 lines (578 loc) · 26.1 KB
/
itemdb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
# Copyright (c) 2019-2021 Almar Klein - This code is subject to the MIT license
"""
The itemdb library allows you to store and retrieve Python dicts in a
database on the local filesystem, in an easy, fast, and reliable way.
Based on the rock-solid and ACID compliant SQLite, but with easy and
explicit transactions using a ``with`` statement. It provides a simple
object-based API, with the flexibility to store (JSON-compatible) items
with arbitrary fields, and add indices when needed.
"""
import os
import json
import queue
import asyncio
import sqlite3
import threading
__version__ = "1.2.0"
version_info = tuple(map(int, __version__.split(".")))
__all__ = ["ItemDB", "AsyncItemDB", "asyncify"]
json_encode = json.JSONEncoder(ensure_ascii=True).encode
json_decode = json.JSONDecoder().decode
# Notes:
#
# * Setting isolation_level to None turns on autocommit mode. We need to do
# this to prevent Python from issuing BEGIN before DML statements.
# * Using a connection object as a context manager auto-commits/rollbacks a
# transaction.
# * We should close cursor objects as soon as possible, because they can hold
# back waiting writers. That's why we dont have an iterator.
# * MongoDB's approach of db.tablename.push() looks nice, but I don't like
# the "magical" side of it, especially since the db does not know its tables.
# Also it makes the code more complex, introduces an extra class, and
# increases the risk of preventing a db from closing (by holding a table).
def asyncify(func):
"""Wrap a normal function into an awaitable co-routine. Can be used
as a decorator.
The original function will be executed in a separate thread. This
allows async code to execute io-bound code (like querying a sqlite
database) without stalling.
Note that the code in func must be thread-safe. It's probably best to
isolate the io-bound parts of your code and only wrap these.
"""
def threaded_func(loop, future, args, kwargs):
try:
result = func(*args, **kwargs)
except BaseException as e:
loop.call_soon_threadsafe(future.set_exception, e)
else:
loop.call_soon_threadsafe(future.set_result, result)
async def asyncified_func(*args, **kwargs):
loop = asyncio.get_running_loop()
future = loop.create_future()
threading.Thread(
name="asyncify " + func.__name__,
target=threaded_func,
args=(loop, future, args, kwargs),
).start()
return await future
asyncified_func.__name__ = "asyncified_" + func.__name__
return asyncified_func
class ItemDB:
"""A transactional database for storage and retrieval of dict items.
Parameters
----------
filename : str
The file to open. Use ":memory:" for an in-memory db.
The items in the database can be any JSON serializable dictionary.
Indices can be defined for specific fields to enable fast selection
of items based on these fields. Indices can be marked as unique to
make a field mandatory and *identify* items based on that field.
Transactions are done by using the ``with`` statement, and are mandatory
for all operations that write to the database.
"""
def __init__(self, filename):
self._mtime = -1
if os.path.isfile(filename):
self._mtime = os.path.getmtime(filename)
self._conn = sqlite3.connect(
filename, timeout=60, isolation_level=None, check_same_thread=False
)
self._cur = None
self._indices_per_table = {}
@property
def mtime(self):
"""The time that the database file was last modified, as a Unix timestamp.
Is -1 if the file did not exist, or if the filename is not represented
on the filesystem.
"""
return self._mtime
def __enter__(self):
if self._cur is not None:
raise IOError("Already in a transaction")
self._cur = self._conn.cursor()
self._cur.execute("BEGIN IMMEDIATE")
return self
def __exit__(self, type, value, traceback):
self._cur.close()
self._cur = None
if value:
self._conn.rollback()
self._indices_per_table.clear() # we cannot trust this cache anymore
else:
self._conn.commit()
def __del__(self):
self._conn.close()
def close(self):
"""Close the database connection.
This will be automatically called when the instance is deleted.
But since it can be held e.g. in a traceback, consider using
``with closing(db):``.
"""
self._conn.close()
def get_table_names(self):
"""Return a (sorted) list of table names present in the database."""
cur = self._conn.cursor()
try:
cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
table_names = {x[0] for x in cur}
finally:
cur.close()
return list(sorted(table_names))
def get_indices(self, table_name):
"""Get a set of index names for the given table.
Parameters
----------
table_name : str
The name of the table to get the indices for.
*To avoid SQL injection, this arg should not be based on unsafe data.*
Names prefixed with "!" represent fields that are required and
unique. Raises KeyError if the table does not exist.
"""
# Use cached?
try:
return self._indices_per_table[table_name]
except KeyError:
pass
except TypeError:
raise TypeError(f"Table name must be str, not {table_name}.")
# Check table name
if not isinstance(table_name, str):
raise TypeError(f"Table name must be str, not {table_name}")
elif not table_name.isidentifier():
raise ValueError(f"Table name must be an identifier, not '{table_name}'")
# Get columns for the table (cid, name, type, notnull, default, pk)
cur = self._conn.cursor()
try:
cur.execute(f"PRAGMA table_info('{table_name}');")
found_indices = {(x[3] * "!" + x[1]) for x in cur} # includes !_ob
finally:
cur.close()
# Cache and return - or fail
if found_indices:
found_indices.difference_update({"!_ob", "_ob"})
self._indices_per_table[table_name] = found_indices
return found_indices
else:
raise KeyError(f"Table {table_name} not present, maybe use ensure_table()?")
def ensure_table(self, table_name, *indices):
"""Ensure that the given table exists and has the given indices.
Parameters
----------
table_name : str
The name of the table to make sure exists.
*To avoid SQL injection, this arg should not be based on unsafe data.*
indices : varargs
A sequence of strings, representing index names. Fields that are
indexed can be queried with e.g. ``select()``.
*To avoid SQL injection, this arg should not be based on unsafe data.*
If an index name is prefixed with "!", it indicates a field that is
mandatory and unique. Note that new unique indices cannot be added
when the table already exist.
This method returns as quickly as possible when the table
already exists and has the appropriate indices. Returns the
ItemDB object, so calls to this method can be stacked.
Although this call may modify the database, one does not need
to call this in a transaction.
"""
if not all(isinstance(x, str) for x in indices):
raise TypeError("Indices must be str")
# Select missing indices
try:
missing_indices = set(indices).difference(self.get_indices(table_name))
except KeyError:
missing_indices = {"--table--"}
# Do we need to do some work? Allow being used under a context and not
if missing_indices:
if self._cur:
self._ensure_table_helper1(table_name, indices, missing_indices)
else:
with self:
self._ensure_table_helper1(table_name, indices, missing_indices)
return self # allow stacking this function
def _ensure_table_helper1(self, table_name, indices, missing_indices):
# Make sure the table is complete
self._ensure_table_helper2(table_name, indices)
self._indices_per_table.pop(table_name, None) # let it refresh
# Update values that already had a value for the just added columns/indices
items = [
item
for item in self.select_all(table_name)
if any(x.lstrip("!") in item for x in missing_indices)
]
self.put(table_name, *items)
def _ensure_table_helper2(self, table_name, indices):
"""Slow version to ensure table."""
cur = self._cur
# Check the column names
for fieldname in indices:
key = fieldname.lstrip("!")
if not key.isidentifier():
raise ValueError("Column names must be identifiers.")
elif key == "_ob":
raise IndexError("Column names cannot be '_ob' (name is reserved).")
# Ensure the table.
# If there is one unique key, make it the primary key and omit rowid.
# This results in smaller and faster databases.
text = f"CREATE TABLE IF NOT EXISTS {table_name} (_ob TEXT NOT NULL"
unique_keys = sorted(x.lstrip("!") for x in indices if x.startswith("!"))
if len(unique_keys) == 1:
index_key = unique_keys[0]
text += f", {index_key} NOT NULL PRIMARY KEY) WITHOUT ROWID;"
else:
for index_key in unique_keys:
text += f", {index_key} NOT NULL UNIQUE"
text += ");"
cur.execute(text)
# Ensure the columns and indices
cur.execute(f"PRAGMA table_info('{table_name}');")
found_indices = {(x[3] * "!" + x[1]) for x in cur}
for fieldname in sorted(indices):
index_key = fieldname.lstrip("!")
if fieldname not in found_indices:
if fieldname.startswith("!"):
raise IndexError(
f"Cannot add unique index {fieldname!r} after the table has been created."
)
elif fieldname in {x.lstrip("!") for x in found_indices}:
raise IndexError(f"Given index {fieldname!r} should be unique.")
cur.execute(f"ALTER TABLE {table_name} ADD {index_key};")
cmd = "CREATE INDEX IF NOT EXISTS"
cur.execute(
f"{cmd} idx_{table_name}_{index_key} ON {table_name} ({index_key})"
)
def delete_table(self, table_name):
"""Delete the table with the given name.
Parameters
----------
table_name : str
The name of the table to delete.
*To avoid SQL injection, this arg should not be based on unsafe data.*
Be aware that this deletes the whole table, including all of
its items.
This method must be called within a transaction. Can raise
KeyError if an invalid table is given, or IOError if not used
within a transaction
"""
self.get_indices(table_name) # Fail with KeyError for invalid table name
cur = self._cur
if cur is None:
raise IOError("Can only use delete_table() within a transaction.")
self._indices_per_table.pop(table_name, None)
self._cur.execute(f"DROP TABLE {table_name}")
def rename_table(self, table_name, new_table_name):
"""Rename a table.
Parameters
----------
table_name : str
The current name of the table.
*To avoid SQL injection, this arg should not be based on unsafe data.*
new_table_name : str
The new name.
*To avoid SQL injection, this arg should not be based on unsafe data.*
This method must be called within a transaction. Can raise
KeyError if an invalid table is given, or IOError if not used
within a transaction
"""
self.get_indices(table_name) # Fail with KeyError for invalid table name
if not (isinstance(new_table_name, str) and new_table_name.isidentifier()):
raise TypeError(f"Table name must be a str identifier, not '{table_name}'")
cur = self._cur
if cur is None:
raise IOError("Can only use rename_table() within a transaction.")
self._indices_per_table.pop(table_name, None)
self._cur.execute(f"ALTER TABLE {table_name} RENAME TO {new_table_name}")
def count_all(self, table_name):
"""Get the total number of items in the given table."""
self.get_indices(table_name) # Fail with KeyError for invalid table name
cur = self._conn.cursor()
try:
cur.execute(f"SELECT COUNT(*) FROM {table_name}")
return cur.fetchone()[0]
finally:
cur.close()
def count(self, table_name, query, *save_args):
"""Get the number of items in the given table that match the given query.
Parameters
----------
table_name : str
The name of the table to count items in.
*To avoid SQL injection, this arg should not be based on unsafe data.*
query : str
The query to select items on.
*To avoid SQL injection, this arg should not be based on unsafe data;
use save_args for end-user input.*
save_args : varargs
The values to select items on.
Examples::
# Count the persons older than 20
db.count("persons", "age > ?", 20)
# Count the persons older than a given value
db.count("persons", "age > ?", min_age)
# Use AND and OR for more precise queries
db.count("persons", "age > ? AND age < ?", min_age, max_age)
See ``select(``) for details on queries.
Can raise KeyError if an invalid table is given, IndexError if an
invalid field is used in the query, or sqlite3.OperationalError for
an invalid query.
"""
self.get_indices(table_name) # Fail with KeyError for invalid table name
cur = self._conn.cursor()
try:
cur.execute(f"SELECT COUNT(*) FROM {table_name} WHERE {query}", save_args)
return cur.fetchone()[0]
except sqlite3.OperationalError as err:
if "no such column" in str(err).lower():
raise IndexError(str(err))
raise err
finally:
cur.close()
def select_all(self, table_name):
"""Get all items in the given table. See ``select()`` for details."""
self.get_indices(table_name) # Fail with KeyError for invalid table name
cur = self._conn.cursor()
try:
cur.execute(f"SELECT _ob FROM {table_name}")
return [json_decode(x[0]) for x in cur]
finally:
cur.close()
def select(self, table_name, query, *save_args):
"""Get the items in the given table that match the given query.
Parameters
----------
table_name : str
The name of the table to select items in.
*To avoid SQL injection, this arg should not be based on unsafe data.*
query : str
The query to select items on.
*To avoid SQL injection, this arg should not be based on unsafe data;
use save_args for end-user input.*
save_args : varargs
The values to select items on.
The query follows SQLite syntax and can only include indexed
fields. If needed, use ensure_table() to add indices. The query
is always fast (which is why this method is called 'select', and
not 'search').
Examples::
# Select the persons older than 20
db.select("persons", "age > ?", 20)
# Select the persons older than a given age
db.select("persons", "age > ?", min_age)
# Use AND and OR for more precise queries
db.select("persons", "age > ? AND age < ?", min_age, max_age)
There is no method to filter items bases on non-indexed fields,
because this is easy using a list comprehension, e.g.::
items = db.select_all("persons")
items = [i for i in items if i["age"] > 20]
Can raise KeyError if an invalid table is given, IndexError if an
invalid field is used in the query, or sqlite3.OperationalError for
an invalid query.
"""
self.get_indices(table_name) # Fail with KeyError for invalid table name
# It is tempting to make this a generator, but also dangerous because
# the cursor might not be closed if the generator is stored somewhere
# and not run through the end.
cur = self._conn.cursor()
try:
cur.execute(f"SELECT _ob FROM {table_name} WHERE {query}", save_args)
return [json_decode(x[0]) for x in cur]
except sqlite3.OperationalError as err:
if "no such column" in str(err).lower():
raise IndexError(str(err))
raise err
finally:
cur.close()
def select_one(self, table_name, query, *args):
"""Get the first item in the given table that match the given query.
Parameters
----------
table_name : str
The name of the table to select an item in.
*To avoid SQL injection, this arg should not be based on unsafe data.*
query : str
The query to select the item on.
*To avoid SQL injection, this arg should not be based on unsafe data;
use save_args for end-user input.*
save_args : varargs
The values to select the item on.
Returns None if there was no match. See ``select()`` for details.
"""
items = self.select(table_name, query, *args)
return items[0] if items else None
def put(self, table_name, *items):
"""Put one or more items into the given table.
Parameters
----------
table_name : str
The name of the table to put the item(s) in.
*To avoid SQL injection, this arg should not be based on unsafe data.*
items : varargs
The dicts to add. Keys that match an index can later be used for
fast querying.
This method must be called within a transaction. Can raise
KeyError if an invalid table is given, IOError if not used
within a transaction, TypeError if an item is not a (JSON
serializable) dict, or IndexError if an item does not have a
required field.
"""
cur = self._cur
if cur is None:
raise IOError("Can only use put() within a transaction.")
# Get indices - fail with KeyError for invalid table name
indices = self.get_indices(table_name)
for item in items:
if not isinstance(item, dict):
raise TypeError("Expecing each item to be a dict")
index_keys = "_ob"
row_plac = "?"
row_vals = [json_encode(item)] # Can raise TypeError
for fieldname in indices:
index_key = fieldname.lstrip("!")
if index_key in item:
index_keys += ", " + index_key
row_plac += ", ?"
row_vals.append(item[index_key])
elif fieldname.startswith("!"):
raise IndexError(f"Item does not have required field {index_key!r}")
cur.execute(
f"INSERT OR REPLACE INTO {table_name} ({index_keys}) VALUES ({row_plac})",
row_vals,
)
def put_one(self, table_name, **item):
"""Put an item into the given table using kwargs.
Parameters
----------
table_name : str
The name of the table to put the item(s) in.
*To avoid SQL injection, this arg should not be based on unsafe data.*
item : kwargs
The dict to add. Keys that match an index can later be used for
fast querying.
This method must be called within a transaction.
"""
self.put(table_name, item)
def delete(self, table_name, query, *save_args):
"""Delete items from the given table.
Parameters
----------
table_name : str
The name of the table to delete items from.
*To avoid SQL injection, this arg should not be based on unsafe data.*
query : str
The query to select the items to delete.
*To avoid SQL injection, this arg should not be based on unsafe data;
use save_args for end-user input.*
save_args : varargs
The values to select the item on.
Examples::
# Delete the persons older than 20
db.delete("persons", "age > ?", 20)
# Delete the persons older than a given age
db.delete("persons", "age > ?", min_age)
# Use AND and OR for more precise queries
db.delete("persons", "age > ? AND age < ?", min_age, max_age)
See ``select()`` for details on queries.
This method must be called within a transaction. Can raise
KeyError if an invalid table is given, IOError if not used
within a transaction, IndexError if an invalid field is used
in the query, or sqlite3.OperationalError for an invalid query.
"""
self.get_indices(table_name) # Fail with KeyError for invalid table name
cur = self._cur
if cur is None:
raise IOError("Can only use delete() within a transaction.")
try:
cur.execute(f"DELETE FROM {table_name} WHERE {query}", save_args)
except sqlite3.OperationalError as err:
if "no such column" in str(err).lower():
raise IndexError(str(err))
raise err
finally:
cur.close()
class AsyncItemDB:
"""An async version of ItemDB. The API is exactly the same, except
that all methods are async, and one must use `async with` instead
of the normal `with`.
"""
async def __new__(cls, filename):
self = super().__new__(cls)
self._loop = asyncio.get_running_loop()
self._queue = queue.Queue()
self._thread = Thread4AsyncItemDB(self._queue)
self._thread.start()
self.db = self._thread.db = await self._handle(ItemDB, filename)
return self
@property
def mtime(self):
return self.db.mtime
async def _handle(self, function, *args, **kwargs):
future = self._loop.create_future()
self._queue.put_nowait((future, function, args, kwargs))
return await future
async def __aenter__(self):
return await self._handle(self.db.__enter__)
async def __aexit__(self, type, value, traceback):
return await self._handle(self.db.__exit__, type, value, traceback)
def __del__(self):
future = self._loop.create_future()
self._queue.put_nowait((future, self.db.close, (), {}))
self._queue.put_nowait((None, None, None, None))
async def close(self):
future = self._loop.create_future()
self._queue.put_nowait((future, self.db.close, (), {}))
self._queue.put_nowait((None, None, None, None))
return await future
async def get_table_names(self, *args, **kwargs):
return await self._handle(self.db.get_table_names, *args, **kwargs)
async def get_indices(self, *args, **kwargs):
return await self._handle(self.db.get_indices, *args, **kwargs)
async def ensure_table(self, *args, **kwargs):
return await self._handle(self.db.ensure_table, *args, **kwargs)
async def delete_table(self, *args, **kwargs):
return await self._handle(self.db.delete_table, *args, **kwargs)
async def rename_table(self, *args, **kwargs):
return await self._handle(self.db.rename_table, *args, **kwargs)
async def count_all(self, *args, **kwargs):
return await self._handle(self.db.count_all, *args, **kwargs)
async def count(self, *args, **kwargs):
return await self._handle(self.db.count, *args, **kwargs)
async def select_all(self, *args, **kwargs):
return await self._handle(self.db.select_all, *args, **kwargs)
async def select(self, *args, **kwargs):
return await self._handle(self.db.select, *args, **kwargs)
async def select_one(self, *args, **kwargs):
return await self._handle(self.db.select_one, *args, **kwargs)
async def put(self, *args, **kwargs):
return await self._handle(self.db.put, *args, **kwargs)
async def put_one(self, *args, **kwargs):
return await self._handle(self.db.put_one, *args, **kwargs)
async def delete(self, *args, **kwargs):
return await self._handle(self.db.delete, *args, **kwargs)
class Thread4AsyncItemDB(threading.Thread):
"""Thread that does the work for the AsyncItemDB."""
_count = 0
def __init__(self, queue):
Thread4AsyncItemDB._count += 1
super().__init__(name=f"AsyncItemDB_{Thread4AsyncItemDB._count}")
self.daemon = True
self._queue = queue
self.db = None
def run(self) -> None:
while True:
# Continues running until all queue items are processed,
# even after closed (so we can finalize all futures)
future, function, args, kwargs = self._queue.get()
if future is None:
break
try:
result = function(*args, **kwargs)
def set_result(fut, result):
if not fut.done():
fut.set_result(result)
loop = future.get_loop()
loop.call_soon_threadsafe(set_result, future, result)
except BaseException as e:
def set_exception(fut, e):
if not fut.done():
fut.set_exception(e)
loop = future.get_loop()
loop.call_soon_threadsafe(set_exception, future, e)