Skip to content

Commit

Permalink
test(common): add more complicated test case for generic type coercions
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Feb 28, 2023
1 parent a1c46a2 commit 91bef71
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 24 deletions.
6 changes: 6 additions & 0 deletions ibis/common/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ def __eq__(self, other):
and self._validator == other._validator
)

def __repr__(self):
return (
f"{self.__class__.__name__}(validator={self._validator!r}, "
f"default={self._default!r}, typehint={self._typehint!r})"
)

def validate(self, arg, **kwargs):
if self._validator is None:
return arg
Expand Down
8 changes: 8 additions & 0 deletions ibis/common/tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
is_int = instance_of(int)


def test_argument_repr():
argument = Argument(is_int, typehint=int, default=None)
assert repr(argument) == (
"Argument(validator=instance_of(<class 'int'>,), default=None, "
"typehint=<class 'int'>)"
)


def test_default_argument():
annotation = Argument.default(validator=int, default=3)
assert annotation.validate(1) == 1
Expand Down
121 changes: 106 additions & 15 deletions ibis/common/tests/test_grounds.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import pickle
import weakref
from typing import Sequence, TypeVar
from typing import Mapping, Sequence, Tuple, TypeVar

import pytest

Expand All @@ -25,7 +25,7 @@
Immutable,
Singleton,
)
from ibis.common.validators import Coercible, instance_of, option, validator
from ibis.common.validators import Coercible, Validator, instance_of, option, validator
from ibis.tests.util import assert_pickle_roundtrip

is_any = instance_of(object)
Expand Down Expand Up @@ -85,14 +85,16 @@ class VariadicArgsAndKeywords(Concrete):


T = TypeVar('T')
K = TypeVar('K')
V = TypeVar('V')


class List(Concrete, Sequence[T], Coercible):
@classmethod
def __coerce__(self, value):
value = tuple(value)
if value:
head, *rest = value
def __coerce__(self, values):
values = tuple(values)
if values:
head, *rest = values
return ConsList(head, rest)
else:
return EmptyList()
Expand Down Expand Up @@ -120,6 +122,66 @@ def __len__(self):
return len(self.rest) + 1


class Map(Concrete, Mapping[K, V], Coercible):
@classmethod
def __coerce__(self, pairs):
pairs = dict(pairs)
if pairs:
head_key = next(iter(pairs))
head_value = pairs.pop(head_key)
rest = pairs
return ConsMap((head_key, head_value), rest)
else:
return EmptyMap()


class EmptyMap(Map[K, V]):
def __getitem__(self, key):
raise KeyError(key)

def __iter__(self):
return iter(())

def __len__(self):
return 0


class ConsMap(Map[K, V]):
head: Tuple[K, V]
rest: Map[K, V]

def __getitem__(self, key):
if key == self.head[0]:
return self.head[1]
else:
return self.rest[key]

def __iter__(self):
yield self.head[0]
yield from self.rest

def __len__(self):
return len(self.rest) + 1


class Integer(int, Coercible):
@classmethod
def __coerce__(cls, value):
return Integer(value)


class Float(float, Coercible):
@classmethod
def __coerce__(cls, value):
return Float(value)


class MyExpr(Concrete):
a: Integer
b: List[Float]
c: Map[str, Integer]


def test_annotable():
class InBetween(BetweenSimple):
pass
Expand Down Expand Up @@ -199,7 +261,37 @@ class Op(Annotable):


def test_annotable_with_recursive_generic_type_annotations():
pass
# testing cons list
validator = Validator.from_typehint(List[Integer])
values = ["1", 2.0, 3]
result = validator(values)
expected = ConsList(1, ConsList(2, ConsList(3, EmptyList())))
assert result == expected
assert result[0] == 1
assert result[1] == 2
assert result[2] == 3
assert len(result) == 3
with pytest.raises(IndexError):
result[3]

# testing cons map
validator = Validator.from_typehint(Map[Integer, Float])
values = {"1": 2, 3: "4.0", 5: 6.0}
result = validator(values)
expected = ConsMap((1, 2.0), ConsMap((3, 4.0), ConsMap((5, 6.0), EmptyMap())))
assert result == expected
assert result[1] == 2.0
assert result[3] == 4.0
assert result[5] == 6.0
assert len(result) == 3
with pytest.raises(KeyError):
result[7]

# testing both encapsulated in a class
expr = MyExpr(a="1", b=["2.0", 3, True], c={"a": "1", "b": 2, "c": 3.0})
assert expr.a == 1
assert expr.b == ConsList(2.0, ConsList(3.0, ConsList(1.0, EmptyList())))
assert expr.c == ConsMap(("a", 1), ConsMap(("b", 2), ConsMap(("c", 3), EmptyMap())))


