-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsqlitent.py
296 lines (238 loc) · 10.6 KB
/
sqlitent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import collections.abc
import collections
import sqlite3
import pickle
import itertools
import types
################################################################################
# Helpers
################################################################################
def _sqlname(name):
''' Appends a hex-encoded version of the name to itself to distinguish
between lower and upper case.
>>> _sqlname('AAA')
'AAA_414141'
>>> _sqlname('aaa')
'aaa_616161'
'''
return f'{name}_{name.encode("ascii").hex()}'
def _sqltype(_type):
''' Returns the corresponding Sqlite datatype given a Python type.
>>> _sqltype(int)
'INTEGER'
>>> _sqltype(object)
'BLOB'
'''
if int == _type: return 'INTEGER'
if float == _type: return 'REAL'
if str == _type: return 'TEXT'
if bytes == _type: return 'BLOB'
return 'BLOB'
def _identity(something):
return something
def _istrivial(val):
return val in (type(None), int, float, str, bytes) or \
type(val) in (type(None), int, float, str, bytes)
def _flatten(it):
if hasattr(type(it), '__bases__') and tuple in type(it).__bases__ and hasattr(it, '_fields'):
yield it
elif isinstance(it, (tuple, list, set)):
yield from itertools.chain.from_iterable(map(_flatten, it))
else:
raise Exception('datatype!')
################################################################################
# sqlitent API
################################################################################
class sqlitent(collections.abc.Collection):
def __init__(self, database, encode=pickle.dumps, decode=pickle.loads):
self.__db = sqlite3.connect(database)
self.__encode = encode
self.__decode = decode
# We need to keep track of recognized namedtuples and how to encode
# and decode them.
self.__tupletypes = set()
self.__encoder = {}
self.__decoder = {}
# We cache frequently used SQL statements
self.__insert_stmt = {}
self.__select_stmt = {}
self.__delete_stmt = {}
self.__count_stmt = {}
def __register(self, tupletype, fields):
# Registers a namedtuple with the database. All fields in the tupletype
# need to be mapped to a trivial type in fields.
print(self)
print(tupletype)
print(tupletype.__bases__)
print(fields)
# Is tupletupe really a namedtuple?
assert tuple in tupletype.__bases__, 'expected namedtuple'
assert hasattr(tupletype, '_fields'), 'expected namedtuple'
assert hasattr(tupletype, '_make'), 'expected namedtuple'
assert hasattr(tupletype, '_source'), 'expected namedtuple'
assert hasattr(tupletype, '_replace'), 'expected namedtuple'
assert hasattr(tupletype, '_asdict'), 'expected namedtuple'
assert all(hasattr(tupletype, n) for n in tupletype._fields), 'expected namedtuple'
# We require that all fields in the namedtuple are typed to create a
# typed database table and handle encoding complex types.
assert all(f in fields for f in tupletype._fields), 'untyped field(s)'
fields = collections.OrderedDict([(f, fields[f]) for f in tupletype._fields])
encs = [_identity if _istrivial(t) else self.__encode for t in fields.values()]
def _encode(tup): return tuple(enc(v) for enc, v in zip(encs, tup))
self.__encoder[tupletype] = _encode
decs = [_identity if _istrivial(t) else self.__decode for t in fields.values()]
def _decode(tup): return tupletype._make(dec(v) for dec, v in zip(decs, tup))
self.__decoder[tupletype] = _decode
self.__insert_stmt[tupletype] = self.__build_insert_stmt(tupletype.__name__, fields.keys())
self.__select_stmt[tupletype] = self.__build_select_stmt(tupletype.__name__, fields.keys())
self.__delete_stmt[tupletype] = self.__build_delete_stmt(tupletype.__name__, fields.keys())
self.__count_stmt[tupletype] = self.__build_count_stmt(tupletype.__name__)
self.__execute(self.__build_create_table_stmt(tupletype.__name__, fields))
# If we get to this point, everything is ready to deal instances of
# tupletype. Hence tupletype is added to the recognized namedtuples.
self.__tupletypes.add(tupletype)
def __build_insert_stmt(self, name, fieldnames):
cols = ','.join(map(_sqlname, fieldnames))
gaps = ','.join('?' for _ in fieldnames)
return f'INSERT OR IGNORE INTO {_sqlname(name)} ({cols}) VALUES ({gaps});'
def __build_select_stmt(self, name, fieldnames=[]):
clauses = ' AND '.join(f'{_sqlname(f)} IS ?' for f in fieldnames)
where = f' WHERE {clauses}' if clauses else ''
return f'SELECT * FROM {_sqlname(name)}{where};'
def __build_delete_stmt(self, name, fieldnames=[]):
clauses = ' AND '.join(f'{_sqlname(f)} IS ?' for f in fieldnames)
where = f' WHERE {clauses}' if clauses else ''
return f'DELETE FROM {_sqlname(name)}{where};'
def __build_count_stmt(self, name):
return f'SELECT count(*) FROM {_sqlname(name)};'
def __build_create_table_stmt(self, name, fieldtypes):
defs = ','.join(f'{_sqlname(f)} {_sqltype(v)}' for f, v in fieldtypes.items())
cols = ','.join(map(_sqlname, fieldtypes.keys()))
return f'CREATE TABLE IF NOT EXISTS {_sqlname(name)} ({defs}, UNIQUE ({cols}));'
def __assert_registered(self, tupletype):
if tupletype not in self.__tupletypes:
raise Exception(f'unknown tupletype: {tupletype}')
def __execute(self, stmt, *args, **kwargs):
cur = self.__db.cursor().execute(stmt, *args, **kwargs)
self.__db.commit()
return cur
def __contains__(self, tup):
if type(tup) not in self.__tupletypes:
return False
tmp = self.__encoder[type(tup)](tup)
return bool(list(self.__execute(self.__select_stmt[type(tup)], tmp)))
def __iter__(self):
return itertools.chain.from_iterable(
map(self.__decoder[t], self.__execute(self.__build_select_stmt(t.__name__)))
for t in self.__tupletypes
)
def __len__(self):
return sum(self.__execute(self.__count_stmt[t]).fetchone()[0] for t in self.__tupletypes)
def add(self, nt):
''' Add a namedtuple to the database. Registers the namedtuple class
with the database if necessary.
>>> Pal = collections.namedtuple('Pal', ['name', 'age'])
>>> db = sqlitent(':memory:')
>>> p = Pal('Jim', 35)
>>> db.remove(p)
'''
tupletype = type(nt)
if tupletype not in self.__tupletypes:
self.__register(tupletype, {f: type(v) for f, v in nt._asdict().items()})
if None in nt and nt in self:
return # abort if exists, because Sqlite's NULL isn't unique
tmp = self.__encoder[tupletype](nt)
self.__execute(self.__insert_stmt[tupletype], tmp)
def insert(self, *nts):
''' Insert one or more namedtuples to the database.
>>> Pal = collections.namedtuple('Pal', ['name', 'age'])
>>> db = sqlitent(':memory:')
>>> p = Pal('Jim', 35)
>>> db.remove(p)
'''
for tup in set(_flatten(nts)):
self.add(tup)
def remove(self, tup):
''' Remove one matching namedtuple from the database.
>>> Pal = collections.namedtuple('Pal', ['name', 'age'])
>>> db = sqlitent(':memory:')
>>> p = Pal('Jim', 35)
>>> db.remove(p)
'''
tupletype = type(tup)
self.__assert_registered(tupletype)
tmp = self.__encoder[tupletype](tup)
self.__execute(self.__delete_stmt[tupletype], tmp)
def delete(self, *nts):
''' Remove one or more namedtuples from the database.
>>> Pal = collections.namedtuple('Pal', ['name', 'age'])
>>> db = sqlitent(':memory:')
>>> p = Pal('Jim', 35)
>>> db.delete([p, p])
'''
for tup in set(_flatten(nts)):
self.remove(tup)
def one(self, tupletype, **kwargs):
''' Return one matching namedtuple or None.
>>> Pal = collections.namedtuple('Pal', ['name', 'age'])
>>> db = sqlitent(':memory:')
>>> p = Pal('Jim', 35)
>>> db.one(Pal, name='Jim')
Pal('Jim', 35)
'''
self.__assert_registered(tupletype)
for tup in self.many(tupletype, **kwargs):
return tup
return None
def pop(self, tupletype, **kwargs):
''' Return one matching namedtuple or None and remove the returned
namedtuple from the database.
>>> Pal = collections.namedtuple('Pal', ['name', 'age'])
>>> db = sqlitent(':memory:')
>>> p = Pal('Jim', 35)
>>> db.pop(Pal, name='Jim')
Pal('Jim', 35)
'''
self.__assert_registered(tupletype)
tup = self.one(tupletype, **kwargs)
if tup is not None:
self.remove(tup)
return tup
def many(self, tupletype, **kwargs):
''' Return zero or more matching namedtuples.
>>> Pal = collections.namedtuple('Pal', ['name', 'age'])
>>> db = sqlitent(':memory:')
>>> db.many(Pal, name='Jim')
[...]
'''
if not all(k in tupletype._fields for k in kwargs):
raise Exception(f'{tupletype} doesn\'t have one of your keywords')
if tupletype not in self.__tupletypes:
raise Exception(f'unknown tupletype: {tupletype}')
sqlargs = []
sqlvals = []
filters = []
for field, value in sorted(kwargs.items()):
if isinstance(value, types.FunctionType):
filters.append(lambda t: value(getattr(t, field)))
else:
sqlargs.append(field)
sqlvals.append(value)
stmt = self.__build_select_stmt(tupletype.__name__, sqlargs)
it = self.__execute(stmt, sqlvals)
it = map(self.__decoder[tupletype], it)
for fn in filters:
it = filter(fn, it)
print(fn, it)
yield from it
def popmany(self, tupletype, **kwargs):
''' Return zero or more matching namedtuples and removes them
from the database.
>>> Pal = collections.namedtuple('Pal', ['name', 'age'])
>>> db = sqlitent(':memory:')
>>> db.popmany(Pal, name='Jim')
[...]
'''
tups = list(self.many(tupletype, **kwargs))
self.delete(tups)
return tups