Skip to content

Commit

Permalink
Clean typing.Generic from mro
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Apr 7, 2021
1 parent 6f275c1 commit 3ff3222
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Release Date: TBA

* Use ``inference_tip`` for ``typing.TypedDict`` brain.

* Fix mro for classes that inherit from typing.Generic

* Add inference tip for typing.Generic and typing.Annotated with ``__class_getitem__``

Closes PyCQA/pylint#2822
Expand Down
34 changes: 34 additions & 0 deletions astroid/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,39 @@ def _c3_merge(sequences, cls, context):
return None


def clean_typing_generic_mro(sequences: List[List["ClassDef"]]) -> None:
"""typing.Generic is allowed to appear multiple times in the initial mro.
The final one however, MUST only contain ONE.
This method will check if Generic is in inferred_bases, but also
part of bases_mro. If true, remove it from inferred_bases
as well as its entry the bases_mro.
Format sequences: [[self]] + bases_mro + [inferred_bases]
"""
pos_generic_in_main_bases = -1
# Check if part of inferred_bases
for i, base in enumerate(sequences[-1]):
if base.qname() == "typing.Generic":
pos_generic_in_main_bases = i
break
else:
return
# Check if also part of bases_mro
# Ignore entry for typing.Generic
for i, seq in enumerate(sequences[1:-1]):
if i == pos_generic_in_main_bases:
continue
if any(base.qname() == "typing.Generic" for base in seq):
break
else:
return
# Found multiple Generics in mro, remove entry from inferred_bases
# and the corresponding one from bases_mro
sequences[-1].pop(pos_generic_in_main_bases)
sequences.pop(pos_generic_in_main_bases + 1)


def clean_duplicates_mro(sequences, cls, context):
for sequence in sequences:
names = [
Expand Down Expand Up @@ -2924,6 +2957,7 @@ def _compute_mro(self, context=None):

unmerged_mro = [[self]] + bases_mro + [inferred_bases]
unmerged_mro = list(clean_duplicates_mro(unmerged_mro, self, context))
clean_typing_generic_mro(unmerged_mro)
return _c3_merge(unmerged_mro, self, context)

def mro(self, context=None) -> List["ClassDef"]:
Expand Down
139 changes: 139 additions & 0 deletions tests/unittest_scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,9 @@ class NodeBase(object):
def assertEqualMro(self, klass, expected_mro):
self.assertEqual([member.name for member in klass.mro()], expected_mro)

def assertEqualMroQName(self, klass, expected_mro):
self.assertEqual([member.qname() for member in klass.mro()], expected_mro)

@unittest.skipUnless(HAS_SIX, "These tests require the six library")
def test_with_metaclass_mro(self):
astroid = builder.parse(
Expand Down Expand Up @@ -1438,6 +1441,142 @@ class C(scope.A, scope.B):
)
self.assertEqualMro(cls, ["C", "A", "B", "object"])

@test_utils.require_version(minver="3.7")
def test_mro_generic_1(self):
cls = builder.extract_node(
"""
import typing
T = typing.TypeVar('T')
class A(typing.Generic[T]): ...
class B: ...
class C(A[T], B): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", "typing.Generic", ".B", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_2(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(Generic[T]): ...
class C(Generic[T], A, B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_3(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(A, Generic[T]): ...
class C(Generic[T]): ...
class D(B[T], C[T], Generic[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".D", ".B", ".A", ".C", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_4(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(Generic[T]): ...
class C(A, Generic[T], B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_5(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T1 = TypeVar('T1')
T2 = TypeVar('T2')
class A(Generic[T1]): ...
class B(Generic[T2]): ...
class C(A[T1], B[T2]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_6(self):
cls = builder.extract_node(
"""
from typing import Generic as TGeneric, TypeVar
T = TypeVar('T')
class Generic: ...
class A(Generic): ...
class B(TGeneric[T]): ...
class C(A, B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".Generic", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_7(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(): ...
class B(Generic[T]): ...
class C(A, B[T]): ...
class D: ...
class E(C[str], D): ...
"""
)
self.assertEqualMroQName(
cls, [".E", ".C", ".A", ".B", "typing.Generic", ".D", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_error_1(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T1 = TypeVar('T1')
T2 = TypeVar('T2')
class A(Generic[T1], Generic[T2]): ...
"""
)
with self.assertRaises(DuplicateBasesError) as ex:
cls.mro()

@test_utils.require_version(minver="3.7")
def test_mro_generic_error_2(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(Generic[T]): ...
class B(A[T], A[T]): ...
"""
)
with self.assertRaises(DuplicateBasesError) as ex:
cls.mro()

def test_generator_from_infer_call_result_parent(self):
func = builder.extract_node(
"""
Expand Down

0 comments on commit 3ff3222

Please sign in to comment.