Skip to content

Commit

Permalink
core: add support for caching single return values, not just iterables
Browse files Browse the repository at this point in the history
should fix #29
  • Loading branch information
karlicoss committed Sep 19, 2023
1 parent 58d93b9 commit 251efb5
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 24 deletions.
86 changes: 63 additions & 23 deletions src/cachew/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ def infer_return_type(func) -> Union[Failure, Inferred]:
>>> infer_return_type(person_provider)
('multiple', <class 'cachew.Person'>)
>>> def single_str() -> str:
... return 'hello'
>>> infer_return_type(single_str)
('single', <class 'str'>)
>>> def single_person() -> Person:
... return Person(name="what", age=-1)
>>> infer_return_type(single_person)
Expand Down Expand Up @@ -217,16 +222,7 @@ def bail(reason: str) -> str:

# first we wanna check if the top level type is some sort of iterable that makes sense ot cache
# e.g. List/Sequence/Iterator etc
origin = get_origin(rtype)
return_multiple = False
if origin is not None and origin is not tuple:
# TODO need to check it handles namedtuple correctly..
try:
return_multiple = issubclass(origin, Iterable)
except TypeError:
# that would happen if origin is not a 'proper' type, e.g. is a Union or something
# seems like exception is the easiest way to check
pass
return_multiple = _returns_multiple(rtype)

if return_multiple:
# then the actual type to cache will be the argument of the top level one
Expand All @@ -249,6 +245,21 @@ def bail(reason: str) -> str:
return ('multiple' if return_multiple else 'single', cached_type)


def _returns_multiple(rtype) -> bool:
origin = get_origin(rtype)
if origin is None:
return False
if origin is tuple:
# usually tuples are more like single values rather than a sequence? (+ this works for namedtuple)
return False
try:
return issubclass(origin, Iterable)
except TypeError:
# that would happen if origin is not a 'proper' type, e.g. is a Union or something
# seems like exception is the easiest way to check
return False


# https://stackoverflow.com/questions/653368/how-to-create-a-python-decorator-that-can-be-used-either-with-or-without-paramet
def doublewrap(f):
@functools.wraps(f)
Expand Down Expand Up @@ -281,7 +292,7 @@ def cachew_impl(
func=None,
cache_path: Optional[PathProvider[P]] = use_default_path,
force_file: bool = False,
cls: Optional[Type] = None,
cls: Optional[Union[Type, Tuple[Kind, Type]]] = None,
depends_on: HashFunction[P] = default_hash,
logger: Optional[logging.Logger] = None,
chunk_by: int = 100,
Expand Down Expand Up @@ -368,34 +379,57 @@ def process(self, msg, kwargs):
cache_path = settings.DEFAULT_CACHEW_DIR
logger.debug(f'no cache_path specified, using the default {cache_path}')

use_kind: Optional[Kind] = None
use_cls: Optional[Type] = None
if cls is not None:
# defensive here since typing. objects passed as cls might fail on isinstance
try:
is_tuple = isinstance(cls, tuple)
except:
is_tuple = False
if is_tuple:
use_kind, use_cls = cls # type: ignore[misc]
else:
use_kind = 'multiple'
use_cls = cls # type: ignore[assignment]

# TODO fuzz infer_return_type, should never crash?
inference_res = infer_return_type(func)
if isinstance(inference_res, Failure):
msg = f"failed to infer cache type: {inference_res}. See https://github.com/karlicoss/cachew#features for the list of supported types."
if cls is None:
if use_cls is None:
ex = CachewException(msg)
cachew_error(ex, logger=logger)
return func
else:
# it's ok, assuming user knows better
logger.debug(msg)
assert use_kind is not None
else:
(kind, inferred) = inference_res
assert kind == 'multiple' # TODO implement later
if cls is None:
logger.debug(f'using inferred type {inferred}')
cls = inferred
(inferred_kind, inferred_cls) = inference_res
if use_cls is None:
logger.debug(f'using inferred type {inferred_kind} {inferred_cls}')
(use_kind, use_cls) = (inferred_kind, inferred_cls)
else:
if cls != inferred:
logger.warning(f"inferred type {inferred} mismatches specified type {cls}")
assert use_kind is not None
if (use_kind, use_cls) != inference_res:
logger.warning(f"inferred type {inference_res} mismatches explicitly specified type {(use_kind, use_cls)}")
# TODO not sure if should be more serious error...

if use_kind == 'single':
# pretend it's an iterable, this is just simpler for cachew_wrapper
def _func(*args, **kwargs):
return [func(*args, **kwargs)]

else:
_func = func

ctx = Context(
# fmt: off
func =func,
func =_func,
cache_path =cache_path,
force_file =force_file,
cls_ =cls,
cls_ =use_cls,
depends_on =depends_on,
logger =logger,
chunk_by =chunk_by,
Expand All @@ -408,7 +442,13 @@ def process(self, msg, kwargs):
@functools.wraps(func)
def binder(*args, **kwargs):
kwargs['_cachew_context'] = ctx
return cachew_wrapper(*args, **kwargs)
res = cachew_wrapper(*args, **kwargs)

if use_kind == 'single':
lres = list(res)
assert len(lres) == 1, lres # shouldn't happen
return lres[0]
return res

return binder

Expand All @@ -428,7 +468,7 @@ def cachew(
cache_path: Optional[PathProvider[P]] = ...,
*,
force_file: bool = ...,
cls: Optional[Type] = ...,
cls: Optional[Union[Type, Tuple[Kind, Type]]] = ...,
depends_on: HashFunction[P] = ...,
logger: Optional[logging.Logger] = ...,
chunk_by: int = ...,
Expand Down
30 changes: 29 additions & 1 deletion src/cachew/tests/test_cachew.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def H(t: AllTypes):
# TODO should be possible to iterate anonymous tuples too? or just sequences of primitive types?


def test_primitive(tmp_path: Path):
def test_primitive(tmp_path: Path) -> None:
@cachew(tmp_path)
def fun() -> Iterator[str]:
yield 'aba'
Expand All @@ -738,6 +738,34 @@ def fun() -> Iterator[str]:
assert list(fun()) == ['aba', 'caba']


def test_single_value(tmp_path: Path) -> None:
@cachew(tmp_path)
def fun_int() -> int:
return 123

assert fun_int() == 123
assert fun_int() == 123

@cachew(tmp_path, cls=('single', str))
def fun_str():
return 'whatever'

assert fun_str() == 'whatever'
assert fun_str() == 'whatever'

@cachew(tmp_path)
def fun_opt_namedtuple(none: bool) -> Optional[UUU]:
if none:
return None
else:
return UUU(xx=1, yy=2)

assert fun_opt_namedtuple(none=False) == UUU(xx=1, yy=2)
assert fun_opt_namedtuple(none=False) == UUU(xx=1, yy=2)
assert fun_opt_namedtuple(none=True) is None
assert fun_opt_namedtuple(none=True) is None


class O(NamedTuple):
x: int

Expand Down

0 comments on commit 251efb5

Please sign in to comment.