Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify infernce tip for typing.Generic and typing.Annotated with __class_getitem__ #931

Merged
merged 5 commits into from
Apr 10, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ Release Date: TBA

* Use ``inference_tip`` for ``typing.TypedDict`` brain.

* Fix mro for classes that inherit from typing.Generic

* Add inference tip for typing.Generic and typing.Annotated with ``__class_getitem__``

Closes PyCQA/pylint#2822


What's New in astroid 2.5.2?
============================
Expand Down
42 changes: 33 additions & 9 deletions astroid/brain/brain_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class {0}(metaclass=Meta):
)
)

CLASS_GETITEM_TEMPLATE = """
@classmethod
def __class_getitem__(cls, item):
return cls
"""


def looks_like_typing_typevar_or_newtype(node):
func = node.func
Expand Down Expand Up @@ -126,7 +132,9 @@ def _looks_like_typing_subscript(node):
return False


def infer_typing_attr(node, context=None):
def infer_typing_attr(
node: nodes.Subscript, ctx: context.InferenceContext = None
) -> typing.Iterator[nodes.ClassDef]:
"""Infer a typing.X[...] subscript"""
try:
value = next(node.value.infer())
Expand All @@ -142,8 +150,31 @@ def infer_typing_attr(node, context=None):
# (PY37+) handle it separately.
raise UseInferenceDefault

if (
PY37
and isinstance(value, nodes.ClassDef)
and value.qname()
in ("typing.Generic", "typing.Annotated", "typing_extensions.Annotated")
):
# With PY37+ typing.Generic and typing.Annotated (PY39) are subscriptable
# through __class_getitem__. Since astroid can't easily
# infer the native methods, replace them for an easy inference tip
func_to_add = astroid.extract_node(CLASS_GETITEM_TEMPLATE)
value.locals["__class_getitem__"] = [func_to_add]
if (
isinstance(node.parent, nodes.ClassDef)
and node in node.parent.bases
and getattr(node.parent, "__cache", None)
hippo91 marked this conversation as resolved.
Show resolved Hide resolved
):
# node.parent.slots is evaluated and cached before the inference tip
# is first applied. Remove the last result to allow a recalculation of slots
cache = getattr(node.parent, "__cache")
if cache and cache.get(node.parent.slots) is not None:
del cache[node.parent.slots]
return iter([value])

node = extract_node(TYPING_TYPE_TEMPLATE.format(value.qname().split(".")[-1]))
return node.infer(context=context)
return node.infer(context=ctx)


def _looks_like_typedDict( # pylint: disable=invalid-name
Expand All @@ -166,13 +197,6 @@ def infer_typedDict( # pylint: disable=invalid-name
return iter([class_def])


CLASS_GETITEM_TEMPLATE = """
@classmethod
def __class_getitem__(cls, item):
return cls
"""


def _looks_like_typing_alias(node: nodes.Call) -> bool:
"""
Returns True if the node corresponds to a call to _alias function.
Expand Down
34 changes: 34 additions & 0 deletions astroid/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,39 @@ def _c3_merge(sequences, cls, context):
return None


def clean_typing_generic_mro(sequences: List[List["ClassDef"]]) -> None:
"""typing.Generic is allowed to appear multiple times in the initial mro.
The final one however, MUST only contain ONE.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The terms initial and final are a bit confusing. What are their meanings?


This method will check if Generic is in inferred_bases, but also
part of bases_mro. If true, remove it from inferred_bases
as well as its entry the bases_mro.

Format sequences: [[self]] + bases_mro + [inferred_bases]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it could not be clearer to pass 3 arguments instead of one (i.e self, bases_mro and inferred_bases).
What do you think about it?

"""
pos_generic_in_main_bases = -1
# Check if part of inferred_bases
for i, base in enumerate(sequences[-1]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be definitely clearer to manipulate self (or an alias), bases_mro and inferred bases instead of dealing with only one sequence and multiples indexes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to leave list(clean_duplicates_mro(unmerged_mro, self, context)) in _compute_mro untouched. This should be run before the MRO is cleaned from Generics. However, that also means we only have access to the whole sequence. Creating references inside clean_typing_generic_mro might work though.

if base.qname() == "typing.Generic":
pos_generic_in_main_bases = i
break
else:
return
# Check if also part of bases_mro
# Ignore entry for typing.Generic
for i, seq in enumerate(sequences[1:-1]):
if i == pos_generic_in_main_bases:
continue
if any(base.qname() == "typing.Generic" for base in seq):
break
else:
return
# Found multiple Generics in mro, remove entry from inferred_bases
# and the corresponding one from bases_mro
sequences[-1].pop(pos_generic_in_main_bases)
sequences.pop(pos_generic_in_main_bases + 1)


def clean_duplicates_mro(sequences, cls, context):
for sequence in sequences:
names = [
Expand Down Expand Up @@ -2924,6 +2957,7 @@ def _compute_mro(self, context=None):

unmerged_mro = [[self]] + bases_mro + [inferred_bases]
unmerged_mro = list(clean_duplicates_mro(unmerged_mro, self, context))
clean_typing_generic_mro(unmerged_mro)
return _c3_merge(unmerged_mro, self, context)

def mro(self, context=None) -> List["ClassDef"]:
Expand Down
50 changes: 50 additions & 0 deletions tests/unittest_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,56 @@ def test_typing_types(self):
inferred = next(node.infer())
self.assertIsInstance(inferred, nodes.ClassDef, node.as_string())

@test_utils.require_version(minver="3.7")
def test_typing_generic_subscriptable(self):
"""Test typing.Generic is subscriptable with __class_getitem__ (added in PY37)"""
node = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
Generic[T]
"""
)
inferred = next(node.infer())
assert isinstance(inferred, nodes.ClassDef)
assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)

