From df63e8ba71e8c17fdbfe9f6361d515e33abf3068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 12 Jul 2023 09:09:45 +0200 Subject: [PATCH] fix(common): disallow type coercion when checking for generic type fields --- ibis/common/patterns.py | 64 +++++++++++------------------- ibis/common/tests/test_patterns.py | 29 ++++++++++++++ 2 files changed, 52 insertions(+), 41 deletions(-) diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 954c83dd777a..643a7e9bf9c5 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -67,39 +67,7 @@ class Validator(ABC): __slots__ = () @classmethod - def from_typevar(cls, var: TypeVar, bound: AnyType = None) -> Pattern: - """Construct a validator from a type variable. - - This method is called from two places: - 1. `Validator.from_typehint` without additional bound argument - 2. `GenericInstanceOf` with a substituted type parameter given as bound - - This method also ensures that the type variable is covariant, - contravariant and invariant type variables are not supported yet. - - Parameters - ---------- - var - The type variable to construct the pattern from. - bound - An optional bound to use for the type variable. If not provided, - a no-op validator is returned. - - Returns - ------- - pattern - A pattern that matches the given type variable. - """ - if var.__covariant__: - if bound := bound or var.__bound__: - return cls.from_typehint(bound) - else: - return Any() - else: - raise NotImplementedError("Only covariant typevars are supported for now") - - @classmethod - def from_typehint(cls, annot: type) -> Pattern: + def from_typehint(cls, annot: type, allow_coercion: bool = True) -> Pattern: """Construct a validator from a python type annotation. Parameters @@ -107,6 +75,8 @@ def from_typehint(cls, annot: type) -> Pattern: annot The typehint annotation to construct the pattern from. This must be an already evaluated type annotation. + allow_coercion + Whether to use coercion if the typehint is a Coercible type. Returns ------- @@ -124,7 +94,7 @@ def from_typehint(cls, annot: type) -> Pattern: return Any() elif isinstance(annot, type): # the typehint is a concrete type (e.g. int, str, etc.) - if issubclass(annot, Coercible): + if allow_coercion and issubclass(annot, Coercible): # the type implements the Coercible protocol so we try to # coerce the value to the given type rather than checking return CoercedTo(annot) @@ -133,7 +103,14 @@ def from_typehint(cls, annot: type) -> Pattern: elif isinstance(annot, TypeVar): # if the typehint is a type variable we try to construct a # validator from it only if it is covariant and has a bound - return cls.from_typevar(annot) + if not annot.__covariant__: + raise NotImplementedError( + "Only covariant typevars are supported for now" + ) + if annot.__bound__: + return cls.from_typehint(annot.__bound__) + else: + return Any() elif isinstance(annot, Enum): # for enums we check the value against the enum values return EqualTo(annot) @@ -208,7 +185,7 @@ def from_typehint(cls, annot: type) -> Pattern: elif isinstance(origin, GenericMeta): # construct a validator for the generic type, see the specific # Generic* validators for more details - if issubclass(origin, Coercible) and args: + if allow_coercion and issubclass(origin, Coercible) and args: return GenericCoercedTo(annot) else: return GenericInstanceOf(annot) @@ -562,11 +539,16 @@ class GenericInstanceOf(Matcher): def __init__(self, typ): origin = get_origin(typ) typevars = get_bound_typevars(typ) - field_inners = { - attr: Pattern.from_typevar(var, type_) - for var, (attr, type_) in typevars.items() - } - super().__init__(origin, frozendict(field_inners)) + + field_patterns = {} + for var, (attr, type_) in typevars.items(): + if not var.__covariant__: + raise TypeError( + f"Typevar {var} is not covariant, cannot use it in a GenericInstanceOf" + ) + field_patterns[attr] = Pattern.from_typehint(type_, allow_coercion=False) + + super().__init__(origin, frozendict(field_patterns)) def match(self, value, context): if not isinstance(value, self.origin): diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index 151d90d7232c..ed2b6f22beda 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -40,6 +40,7 @@ EqualTo, FrozenDictOf, Function, + GenericInstanceOf, Innermost, InstanceOf, IsIn, @@ -194,6 +195,21 @@ def test_generic_instance_of_with_covariant_typevar(): assert match(My[int, float], My(1, 2.0, "3"), context={}) == {} +def test_generic_instance_of_disallow_nested_coercion(): + class MyString(str, Coercible): + @classmethod + def __coerce__(cls, other): + return cls(str(other)) + + class Box(Generic[T]): + value: T + + p = Pattern.from_typehint(Box[MyString]) + assert isinstance(p, GenericInstanceOf) + assert p.origin == Box + assert p.field_patterns == {"value": InstanceOf(MyString)} + + def test_coerced_to(): class MyInt(int, Coercible): @classmethod @@ -704,6 +720,19 @@ def test_pattern_from_typehint_uniontype(): assert validator == AnyOf(InstanceOf(str), InstanceOf(int), InstanceOf(float)) +def test_pattern_from_typehint_disable_coercion(): + class MyFloat(float, Coercible): + @classmethod + def __coerce__(cls, obj): + return cls(float(obj)) + + p = Pattern.from_typehint(MyFloat, allow_coercion=True) + assert isinstance(p, CoercedTo) + + p = Pattern.from_typehint(MyFloat, allow_coercion=False) + assert isinstance(p, InstanceOf) + + class PlusOne(Coercible): def __init__(self, value): self.value = value