Skip to content

Commit

Permalink
Implementing background infrastructure for recursive types: Part 1 (#…
Browse files Browse the repository at this point in the history
…7330)

During planning discussions one of the main concerns about recursive types was the fact that we have hundreds of places where certain types are special-cased using `isinstance()`, and fixing all of them will take weeks.

So I did a little experiment this weekend, to understand how bad it _actually_ is. I wrote a simple mypy plugin for mypy self-check, and it discovered 800+ such call sites. This looks pretty bad, but it turns out that fixing half of them (roughly 400 plugin errors) took me less than 2 days. This is kind of a triumph of our tooling :-) (i.e. mypy plugin + PyCharm plugin).

Taking into account results of this experiment I propose to actually go ahead and implement recursive types. Here are some comments:

* There will be four subsequent PRs: second part of `isinstance()` cleanup, implementing visitors and related methods everywhere, actual core implementation, adding extra tests for tricky recursion patterns.
* The core idea of implementation stays the same as we discussed with @JukkaL: `TypeAliasType` and `TypeAlias` node will essentially match logic between `Instance` and `TypeInfo` (but structurally, as for protocols)
* I wanted to make `PlaceholderType` a non-`ProperType`, but it didn't work immediately because we call `make_union()` during semantic analysis. If this seems important, this can be done with a bit more effort.
* I make `TypeType.item` a proper type (following PEP 484, only very limited things can be passed to `Type[...]`). I also make `UnionType.items` proper types, mostly because of `make_simplified_union()`. Finally, I make `FuncBase.type` a proper type, I think a type alias can never appear there.
* It is sometimes hard to decide where exactly is to call `get_proper_type()`, I tried to balance calling them not too soon and not too late, depending of every individual case. Please review, I am open to modifying logic in some places.
  • Loading branch information
ilevkivskyi authored Aug 16, 2019
1 parent 7fb7e26 commit e04bf78
Show file tree
Hide file tree
Showing 43 changed files with 659 additions and 320 deletions.
70 changes: 70 additions & 0 deletions misc/proper_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from mypy.plugin import Plugin, FunctionContext
from mypy.types import Type, Instance, CallableType, UnionType, get_proper_type

import os.path
from typing_extensions import Type as typing_Type
from typing import Optional, Callable

FILE_WHITELIST = [
'checker.py',
'checkexpr.py',
'checkmember.py',
'messages.py',
'semanal.py',
'typeanal.py'
]


class ProperTypePlugin(Plugin):
"""
A plugin to ensure that every type is expanded before doing any special-casing.
This solves the problem that we have hundreds of call sites like:
if isinstance(typ, UnionType):
... # special-case union
But after introducing a new type TypeAliasType (and removing immediate expansion)
all these became dangerous because typ may be e.g. an alias to union.
"""
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == 'builtins.isinstance':
return isinstance_proper_hook
return None


def isinstance_proper_hook(ctx: FunctionContext) -> Type:
if os.path.split(ctx.api.path)[-1] in FILE_WHITELIST:
return ctx.default_return_type
for arg in ctx.arg_types[0]:
if is_improper_type(arg):
right = get_proper_type(ctx.arg_types[1][0])
if isinstance(right, CallableType) and right.is_type_obj():
if right.type_object().fullname() in ('mypy.types.Type',
'mypy.types.ProperType',
'mypy.types.TypeAliasType'):
# Special case: things like assert isinstance(typ, ProperType) are always OK.
return ctx.default_return_type
if right.type_object().fullname() in ('mypy.types.UnboundType',
'mypy.types.TypeVarType'):
# Special case: these are not valid targets for a type alias and thus safe.
return ctx.default_return_type
ctx.api.fail('Never apply isinstance() to unexpanded types;'
' use mypy.types.get_proper_type() first', ctx.context)
return ctx.default_return_type


def is_improper_type(typ: Type) -> bool:
"""Is this a type that is not a subtype of ProperType?"""
typ = get_proper_type(typ)
if isinstance(typ, Instance):
info = typ.type
return info.has_base('mypy.types.Type') and not info.has_base('mypy.types.ProperType')
if isinstance(typ, UnionType):
return any(is_improper_type(t) for t in typ.items)
return False


