Skip to content

Commit

Permalink
Fix crash with PartialTypes and the enum plugin (#14021)
Browse files Browse the repository at this point in the history
Fixes #12109.

The original issue reported that the bug had to do with the use of the
`--follow-imports=skip` flag. However, it turned out this was a red
herring after closer inspection: I was able to trigger a more minimal
repro both with and without this flag:

```python
from enum import Enum

class Foo(Enum):
    a = []  # E: Need type annotation for "a" (hint: "a: List[<type>] = ...")
    b = None

    def check(self) -> None:
        reveal_type(Foo.a.value)  # N: Revealed type is "<partial list[?]>"
        reveal_type(Foo.b.value)  # N: Revealed type is "<partial None>"
```

The first two `reveal_types` demonstrate the crux of the bug: the enum
plugin does not correctly handle and convert partial types into regular
types when inferring the type of the `.value` field.

This can then cause any number of downstream problems. For example,
suppose we modify `def check(...)` so it runs `reveal_type(self.value)`.
Doing this will trigger a crash in mypy because it makes the enum plugin
eventually try running `is_equivalent(...)` on the two partial types.
But `is_equivalent` does not support partial types, so we crash.

I opted to solve this problem by:

1. Making the enum plugin explicitly call the `fixup_partial_types`
function on all field types. This prevents the code from crashing.

2. Modifies mypy so that Final vars are never marked as being
PartialTypes. Without this, `reveal_type(Foo.b.value)` would report a
type of `Union[Any, None]` instead of just `None`. (Note that all enum
fields are implicitly final).
  • Loading branch information
Michael0x2a authored Nov 7, 2022
1 parent e8de6d1 commit d2a3e66
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 28 deletions.
50 changes: 24 additions & 26 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
erase_to_bound,
erase_to_union_or_bound,
false_only,
fixup_partial_type,
function_type,
get_type_vars,
is_literal_type_like,
Expand Down Expand Up @@ -2738,8 +2739,8 @@ def check_assignment(
# None initializers preserve the partial None type.
return

if is_valid_inferred_type(rvalue_type):
var = lvalue_type.var
var = lvalue_type.var
if is_valid_inferred_type(rvalue_type, is_lvalue_final=var.is_final):
partial_types = self.find_partial_types(var)
if partial_types is not None:
if not self.current_node_deferred:
Expand Down Expand Up @@ -3687,7 +3688,10 @@ def infer_variable_type(
"""Infer the type of initialized variables from initializer type."""
if isinstance(init_type, DeletedType):
self.msg.deleted_as_rvalue(init_type, context)
elif not is_valid_inferred_type(init_type) and not self.no_partial_types:
elif (
not is_valid_inferred_type(init_type, is_lvalue_final=name.is_final)
and not self.no_partial_types
):
# We cannot use the type of the initialization expression for full type
# inference (it's not specific enough), but we might be able to give
# partial type which will be made more specific later. A partial type
Expand Down Expand Up @@ -6114,7 +6118,7 @@ def enter_partial_types(
self.msg.need_annotation_for_var(var, context, self.options.python_version)
self.partial_reported.add(var)
if var.type:
fixed = self.fixup_partial_type(var.type)
fixed = fixup_partial_type(var.type)
var.invalid_partial_type = fixed != var.type
var.type = fixed

Expand Down Expand Up @@ -6145,20 +6149,7 @@ def handle_partial_var_type(
else:
# Defer the node -- we might get a better type in the outer scope
self.handle_cannot_determine_type(node.name, context)
return self.fixup_partial_type(typ)

def fixup_partial_type(self, typ: Type) -> Type:
"""Convert a partial type that we couldn't resolve into something concrete.
This means, for None we make it Optional[Any], and for anything else we
fill in all of the type arguments with Any.
"""
if not isinstance(typ, PartialType):
return typ
if typ.type is None:
return UnionType.make_union([AnyType(TypeOfAny.unannotated), NoneType()])
else:
return Instance(typ.type, [AnyType(TypeOfAny.unannotated)] * len(typ.type.type_vars))
return fixup_partial_type(typ)

def is_defined_in_base_class(self, var: Var) -> bool:
if var.info:
Expand Down Expand Up @@ -7006,20 +6997,27 @@ def infer_operator_assignment_method(typ: Type, operator: str) -> tuple[bool, st
return False, method


def is_valid_inferred_type(typ: Type) -> bool:
"""Is an inferred type valid?
def is_valid_inferred_type(typ: Type, is_lvalue_final: bool = False) -> bool:
"""Is an inferred type valid and needs no further refinement?
Examples of invalid types include the None type or List[<uninhabited>].
Examples of invalid types include the None type (when we are not assigning
None to a final lvalue) or List[<uninhabited>].
When not doing strict Optional checking, all types containing None are
invalid. When doing strict Optional checking, only None and types that are
incompletely defined (i.e. contain UninhabitedType) are invalid.
"""
if isinstance(get_proper_type(typ), (NoneType, UninhabitedType)):
# With strict Optional checking, we *may* eventually infer NoneType when
# the initializer is None, but we only do that if we can't infer a
# specific Optional type. This resolution happens in
# leave_partial_types when we pop a partial types scope.
proper_type = get_proper_type(typ)
if isinstance(proper_type, NoneType):
# If the lvalue is final, we may immediately infer NoneType when the
# initializer is None.
#
# If not, we want to defer making this decision. The final inferred
# type could either be NoneType or an Optional type, depending on
# the context. This resolution happens in leave_partial_types when
# we pop a partial types scope.
return is_lvalue_final
elif isinstance(proper_type, UninhabitedType):
return False
return not typ.accept(NothingSeeker())

Expand Down
3 changes: 2 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
custom_special_method,
erase_to_union_or_bound,
false_only,
fixup_partial_type,
function_type,
is_literal_type_like,
make_simplified_union,
Expand Down Expand Up @@ -2925,7 +2926,7 @@ def find_partial_type_ref_fast_path(self, expr: Expression) -> Type | None:
if isinstance(expr.node, Var):
result = self.analyze_var_ref(expr.node, expr)
if isinstance(result, PartialType) and result.type is not None:
self.chk.store_type(expr, self.chk.fixup_partial_type(result))
self.chk.store_type(expr, fixup_partial_type(result))
return result
return None

Expand Down
3 changes: 2 additions & 1 deletion mypy/plugins/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mypy.nodes import TypeInfo
from mypy.semanal_enum import ENUM_BASES
from mypy.subtypes import is_equivalent
from mypy.typeops import make_simplified_union
from mypy.typeops import fixup_partial_type, make_simplified_union
from mypy.types import CallableType, Instance, LiteralType, ProperType, Type, get_proper_type

ENUM_NAME_ACCESS: Final = {f"{prefix}.name" for prefix in ENUM_BASES} | {
Expand Down Expand Up @@ -77,6 +77,7 @@ def _infer_value_type_with_auto_fallback(
"""
if proper_type is None:
return None
proper_type = get_proper_type(fixup_partial_type(proper_type))
if not (isinstance(proper_type, Instance) and proper_type.type.fullname == "enum.auto"):
return proper_type
assert isinstance(ctx.type, Instance), "An incorrect ctx.type was passed."
Expand Down
15 changes: 15 additions & 0 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
Overloaded,
Parameters,
ParamSpecType,
PartialType,
ProperType,
TupleType,
Type,
Expand Down Expand Up @@ -1016,3 +1017,17 @@ def try_getting_instance_fallback(typ: Type) -> Instance | None:
elif isinstance(typ, TypeVarType):
return try_getting_instance_fallback(typ.upper_bound)
return None


def fixup_partial_type(typ: Type) -> Type:
"""Convert a partial type that we couldn't resolve into something concrete.
This means, for None we make it Optional[Any], and for anything else we
fill in all of the type arguments with Any.
"""
if not isinstance(typ, PartialType):
return typ
if typ.type is None:
return UnionType.make_union([AnyType(TypeOfAny.unannotated), NoneType()])
else:
return Instance(typ.type, [AnyType(TypeOfAny.unannotated)] * len(typ.type.type_vars))
27 changes: 27 additions & 0 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -2100,3 +2100,30 @@ class Some:
class A(Some, Enum):
__labels__ = {1: "1"}
[builtins fixtures/dict.pyi]

[case testEnumWithPartialTypes]
from enum import Enum

class Mixed(Enum):
a = [] # E: Need type annotation for "a" (hint: "a: List[<type>] = ...")
b = None

def check(self) -> None:
reveal_type(Mixed.a.value) # N: Revealed type is "builtins.list[Any]"
reveal_type(Mixed.b.value) # N: Revealed type is "None"

# Inferring Any here instead of a union seems to be a deliberate
# choice; see the testEnumValueInhomogenous case above.
reveal_type(self.value) # N: Revealed type is "Any"

for field in Mixed:
reveal_type(field.value) # N: Revealed type is "Any"
if field.value is None:
pass

class AllPartialList(Enum):
a = [] # E: Need type annotation for "a" (hint: "a: List[<type>] = ...")
b = [] # E: Need type annotation for "b" (hint: "b: List[<type>] = ...")

def check(self) -> None:
reveal_type(self.value) # N: Revealed type is "builtins.list[Any]"

0 comments on commit d2a3e66

Please sign in to comment.