diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index 50d2955d2584..6fda965ade8b 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -6,6 +6,7 @@ from typing_extensions import Final, Literal import mypy.plugin # To avoid circular imports. +from mypy.errorcodes import LITERAL_REQ from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type from mypy.nodes import ( ARG_NAMED, @@ -246,7 +247,11 @@ def _get_decorator_optional_bool_argument( return False if attr_value.fullname == "builtins.None": return None - ctx.api.fail(f'"{name}" argument must be True or False.', ctx.reason) + ctx.api.fail( + f'"{name}" argument must be a True, False, or None literal', + ctx.reason, + code=LITERAL_REQ, + ) return default return default else: diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 38109892e09d..0acf3e3a6369 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -20,7 +20,11 @@ Var, ) from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface -from mypy.semanal_shared import ALLOW_INCOMPATIBLE_OVERRIDE, set_callable_name +from mypy.semanal_shared import ( + ALLOW_INCOMPATIBLE_OVERRIDE, + require_bool_literal_argument, + set_callable_name, +) from mypy.typeops import ( # noqa: F401 # Part of public API try_getting_str_literals as try_getting_str_literals, ) @@ -54,11 +58,7 @@ def _get_bool_argument(ctx: ClassDefContext, expr: CallExpr, name: str, default: """ attr_value = _get_argument(expr, name) if attr_value: - ret = ctx.api.parse_bool(attr_value) - if ret is None: - ctx.api.fail(f'"{name}" argument must be True or False.', expr) - return default - return ret + return require_bool_literal_argument(ctx.api, attr_value, name, default) return default diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 3feb644dc8ea..872765847073 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -41,7 +41,7 @@ add_method_to_class, deserialize_and_fixup_type, ) -from mypy.semanal_shared import find_dataclass_transform_spec +from mypy.semanal_shared import find_dataclass_transform_spec, require_bool_literal_argument from mypy.server.trigger import make_wildcard_trigger from mypy.state import state from mypy.typeops import map_type_from_supertype @@ -678,11 +678,7 @@ def _get_bool_arg(self, name: str, default: bool) -> bool: # class's keyword arguments (ie `class Subclass(Parent, kwarg1=..., kwarg2=...)`) expression = self._cls.keywords.get(name) if expression is not None: - value = self._api.parse_bool(self._cls.keywords[name]) - if value is not None: - return value - else: - self._api.fail(f'"{name}" argument must be True or False', expression) + return require_bool_literal_argument(self._api, expression, name, default) return default diff --git a/mypy/semanal.py b/mypy/semanal.py index 8dcea36f41b9..8c16b0addd45 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -216,6 +216,7 @@ calculate_tuple_fallback, find_dataclass_transform_spec, has_placeholder, + require_bool_literal_argument, set_callable_name as set_callable_name, ) from mypy.semanal_typeddict import TypedDictAnalyzer @@ -6473,15 +6474,19 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp typing.dataclass_transform.""" parameters = DataclassTransformSpec() for name, value in zip(call.arg_names, call.args): + # Skip any positional args. Note that any such args are invalid, but we can rely on + # typeshed to enforce this and don't need an additional error here. + if name is None: + continue + # field_specifiers is currently the only non-boolean argument; check for it first so # so the rest of the block can fail through to handling booleans if name == "field_specifiers": self.fail('"field_specifiers" support is currently unimplemented', call) continue - boolean = self.parse_bool(value) + boolean = require_bool_literal_argument(self, value, name) if boolean is None: - self.fail(f'"{name}" argument must be a True or False literal', call) continue if name == "eq_default": diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index dd069fbaec98..03efbe6ca1b8 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -3,13 +3,13 @@ from __future__ import annotations from abc import abstractmethod -from typing import Callable -from typing_extensions import Final, Protocol +from typing import Callable, overload +from typing_extensions import Final, Literal, Protocol from mypy_extensions import trait from mypy import join -from mypy.errorcodes import ErrorCode +from mypy.errorcodes import LITERAL_REQ, ErrorCode from mypy.nodes import ( CallExpr, ClassDef, @@ -26,6 +26,7 @@ SymbolTableNode, TypeInfo, ) +from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.tvar_scope import TypeVarLikeScope from mypy.type_visitor import ANY_STRATEGY, BoolTypeQuery from mypy.types import ( @@ -420,3 +421,41 @@ def find_dataclass_transform_spec(node: Node | None) -> DataclassTransformSpec | return metaclass_type.type.dataclass_transform_spec return None + + +# Never returns `None` if a default is given +@overload +def require_bool_literal_argument( + api: SemanticAnalyzerInterface | SemanticAnalyzerPluginInterface, + expression: Expression, + name: str, + default: Literal[True] | Literal[False], +) -> bool: + ... + + +@overload +def require_bool_literal_argument( + api: SemanticAnalyzerInterface | SemanticAnalyzerPluginInterface, + expression: Expression, + name: str, + default: None = None, +) -> bool | None: + ... + + +def require_bool_literal_argument( + api: SemanticAnalyzerInterface | SemanticAnalyzerPluginInterface, + expression: Expression, + name: str, + default: bool | None = None, +) -> bool | None: + """Attempt to interpret an expression as a boolean literal, and fail analysis if we can't.""" + value = api.parse_bool(expression) + if value is None: + api.fail( + f'"{name}" argument must be a True or False literal', expression, code=LITERAL_REQ + ) + return default + + return value diff --git a/mypy/semanal_typeddict.py b/mypy/semanal_typeddict.py index 55618318c1e8..acb93edb7d2d 100644 --- a/mypy/semanal_typeddict.py +++ b/mypy/semanal_typeddict.py @@ -31,7 +31,11 @@ TypeInfo, ) from mypy.options import Options -from mypy.semanal_shared import SemanticAnalyzerInterface, has_placeholder +from mypy.semanal_shared import ( + SemanticAnalyzerInterface, + has_placeholder, + require_bool_literal_argument, +) from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type from mypy.types import ( TPDICT_NAMES, @@ -320,10 +324,7 @@ def analyze_typeddict_classdef_fields( self.fail("Right hand side values are not supported in TypedDict", stmt) total: bool | None = True if "total" in defn.keywords: - total = self.api.parse_bool(defn.keywords["total"]) - if total is None: - self.fail('Value of "total" must be True or False', defn) - total = True + total = require_bool_literal_argument(self.api, defn.keywords["total"], "total", True) required_keys = { field for (field, t) in zip(fields, types) @@ -436,11 +437,9 @@ def parse_typeddict_args( ) total: bool | None = True if len(args) == 3: - total = self.api.parse_bool(call.args[2]) + total = require_bool_literal_argument(self.api, call.args[2], "total") if total is None: - return self.fail_typeddict_arg( - 'TypedDict() "total" argument must be True or False', call - ) + return "", [], [], True, [], False dictexpr = args[1] tvar_defs = self.api.get_and_bind_all_tvars([t for k, t in dictexpr.items]) res = self.parse_typeddict_fields_with_types(dictexpr.items, call) diff --git a/test-data/unit/check-attr.test b/test-data/unit/check-attr.test index f555f2ea7011..f6ef289e792e 100644 --- a/test-data/unit/check-attr.test +++ b/test-data/unit/check-attr.test @@ -151,9 +151,9 @@ class D: [case testAttrsNotBooleans] import attr x = True -@attr.s(cmp=x) # E: "cmp" argument must be True or False. +@attr.s(cmp=x) # E: "cmp" argument must be a True, False, or None literal class A: - a = attr.ib(init=x) # E: "init" argument must be True or False. + a = attr.ib(init=x) # E: "init" argument must be a True or False literal [builtins fixtures/bool.pyi] [case testAttrsInitFalse] @@ -1866,4 +1866,4 @@ reveal_type(D) # N: Revealed type is "def (a: builtins.int, b: builtins.str) -> D(1, "").a = 2 # E: Cannot assign to final attribute "a" D(1, "").b = "2" # E: Cannot assign to final attribute "b" -[builtins fixtures/property.pyi] \ No newline at end of file +[builtins fixtures/property.pyi] diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 40f3a4cde5fb..bc8fe1ecf58c 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -83,12 +83,12 @@ class BaseClass: class Metaclass(type): ... BOOL_CONSTANT = True -@my_dataclass(eq=BOOL_CONSTANT) # E: "eq" argument must be True or False. +@my_dataclass(eq=BOOL_CONSTANT) # E: "eq" argument must be a True or False literal class A: ... -@my_dataclass(order=not False) # E: "order" argument must be True or False. +@my_dataclass(order=not False) # E: "order" argument must be a True or False literal class B: ... -class C(BaseClass, eq=BOOL_CONSTANT): ... # E: "eq" argument must be True or False -class D(metaclass=Metaclass, order=not False): ... # E: "order" argument must be True or False +class C(BaseClass, eq=BOOL_CONSTANT): ... # E: "eq" argument must be a True or False literal +class D(metaclass=Metaclass, order=not False): ... # E: "order" argument must be a True or False literal [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 1f200d168a55..e3d6188b643b 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -1084,8 +1084,8 @@ reveal_type(d) \ [case testTypedDictWithInvalidTotalArgument] from mypy_extensions import TypedDict -A = TypedDict('A', {'x': int}, total=0) # E: TypedDict() "total" argument must be True or False -B = TypedDict('B', {'x': int}, total=bool) # E: TypedDict() "total" argument must be True or False +A = TypedDict('A', {'x': int}, total=0) # E: "total" argument must be a True or False literal +B = TypedDict('B', {'x': int}, total=bool) # E: "total" argument must be a True or False literal C = TypedDict('C', {'x': int}, x=False) # E: Unexpected keyword argument "x" for "TypedDict" D = TypedDict('D', {'x': int}, False) # E: Unexpected arguments to TypedDict() [builtins fixtures/dict.pyi] @@ -1179,12 +1179,12 @@ reveal_type(d) # N: Revealed type is "TypedDict('__main__.D', {'x'?: builtins.in [case testTypedDictClassWithInvalidTotalArgument] from mypy_extensions import TypedDict -class D(TypedDict, total=1): # E: Value of "total" must be True or False +class D(TypedDict, total=1): # E: "total" argument must be a True or False literal x: int -class E(TypedDict, total=bool): # E: Value of "total" must be True or False +class E(TypedDict, total=bool): # E: "total" argument must be a True or False literal x: int -class F(TypedDict, total=xyz): # E: Value of "total" must be True or False \ - # E: Name "xyz" is not defined +class F(TypedDict, total=xyz): # E: Name "xyz" is not defined \ + # E: "total" argument must be a True or False literal x: int [builtins fixtures/dict.pyi]