def test_composition_of_annotable_and_immutable():
Expand Down Expand Up @@ -600,10 +692,7 @@ class Reduction(Annotable):
class Sum(VersionedOp, Reduction):
where = optional(is_bool, default=False)

# assert (
# str(Sum.__signature__)
# == "(arg: instance_of(<class 'object'>,), version: instance_of(<class 'int'>,), where: option(instance_of(<class 'bool'>,),default=False) = None)"
# )
assert tuple(Sum.__signature__.parameters.keys()) == ("arg", "version", "where")


def test_multiple_inheritance_optional_argument_order():
Expand All @@ -618,10 +707,12 @@ class Between(Value, ConditionalOp):
max = is_int
how = optional(is_str, default="strict")

# assert (
# str(Between.__signature__)
# == "(min: instance_of(<class 'int'>,), max: instance_of(<class 'int'>,), how: option(instance_of(<class 'str'>,),default='strict') = None, where: option(instance_of(<class 'bool'>,),default=False) = None)"
# )
assert tuple(Between.__signature__.parameters.keys()) == (
"min",
"max",
"how",
"where",
)


def test_immutability():
Expand Down
12 changes: 10 additions & 2 deletions ibis/common/tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,16 @@ def endswith_d(x, this):
),
),
(Tuple[int, ...], tuple_of(instance_of(int), type=coerced_to(tuple))),
(Dict[str, float], dict_of(instance_of(str), instance_of(float))),
(frozendict[str, int], frozendict_of(instance_of(str), instance_of(int))),
(
Dict[str, float],
dict_of(instance_of(str), instance_of(float), type=coerced_to(dict)),
),
(
frozendict[str, int],
frozendict_of(
instance_of(str), instance_of(int), type=coerced_to(frozendict)
),
),
(Literal["alpha", "beta", "gamma"], isin(("alpha", "beta", "gamma"))),
(
Callable[[str, int], str],
Expand Down
14 changes: 10 additions & 4 deletions ibis/common/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def from_typehint(cls, annot: type) -> Validator:
inners = tuple(map(cls.from_typehint, args))
return tuple_of(inners, type=coerced_to(origin))
elif issubclass(origin, Sequence):
(inner,) = map(cls.from_typehint, args)
return sequence_of(inner, type=coerced_to(origin))
(value_inner,) = map(cls.from_typehint, args)
return sequence_of(value_inner, type=coerced_to(origin))
elif issubclass(origin, Mapping):
key_inner, value_inner = map(cls.from_typehint, args)
return mapping_of(key_inner, value_inner, type=origin)
return mapping_of(key_inner, value_inner, type=coerced_to(origin))
elif issubclass(origin, Callable):
if args:
arg_inners = tuple(map(cls.from_typehint, args[0]))
Expand Down Expand Up @@ -496,8 +496,14 @@ def tuple_of(inner: Validator | tuple[Validator], arg: Any, *, type=tuple, **kwa
The coerced tuple containing validated elements.
"""
if isinstance(inner, tuple):
if is_iterable(arg):
arg = tuple(arg)
else:
raise IbisTypeError('Argument must be a sequence')

if len(inner) != len(arg):
raise IbisTypeError(f'Argument must has length {len(inner)}')

return type(validator(item, **kwargs) for validator, item in zip(inner, arg))
else:
return sequence_of(inner, arg, type=type, **kwargs)
Expand Down Expand Up @@ -535,7 +541,7 @@ def mapping_of(
if not isinstance(arg, Mapping):
raise IbisTypeError('Argument must be a mapping')
return type(
(key_inner(k, **kwargs), value_inner(v, **kwargs)) for k, v in arg.items()
{key_inner(k, **kwargs): value_inner(v, **kwargs) for k, v in arg.items()}
)


Expand Down
6 changes: 3 additions & 3 deletions ibis/expr/operations/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ class ExtractTemporalField(TemporalUnary):
output_dtype = dt.int32


ExtractTimestampField = ExtractTemporalField


@public
class ExtractDateField(ExtractTemporalField):
arg = rlz.one_of([rlz.date, rlz.timestamp])
Expand Down Expand Up @@ -412,3 +409,6 @@ class BetweenTime(Between):
arg = rlz.one_of([rlz.timestamp, rlz.time])
lower_bound = rlz.one_of([rlz.time, rlz.string])
upper_bound = rlz.one_of([rlz.time, rlz.string])


public(ExtractTimestampField=ExtractTemporalField)

0 comments on commit 91bef71

Please sign in to comment.