Skip to content

Commit

Permalink
feat(common): add support for annotating with coercible types
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Jan 11, 2023
1 parent ddc6603 commit ae4a415
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 11 deletions.
64 changes: 62 additions & 2 deletions ibis/common/tests/test_validators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import sys
from typing import Dict, List, Optional, Tuple, Union

import pytest
from typing_extensions import Annotated

from ibis.common.validators import (
Coercible,
Validator,
all_of,
any_of,
Expand All @@ -14,6 +18,7 @@
int_,
isin,
list_of,
mapping_of,
min_,
str_,
tuple_of,
Expand Down Expand Up @@ -98,5 +103,60 @@ def endswith_d(x, this):
],
)
def test_validator_from_annotation(annot, expected):
validator = Validator.from_annotation(annot)
assert validator == expected
assert Validator.from_annotation(annot) == expected


@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
def test_validator_from_annotation_uniontype():
# uniontype marks `type1 | type2` annotations and it's different from
# Union[type1, type2]
validator = Validator.from_annotation(str | int | float)
assert validator == any_of((instance_of(str), instance_of(int), instance_of(float)))


class Something(Coercible):
def __init__(self, value):
self.value = value

@classmethod
def __coerce__(cls, obj):
return cls(obj + 1)

def __eq__(self, other):
return type(self) == type(other) and self.value == other.value


class SomethingSimilar(Something):
pass


class SomethingDifferent(Coercible):
@classmethod
def __coerce__(cls, obj):
return obj + 2


def test_coercible():
s = Validator.from_annotation(Something)
assert s(1) == Something(2)
assert s(10) == Something(11)


def test_coercible_checks_type():
s = Validator.from_annotation(SomethingSimilar)
v = Validator.from_annotation(SomethingDifferent)

assert s(1) == SomethingSimilar(2)
assert SomethingDifferent.__coerce__(1) == 3

with pytest.raises(TypeError, match="not an instance of .*SomethingDifferent.*"):
v(1)


def test_mapping_of():
value = {"a": 1, "b": 2}
assert mapping_of(str, int, value, type=dict) == value
assert mapping_of(str, int, value, type=frozendict) == frozendict(value)

with pytest.raises(TypeError, match="Argument must be a mapping"):
mapping_of(str, float, 10, type=dict)
12 changes: 9 additions & 3 deletions ibis/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
if sys.version_info >= (3, 9):

@toolz.memoize
def evaluate_typehint(hint, module_name) -> Any:
def evaluate_typehint(hint, module_name=None) -> Any:
if isinstance(hint, str):
hint = ForwardRef(hint)
if isinstance(hint, ForwardRef):
globalns = sys.modules[module_name].__dict__
if module_name is None:
globalns = {}
else:
globalns = sys.modules[module_name].__dict__
return hint._evaluate(globalns, locals(), frozenset())
else:
return hint
Expand All @@ -26,7 +29,10 @@ def evaluate_typehint(hint, module_name) -> Any:
if isinstance(hint, str):
hint = ForwardRef(hint)
if isinstance(hint, ForwardRef):
globalns = sys.modules[module_name].__dict__
if module_name is None:
globalns = {}
else:
globalns = sys.modules[module_name].__dict__
return hint._evaluate(globalns, locals())
else:
return hint
43 changes: 37 additions & 6 deletions ibis/common/validators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import math
from abc import ABC, abstractmethod
from contextlib import suppress
from typing import Any, Callable, Iterable, Mapping, Sequence, Union

Expand All @@ -12,6 +13,20 @@
from ibis.common.typing import evaluate_typehint
from ibis.util import flatten_iterable, frozendict, is_function, is_iterable

try:
from types import UnionType
except ImportError:
UnionType = object()


class Coercible(ABC):
__slots__ = ()

@classmethod
@abstractmethod
def __coerce__(cls, obj):
...


class Validator(Callable):
"""Abstract base class for defining argument validators."""
Expand All @@ -24,11 +39,14 @@ def from_annotation(cls, annot, module=None):
annot = evaluate_typehint(annot, module)
origin_type = get_origin(annot)

if annot is Any:
return any_
elif origin_type is None:
return instance_of(annot)
elif origin_type is Union:
if origin_type is None:
if annot is Any:
return any_
elif issubclass(annot, Coercible):
return coerced_to(annot)
else:
return instance_of(annot)
elif origin_type is UnionType or origin_type is Union:
inners = map(cls.from_annotation, get_args(annot))
return any_of(tuple(inners))
elif origin_type is Annotated:
Expand All @@ -40,8 +58,13 @@ def from_annotation(cls, annot, module=None):
elif issubclass(origin_type, Mapping):
key_type, value_type = map(cls.from_annotation, get_args(annot))
return mapping_of(key_type, value_type, type=origin_type)
elif issubclass(origin_type, Callable):
# TODO(kszucs): add a more comprehensive callable_with rule here
return instance_of(Callable)
else:
return instance_of(annot)
raise NotImplementedError(
f"Cannot create validator from annotation {annot} {origin_type}"
)


# TODO(kszucs): in order to cache valiadator instances we could subclass
Expand Down Expand Up @@ -96,6 +119,12 @@ def instance_of(klasses, arg, **kwargs):
return arg


@validator
def coerced_to(klass, arg, **kwargs):
value = klass.__coerce__(arg)
return instance_of(klass, value, **kwargs)


class lazy_instance_of(Validator):
"""A version of `instance_of` that accepts qualnames instead of imported classes.
Expand Down Expand Up @@ -194,6 +223,8 @@ def sequence_of(inner, arg, *, type, min_length=0, flatten=False, **kwargs):

@validator
def mapping_of(key_inner, value_inner, arg, *, type, **kwargs):
if not isinstance(arg, Mapping):
raise IbisTypeError('Argument must be a mapping')
return type(
(key_inner(k, **kwargs), value_inner(v, **kwargs)) for k, v in arg.items()
)
Expand Down

0 comments on commit ae4a415

Please sign in to comment.