def plugin(version: str) -> typing_Type[ProperTypePlugin]:
return ProperTypePlugin
8 changes: 5 additions & 3 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import mypy.subtypes
import mypy.sametypes
from mypy.expandtype import expand_type
from mypy.types import Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType
from mypy.types import (
Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types
)
from mypy.messages import MessageBuilder
from mypy.nodes import Context

Expand All @@ -25,10 +27,10 @@ def apply_generic_arguments(callable: CallableType, orig_types: Sequence[Optiona
assert len(tvars) == len(orig_types)
# Check that inferred type variable values are compatible with allowed
# values and bounds. Also, promote subtype values to allowed values.
types = list(orig_types)
types = get_proper_types(orig_types)
for i, type in enumerate(types):
assert not isinstance(type, PartialType), "Internal error: must never apply partial type"
values = callable.variables[i].values
values = get_proper_types(callable.variables[i].values)
if type is None:
continue
if values:
Expand Down
9 changes: 6 additions & 3 deletions mypy/argmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from typing import List, Optional, Sequence, Callable, Set

from mypy.types import Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType
from mypy.types import (
Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType, get_proper_type
)
from mypy import nodes


Expand Down Expand Up @@ -34,7 +36,7 @@ def map_actuals_to_formals(actual_kinds: List[int],
formal_to_actual[fi].append(ai)
elif actual_kind == nodes.ARG_STAR:
# We need to know the actual type to map varargs.
actualt = actual_arg_type(ai)
actualt = get_proper_type(actual_arg_type(ai))
if isinstance(actualt, TupleType):
# A tuple actual maps to a fixed number of formals.
for _ in range(len(actualt.items)):
Expand Down Expand Up @@ -65,7 +67,7 @@ def map_actuals_to_formals(actual_kinds: List[int],
formal_to_actual[formal_kinds.index(nodes.ARG_STAR2)].append(ai)
else:
assert actual_kind == nodes.ARG_STAR2
actualt = actual_arg_type(ai)
actualt = get_proper_type(actual_arg_type(ai))
if isinstance(actualt, TypedDictType):
for name, value in actualt.items.items():
if name in formal_names:
Expand Down Expand Up @@ -153,6 +155,7 @@ def expand_actual_type(self,
This is supposed to be called for each formal, in order. Call multiple times per
formal if multiple actuals map to a formal.
"""
actual_type = get_proper_type(actual_type)
if actual_kind == nodes.ARG_STAR:
if isinstance(actual_type, Instance):
if actual_type.type.fullname() == 'builtins.list':
Expand Down
17 changes: 12 additions & 5 deletions mypy/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import Dict, List, Set, Iterator, Union, Optional, Tuple, cast
from typing_extensions import DefaultDict

from mypy.types import Type, AnyType, PartialType, UnionType, TypeOfAny, NoneType
from mypy.types import (
Type, AnyType, PartialType, UnionType, TypeOfAny, NoneType, get_proper_type
)
from mypy.subtypes import is_subtype
from mypy.join import join_simple
from mypy.sametypes import is_same_type
Expand Down Expand Up @@ -191,7 +193,7 @@ def update_from_options(self, frames: List[Frame]) -> bool:

type = resulting_values[0]
assert type is not None
declaration_type = self.declarations.get(key)
declaration_type = get_proper_type(self.declarations.get(key))
if isinstance(declaration_type, AnyType):
# At this point resulting values can't contain None, see continue above
if not all(is_same_type(type, cast(Type, t)) for t in resulting_values[1:]):
Expand Down Expand Up @@ -246,6 +248,9 @@ def assign_type(self, expr: Expression,
type: Type,
declared_type: Optional[Type],
restrict_any: bool = False) -> None:
type = get_proper_type(type)
declared_type = get_proper_type(declared_type)

if self.type_assignments is not None:
# We are in a multiassign from union, defer the actual binding,
# just collect the types.
Expand All @@ -270,7 +275,7 @@ def assign_type(self, expr: Expression,
# times?
return

enclosing_type = self.most_recent_enclosing_type(expr, type)
enclosing_type = get_proper_type(self.most_recent_enclosing_type(expr, type))
if isinstance(enclosing_type, AnyType) and not restrict_any:
# If x is Any and y is int, after x = y we do not infer that x is int.
# This could be changed.
Expand All @@ -287,7 +292,8 @@ def assign_type(self, expr: Expression,
elif (isinstance(type, AnyType)
and isinstance(declared_type, UnionType)
and any(isinstance(item, NoneType) for item in declared_type.items)
and isinstance(self.most_recent_enclosing_type(expr, NoneType()), NoneType)):
and isinstance(get_proper_type(self.most_recent_enclosing_type(expr, NoneType())),
NoneType)):
# Replace any Nones in the union type with Any
new_items = [type if isinstance(item, NoneType) else item
for item in declared_type.items]
Expand Down Expand Up @@ -320,6 +326,7 @@ def invalidate_dependencies(self, expr: BindableExpression) -> None:
self._cleanse_key(dep)

def most_recent_enclosing_type(self, expr: BindableExpression, type: Type) -> Optional[Type]:
type = get_proper_type(type)
if isinstance(type, AnyType):
return get_declaration(expr)
key = literal_hash(expr)
Expand Down Expand Up @@ -412,7 +419,7 @@ def top_frame_context(self) -> Iterator[Frame]:

def get_declaration(expr: BindableExpression) -> Optional[Type]:
if isinstance(expr, RefExpr) and isinstance(expr.node, Var):
type = expr.node.type
type = get_proper_type(expr.node.type)
if not isinstance(type, PartialType):
return type
return None
8 changes: 4 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
impl_type = None # type: Optional[CallableType]
if defn.impl:
if isinstance(defn.impl, FuncDef):
inner_type = defn.impl.type
inner_type = defn.impl.type # type: Optional[Type]
elif isinstance(defn.impl, Decorator):
inner_type = defn.impl.var.type
else:
Expand Down Expand Up @@ -3650,8 +3650,8 @@ def find_isinstance_check(self, node: Expression
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
# respectively
vartype = type_map[node]
if_type = true_only(vartype)
else_type = false_only(vartype)
if_type = true_only(vartype) # type: Type
else_type = false_only(vartype) # type: Type
ref = node # type: Expression
if_map = {ref: if_type} if not isinstance(if_type, UninhabitedType) else None
else_map = {ref: else_type} if not isinstance(else_type, UninhabitedType) else None
Expand Down Expand Up @@ -4139,7 +4139,7 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap:
# expressions whose type is refined by both conditions. (We do not
# learn anything about expressions whose type is refined by only
# one condition.)
result = {}
result = {} # type: Dict[Expression, Type]
for n1 in m1:
for n2 in m2:
if literal_hash(n1) == literal_hash(n2):
Expand Down
11 changes: 6 additions & 5 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike,
StarType, is_optional, remove_optional, is_generic_instance
StarType, is_optional, remove_optional, is_generic_instance, get_proper_type
)
from mypy.nodes import (
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
Expand Down Expand Up @@ -587,6 +587,7 @@ def apply_function_plugin(self,
# Apply method plugin
method_callback = self.plugin.get_method_hook(fullname)
assert method_callback is not None # Assume that caller ensures this
object_type = get_proper_type(object_type)
return method_callback(
MethodContext(object_type, formal_arg_types, formal_arg_kinds,
callee.arg_names, formal_arg_names,
Expand All @@ -608,6 +609,7 @@ def apply_method_signature_hook(
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_exprs[formal].append(args[actual])
object_type = get_proper_type(object_type)
return signature_hook(
MethodSigContext(object_type, formal_arg_exprs, callee, context, self.chk))
else:
Expand Down Expand Up @@ -2710,7 +2712,7 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression)
else:
typ = self.accept(index)
if isinstance(typ, UnionType):
key_types = typ.items
key_types = list(typ.items) # type: List[Type]
else:
key_types = [typ]

Expand Down Expand Up @@ -3549,7 +3551,7 @@ def has_member(self, typ: Type, member: str) -> bool:
elif isinstance(typ, TypeType):
# Type[Union[X, ...]] is always normalized to Union[Type[X], ...],
# so we don't need to care about unions here.
item = typ.item
item = typ.item # type: Type
if isinstance(item, TypeVarType):
item = item.upper_bound
if isinstance(item, TupleType):
Expand Down Expand Up @@ -3743,8 +3745,7 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type, # noqa
not is_overlapping_types(known_type, restriction,
prohibit_none_typevar_overlap=True)):
return None
ans = narrow_declared_type(known_type, restriction)
return ans
return narrow_declared_type(known_type, restriction)
return known_type


Expand Down
11 changes: 6 additions & 5 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mypy.types import (
Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarDef,
Overloaded, TypeVarType, UnionType, PartialType, UninhabitedType, TypeOfAny, LiteralType,
DeletedType, NoneType, TypeType, function_type, get_type_vars,
DeletedType, NoneType, TypeType, function_type, get_type_vars, get_proper_type
)
from mypy.nodes import (
TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context, MypyFile, TypeVarExpr,
Expand Down Expand Up @@ -371,8 +371,8 @@ def analyze_member_var_access(name: str,
fullname = '{}.{}'.format(method.info.fullname(), name)
hook = mx.chk.plugin.get_attribute_hook(fullname)
if hook:
result = hook(AttributeContext(mx.original_type, result,
mx.context, mx.chk))
result = hook(AttributeContext(get_proper_type(mx.original_type),
result, mx.context, mx.chk))
return result
else:
setattr_meth = info.get_method('__setattr__')
Expand Down Expand Up @@ -511,7 +511,7 @@ def analyze_var(name: str,
mx.msg.read_only_property(name, itype.type, mx.context)
if mx.is_lvalue and var.is_classvar:
mx.msg.cant_assign_to_classvar(name, mx.context)
result = t
result = t # type: Type
if var.is_initialized_in_class and isinstance(t, FunctionLike) and not t.is_type_obj():
if mx.is_lvalue:
if var.is_property:
Expand Down Expand Up @@ -552,7 +552,8 @@ def analyze_var(name: str,
result = analyze_descriptor_access(mx.original_type, result, mx.builtin_type,
mx.msg, mx.context, chk=mx.chk)
if hook:
result = hook(AttributeContext(mx.original_type, result, mx.context, mx.chk))
result = hook(AttributeContext(get_proper_type(mx.original_type),
result, mx.context, mx.chk))
return result


Expand Down
4 changes: 2 additions & 2 deletions mypy/checkstrformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Final, TYPE_CHECKING

from mypy.types import (
Type, AnyType, TupleType, Instance, UnionType, TypeOfAny
Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type
)
from mypy.nodes import (
StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr
Expand Down Expand Up @@ -137,7 +137,7 @@ def check_simple_str_interpolation(self, specifiers: List[ConversionSpecifier],
if checkers is None:
return

rhs_type = self.accept(replacements)
rhs_type = get_proper_type(self.accept(replacements))
rep_types = [] # type: List[Type]
if isinstance(rhs_type, TupleType):
rep_types = rhs_type.items
Expand Down
13 changes: 8 additions & 5 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance,
TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType,
ProperType, get_proper_type
)
from mypy.maptype import map_instance_to_supertype
import mypy.subtypes
Expand Down Expand Up @@ -88,6 +89,8 @@ def infer_constraints(template: Type, actual: Type,
The constraints are represented as Constraint objects.
"""
template = get_proper_type(template)
actual = get_proper_type(actual)

# If the template is simply a type variable, emit a Constraint directly.
# We need to handle this case before handling Unions for two reasons:
Expand Down Expand Up @@ -199,12 +202,12 @@ def is_same_constraint(c1: Constraint, c2: Constraint) -> bool:
and mypy.sametypes.is_same_type(c1.target, c2.target))


def simplify_away_incomplete_types(types: List[Type]) -> List[Type]:
def simplify_away_incomplete_types(types: Iterable[Type]) -> List[Type]:
complete = [typ for typ in types if is_complete_type(typ)]
if complete:
return complete
else:
return types
return list(types)


def is_complete_type(typ: Type) -> bool:
Expand All @@ -229,9 +232,9 @@ class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]):

# The type that is compared against a template
# TODO: The value may be None. Is that actually correct?
actual = None # type: Type
actual = None # type: ProperType

def __init__(self, actual: Type, direction: int) -> None:
def __init__(self, actual: ProperType, direction: int) -> None:
# Direction must be SUBTYPE_OF or SUPERTYPE_OF.
self.actual = actual
self.direction = direction
Expand Down Expand Up @@ -298,7 +301,7 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
if isinstance(actual, Instance):
instance = actual
erased = erase_typevars(template)
assert isinstance(erased, Instance)
assert isinstance(erased, Instance) # type: ignore
# We always try nominal inference if possible,
# it is much faster than the structural one.
if (self.direction == SUBTYPE_OF and
Expand Down
Loading

0 comments on commit e04bf78

Please sign in to comment.