Skip to content

Commit

Permalink
Support for hasattr() checks (#13544)
Browse files Browse the repository at this point in the history
Fixes #1424
Fixes mypyc/mypyc#939

Not that I really like `hasattr()` but this issue comes up surprisingly often. Also it looks like we can offer a simple solution that will cover 95% of use cases using `extra_attrs` for instances. Essentially the logic is like this:
* In the if branch, keep types that already has a valid attribute as is, for other inject an attribute with `Any` type using fallbacks.
* In the else branch, remove types that already have a valid attribute, while keeping the rest.
  • Loading branch information
ilevkivskyi authored Aug 29, 2022
1 parent c2949e9 commit b29051c
Show file tree
Hide file tree
Showing 12 changed files with 406 additions and 41 deletions.
100 changes: 98 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
true_only,
try_expanding_sum_type_to_union,
try_getting_int_literals_from_type,
try_getting_str_literals,
try_getting_str_literals_from_type,
tuple_fallback,
)
Expand Down Expand Up @@ -4701,7 +4702,7 @@ def _make_fake_typeinfo_and_full_name(
return None

curr_module.names[full_name] = SymbolTableNode(GDEF, info)
return Instance(info, [])
return Instance(info, [], extra_attrs=instances[0].extra_attrs or instances[1].extra_attrs)

def intersect_instance_callable(self, typ: Instance, callable_type: CallableType) -> Instance:
"""Creates a fake type that represents the intersection of an Instance and a CallableType.
Expand All @@ -4728,7 +4729,7 @@ def intersect_instance_callable(self, typ: Instance, callable_type: CallableType

cur_module.names[gen_name] = SymbolTableNode(GDEF, info)

return Instance(info, [])
return Instance(info, [], extra_attrs=typ.extra_attrs)

def make_fake_callable(self, typ: Instance) -> Instance:
"""Produce a new type that makes type Callable with a generic callable type."""
Expand Down Expand Up @@ -5032,6 +5033,12 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
if literal(expr) == LITERAL_TYPE:
vartype = self.lookup_type(expr)
return self.conditional_callable_type_map(expr, vartype)
elif refers_to_fullname(node.callee, "builtins.hasattr"):
if len(node.args) != 2: # the error will be reported elsewhere
return {}, {}
attr = try_getting_str_literals(node.args[1], self.lookup_type(node.args[1]))
if literal(expr) == LITERAL_TYPE and attr and len(attr) == 1:
return self.hasattr_type_maps(expr, self.lookup_type(expr), attr[0])
elif isinstance(node.callee, RefExpr):
if node.callee.type_guard is not None:
# TODO: Follow keyword args or *args, **kwargs
Expand Down Expand Up @@ -6239,6 +6246,95 @@ class Foo(Enum):
and member_type.fallback.type == parent_type.type_object()
)

def add_any_attribute_to_type(self, typ: Type, name: str) -> Type:
"""Inject an extra attribute with Any type using fallbacks."""
orig_typ = typ
typ = get_proper_type(typ)
any_type = AnyType(TypeOfAny.unannotated)
if isinstance(typ, Instance):
result = typ.copy_with_extra_attr(name, any_type)
# For instances, we erase the possible module name, so that restrictions
# become anonymous types.ModuleType instances, allowing hasattr() to
# have effect on modules.
assert result.extra_attrs is not None
result.extra_attrs.mod_name = None
return result
if isinstance(typ, TupleType):
fallback = typ.partial_fallback.copy_with_extra_attr(name, any_type)
return typ.copy_modified(fallback=fallback)
if isinstance(typ, CallableType):
fallback = typ.fallback.copy_with_extra_attr(name, any_type)
return typ.copy_modified(fallback=fallback)
if isinstance(typ, TypeType) and isinstance(typ.item, Instance):
return TypeType.make_normalized(self.add_any_attribute_to_type(typ.item, name))
if isinstance(typ, TypeVarType):
return typ.copy_modified(
upper_bound=self.add_any_attribute_to_type(typ.upper_bound, name),
values=[self.add_any_attribute_to_type(v, name) for v in typ.values],
)
if isinstance(typ, UnionType):
with_attr, without_attr = self.partition_union_by_attr(typ, name)
return make_simplified_union(
with_attr + [self.add_any_attribute_to_type(typ, name) for typ in without_attr]
)
return orig_typ

def hasattr_type_maps(
self, expr: Expression, source_type: Type, name: str
) -> tuple[TypeMap, TypeMap]:
"""Simple support for hasattr() checks.
Essentially the logic is following:
* In the if branch, keep types that already has a valid attribute as is,
for other inject an attribute with `Any` type.
* In the else branch, remove types that already have a valid attribute,
while keeping the rest.
"""
if self.has_valid_attribute(source_type, name):
return {expr: source_type}, {}

source_type = get_proper_type(source_type)
if isinstance(source_type, UnionType):
_, without_attr = self.partition_union_by_attr(source_type, name)
yes_map = {expr: self.add_any_attribute_to_type(source_type, name)}
return yes_map, {expr: make_simplified_union(without_attr)}

type_with_attr = self.add_any_attribute_to_type(source_type, name)
if type_with_attr != source_type:
return {expr: type_with_attr}, {}
return {}, {}

def partition_union_by_attr(
self, source_type: UnionType, name: str
) -> tuple[list[Type], list[Type]]:
with_attr = []
without_attr = []
for item in source_type.items:
if self.has_valid_attribute(item, name):
with_attr.append(item)
else:
without_attr.append(item)
return with_attr, without_attr

def has_valid_attribute(self, typ: Type, name: str) -> bool:
if isinstance(get_proper_type(typ), AnyType):
return False
with self.msg.filter_errors() as watcher:
analyze_member_access(
name,
typ,
TempNode(AnyType(TypeOfAny.special_form)),
False,
False,
False,
self.msg,
original_type=typ,
chk=self,
# This is not a real attribute lookup so don't mess with deferring nodes.
no_deferral=True,
)
return not watcher.has_new_errors()


class CollectArgTypes(TypeTraverserVisitor):
"""Collects the non-nested argument types in a set."""
Expand Down
2 changes: 2 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ def module_type(self, node: MypyFile) -> Instance:
module_attrs = {}
immutable = set()
for name, n in node.names.items():
if not n.module_public:
continue
if isinstance(n.node, Var) and n.node.is_final:
immutable.add(name)
typ = self.chk.determine_type_of_member(n)
Expand Down
16 changes: 15 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
chk: mypy.checker.TypeChecker,
self_type: Type | None,
module_symbol_table: SymbolTable | None = None,
no_deferral: bool = False,
) -> None:
self.is_lvalue = is_lvalue
self.is_super = is_super
Expand All @@ -100,6 +101,7 @@ def __init__(
self.msg = msg
self.chk = chk
self.module_symbol_table = module_symbol_table
self.no_deferral = no_deferral

def named_type(self, name: str) -> Instance:
return self.chk.named_type(name)
Expand All @@ -124,6 +126,7 @@ def copy_modified(
self.chk,
self.self_type,
self.module_symbol_table,
self.no_deferral,
)
if messages is not None:
mx.msg = messages
Expand All @@ -149,6 +152,7 @@ def analyze_member_access(
in_literal_context: bool = False,
self_type: Type | None = None,
module_symbol_table: SymbolTable | None = None,
no_deferral: bool = False,
) -> Type:
"""Return the type of attribute 'name' of 'typ'.
Expand Down Expand Up @@ -183,6 +187,7 @@ def analyze_member_access(
chk=chk,
self_type=self_type,
module_symbol_table=module_symbol_table,
no_deferral=no_deferral,
)
result = _analyze_member_access(name, typ, mx, override_info)
possible_literal = get_proper_type(result)
Expand Down Expand Up @@ -540,6 +545,11 @@ def analyze_member_var_access(
return AnyType(TypeOfAny.special_form)

# Could not find the member.
if itype.extra_attrs and name in itype.extra_attrs.attrs:
# For modules use direct symbol table lookup.
if not itype.extra_attrs.mod_name:
return itype.extra_attrs.attrs[name]

if mx.is_super:
mx.msg.undefined_in_superclass(name, mx.context)
return AnyType(TypeOfAny.from_error)
Expand Down Expand Up @@ -744,7 +754,7 @@ def analyze_var(
else:
result = expanded_signature
else:
if not var.is_ready:
if not var.is_ready and not mx.no_deferral:
mx.not_ready_callback(var.name, mx.context)
# Implicit 'Any' type.
result = AnyType(TypeOfAny.special_form)
Expand Down Expand Up @@ -858,6 +868,10 @@ def analyze_class_attribute_access(

node = info.get(name)
if not node:
if itype.extra_attrs and name in itype.extra_attrs.attrs:
# For modules use direct symbol table lookup.
if not itype.extra_attrs.mod_name:
return itype.extra_attrs.attrs[name]
if info.fallback_to_any:
return apply_class_attr_hook(mx, hook, AnyType(TypeOfAny.special_form))
return None
Expand Down
24 changes: 22 additions & 2 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from mypy.erasetype import erase_type
from mypy.maptype import map_instance_to_supertype
from mypy.state import state
from mypy.subtypes import is_callable_compatible, is_equivalent, is_proper_subtype, is_subtype
from mypy.subtypes import (
is_callable_compatible,
is_equivalent,
is_proper_subtype,
is_same_type,
is_subtype,
)
from mypy.typeops import is_recursive_pair, make_simplified_union, tuple_fallback
from mypy.types import (
AnyType,
Expand Down Expand Up @@ -61,11 +67,25 @@ def meet_types(s: Type, t: Type) -> ProperType:
"""Return the greatest lower bound of two types."""
if is_recursive_pair(s, t):
# This case can trigger an infinite recursion, general support for this will be
# tricky so we use a trivial meet (like for protocols).
# tricky, so we use a trivial meet (like for protocols).
return trivial_meet(s, t)
s = get_proper_type(s)
t = get_proper_type(t)

if isinstance(s, Instance) and isinstance(t, Instance) and s.type == t.type:
# Code in checker.py should merge any extra_items where possible, so we
# should have only compatible extra_items here. We check this before
# the below subtype check, so that extra_attrs will not get erased.
if is_same_type(s, t) and (s.extra_attrs or t.extra_attrs):
if s.extra_attrs and t.extra_attrs:
if len(s.extra_attrs.attrs) > len(t.extra_attrs.attrs):
# Return the one that has more precise information.
return s
return t
if s.extra_attrs:
return s
return t

if not isinstance(s, UnboundType) and not isinstance(t, UnboundType):
if is_proper_subtype(s, t, ignore_promotions=True):
return s
Expand Down
6 changes: 3 additions & 3 deletions mypy/server/objgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def get_edges(o: object) -> Iterator[tuple[object, object]]:
# in closures and self pointers to other objects

if hasattr(e, "__closure__"):
yield (s, "__closure__"), e.__closure__ # type: ignore[union-attr]
yield (s, "__closure__"), e.__closure__
if hasattr(e, "__self__"):
se = e.__self__ # type: ignore[union-attr]
se = e.__self__
if se is not o and se is not type(o) and hasattr(s, "__self__"):
yield s.__self__, se # type: ignore[attr-defined]
yield s.__self__, se
else:
if not type(e) in TYPE_BLACKLIST:
yield s, e
Expand Down
36 changes: 34 additions & 2 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
TupleType,
Type,
TypeAliasType,
TypedDictType,
TypeOfAny,
TypeQuery,
TypeType,
Expand Down Expand Up @@ -104,7 +105,7 @@ def tuple_fallback(typ: TupleType) -> Instance:
raise NotImplementedError
else:
items.append(item)
return Instance(info, [join_type_list(items)])
return Instance(info, [join_type_list(items)], extra_attrs=typ.partial_fallback.extra_attrs)


def get_self_type(func: CallableType, default_self: Instance | TupleType) -> Type | None:
Expand Down Expand Up @@ -462,7 +463,20 @@ def make_simplified_union(
):
simplified_set = try_contracting_literals_in_union(simplified_set)

return get_proper_type(UnionType.make_union(simplified_set, line, column))
result = get_proper_type(UnionType.make_union(simplified_set, line, column))

# Step 4: At last, we erase any (inconsistent) extra attributes on instances.
extra_attrs_set = set()
for item in items:
instance = try_getting_instance_fallback(item)
if instance and instance.extra_attrs:
extra_attrs_set.add(instance.extra_attrs)

fallback = try_getting_instance_fallback(result)
if len(extra_attrs_set) > 1 and fallback:
fallback.extra_attrs = None

return result


def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[Type]:
Expand Down Expand Up @@ -984,3 +998,21 @@ def separate_union_literals(t: UnionType) -> tuple[Sequence[LiteralType], Sequen
union_items.append(item)

return literal_items, union_items


def try_getting_instance_fallback(typ: Type) -> Instance | None:
"""Returns the Instance fallback for this type if one exists or None."""
typ = get_proper_type(typ)
if isinstance(typ, Instance):
return typ
elif isinstance(typ, TupleType):
return typ.partial_fallback
elif isinstance(typ, TypedDictType):
return typ.fallback
elif isinstance(typ, FunctionLike):
return typ.fallback
elif isinstance(typ, LiteralType):
return typ.fallback
elif isinstance(typ, TypeVarType):
return try_getting_instance_fallback(typ.upper_bound)
return None
Loading

0 comments on commit b29051c

Please sign in to comment.