Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stubgen: generate valid dataclass stubs #15625

Merged
merged 4 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
OverloadedFuncDef,
Statement,
StrExpr,
TempNode,
TupleExpr,
TypeInfo,
UnaryExpr,
Expand Down Expand Up @@ -650,6 +651,7 @@ def __init__(
self.defined_names: set[str] = set()
# Short names of methods defined in the body of the current class
self.method_names: set[str] = set()
self.processing_dataclass = False

def visit_mypy_file(self, o: MypyFile) -> None:
self.module = o.fullname # Current module being processed
Expand Down Expand Up @@ -699,6 +701,9 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
self.clear_decorators()

def visit_func_def(self, o: FuncDef) -> None:
if self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated:
# Skip methods generated by the @dataclass decorator
return
if (
self.is_private_name(o.name, o.fullname)
or self.is_not_in_all(o.name)
Expand Down Expand Up @@ -890,6 +895,9 @@ def visit_class_def(self, o: ClassDef) -> None:
if not self._indent and self._state != EMPTY:
sep = len(self._output)
self.add("\n")
decorators = self.get_class_decorators(o)
for d in decorators:
self.add(f"{self._indent}@{d}\n")
self.add(f"{self._indent}class {o.name}")
self.record_name(o.name)
base_types = self.get_base_types(o)
Expand Down Expand Up @@ -921,6 +929,7 @@ def visit_class_def(self, o: ClassDef) -> None:
else:
self._state = CLASS
self.method_names = set()
self.processing_dataclass = False

def get_base_types(self, cdef: ClassDef) -> list[str]:
"""Get list of base classes for a class."""
Expand Down Expand Up @@ -967,6 +976,21 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
base_types.append(f"{name}={value.accept(p)}")
return base_types

def get_class_decorators(self, cdef: ClassDef) -> list[str]:
decorators: list[str] = []
p = AliasPrinter(self)
for d in cdef.decorators:
if self.is_dataclass(d):
decorators.append(d.accept(p))
self.import_tracker.require_name(get_qualified_name(d))
self.processing_dataclass = True
return decorators

def is_dataclass(self, expr: Expression) -> bool:
if isinstance(expr, CallExpr):
expr = expr.callee
return self.get_fullname(expr) == "dataclasses.dataclass"

def visit_block(self, o: Block) -> None:
# Unreachable statements may be partially uninitialized and that may
# cause trouble.
Expand Down Expand Up @@ -1323,8 +1347,14 @@ def get_init(
# Final without type argument is invalid in stubs.
final_arg = self.get_str_type_of_node(rvalue)
typename += f"[{final_arg}]"
elif self.processing_dataclass:
# attribute without annotation is not a dataclass field, don't add annotation.
return f"{self._indent}{lvalue} = ...\n"
else:
typename = self.get_str_type_of_node(rvalue)
if self.processing_dataclass and not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
# dataclass field with default value, keep the initializer.
return f"{self._indent}{lvalue}: {typename} = ...\n"
return f"{self._indent}{lvalue}: {typename}\n"

def add(self, string: str) -> None:
Expand Down
70 changes: 70 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -3317,3 +3317,73 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...

class X(_Incomplete): ...
class Y(_Incomplete): ...


[case testDataclass]
import dataclasses
import dataclasses as dcs
from dataclasses import dataclass
from dataclasses import dataclass as dc

@dataclasses.dataclass
class X:
a: int
b: str = "hello"
non_field = None

@dcs.dataclass
class Y: ...

@dataclass
class Z: ...

@dc
class W: ...

[out]
import dataclasses
import dataclasses as dcs
from dataclasses import dataclass, dataclass as dc

@dataclasses.dataclass
class X:
a: int
b: str = ...
non_field = ...

@dcs.dataclass
class Y: ...
@dataclass
class Z: ...
@dc
class W: ...

[case testDataclassWithKeywords]
from dataclasses import dataclass

@dataclass(init=False)
class X: ...

[out]
from dataclasses import dataclass

@dataclass(init=False)
class X: ...

[case testDataclassWithExplicitGeneratedMethodsOverrides_semanal]
from dataclasses import dataclass

@dataclass
class X:
a: int
def __init__(self, a: int, b: str = ...) -> None: ...
def __post_init__(self) -> None: ...

[out]
from dataclasses import dataclass

@dataclass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that init=False is assumed here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the dataclass documentation,

init: If true (the default), a __init__() method will be generated.
If the class already defines __init__(), this parameter is ignored.

So it is ignored in this case. This was not the goal of the test anyway, it was to test that user defined methods are always included in the stub.

class X:
a: int
def __init__(self, a: int, b: str = ...) -> None: ...
def __post_init__(self) -> None: ...