Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify infernce tip for typing.Generic and typing.Annotated with __class_getitem__ #931

Merged
merged 5 commits into from
Apr 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ Release Date: TBA

* Use ``inference_tip`` for ``typing.TypedDict`` brain.

* Fix mro for classes that inherit from typing.Generic

* Add inference tip for typing.Generic and typing.Annotated with ``__class_getitem__``

Closes PyCQA/pylint#2822


What's New in astroid 2.5.2?
============================
Expand Down
42 changes: 33 additions & 9 deletions astroid/brain/brain_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class {0}(metaclass=Meta):
)
)

CLASS_GETITEM_TEMPLATE = """
@classmethod
def __class_getitem__(cls, item):
return cls
"""


def looks_like_typing_typevar_or_newtype(node):
func = node.func
Expand Down Expand Up @@ -126,7 +132,9 @@ def _looks_like_typing_subscript(node):
return False


def infer_typing_attr(node, context=None):
def infer_typing_attr(
node: nodes.Subscript, ctx: context.InferenceContext = None
) -> typing.Iterator[nodes.ClassDef]:
"""Infer a typing.X[...] subscript"""
try:
value = next(node.value.infer())
Expand All @@ -142,8 +150,31 @@ def infer_typing_attr(node, context=None):
# (PY37+) handle it separately.
raise UseInferenceDefault

if (
PY37
and isinstance(value, nodes.ClassDef)
and value.qname()
in ("typing.Generic", "typing.Annotated", "typing_extensions.Annotated")
):
# With PY37+ typing.Generic and typing.Annotated (PY39) are subscriptable
# through __class_getitem__. Since astroid can't easily
# infer the native methods, replace them for an easy inference tip
func_to_add = astroid.extract_node(CLASS_GETITEM_TEMPLATE)
value.locals["__class_getitem__"] = [func_to_add]
if (
isinstance(node.parent, nodes.ClassDef)
and node in node.parent.bases
and getattr(node.parent, "__cache", None)
hippo91 marked this conversation as resolved.
Show resolved Hide resolved
):
# node.parent.slots is evaluated and cached before the inference tip
# is first applied. Remove the last result to allow a recalculation of slots
cache = getattr(node.parent, "__cache")
if cache.get(node.parent.slots) is not None:
del cache[node.parent.slots]
return iter([value])

node = extract_node(TYPING_TYPE_TEMPLATE.format(value.qname().split(".")[-1]))
return node.infer(context=context)
return node.infer(context=ctx)


def _looks_like_typedDict( # pylint: disable=invalid-name
Expand All @@ -166,13 +197,6 @@ def infer_typedDict( # pylint: disable=invalid-name
return iter([class_def])


CLASS_GETITEM_TEMPLATE = """
@classmethod
def __class_getitem__(cls, item):
return cls
"""


def _looks_like_typing_alias(node: nodes.Call) -> bool:
"""
Returns True if the node corresponds to a call to _alias function.
Expand Down
37 changes: 37 additions & 0 deletions astroid/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,42 @@ def _c3_merge(sequences, cls, context):
return None


def clean_typing_generic_mro(sequences: List[List["ClassDef"]]) -> None:
"""A class can inherit from typing.Generic directly, as base,
and as base of bases. The merged MRO must however only contain the last entry.
To prepare for _c3_merge, remove some typing.Generic entries from
sequences if multiple are present.

This method will check if Generic is in inferred_bases and also
part of bases_mro. If true, remove it from inferred_bases
as well as its entry the bases_mro.

Format sequences: [[self]] + bases_mro + [inferred_bases]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it could not be clearer to pass 3 arguments instead of one (i.e self, bases_mro and inferred_bases).
What do you think about it?

"""
bases_mro = sequences[1:-1]
inferred_bases = sequences[-1]
# Check if Generic is part of inferred_bases
for i, base in enumerate(inferred_bases):
if base.qname() == "typing.Generic":
position_in_inferred_bases = i
break
else:
return
# Check if also part of bases_mro
# Ignore entry for typing.Generic
for i, seq in enumerate(bases_mro):
if i == position_in_inferred_bases:
continue
if any(base.qname() == "typing.Generic" for base in seq):
break
else:
return
# Found multiple Generics in mro, remove entry from inferred_bases
# and the corresponding one from bases_mro
inferred_bases.pop(position_in_inferred_bases)
bases_mro.pop(position_in_inferred_bases)


def clean_duplicates_mro(sequences, cls, context):
for sequence in sequences:
names = [
Expand Down Expand Up @@ -2924,6 +2960,7 @@ def _compute_mro(self, context=None):

unmerged_mro = [[self]] + bases_mro + [inferred_bases]
unmerged_mro = list(clean_duplicates_mro(unmerged_mro, self, context))
clean_typing_generic_mro(unmerged_mro)
return _c3_merge(unmerged_mro, self, context)

def mro(self, context=None) -> List["ClassDef"]:
Expand Down
50 changes: 50 additions & 0 deletions tests/unittest_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,56 @@ def test_typing_types(self):
inferred = next(node.infer())
self.assertIsInstance(inferred, nodes.ClassDef, node.as_string())

@test_utils.require_version(minver="3.7")
def test_typing_generic_subscriptable(self):
"""Test typing.Generic is subscriptable with __class_getitem__ (added in PY37)"""
node = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
Generic[T]
"""
)
inferred = next(node.infer())
assert isinstance(inferred, nodes.ClassDef)
assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)

