Skip to content

Commit

Permalink
Clean typing.Generic from mro
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Apr 7, 2021
1 parent 354c46e commit dedf3b9
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Release Date: TBA

* Use ``inference_tip`` for ``typing.TypedDict`` brain.

* Fix mro for classes that inherit from typing.Generic

* Add inference tip for typing.Generic and typing.Annotated with ``__class_getitem__``

Closes PyCQA/pylint#2822
Expand Down
34 changes: 34 additions & 0 deletions astroid/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,39 @@ def _c3_merge(sequences, cls, context):
return None


def clean_typing_generic_mro(sequences: List[List["ClassDef"]]) -> None:
"""typing.Generic is allowed to appear multiple times in the initial mro.
The final one however, MUST only contain ONE occurrence.
This method will check if the Generic is in the inferred_bases, but also
part of the bases_mro. If true, remove it from the inferred_bases
as well as its entry the bases_mro.
Format sequences: [[self]] + bases_mro + [inferred_bases]
"""
pos_generic_in_main_bases = -1
# Check if part of inferred_bases
for i, base in enumerate(sequences[-1]):
if base.qname() == "typing.Generic":
pos_generic_in_main_bases = i
break
else:
return
# Check if also part of the bases_mro
# Ignore entry for typing.Generic
for i, seq in enumerate(sequences[1:-1]):
if i == pos_generic_in_main_bases:
continue
if any(base.qname() == "typing.Generic" for base in seq):
break
else:
return
# Found multiple Generics in mro, remove entry from inferred_bases
# and the coresponding one from bases_mro
sequences[-1].pop(pos_generic_in_main_bases)
sequences.pop(pos_generic_in_main_bases + 1)


def clean_duplicates_mro(sequences, cls, context):
for sequence in sequences:
names = [
Expand Down Expand Up @@ -2924,6 +2957,7 @@ def _compute_mro(self, context=None):

unmerged_mro = [[self]] + bases_mro + [inferred_bases]
unmerged_mro = list(clean_duplicates_mro(unmerged_mro, self, context))
clean_typing_generic_mro(unmerged_mro)
return _c3_merge(unmerged_mro, self, context)

def mro(self, context=None) -> List["ClassDef"]:
Expand Down
139 changes: 139 additions & 0 deletions tests/unittest_scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,9 @@ class NodeBase(object):
def assertEqualMro(self, klass, expected_mro):
self.assertEqual([member.name for member in klass.mro()], expected_mro)

def assertEqualMroQName(self, klass, expected_mro):
self.assertEqual([member.qname() for member in klass.mro()], expected_mro)

@unittest.skipUnless(HAS_SIX, "These tests require the six library")
def test_with_metaclass_mro(self):
astroid = builder.parse(
Expand Down Expand Up @@ -1438,6 +1441,142 @@ class C(scope.A, scope.B):
)
self.assertEqualMro(cls, ["C", "A", "B", "object"])

@test_utils.require_version(minver="3.7")
def test_mro_generic_1(self):
cls = builder.extract_node(
"""
import typing
T = typing.TypeVar('T')
class A(typing.Generic[T]): ...
class B: ...
class C(A[T], B): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", "typing.Generic", ".B", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_2(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(Generic[T]): ...
class C(Generic[T], A, B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_3(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(A, Generic[T]): ...
class C(Generic[T]): ...
class D(B[T], C[T], Generic[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".D", ".B", ".A", ".C", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_4(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A: ...
class B(Generic[T]): ...
class C(A, Generic[T], B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_5(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T1 = TypeVar('T1')
T2 = TypeVar('T2')
class A(Generic[T1]): ...
class B(Generic[T2]): ...
class C(A[T1], B[T2]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_6(self):
cls = builder.extract_node(
"""
from typing import Generic as TGeneric, TypeVar
T = TypeVar('T')
class Generic: ...
class A(Generic): ...
class B(TGeneric[T]): ...
class C(A, B[T]): ...
"""
)
self.assertEqualMroQName(
cls, [".C", ".A", ".Generic", ".B", "typing.Generic", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_7(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(): ...
class B(Generic[T]): ...
class C(A, B[T]): ...
class D: ...
class E(C[str], D): ...
"""
)
self.assertEqualMroQName(
cls, [".E", ".C", ".A", ".B", "typing.Generic", ".D", "builtins.object"]
)

@test_utils.require_version(minver="3.7")
def test_mro_generic_error_1(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T1 = TypeVar('T1')
T2 = TypeVar('T2')
class A(Generic[T1], Generic[T2]): ...
"""
)
with self.assertRaises(DuplicateBasesError) as ex:
cls.mro()

@test_utils.require_version(minver="3.7")
def test_mro_generic_error_2(self):
cls = builder.extract_node(
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class A(Generic[T]): ...
class B(A[T], A[T]): ...
"""
)
with self.assertRaises(DuplicateBasesError) as ex:
cls.mro()

def test_generator_from_infer_call_result_parent(self):
func = builder.extract_node(
"""
Expand Down

0 comments on commit dedf3b9

Please sign in to comment.