Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: rewrite and simplify return type inference #46

Merged
merged 1 commit into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 133 additions & 81 deletions src/cachew/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
import importlib.metadata
from dataclasses import dataclass
import functools
import logging
import importlib.metadata
import inspect
import json
import stat
import logging
from pathlib import Path
import sqlite3
import stat
import sys
import time
import sqlite3
from typing import (Any, Callable, List, Optional,
Type, Union, TypeVar, Generic, Sequence, Iterable, Dict, cast,
TYPE_CHECKING, overload)
import dataclasses
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
List,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
Sequence,
cast,
get_args,
get_origin,
overload,
TYPE_CHECKING,
)
import warnings

try:
Expand All @@ -38,11 +55,10 @@ def orjson_dumps(*args, **kwargs): # type: ignore[misc]
logging.exception(e)

from .logging_helper import makeLogger
from .marshall.cachew import CachewMarshall
from .marshall.cachew import CachewMarshall, build_schema
from .utils import (
is_primitive,
is_union,
CachewException,
TypeNotSupported,
)


Expand Down Expand Up @@ -73,38 +89,6 @@ def get_logger() -> logging.Logger:
return makeLogger(__name__)


# https://stackoverflow.com/a/2166841/706389
def is_dataclassish(t: Type) -> bool:
"""
>>> is_dataclassish(int)
False
>>> is_dataclassish(tuple)
False
>>> from typing import NamedTuple
>>> class N(NamedTuple):
... field: int
>>> is_dataclassish(N)
True
>>> from dataclasses import dataclass
>>> @dataclass
... class D:
... field: str
>>> is_dataclassish(D)
True
"""
if dataclasses.is_dataclass(t):
return True
b = t.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(t, '_fields', None)
if not isinstance(f, tuple):
return False
# pylint: disable=unidiomatic-typecheck
return all(type(n) == str for n in f)



# TODO better name to represent what it means?
SourceHash = str

Expand Down Expand Up @@ -199,32 +183,87 @@ def mtime_hash(path: Path, *args, **kwargs) -> SourceHash:


Failure = str
Kind = Literal['single', 'multiple']
Inferred = Tuple[Kind, Type[Any]]


def infer_type(func) -> Union[Failure, Type[Any]]:
def infer_return_type(func) -> Union[Failure, Inferred]:
"""
>>> def const() -> int:
... return 123
>>> infer_return_type(const)
('single', <class 'int'>)

>>> from typing import Optional
>>> def first_character(s: str) -> Optional[str]:
... return None if len(s) == 0 else s[0]
>>> kind, opt = infer_return_type(first_character)
>>> # in 3.8, Optional[str] is printed as Union[str, None], so need to hack around this
>>> (kind, opt is Optional[str])
('single', True)

# tuple is an iterable.. but presumably should be treated as a single value
>>> from typing import Tuple
>>> def a_tuple() -> Tuple[int, str]:
... return (123, 'hi')
>>> infer_return_type(a_tuple)
('single', typing.Tuple[int, str])

>>> from typing import Collection, NamedTuple
>>> class Person(NamedTuple):
... name: str
... age: int
>>> def person_provider() -> Collection[Person]:
... return []
>>> infer_type(person_provider)
<class 'cachew.Person'>
>>> infer_return_type(person_provider)
('multiple', <class 'cachew.Person'>)

>>> def single_person() -> Person:
... return Person(name="what", age=-1)
>>> infer_return_type(single_person)
('single', <class 'cachew.Person'>)

>>> from typing import Sequence
>>> def int_provider() -> Sequence[int]:
... return (1, 2, 3)
>>> infer_type(int_provider)
<class 'int'>
>>> infer_return_type(int_provider)
('multiple', <class 'int'>)

>>> from typing import Iterator, Union
>>> def union_provider() -> Iterator[Union[str, int]]:
... yield 1
... yield 'aaa'
>>> infer_type(union_provider)
typing.Union[str, int]
>>> infer_return_type(union_provider)
('multiple', typing.Union[str, int])

# a bit of an edge case
>>> from typing import Tuple
>>> def empty_tuple() -> Iterator[Tuple[()]]:
... yield ()
>>> infer_return_type(empty_tuple)
('multiple', typing.Tuple[()])

... # doctest: +ELLIPSIS

>>> def untyped():
... return 123
>>> infer_return_type(untyped)
'no return type annotation...'

>>> from typing import List
>>> class Custom:
... pass
>>> def unsupported() -> Custom:
... return Custom()
>>> infer_return_type(unsupported)
"can't infer type from <class 'cachew.Custom'>: can't cache <class 'cachew.Custom'>"

>>> def unsupported_list() -> List[Custom]:
... return [Custom()]
>>> infer_return_type(unsupported_list)
"can't infer type from typing.List[cachew.Custom]: can't cache <class 'cachew.Custom'>"
"""
# TODO why not get_type_hints??
annots = get_annotations(func, eval_str=True)
rtype = annots.get('return', None)
if rtype is None:
Expand All @@ -233,25 +272,38 @@ def infer_type(func) -> Union[Failure, Type[Any]]:
def bail(reason: str) -> str:
return f"can't infer type from {rtype}: " + reason

# need to get erased type, otherwise subclass check would fail
if not hasattr(rtype, '__origin__'):
return bail("expected __origin__")
if not issubclass(rtype.__origin__, Iterable):
return bail("not subclassing Iterable")

