From 9b2cae1551c606914bc0e3846fbd07f3f41926f4 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Mon, 30 Jan 2023 14:50:24 +0000 Subject: [PATCH 01/11] [dataclass_transform] support default parameters --- mypy/nodes.py | 71 +++++++++-- mypy/plugins/dataclasses.py | 110 ++++++++++++------ mypy/semanal.py | 43 ++++++- test-data/unit/check-dataclass-transform.test | 83 +++++++++++++ test-data/unit/fixtures/dataclasses.pyi | 1 + test-data/unit/fixtures/typing-full.pyi | 9 ++ test-data/unit/fixtures/typing-medium.pyi | 9 +- test-data/unit/lib-stub/typing_extensions.pyi | 9 +- 8 files changed, 288 insertions(+), 47 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index 98976f4fe56a..bb652d01a09c 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -480,13 +480,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: return visitor.visit_import_all(self) -FUNCBASE_FLAGS: Final = [ - "is_property", - "is_class", - "is_static", - "is_final", - "is_dataclass_transform", -] +FUNCBASE_FLAGS: Final = ["is_property", "is_class", "is_static", "is_final"] class FuncBase(Node): @@ -512,7 +506,9 @@ class FuncBase(Node): "is_static", # Uses "@staticmethod" "is_final", # Uses "@final" "_fullname", - "is_dataclass_transform", # Is decorated with "@typing.dataclass_transform" or similar + # Present when a function is decorated with "@typing.dataclass_transform" or similar, and + # records the parameters passed to typing.dataclass_transform for later use + "dataclass_transform_spec", ) def __init__(self) -> None: @@ -531,7 +527,7 @@ def __init__(self) -> None: self.is_final = False # Name with module prefix self._fullname = "" - self.is_dataclass_transform = False + self.dataclass_transform_spec: DataclassTransformSpec | None = None @property @abstractmethod @@ -592,6 +588,11 @@ def serialize(self) -> JsonDict: "fullname": self._fullname, "impl": None if self.impl is None else self.impl.serialize(), "flags": get_flags(self, FUNCBASE_FLAGS), + "dataclass_transform_spec": ( + None + if self.dataclass_transform_spec is None + else self.dataclass_transform_spec.serialize() + ), } @classmethod @@ -610,6 +611,11 @@ def deserialize(cls, data: JsonDict) -> OverloadedFuncDef: assert isinstance(typ, mypy.types.ProperType) res.type = typ res._fullname = data["fullname"] + res.dataclass_transform_spec = ( + DataclassTransformSpec.deserialize(data["dataclass_transform_spec"]) + if data["dataclass_transform_spec"] is not None + else None + ) set_flags(res, data["flags"]) # NOTE: res.info will be set in the fixup phase. return res @@ -810,6 +816,11 @@ def serialize(self) -> JsonDict: "flags": get_flags(self, FUNCDEF_FLAGS), "abstract_status": self.abstract_status, # TODO: Do we need expanded, original_def? + "dataclass_transform_spec": ( + None + if self.dataclass_transform_spec is None + else self.dataclass_transform_spec.serialize() + ), } @classmethod @@ -832,6 +843,11 @@ def deserialize(cls, data: JsonDict) -> FuncDef: ret.arg_names = data["arg_names"] ret.arg_kinds = [ArgKind(x) for x in data["arg_kinds"]] ret.abstract_status = data["abstract_status"] + ret.dataclass_transform_spec = ( + DataclassTransformSpec.deserialize(data["dataclass_transform_spec"]) + if data["dataclass_transform_spec"] is not None + else None + ) # Leave these uninitialized so that future uses will trigger an error del ret.arguments del ret.max_pos @@ -3851,6 +3867,43 @@ def deserialize(cls, data: JsonDict) -> SymbolTable: return st +class DataclassTransformSpec: + """Specifies how a dataclass-like transform should be applied. The fields here are based on the + parameters accepted by `typing.dataclass_transform`.""" + + __slots__ = ("eq_default", "order_default", "kw_only_default", "field_specifiers") + + def __init__( + self, + *, + eq_default: bool | None = None, + order_default: bool | None = None, + kw_only_default: bool | None = None, + field_specifiers: tuple[str, ...] | None = None, + ): + self.eq_default = eq_default if eq_default is not None else True + self.order_default = order_default if order_default is not None else False + self.kw_only_default = kw_only_default if kw_only_default is not None else False + self.field_specifiers = field_specifiers if field_specifiers is not None else () + + def serialize(self) -> JsonDict: + return { + "eq_default": self.eq_default, + "order_default": self.order_default, + "kw_only_default": self.kw_only_default, + "field_specifiers": self.field_specifiers, + } + + @classmethod + def deserialize(cls, data: JsonDict) -> DataclassTransformSpec: + return DataclassTransformSpec( + eq_default=data.get("eq_default"), + order_default=data.get("order_default"), + kw_only_default=data.get("kw_only_default"), + field_specifiers=data.get("field_specifiers"), + ) + + def get_flags(node: Node, names: list[str]) -> list[str]: return [name for name in names if getattr(node, name)] diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 75496d5e56f9..4602c046237e 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -14,13 +14,17 @@ ARG_STAR, ARG_STAR2, MDEF, + SYMBOL_FUNCBASE_TYPES, Argument, AssignmentStmt, CallExpr, Context, + DataclassTransformSpec, + Decorator, Expression, JsonDict, NameExpr, + Node, PlaceholderNode, RefExpr, SymbolTableNode, @@ -56,11 +60,15 @@ # The set of decorators that generate dataclasses. dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"} -# The set of functions that generate dataclass fields. -field_makers: Final = {"dataclasses.field"} SELF_TVAR_NAME: Final = "_DT" +_TRANSFORM_SPEC_FOR_DATACLASSES = DataclassTransformSpec( + eq_default=True, + order_default=False, + kw_only_default=False, + field_specifiers=("dataclasses.Field", "dataclasses.field"), +) class DataclassAttribute: @@ -155,6 +163,7 @@ class DataclassTransformer: def __init__(self, ctx: ClassDefContext) -> None: self._ctx = ctx + self._spec = _get_transform_spec(ctx.reason) def transform(self) -> bool: """Apply all the necessary transformations to the underlying @@ -172,8 +181,8 @@ def transform(self) -> bool: return False decorator_arguments = { "init": _get_decorator_bool_argument(self._ctx, "init", True), - "eq": _get_decorator_bool_argument(self._ctx, "eq", True), - "order": _get_decorator_bool_argument(self._ctx, "order", False), + "eq": _get_decorator_bool_argument(self._ctx, "eq", self._spec.eq_default), + "order": _get_decorator_bool_argument(self._ctx, "order", self._spec.order_default), "frozen": _get_decorator_bool_argument(self._ctx, "frozen", False), "slots": _get_decorator_bool_argument(self._ctx, "slots", False), "match_args": _get_decorator_bool_argument(self._ctx, "match_args", True), @@ -411,7 +420,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: # Second, collect attributes belonging to the current class. current_attr_names: set[str] = set() - kw_only = _get_decorator_bool_argument(ctx, "kw_only", False) + kw_only = _get_decorator_bool_argument(ctx, "kw_only", self._spec.kw_only_default) for stmt in cls.defs.body: # Any assignment that doesn't use the new type declaration # syntax can be ignored out of hand. @@ -461,7 +470,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: if self._is_kw_only_type(node_type): kw_only = True - has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx) + has_field_call, field_args = self._collect_field_args(stmt.rvalue, ctx) is_in_init_param = field_args.get("init") if is_in_init_param is None: @@ -614,6 +623,36 @@ def _add_dataclass_fields_magic_attribute(self) -> None: kind=MDEF, node=var, plugin_generated=True ) + def _collect_field_args( + self, expr: Expression, ctx: ClassDefContext + ) -> tuple[bool, dict[str, Expression]]: + """Returns a tuple where the first value represents whether or not + the expression is a call to dataclass.field and the second is a + dictionary of the keyword arguments that field() was called with. + """ + if ( + isinstance(expr, CallExpr) + and isinstance(expr.callee, RefExpr) + and expr.callee.fullname in self._spec.field_specifiers + ): + # field() only takes keyword arguments. + args = {} + for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds): + if not kind.is_named(): + if kind.is_named(star=True): + # This means that `field` is used with `**` unpacking, + # the best we can do for now is not to fail. + # TODO: we can infer what's inside `**` and try to collect it. + message = 'Unpacking **kwargs in "field()" is not supported' + else: + message = '"field()" does not accept positional arguments' + ctx.api.fail(message, expr) + return True, {} + assert name is not None + args[name] = arg + return True, args + return False, {} + def dataclass_tag_callback(ctx: ClassDefContext) -> None: """Record that we have a dataclass in the main semantic analysis pass. @@ -631,32 +670,35 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool: return transformer.transform() -def _collect_field_args( - expr: Expression, ctx: ClassDefContext -) -> tuple[bool, dict[str, Expression]]: - """Returns a tuple where the first value represents whether or not - the expression is a call to dataclass.field and the second is a - dictionary of the keyword arguments that field() was called with. +def _get_transform_spec(reason: Expression) -> DataclassTransformSpec: + """Find the relevant transform parameters from the decorator/parent class/metaclass that + triggered the dataclasses plugin. + + Although the resulting DataclassTransformSpec is based on the typing.dataclass_transform + function, we also use it for traditional dataclasses.dataclass classes as well for simplicity. + In those cases, we return a default spec rather than one based on a call to + `typing.dataclass_transform`. """ - if ( - isinstance(expr, CallExpr) - and isinstance(expr.callee, RefExpr) - and expr.callee.fullname in field_makers - ): - # field() only takes keyword arguments. - args = {} - for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds): - if not kind.is_named(): - if kind.is_named(star=True): - # This means that `field` is used with `**` unpacking, - # the best we can do for now is not to fail. - # TODO: we can infer what's inside `**` and try to collect it. - message = 'Unpacking **kwargs in "field()" is not supported' - else: - message = '"field()" does not accept positional arguments' - ctx.api.fail(message, expr) - return True, {} - assert name is not None - args[name] = arg - return True, args - return False, {} + node: Node | None = reason + + # The spec only lives on the function/class definition itself, so we need to unwrap down to that + # point + if isinstance(node, CallExpr): + # Decorators may take the form of either @decorator or @decorator(...); unwrap the latter + node = node.callee + if isinstance(node, RefExpr): + # If we see dataclasses.dataclass here, we know that we're not going to find a transform + # spec, so return early. + if node.fullname in dataclass_makers: + return _TRANSFORM_SPEC_FOR_DATACLASSES + node = node.node + if isinstance(node, Decorator): + node = node.func + + if isinstance(node, SYMBOL_FUNCBASE_TYPES) and node.dataclass_transform_spec is not None: + return node.dataclass_transform_spec + + raise AssertionError( + "trying to find dataclass transform spec, but reason is neither dataclasses.dataclass nor " + "decorated with typing.dataclass_transform" + ) diff --git a/mypy/semanal.py b/mypy/semanal.py index 6a483edd7c72..16109c319e33 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -99,6 +99,7 @@ ConditionalExpr, Context, ContinueStmt, + DataclassTransformSpec, Decorator, DelStmt, DictExpr, @@ -307,6 +308,13 @@ # available very early on. CORE_BUILTIN_CLASSES: Final = ["object", "bool", "function"] +KNOWN_DATACLASS_TRANSFORM_PARAMETERS: Final = ( + "eq_default", + "order_default", + "kw_only_default", + "field_specifiers", +) + # Used for tracking incomplete references Tag: _TypeAlias = int @@ -1524,7 +1532,7 @@ def visit_decorator(self, dec: Decorator) -> None: elif isinstance(d, CallExpr) and refers_to_fullname( d.callee, DATACLASS_TRANSFORM_NAMES ): - dec.func.is_dataclass_transform = True + dec.func.dataclass_transform_spec = self.parse_dataclass_transform_spec(d) elif not dec.var.is_property: # We have seen a "non-trivial" decorator before seeing @property, if # we will see a @property later, give an error, as we don't support this. @@ -6452,6 +6460,34 @@ def set_future_import_flags(self, module_name: str) -> None: def is_future_flag_set(self, flag: str) -> bool: return self.modules[self.cur_mod_id].is_future_flag_set(flag) + def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSpec: + """Build a DataclassTransformSpec from the arguments passed to the given call to + typing.dataclass_transform.""" + parameters = DataclassTransformSpec() + for name, value in zip(call.arg_names, call.args): + if name not in KNOWN_DATACLASS_TRANSFORM_PARAMETERS: + self.fail(f"unrecognized dataclass_transform parameter '{name}'", call) + + # field_specifiers is currently the only non-boolean argument; check for it first so + # so the rest of the block can fail through to handling booleans + if name == "field_specifiers": + self.fail("field_specifiers support is currently unimplemented", call) + continue + + boolean = self.parse_bool(value) + if boolean is None: + self.fail(f"{name} argument must be True or False.", call) + continue + + if name == "eq_default": + parameters.eq_default = boolean + elif name == "order_default": + parameters.order_default = boolean + elif name == "kw_only_default": + parameters.kw_only_default = boolean + + return parameters + def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike: if isinstance(sig, CallableType): @@ -6658,4 +6694,7 @@ def is_dataclass_transform_decorator(node: Node | None) -> bool: # We need to unwrap the call for the second variant. return is_dataclass_transform_decorator(node.callee) - return isinstance(node, Decorator) and node.func.is_dataclass_transform + if isinstance(node, Decorator) and node.func.dataclass_transform_spec is not None: + return True + + return False diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 1a25c087c5a6..26401ba90833 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -85,3 +85,86 @@ class B: ... [typing fixtures/typing-medium.pyi] [builtins fixtures/dataclasses.pyi] + +[case testDataclassTransformDefaultParamsMustBeLiterals] +# flags: --python-version 3.7 +from typing import dataclass_transform, Type + +BOOLEAN_CONSTANT = True + +@dataclass_transform(eq_default=BOOLEAN_CONSTANT) # E: eq_default argument must be True or False. +def foo(cls: Type) -> Type: + return cls +@dataclass_transform(eq_default=(not True)) # E: eq_default argument must be True or False. +def bar(cls: Type) -> Type: + return cls + +[typing fixtures/typing-medium.pyi] +[builtins fixtures/dataclasses.pyi] + +[case testDataclassTransformUnrecognizedParamsAreErrors] +# flags: --python-version 3.7 +from typing import dataclass_transform, Type + +BOOLEAN_CONSTANT = True + +@dataclass_transform(nonexistant=True) # E: unrecognized dataclass_transform parameter 'nonexistant' +def foo(cls: Type) -> Type: + return cls + +[typing fixtures/typing-medium.pyi] +[builtins fixtures/dataclasses.pyi] + + +[case testDataclassTransformDefaultParams] +# flags: --python-version 3.7 +from typing import dataclass_transform, Type, Callable + +@dataclass_transform(eq_default=False) +def no_eq(*, order: bool = False) -> Callable[[Type], Type]: + return lambda cls: cls +@no_eq() +class Foo: ... +@no_eq(order=True) +class Bar: ... # E: eq must be True if order is True + + +@dataclass_transform(kw_only_default=True) +def always_use_kw(cls: Type) -> Type: + return cls +@always_use_kw +class Baz: + x: int +Baz(x=5) +Baz(5) # E: Too many positional arguments for "Baz" + +@dataclass_transform(order_default=True) +def ordered(*, eq: bool = True) -> Callable[[Type], Type]: + return lambda cls: cls +@ordered() +class A: + x: int +A(1) > A(2) + +[typing fixtures/typing-medium.pyi] +[builtins fixtures/dataclasses.pyi] + +[case testDataclassTransformFieldSpecifiersDefaultsToEmpty] +# flags: --python-version 3.7 +from dataclasses import field, dataclass +from typing import dataclass_transform, Type + +@dataclass_transform() +def my_dataclass(cls: Type) -> Type: + return cls + +@my_dataclass +class Foo: + foo: int = field(kw_only=True) + +# Does not cause a type error because `dataclasses.field` is not a recognized field specifier by +# default +Foo(5) + +[typing fixtures/typing-medium.pyi] +[builtins fixtures/dataclasses.pyi] diff --git a/test-data/unit/fixtures/dataclasses.pyi b/test-data/unit/fixtures/dataclasses.pyi index 7de40af9cfe7..ab692302a8b6 100644 --- a/test-data/unit/fixtures/dataclasses.pyi +++ b/test-data/unit/fixtures/dataclasses.pyi @@ -18,6 +18,7 @@ class ellipsis: pass class tuple(Generic[_T]): pass class int: pass class float: pass +class bytes: pass class str: pass class bool(int): pass diff --git a/test-data/unit/fixtures/typing-full.pyi b/test-data/unit/fixtures/typing-full.pyi index 04568f7c03f3..1471473249dc 100644 --- a/test-data/unit/fixtures/typing-full.pyi +++ b/test-data/unit/fixtures/typing-full.pyi @@ -181,3 +181,12 @@ class _TypedDict(Mapping[str, object]): def __delitem__(self, k: NoReturn) -> None: ... class _SpecialForm: pass + +def dataclass_transform( + *, + eq_default: bool = ..., + order_default: bool = ..., + kw_only_default: bool = ..., + field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = ..., + **kwargs: Any, +) -> Callable[[T], T]: ... diff --git a/test-data/unit/fixtures/typing-medium.pyi b/test-data/unit/fixtures/typing-medium.pyi index 0d0e13468013..ab3be92d3d9b 100644 --- a/test-data/unit/fixtures/typing-medium.pyi +++ b/test-data/unit/fixtures/typing-medium.pyi @@ -72,4 +72,11 @@ class _SpecialForm: pass TYPE_CHECKING = 1 -def dataclass_transform() -> Callable[[T], T]: ... +def dataclass_transform( + *, + eq_default: bool = ..., + order_default: bool = ..., + kw_only_default: bool = ..., + #field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = ..., + **kwargs: Any, +) -> Callable[[T], T]: ... diff --git a/test-data/unit/lib-stub/typing_extensions.pyi b/test-data/unit/lib-stub/typing_extensions.pyi index 89f7108fe83c..97d02519055d 100644 --- a/test-data/unit/lib-stub/typing_extensions.pyi +++ b/test-data/unit/lib-stub/typing_extensions.pyi @@ -58,4 +58,11 @@ def TypedDict(typename: str, fields: Dict[str, Type[_T]], *, total: Any = ...) - def reveal_type(__obj: T) -> T: pass -def dataclass_transform() -> Callable[[T], T]: ... +def dataclass_transform( + *, + eq_default: bool = ..., + order_default: bool = ..., + kw_only_default: bool = ..., + field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = ..., + **kwargs: Any, +) -> Callable[[T], T]: ... From 85ffb184831284de8cdaa32de78d7ae3544f5cb9 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Wed, 1 Feb 2023 21:17:08 +0000 Subject: [PATCH 02/11] cleanup type stubs a little --- test-data/unit/check-dataclass-transform.test | 14 +++++++------- test-data/unit/fixtures/typing-medium.pyi | 9 --------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 26401ba90833..45703e50eb7b 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -18,7 +18,7 @@ reveal_type(Person) # N: Revealed type is "def (name: builtins.str, age: builti Person('John', 32) Person('Jonh', 21, None) # E: Too many arguments for "Person" -[typing fixtures/typing-medium.pyi] +[typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] [case testDataclassTransformIsFoundInTypingExtensions] @@ -64,7 +64,7 @@ reveal_type(Person) # N: Revealed type is "def (name: builtins.str, age: builti Person('John', 32) Person('John', 21, None) # E: Too many arguments for "Person" -[typing fixtures/typing-medium.pyi] +[typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] [case testDataclassTransformParametersMustBeBoolLiterals] @@ -83,7 +83,7 @@ class A: ... @my_dataclass(order=not False) # E: "order" argument must be True or False. class B: ... -[typing fixtures/typing-medium.pyi] +[typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] [case testDataclassTransformDefaultParamsMustBeLiterals] @@ -99,7 +99,7 @@ def foo(cls: Type) -> Type: def bar(cls: Type) -> Type: return cls -[typing fixtures/typing-medium.pyi] +[typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] [case testDataclassTransformUnrecognizedParamsAreErrors] @@ -112,7 +112,7 @@ BOOLEAN_CONSTANT = True def foo(cls: Type) -> Type: return cls -[typing fixtures/typing-medium.pyi] +[typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] @@ -146,7 +146,7 @@ class A: x: int A(1) > A(2) -[typing fixtures/typing-medium.pyi] +[typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] [case testDataclassTransformFieldSpecifiersDefaultsToEmpty] @@ -166,5 +166,5 @@ class Foo: # default Foo(5) -[typing fixtures/typing-medium.pyi] +[typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] diff --git a/test-data/unit/fixtures/typing-medium.pyi b/test-data/unit/fixtures/typing-medium.pyi index ab3be92d3d9b..863b0703989d 100644 --- a/test-data/unit/fixtures/typing-medium.pyi +++ b/test-data/unit/fixtures/typing-medium.pyi @@ -71,12 +71,3 @@ class ContextManager(Generic[T]): class _SpecialForm: pass TYPE_CHECKING = 1 - -def dataclass_transform( - *, - eq_default: bool = ..., - order_default: bool = ..., - kw_only_default: bool = ..., - #field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = ..., - **kwargs: Any, -) -> Callable[[T], T]: ... From 12c8c74b81389f48cd5f6cff429614ff159e30ec Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Mon, 6 Feb 2023 20:17:05 +0000 Subject: [PATCH 03/11] consolidate functions for finding dataclass transform specs --- mypy/plugins/dataclasses.py | 35 +++++++++++++------------------ mypy/semanal.py | 24 ++------------------- mypy/semanal_main.py | 4 ++-- mypy/semanal_shared.py | 42 +++++++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 45 deletions(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 4602c046237e..686f47d46d66 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -14,13 +14,11 @@ ARG_STAR, ARG_STAR2, MDEF, - SYMBOL_FUNCBASE_TYPES, Argument, AssignmentStmt, CallExpr, Context, DataclassTransformSpec, - Decorator, Expression, JsonDict, NameExpr, @@ -41,6 +39,7 @@ add_method, deserialize_and_fixup_type, ) +from mypy.semanal_shared import find_dataclass_transform_spec from mypy.server.trigger import make_wildcard_trigger from mypy.state import state from mypy.typeops import map_type_from_supertype @@ -679,26 +678,20 @@ def _get_transform_spec(reason: Expression) -> DataclassTransformSpec: In those cases, we return a default spec rather than one based on a call to `typing.dataclass_transform`. """ - node: Node | None = reason + if _is_dataclasses_decorator(reason): + return _TRANSFORM_SPEC_FOR_DATACLASSES - # The spec only lives on the function/class definition itself, so we need to unwrap down to that - # point - if isinstance(node, CallExpr): - # Decorators may take the form of either @decorator or @decorator(...); unwrap the latter - node = node.callee - if isinstance(node, RefExpr): - # If we see dataclasses.dataclass here, we know that we're not going to find a transform - # spec, so return early. - if node.fullname in dataclass_makers: - return _TRANSFORM_SPEC_FOR_DATACLASSES - node = node.node - if isinstance(node, Decorator): - node = node.func - - if isinstance(node, SYMBOL_FUNCBASE_TYPES) and node.dataclass_transform_spec is not None: - return node.dataclass_transform_spec - - raise AssertionError( + spec = find_dataclass_transform_spec(reason) + assert spec is not None, ( "trying to find dataclass transform spec, but reason is neither dataclasses.dataclass nor " "decorated with typing.dataclass_transform" ) + return spec + + +def _is_dataclasses_decorator(node: Node) -> bool: + if isinstance(node, CallExpr): + node = node.callee + if isinstance(node, RefExpr): + return node.fullname in dataclass_makers + return False diff --git a/mypy/semanal.py b/mypy/semanal.py index 67c404b7984c..51a22eaa2867 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -214,6 +214,7 @@ PRIORITY_FALLBACKS, SemanticAnalyzerInterface, calculate_tuple_fallback, + find_dataclass_transform_spec, has_placeholder, set_callable_name as set_callable_name, ) @@ -1737,7 +1738,7 @@ def apply_class_plugin_hooks(self, defn: ClassDef) -> None: # Special case: if the decorator is itself decorated with # typing.dataclass_transform, apply the hook for the dataclasses plugin # TODO: remove special casing here - if hook is None and is_dataclass_transform_decorator(decorator): + if hook is None and find_dataclass_transform_spec(decorator): hook = dataclasses_plugin.dataclass_tag_callback if hook: hook(ClassDefContext(defn, decorator, self)) @@ -6687,24 +6688,3 @@ def halt(self, reason: str = ...) -> NoReturn: return isinstance(stmt, PassStmt) or ( isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, EllipsisExpr) ) - - -def is_dataclass_transform_decorator(node: Node | None) -> bool: - if isinstance(node, RefExpr): - return is_dataclass_transform_decorator(node.node) - if isinstance(node, CallExpr): - # Like dataclasses.dataclass, transform-based decorators can be applied either with or - # without parameters; ie, both of these forms are accepted: - # - # @typing.dataclass_transform - # class Foo: ... - # @typing.dataclass_transform(eq=True, order=True, ...) - # class Bar: ... - # - # We need to unwrap the call for the second variant. - return is_dataclass_transform_decorator(node.callee) - - if isinstance(node, Decorator) and node.func.dataclass_transform_spec is not None: - return True - - return False diff --git a/mypy/semanal_main.py b/mypy/semanal_main.py index d2dd0e32398d..796a862c35e7 100644 --- a/mypy/semanal_main.py +++ b/mypy/semanal_main.py @@ -41,7 +41,6 @@ from mypy.semanal import ( SemanticAnalyzer, apply_semantic_analyzer_patches, - is_dataclass_transform_decorator, remove_imported_names_from_symtable, ) from mypy.semanal_classprop import ( @@ -51,6 +50,7 @@ check_protocol_status, ) from mypy.semanal_infer import infer_decorator_signature_if_simple +from mypy.semanal_shared import find_dataclass_transform_spec from mypy.semanal_typeargs import TypeArgumentAnalyzer from mypy.server.aststrip import SavedAttributes from mypy.util import is_typeshed_file @@ -467,7 +467,7 @@ def apply_hooks_to_class( # Special case: if the decorator is itself decorated with # typing.dataclass_transform, apply the hook for the dataclasses plugin # TODO: remove special casing here - if hook is None and is_dataclass_transform_decorator(decorator): + if hook is None and find_dataclass_transform_spec(decorator): hook = dataclasses_plugin.dataclass_class_maker_callback if hook: diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index 11c4af314a3b..e28fb4052947 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -11,10 +11,15 @@ from mypy import join from mypy.errorcodes import ErrorCode from mypy.nodes import ( + SYMBOL_FUNCBASE_TYPES, + CallExpr, Context, + DataclassTransformSpec, + Decorator, Expression, FuncDef, Node, + RefExpr, SymbolNode, SymbolTable, SymbolTableNode, @@ -341,3 +346,40 @@ def visit_placeholder_type(self, t: PlaceholderType) -> bool: def has_placeholder(typ: Type) -> bool: """Check if a type contains any placeholder types (recursively).""" return typ.accept(HasPlaceholders()) + + +def find_dataclass_transform_spec(node: Node | None) -> DataclassTransformSpec | None: + """ + Find the dataclass transform spec for the given node, if any exists. + + Per PEP 681 (https://peps.python.org/pep-0681/#the-dataclass-transform-decorator), dataclass + transforms can be specified in multiple ways, including decorator functions and + metaclasses/base classes. This function resolves the spec from any of these variants. + """ + + # The spec only lives on the function/class definition itself, so we need to unwrap down to that + # point + if isinstance(node, CallExpr): + # Like dataclasses.dataclass, transform-based decorators can be applied either with or + # without parameters; ie, both of these forms are accepted: + # + # @typing.dataclass_transform + # class Foo: ... + # @typing.dataclass_transform(eq=True, order=True, ...) + # class Bar: ... + # + # We need to unwrap the call for the second variant. + node = node.callee + + if isinstance(node, RefExpr): + node = node.node + + if isinstance(node, Decorator): + # typing.dataclass_transform usage must always result in a Decorator; it always uses the + # `@dataclass_transform(...)` syntax and never `@dataclass_transform` + node = node.func + + if isinstance(node, SYMBOL_FUNCBASE_TYPES): + return node.dataclass_transform_spec + + return None From 9faa37998e91bb8b272b18f7e1f87be4e5780362 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Tue, 7 Feb 2023 15:12:06 +0000 Subject: [PATCH 04/11] add frozen_default, adjust fall-through for parameter parsing --- mypy/nodes.py | 15 ++++++++++++++- mypy/plugins/dataclasses.py | 3 ++- mypy/semanal.py | 16 +++++----------- test-data/unit/check-dataclass-transform.test | 9 +++++++++ 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index bb652d01a09c..161dfd831fb7 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -3871,7 +3871,13 @@ class DataclassTransformSpec: """Specifies how a dataclass-like transform should be applied. The fields here are based on the parameters accepted by `typing.dataclass_transform`.""" - __slots__ = ("eq_default", "order_default", "kw_only_default", "field_specifiers") + __slots__ = ( + "eq_default", + "order_default", + "kw_only_default", + "frozen_default", + "field_specifiers", + ) def __init__( self, @@ -3880,10 +3886,15 @@ def __init__( order_default: bool | None = None, kw_only_default: bool | None = None, field_specifiers: tuple[str, ...] | None = None, + # Specified outside of PEP 681: + # frozen_default was added to CPythonin https://github.com/python/cpython/pull/99958 citing + # positive discussion in typing-sig + frozen_default: bool | None = None, ): self.eq_default = eq_default if eq_default is not None else True self.order_default = order_default if order_default is not None else False self.kw_only_default = kw_only_default if kw_only_default is not None else False + self.frozen_default = frozen_default if frozen_default is not None else False self.field_specifiers = field_specifiers if field_specifiers is not None else () def serialize(self) -> JsonDict: @@ -3891,6 +3902,7 @@ def serialize(self) -> JsonDict: "eq_default": self.eq_default, "order_default": self.order_default, "kw_only_default": self.kw_only_default, + "frozen_only_default": self.frozen_default, "field_specifiers": self.field_specifiers, } @@ -3900,6 +3912,7 @@ def deserialize(cls, data: JsonDict) -> DataclassTransformSpec: eq_default=data.get("eq_default"), order_default=data.get("order_default"), kw_only_default=data.get("kw_only_default"), + frozen_default=data.get("frozen_default"), field_specifiers=data.get("field_specifiers"), ) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 686f47d46d66..988ea6038ffb 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -66,6 +66,7 @@ eq_default=True, order_default=False, kw_only_default=False, + frozen_default=False, field_specifiers=("dataclasses.Field", "dataclasses.field"), ) @@ -182,7 +183,7 @@ def transform(self) -> bool: "init": _get_decorator_bool_argument(self._ctx, "init", True), "eq": _get_decorator_bool_argument(self._ctx, "eq", self._spec.eq_default), "order": _get_decorator_bool_argument(self._ctx, "order", self._spec.order_default), - "frozen": _get_decorator_bool_argument(self._ctx, "frozen", False), + "frozen": _get_decorator_bool_argument(self._ctx, "frozen", self._spec.frozen_default), "slots": _get_decorator_bool_argument(self._ctx, "slots", False), "match_args": _get_decorator_bool_argument(self._ctx, "match_args", True), } diff --git a/mypy/semanal.py b/mypy/semanal.py index 51a22eaa2867..956e76da969f 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -309,13 +309,6 @@ # available very early on. CORE_BUILTIN_CLASSES: Final = ["object", "bool", "function"] -KNOWN_DATACLASS_TRANSFORM_PARAMETERS: Final = ( - "eq_default", - "order_default", - "kw_only_default", - "field_specifiers", -) - # Used for tracking incomplete references Tag: _TypeAlias = int @@ -6476,9 +6469,6 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp typing.dataclass_transform.""" parameters = DataclassTransformSpec() for name, value in zip(call.arg_names, call.args): - if name not in KNOWN_DATACLASS_TRANSFORM_PARAMETERS: - self.fail(f"unrecognized dataclass_transform parameter '{name}'", call) - # field_specifiers is currently the only non-boolean argument; check for it first so # so the rest of the block can fail through to handling booleans if name == "field_specifiers": @@ -6487,7 +6477,7 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp boolean = self.parse_bool(value) if boolean is None: - self.fail(f"{name} argument must be True or False.", call) + self.fail(f"{name} argument must be True or False", call) continue if name == "eq_default": @@ -6496,6 +6486,10 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp parameters.order_default = boolean elif name == "kw_only_default": parameters.kw_only_default = boolean + elif name == "frozen_default": + parameters.frozen_default = boolean + else: + self.fail(f"unrecognized dataclass_transform parameter '{name}'", call) return parameters diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 45703e50eb7b..956382c8d516 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -146,6 +146,15 @@ class A: x: int A(1) > A(2) +@dataclass_transform(frozen_default=True) +def frozen(cls: Type) -> Type: + return cls +@frozen +class B: + x: int +b = B(x=1) +b.x = 2 # E: Property "x" defined in "B" is read-only + [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] From 4d7491351bc76f8e204d332216879f1aef273696 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Tue, 7 Feb 2023 15:38:27 +0000 Subject: [PATCH 05/11] fix error strings in test case (whoops) --- test-data/unit/check-dataclass-transform.test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 956382c8d516..be496b286db1 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -92,10 +92,10 @@ from typing import dataclass_transform, Type BOOLEAN_CONSTANT = True -@dataclass_transform(eq_default=BOOLEAN_CONSTANT) # E: eq_default argument must be True or False. +@dataclass_transform(eq_default=BOOLEAN_CONSTANT) # E: eq_default argument must be True or False def foo(cls: Type) -> Type: return cls -@dataclass_transform(eq_default=(not True)) # E: eq_default argument must be True or False. +@dataclass_transform(eq_default=(not True)) # E: eq_default argument must be True or False def bar(cls: Type) -> Type: return cls From bc620dd8c0ea2b3dc5508eeb3f86f95af5949ede Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Tue, 7 Feb 2023 16:53:43 +0000 Subject: [PATCH 06/11] change most 3.7 => 3.11 for test cases --- test-data/unit/check-dataclass-transform.test | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index be496b286db1..4c6b048a029a 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -1,5 +1,5 @@ [case testDataclassTransformReusesDataclassLogic] -# flags: --python-version 3.7 +# flags: --python-version 3.11 from typing import dataclass_transform, Type @dataclass_transform() @@ -46,7 +46,7 @@ Person('Jonh', 21, None) # E: Too many arguments for "Person" [builtins fixtures/dataclasses.pyi] [case testDataclassTransformParametersAreApplied] -# flags: --python-version 3.7 +# flags: --python-version 3.11 from typing import dataclass_transform, Callable, Type @dataclass_transform() @@ -68,7 +68,7 @@ Person('John', 21, None) # E: Too many arguments for "Person" [builtins fixtures/dataclasses.pyi] [case testDataclassTransformParametersMustBeBoolLiterals] -# flags: --python-version 3.7 +# flags: --python-version 3.11 from typing import dataclass_transform, Callable, Type @dataclass_transform() @@ -87,7 +87,7 @@ class B: ... [builtins fixtures/dataclasses.pyi] [case testDataclassTransformDefaultParamsMustBeLiterals] -# flags: --python-version 3.7 +# flags: --python-version 3.11 from typing import dataclass_transform, Type BOOLEAN_CONSTANT = True @@ -103,7 +103,7 @@ def bar(cls: Type) -> Type: [builtins fixtures/dataclasses.pyi] [case testDataclassTransformUnrecognizedParamsAreErrors] -# flags: --python-version 3.7 +# flags: --python-version 3.11 from typing import dataclass_transform, Type BOOLEAN_CONSTANT = True @@ -117,7 +117,7 @@ def foo(cls: Type) -> Type: [case testDataclassTransformDefaultParams] -# flags: --python-version 3.7 +# flags: --python-version 3.11 from typing import dataclass_transform, Type, Callable @dataclass_transform(eq_default=False) @@ -159,7 +159,7 @@ b.x = 2 # E: Property "x" defined in "B" is read-only [builtins fixtures/dataclasses.pyi] [case testDataclassTransformFieldSpecifiersDefaultsToEmpty] -# flags: --python-version 3.7 +# flags: --python-version 3.11 from dataclasses import field, dataclass from typing import dataclass_transform, Type From ad359775c64211e7cb78e2b91599fad5f0ce7ed0 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Tue, 7 Feb 2023 16:54:35 +0000 Subject: [PATCH 07/11] add final boolean test case --- test-data/unit/check-dataclass-transform.test | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 4c6b048a029a..d9e19aaa1db4 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -88,9 +88,10 @@ class B: ... [case testDataclassTransformDefaultParamsMustBeLiterals] # flags: --python-version 3.11 -from typing import dataclass_transform, Type +from typing import dataclass_transform, Type, Final BOOLEAN_CONSTANT = True +FINAL_BOOLEAN: Final = True @dataclass_transform(eq_default=BOOLEAN_CONSTANT) # E: eq_default argument must be True or False def foo(cls: Type) -> Type: @@ -98,6 +99,9 @@ def foo(cls: Type) -> Type: @dataclass_transform(eq_default=(not True)) # E: eq_default argument must be True or False def bar(cls: Type) -> Type: return cls +@dataclass_transform(eq_default=FINAL_BOOLEAN) # E: eq_default argument must be True or False +def baz(cls: Type) -> Type: + return cls [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] From f8662601588a16985af93185d34a20115bf6c3ed Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Tue, 7 Feb 2023 16:59:08 +0000 Subject: [PATCH 08/11] fix error message formatting --- mypy/semanal.py | 6 +++--- test-data/unit/check-dataclass-transform.test | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index 956e76da969f..f2875cf66d15 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -6472,12 +6472,12 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp # field_specifiers is currently the only non-boolean argument; check for it first so # so the rest of the block can fail through to handling booleans if name == "field_specifiers": - self.fail("field_specifiers support is currently unimplemented", call) + self.fail('"field_specifiers" support is currently unimplemented', call) continue boolean = self.parse_bool(value) if boolean is None: - self.fail(f"{name} argument must be True or False", call) + self.fail(f'"{name}" argument must be a True or False literal', call) continue if name == "eq_default": @@ -6489,7 +6489,7 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp elif name == "frozen_default": parameters.frozen_default = boolean else: - self.fail(f"unrecognized dataclass_transform parameter '{name}'", call) + self.fail(f'Unrecognized dataclass_transform parameter "{name}"', call) return parameters diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index d9e19aaa1db4..3c68af925f6a 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -93,13 +93,13 @@ from typing import dataclass_transform, Type, Final BOOLEAN_CONSTANT = True FINAL_BOOLEAN: Final = True -@dataclass_transform(eq_default=BOOLEAN_CONSTANT) # E: eq_default argument must be True or False +@dataclass_transform(eq_default=BOOLEAN_CONSTANT) # E: "eq_default" argument must be a True or False literal def foo(cls: Type) -> Type: return cls -@dataclass_transform(eq_default=(not True)) # E: eq_default argument must be True or False +@dataclass_transform(eq_default=(not True)) # E: "eq_default" argument must be a True or False literal def bar(cls: Type) -> Type: return cls -@dataclass_transform(eq_default=FINAL_BOOLEAN) # E: eq_default argument must be True or False +@dataclass_transform(eq_default=FINAL_BOOLEAN) # E: "eq_default" argument must be a True or False literal def baz(cls: Type) -> Type: return cls @@ -112,7 +112,7 @@ from typing import dataclass_transform, Type BOOLEAN_CONSTANT = True -@dataclass_transform(nonexistant=True) # E: unrecognized dataclass_transform parameter 'nonexistant' +@dataclass_transform(nonexistant=True) # E: Unrecognized dataclass_transform parameter "nonexistant" def foo(cls: Type) -> Type: return cls From 91772caeb6a3a627e55f877df03bff3ac1a5a75d Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Tue, 7 Feb 2023 17:08:57 +0000 Subject: [PATCH 09/11] add test case for overriding defaults --- test-data/unit/check-dataclass-transform.test | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 3c68af925f6a..75c123c91296 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -162,6 +162,27 @@ b.x = 2 # E: Property "x" defined in "B" is read-only [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] +[case testDataclassTransformDefaultsCanBeOverridden] +# flags: --python-version 3.11 +from typing import dataclass_transform, Callable, Type + +@dataclass_transform(kw_only_default=True) +def my_dataclass(*, kw_only: bool = True) -> Callable[[Type], Type]: + return lambda cls: cls + +@my_dataclass() +class KwOnly: + x: int +@my_dataclass(kw_only=False) +class KwOptional: + x: int + +KwOnly(5) # E: Too many positional arguments for "KwOnly" +KwOptional(5) + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] + [case testDataclassTransformFieldSpecifiersDefaultsToEmpty] # flags: --python-version 3.11 from dataclasses import field, dataclass From b31d375cb30356415c237fb83184c73c0dd53bec Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Tue, 7 Feb 2023 18:52:33 +0000 Subject: [PATCH 10/11] fix tests are merging latest --- test-data/unit/check-dataclass-transform.test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 4131609d3384..01e8935b0745 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -129,8 +129,8 @@ def no_eq(*, order: bool = False) -> Callable[[Type], Type]: return lambda cls: cls @no_eq() class Foo: ... -@no_eq(order=True) -class Bar: ... # E: eq must be True if order is True +@no_eq(order=True) # E: "eq" must be True if "order" is True +class Bar: ... @dataclass_transform(kw_only_default=True) From 02b7ceeffc1aa524c97009e30cfafaaef587d807 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Tue, 7 Feb 2023 18:56:55 +0000 Subject: [PATCH 11/11] move dataclass_transform_spec from FuncBase to FuncDef --- mypy/nodes.py | 17 +++-------------- mypy/semanal_shared.py | 3 +-- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index fbef11724a63..534ba7f82607 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -506,9 +506,6 @@ class FuncBase(Node): "is_static", # Uses "@staticmethod" "is_final", # Uses "@final" "_fullname", - # Present when a function is decorated with "@typing.dataclass_transform" or similar, and - # records the parameters passed to typing.dataclass_transform for later use - "dataclass_transform_spec", ) def __init__(self) -> None: @@ -527,7 +524,6 @@ def __init__(self) -> None: self.is_final = False # Name with module prefix self._fullname = "" - self.dataclass_transform_spec: DataclassTransformSpec | None = None @property @abstractmethod @@ -588,11 +584,6 @@ def serialize(self) -> JsonDict: "fullname": self._fullname, "impl": None if self.impl is None else self.impl.serialize(), "flags": get_flags(self, FUNCBASE_FLAGS), - "dataclass_transform_spec": ( - None - if self.dataclass_transform_spec is None - else self.dataclass_transform_spec.serialize() - ), } @classmethod @@ -611,11 +602,6 @@ def deserialize(cls, data: JsonDict) -> OverloadedFuncDef: assert isinstance(typ, mypy.types.ProperType) res.type = typ res._fullname = data["fullname"] - res.dataclass_transform_spec = ( - DataclassTransformSpec.deserialize(data["dataclass_transform_spec"]) - if data["dataclass_transform_spec"] is not None - else None - ) set_flags(res, data["flags"]) # NOTE: res.info will be set in the fixup phase. return res @@ -764,6 +750,8 @@ class FuncDef(FuncItem, SymbolNode, Statement): "deco_line", "is_trivial_body", "is_mypy_only", + # Present only when a function is decorated with @typing.datasclass_transform or similar + "dataclass_transform_spec", ) __match_args__ = ("name", "arguments", "type", "body") @@ -791,6 +779,7 @@ def __init__( self.deco_line: int | None = None # Definitions that appear in if TYPE_CHECKING are marked with this flag. self.is_mypy_only = False + self.dataclass_transform_spec: DataclassTransformSpec | None = None @property def name(self) -> str: diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index e28fb4052947..05edf2ac073f 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -11,7 +11,6 @@ from mypy import join from mypy.errorcodes import ErrorCode from mypy.nodes import ( - SYMBOL_FUNCBASE_TYPES, CallExpr, Context, DataclassTransformSpec, @@ -379,7 +378,7 @@ def find_dataclass_transform_spec(node: Node | None) -> DataclassTransformSpec | # `@dataclass_transform(...)` syntax and never `@dataclass_transform` node = node.func - if isinstance(node, SYMBOL_FUNCBASE_TYPES): + if isinstance(node, FuncDef): return node.dataclass_transform_spec return None