Skip to content

Commit

Permalink
stubgen: use PEP 604 unions everywhere (#16519)
Browse files Browse the repository at this point in the history
Fixes #12920
  • Loading branch information
hamdanal authored Jan 13, 2024
1 parent a8741d8 commit e28925d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 20 deletions.
22 changes: 10 additions & 12 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
=> Generate out/urllib/parse.pyi.
$ stubgen -p urllib
=> Generate stubs for whole urlib package (recursively).
=> Generate stubs for whole urllib package (recursively).
For C modules, you can get more precise function signatures by parsing .rst (Sphinx)
documentation for extra information. For this, use the --doc-dir option:
Expand Down Expand Up @@ -306,6 +306,13 @@ def visit_str_expr(self, node: StrExpr) -> str:
return repr(node.value)

def visit_index_expr(self, node: IndexExpr) -> str:
base_fullname = self.stubgen.get_fullname(node.base)
if base_fullname == "typing.Union":
if isinstance(node.index, TupleExpr):
return " | ".join([item.accept(self) for item in node.index.items])
return node.index.accept(self)
if base_fullname == "typing.Optional":
return f"{node.index.accept(self)} | None"
base = node.base.accept(self)
index = node.index.accept(self)
if len(index) > 2 and index.startswith("(") and index.endswith(")"):
Expand Down Expand Up @@ -682,7 +689,7 @@ def process_decorator(self, o: Decorator) -> None:
self.add_decorator(qualname, require_name=False)

def get_fullname(self, expr: Expression) -> str:
"""Return the full name resolving imports and import aliases."""
"""Return the expression's full name."""
if (
self.analyzed
and isinstance(expr, (NameExpr, MemberExpr))
Expand All @@ -691,16 +698,7 @@ def get_fullname(self, expr: Expression) -> str:
):
return expr.fullname
name = get_qualified_name(expr)
if "." not in name:
real_module = self.import_tracker.module_for.get(name)
real_short = self.import_tracker.reverse_alias.get(name, name)
if real_module is None and real_short not in self.defined_names:
real_module = "builtins" # not imported and not defined, must be a builtin
else:
name_module, real_short = name.split(".", 1)
real_module = self.import_tracker.reverse_alias.get(name_module, name_module)
resolved_name = real_short if real_module is None else f"{real_module}.{real_short}"
return resolved_name
return self.resolve_name(name)

def visit_class_def(self, o: ClassDef) -> None:
self._current_class = o
Expand Down
25 changes: 17 additions & 8 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ def visit_any(self, t: AnyType) -> str:

def visit_unbound_type(self, t: UnboundType) -> str:
s = t.name
fullname = self.stubgen.resolve_name(s)
if fullname == "typing.Union":
return " | ".join([item.accept(self) for item in t.args])
if fullname == "typing.Optional":
return f"{t.args[0].accept(self)} | None"
if self.known_modules is not None and "." in s:
# see if this object is from any of the modules that we're currently processing.
# reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
Expand Down Expand Up @@ -588,14 +593,18 @@ def __init__(
def get_sig_generators(self) -> list[SignatureGenerator]:
return []

def refers_to_fullname(self, name: str, fullname: str | tuple[str, ...]) -> bool:
"""Return True if the variable name identifies the same object as the given fullname(s)."""
if isinstance(fullname, tuple):
return any(self.refers_to_fullname(name, fname) for fname in fullname)
module, short = fullname.rsplit(".", 1)
return self.import_tracker.module_for.get(name) == module and (
name == short or self.import_tracker.reverse_alias.get(name) == short
)
def resolve_name(self, name: str) -> str:
"""Return the full name resolving imports and import aliases."""
if "." not in name:
real_module = self.import_tracker.module_for.get(name)
real_short = self.import_tracker.reverse_alias.get(name, name)
if real_module is None and real_short not in self.defined_names:
real_module = "builtins" # not imported and not defined, must be a builtin
else:
name_module, real_short = name.split(".", 1)
real_module = self.import_tracker.reverse_alias.get(name_module, name_module)
resolved_name = real_short if real_module is None else f"{real_module}.{real_short}"
return resolved_name

def add_name(self, fullname: str, require: bool = True) -> str:
"""Add a name to be imported and return the name reference.
Expand Down
30 changes: 30 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -4139,3 +4139,33 @@ from dataclasses import dataclass
class X(missing.Base):
a: int
def __init__(self, *selfa_, a, **selfa__) -> None: ...

[case testAlwaysUsePEP604Union]
import typing
import typing as t
from typing import Optional, Union, Optional as O, Union as U
import x

union = Union[int, str]
bad_union = Union[int]
nested_union = Optional[Union[int, str]]
not_union = x.Union[int, str]
u = U[int, str]
o = O[int]

def f1(a: Union["int", Optional[tuple[int, t.Optional[int]]]]) -> int: ...
def f2(a: typing.Union[int | x.Union[int, int], O[float]]) -> int: ...

[out]
import x
from _typeshed import Incomplete

union = int | str
bad_union = int
nested_union = int | str | None
not_union: Incomplete
u = int | str
o = int | None

def f1(a: int | tuple[int, int | None] | None) -> int: ...
def f2(a: int | x.Union[int, int] | float | None) -> int: ...

0 comments on commit e28925d

Please sign in to comment.