diff --git a/src/cachew/__init__.py b/src/cachew/__init__.py index ff027bc..308f98c 100644 --- a/src/cachew/__init__.py +++ b/src/cachew/__init__.py @@ -1,17 +1,34 @@ -import importlib.metadata +from dataclasses import dataclass import functools -import logging +import importlib.metadata import inspect import json -import stat +import logging from pathlib import Path +import sqlite3 +import stat import sys import time -import sqlite3 -from typing import (Any, Callable, List, Optional, - Type, Union, TypeVar, Generic, Sequence, Iterable, Dict, cast, - TYPE_CHECKING, overload) -import dataclasses +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, + Sequence, + cast, + get_args, + get_origin, + overload, + TYPE_CHECKING, +) import warnings try: @@ -38,11 +55,10 @@ def orjson_dumps(*args, **kwargs): # type: ignore[misc] logging.exception(e) from .logging_helper import makeLogger -from .marshall.cachew import CachewMarshall +from .marshall.cachew import CachewMarshall, build_schema from .utils import ( - is_primitive, - is_union, CachewException, + TypeNotSupported, ) @@ -73,38 +89,6 @@ def get_logger() -> logging.Logger: return makeLogger(__name__) -# https://stackoverflow.com/a/2166841/706389 -def is_dataclassish(t: Type) -> bool: - """ - >>> is_dataclassish(int) - False - >>> is_dataclassish(tuple) - False - >>> from typing import NamedTuple - >>> class N(NamedTuple): - ... field: int - >>> is_dataclassish(N) - True - >>> from dataclasses import dataclass - >>> @dataclass - ... class D: - ... field: str - >>> is_dataclassish(D) - True - """ - if dataclasses.is_dataclass(t): - return True - b = t.__bases__ - if len(b) != 1 or b[0] != tuple: - return False - f = getattr(t, '_fields', None) - if not isinstance(f, tuple): - return False - # pylint: disable=unidiomatic-typecheck - return all(type(n) == str for n in f) - - - # TODO better name to represent what it means? SourceHash = str @@ -199,32 +183,87 @@ def mtime_hash(path: Path, *args, **kwargs) -> SourceHash: Failure = str +Kind = Literal['single', 'multiple'] +Inferred = Tuple[Kind, Type[Any]] -def infer_type(func) -> Union[Failure, Type[Any]]: +def infer_return_type(func) -> Union[Failure, Inferred]: """ + >>> def const() -> int: + ... return 123 + >>> infer_return_type(const) + ('single', ) + + >>> from typing import Optional + >>> def first_character(s: str) -> Optional[str]: + ... return None if len(s) == 0 else s[0] + >>> kind, opt = infer_return_type(first_character) + >>> # in 3.8, Optional[str] is printed as Union[str, None], so need to hack around this + >>> (kind, opt is Optional[str]) + ('single', True) + + # tuple is an iterable.. but presumably should be treated as a single value + >>> from typing import Tuple + >>> def a_tuple() -> Tuple[int, str]: + ... return (123, 'hi') + >>> infer_return_type(a_tuple) + ('single', typing.Tuple[int, str]) + >>> from typing import Collection, NamedTuple >>> class Person(NamedTuple): ... name: str ... age: int >>> def person_provider() -> Collection[Person]: ... return [] - >>> infer_type(person_provider) - + >>> infer_return_type(person_provider) + ('multiple', ) + + >>> def single_person() -> Person: + ... return Person(name="what", age=-1) + >>> infer_return_type(single_person) + ('single', ) >>> from typing import Sequence >>> def int_provider() -> Sequence[int]: ... return (1, 2, 3) - >>> infer_type(int_provider) - + >>> infer_return_type(int_provider) + ('multiple', ) >>> from typing import Iterator, Union >>> def union_provider() -> Iterator[Union[str, int]]: ... yield 1 ... yield 'aaa' - >>> infer_type(union_provider) - typing.Union[str, int] + >>> infer_return_type(union_provider) + ('multiple', typing.Union[str, int]) + + # a bit of an edge case + >>> from typing import Tuple + >>> def empty_tuple() -> Iterator[Tuple[()]]: + ... yield () + >>> infer_return_type(empty_tuple) + ('multiple', typing.Tuple[()]) + + ... # doctest: +ELLIPSIS + + >>> def untyped(): + ... return 123 + >>> infer_return_type(untyped) + 'no return type annotation...' + + >>> from typing import List + >>> class Custom: + ... pass + >>> def unsupported() -> Custom: + ... return Custom() + >>> infer_return_type(unsupported) + "can't infer type from : can't cache " + + >>> def unsupported_list() -> List[Custom]: + ... return [Custom()] + >>> infer_return_type(unsupported_list) + "can't infer type from typing.List[cachew.Custom]: can't cache " """ + # TODO why not get_type_hints?? annots = get_annotations(func, eval_str=True) rtype = annots.get('return', None) if rtype is None: @@ -233,25 +272,38 @@ def infer_type(func) -> Union[Failure, Type[Any]]: def bail(reason: str) -> str: return f"can't infer type from {rtype}: " + reason - # need to get erased type, otherwise subclass check would fail - if not hasattr(rtype, '__origin__'): - return bail("expected __origin__") - if not issubclass(rtype.__origin__, Iterable): - return bail("not subclassing Iterable") - - args = getattr(rtype, '__args__', None) - if args is None: - return bail("has no __args__") - if len(args) != 1: - return bail(f"wrong number of __args__: {args}") - arg = args[0] - if is_primitive(arg): - return arg - if is_union(arg): - return arg # meh? - if not is_dataclassish(arg): - return bail(f"{arg} is not NamedTuple/dataclass") - return arg + # first we wanna check if the top level type is some sort of iterable that makes sense ot cache + # e.g. List/Sequence/Iterator etc + origin = get_origin(rtype) + return_multiple = False + if origin is not None and origin is not tuple: + # TODO need to check it handles namedtuple correctly.. + try: + return_multiple = issubclass(origin, Iterable) + except TypeError: + # that would happen if origin is not a 'proper' type, e.g. is a Union or something + # seems like exception is the easiest way to check + pass + + if return_multiple: + # then the actual type to cache will be the argument of the top level one + args = get_args(rtype) + if args is None: + return bail("has no __args__") + + if len(args) != 1: + return bail(f"wrong number of __args__: {args}") + + (cached_type,) = args + else: + cached_type = rtype + + try: + build_schema(Type=cached_type) + except TypeNotSupported as ex: + return bail(f"can't cache {ex.type_}") + + return ('multiple' if return_multiple else 'single', cached_type) # https://stackoverflow.com/questions/653368/how-to-create-a-python-decorator-that-can-be-used-either-with-or-without-paramet @@ -267,8 +319,6 @@ def new_dec(*args, **kwargs): return new_dec - - def cachew_error(e: Exception) -> None: if settings.THROW_ON_ERROR: raise e @@ -307,7 +357,7 @@ def cachew_impl( :param cache_path: if not set, `cachew.settings.DEFAULT_CACHEW_DIR` will be used. :param force_file: if set to True, assume `cache_path` is a regular file (instead of a directory) - :param cls: if not set, cachew will attempt to infer it from return type annotation. See :func:`infer_type` and :func:`cachew.tests.test_cachew.test_return_type_inference`. + :param cls: if not set, cachew will attempt to infer it from return type annotation. See :func:`infer_return_type` and :func:`cachew.tests.test_cachew.test_return_type_inference`. :param depends_on: hash function to determine whether the underlying . Can potentially benefit from the use of side effects (e.g. file modification time). TODO link to test? :param logger: custom logger, if not specified will use logger named `cachew`. See :func:`get_logger`. :return: iterator over original or cached items @@ -358,7 +408,7 @@ def cachew_impl( warnings.warn("'hashf' is deprecated. Please use 'depends_on' instead") depends_on = hashf - cn = cname(func) + cn = callable_name(func) # todo not very nice that ENABLE check is scattered across two places if not settings.ENABLE or cache_path is None: logger.debug('[%s]: cache explicitly disabled (settings.ENABLE is False or cache_path is None)', cn) @@ -368,10 +418,10 @@ def cachew_impl( cache_path = settings.DEFAULT_CACHEW_DIR logger.debug('[%s]: no cache_path specified, using the default %s', cn, cache_path) - # TODO fuzz infer_type, should never crash? - inferred = infer_type(func) - if isinstance(inferred, Failure): - msg = f"failed to infer cache type: {inferred}. See https://github.com/karlicoss/cachew#features for the list of supported types." + # TODO fuzz infer_return_type, should never crash? + inference_res = infer_return_type(func) + if isinstance(inference_res, Failure): + msg = f"failed to infer cache type: {inference_res}. See https://github.com/karlicoss/cachew#features for the list of supported types." if cls is None: ex = CachewException(msg) cachew_error(ex) @@ -380,6 +430,8 @@ def cachew_impl( # it's ok, assuming user knows better logger.debug(msg) else: + (kind, inferred) = inference_res + assert kind == 'multiple' # TODO implement later if cls is None: logger.debug('[%s] using inferred type %s', cn, inferred) cls = inferred @@ -433,7 +485,7 @@ def cachew( cachew = cachew_impl -def cname(func: Callable) -> str: +def callable_name(func: Callable) -> str: # some functions don't have __module__ mod = getattr(func, '__module__', None) or '' return f'{mod}:{func.__qualname__}' @@ -445,7 +497,7 @@ def cname(func: Callable) -> str: _DEPENDENCIES = 'dependencies' -@dataclasses.dataclass +@dataclass class Context(Generic[P]): func : Callable cache_path : PathProvider[P] @@ -503,7 +555,7 @@ def cachew_wrapper( chunk_by = C.chunk_by synthetic_key = C.synthetic_key - cn = cname(func) + cn = callable_name(func) if not settings.ENABLE: logger.debug('[%s]: cache explicitly disabled (settings.ENABLE is False)', cn) yield from func(*args, **kwargs) @@ -541,7 +593,7 @@ def cachew_wrapper( dbp.mkdir(parents=True, exist_ok=True) dbp = dbp / cn else: - # already exists, so just use cname if it's a dir + # already exists, so just use callable name if it's a dir if stat.S_ISDIR(st.st_mode): dbp = dbp / cn diff --git a/src/cachew/marshall/cachew.py b/src/cachew/marshall/cachew.py index ac97663..bf87625 100644 --- a/src/cachew/marshall/cachew.py +++ b/src/cachew/marshall/cachew.py @@ -28,7 +28,7 @@ Json, T, ) -from ..utils import CachewException +from ..utils import TypeNotSupported, is_namedtuple class CachewMarshall(AbstractMarshall[T]): @@ -227,7 +227,7 @@ def load(self, dct: Json): @dataclass(**SLOTS) -class XDatetime(Schema): +class SDatetime(Schema): def dump(self, obj: datetime) -> Json: iso = obj.isoformat() tz = obj.tzinfo @@ -253,7 +253,7 @@ def load(self, dct: tuple): @dataclass(**SLOTS) -class XDate(Schema): +class SDate(Schema): def dump(self, obj: date) -> Json: return obj.isoformat() @@ -287,19 +287,6 @@ def load(self, dct: str): } -# TODO reuse in legacy? -# https://stackoverflow.com/a/2166841/706389 -def is_namedtuple(t) -> bool: - b = t.__bases__ - if len(b) != 1 or b[0] != tuple: - return False - f = getattr(t, '_fields', None) - if not isinstance(f, tuple): - return False - # pylint: disable=unidiomatic-typecheck - return all(type(n) == str for n in f) # noqa: E721 - - def build_schema(Type) -> Schema: prim = primitives_from.get(Type) if prim is not None: @@ -313,15 +300,13 @@ def build_schema(Type) -> Schema: return SException(type=Type) if issubclass(Type, datetime): - return XDatetime(type=Type) + return SDatetime(type=Type) if issubclass(Type, date): - return XDate(type=Type) + return SDate(type=Type) if not (is_dataclass(Type) or is_namedtuple(Type)): - raise CachewException( - f"{Type} doesn't look like a supported type to cache. See https://github.com/karlicoss/cachew#features for the list of supported types." - ) + raise TypeNotSupported(type_=Type) hints = get_type_hints(Type) fields = tuple((k, build_schema(t)) for k, t in hints.items()) return SDataclass( @@ -362,6 +347,10 @@ def build_schema(Type) -> Schema: is_tuplish = origin is tuple or origin is abc.Sequence if is_tuplish: if origin is tuple: + # this is for Tuple[()], which is the way to represent empty tuple + # before python 3.11, get_args for that gives ((),) instead of an empty tuple () as one might expect + if args == ((),): + args = () return STuple( type=Type, args=tuple(build_schema(a) for a in args), @@ -419,6 +408,8 @@ def normalise(x): # TODO customise with cattrs def test_serialize_and_deserialize() -> None: + import pytest + helper = _test_identity # primitives @@ -525,5 +516,15 @@ class WithJson: assert helper(dwinter.date(), date)[0] == '2020-02-03' + # unsupported types + class NotSupported: + pass + + with pytest.raises(RuntimeError, match=".*NotSupported.* isn't supported by cachew"): + helper([NotSupported()], List[NotSupported]) + + # edge cases + helper((), Tuple[()]) + # TODO test type aliases and such?? diff --git a/src/cachew/tests/test_cachew.py b/src/cachew/tests/test_cachew.py index 1134a86..01882f3 100644 --- a/src/cachew/tests/test_cachew.py +++ b/src/cachew/tests/test_cachew.py @@ -275,16 +275,13 @@ def test_unsupported_class(tmp_path: Path) -> None: def fun() -> List[UBad]: return [UBad()] - # now something a bit nastier - # TODO hmm, should really throw at the definition time... but can fix later I suppose - @cachew(cache_path=tmp_path) - def fun2() -> Iterable[Union[UGood, UBad]]: - yield UGood(x=1) - yield UBad() - yield UGood(x=2) + with pytest.raises(CachewException, match=".*can't infer type from.*"): - with pytest.raises(CachewException, match=".*doesn't look like a supported type.*"): - list(fun2()) + @cachew(cache_path=tmp_path) + def fun2() -> Iterable[Union[UGood, UBad]]: + yield UGood(x=1) + yield UBad() + yield UGood(x=2) class TE2(NamedTuple): diff --git a/src/cachew/utils.py b/src/cachew/utils.py index 05b956c..9209840 100644 --- a/src/cachew/utils.py +++ b/src/cachew/utils.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from datetime import date, datetime from typing import ( NamedTuple, @@ -26,6 +27,14 @@ class CachewException(RuntimeError): pass +@dataclass +class TypeNotSupported(CachewException): + type_: Type + + def __str__(self) -> str: + return f"{self.type_} isn't supported by cachew. See https://github.com/karlicoss/cachew#features for the list of supported types." + + Types = Union[ Type[str], Type[int], @@ -78,3 +87,17 @@ def is_primitive(cls: Type) -> bool: True """ return cls in PRIMITIVE_TYPES + + +# https://stackoverflow.com/a/2166841/706389 +def is_namedtuple(t) -> bool: + b = getattr(t, '__bases__', None) + if b is None: + return False + if len(b) != 1 or b[0] != tuple: + return False + f = getattr(t, '_fields', None) + if not isinstance(f, tuple): + return False + # pylint: disable=unidiomatic-typecheck + return all(type(n) == str for n in f) # noqa: E721