diff --git a/ibis/common/annotations.py b/ibis/common/annotations.py index 81e60cf67160..6e1fa62d54ea 100644 --- a/ibis/common/annotations.py +++ b/ibis/common/annotations.py @@ -46,6 +46,12 @@ def __eq__(self, other): and self._validator == other._validator ) + def __repr__(self): + return ( + f"{self.__class__.__name__}(validator={self._validator!r}, " + f"default={self._default!r}, typehint={self._typehint!r})" + ) + def validate(self, arg, **kwargs): if self._validator is None: return arg diff --git a/ibis/common/tests/test_annotations.py b/ibis/common/tests/test_annotations.py index b9664b7a40fd..e35d37fcf963 100644 --- a/ibis/common/tests/test_annotations.py +++ b/ibis/common/tests/test_annotations.py @@ -13,6 +13,14 @@ is_int = instance_of(int) +def test_argument_repr(): + argument = Argument(is_int, typehint=int, default=None) + assert repr(argument) == ( + "Argument(validator=instance_of(,), default=None, " + "typehint=)" + ) + + def test_default_argument(): annotation = Argument.default(validator=int, default=3) assert annotation.validate(1) == 1 diff --git a/ibis/common/tests/test_grounds.py b/ibis/common/tests/test_grounds.py index b739f262f944..11f511876f46 100644 --- a/ibis/common/tests/test_grounds.py +++ b/ibis/common/tests/test_grounds.py @@ -1,7 +1,7 @@ import copy import pickle import weakref -from typing import Sequence, TypeVar +from typing import Mapping, Sequence, Tuple, TypeVar import pytest @@ -25,7 +25,7 @@ Immutable, Singleton, ) -from ibis.common.validators import Coercible, instance_of, option, validator +from ibis.common.validators import Coercible, Validator, instance_of, option, validator from ibis.tests.util import assert_pickle_roundtrip is_any = instance_of(object) @@ -85,14 +85,16 @@ class VariadicArgsAndKeywords(Concrete): T = TypeVar('T') +K = TypeVar('K') +V = TypeVar('V') class List(Concrete, Sequence[T], Coercible): @classmethod - def __coerce__(self, value): - value = tuple(value) - if value: - head, *rest = value + def __coerce__(self, values): + values = tuple(values) + if values: + head, *rest = values return ConsList(head, rest) else: return EmptyList() @@ -120,6 +122,66 @@ def __len__(self): return len(self.rest) + 1 +class Map(Concrete, Mapping[K, V], Coercible): + @classmethod + def __coerce__(self, pairs): + pairs = dict(pairs) + if pairs: + head_key = next(iter(pairs)) + head_value = pairs.pop(head_key) + rest = pairs + return ConsMap((head_key, head_value), rest) + else: + return EmptyMap() + + +class EmptyMap(Map[K, V]): + def __getitem__(self, key): + raise KeyError(key) + + def __iter__(self): + return iter(()) + + def __len__(self): + return 0 + + +class ConsMap(Map[K, V]): + head: Tuple[K, V] + rest: Map[K, V] + + def __getitem__(self, key): + if key == self.head[0]: + return self.head[1] + else: + return self.rest[key] + + def __iter__(self): + yield self.head[0] + yield from self.rest + + def __len__(self): + return len(self.rest) + 1 + + +class Integer(int, Coercible): + @classmethod + def __coerce__(cls, value): + return Integer(value) + + +class Float(float, Coercible): + @classmethod + def __coerce__(cls, value): + return Float(value) + + +class MyExpr(Concrete): + a: Integer + b: List[Float] + c: Map[str, Integer] + + def test_annotable(): class InBetween(BetweenSimple): pass @@ -199,7 +261,37 @@ class Op(Annotable): def test_annotable_with_recursive_generic_type_annotations(): - pass + # testing cons list + validator = Validator.from_typehint(List[Integer]) + values = ["1", 2.0, 3] + result = validator(values) + expected = ConsList(1, ConsList(2, ConsList(3, EmptyList()))) + assert result == expected + assert result[0] == 1 + assert result[1] == 2 + assert result[2] == 3 + assert len(result) == 3 + with pytest.raises(IndexError): + result[3] + + # testing cons map + validator = Validator.from_typehint(Map[Integer, Float]) + values = {"1": 2, 3: "4.0", 5: 6.0} + result = validator(values) + expected = ConsMap((1, 2.0), ConsMap((3, 4.0), ConsMap((5, 6.0), EmptyMap()))) + assert result == expected + assert result[1] == 2.0 + assert result[3] == 4.0 + assert result[5] == 6.0 + assert len(result) == 3 + with pytest.raises(KeyError): + result[7] + + # testing both encapsulated in a class + expr = MyExpr(a="1", b=["2.0", 3, True], c={"a": "1", "b": 2, "c": 3.0}) + assert expr.a == 1 + assert expr.b == ConsList(2.0, ConsList(3.0, ConsList(1.0, EmptyList()))) + assert expr.c == ConsMap(("a", 1), ConsMap(("b", 2), ConsMap(("c", 3), EmptyMap()))) def test_composition_of_annotable_and_immutable(): @@ -600,10 +692,7 @@ class Reduction(Annotable): class Sum(VersionedOp, Reduction): where = optional(is_bool, default=False) - # assert ( - # str(Sum.__signature__) - # == "(arg: instance_of(,), version: instance_of(,), where: option(instance_of(,),default=False) = None)" - # ) + assert tuple(Sum.__signature__.parameters.keys()) == ("arg", "version", "where") def test_multiple_inheritance_optional_argument_order(): @@ -618,10 +707,12 @@ class Between(Value, ConditionalOp): max = is_int how = optional(is_str, default="strict") - # assert ( - # str(Between.__signature__) - # == "(min: instance_of(,), max: instance_of(,), how: option(instance_of(,),default='strict') = None, where: option(instance_of(,),default=False) = None)" - # ) + assert tuple(Between.__signature__.parameters.keys()) == ( + "min", + "max", + "how", + "where", + ) def test_immutability(): diff --git a/ibis/common/tests/test_validators.py b/ibis/common/tests/test_validators.py index 86d16a56bb79..8f368546bf72 100644 --- a/ibis/common/tests/test_validators.py +++ b/ibis/common/tests/test_validators.py @@ -134,8 +134,16 @@ def endswith_d(x, this): ), ), (Tuple[int, ...], tuple_of(instance_of(int), type=coerced_to(tuple))), - (Dict[str, float], dict_of(instance_of(str), instance_of(float))), - (frozendict[str, int], frozendict_of(instance_of(str), instance_of(int))), + ( + Dict[str, float], + dict_of(instance_of(str), instance_of(float), type=coerced_to(dict)), + ), + ( + frozendict[str, int], + frozendict_of( + instance_of(str), instance_of(int), type=coerced_to(frozendict) + ), + ), (Literal["alpha", "beta", "gamma"], isin(("alpha", "beta", "gamma"))), ( Callable[[str, int], str], diff --git a/ibis/common/validators.py b/ibis/common/validators.py index 9c0b85d05253..01c258f1c433 100644 --- a/ibis/common/validators.py +++ b/ibis/common/validators.py @@ -104,11 +104,11 @@ def from_typehint(cls, annot: type) -> Validator: inners = tuple(map(cls.from_typehint, args)) return tuple_of(inners, type=coerced_to(origin)) elif issubclass(origin, Sequence): - (inner,) = map(cls.from_typehint, args) - return sequence_of(inner, type=coerced_to(origin)) + (value_inner,) = map(cls.from_typehint, args) + return sequence_of(value_inner, type=coerced_to(origin)) elif issubclass(origin, Mapping): key_inner, value_inner = map(cls.from_typehint, args) - return mapping_of(key_inner, value_inner, type=origin) + return mapping_of(key_inner, value_inner, type=coerced_to(origin)) elif issubclass(origin, Callable): if args: arg_inners = tuple(map(cls.from_typehint, args[0])) @@ -496,8 +496,14 @@ def tuple_of(inner: Validator | tuple[Validator], arg: Any, *, type=tuple, **kwa The coerced tuple containing validated elements. """ if isinstance(inner, tuple): + if is_iterable(arg): + arg = tuple(arg) + else: + raise IbisTypeError('Argument must be a sequence') + if len(inner) != len(arg): raise IbisTypeError(f'Argument must has length {len(inner)}') + return type(validator(item, **kwargs) for validator, item in zip(inner, arg)) else: return sequence_of(inner, arg, type=type, **kwargs) @@ -535,7 +541,7 @@ def mapping_of( if not isinstance(arg, Mapping): raise IbisTypeError('Argument must be a mapping') return type( - (key_inner(k, **kwargs), value_inner(v, **kwargs)) for k, v in arg.items() + {key_inner(k, **kwargs): value_inner(v, **kwargs) for k, v in arg.items()} ) diff --git a/ibis/expr/operations/temporal.py b/ibis/expr/operations/temporal.py index a5ca683e7cb8..48f6da613d46 100644 --- a/ibis/expr/operations/temporal.py +++ b/ibis/expr/operations/temporal.py @@ -128,9 +128,6 @@ class ExtractTemporalField(TemporalUnary): output_dtype = dt.int32 -ExtractTimestampField = ExtractTemporalField - - @public class ExtractDateField(ExtractTemporalField): arg = rlz.one_of([rlz.date, rlz.timestamp]) @@ -412,3 +409,6 @@ class BetweenTime(Between): arg = rlz.one_of([rlz.timestamp, rlz.time]) lower_bound = rlz.one_of([rlz.time, rlz.string]) upper_bound = rlz.one_of([rlz.time, rlz.string]) + + +public(ExtractTimestampField=ExtractTemporalField)