Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add signature for attr.evolve #14526

Merged
merged 21 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing_extensions import Final, Literal

import mypy.plugin # To avoid circular imports.
from mypy.checker import TypeChecker
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
from mypy.nodes import (
ARG_NAMED,
Expand Down Expand Up @@ -76,6 +77,7 @@
SELF_TVAR_NAME: Final = "_AT"
MAGIC_ATTR_NAME: Final = "__attrs_attrs__"
MAGIC_ATTR_CLS_NAME_TEMPLATE: Final = "__{}_AttrsAttributes__" # The tuple subclass pattern.
ATTRS_INIT_NAME: Final = "__attrs_init__"


class Converter:
Expand Down Expand Up @@ -325,7 +327,7 @@ def attr_class_maker_callback(

adder = MethodAdder(ctx)
# If __init__ is not being generated, attrs still generates it as __attrs_init__ instead.
_add_init(ctx, attributes, adder, "__init__" if init else "__attrs_init__")
_add_init(ctx, attributes, adder, "__init__" if init else ATTRS_INIT_NAME)
if order:
_add_order(ctx, adder)
if frozen:
Expand Down Expand Up @@ -883,3 +885,63 @@ def add_method(
"""
self_type = self_type if self_type is not None else self.self_type
add_method(self.ctx, method_name, args, ret_type, self_type, tvd)


def _get_attrs_init_type(typ: Type) -> CallableType | None:
"""
If `typ` refers to an attrs class, gets the type of its initializer method.
"""
typ = get_proper_type(typ)
if not isinstance(typ, Instance):
return None
magic_attr = typ.type.get(MAGIC_ATTR_NAME)
if magic_attr is None or not magic_attr.plugin_generated:
return None
init_method = typ.type.get_method("__init__") or typ.type.get_method(ATTRS_INIT_NAME)
if not isinstance(init_method, FuncDef) or not isinstance(init_method.type, CallableType):
return None
return init_method.type


def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
"""
Generates a signature for the 'attr.evolve' function that's specific to the call site
and dependent on the type of the first argument.
"""
if len(ctx.args) != 2:
# Ideally the name and context should be callee's, but we don't have it in FunctionSigContext.
ctx.api.fail(f'"{ctx.default_signature.name}" has unexpected type annotation', ctx.context)
return ctx.default_signature

if len(ctx.args[0]) != 1:
return ctx.default_signature # leave it to the type checker to complain

inst_arg = ctx.args[0][0]

# <hack>
assert isinstance(ctx.api, TypeChecker)
inst_type = ctx.api.expr_checker.accept(inst_arg)
# </hack>
Comment on lines +927 to +930
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JelleZijlstra now that we got this merged 😅 what would be the right place to discuss whether this should be promoted to be formal plugin API?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Open an issue, I suppose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👉 #14845


inst_type = get_proper_type(inst_type)
if isinstance(inst_type, AnyType):
return ctx.default_signature

# In practice, we're taking the initializer generated by _add_init and tweaking it
# so that (a) its arguments are kw-only & optional, and (b) its return type is the instance's.
attrs_init_type = _get_attrs_init_type(inst_type)
if not attrs_init_type:
ctx.api.fail(
f'Argument 1 to "evolve" has incompatible type "{inst_type}"; expected an attrs class',
ctx.context,
)
return ctx.default_signature

arg_names = attrs_init_type.arg_names.copy()
arg_names[0] = "inst"
return attrs_init_type.copy_modified(
arg_names=arg_names,
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT] * (len(attrs_init_type.arg_kinds) - 1),
ret_type=inst_type,
name=ctx.default_signature.name,
)
10 changes: 10 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AttributeContext,
ClassDefContext,
FunctionContext,
FunctionSigContext,
MethodContext,
MethodSigContext,
Plugin,
Expand Down Expand Up @@ -45,6 +46,15 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
return singledispatch.create_singledispatch_function_callback
return None

def get_function_signature_hook(
self, fullname: str
) -> Callable[[FunctionSigContext], FunctionLike] | None:
from mypy.plugins import attrs

if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
return attrs.evolve_function_sig_callback
return None

def get_method_signature_hook(
self, fullname: str
) -> Callable[[MethodSigContext], FunctionLike] | None:
Expand Down
49 changes: 48 additions & 1 deletion test-data/unit/check-attr.test
Original file line number Diff line number Diff line change
Expand Up @@ -1866,4 +1866,51 @@ reveal_type(D) # N: Revealed type is "def (a: builtins.int, b: builtins.str) ->
D(1, "").a = 2 # E: Cannot assign to final attribute "a"
D(1, "").b = "2" # E: Cannot assign to final attribute "b"

[builtins fixtures/property.pyi]
[builtins fixtures/property.pyi]

[case testEvolve]
from typing import Any
import attr

class Base:
pass

class Derived(Base):
pass

class Other:
pass

@attr.s(auto_attribs=True)
class C:
name: str
b: Base

c = C(name='foo', b=Derived())
c = attr.evolve(c)
c = attr.evolve(c, name='foo')
c = attr.evolve(c, 'foo') # E: Too many positional arguments for "evolve"
c = attr.evolve(c, b=Derived())
c = attr.evolve(c, b=Base())
c = attr.evolve(c, b=Other()) # E: Argument "b" to "evolve" has incompatible type "Other"; expected "Base"
c = attr.evolve(c, name=42) # E: Argument "name" to "evolve" has incompatible type "int"; expected "str"
c = attr.evolve(c, foobar=42) # E: Unexpected keyword argument "foobar" for "evolve"

def f() -> C:
return c


# Determining type of first argument's expression
c = attr.evolve(f(), name='foo')

# First argument type check
attr.evolve(42, name='foo') # E: Argument 1 to "evolve" has incompatible type "Literal[42]?"; expected an attrs class
attr.evolve(None, name='foo') # E: Argument 1 to "evolve" has incompatible type "None"; expected an attrs class

# All bets are off for 'Any'
any: Any
ret = attr.evolve(any, name='foo')
reveal_type(ret) # N: Revealed type is "Any"

[builtins fixtures/dict.pyi]
[typing fixtures/typing-medium.pyi]
4 changes: 2 additions & 2 deletions test-data/unit/fine-grained.test
Original file line number Diff line number Diff line change
Expand Up @@ -1809,8 +1809,8 @@ def f() -> Iterator[None]:
[typing fixtures/typing-medium.pyi]
[builtins fixtures/list.pyi]
[triggered]
2: <b>, __main__
3: <b>, __main__, a
2: <b>, <b[wildcard]>, __main__
3: <b>, <b[wildcard]>, __main__, a
Copy link
Contributor Author

@ikonst ikonst Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So.... the addition of dict in fixtures/list.pyi now causes __annotations__ to exist:

mypy/mypy/semanal.py

Lines 636 to 639 in bac9e77

elif name == "__annotations__":
sym = self.lookup_qualified("__builtins__.dict", Context(), suppress_errors=True)
if not sym:
continue

which then makes __annotations__ one of the names in the diff (reasonably...):

mypy/mypy/server/update.py

Lines 781 to 795 in a08388c

if item.count(".") <= package_nesting_level + 1 and item.split(".")[-1] not in (
"__builtins__",
"__file__",
"__name__",
"__package__",
"__doc__",
):
# Activate catch-all wildcard trigger for top-level module changes (used for
# "from m import *"). This also gets triggered by changes to module-private
# entries, but as these unneeded dependencies only result in extra processing,
# it's a minor problem.
#
# TODO: Some __* names cause mistriggers. Fix the underlying issue instead of
# special casing them here.
diff.add(id + WILDCARD_TAG)

I've created #14547 to discuss this more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in #14550 and #14575.

[out]
main:2: note: Revealed type is "contextlib.GeneratorContextManager[None]"
==
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/attr.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ class complex:
class str: pass
class ellipsis: pass
class tuple: pass
class dict: pass
1 change: 1 addition & 0 deletions test-data/unit/fixtures/bool.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ class bool(int): pass
class float: pass
class str: pass
class ellipsis: pass
class dict: pass
class list(Generic[T]): pass
class property: pass
1 change: 1 addition & 0 deletions test-data/unit/fixtures/callable.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ class str:
def __eq__(self, other: 'str') -> bool: pass
class ellipsis: pass
class list: ...
class dict: pass
1 change: 1 addition & 0 deletions test-data/unit/fixtures/classmethod.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ class str: pass
class bytes: pass
class bool: pass
class ellipsis: pass
class dict: pass

class tuple(typing.Generic[_T]): pass
1 change: 1 addition & 0 deletions test-data/unit/fixtures/exception.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class int: pass
class str: pass
class bool: pass
class ellipsis: pass
class dict: pass

class BaseException:
def __init__(self, *args: object) -> None: ...
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/list.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ class float:
class str:
def __len__(self) -> bool: pass
class bool(int): pass
class dict: pass

property = object() # Dummy definition.
1 change: 1 addition & 0 deletions test-data/unit/fixtures/tuple.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class bool(int): pass
class str: pass # For convenience
class bytes: pass
class bytearray: pass
class dict: pass

class list(Sequence[T], Generic[T]):
@overload
Expand Down
3 changes: 3 additions & 0 deletions test-data/unit/lib-stub/attr/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,6 @@ def field(
order: Optional[bool] = ...,
on_setattr: Optional[object] = ...,
) -> Any: ...

def evolve(inst: _T, **changes: Any) -> _T: ...
def assoc(inst: _T, **changes: Any) -> _T: ...
3 changes: 3 additions & 0 deletions test-data/unit/lib-stub/attrs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,6 @@ def field(
order: Optional[bool] = ...,
on_setattr: Optional[object] = ...,
) -> Any: ...

def evolve(inst: _T, **changes: Any) -> _T: ...
def assoc(inst: _T, **changes: Any) -> _T: ...
96 changes: 48 additions & 48 deletions test-data/unit/merge.test
Original file line number Diff line number Diff line change
Expand Up @@ -669,18 +669,18 @@ TypeInfo<2>(
Mro(target.N<2>, builtins.tuple<3>, typing.Sequence<4>, typing.Iterable<5>, builtins.object<1>)
Names(
_NT<6>
__annotations__<7> (builtins.object<1>)
__doc__<8> (builtins.str<9>)
__match_args__<10> (Tuple[Literal['x']])
__new__<11>
_asdict<12>
_field_defaults<13> (builtins.object<1>)
_field_types<14> (builtins.object<1>)
_fields<15> (Tuple[builtins.str<9>])
_make<16>
_replace<17>
_source<18> (builtins.str<9>)
x<19> (target.A<0>)))
__annotations__<7> (builtins.dict[builtins.str<8>, Any]<9>)
__doc__<10> (builtins.str<8>)
__match_args__<11> (Tuple[Literal['x']])
__new__<12>
_asdict<13>
_field_defaults<14> (builtins.dict[builtins.str<8>, Any]<9>)
_field_types<15> (builtins.dict[builtins.str<8>, Any]<9>)
_fields<16> (Tuple[builtins.str<8>])
_make<17>
_replace<18>
_source<19> (builtins.str<8>)
x<20> (target.A<0>)))
==>
TypeInfo<0>(
Name(target.A)
Expand All @@ -693,19 +693,19 @@ TypeInfo<2>(
Mro(target.N<2>, builtins.tuple<3>, typing.Sequence<4>, typing.Iterable<5>, builtins.object<1>)
Names(
_NT<6>
__annotations__<7> (builtins.object<1>)
__doc__<8> (builtins.str<9>)
__match_args__<10> (Tuple[Literal['x'], Literal['y']])
__new__<11>
_asdict<12>
_field_defaults<13> (builtins.object<1>)
_field_types<14> (builtins.object<1>)
_fields<15> (Tuple[builtins.str<9>, builtins.str<9>])
_make<16>
_replace<17>
_source<18> (builtins.str<9>)
x<19> (target.A<0>)
y<20> (target.A<0>)))
__annotations__<7> (builtins.dict[builtins.str<8>, Any]<9>)
__doc__<10> (builtins.str<8>)
__match_args__<11> (Tuple[Literal['x'], Literal['y']])
__new__<12>
_asdict<13>
_field_defaults<14> (builtins.dict[builtins.str<8>, Any]<9>)
_field_types<15> (builtins.dict[builtins.str<8>, Any]<9>)
_fields<16> (Tuple[builtins.str<8>, builtins.str<8>])
_make<17>
_replace<18>
_source<19> (builtins.str<8>)
x<20> (target.A<0>)
y<21> (target.A<0>)))

[case testNamedTupleOldVersion_typeinfo]
import target
Expand All @@ -730,17 +730,17 @@ TypeInfo<2>(
Mro(target.N<2>, builtins.tuple<3>, typing.Sequence<4>, typing.Iterable<5>, builtins.object<1>)
Names(
_NT<6>
__annotations__<7> (builtins.object<1>)
__doc__<8> (builtins.str<9>)
__new__<10>
_asdict<11>
_field_defaults<12> (builtins.object<1>)
_field_types<13> (builtins.object<1>)
_fields<14> (Tuple[builtins.str<9>])
_make<15>
_replace<16>
_source<17> (builtins.str<9>)
x<18> (target.A<0>)))
__annotations__<7> (builtins.dict[builtins.str<8>, Any]<9>)
__doc__<10> (builtins.str<8>)
__new__<11>
_asdict<12>
_field_defaults<13> (builtins.dict[builtins.str<8>, Any]<9>)
_field_types<14> (builtins.dict[builtins.str<8>, Any]<9>)
_fields<15> (Tuple[builtins.str<8>])
_make<16>
_replace<17>
_source<18> (builtins.str<8>)
x<19> (target.A<0>)))
==>
TypeInfo<0>(
Name(target.A)
Expand All @@ -753,18 +753,18 @@ TypeInfo<2>(
Mro(target.N<2>, builtins.tuple<3>, typing.Sequence<4>, typing.Iterable<5>, builtins.object<1>)
Names(
_NT<6>
__annotations__<7> (builtins.object<1>)
__doc__<8> (builtins.str<9>)
__new__<10>
_asdict<11>
_field_defaults<12> (builtins.object<1>)
_field_types<13> (builtins.object<1>)
_fields<14> (Tuple[builtins.str<9>, builtins.str<9>])
_make<15>
_replace<16>
_source<17> (builtins.str<9>)
x<18> (target.A<0>)
y<19> (target.A<0>)))
__annotations__<7> (builtins.dict[builtins.str<8>, Any]<9>)
__doc__<10> (builtins.str<8>)
__new__<11>
_asdict<12>
_field_defaults<13> (builtins.dict[builtins.str<8>, Any]<9>)
_field_types<14> (builtins.dict[builtins.str<8>, Any]<9>)
_fields<15> (Tuple[builtins.str<8>, builtins.str<8>])
_make<16>
_replace<17>
_source<18> (builtins.str<8>)
x<19> (target.A<0>)
y<20> (target.A<0>)))

[case testUnionType_types]
import target
Expand Down