args = getattr(rtype, '__args__', None)
if args is None:
return bail("has no __args__")
if len(args) != 1:
return bail(f"wrong number of __args__: {args}")
arg = args[0]
if is_primitive(arg):
return arg
if is_union(arg):
return arg # meh?
if not is_dataclassish(arg):
return bail(f"{arg} is not NamedTuple/dataclass")
return arg
# first we wanna check if the top level type is some sort of iterable that makes sense ot cache
# e.g. List/Sequence/Iterator etc
origin = get_origin(rtype)
return_multiple = False
if origin is not None and origin is not tuple:
# TODO need to check it handles namedtuple correctly..
try:
return_multiple = issubclass(origin, Iterable)
except TypeError:
# that would happen if origin is not a 'proper' type, e.g. is a Union or something
# seems like exception is the easiest way to check
pass

if return_multiple:
# then the actual type to cache will be the argument of the top level one
args = get_args(rtype)
if args is None:
return bail("has no __args__")

if len(args) != 1:
return bail(f"wrong number of __args__: {args}")

(cached_type,) = args
else:
cached_type = rtype

try:
build_schema(Type=cached_type)
except TypeNotSupported as ex:
return bail(f"can't cache {ex.type_}")

return ('multiple' if return_multiple else 'single', cached_type)


# https://stackoverflow.com/questions/653368/how-to-create-a-python-decorator-that-can-be-used-either-with-or-without-paramet
Expand All @@ -267,8 +319,6 @@ def new_dec(*args, **kwargs):
return new_dec




def cachew_error(e: Exception) -> None:
if settings.THROW_ON_ERROR:
raise e
Expand Down Expand Up @@ -307,7 +357,7 @@ def cachew_impl(

:param cache_path: if not set, `cachew.settings.DEFAULT_CACHEW_DIR` will be used.
:param force_file: if set to True, assume `cache_path` is a regular file (instead of a directory)
:param cls: if not set, cachew will attempt to infer it from return type annotation. See :func:`infer_type` and :func:`cachew.tests.test_cachew.test_return_type_inference`.
:param cls: if not set, cachew will attempt to infer it from return type annotation. See :func:`infer_return_type` and :func:`cachew.tests.test_cachew.test_return_type_inference`.
:param depends_on: hash function to determine whether the underlying . Can potentially benefit from the use of side effects (e.g. file modification time). TODO link to test?
:param logger: custom logger, if not specified will use logger named `cachew`. See :func:`get_logger`.
:return: iterator over original or cached items
Expand Down Expand Up @@ -358,7 +408,7 @@ def cachew_impl(
warnings.warn("'hashf' is deprecated. Please use 'depends_on' instead")
depends_on = hashf

cn = cname(func)
cn = callable_name(func)
# todo not very nice that ENABLE check is scattered across two places
if not settings.ENABLE or cache_path is None:
logger.debug('[%s]: cache explicitly disabled (settings.ENABLE is False or cache_path is None)', cn)
Expand All @@ -368,10 +418,10 @@ def cachew_impl(
cache_path = settings.DEFAULT_CACHEW_DIR
logger.debug('[%s]: no cache_path specified, using the default %s', cn, cache_path)

# TODO fuzz infer_type, should never crash?
inferred = infer_type(func)
if isinstance(inferred, Failure):
msg = f"failed to infer cache type: {inferred}. See https://github.com/karlicoss/cachew#features for the list of supported types."
# TODO fuzz infer_return_type, should never crash?
inference_res = infer_return_type(func)
if isinstance(inference_res, Failure):
msg = f"failed to infer cache type: {inference_res}. See https://github.com/karlicoss/cachew#features for the list of supported types."
if cls is None:
ex = CachewException(msg)
cachew_error(ex)
Expand All @@ -380,6 +430,8 @@ def cachew_impl(
# it's ok, assuming user knows better
logger.debug(msg)
else:
(kind, inferred) = inference_res
assert kind == 'multiple' # TODO implement later
if cls is None:
logger.debug('[%s] using inferred type %s', cn, inferred)
cls = inferred
Expand Down Expand Up @@ -433,7 +485,7 @@ def cachew(
cachew = cachew_impl


def cname(func: Callable) -> str:
def callable_name(func: Callable) -> str:
# some functions don't have __module__
mod = getattr(func, '__module__', None) or ''
return f'{mod}:{func.__qualname__}'
Expand All @@ -445,7 +497,7 @@ def cname(func: Callable) -> str:
_DEPENDENCIES = 'dependencies'


@dataclasses.dataclass
@dataclass
class Context(Generic[P]):
func : Callable
cache_path : PathProvider[P]
Expand Down Expand Up @@ -503,7 +555,7 @@ def cachew_wrapper(
chunk_by = C.chunk_by
synthetic_key = C.synthetic_key

cn = cname(func)
cn = callable_name(func)
if not settings.ENABLE:
logger.debug('[%s]: cache explicitly disabled (settings.ENABLE is False)', cn)
yield from func(*args, **kwargs)
Expand Down Expand Up @@ -541,7 +593,7 @@ def cachew_wrapper(
dbp.mkdir(parents=True, exist_ok=True)
dbp = dbp / cn
else:
# already exists, so just use cname if it's a dir
# already exists, so just use callable name if it's a dir
if stat.S_ISDIR(st.st_mode):
dbp = dbp / cn

Expand Down
Loading