Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stubgen: Preserve simple defaults in function signatures #15355

Merged
merged 9 commits into from
Nov 27, 2023
15 changes: 12 additions & 3 deletions mypy/stubdoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,19 @@ def is_valid_type(s: str) -> bool:
class ArgSig:
"""Signature info for a single argument."""

def __init__(self, name: str, type: str | None = None, default: bool = False):
def __init__(
self,
name: str,
type: str | None = None,
*,
default: bool = False,
default_value: str = "...",
) -> None:
self.name = name
self.type = type
# Does this argument have a default value?
self.default = default
self.default_value = default_value

def is_star_arg(self) -> bool:
return self.name.startswith("*") and not self.name.startswith("**")
Expand All @@ -59,6 +67,7 @@ def __eq__(self, other: Any) -> bool:
self.name == other.name
and self.type == other.type
and self.default == other.default
and self.default_value == other.default_value
)
return False

Expand Down Expand Up @@ -119,10 +128,10 @@ def format_sig(
if arg_type:
arg_def += ": " + arg_type
if arg.default:
arg_def += " = ..."
arg_def += f" = {arg.default_value}"

elif arg.default:
arg_def += "=..."
arg_def += f"={arg.default_value}"

args.append(arg_def)

Expand Down
73 changes: 72 additions & 1 deletion mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
NameExpr,
OpExpr,
OverloadedFuncDef,
SetExpr,
Statement,
StrExpr,
TempNode,
Expand Down Expand Up @@ -491,15 +492,21 @@ def _get_func_args(self, o: FuncDef, ctx: FunctionContext) -> list[ArgSig]:
if kind.is_named() and not any(arg.name.startswith("*") for arg in args):
args.append(ArgSig("*"))

default = "..."
if arg_.initializer:
if not typename:
typename = self.get_str_type_of_node(arg_.initializer, True, False)
potential_default, valid = self.get_str_default_of_node(arg_.initializer)
if valid and len(potential_default) <= 200:
default = potential_default
elif kind == ARG_STAR:
name = f"*{name}"
elif kind == ARG_STAR2:
name = f"**{name}"

args.append(ArgSig(name, typename, default=bool(arg_.initializer)))
args.append(
ArgSig(name, typename, default=bool(arg_.initializer), default_value=default)
)

if ctx.class_info is not None and all(
arg.type is None and arg.default is False for arg in args
Expand Down Expand Up @@ -1234,6 +1241,70 @@ def maybe_unwrap_unary_expr(self, expr: Expression) -> Expression:
# This is some other unary expr, we cannot do anything with it (yet?).
return expr

def get_str_default_of_node(self, rvalue: Expression) -> tuple[str, bool]:
"""Get a string representation of the default value of a node.

Returns a 2-tuple of the default and whether or not it is valid.
"""
if isinstance(rvalue, NameExpr):
if rvalue.name in ("None", "True", "False"):
return rvalue.name, True
elif isinstance(rvalue, (IntExpr, FloatExpr)):
return f"{rvalue.value}", True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work correctly for NaN and infinity?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work correctly for NaN and infinity?

Do these have a literal syntax? float("nan") is a call expression that is ignored when used as a default value in the current implementation. I'll add a test.

elif isinstance(rvalue, UnaryExpr):
if isinstance(rvalue.expr, (IntExpr, FloatExpr)):
return f"{rvalue.op}{rvalue.expr.value}", True
elif isinstance(rvalue, StrExpr):
return repr(rvalue.value), True
elif isinstance(rvalue, BytesExpr):
return "b" + repr(rvalue.value).replace("\\\\", "\\"), True
elif isinstance(rvalue, TupleExpr):
items_defaults = []
for e in rvalue.items:
e_default, valid = self.get_str_default_of_node(e)
if not valid:
break
items_defaults.append(e_default)
else:
closing = ",)" if len(items_defaults) == 1 else ")"
default = "(" + ", ".join(items_defaults) + closing
return default, True
elif isinstance(rvalue, ListExpr):
items_defaults = []
for e in rvalue.items:
e_default, valid = self.get_str_default_of_node(e)
if not valid:
break
items_defaults.append(e_default)
else:
default = "[" + ", ".join(items_defaults) + "]"
return default, True
elif isinstance(rvalue, SetExpr):
items_defaults = []
for e in rvalue.items:
e_default, valid = self.get_str_default_of_node(e)
if not valid:
break
items_defaults.append(e_default)
else:
if items_defaults:
default = "{" + ", ".join(items_defaults) + "}"
return default, True
elif isinstance(rvalue, DictExpr):
items_defaults = []
for k, v in rvalue.items:
if k is None:
break
k_default, k_valid = self.get_str_default_of_node(k)
v_default, v_valid = self.get_str_default_of_node(v)
if not (k_valid and v_valid):
break
items_defaults.append(f"{k_default}: {v_default}")
else:
default = "{" + ", ".join(items_defaults) + "}"
return default, True
return "...", False

def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> bool:
is_private = self.is_private_name(name, full_module + "." + name)
if (
Expand Down
89 changes: 69 additions & 20 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,47 @@ def g(arg) -> None: ...
def f(a, b=2): ...
def g(b=-1, c=0): ...
[out]
def f(a, b: int = ...) -> None: ...
def g(b: int = ..., c: int = ...) -> None: ...
def f(a, b: int = 2) -> None: ...
def g(b: int = -1, c: int = 0) -> None: ...

[case testDefaultArgNone]
def f(x=None): ...
[out]
from _typeshed import Incomplete

def f(x: Incomplete | None = ...) -> None: ...
def f(x: Incomplete | None = None) -> None: ...

[case testDefaultArgBool]
def f(x=True, y=False): ...
[out]
def f(x: bool = ..., y: bool = ...) -> None: ...
def f(x: bool = True, y: bool = False) -> None: ...

[case testDefaultArgBool_inspect]
def f(x=True, y=False): ...
[out]
def f(x: bool = ..., y: bool = ...): ...

[case testDefaultArgStr]
def f(x='foo'): ...
def f(x='foo',y="how's quotes"): ...
[out]
def f(x: str = ...) -> None: ...
def f(x: str = 'foo', y: str = "how's quotes") -> None: ...

[case testDefaultArgStr_inspect]
def f(x='foo'): ...
[out]
def f(x: str = ...): ...

[case testDefaultArgBytes]
def f(x=b'foo'): ...
def f(x=b'foo',y=b"what's up",z=b'\xc3\xa0 la une'): ...
[out]
def f(x: bytes = ...) -> None: ...
def f(x: bytes = b'foo', y: bytes = b"what's up", z: bytes = b'\xc3\xa0 la une') -> None: ...

[case testDefaultArgFloat]
def f(x=1.2): ...
def f(x=1.2,y=1e-6,z=0.0,w=-0.0,v=+1.0): ...
def g(x=float("nan"), y=float("inf"), z=float("-inf")): ...
[out]
def f(x: float = ...) -> None: ...
def f(x: float = 1.2, y: float = 1e-06, z: float = 0.0, w: float = -0.0, v: float = +1.0) -> None: ...
def g(x=..., y=..., z=...) -> None: ...

[case testDefaultArgOther]
def f(x=ord): ...
Expand Down Expand Up @@ -126,10 +128,10 @@ def i(a, *, b=1): ...
def j(a, *, b=1, **c): ...
[out]
def f(a, *b, **c) -> None: ...
def g(a, *b, c: int = ...) -> None: ...
def h(a, *b, c: int = ..., **d) -> None: ...
def i(a, *, b: int = ...) -> None: ...
def j(a, *, b: int = ..., **c) -> None: ...
def g(a, *b, c: int = 1) -> None: ...
def h(a, *b, c: int = 1, **d) -> None: ...
def i(a, *, b: int = 1) -> None: ...
def j(a, *, b: int = 1, **c) -> None: ...

[case testClass]
class A:
Expand Down Expand Up @@ -356,8 +358,8 @@ y: Incomplete
def f(x, *, y=1): ...
def g(x, *, y=1, z=2): ...
[out]
def f(x, *, y: int = ...) -> None: ...
def g(x, *, y: int = ..., z: int = ...) -> None: ...
def f(x, *, y: int = 1) -> None: ...
def g(x, *, y: int = 1, z: int = 2) -> None: ...

[case testProperty]
class A:
Expand Down Expand Up @@ -1285,8 +1287,8 @@ from _typeshed import Incomplete

class A:
x: Incomplete
def __init__(self, a: Incomplete | None = ...) -> None: ...
def method(self, a: Incomplete | None = ...) -> None: ...
def __init__(self, a: Incomplete | None = None) -> None: ...
def method(self, a: Incomplete | None = None) -> None: ...

[case testAnnotationImportsFrom]
import foo
Expand Down Expand Up @@ -2514,7 +2516,7 @@ from _typeshed import Incomplete as _Incomplete

Y: _Incomplete

def g(x: _Incomplete | None = ...) -> None: ...
def g(x: _Incomplete | None = None) -> None: ...

x: _Incomplete

Expand Down Expand Up @@ -3503,7 +3505,7 @@ class P(Protocol):
[case testNonDefaultKeywordOnlyArgAfterAsterisk]
def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...
[out]
def func(*, non_default_kwarg: bool, default_kwarg: bool = ...): ...
def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...

[case testNestedGenerator]
def f1():
Expand Down Expand Up @@ -3909,6 +3911,53 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...
class X(_Incomplete): ...
class Y(_Incomplete): ...

[case testIgnoreLongDefaults]
def f(x='abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'): ...

def g(x=b'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'): ...

def h(x=123456789012345678901234567890123456789012345678901234567890\
123456789012345678901234567890123456789012345678901234567890\
123456789012345678901234567890123456789012345678901234567890\
123456789012345678901234567890123456789012345678901234567890): ...

[out]
def f(x: str = ...) -> None: ...
def g(x: bytes = ...) -> None: ...
def h(x: int = ...) -> None: ...

[case testDefaultsOfBuiltinContainers]
def f(x=(), y=(1,), z=(1, 2)): ...
def g(x=[], y=[1, 2]): ...
def h(x={}, y={1: 2, 3: 4}): ...
def i(x={1, 2, 3}): ...
def j(x=[(1,"a"), (2,"b")]): ...

[out]
def f(x=(), y=(1,), z=(1, 2)) -> None: ...
def g(x=[], y=[1, 2]) -> None: ...
def h(x={}, y={1: 2, 3: 4}) -> None: ...
def i(x={1, 2, 3}) -> None: ...
def j(x=[(1, 'a'), (2, 'b')]) -> None: ...

[case testDefaultsOfBuiltinContainersWithNonTrivialContent]
def f(x=(1, u.v), y=(k(),), z=(w,)): ...
def g(x=[1, u.v], y=[k()], z=[w]): ...
def h(x={1: u.v}, y={k(): 2}, z={m: m}, w={**n}): ...
def i(x={u.v, 2}, y={3, k()}, z={w}): ...

[out]
def f(x=..., y=..., z=...) -> None: ...
def g(x=..., y=..., z=...) -> None: ...
def h(x=..., y=..., z=..., w=...) -> None: ...
def i(x=..., y=..., z=...) -> None: ...

[case testDataclass]
import dataclasses
import dataclasses as dcs
Expand Down