diff --git a/mypy/checker.py b/mypy/checker.py index f8461fefc55f..c9d2d3ede283 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5350,10 +5350,26 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM return self.hasattr_type_maps(expr, self.lookup_type(expr), attr[0]) elif isinstance(node.callee, RefExpr): if node.callee.type_guard is not None: - # TODO: Follow keyword args or *args, **kwargs + # TODO: Follow *args, **kwargs if node.arg_kinds[0] != nodes.ARG_POS: - self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node) - return {}, {} + # the first argument might be used as a kwarg + called_type = get_proper_type(self.lookup_type(node.callee)) + assert isinstance(called_type, (CallableType, Overloaded)) + + # *assuming* the overloaded function is correct, there's a couple cases: + # 1) The first argument has different names, but is pos-only. We don't + # care about this case, the argument must be passed positionally. + # 2) The first argument allows keyword reference, therefore must be the + # same between overloads. + name = called_type.items[0].arg_names[0] + + if name in node.arg_names: + idx = node.arg_names.index(name) + # we want the idx-th variable to be narrowed + expr = collapse_walrus(node.args[idx]) + else: + self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node) + return {}, {} if literal(expr) == LITERAL_TYPE: # Note: we wrap the target type, so that we can special case later. # Namely, for isinstance() we use a normal meet, while TypeGuard is diff --git a/mypy/semanal.py b/mypy/semanal.py index 79302b4d08e1..d0802c194943 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -864,6 +864,20 @@ def analyze_func_def(self, defn: FuncDef) -> None: return assert isinstance(result, ProperType) if isinstance(result, CallableType): + # type guards need to have a positional argument, to spec + if ( + result.type_guard + and ARG_POS not in result.arg_kinds[self.is_class_scope() :] + and not defn.is_static + ): + self.fail( + "TypeGuard functions must have a positional argument", + result, + code=codes.VALID_TYPE, + ) + # in this case, we just kind of just ... remove the type guard. + result = result.copy_modified(type_guard=None) + result = self.remove_unpack_kwargs(defn, result) if has_self_type and self.type is not None: info = self.type diff --git a/test-data/unit/check-python38.test b/test-data/unit/check-python38.test index 7e5e0f3cf185..b9c798b9530e 100644 --- a/test-data/unit/check-python38.test +++ b/test-data/unit/check-python38.test @@ -735,6 +735,34 @@ class C(Generic[T]): main:10: note: Revealed type is "builtins.int" main:10: note: Revealed type is "builtins.str" +[case testTypeGuardWithPositionalOnlyArg] +# flags: --python-version 3.8 +from typing_extensions import TypeGuard + +def typeguard(x: object, /) -> TypeGuard[int]: + ... + +n: object +if typeguard(n): + reveal_type(n) +[builtins fixtures/tuple.pyi] +[out] +main:9: note: Revealed type is "builtins.int" + +[case testTypeGuardKeywordFollowingWalrus] +# flags: --python-version 3.8 +from typing import cast +from typing_extensions import TypeGuard + +def typeguard(x: object) -> TypeGuard[int]: + ... + +if typeguard(x=(n := cast(object, "hi"))): + reveal_type(n) +[builtins fixtures/tuple.pyi] +[out] +main:9: note: Revealed type is "builtins.int" + [case testNoCrashOnAssignmentExprClass] class C: [(j := i) for i in [1, 2, 3]] # E: Assignment expression within a comprehension cannot be used in a class body diff --git a/test-data/unit/check-typeguard.test b/test-data/unit/check-typeguard.test index cf72e7033087..39bcb091f09e 100644 --- a/test-data/unit/check-typeguard.test +++ b/test-data/unit/check-typeguard.test @@ -37,8 +37,8 @@ reveal_type(foo) # N: Revealed type is "def (a: builtins.object) -> TypeGuard[b [case testTypeGuardCallArgsNone] from typing_extensions import TypeGuard class Point: pass -# TODO: error on the 'def' line (insufficient args for type guard) -def is_point() -> TypeGuard[Point]: pass + +def is_point() -> TypeGuard[Point]: pass # E: TypeGuard functions must have a positional argument def main(a: object) -> None: if is_point(): reveal_type(a) # N: Revealed type is "builtins.object" @@ -227,13 +227,13 @@ def main(a: object) -> None: from typing_extensions import TypeGuard def is_float(a: object, b: object = 0) -> TypeGuard[float]: pass def main1(a: object) -> None: - # This is debatable -- should we support these cases? + if is_float(a=a, b=1): + reveal_type(a) # N: Revealed type is "builtins.float" - if is_float(a=a, b=1): # E: Type guard requires positional argument - reveal_type(a) # N: Revealed type is "builtins.object" + if is_float(b=1, a=a): + reveal_type(a) # N: Revealed type is "builtins.float" - if is_float(b=1, a=a): # E: Type guard requires positional argument - reveal_type(a) # N: Revealed type is "builtins.object" + # This is debatable -- should we support these cases? ta = (a,) if is_float(*ta): # E: Type guard requires positional argument @@ -597,3 +597,77 @@ def func(names: Tuple[str, ...]): if is_two_element_tuple(names): reveal_type(names) # N: Revealed type is "Tuple[builtins.str, builtins.str]" [builtins fixtures/tuple.pyi] + +[case testTypeGuardErroneousDefinitionFails] +from typing_extensions import TypeGuard + +class Z: + def typeguard(self, *, x: object) -> TypeGuard[int]: # E: TypeGuard functions must have a positional argument + ... + +def bad_typeguard(*, x: object) -> TypeGuard[int]: # E: TypeGuard functions must have a positional argument + ... +[builtins fixtures/tuple.pyi] + +[case testTypeGuardWithKeywordArg] +from typing_extensions import TypeGuard + +class Z: + def typeguard(self, x: object) -> TypeGuard[int]: + ... + +def typeguard(x: object) -> TypeGuard[int]: + ... + +n: object +if typeguard(x=n): + reveal_type(n) # N: Revealed type is "builtins.int" + +if Z().typeguard(x=n): + reveal_type(n) # N: Revealed type is "builtins.int" +[builtins fixtures/tuple.pyi] + +[case testStaticMethodTypeGuard] +from typing_extensions import TypeGuard + +class Y: + @staticmethod + def typeguard(h: object) -> TypeGuard[int]: + ... + +x: object +if Y().typeguard(x): + reveal_type(x) # N: Revealed type is "builtins.int" +if Y.typeguard(x): + reveal_type(x) # N: Revealed type is "builtins.int" +[builtins fixtures/tuple.pyi] +[builtins fixtures/classmethod.pyi] + +[case testTypeGuardKwargFollowingThroughOverloaded] +from typing import overload, Union +from typing_extensions import TypeGuard + +@overload +def typeguard(x: object, y: str) -> TypeGuard[str]: + ... + +@overload +def typeguard(x: object, y: int) -> TypeGuard[int]: + ... + +def typeguard(x: object, y: Union[int, str]) -> Union[TypeGuard[int], TypeGuard[str]]: + ... + +x: object +if typeguard(x=x, y=42): + reveal_type(x) # N: Revealed type is "builtins.int" + +if typeguard(y=42, x=x): + reveal_type(x) # N: Revealed type is "builtins.int" + +if typeguard(x=x, y="42"): + reveal_type(x) # N: Revealed type is "builtins.str" + +if typeguard(y="42", x=x): + reveal_type(x) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi]