Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some cleanup in partial plugin #17423

Merged
merged 2 commits into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,8 @@ def apply_function_plugin(
formal_arg_exprs[formal].append(args[actual])
if arg_names:
formal_arg_names[formal].append(arg_names[actual])
else:
formal_arg_names[formal].append(None)
formal_arg_kinds[formal].append(arg_kinds[actual])

if object_type is None:
Expand Down
31 changes: 24 additions & 7 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Type,
TypeOfAny,
UnboundType,
UninhabitedType,
get_proper_type,
)

Expand Down Expand Up @@ -132,6 +131,9 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
if fn_type is None:
return ctx.default_return_type

# We must normalize from the start to have coherent view together with TypeChecker.
fn_type = fn_type.with_unpacked_kwargs().with_normalized_var_args()

defaulted = fn_type.copy_modified(
arg_kinds=[
(
Expand All @@ -146,10 +148,25 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
# Make up a line number if we don't have one
defaulted.set_line(ctx.default_return_type)

actual_args = [a for param in ctx.args[1:] for a in param]
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param]
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param]
actual_types = [a for param in ctx.arg_types[1:] for a in param]
# Flatten actual to formal mapping, since this is what check_call() expects.
actual_args = []
actual_arg_kinds = []
actual_arg_names = []
actual_types = []
seen_args = set()
for i, param in enumerate(ctx.args[1:], start=1):
for j, a in enumerate(param):
if a in seen_args:
# Same actual arg can map to multiple formals, but we need to include
# each one only once.
continue
# Here we rely on the fact that expressions are essentially immutable, so
# they can be compared by identity.
seen_args.add(a)
actual_args.append(a)
actual_arg_kinds.append(ctx.arg_kinds[i][j])
actual_arg_names.append(ctx.arg_names[i][j])
actual_types.append(ctx.arg_types[i][j])

# Create a valid context for various ad-hoc inspections in check_call().
call_expr = CallExpr(
Expand Down Expand Up @@ -188,7 +205,7 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
for i, actuals in enumerate(formal_to_actual):
if len(bound.arg_types) == len(fn_type.arg_types):
arg_type = bound.arg_types[i]
if isinstance(get_proper_type(arg_type), UninhabitedType):
if not mypy.checker.is_valid_inferred_type(arg_type):
arg_type = fn_type.arg_types[i] # bit of a hack
else:
# TODO: I assume that bound and fn_type have the same arguments. It appears this isn't
Expand All @@ -210,7 +227,7 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
partial_names.append(fn_type.arg_names[i])

ret_type = bound.ret_type
if isinstance(get_proper_type(ret_type), UninhabitedType):
if not mypy.checker.is_valid_inferred_type(ret_type):
ret_type = fn_type.ret_type # same kind of hack as above

partially_applied = fn_type.copy_modified(
Expand Down
52 changes: 52 additions & 0 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,55 @@ def foo(cls3: Type[B[T]]):
reveal_type(functools.partial(cls3, 2)()) # N: Revealed type is "__main__.B[T`-1]" \
# E: Argument 1 to "B" has incompatible type "int"; expected "T"
[builtins fixtures/tuple.pyi]

[case testFunctoolsPartialTypedDictUnpack]
from typing_extensions import TypedDict, Unpack
from functools import partial

class Data(TypedDict, total=False):
x: int

def f(**kwargs: Unpack[Data]) -> None: ...
def g(**kwargs: Unpack[Data]) -> None:
partial(f, **kwargs)()

class MoreData(TypedDict, total=False):
x: int
y: int

def f_more(**kwargs: Unpack[MoreData]) -> None: ...
def g_more(**kwargs: Unpack[MoreData]) -> None:
partial(f_more, **kwargs)()

class Good(TypedDict, total=False):
y: int
class Bad(TypedDict, total=False):
y: str

def h(**kwargs: Unpack[Data]) -> None:
bad: Bad
partial(f_more, **kwargs)(**bad) # E: Argument "y" to "f_more" has incompatible type "str"; expected "int"
good: Good
partial(f_more, **kwargs)(**good)
[builtins fixtures/dict.pyi]

[case testFunctoolsPartialNestedGeneric]
from functools import partial
from typing import Generic, TypeVar, List

T = TypeVar("T")
def get(n: int, args: List[T]) -> T: ...
first = partial(get, 0)

x: List[str]
reveal_type(first(x)) # N: Revealed type is "builtins.str"
reveal_type(first([1])) # N: Revealed type is "builtins.int"

first_kw = partial(get, n=0)
reveal_type(first_kw(args=[1])) # N: Revealed type is "builtins.int"

# TODO: this is indeed invalid, but the error is incomprehensible.
first_kw([1]) # E: Too many positional arguments for "get" \
# E: Too few arguments for "get" \
# E: Argument 1 to "get" has incompatible type "List[int]"; expected "int"
[builtins fixtures/list.pyi]
Loading