Skip to content

Commit

Permalink
feat(common): support generic mapping and sequence type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Jan 11, 2023
1 parent 1c25213 commit ddc6603
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 18 deletions.
5 changes: 5 additions & 0 deletions ibis/common/tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
any_of,
bool_,
dict_of,
frozendict_of,
instance_of,
int_,
isin,
Expand All @@ -17,6 +18,7 @@
str_,
tuple_of,
)
from ibis.util import frozendict


@pytest.mark.parametrize(
Expand All @@ -36,6 +38,7 @@
(any_of((str_, int_(max=8))), "foo", "foo"),
(any_of((str_, int_(max=8))), 7, 7),
(all_of((int_, min_(3), min_(8))), 10, 10),
(dict_of(str_, int_), {"a": 1, "b": 2}, {"a": 1, "b": 2}),
],
)
def test_validators_passing(validator, value, expected):
Expand All @@ -59,6 +62,7 @@ def test_validators_passing(validator, value, expected):
(any_of((str_, int_(max=8))), 3.14),
(any_of((str_, int_(max=8))), 9),
(all_of((int_, min_(3), min_(8))), 7),
(dict_of(int_, str_), {"a": 1, "b": 2}),
],
)
def test_validators_failing(validator, value):
Expand Down Expand Up @@ -90,6 +94,7 @@ def endswith_d(x, this):
(List[int], list_of(instance_of(int))),
(Tuple[int], tuple_of(instance_of(int))),
(Dict[str, float], dict_of(instance_of(str), instance_of(float))),
(frozendict[str, int], frozendict_of(instance_of(str), instance_of(int))),
],
)
def test_validator_from_annotation(annot, expected):
Expand Down
34 changes: 17 additions & 17 deletions ibis/common/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import math
from contextlib import suppress
from typing import Any, Callable, Iterable, Union
from typing import Any, Callable, Iterable, Mapping, Sequence, Union

import toolz
from typing_extensions import Annotated, get_args, get_origin

from ibis.common.dispatch import lazy_singledispatch
from ibis.common.exceptions import IbisTypeError
from ibis.common.typing import evaluate_typehint
from ibis.util import flatten_iterable, is_function, is_iterable
from ibis.util import flatten_iterable, frozendict, is_function, is_iterable


class Validator(Callable):
Expand All @@ -24,23 +24,22 @@ def from_annotation(cls, annot, module=None):
annot = evaluate_typehint(annot, module)
origin_type = get_origin(annot)

if origin_type is Union:
if annot is Any:
return any_
elif origin_type is None:
return instance_of(annot)
elif origin_type is Union:
inners = map(cls.from_annotation, get_args(annot))
return any_of(tuple(inners))
elif origin_type is list:
(inner,) = map(cls.from_annotation, get_args(annot))
return list_of(inner)
elif origin_type is tuple:
(inner,) = map(cls.from_annotation, get_args(annot))
return tuple_of(inner)
elif origin_type is dict:
key_type, value_type = map(cls.from_annotation, get_args(annot))
return dict_of(key_type, value_type)
elif origin_type is Annotated:
annot, *extras = get_args(annot)
return all_of((instance_of(annot), *extras))
elif annot is Any:
return any_
elif issubclass(origin_type, Sequence):
(inner,) = map(cls.from_annotation, get_args(annot))
return sequence_of(inner, type=origin_type)
elif issubclass(origin_type, Mapping):
key_type, value_type = map(cls.from_annotation, get_args(annot))
return mapping_of(key_type, value_type, type=origin_type)
else:
return instance_of(annot)

Expand Down Expand Up @@ -180,7 +179,7 @@ def map_to(mapping, variant, **kwargs):


@validator
def container_of(inner, arg, *, type, min_length=0, flatten=False, **kwargs):
def sequence_of(inner, arg, *, type, min_length=0, flatten=False, **kwargs):
if not is_iterable(arg):
raise IbisTypeError('Argument must be a sequence')

Expand Down Expand Up @@ -222,5 +221,6 @@ def min_(min, arg, **kwargs):
bool_ = instance_of(bool)
none_ = instance_of(type(None))
dict_of = mapping_of(type=dict)
list_of = container_of(type=list)
tuple_of = container_of(type=tuple)
list_of = sequence_of(type=list)
tuple_of = sequence_of(type=tuple)
frozendict_of = mapping_of(type=frozendict)
3 changes: 2 additions & 1 deletion ibis/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

T = TypeVar("T", covariant=True)
U = TypeVar("U", covariant=True)
K = TypeVar("K")
V = TypeVar("V")

# https://www.compart.com/en/unicode/U+22EE
Expand All @@ -47,7 +48,7 @@
HORIZONTAL_ELLIPSIS = "\u2026"


class frozendict(Mapping, Hashable):
class frozendict(Mapping[K, V], Hashable):
__slots__ = ("__view__", "__precomputed_hash__")

def __init__(self, *args, **kwargs):
Expand Down

0 comments on commit ddc6603

Please sign in to comment.