From faae3c0f82f87d1a217ac6c1135f037b898d7cac Mon Sep 17 00:00:00 2001 From: Ali Hamdan Date: Sat, 8 Jul 2023 12:04:03 +0200 Subject: [PATCH 1/3] stubgen: generate valid dataclass stubs Fixes #12441 --- mypy/stubgen.py | 30 ++++++++++++++++ test-data/unit/stubgen.test | 70 +++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 229559ac8120..9758e4c14973 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -101,6 +101,7 @@ OverloadedFuncDef, Statement, StrExpr, + TempNode, TupleExpr, TypeInfo, UnaryExpr, @@ -650,6 +651,7 @@ def __init__( self.defined_names: set[str] = set() # Short names of methods defined in the body of the current class self.method_names: set[str] = set() + self.processing_dataclass = False def visit_mypy_file(self, o: MypyFile) -> None: self.module = o.fullname # Current module being processed @@ -699,6 +701,9 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None: self.clear_decorators() def visit_func_def(self, o: FuncDef) -> None: + if self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated: + # Skip methods generated by the @dataclass decorator + return if ( self.is_private_name(o.name, o.fullname) or self.is_not_in_all(o.name) @@ -890,6 +895,9 @@ def visit_class_def(self, o: ClassDef) -> None: if not self._indent and self._state != EMPTY: sep = len(self._output) self.add("\n") + decorators = self.get_class_decorators(o) + for d in decorators: + self.add(f"{self._indent}@{d}\n") self.add(f"{self._indent}class {o.name}") self.record_name(o.name) base_types = self.get_base_types(o) @@ -921,6 +929,7 @@ def visit_class_def(self, o: ClassDef) -> None: else: self._state = CLASS self.method_names = set() + self.processing_dataclass = False def get_base_types(self, cdef: ClassDef) -> list[str]: """Get list of base classes for a class.""" @@ -967,6 +976,21 @@ def get_base_types(self, cdef: ClassDef) -> list[str]: base_types.append(f"{name}={value.accept(p)}") return base_types + def get_class_decorators(self, cdef: ClassDef) -> list[str]: + decorators: list[str] = [] + p = AliasPrinter(self) + for d in cdef.decorators: + if self.is_dataclass(d): + decorators.append(d.accept(p)) + self.import_tracker.require_name(get_qualified_name(d)) + self.processing_dataclass = True + return decorators + + def is_dataclass(self, expr: Expression) -> bool: + if isinstance(expr, CallExpr): + expr = expr.callee + return self.get_fullname(expr) == "dataclasses.dataclass" + def visit_block(self, o: Block) -> None: # Unreachable statements may be partially uninitialized and that may # cause trouble. @@ -1323,8 +1347,14 @@ def get_init( # Final without type argument is invalid in stubs. final_arg = self.get_str_type_of_node(rvalue) typename += f"[{final_arg}]" + elif self.processing_dataclass: + # attribute without annotation is not a dataclass field, don't add annotation. + return f"{self._indent}{lvalue} = ...\n" else: typename = self.get_str_type_of_node(rvalue) + if self.processing_dataclass and not (isinstance(rvalue, TempNode) and rvalue.no_rhs): + # dataclass field with default value, keep the initializer. + return f"{self._indent}{lvalue}: {typename} = ...\n" return f"{self._indent}{lvalue}: {typename}\n" def add(self, string: str) -> None: diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index e1818dc4c4bc..1e73df3ebb60 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -3317,3 +3317,73 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ... class X(_Incomplete): ... class Y(_Incomplete): ... + + +[case testDataclass] +import dataclasses +import dataclasses as dcs +from dataclasses import dataclass +from dataclasses import dataclass as dc + +@dataclasses.dataclass +class X: + a: int + b: str = "hello" + non_field = None + +@dcs.dataclass +class Y: ... + +@dataclass +class Z: ... + +@dc +class W: ... + +[out] +import dataclasses +import dataclasses as dcs +from dataclasses import dataclass, dataclass as dc + +@dataclasses.dataclass +class X: + a: int + b: str = ... + non_field = ... + +@dcs.dataclass +class Y: ... +@dataclass +class Z: ... +@dc +class W: ... + +[case testDataclassWithKeywords] +from dataclasses import dataclass + +@dataclass(init=False) +class X: ... + +[out] +from dataclasses import dataclass + +@dataclass(init=False) +class X: ... + +[case testDataclassWithExplicitGeneratedMethodsOverrides_semanal] +from dataclasses import dataclass + +@dataclass +class X: + a: int + def __init__(self, a: int, b: str = ...) -> None: ... + def __post_init__(self) -> None: ... + +[out] +from dataclasses import dataclass + +@dataclass +class X: + a: int + def __init__(self, a: int, b: str = ...) -> None: ... + def __post_init__(self) -> None: ... From 691b8e9550b52fb096c548eca0ac65bc7dc891f9 Mon Sep 17 00:00:00 2001 From: Ali Hamdan Date: Fri, 14 Jul 2023 11:16:38 +0200 Subject: [PATCH 2/3] Keep `__init__` in dataclasses and add more tests We cannot safely remove `__init__` and depend on the plugin because its signature depends on dataclass field assignments to `dataclasses.field` and these assignments are not included in the stub --- mypy/stubgen.py | 15 ++++++- test-data/unit/stubgen.test | 90 ++++++++++++++++++++++++++++++++----- 2 files changed, 93 insertions(+), 12 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 9758e4c14973..6d5134ec9ec4 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -701,8 +701,11 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None: self.clear_decorators() def visit_func_def(self, o: FuncDef) -> None: - if self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated: - # Skip methods generated by the @dataclass decorator + is_dataclass_generated = ( + self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated + ) + if is_dataclass_generated and o.name != "__init__": + # Skip methods generated by the @dataclass decorator (except for __init__) return if ( self.is_private_name(o.name, o.fullname) @@ -769,6 +772,12 @@ def visit_func_def(self, o: FuncDef) -> None: else: arg = name + annotation args.append(arg) + if o.name == "__init__" and is_dataclass_generated and "**" in args: + # The dataclass plugin generates invalid nameless "*" and "**" arguments + new_name = "".join(a.split(":", 1)[0] for a in args).replace("*", "") + args[args.index("*")] = f"*{new_name}_" # this name is guaranteed to be unique + args[args.index("**")] = f"**{new_name}__" # same here + retname = None if o.name != "__init__" and isinstance(o.unanalyzed_type, CallableType): if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType): @@ -1413,6 +1422,8 @@ def is_private_name(self, name: str, fullname: str | None = None) -> bool: return False if fullname in EXTRA_EXPORTED: return False + if name == "_": + return False return name.startswith("_") and (not name.endswith("__") or name in IGNORED_DUNDERS) def is_private_member(self, fullname: str) -> bool: diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 1e73df3ebb60..5fed8cd2e61f 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -3318,17 +3318,25 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ... class X(_Incomplete): ... class Y(_Incomplete): ... - [case testDataclass] import dataclasses import dataclasses as dcs -from dataclasses import dataclass +from dataclasses import dataclass, InitVar, KW_ONLY from dataclasses import dataclass as dc +from typing import ClassVar @dataclasses.dataclass class X: a: int b: str = "hello" + c: ClassVar + d: ClassVar = 200 + f: list[int] = field(init=False, default_factory=list) + g: int = field(default=2, kw_only=True) + _: KW_ONLY + h: int = 1 + i: InitVar[str] + j: InitVar = 100 non_field = None @dcs.dataclass @@ -3340,15 +3348,27 @@ class Z: ... @dc class W: ... +@dataclass(init=False, repr=False) +class V: ... + [out] import dataclasses import dataclasses as dcs -from dataclasses import dataclass, dataclass as dc +from dataclasses import InitVar, KW_ONLY, dataclass, dataclass as dc +from typing import ClassVar @dataclasses.dataclass class X: a: int b: str = ... + c: ClassVar + d: ClassVar = ... + f: list[int] = ... + g: int = ... + _: KW_ONLY + h: int = ... + i: InitVar[str] + j: InitVar = ... non_field = ... @dcs.dataclass @@ -3357,18 +3377,51 @@ class Y: ... class Z: ... @dc class W: ... +@dataclass(init=False, repr=False) +class V: ... -[case testDataclassWithKeywords] -from dataclasses import dataclass +[case testDataclass_semanal] +from dataclasses import dataclass, InitVar, KW_ONLY +from typing import ClassVar -@dataclass(init=False) -class X: ... +@dataclass +class X: + a: int + b: str = "hello" + c: ClassVar + d: ClassVar = 200 + f: list[int] = field(init=False, default_factory=list) + g: int = field(default=2, kw_only=True) + _: KW_ONLY + h: int = 1 + i: InitVar[str] + j: InitVar = 100 + non_field = None + +@dataclass(init=False, repr=False, frozen=True) +class Y: ... [out] -from dataclasses import dataclass +from dataclasses import InitVar, KW_ONLY, dataclass +from typing import ClassVar -@dataclass(init=False) -class X: ... +@dataclass +class X: + a: int + b: str = ... + c: ClassVar + d: ClassVar = ... + f: list[int] = ... + g: int = ... + _: KW_ONLY + h: int = ... + i: InitVar[str] + j: InitVar = ... + non_field = ... + def __init__(self, a, b, f, g, *, h, i, j) -> None: ... + +@dataclass(init=False, repr=False, frozen=True) +class Y: ... [case testDataclassWithExplicitGeneratedMethodsOverrides_semanal] from dataclasses import dataclass @@ -3387,3 +3440,20 @@ class X: a: int def __init__(self, a: int, b: str = ...) -> None: ... def __post_init__(self) -> None: ... + +[case testDataclassInheritsFromAny_semanal] +from dataclasses import dataclass +import missing + +@dataclass +class X(missing.Base): + a: int + +[out] +import missing +from dataclasses import dataclass + +@dataclass +class X(missing.Base): + a: int + def __init__(self, *selfa_, a, **selfa__) -> None: ... From be3028ea5558572fd644613efd1bf890d6f18def Mon Sep 17 00:00:00 2001 From: Ali Hamdan Date: Sat, 26 Aug 2023 22:07:24 +0200 Subject: [PATCH 3/3] Fix tests running on older python versions --- mypy/test/teststubgen.py | 11 ++++++++++ test-data/unit/stubgen.test | 42 +++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/mypy/test/teststubgen.py b/mypy/test/teststubgen.py index 79d380785a39..7e30515ac892 100644 --- a/mypy/test/teststubgen.py +++ b/mypy/test/teststubgen.py @@ -724,11 +724,22 @@ def run_case_inner(self, testcase: DataDrivenTestCase) -> None: def parse_flags(self, program_text: str, extra: list[str]) -> Options: flags = re.search("# flags: (.*)$", program_text, flags=re.MULTILINE) + pyversion = None if flags: flag_list = flags.group(1).split() + for i, flag in enumerate(flag_list): + if flag.startswith("--python-version="): + pyversion = flag.split("=", 1)[1] + del flag_list[i] + break else: flag_list = [] options = parse_options(flag_list + extra) + if pyversion: + # A hack to allow testing old python versions with new language constructs + # This should be rarely used in general as stubgen output should not be version-specific + major, minor = pyversion.split(".", 1) + options.pyversion = (int(major), int(minor)) if "--verbose" not in flag_list: options.quiet = True else: diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 64b759d0c960..828680fadcf2 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -3576,6 +3576,48 @@ class W: ... class V: ... [case testDataclass_semanal] +from dataclasses import dataclass, InitVar +from typing import ClassVar + +@dataclass +class X: + a: int + b: str = "hello" + c: ClassVar + d: ClassVar = 200 + f: list[int] = field(init=False, default_factory=list) + g: int = field(default=2, kw_only=True) + h: int = 1 + i: InitVar[str] + j: InitVar = 100 + non_field = None + +@dataclass(init=False, repr=False, frozen=True) +class Y: ... + +[out] +from dataclasses import InitVar, dataclass +from typing import ClassVar + +@dataclass +class X: + a: int + b: str = ... + c: ClassVar + d: ClassVar = ... + f: list[int] = ... + g: int = ... + h: int = ... + i: InitVar[str] + j: InitVar = ... + non_field = ... + def __init__(self, a, b, f, g, h, i, j) -> None: ... + +@dataclass(init=False, repr=False, frozen=True) +class Y: ... + +[case testDataclassWithKwOnlyField_semanal] +# flags: --python-version=3.10 from dataclasses import dataclass, InitVar, KW_ONLY from typing import ClassVar