Skip to content

Commit

Permalink
feat(common): add pattern matchers
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Apr 21, 2023
1 parent fc89cc3 commit b515d5c
Show file tree
Hide file tree
Showing 10 changed files with 2,369 additions and 82 deletions.
11 changes: 5 additions & 6 deletions ibis/common/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any

from ibis.common.collections import DotDict
from ibis.common.typing import evaluate_typehint
from ibis.common.typing import get_type_hints
from ibis.common.validators import Validator, any_, frozendict_of, option, tuple_of

EMPTY = inspect.Parameter.empty # marker for missing argument
Expand Down Expand Up @@ -247,6 +247,7 @@ def from_callable(cls, fn, validators=None, return_validator=None):
Signature
"""
sig = super().from_callable(fn)
typehints = get_type_hints(fn)

if validators is None:
validators = {}
Expand All @@ -263,12 +264,11 @@ def from_callable(cls, fn, validators=None, return_validator=None):
name = param.name
kind = param.kind
default = param.default
typehint = param.annotation
typehint = typehints.get(name)

if name in validators:
validator = validators[name]
elif typehint is not EMPTY:
typehint = evaluate_typehint(typehint, fn.__module__)
elif typehint is not None:
validator = Validator.from_typehint(typehint)
else:
validator = None
Expand All @@ -288,8 +288,7 @@ def from_callable(cls, fn, validators=None, return_validator=None):

if return_validator is not None:
return_annotation = return_validator
elif sig.return_annotation is not EMPTY:
typehint = evaluate_typehint(sig.return_annotation, fn.__module__)
elif (typehint := typehints.get("return")) is not None:
return_annotation = Validator.from_typehint(typehint)
else:
return_annotation = EMPTY
Expand Down
29 changes: 24 additions & 5 deletions ibis/common/caching.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@
from __future__ import annotations

import functools
import weakref
from collections import Counter, defaultdict
from collections.abc import Iterator
from typing import Any, Callable, MutableMapping

import toolz
from bidict import bidict

from ibis.common.collections import FrozenDict
from ibis.common.exceptions import IbisError


def memoize(func: Callable) -> Callable:
"""Memoize a function."""
cache = {}

@functools.wraps(func)
def wrapper(*args, **kwargs):
key = (args, FrozenDict(kwargs))
try:
return cache[key]
except KeyError:
result = func(*args, **kwargs)
cache[key] = result
return result

return wrapper


class WeakCache(MutableMapping):
__slots__ = ('_data',)

Expand All @@ -19,13 +38,13 @@ def __init__(self):
def __setattr__(self, name, value):
raise TypeError(f"can't set {name}")

def __len__(self):
def __len__(self) -> int:
return len(self._data)

def __iter__(self):
def __iter__(self) -> Iterator[Any]:
return iter(self._data)

def __setitem__(self, key, value):
def __setitem__(self, key, value) -> None:
# construct an alternative representation of the key using the id()
# of the key's components, this prevents infinite recursions
identifiers = tuple(id(item) for item in key)
Expand Down Expand Up @@ -78,7 +97,7 @@ def __init__(
self.lookup = lookup
self.finalize = finalize
self.names = defaultdict(generate_name)
self.key = key or toolz.identity
self.key = key or (lambda x: x)

def get(self, key, default=None):
try:
Expand Down
46 changes: 45 additions & 1 deletion ibis/common/collections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from collections.abc import Iterator
from itertools import tee
from types import MappingProxyType
from typing import Any, Hashable, Mapping, TypeVar

Expand Down Expand Up @@ -148,7 +150,7 @@ def __str__(self):
return str(self.__view__)

def __repr__(self):
return f"{self.__class__.__name__}({self.__view__!r})"
return f"{self.__class__.__name__}({dict(self.__view__)!r})"

def __setattr__(self, name: str, _: Any) -> None:
raise TypeError(f"Attribute {name!r} cannot be assigned to frozendict")
Expand Down Expand Up @@ -208,4 +210,46 @@ def __repr__(self):
return f"{self.__class__.__name__}({super().__repr__()})"


class RewindableIterator(Iterator):
"""Iterator that can be rewound to a checkpoint.
Examples
--------
>>> it = RewindableIterator(range(5))
>>> next(it)
0
>>> next(it)
1
>>> it.checkpoint()
>>> next(it)
2
>>> next(it)
3
>>> it.rewind()
>>> next(it)
2
>>> next(it)
3
>>> next(it)
4
"""

def __init__(self, iterable):
self._iterator = iter(iterable)
self._checkpoint = None

def __next__(self):
return next(self._iterator)

def rewind(self):
"""Rewind the iterator to the last checkpoint."""
if self._checkpoint is None:
raise ValueError("No checkpoint to rewind to.")
self._iterator, self._checkpoint = tee(self._checkpoint)

def checkpoint(self):
"""Create a checkpoint of the current iterator state."""
self._iterator, self._checkpoint = tee(self._iterator)


public(frozendict=FrozenDict, dotdict=DotDict)
10 changes: 5 additions & 5 deletions ibis/common/grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ibis.common.annotations import EMPTY, Argument, Attribute, Signature, attribute
from ibis.common.caching import WeakCache
from ibis.common.collections import FrozenDict
from ibis.common.typing import evaluate_typehint
from ibis.common.typing import evaluate_annotations
from ibis.common.validators import Validator


Expand Down Expand Up @@ -45,10 +45,10 @@ def __new__(metacls, clsname, bases, dct, **kwargs):
signatures.append(parent.__signature__)

# collection type annotations and convert them to validators
module = dct.get('__module__')
annots = dct.get('__annotations__', {})
for name, typehint in annots.items():
typehint = evaluate_typehint(typehint, module)
module_name = dct.get('__module__')
annotations = dct.get('__annotations__', {})
typehints = evaluate_annotations(annotations, module_name)
for name, typehint in typehints.items():
validator = Validator.from_typehint(typehint)
if name in dct:
dct[name] = Argument.default(dct[name], validator, typehint=typehint)
Expand Down
Loading

0 comments on commit b515d5c

Please sign in to comment.