@test_utils.require_version(minver="3.9")
def test_typing_annotated_subscriptable(self):
"""Test typing.Annotated is subscriptable with __class_getitem__"""
node = builder.extract_node(
"""
import typing
typing.Annotated[str, "data"]
"""
)
inferred = next(node.infer())
assert isinstance(inferred, nodes.ClassDef)
assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)

@test_utils.require_version(minver="3.7")
def test_typing_generic_slots(self):
"""Test cache reset for slots if Generic subscript is inferred."""
node = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(Generic[T]):
__slots__ = ['value']
def __init__(self, value):
self.value = value
"""
)
inferred = next(node.infer())
assert len(inferred.slots()) == 0
# Only after the subscript base is inferred and the inference tip applied,
# will slots contain the correct value
next(node.bases[0].infer())
slots = inferred.slots()
assert len(slots) == 1
assert isinstance(slots[0], nodes.Const)
assert slots[0].value == "value"

def test_has_dunder_args(self):
ast_node = builder.extract_node(
"""
Expand Down
139 changes: 139 additions & 0 deletions tests/unittest_scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,9 @@ class NodeBase(object):
def assertEqualMro(self, klass, expected_mro):
self.assertEqual([member.name for member in klass.mro()], expected_mro)

def assertEqualMroQName(self, klass, expected_mro):
self.assertEqual([member.qname() for member in klass.mro()], expected_mro)

@unittest.skipUnless(HAS_SIX, "These tests require the six library")
def test_with_metaclass_mro(self):
astroid = builder.parse(
Expand Down Expand Up @@ -1438,6 +1441,142 @@ class C(scope.A, scope.B):
)
self.assertEqualMro(cls, ["C", "A", "B", "object"])

@test_utils.require_version(minver="3.7")
def test_mro_generic_1(self):
cls = builder.extract_node(
"""
import typing
T = typing.TypeVar('T')
class A(typing.Generic[T]): ...
class B: ...
class C(A[T], B): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", "typing.Generic", ".B", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_2(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(Generic[T]): ...
class C(Generic[T], A, B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_3(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(A, Generic[T]): ...
class C(Generic[T]): ...
class D(B[T], C[T], Generic[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".D", ".B", ".A", ".C", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_4(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(Generic[T]): ...
class C(A, Generic[T], B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_5(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T1 = TypeVar('T1')
T2 = TypeVar('T2')
class A(Generic[T1]): ...
class B(Generic[T2]): ...
class C(A[T1], B[T2]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_6(self):
cls = builder.extract_node(
"""
from typing import Generic as TGeneric, TypeVar
T = TypeVar('T')
class Generic: ...
class A(Generic): ...
class B(TGeneric[T]): ...
class C(A, B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".Generic", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_7(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(): ...
class B(Generic[T]): ...
class C(A, B[T]): ...
class D: ...
class E(C[str], D): ...
"""
)
self.assertEqualMroQName(
cls, [".E", ".C", ".A", ".B", "typing.Generic", ".D", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_error_1(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T1 = TypeVar('T1')
T2 = TypeVar('T2')
class A(Generic[T1], Generic[T2]): ...
"""
)
with self.assertRaises(DuplicateBasesError) as ex:
cls.mro()

@test_utils.require_version(minver="3.7")
def test_mro_generic_error_2(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(Generic[T]): ...
class B(A[T], A[T]): ...
"""
)
with self.assertRaises(DuplicateBasesError) as ex:
cls.mro()

def test_generator_from_infer_call_result_parent(self):
func = builder.extract_node(
"""
Expand Down