@test_utils.require_version(minver="3.9")
def test_typing_annotated_subscriptable(self):
"""Test typing.Annotated is subscriptable with __class_getitem__"""
node = builder.extract_node(
"""
import typing
typing.Annotated[str, "data"]
"""
)
inferred = next(node.infer())
assert isinstance(inferred, nodes.ClassDef)
assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)

@test_utils.require_version(minver="3.7")
def test_typing_generic_slots(self):
"""Test cache reset for slots if Generic subscript is inferred."""
node = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(Generic[T]):
__slots__ = ['value']
def __init__(self, value):
self.value = value
"""
)
inferred = next(node.infer())
assert len(inferred.slots()) == 0
# Only after the subscript base is inferred and the inference tip applied,
# will slots contain the correct value
next(node.bases[0].infer())
slots = inferred.slots()
assert len(slots) == 1
assert isinstance(slots[0], nodes.Const)
assert slots[0].value == "value"

def test_has_dunder_args(self):
ast_node = builder.extract_node(
"""
Expand Down
139 changes: 139 additions & 0 deletions tests/unittest_scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,9 @@ class NodeBase(object):
def assertEqualMro(self, klass, expected_mro):
self.assertEqual([member.name for member in klass.mro()], expected_mro)

def assertEqualMroQName(self, klass, expected_mro):
self.assertEqual([member.qname() for member in klass.mro()], expected_mro)

@unittest.skipUnless(HAS_SIX, "These tests require the six library")
def test_with_metaclass_mro(self):
astroid = builder.parse(
Expand Down Expand Up @@ -1438,6 +1441,142 @@ class C(scope.A, scope.B):
)
self.assertEqualMro(cls, ["C", "A", "B", "object"])

@test_utils.require_version(minver="3.7")
def test_mro_generic_1(self):
cls = builder.extract_node(
"""
import typing
T = typing.TypeVar('T')
class A(typing.Generic[T]): ...
class B: ...
class C(A[T], B): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", "typing.Generic", ".B", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_2(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(Generic[T]): ...
class C(Generic[T], A, B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_3(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(A, Generic[T]): ...
class C(Generic[T]): ...
class D(B[T], C[T], Generic[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".D", ".B", ".A", ".C", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_4(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(Generic[T]): ...
class C(A, Generic[T], B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_5(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T1 = TypeVar('T1')
T2 = TypeVar('T2')
class A(Generic[T1]): ...
class B(Generic[T2]): ...
class C(A[T1], B[T2]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_6(self):
cls = builder.extract_node(
"""
from typing import Generic as TGeneric, TypeVar
T = TypeVar('T')
class Generic: ...
class A(Generic): ...
class B(TGeneric[T]): ...
class C(A, B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".Generic", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_7(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(): ...
class B(Generic[T]): ...
class C(A, B[T]): ...
class D: ...
class E(C[str], D): ...
"""
)
self.assertEqualMroQName(
cls, [".E", ".C", ".A", ".B", "typing.Generic", ".D", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_error_1(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T1 = TypeVar('T1')
T2 = TypeVar('T2')
class A(Generic[T1], Generic[T2]): ...
"""
)
with self.assertRaises(DuplicateBasesError) as ex:
cls.mro()

@test_utils.require_version(minver="3.7")
def test_mro_generic_error_2(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(Generic[T]): ...
class B(A[T], A[T]): ...
"""
)
with self.assertRaises(DuplicateBasesError) as ex:
cls.mro()

def test_generator_from_infer_call_result_parent(self):
func = builder.extract_node(
"""
Expand Down