Skip to content

Commit

Permalink
typing: adding check and fixing (#250)
Browse files Browse the repository at this point in the history
* typing: adding check and fixing

* timedelta

* refactor

* typo

* more

* note

* names

* typo

* lint

* fixed loading

* save

* save
  • Loading branch information
Borda authored Oct 26, 2024
1 parent 84fe27a commit 82d2341
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 75 deletions.
12 changes: 7 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,6 @@ repos:
# https://prettier.io/docs/en/options.html#print-width
args: ["--print-width=79"]

# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.8.0
# hooks:
# - id: mypy

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
hooks:
Expand All @@ -70,6 +65,13 @@ repos:
name: Ruff check
args: ["--fix"]

# it needs to be after formatting hooks because the lines might be changed
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
files: "src/*"

- repo: https://github.com/tox-dev/pyproject-fmt
rev: 2.2.4
hooks:
Expand Down
7 changes: 4 additions & 3 deletions src/cachier/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import datetime
import hashlib
import os
import pickle
import threading
from collections.abc import Mapping
from dataclasses import dataclass, replace
from datetime import datetime, timedelta
from typing import Any, Optional, Union

from ._types import Backend, HashFunc, Mongetter
Expand All @@ -27,7 +27,7 @@ class Params:
hash_func: HashFunc = _default_hash_func
backend: Backend = "pickle"
mongetter: Optional[Mongetter] = None
stale_after: datetime.timedelta = datetime.timedelta.max
stale_after: timedelta = timedelta.max
next_time: bool = False
cache_dir: Union[str, os.PathLike] = "~/.cachier/"
pickle_reload: bool = True
Expand Down Expand Up @@ -100,7 +100,8 @@ def set_global_params(**params: Mapping) -> None:
if hasattr(cachier.config._global_params, k)
}
cachier.config._global_params = replace(
cachier.config._global_params, **valid_params
cachier.config._global_params,
**valid_params, # type: ignore[arg-type]
)


Expand Down
6 changes: 3 additions & 3 deletions src/cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
# http://www.opensource.org/licenses/MIT-license
# Copyright (c) 2016, Shay Palachy <shaypal5@gmail.com>

