Skip to content

Commit

Permalink
core: add support for other backends apart from sqlite, add file-base…
Browse files Browse the repository at this point in the history
…d backend (basically jsonl)

plus various improvements for tests
  • Loading branch information
karlicoss committed Sep 17, 2023
1 parent 291f635 commit 9eb9bf9
Show file tree
Hide file tree
Showing 7 changed files with 404 additions and 192 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
'appdirs' , # default cache dir
'sqlalchemy>=1.0', # cache DB interaction
'orjson', # fast json serialization
'pytz', # used to properly marshall pytz datatimes
]


Expand Down
235 changes: 49 additions & 186 deletions src/cachew/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import json
import logging
from pathlib import Path
import sqlite3
import stat
import sys
import time
from typing import (
Any,
Callable,
Expand All @@ -22,7 +20,6 @@
Type,
TypeVar,
Union,
Sequence,
cast,
get_args,
get_type_hints,
Expand All @@ -45,10 +42,11 @@ def orjson_dumps(*args, **kwargs): # type: ignore[misc]
orjson_loads = json.loads

import appdirs
import sqlalchemy
from sqlalchemy import Column, Table, event, text
from sqlalchemy.dialects import sqlite

from .backend.common import AbstractBackend
from .backend.file import FileBackend
from .backend.sqlite import SqliteBackend
from .common import SourceHash
from .logging_helper import makeLogger
from .marshall.cachew import CachewMarshall, build_schema
from .utils import (
Expand All @@ -60,9 +58,10 @@ def orjson_dumps(*args, **kwargs): # type: ignore[misc]
# in case of changes in the way cachew stores data, this should be changed to discard old caches
CACHEW_VERSION: str = importlib.metadata.version(__name__)


PathIsh = Union[Path, str]

Backend = Literal['sqlite', 'file']


class settings:
'''
Expand All @@ -81,71 +80,17 @@ class settings:
'''
THROW_ON_ERROR: bool = False

DEFAULT_BACKEND: Backend = 'sqlite'


def get_logger() -> logging.Logger:
return makeLogger(__name__)


# TODO better name to represent what it means?
SourceHash = str


class DbHelper:
def __init__(self, db_path: Path, cls: Type) -> None:
self.engine = sqlalchemy.create_engine(f'sqlite:///{db_path}', connect_args={'timeout': 0})
# NOTE: timeout is necessary so we don't lose time waiting during recursive calls
# by default, it's several seconds? you'd see 'test_recursive' test performance degrade

@event.listens_for(self.engine, 'connect')
def set_sqlite_pragma(dbapi_connection, connection_record):
# without wal, concurrent reading/writing is not gonna work

# ugh. that's odd, how are we supposed to set WAL if the very fact of setting wal might lock the db?
while True:
try:
dbapi_connection.execute('PRAGMA journal_mode=WAL')
break
except sqlite3.OperationalError as oe:
if 'database is locked' not in str(oe):
# ugh, pretty annoying that exception doesn't include database path for some reason
raise CachewException(f'Error while setting WAL on {db_path}') from oe
time.sleep(0.1)

self.connection = self.engine.connect()

"""
Erm... this is pretty confusing.
https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#transaction-isolation-level
Somehow without this thing sqlalchemy logs BEGIN (implicit) instead of BEGIN TRANSACTION which actually works in sqlite...
Judging by sqlalchemy/dialects/sqlite/base.py, looks like some sort of python sqlite driver problem??
test_transaction should check this behaviour
"""

@event.listens_for(self.connection, 'begin')
# pylint: disable=unused-variable
def do_begin(conn):
# NOTE there is also BEGIN CONCURRENT in newer versions of sqlite. could use it later?
conn.execute(text('BEGIN DEFERRED'))

self.meta = sqlalchemy.MetaData()
self.table_hash = Table('hash', self.meta, Column('value', sqlalchemy.String))

# fmt: off
# actual cache
self.table_cache = Table('cache' , self.meta, Column('data', sqlalchemy.BLOB))
# temporary table, we use it to insert and then (atomically?) rename to the above table at the very end
self.table_cache_tmp = Table('cache_tmp', self.meta, Column('data', sqlalchemy.BLOB))
# fmt: on

def __enter__(self) -> 'DbHelper':
return self

def __exit__(self, *args) -> None:
self.connection.close()
self.engine.dispose()
BACKENDS: Dict[Backend, Type[AbstractBackend]] = {
'file': FileBackend,
'sqlite': SqliteBackend,
}


R = TypeVar('R')
Expand Down Expand Up @@ -349,6 +294,7 @@ def cachew_impl(
# - too small values (e.g. 10) are slower than 100 (presumably, too many sql statements)
# - too large values (e.g. 10K) are slightly slower as well (not sure why?)
synthetic_key: Optional[str] = None,
backend: Optional[Backend] = None,
**kwargs,
):
r"""
Expand Down Expand Up @@ -449,6 +395,7 @@ def cachew_impl(
logger =logger,
chunk_by =chunk_by,
synthetic_key=synthetic_key,
backend =backend,
# fmt: on
)

Expand Down Expand Up @@ -481,6 +428,7 @@ def cachew(
logger: Optional[logging.Logger] = ...,
chunk_by: int = ...,
synthetic_key: Optional[str] = ...,
backend: Optional[Backend] = ...,
) -> Callable[[F], F]:
...

Expand Down Expand Up @@ -513,6 +461,7 @@ class Context(Generic[P]):
logger : logging.Logger
chunk_by : int
synthetic_key: Optional[str]
backend : Optional[Backend]

def composite_hash(self, *args, **kwargs) -> Dict[str, Any]:
fsig = inspect.signature(self.func)
Expand Down Expand Up @@ -562,8 +511,11 @@ def cachew_wrapper(
logger = C.logger
chunk_by = C.chunk_by
synthetic_key = C.synthetic_key
backend_name = C.backend
# fmt: on

used_backend = backend_name or settings.DEFAULT_BACKEND

func_name = callable_name(func)
if not settings.ENABLE:
logger.debug(f'[{func_name}]: cache explicitly disabled (settings.ENABLE is False)')
Expand Down Expand Up @@ -602,60 +554,9 @@ def get_db_path() -> Optional[Path]:
if stat.S_ISDIR(st.st_mode):
db_path = db_path / func_name

logger.debug(f'[{func_name}]: using {db_path} for db cache')
logger.debug(f'[{func_name}]: using {used_backend}:{db_path} for cache')
return db_path

def get_old_hash(db: DbHelper) -> Optional[SourceHash]:
# first, try to do as much as possible read-only, benefiting from deferred transaction
old_hashes: Sequence
try:
# not sure if there is a better way...
cursor = conn.execute(db.table_hash.select())
except sqlalchemy.exc.OperationalError as e:
# meh. not sure if this is a good way to handle this..
if 'no such table: hash' in str(e):
old_hashes = []
else:
raise e
else:
old_hashes = cursor.fetchall()

assert len(old_hashes) <= 1, old_hashes # shouldn't happen

old_hash: Optional[SourceHash]
if len(old_hashes) == 0:
old_hash = None
else:
old_hash = old_hashes[0][0] # returns a tuple...
return old_hash

def cached_items():
total = list(conn.execute(sqlalchemy.select(sqlalchemy.func.count()).select_from(table_cache)))[0][0]
logger.info(f'{func_name}: loading {total} objects from cachew (sqlite {db_path})')

rows = conn.execute(table_cache.select())
# by default, sqlalchemy wraps all results into Row object
# this can cause quite a lot of overhead if you're reading many rows
# it seems that in principle, sqlalchemy supports just returning bare underlying tuple from the dbapi
# but from browsing the code it doesn't seem like this functionality exposed
# if you're looking for cues, see
# - ._source_supports_scalars
# - ._generate_rows
# - ._row_getter
# by using this raw iterator we speed up reading the cache quite a bit
# asked here https://github.com/sqlalchemy/sqlalchemy/discussions/10350
raw_row_iterator = getattr(rows, '_raw_row_iterator', None)
if raw_row_iterator is None:
warnings.warn("CursorResult._raw_row_iterator method isn't found. This could lead to degraded cache reading performance.")
row_iterator = rows
else:
row_iterator = raw_row_iterator()

for (blob,) in row_iterator:
j = orjson_loads(blob)
obj = marshall.load(j)
yield obj

def try_use_synthetic_key() -> None:
if synthetic_key is None:
return
Expand Down Expand Up @@ -710,62 +611,30 @@ def missing_keys(cached: List[str], wanted: List[str]) -> Optional[List[str]]:
kwargs[_CACHEW_CACHED] = cached_items()
kwargs[synthetic_key] = missing

def get_exclusive_write_transaction() -> bool:
# returns whether it actually managed to get it

# NOTE on recursive calls
# somewhat magically, they should work as expected with no extra database inserts?
# the top level call 'wins' the write transaction and once it's gathered all data, will write it
# the 'intermediate' level calls fail to get it and will pass data through
# the cached 'bottom' level is read only and will be yielded without a write transaction
try:
# first 'write' statement will upgrade transaction to write transaction which might fail due to concurrency
# see https://www.sqlite.org/lang_transaction.html
# NOTE: because of 'checkfirst=True', only the last .create will guarantee the transaction upgrade to write transaction
db.table_hash.create(conn, checkfirst=True)

# 'table' used to be old 'cache' table name, so we just delete it regardless
# otherwise it might overinfalte the cache db with stale values
conn.execute(text('DROP TABLE IF EXISTS `table`'))

# NOTE: we have to use .drop and then .create (e.g. instead of some sort of replace)
# since it's possible to have schema changes inbetween calls
# checkfirst=True because it might be the first time we're using cache
table_cache_tmp.drop(conn, checkfirst=True)
table_cache_tmp.create(conn)
except sqlalchemy.exc.OperationalError as e:
if e.code == 'e3q8' and 'database is locked' in str(e):
# someone else must be have won the write lock
# not much we can do here
# NOTE: important to close early, otherwise we might hold onto too many file descriptors during yielding
# see test_recursive_deep
# (normally connection is closed in DbHelper.__exit__)
conn.close()
# in this case all the callee can do is just to call the actual function
return False
else:
raise e
return True

early_exit = False

def written_to_cache():
nonlocal early_exit

datas = func(*args, **kwargs)

# uhh. this gives a huge speedup for inserting
# since we don't have to create intermediate dictionaries
insert_into_table_cache_tmp_raw = str(table_cache_tmp.insert().compile(dialect=sqlite.dialect(paramstyle='qmark')))
# I also tried setting paramstyle='qmark' in create_engine, but it seems to be ignored :(
# idk what benefit sqlalchemy gives at this point, seems to just complicate things
if isinstance(backend, FileBackend):
# FIXME uhhh.. this is a bit crap
# but in sqlite mode we don't want to publish new hash before we write new items
# maybe should use tmp table for hashes as well?
backend.write_new_hash(new_hash)
else:
# happens later for sqlite
pass

flush_blobs = backend.flush_blobs

chunk: List[Any] = []

def flush() -> None:
nonlocal chunk
if len(chunk) > 0:
conn.exec_driver_sql(insert_into_table_cache_tmp_raw, [(c,) for c in chunk])
flush_blobs(chunk=chunk)
chunk = []

total_objects = 0
Expand All @@ -784,23 +653,20 @@ def flush() -> None:
flush()
flush()

# delete hash first, so if we are interrupted somewhere, it mismatches next time and everything is recomputed
# pylint: disable=no-value-for-parameter
conn.execute(db.table_hash.delete())

# checkfirst is necessary since it might not have existed in the first place
# e.g. first time we use cache
table_cache.drop(conn, checkfirst=True)
backend.finalize(new_hash)
logger.info(f'{func_name}: wrote {total_objects} objects to cachew ({used_backend}:{db_path})')

# meh https://docs.sqlalchemy.org/en/14/faq/metadata_schema.html#does-sqlalchemy-support-alter-table-create-view-create-trigger-schema-upgrade-functionality
# also seems like sqlalchemy doesn't have any primitives to escape table names.. sigh
conn.execute(text(f"ALTER TABLE `{table_cache_tmp.name}` RENAME TO `{table_cache.name}`"))
def cached_items():
total_cached = backend.cached_blobs_total()
total_cached_s = '' if total_cached is None else f'{total_cached} '
logger.info(f'{func_name}: loading {total_cached_s}objects from cachew ({used_backend}:{db_path})')

# pylint: disable=no-value-for-parameter
conn.execute(db.table_hash.insert().values([{'value': new_hash}]))
logger.info(f'{func_name}: wrote {total_objects} objects to cachew (sqlite {db_path})')
for blob in backend.cached_blobs():
j = orjson_loads(blob)
obj = marshall.load(j)
yield obj

# WARNING: annoyingly huge try/catch ahead...
# NOTE: annoyingly huge try/catch ahead...
# but it lets us save a function call, hence a stack frame
# see test_recursive*
try:
Expand All @@ -809,20 +675,17 @@ def flush() -> None:
yield from func(*args, **kwargs)
return

BackendCls = BACKENDS[used_backend]

new_hash_d = C.composite_hash(*args, **kwargs)
new_hash = json.dumps(new_hash_d)
new_hash: SourceHash = json.dumps(new_hash_d)
logger.debug('new hash: %s', new_hash)

marshall = CachewMarshall(Type_=cls)

with DbHelper(db_path, cls) as db, db.connection.begin():
# NOTE: deferred transaction
conn = db.connection
table_cache = db.table_cache
table_cache_tmp = db.table_cache_tmp

old_hash = get_old_hash(db=db)
logger.debug('old hash: %s', old_hash)
with BackendCls(cache_path=db_path, logger=logger) as backend:
old_hash = backend.get_old_hash()
logger.debug(f'old hash: {old_hash}')

if new_hash == old_hash:
logger.debug('hash matched: loading from cache')
Expand All @@ -833,7 +696,7 @@ def flush() -> None:

try_use_synthetic_key()

got_write = get_exclusive_write_transaction()
got_write = backend.get_exclusive_write()
if not got_write:
# NOTE: this is the bit we really have to watch out for and not put in a helper function
# otherwise it's causing an extra stack frame on every call
Expand Down
Loading

0 comments on commit 9eb9bf9

Please sign in to comment.