diff --git a/ibis/common/tests/test_validators.py b/ibis/common/tests/test_validators.py index 428389fc5a26..09dee33f6c16 100644 --- a/ibis/common/tests/test_validators.py +++ b/ibis/common/tests/test_validators.py @@ -1,9 +1,13 @@ +from __future__ import annotations + +import sys from typing import Dict, List, Optional, Tuple, Union import pytest from typing_extensions import Annotated from ibis.common.validators import ( + Coercible, Validator, all_of, any_of, @@ -14,6 +18,7 @@ int_, isin, list_of, + mapping_of, min_, str_, tuple_of, @@ -98,5 +103,60 @@ def endswith_d(x, this): ], ) def test_validator_from_annotation(annot, expected): - validator = Validator.from_annotation(annot) - assert validator == expected + assert Validator.from_annotation(annot) == expected + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") +def test_validator_from_annotation_uniontype(): + # uniontype marks `type1 | type2` annotations and it's different from + # Union[type1, type2] + validator = Validator.from_annotation(str | int | float) + assert validator == any_of((instance_of(str), instance_of(int), instance_of(float))) + + +class Something(Coercible): + def __init__(self, value): + self.value = value + + @classmethod + def __coerce__(cls, obj): + return cls(obj + 1) + + def __eq__(self, other): + return type(self) == type(other) and self.value == other.value + + +class SomethingSimilar(Something): + pass + + +class SomethingDifferent(Coercible): + @classmethod + def __coerce__(cls, obj): + return obj + 2 + + +def test_coercible(): + s = Validator.from_annotation(Something) + assert s(1) == Something(2) + assert s(10) == Something(11) + + +def test_coercible_checks_type(): + s = Validator.from_annotation(SomethingSimilar) + v = Validator.from_annotation(SomethingDifferent) + + assert s(1) == SomethingSimilar(2) + assert SomethingDifferent.__coerce__(1) == 3 + + with pytest.raises(TypeError, match="not an instance of .*SomethingDifferent.*"): + v(1) + + +def test_mapping_of(): + value = {"a": 1, "b": 2} + assert mapping_of(str, int, value, type=dict) == value + assert mapping_of(str, int, value, type=frozendict) == frozendict(value) + + with pytest.raises(TypeError, match="Argument must be a mapping"): + mapping_of(str, float, 10, type=dict) diff --git a/ibis/common/typing.py b/ibis/common/typing.py index c7a04eafff0d..e6043cd180f8 100644 --- a/ibis/common/typing.py +++ b/ibis/common/typing.py @@ -10,11 +10,14 @@ if sys.version_info >= (3, 9): @toolz.memoize - def evaluate_typehint(hint, module_name) -> Any: + def evaluate_typehint(hint, module_name=None) -> Any: if isinstance(hint, str): hint = ForwardRef(hint) if isinstance(hint, ForwardRef): - globalns = sys.modules[module_name].__dict__ + if module_name is None: + globalns = {} + else: + globalns = sys.modules[module_name].__dict__ return hint._evaluate(globalns, locals(), frozenset()) else: return hint @@ -26,7 +29,10 @@ def evaluate_typehint(hint, module_name) -> Any: if isinstance(hint, str): hint = ForwardRef(hint) if isinstance(hint, ForwardRef): - globalns = sys.modules[module_name].__dict__ + if module_name is None: + globalns = {} + else: + globalns = sys.modules[module_name].__dict__ return hint._evaluate(globalns, locals()) else: return hint diff --git a/ibis/common/validators.py b/ibis/common/validators.py index 30569c7c2b96..7fd3fc6701cf 100644 --- a/ibis/common/validators.py +++ b/ibis/common/validators.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +from abc import ABC, abstractmethod from contextlib import suppress from typing import Any, Callable, Iterable, Mapping, Sequence, Union @@ -12,6 +13,20 @@ from ibis.common.typing import evaluate_typehint from ibis.util import flatten_iterable, frozendict, is_function, is_iterable +try: + from types import UnionType +except ImportError: + UnionType = object() + + +class Coercible(ABC): + __slots__ = () + + @classmethod + @abstractmethod + def __coerce__(cls, obj): + ... + class Validator(Callable): """Abstract base class for defining argument validators.""" @@ -24,11 +39,14 @@ def from_annotation(cls, annot, module=None): annot = evaluate_typehint(annot, module) origin_type = get_origin(annot) - if annot is Any: - return any_ - elif origin_type is None: - return instance_of(annot) - elif origin_type is Union: + if origin_type is None: + if annot is Any: + return any_ + elif issubclass(annot, Coercible): + return coerced_to(annot) + else: + return instance_of(annot) + elif origin_type is UnionType or origin_type is Union: inners = map(cls.from_annotation, get_args(annot)) return any_of(tuple(inners)) elif origin_type is Annotated: @@ -40,8 +58,13 @@ def from_annotation(cls, annot, module=None): elif issubclass(origin_type, Mapping): key_type, value_type = map(cls.from_annotation, get_args(annot)) return mapping_of(key_type, value_type, type=origin_type) + elif issubclass(origin_type, Callable): + # TODO(kszucs): add a more comprehensive callable_with rule here + return instance_of(Callable) else: - return instance_of(annot) + raise NotImplementedError( + f"Cannot create validator from annotation {annot} {origin_type}" + ) # TODO(kszucs): in order to cache valiadator instances we could subclass @@ -96,6 +119,12 @@ def instance_of(klasses, arg, **kwargs): return arg +@validator +def coerced_to(klass, arg, **kwargs): + value = klass.__coerce__(arg) + return instance_of(klass, value, **kwargs) + + class lazy_instance_of(Validator): """A version of `instance_of` that accepts qualnames instead of imported classes. @@ -194,6 +223,8 @@ def sequence_of(inner, arg, *, type, min_length=0, flatten=False, **kwargs): @validator def mapping_of(key_inner, value_inner, arg, *, type, **kwargs): + if not isinstance(arg, Mapping): + raise IbisTypeError('Argument must be a mapping') return type( (key_inner(k, **kwargs), value_inner(v, **kwargs)) for k, v in arg.items() )