diff --git a/setup.py b/setup.py index c60d65c..774724a 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ 'appdirs' , # default cache dir 'sqlalchemy>=1.0', # cache DB interaction 'orjson', # fast json serialization + 'pytz', # used to properly marshall pytz datatimes ] diff --git a/src/cachew/__init__.py b/src/cachew/__init__.py index dec5564..aa70dbe 100644 --- a/src/cachew/__init__.py +++ b/src/cachew/__init__.py @@ -5,10 +5,8 @@ import json import logging from pathlib import Path -import sqlite3 import stat import sys -import time from typing import ( Any, Callable, @@ -22,7 +20,6 @@ Type, TypeVar, Union, - Sequence, cast, get_args, get_type_hints, @@ -45,10 +42,11 @@ def orjson_dumps(*args, **kwargs): # type: ignore[misc] orjson_loads = json.loads import appdirs -import sqlalchemy -from sqlalchemy import Column, Table, event, text -from sqlalchemy.dialects import sqlite +from .backend.common import AbstractBackend +from .backend.file import FileBackend +from .backend.sqlite import SqliteBackend +from .common import SourceHash from .logging_helper import makeLogger from .marshall.cachew import CachewMarshall, build_schema from .utils import ( @@ -60,9 +58,10 @@ def orjson_dumps(*args, **kwargs): # type: ignore[misc] # in case of changes in the way cachew stores data, this should be changed to discard old caches CACHEW_VERSION: str = importlib.metadata.version(__name__) - PathIsh = Union[Path, str] +Backend = Literal['sqlite', 'file'] + class settings: ''' @@ -81,71 +80,17 @@ class settings: ''' THROW_ON_ERROR: bool = False + DEFAULT_BACKEND: Backend = 'sqlite' + def get_logger() -> logging.Logger: return makeLogger(__name__) -# TODO better name to represent what it means? -SourceHash = str - - -class DbHelper: - def __init__(self, db_path: Path, cls: Type) -> None: - self.engine = sqlalchemy.create_engine(f'sqlite:///{db_path}', connect_args={'timeout': 0}) - # NOTE: timeout is necessary so we don't lose time waiting during recursive calls - # by default, it's several seconds? you'd see 'test_recursive' test performance degrade - - @event.listens_for(self.engine, 'connect') - def set_sqlite_pragma(dbapi_connection, connection_record): - # without wal, concurrent reading/writing is not gonna work - - # ugh. that's odd, how are we supposed to set WAL if the very fact of setting wal might lock the db? - while True: - try: - dbapi_connection.execute('PRAGMA journal_mode=WAL') - break - except sqlite3.OperationalError as oe: - if 'database is locked' not in str(oe): - # ugh, pretty annoying that exception doesn't include database path for some reason - raise CachewException(f'Error while setting WAL on {db_path}') from oe - time.sleep(0.1) - - self.connection = self.engine.connect() - - """ - Erm... this is pretty confusing. - https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#transaction-isolation-level - - Somehow without this thing sqlalchemy logs BEGIN (implicit) instead of BEGIN TRANSACTION which actually works in sqlite... - - Judging by sqlalchemy/dialects/sqlite/base.py, looks like some sort of python sqlite driver problem?? - - test_transaction should check this behaviour - """ - - @event.listens_for(self.connection, 'begin') - # pylint: disable=unused-variable - def do_begin(conn): - # NOTE there is also BEGIN CONCURRENT in newer versions of sqlite. could use it later? - conn.execute(text('BEGIN DEFERRED')) - - self.meta = sqlalchemy.MetaData() - self.table_hash = Table('hash', self.meta, Column('value', sqlalchemy.String)) - - # fmt: off - # actual cache - self.table_cache = Table('cache' , self.meta, Column('data', sqlalchemy.BLOB)) - # temporary table, we use it to insert and then (atomically?) rename to the above table at the very end - self.table_cache_tmp = Table('cache_tmp', self.meta, Column('data', sqlalchemy.BLOB)) - # fmt: on - - def __enter__(self) -> 'DbHelper': - return self - - def __exit__(self, *args) -> None: - self.connection.close() - self.engine.dispose() +BACKENDS: Dict[Backend, Type[AbstractBackend]] = { + 'file': FileBackend, + 'sqlite': SqliteBackend, +} R = TypeVar('R') @@ -349,6 +294,7 @@ def cachew_impl( # - too small values (e.g. 10) are slower than 100 (presumably, too many sql statements) # - too large values (e.g. 10K) are slightly slower as well (not sure why?) synthetic_key: Optional[str] = None, + backend: Optional[Backend] = None, **kwargs, ): r""" @@ -449,6 +395,7 @@ def cachew_impl( logger =logger, chunk_by =chunk_by, synthetic_key=synthetic_key, + backend =backend, # fmt: on ) @@ -481,6 +428,7 @@ def cachew( logger: Optional[logging.Logger] = ..., chunk_by: int = ..., synthetic_key: Optional[str] = ..., + backend: Optional[Backend] = ..., ) -> Callable[[F], F]: ... @@ -513,6 +461,7 @@ class Context(Generic[P]): logger : logging.Logger chunk_by : int synthetic_key: Optional[str] + backend : Optional[Backend] def composite_hash(self, *args, **kwargs) -> Dict[str, Any]: fsig = inspect.signature(self.func) @@ -562,8 +511,11 @@ def cachew_wrapper( logger = C.logger chunk_by = C.chunk_by synthetic_key = C.synthetic_key + backend_name = C.backend # fmt: on + used_backend = backend_name or settings.DEFAULT_BACKEND + func_name = callable_name(func) if not settings.ENABLE: logger.debug(f'[{func_name}]: cache explicitly disabled (settings.ENABLE is False)') @@ -602,60 +554,9 @@ def get_db_path() -> Optional[Path]: if stat.S_ISDIR(st.st_mode): db_path = db_path / func_name - logger.debug(f'[{func_name}]: using {db_path} for db cache') + logger.debug(f'[{func_name}]: using {used_backend}:{db_path} for cache') return db_path - def get_old_hash(db: DbHelper) -> Optional[SourceHash]: - # first, try to do as much as possible read-only, benefiting from deferred transaction - old_hashes: Sequence - try: - # not sure if there is a better way... - cursor = conn.execute(db.table_hash.select()) - except sqlalchemy.exc.OperationalError as e: - # meh. not sure if this is a good way to handle this.. - if 'no such table: hash' in str(e): - old_hashes = [] - else: - raise e - else: - old_hashes = cursor.fetchall() - - assert len(old_hashes) <= 1, old_hashes # shouldn't happen - - old_hash: Optional[SourceHash] - if len(old_hashes) == 0: - old_hash = None - else: - old_hash = old_hashes[0][0] # returns a tuple... - return old_hash - - def cached_items(): - total = list(conn.execute(sqlalchemy.select(sqlalchemy.func.count()).select_from(table_cache)))[0][0] - logger.info(f'{func_name}: loading {total} objects from cachew (sqlite {db_path})') - - rows = conn.execute(table_cache.select()) - # by default, sqlalchemy wraps all results into Row object - # this can cause quite a lot of overhead if you're reading many rows - # it seems that in principle, sqlalchemy supports just returning bare underlying tuple from the dbapi - # but from browsing the code it doesn't seem like this functionality exposed - # if you're looking for cues, see - # - ._source_supports_scalars - # - ._generate_rows - # - ._row_getter - # by using this raw iterator we speed up reading the cache quite a bit - # asked here https://github.com/sqlalchemy/sqlalchemy/discussions/10350 - raw_row_iterator = getattr(rows, '_raw_row_iterator', None) - if raw_row_iterator is None: - warnings.warn("CursorResult._raw_row_iterator method isn't found. This could lead to degraded cache reading performance.") - row_iterator = rows - else: - row_iterator = raw_row_iterator() - - for (blob,) in row_iterator: - j = orjson_loads(blob) - obj = marshall.load(j) - yield obj - def try_use_synthetic_key() -> None: if synthetic_key is None: return @@ -710,43 +611,6 @@ def missing_keys(cached: List[str], wanted: List[str]) -> Optional[List[str]]: kwargs[_CACHEW_CACHED] = cached_items() kwargs[synthetic_key] = missing - def get_exclusive_write_transaction() -> bool: - # returns whether it actually managed to get it - - # NOTE on recursive calls - # somewhat magically, they should work as expected with no extra database inserts? - # the top level call 'wins' the write transaction and once it's gathered all data, will write it - # the 'intermediate' level calls fail to get it and will pass data through - # the cached 'bottom' level is read only and will be yielded without a write transaction - try: - # first 'write' statement will upgrade transaction to write transaction which might fail due to concurrency - # see https://www.sqlite.org/lang_transaction.html - # NOTE: because of 'checkfirst=True', only the last .create will guarantee the transaction upgrade to write transaction - db.table_hash.create(conn, checkfirst=True) - - # 'table' used to be old 'cache' table name, so we just delete it regardless - # otherwise it might overinfalte the cache db with stale values - conn.execute(text('DROP TABLE IF EXISTS `table`')) - - # NOTE: we have to use .drop and then .create (e.g. instead of some sort of replace) - # since it's possible to have schema changes inbetween calls - # checkfirst=True because it might be the first time we're using cache - table_cache_tmp.drop(conn, checkfirst=True) - table_cache_tmp.create(conn) - except sqlalchemy.exc.OperationalError as e: - if e.code == 'e3q8' and 'database is locked' in str(e): - # someone else must be have won the write lock - # not much we can do here - # NOTE: important to close early, otherwise we might hold onto too many file descriptors during yielding - # see test_recursive_deep - # (normally connection is closed in DbHelper.__exit__) - conn.close() - # in this case all the callee can do is just to call the actual function - return False - else: - raise e - return True - early_exit = False def written_to_cache(): @@ -754,18 +618,23 @@ def written_to_cache(): datas = func(*args, **kwargs) - # uhh. this gives a huge speedup for inserting - # since we don't have to create intermediate dictionaries - insert_into_table_cache_tmp_raw = str(table_cache_tmp.insert().compile(dialect=sqlite.dialect(paramstyle='qmark'))) - # I also tried setting paramstyle='qmark' in create_engine, but it seems to be ignored :( - # idk what benefit sqlalchemy gives at this point, seems to just complicate things + if isinstance(backend, FileBackend): + # FIXME uhhh.. this is a bit crap + # but in sqlite mode we don't want to publish new hash before we write new items + # maybe should use tmp table for hashes as well? + backend.write_new_hash(new_hash) + else: + # happens later for sqlite + pass + + flush_blobs = backend.flush_blobs chunk: List[Any] = [] def flush() -> None: nonlocal chunk if len(chunk) > 0: - conn.exec_driver_sql(insert_into_table_cache_tmp_raw, [(c,) for c in chunk]) + flush_blobs(chunk=chunk) chunk = [] total_objects = 0 @@ -784,23 +653,20 @@ def flush() -> None: flush() flush() - # delete hash first, so if we are interrupted somewhere, it mismatches next time and everything is recomputed - # pylint: disable=no-value-for-parameter - conn.execute(db.table_hash.delete()) - - # checkfirst is necessary since it might not have existed in the first place - # e.g. first time we use cache - table_cache.drop(conn, checkfirst=True) + backend.finalize(new_hash) + logger.info(f'{func_name}: wrote {total_objects} objects to cachew ({used_backend}:{db_path})') - # meh https://docs.sqlalchemy.org/en/14/faq/metadata_schema.html#does-sqlalchemy-support-alter-table-create-view-create-trigger-schema-upgrade-functionality - # also seems like sqlalchemy doesn't have any primitives to escape table names.. sigh - conn.execute(text(f"ALTER TABLE `{table_cache_tmp.name}` RENAME TO `{table_cache.name}`")) + def cached_items(): + total_cached = backend.cached_blobs_total() + total_cached_s = '' if total_cached is None else f'{total_cached} ' + logger.info(f'{func_name}: loading {total_cached_s}objects from cachew ({used_backend}:{db_path})') - # pylint: disable=no-value-for-parameter - conn.execute(db.table_hash.insert().values([{'value': new_hash}])) - logger.info(f'{func_name}: wrote {total_objects} objects to cachew (sqlite {db_path})') + for blob in backend.cached_blobs(): + j = orjson_loads(blob) + obj = marshall.load(j) + yield obj - # WARNING: annoyingly huge try/catch ahead... + # NOTE: annoyingly huge try/catch ahead... # but it lets us save a function call, hence a stack frame # see test_recursive* try: @@ -809,20 +675,17 @@ def flush() -> None: yield from func(*args, **kwargs) return + BackendCls = BACKENDS[used_backend] + new_hash_d = C.composite_hash(*args, **kwargs) - new_hash = json.dumps(new_hash_d) + new_hash: SourceHash = json.dumps(new_hash_d) logger.debug('new hash: %s', new_hash) marshall = CachewMarshall(Type_=cls) - with DbHelper(db_path, cls) as db, db.connection.begin(): - # NOTE: deferred transaction - conn = db.connection - table_cache = db.table_cache - table_cache_tmp = db.table_cache_tmp - - old_hash = get_old_hash(db=db) - logger.debug('old hash: %s', old_hash) + with BackendCls(cache_path=db_path, logger=logger) as backend: + old_hash = backend.get_old_hash() + logger.debug(f'old hash: {old_hash}') if new_hash == old_hash: logger.debug('hash matched: loading from cache') @@ -833,7 +696,7 @@ def flush() -> None: try_use_synthetic_key() - got_write = get_exclusive_write_transaction() + got_write = backend.get_exclusive_write() if not got_write: # NOTE: this is the bit we really have to watch out for and not put in a helper function # otherwise it's causing an extra stack frame on every call diff --git a/src/cachew/backend/common.py b/src/cachew/backend/common.py new file mode 100644 index 0000000..e1a28d4 --- /dev/null +++ b/src/cachew/backend/common.py @@ -0,0 +1,47 @@ +from abc import abstractmethod +import logging +from pathlib import Path +from typing import ( + Iterator, + Optional, + Sequence, +) + +from ..common import SourceHash + + +class AbstractBackend: + @abstractmethod + def __init__(self, cache_path: Path, *, logger: logging.Logger) -> None: + raise NotImplementedError + + @abstractmethod + def __enter__(self): + raise NotImplementedError + + def __exit__(self, *args) -> None: + raise NotImplementedError + + def get_old_hash(self) -> Optional[SourceHash]: + raise NotImplementedError + + def cached_blobs_total(self) -> Optional[int]: + raise NotImplementedError + + def cached_blobs(self) -> Iterator[bytes]: + raise NotImplementedError + + def get_exclusive_write(self) -> bool: + ''' + Returns whether it actually managed to get it + ''' + raise NotImplementedError + + def write_new_hash(self, new_hash: SourceHash) -> None: + raise NotImplementedError + + def flush_blobs(self, chunk: Sequence[bytes]) -> None: + raise NotImplementedError + + def finalize(self, new_hash: SourceHash) -> None: + raise NotImplementedError diff --git a/src/cachew/backend/file.py b/src/cachew/backend/file.py new file mode 100644 index 0000000..e3701f6 --- /dev/null +++ b/src/cachew/backend/file.py @@ -0,0 +1,84 @@ +import logging +import os +from pathlib import Path +from typing import ( + BinaryIO, + Iterator, + Optional, + Sequence, +) + +from ..common import SourceHash +from .common import AbstractBackend + + +class FileBackend(AbstractBackend): + jsonl: Path + jsonl_tmp: Path + jsonl_fr: Optional[BinaryIO] + jsonl_tmp_fw: Optional[BinaryIO] + + def __init__(self, cache_path: Path, *, logger: logging.Logger) -> None: + self.logger = logger + self.jsonl = cache_path + self.jsonl_tmp = Path(str(self.jsonl) + '.tmp') + + self.jsonl_fr = None + self.jsonl_tmp_fw = None + + def __enter__(self) -> 'FileBackend': + try: + self.jsonl_fr = self.jsonl.open('rb') + except FileNotFoundError: + self.jsonl_fr = None + return self + + def __exit__(self, *args) -> None: + if self.jsonl_tmp_fw is not None: + self.jsonl_tmp_fw.close() + + # might still exist in case of early exit + self.jsonl_tmp.unlink(missing_ok=True) + + if self.jsonl_fr is not None: + self.jsonl_fr.close() + + def get_old_hash(self) -> Optional[SourceHash]: + if self.jsonl_fr is None: + return None + hash_line = self.jsonl_fr.readline().rstrip(b'\n') + return hash_line.decode('utf8') + + def cached_blobs_total(self) -> Optional[int]: + # not really sure how to support that for a plaintext file? + # could wc -l but it might be costly.. + return None + + def cached_blobs(self) -> Iterator[bytes]: + assert self.jsonl_fr is not None # should be guaranteed by get_old_hash + yield from self.jsonl_fr # yields line by line + + def get_exclusive_write(self) -> bool: + # NOTE: opening in x (exclusive write) mode just in case, so it throws if file exists + try: + self.jsonl_tmp_fw = self.jsonl_tmp.open('xb') + except FileExistsError: + self.jsonl_tmp_fw = None + return False + else: + return True + + def write_new_hash(self, new_hash: SourceHash) -> None: + assert self.jsonl_tmp_fw is not None + self.jsonl_tmp_fw.write(new_hash.encode('utf8') + b'\n') + + def flush_blobs(self, chunk: Sequence[bytes]) -> None: + fw = self.jsonl_tmp_fw + assert fw is not None + for blob in chunk: + fw.write(blob) + fw.write(b'\n') + + def finalize(self, new_hash: SourceHash) -> None: + # TODO defensive?? + os.rename(self.jsonl_tmp, self.jsonl) diff --git a/src/cachew/backend/sqlite.py b/src/cachew/backend/sqlite.py new file mode 100644 index 0000000..ef51f72 --- /dev/null +++ b/src/cachew/backend/sqlite.py @@ -0,0 +1,190 @@ +import logging +from pathlib import Path +import sqlite3 +import time +from typing import ( + Iterator, + Optional, + Sequence, +) +import warnings + +import sqlalchemy +from sqlalchemy import Column, Table, event, text +from sqlalchemy.dialects import sqlite + +from ..common import SourceHash +from .common import AbstractBackend + + +class SqliteBackend(AbstractBackend): + def __init__(self, cache_path: Path, *, logger: logging.Logger) -> None: + self.logger = logger + self.engine = sqlalchemy.create_engine(f'sqlite:///{cache_path}', connect_args={'timeout': 0}) + # NOTE: timeout is necessary so we don't lose time waiting during recursive calls + # by default, it's several seconds? you'd see 'test_recursive' test performance degrade + + @event.listens_for(self.engine, 'connect') + def set_sqlite_pragma(dbapi_connection, connection_record): + # without wal, concurrent reading/writing is not gonna work + + # ugh. that's odd, how are we supposed to set WAL if the very fact of setting wal might lock the db? + while True: + try: + dbapi_connection.execute('PRAGMA journal_mode=WAL') + break + except sqlite3.OperationalError as oe: + if 'database is locked' not in str(oe): + # ugh, pretty annoying that exception doesn't include database path for some reason + raise RuntimeError(f'Error while setting WAL on {cache_path}') from oe + time.sleep(0.1) + + self.connection = self.engine.connect() + + """ + Erm... this is pretty confusing. + https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#transaction-isolation-level + + Somehow without this thing sqlalchemy logs BEGIN (implicit) instead of BEGIN TRANSACTION which actually works in sqlite... + + Judging by sqlalchemy/dialects/sqlite/base.py, looks like some sort of python sqlite driver problem?? + + test_transaction should check this behaviour + """ + + @event.listens_for(self.connection, 'begin') + # pylint: disable=unused-variable + def do_begin(conn): + # NOTE there is also BEGIN CONCURRENT in newer versions of sqlite. could use it later? + conn.execute(text('BEGIN DEFERRED')) + + self.meta = sqlalchemy.MetaData() + self.table_hash = Table('hash', self.meta, Column('value', sqlalchemy.String)) + + # fmt: off + # actual cache + self.table_cache = Table('cache' , self.meta, Column('data', sqlalchemy.BLOB)) + # temporary table, we use it to insert and then (atomically?) rename to the above table at the very end + self.table_cache_tmp = Table('cache_tmp', self.meta, Column('data', sqlalchemy.BLOB)) + # fmt: on + + def __enter__(self) -> 'SqliteBackend': + # NOTE: deferred transaction + self.transaction = self.connection.begin() + # FIXME this is a bit crap.. is there a nicer way to use another ctx manager here? + self.transaction.__enter__() + return self + + def __exit__(self, *args) -> None: + self.transaction.__exit__(*args) + self.connection.close() + self.engine.dispose() + + def get_old_hash(self) -> Optional[SourceHash]: + # first, try to do as much as possible read-only, benefiting from deferred transaction + old_hashes: Sequence + try: + # not sure if there is a better way... + cursor = self.connection.execute(self.table_hash.select()) + except sqlalchemy.exc.OperationalError as e: + # meh. not sure if this is a good way to handle this.. + if 'no such table: hash' in str(e): + old_hashes = [] + else: + raise e + else: + old_hashes = cursor.fetchall() + + assert len(old_hashes) <= 1, old_hashes # shouldn't happen + + old_hash: Optional[SourceHash] + if len(old_hashes) == 0: + old_hash = None + else: + old_hash = old_hashes[0][0] # returns a tuple... + return old_hash + + def cached_blobs_total(self) -> Optional[int]: + return list(self.connection.execute(sqlalchemy.select(sqlalchemy.func.count()).select_from(self.table_cache)))[0][0] + + def cached_blobs(self) -> Iterator[bytes]: + rows = self.connection.execute(self.table_cache.select()) + # by default, sqlalchemy wraps all results into Row object + # this can cause quite a lot of overhead if you're reading many rows + # it seems that in principle, sqlalchemy supports just returning bare underlying tuple from the dbapi + # but from browsing the code it doesn't seem like this functionality exposed + # if you're looking for cues, see + # - ._source_supports_scalars + # - ._generate_rows + # - ._row_getter + # by using this raw iterator we speed up reading the cache quite a bit + # asked here https://github.com/sqlalchemy/sqlalchemy/discussions/10350 + raw_row_iterator = getattr(rows, '_raw_row_iterator', None) + if raw_row_iterator is None: + warnings.warn("CursorResult._raw_row_iterator method isn't found. This could lead to degraded cache reading performance.") + row_iterator = rows + else: + row_iterator = raw_row_iterator() + + for (blob,) in row_iterator: + yield blob + + def get_exclusive_write(self) -> bool: + # NOTE on recursive calls + # somewhat magically, they should work as expected with no extra database inserts? + # the top level call 'wins' the write transaction and once it's gathered all data, will write it + # the 'intermediate' level calls fail to get it and will pass data through + # the cached 'bottom' level is read only and will be yielded without a write transaction + try: + # first 'write' statement will upgrade transaction to write transaction which might fail due to concurrency + # see https://www.sqlite.org/lang_transaction.html + # NOTE: because of 'checkfirst=True', only the last .create will guarantee the transaction upgrade to write transaction + self.table_hash.create(self.connection, checkfirst=True) + + # 'table' used to be old 'cache' table name, so we just delete it regardless + # otherwise it might overinfalte the cache db with stale values + self.connection.execute(text('DROP TABLE IF EXISTS `table`')) + + # NOTE: we have to use .drop and then .create (e.g. instead of some sort of replace) + # since it's possible to have schema changes inbetween calls + # checkfirst=True because it might be the first time we're using cache + self.table_cache_tmp.drop(self.connection, checkfirst=True) + self.table_cache_tmp.create(self.connection) + except sqlalchemy.exc.OperationalError as e: + if e.code == 'e3q8' and 'database is locked' in str(e): + # someone else must be have won the write lock + # not much we can do here + # NOTE: important to close early, otherwise we might hold onto too many file descriptors during yielding + # see test_recursive_deep + # (normally connection is closed in SqliteBackend.__exit__) + self.connection.close() + # in this case all the callee can do is just to call the actual function + return False + else: + raise e + return True + + def flush_blobs(self, chunk: Sequence[bytes]) -> None: + # uhh. this gives a huge speedup for inserting + # since we don't have to create intermediate dictionaries + # TODO move this to __init__? + insert_into_table_cache_tmp_raw = str(self.table_cache_tmp.insert().compile(dialect=sqlite.dialect(paramstyle='qmark'))) + # I also tried setting paramstyle='qmark' in create_engine, but it seems to be ignored :( + # idk what benefit sqlalchemy gives at this point, seems to just complicate things + self.connection.exec_driver_sql(insert_into_table_cache_tmp_raw, [(c,) for c in chunk]) + + def finalize(self, new_hash: SourceHash) -> None: + # delete hash first, so if we are interrupted somewhere, it mismatches next time and everything is recomputed + # pylint: disable=no-value-for-parameter + self.connection.execute(self.table_hash.delete()) + + # checkfirst is necessary since it might not have existed in the first place + # e.g. first time we use cache + self.table_cache.drop(self.connection, checkfirst=True) + + # meh https://docs.sqlalchemy.org/en/14/faq/metadata_schema.html#does-sqlalchemy-support-alter-table-create-view-create-trigger-schema-upgrade-functionality + # also seems like sqlalchemy doesn't have any primitives to escape table names.. sigh + self.connection.execute(text(f"ALTER TABLE `{self.table_cache_tmp.name}` RENAME TO `{self.table_cache.name}`")) + + # pylint: disable=no-value-for-parameter + self.connection.execute(self.table_hash.insert().values([{'value': new_hash}])) diff --git a/src/cachew/common.py b/src/cachew/common.py new file mode 100644 index 0000000..6086c32 --- /dev/null +++ b/src/cachew/common.py @@ -0,0 +1,2 @@ +# TODO better name to represent what it means? +SourceHash = str diff --git a/src/cachew/tests/test_cachew.py b/src/cachew/tests/test_cachew.py index ebd00f2..6645467 100644 --- a/src/cachew/tests/test_cachew.py +++ b/src/cachew/tests/test_cachew.py @@ -2,6 +2,7 @@ from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass, asdict from datetime import datetime, date, timezone +import hashlib import inspect from itertools import islice, chain from pathlib import Path @@ -21,7 +22,7 @@ import pytest -from .. import cachew, get_logger, NTBinder, CachewException, settings +from .. import cachew, get_logger, NTBinder, CachewException, settings, Backend from .utils import running_on_ci, gc_control @@ -29,6 +30,12 @@ logger = get_logger() +@pytest.fixture(autouse=True) +def set_default_cachew_dir(tmp_path: Path): + tpath = tmp_path / 'cachew_default' + settings.DEFAULT_CACHEW_DIR = tpath + + @pytest.fixture(autouse=True) def throw_on_errors(): # NOTE: in tests we always throw on errors, it's a more reasonable default for testing. @@ -37,6 +44,13 @@ def throw_on_errors(): yield +@pytest.fixture(autouse=True, params=['sqlite', 'file']) +def set_backend(restore_settings, request): + backend = request.param + settings.DEFAULT_BACKEND = backend + yield + + @pytest.fixture def restore_settings(): orig = {k: v for k, v in settings.__dict__.items() if not k.startswith('__')} @@ -100,10 +114,18 @@ def test_custom_hash(tmp_path: Path) -> None: ] calls = 0 + def get_path_version(path: Path): + ns = path.stat().st_mtime_ns + # hmm, this might be unreliable, sometimes mtime doesn't change even after modifications? + # I suppose it takes some time for them to sync or something... + # so let's compute md5 or something in addition.. + md5 = hashlib.md5(path.read_bytes()).digest() + return str((ns, md5)) + # fmt: off @cachew( cache_path=tmp_path, - depends_on=lambda path: path.stat().st_mtime # when path is update, underlying cache would be discarded + depends_on=get_path_version, # when path is updated, underlying cache would be discarded ) # fmt: on def data(path: Path) -> Iterable[UUU]: @@ -809,8 +831,10 @@ def fun() -> Iterator[Union[int, DD]]: assert list(fun()) == [123, DD(456)] -def _concurrent_helper(cache_path: Path, count: int, sleep_s=0.1): - @cachew(cache_path) +# ugh. we need to pass backend here explicitly since it might not get picked up from the fixture +# that sets it in settings. due to multiprocess stuff +def _concurrent_helper(cache_path: Path, count: int, backend: Backend, sleep_s=0.1): + @cachew(cache_path, backend=backend) def test(count: int) -> Iterator[int]: for i in range(count): print(f"{count}: GENERATING {i}") @@ -828,9 +852,9 @@ def fuzz_cachew_impl(): from .. import cachew_wrapper patch = '''\ -@@ -740,6 +740,11 @@ - - logger.debug('old hash: %s', old_hash) +@@ -189,6 +189,11 @@ + old_hash = backend.get_old_hash() + logger.debug(f'old hash: {old_hash}') + from random import random + rs = random() * 2 @@ -839,7 +863,7 @@ def fuzz_cachew_impl(): + if new_hash == old_hash: logger.debug('hash matched: loading from cache') - rows = conn.execute(values_table.select()) + yield from cached_items() ''' patchy.patch(cachew_wrapper, patch) yield @@ -855,11 +879,11 @@ def test_concurrent_writes(tmp_path: Path, fuzz_cachew_impl) -> None: # warm up to create the database # FIXME ok, that will be fixed separately with atomic move I suppose - _concurrent_helper(cache_path, 1) + _concurrent_helper(cache_path, 1, settings.DEFAULT_BACKEND) processes = 5 with ProcessPoolExecutor() as pool: - futures = [pool.submit(_concurrent_helper, cache_path, count) for count in range(processes)] + futures = [pool.submit(_concurrent_helper, cache_path, count, settings.DEFAULT_BACKEND) for count in range(processes)] for count, f in enumerate(futures): assert f.result() == [i * i for i in range(count)] @@ -873,13 +897,13 @@ def test_concurrent_reads(tmp_path: Path, fuzz_cachew_impl): count = 10 # warm up - _concurrent_helper(cache_path, count, sleep_s=0) + _concurrent_helper(cache_path, count, settings.DEFAULT_BACKEND, sleep_s=0) processes = 4 start = time.time() with ProcessPoolExecutor() as pool: - futures = [pool.submit(_concurrent_helper, cache_path, count, 1) for _ in range(processes)] + futures = [pool.submit(_concurrent_helper, cache_path, count, settings.DEFAULT_BACKEND, 1) for _ in range(processes)] for f in futures: print(f.result()) @@ -1131,6 +1155,9 @@ def fun() -> Iterator[int]: def test_old_cache_v0_6_3(tmp_path: Path) -> None: + if settings.DEFAULT_BACKEND != 'sqlite': + pytest.skip('this test only makes sense for sqlite backend') + sql = ''' PRAGMA foreign_keys=OFF; BEGIN TRANSACTION; @@ -1181,7 +1208,7 @@ def fun() -> Iterator[int]: assert calls == 3 -def test_early_exit(tmp_path: Path) -> None: +def test_early_exit_simple(tmp_path: Path) -> None: # cachew works on iterators and we'd prefer not to cache if the iterator hasn't been exhausted calls_f = 0