import datetime
import inspect
import os
import warnings
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from functools import wraps
from typing import Any, Optional, Union
from warnings import warn
Expand Down Expand Up @@ -107,7 +107,7 @@ def cachier(
hash_params: Optional[HashFunc] = None,
backend: Optional[Backend] = None,
mongetter: Optional[Mongetter] = None,
stale_after: Optional[datetime.timedelta] = None,
stale_after: Optional[timedelta] = None,
next_time: Optional[bool] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
pickle_reload: Optional[bool] = None,
Expand Down Expand Up @@ -259,7 +259,7 @@ def func_wrapper(*args, **kwds):
_print("Entry found.")
if _allow_none or entry.value is not None:
_print("Cached result found.")
now = datetime.datetime.now()
now = datetime.now()
if now - entry.time <= _stale_after:
_print("And it is fresh!")
return entry.value
Expand Down
6 changes: 5 additions & 1 deletion src/cachier/cores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ def _get_func_str(func: Callable) -> str:
class _BaseCore:
__metaclass__ = abc.ABCMeta

def __init__(self, hash_func: HashFunc, wait_for_calc_timeout: int):
def __init__(
self,
hash_func: Optional[HashFunc],
wait_for_calc_timeout: Optional[int],
):
self.hash_func = _update_with_defaults(hash_func, "hash_func")
self.wait_for_calc_timeout = wait_for_calc_timeout
self.lock = threading.RLock()
Expand Down
14 changes: 11 additions & 3 deletions src/cachier/cores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import threading
from datetime import datetime
from typing import Any, Optional, Tuple
from typing import Any, Dict, Optional, Tuple

from .._types import HashFunc
from ..config import CacheEntry
Expand All @@ -12,9 +12,13 @@
class _MemoryCore(_BaseCore):
"""The memory core class for cachier."""

def __init__(self, hash_func: HashFunc, wait_for_calc_timeout: int):
def __init__(
self,
hash_func: Optional[HashFunc],
wait_for_calc_timeout: Optional[int],
):
super().__init__(hash_func, wait_for_calc_timeout)
self.cache = {}
self.cache: Dict[str, CacheEntry] = {}

def _hash_func_key(self, key: str) -> str:
return f"{_get_func_str(self.func)}:{key}"
Expand Down Expand Up @@ -79,8 +83,12 @@ def wait_on_entry_calc(self, key: str) -> Any:
hash_key = self._hash_func_key(key)
with self.lock: # pragma: no cover
entry = self.cache[hash_key]
if entry is None:
return None
if not entry._processing:
return entry.value
if entry._condition is None:
raise RuntimeError("No condition set for entry")
entry._condition.acquire()
entry._condition.wait()
entry._condition.release()
Expand Down
10 changes: 6 additions & 4 deletions src/cachier/cores/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class _MongoCore(_BaseCore):

def __init__(
self,
hash_func: HashFunc,
mongetter: Mongetter,
wait_for_calc_timeout: int,
hash_func: Optional[HashFunc],
mongetter: Optional[Mongetter],
wait_for_calc_timeout: Optional[int],
):
if "pymongo" not in sys.modules:
warnings.warn(
Expand All @@ -48,7 +48,9 @@ def __init__(
stacklevel=2,
) # pragma: no cover

super().__init__(hash_func, wait_for_calc_timeout)
super().__init__(
hash_func=hash_func, wait_for_calc_timeout=wait_for_calc_timeout
)
if mongetter is None:
raise MissingMongetter(
"must specify ``mongetter`` when using the mongo core"
Expand Down
102 changes: 57 additions & 45 deletions src/cachier/cores/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import pickle # for local caching
from datetime import datetime
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import portalocker # to lock on pickle cache IO
from watchdog.events import PatternMatchingEventHandler
Expand Down Expand Up @@ -68,21 +68,22 @@ def on_modified(self, event) -> None:

def __init__(
self,
hash_func: HashFunc,
pickle_reload: bool,
cache_dir: str,
separate_files: bool,
wait_for_calc_timeout: int,
hash_func: Optional[HashFunc],
pickle_reload: Optional[bool],
cache_dir: Optional[Union[str, os.PathLike]],
separate_files: Optional[bool],
wait_for_calc_timeout: Optional[int],
):
super().__init__(hash_func, wait_for_calc_timeout)
self.cache = None
self._cache_dict: Dict[str, CacheEntry] = {}
self.reload = _update_with_defaults(pickle_reload, "pickle_reload")
self.cache_dir = os.path.expanduser(
_update_with_defaults(cache_dir, "cache_dir")
)
self.separate_files = _update_with_defaults(
separate_files, "separate_files"
)
self._cache_used_fpath = ""

@property
def cache_fname(self) -> str:
Expand Down Expand Up @@ -110,27 +111,30 @@ def _convert_legacy_cache_entry(
_condition=entry.get("condition", None),
)

def _load_cache(self) -> Mapping[str, CacheEntry]:
def _load_cache_dict(self) -> Dict[str, CacheEntry]:
try:
with portalocker.Lock(self.cache_fpath, mode="rb") as cf:
cache = pickle.load(cf) # noqa: S301
self._cache_used_fpath = str(self.cache_fpath)
except (FileNotFoundError, EOFError):
cache = {}
return {
k: _PickleCore._convert_legacy_cache_entry(v)
for k, v in cache.items()
}

def _reload_cache(self) -> None:
def get_cache_dict(self, reload: bool = False) -> Dict[str, CacheEntry]:
if self._cache_used_fpath != self.cache_fpath:
# force reload if the cache file has changed
# this change is dies to using different wrapped function
reload = True
if self._cache_dict and not (self.reload or reload):
return self._cache_dict
with self.lock:
self.cache = self._load_cache()
self._cache_dict = self._load_cache_dict()
return self._cache_dict

def _get_cache(self) -> Dict[str, CacheEntry]:
if not self.cache:
self._reload_cache()
return self.cache

def _get_cache_by_key(
def _load_cache_by_key(
self, key=None, hash_str=None
) -> Optional[CacheEntry]:
fpath = self.cache_fpath
Expand All @@ -152,35 +156,42 @@ def _clear_being_calculated_all_cache_files(self) -> None:
path, name = os.path.split(self.cache_fpath)
for subpath in os.listdir(path):
if subpath.startswith(name):
entry = self._get_cache_by_key(hash_str=subpath.split("_")[-1])
entry = self._load_cache_by_key(
hash_str=subpath.split("_")[-1]
)
if entry is not None:
entry.being_calculated = False
entry._processing = False
self._save_cache(entry, hash_str=subpath.split("_")[-1])

def _save_cache(
self, cache, key: str = None, hash_str: str = None
self,
cache: Union[Dict[str, CacheEntry], CacheEntry],
separate_file_key: Optional[str] = None,
hash_str: Optional[str] = None,
) -> None:
if separate_file_key and not isinstance(cache, CacheEntry):
raise ValueError(
"`separate_file_key` should only be used with a CacheEntry"
)
fpath = self.cache_fpath
if key is not None:
fpath += f"_{key}"
if separate_file_key is not None:
fpath += f"_{separate_file_key}"
elif hash_str is not None:
fpath += f"_{hash_str}"
with self.lock:
self.cache = cache
with portalocker.Lock(fpath, mode="wb") as cache_file:
pickle.dump(cache, cache_file, protocol=4)
if key is None:
self._reload_cache()
with portalocker.Lock(fpath, mode="wb") as cf:
pickle.dump(cache, cf, protocol=4)
# the same as check for separate_file, but changed for typing
if isinstance(cache, dict):
self._cache_dict = cache
self._cache_used_fpath = str(self.cache_fpath)

def get_entry_by_key(
self, key: str, reload: bool = False
) -> Tuple[str, CacheEntry]:
with self.lock:
if self.separate_files:
return key, self._get_cache_by_key(key)
if self.reload or reload:
self._reload_cache()
return key, self._get_cache().get(key, None)
) -> Tuple[str, Optional[CacheEntry]]:
if self.separate_files:
return key, self._load_cache_by_key(key)
return key, self.get_cache_dict(reload).get(key)

def set_entry(self, key: str, func_res: Any) -> None:
key_data = CacheEntry(
Expand All @@ -195,7 +206,7 @@ def set_entry(self, key: str, func_res: Any) -> None:
return # pragma: no cover

with self.lock:
cache = self._get_cache()
cache = self.get_cache_dict()
cache[key] = key_data
self._save_cache(cache)

Expand All @@ -207,21 +218,23 @@ def mark_entry_being_calculated_separate_files(self, key: str) -> None:
stale=False,
_processing=True,
),
key=key,
separate_file_key=key,
)

def mark_entry_not_calculated_separate_files(self, key: str) -> None:
def _mark_entry_not_calculated_separate_files(self, key: str) -> None:
_, entry = self.get_entry_by_key(key)
if entry is None:
return # that's ok, we don't need an entry in that case
entry._processing = False
self._save_cache(entry, key=key)
self._save_cache(entry, separate_file_key=key)

def mark_entry_being_calculated(self, key: str) -> None:
if self.separate_files:
self.mark_entry_being_calculated_separate_files(key)
return # pragma: no cover

with self.lock:
cache = self._get_cache()
cache = self.get_cache_dict()
if key in cache:
cache[key]._processing = True
else:
Expand All @@ -235,24 +248,23 @@ def mark_entry_being_calculated(self, key: str) -> None:

def mark_entry_not_calculated(self, key: str) -> None:
if self.separate_files:
self.mark_entry_not_calculated_separate_files(key)
self._mark_entry_not_calculated_separate_files(key)
with self.lock:
cache = self._get_cache()
cache = self.get_cache_dict()
# that's ok, we don't need an entry in that case
if isinstance(cache, dict) and key in cache:
cache[key]._processing = False
self._save_cache(cache)

def wait_on_entry_calc(self, key: str) -> Any:
if self.separate_files:
entry = self._get_cache_by_key(key)
entry = self._load_cache_by_key(key)
filename = f"{self.cache_fname}_{key}"
else:
with self.lock:
self._reload_cache()
entry = self._get_cache()[key]
entry = self.get_cache_dict()[key]
filename = self.cache_fname
if not entry._processing:
if entry and not entry._processing:
return entry.value
event_handler = _PickleCore.CacheChangeHandler(
filename=filename, core=self, key=key
Expand Down Expand Up @@ -280,7 +292,7 @@ def clear_being_calculated(self) -> None:
return # pragma: no cover

with self.lock:
cache = self._get_cache()
cache = self.get_cache_dict()
for key in cache:
cache[key]._processing = False
self._save_cache(cache)
Loading

0 comments on commit 82d2341

Please sign in to comment.