Skip to content

Commit

Permalink
Improve usage of outer context for inference (#5699)
Browse files Browse the repository at this point in the history
Fixes #4872 
Fixes #3876
Fixes #2678 
Fixes #5199 
Fixes #5493 
(It also fixes a bunch of similar issues previously closed as duplicates, except one, see below).

This PR fixes a problems when mypy commits to soon to using outer context for type inference. This is done by:
* Postponing inference to inner (argument) context in situations where type inferred from outer (return) context doesn't satisfy bounds or constraints.
* Adding a special case for situation where optional return is inferred against optional context. In such situation, unwrapping the optional is a better idea in 99% of cases. (Note: this doesn't affect type safety, only gives empirically more reasonable inferred types.)

In general, instead of adding a special case, it would be better to use inner and outer context at the same time, but this a big change (see comment in code), and using the simple special case fixes majority of issues. Among reported issues, only #5311 will stay unfixed.
  • Loading branch information
ilevkivskyi authored Oct 2, 2018
1 parent baa4725 commit 626ff68
Show file tree
Hide file tree
Showing 8 changed files with 393 additions and 47 deletions.
38 changes: 30 additions & 8 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@


def apply_generic_arguments(callable: CallableType, orig_types: Sequence[Optional[Type]],
msg: MessageBuilder, context: Context) -> CallableType:
msg: MessageBuilder, context: Context,
skip_unsatisfied: bool = False) -> CallableType:
"""Apply generic type arguments to a callable type.
For example, applying [int] to 'def [T] (T) -> T' results in
'def (int) -> int'.
Note that each type can be None; in this case, it will not be applied.
If `skip_unsatisfied` is True, then just skip the types that don't satisfy type variable
bound or constraints, instead of giving an error.
"""
tvars = callable.variables
assert len(tvars) == len(orig_types)
Expand All @@ -25,7 +29,9 @@ def apply_generic_arguments(callable: CallableType, orig_types: Sequence[Optiona
for i, type in enumerate(types):
assert not isinstance(type, PartialType), "Internal error: must never apply partial type"
values = callable.variables[i].values
if values and type:
if type is None:
continue
if values:
if isinstance(type, AnyType):
continue
if isinstance(type, TypeVarType) and type.values:
Expand All @@ -34,15 +40,31 @@ def apply_generic_arguments(callable: CallableType, orig_types: Sequence[Optiona
if all(any(is_same_type(v, v1) for v in values)
for v1 in type.values):
continue
matching = []
for value in values:
if mypy.subtypes.is_subtype(type, value):
types[i] = value
break
matching.append(value)
if matching:
best = matching[0]
# If there are more than one matching value, we select the narrowest
for match in matching[1:]:
if mypy.subtypes.is_subtype(match, best):
best = match
types[i] = best
else:
msg.incompatible_typevar_value(callable, type, callable.variables[i].name, context)
upper_bound = callable.variables[i].upper_bound
if type and not mypy.subtypes.is_subtype(type, upper_bound):
msg.incompatible_typevar_value(callable, type, callable.variables[i].name, context)
if skip_unsatisfied:
types[i] = None
else:
msg.incompatible_typevar_value(callable, type, callable.variables[i].name,
context)
else:
upper_bound = callable.variables[i].upper_bound
if not mypy.subtypes.is_subtype(type, upper_bound):
if skip_unsatisfied:
types[i] = None
else:
msg.incompatible_typevar_value(callable, type, callable.variables[i].name,
context)

# Create a map from type variable id to target type.
id_to_type = {} # type: Dict[TypeVarId, Type]
Expand Down
14 changes: 2 additions & 12 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType,
Instance, NoneTyp, strip_type, TypeType, TypeOfAny,
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef,
true_only, false_only, function_type, is_named_instance, union_items, TypeQuery
true_only, false_only, function_type, is_named_instance, union_items, TypeQuery,
is_optional, remove_optional
)
from mypy.sametypes import is_same_type, is_same_types
from mypy.messages import MessageBuilder, make_inferred_type_note
Expand Down Expand Up @@ -3792,17 +3793,6 @@ def is_literal_none(n: Expression) -> bool:
return isinstance(n, NameExpr) and n.fullname == 'builtins.None'


def is_optional(t: Type) -> bool:
return isinstance(t, UnionType) and any(isinstance(e, NoneTyp) for e in t.items)


def remove_optional(typ: Type) -> Type:
if isinstance(typ, UnionType):
return UnionType.make_union([t for t in typ.items if not isinstance(t, NoneTyp)])
else:
return typ


def is_literal_not_implemented(n: Expression) -> bool:
return isinstance(n, NameExpr) and n.fullname == 'builtins.NotImplemented'

Expand Down
61 changes: 41 additions & 20 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from mypy.types import (
Type, AnyType, CallableType, Overloaded, NoneTyp, TypeVarDef,
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, true_only,
false_only, is_named_instance, function_type, callable_type, FunctionLike, StarType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny,
true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike,
StarType, is_optional, remove_optional, is_invariant_instance
)
from mypy.nodes import (
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
Expand All @@ -30,7 +31,7 @@
ConditionalExpr, ComparisonExpr, TempNode, SetComprehension,
DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr,
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, ClassDef, Block, SymbolNode,
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode,
ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, LITERAL_TYPE, REVEAL_TYPE
)
from mypy.literals import literal
Expand Down Expand Up @@ -819,20 +820,36 @@ def infer_function_type_arguments_using_context(
# valid results.
erased_ctx = replace_meta_vars(ctx, ErasedType())
ret_type = callable.ret_type
if isinstance(ret_type, TypeVarType):
if ret_type.values or (not isinstance(ctx, Instance) or
not ctx.args):
# The return type is a type variable. If it has values, we can't easily restrict
# type inference to conform to the valid values. If it's unrestricted, we could
# infer a too general type for the type variable if we use context, and this could
# result in confusing and spurious type errors elsewhere.
#
# Give up and just use function arguments for type inference. As an exception,
# if the context is a generic instance type, actually use it as context, as
# this *seems* to usually be the reasonable thing to do.
#
# See also github issues #462 and #360.
ret_type = NoneTyp()
if is_optional(ret_type) and is_optional(ctx):
# If both the context and the return type are optional, unwrap the optional,
# since in 99% cases this is what a user expects. In other words, we replace
# Optional[T] <: Optional[int]
# with
# T <: int
# while the former would infer T <: Optional[int].
ret_type = remove_optional(ret_type)
erased_ctx = remove_optional(erased_ctx)
#
# TODO: Instead of this hack and the one below, we need to use outer and
# inner contexts at the same time. This is however not easy because of two
# reasons:
# * We need to support constraints like [1 <: 2, 2 <: X], i.e. with variables
# on both sides. (This is not too hard.)
# * We need to update all the inference "infrastructure", so that all
# variables in an expression are inferred at the same time.
# (And this is hard, also we need to be careful with lambdas that require
# two passes.)
if isinstance(ret_type, TypeVarType) and not is_invariant_instance(ctx):
# Another special case: the return type is a type variable. If it's unrestricted,
# we could infer a too general type for the type variable if we use context,
# and this could result in confusing and spurious type errors elsewhere.
#
# Give up and just use function arguments for type inference. As an exception,
# if the context is an invariant instance type, actually use it as context, as
# this *seems* to usually be the reasonable thing to do.
#
# See also github issues #462 and #360.
return callable.copy_modified()
args = infer_type_arguments(callable.type_var_ids(), ret_type, erased_ctx)
# Only substitute non-Uninhabited and non-erased types.
new_args = [] # type: List[Optional[Type]]
Expand All @@ -841,7 +858,10 @@ def infer_function_type_arguments_using_context(
new_args.append(None)
else:
new_args.append(arg)
return self.apply_generic_arguments(callable, new_args, error_context)
# Don't show errors after we have only used the outer context for inference.
# We will use argument context to infer more variables.
return self.apply_generic_arguments(callable, new_args, error_context,
skip_unsatisfied=True)

def infer_function_type_arguments(self, callee_type: CallableType,
args: List[Expression],
Expand Down Expand Up @@ -1609,9 +1629,10 @@ def check_arg(caller_type: Type, original_caller_type: Type, caller_kind: int,
return False

def apply_generic_arguments(self, callable: CallableType, types: Sequence[Optional[Type]],
context: Context) -> CallableType:
context: Context, skip_unsatisfied: bool = False) -> CallableType:
"""Simple wrapper around mypy.applytype.apply_generic_arguments."""
return applytype.apply_generic_arguments(callable, types, self.msg, context)
return applytype.apply_generic_arguments(callable, types, self.msg, context,
skip_unsatisfied=skip_unsatisfied)

def visit_member_expr(self, e: MemberExpr, is_lvalue: bool = False) -> Type:
"""Visit member expression (of form e.id)."""
Expand Down
17 changes: 17 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,23 @@ def union_items(typ: Type) -> List[Type]:
return [typ]


def is_invariant_instance(tp: Type) -> bool:
if not isinstance(tp, Instance) or not tp.args:
return False
return any(v.variance == INVARIANT for v in tp.type.defn.type_vars)


def is_optional(t: Type) -> bool:
return isinstance(t, UnionType) and any(isinstance(e, NoneTyp) for e in t.items)


def remove_optional(typ: Type) -> Type:
if isinstance(typ, UnionType):
return UnionType.make_union([t for t in typ.items if not isinstance(t, NoneTyp)])
else:
return typ


names = globals().copy() # type: Final
names.pop('NOT_READY', None)
deserialize_map = {
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def fun2(v: Vec[T], scale: T) -> Vec[T]:
return v

reveal_type(fun1([(1, 1)])) # E: Revealed type is 'builtins.int*'
fun1(1) # E: Argument 1 to "fun1" has incompatible type "int"; expected "List[Tuple[int, int]]"
fun1(1) # E: Argument 1 to "fun1" has incompatible type "int"; expected "List[Tuple[bool, bool]]"
fun1([(1, 'x')]) # E: Cannot infer type argument 1 of "fun1"

reveal_type(fun2([(1, 1)], 1)) # E: Revealed type is 'builtins.list[Tuple[builtins.int*, builtins.int*]]'
Expand Down
Loading

0 comments on commit 626ff68

Please sign in to comment.