Skip to content

Commit

Permalink
fix(common): disallow type coercion when checking for generic type fi…
Browse files Browse the repository at this point in the history
…elds
  • Loading branch information
kszucs authored and cpcloud committed Aug 7, 2023
1 parent d4161d7 commit df63e8b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 41 deletions.
64 changes: 23 additions & 41 deletions ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,46 +67,16 @@ class Validator(ABC):
__slots__ = ()

@classmethod
def from_typevar(cls, var: TypeVar, bound: AnyType = None) -> Pattern:
"""Construct a validator from a type variable.
This method is called from two places:
1. `Validator.from_typehint` without additional bound argument
2. `GenericInstanceOf` with a substituted type parameter given as bound
This method also ensures that the type variable is covariant,
contravariant and invariant type variables are not supported yet.
Parameters
----------
var
The type variable to construct the pattern from.
bound
An optional bound to use for the type variable. If not provided,
a no-op validator is returned.
Returns
-------
pattern
A pattern that matches the given type variable.
"""
if var.__covariant__:
if bound := bound or var.__bound__:
return cls.from_typehint(bound)
else:
return Any()
else:
raise NotImplementedError("Only covariant typevars are supported for now")

@classmethod
def from_typehint(cls, annot: type) -> Pattern:
def from_typehint(cls, annot: type, allow_coercion: bool = True) -> Pattern:
"""Construct a validator from a python type annotation.
Parameters
----------
annot
The typehint annotation to construct the pattern from. This must be
an already evaluated type annotation.
allow_coercion
Whether to use coercion if the typehint is a Coercible type.
Returns
-------
Expand All @@ -124,7 +94,7 @@ def from_typehint(cls, annot: type) -> Pattern:
return Any()
elif isinstance(annot, type):
# the typehint is a concrete type (e.g. int, str, etc.)
if issubclass(annot, Coercible):
if allow_coercion and issubclass(annot, Coercible):
# the type implements the Coercible protocol so we try to
# coerce the value to the given type rather than checking
return CoercedTo(annot)
Expand All @@ -133,7 +103,14 @@ def from_typehint(cls, annot: type) -> Pattern:
elif isinstance(annot, TypeVar):
# if the typehint is a type variable we try to construct a
# validator from it only if it is covariant and has a bound
return cls.from_typevar(annot)
if not annot.__covariant__:
raise NotImplementedError(
"Only covariant typevars are supported for now"
)
if annot.__bound__:
return cls.from_typehint(annot.__bound__)
else:
return Any()
elif isinstance(annot, Enum):
# for enums we check the value against the enum values
return EqualTo(annot)
Expand Down Expand Up @@ -208,7 +185,7 @@ def from_typehint(cls, annot: type) -> Pattern:
elif isinstance(origin, GenericMeta):
# construct a validator for the generic type, see the specific
# Generic* validators for more details
if issubclass(origin, Coercible) and args:
if allow_coercion and issubclass(origin, Coercible) and args:
return GenericCoercedTo(annot)
else:
return GenericInstanceOf(annot)
Expand Down Expand Up @@ -562,11 +539,16 @@ class GenericInstanceOf(Matcher):
def __init__(self, typ):
origin = get_origin(typ)
typevars = get_bound_typevars(typ)
field_inners = {
attr: Pattern.from_typevar(var, type_)
for var, (attr, type_) in typevars.items()
}
super().__init__(origin, frozendict(field_inners))

field_patterns = {}
for var, (attr, type_) in typevars.items():
if not var.__covariant__:
raise TypeError(
f"Typevar {var} is not covariant, cannot use it in a GenericInstanceOf"
)
field_patterns[attr] = Pattern.from_typehint(type_, allow_coercion=False)

super().__init__(origin, frozendict(field_patterns))

def match(self, value, context):
if not isinstance(value, self.origin):
Expand Down
29 changes: 29 additions & 0 deletions ibis/common/tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
EqualTo,
FrozenDictOf,
Function,
GenericInstanceOf,
Innermost,
InstanceOf,
IsIn,
Expand Down Expand Up @@ -194,6 +195,21 @@ def test_generic_instance_of_with_covariant_typevar():
assert match(My[int, float], My(1, 2.0, "3"), context={}) == {}


def test_generic_instance_of_disallow_nested_coercion():
class MyString(str, Coercible):
@classmethod
def __coerce__(cls, other):
return cls(str(other))

class Box(Generic[T]):
value: T

p = Pattern.from_typehint(Box[MyString])
assert isinstance(p, GenericInstanceOf)
assert p.origin == Box
assert p.field_patterns == {"value": InstanceOf(MyString)}


def test_coerced_to():
class MyInt(int, Coercible):
@classmethod
Expand Down Expand Up @@ -704,6 +720,19 @@ def test_pattern_from_typehint_uniontype():
assert validator == AnyOf(InstanceOf(str), InstanceOf(int), InstanceOf(float))


def test_pattern_from_typehint_disable_coercion():
class MyFloat(float, Coercible):
@classmethod
def __coerce__(cls, obj):
return cls(float(obj))

p = Pattern.from_typehint(MyFloat, allow_coercion=True)
assert isinstance(p, CoercedTo)

p = Pattern.from_typehint(MyFloat, allow_coercion=False)
assert isinstance(p, InstanceOf)


class PlusOne(Coercible):
def __init__(self, value):
self.value = value
Expand Down

0 comments on commit df63e8b

Please sign in to comment.