diff --git a/ChangeLog b/ChangeLog index e16b312790..a321d9965c 100644 --- a/ChangeLog +++ b/ChangeLog @@ -19,6 +19,8 @@ Release Date: TBA * Use ``inference_tip`` for ``typing.TypedDict`` brain. +* Fix mro for classes that inherit from typing.Generic + * Add inference tip for typing.Generic and typing.Annotated with ``__class_getitem__`` Closes PyCQA/pylint#2822 diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py index dd5aa1257a..10faa986d8 100644 --- a/astroid/scoped_nodes.py +++ b/astroid/scoped_nodes.py @@ -103,6 +103,39 @@ def _c3_merge(sequences, cls, context): return None +def clean_typing_generic_mro(sequences: List[List["ClassDef"]]) -> None: + """typing.Generic is allowed to appear multiple times in the initial mro. + The final one however, MUST only contain ONE. + + This method will check if Generic is in inferred_bases, but also + part of bases_mro. If true, remove it from inferred_bases + as well as its entry the bases_mro. + + Format sequences: [[self]] + bases_mro + [inferred_bases] + """ + pos_generic_in_main_bases = -1 + # Check if part of inferred_bases + for i, base in enumerate(sequences[-1]): + if base.qname() == "typing.Generic": + pos_generic_in_main_bases = i + break + else: + return + # Check if also part of bases_mro + # Ignore entry for typing.Generic + for i, seq in enumerate(sequences[1:-1]): + if i == pos_generic_in_main_bases: + continue + if any(base.qname() == "typing.Generic" for base in seq): + break + else: + return + # Found multiple Generics in mro, remove entry from inferred_bases + # and the corresponding one from bases_mro + sequences[-1].pop(pos_generic_in_main_bases) + sequences.pop(pos_generic_in_main_bases + 1) + + def clean_duplicates_mro(sequences, cls, context): for sequence in sequences: names = [ @@ -2924,6 +2957,7 @@ def _compute_mro(self, context=None): unmerged_mro = [[self]] + bases_mro + [inferred_bases] unmerged_mro = list(clean_duplicates_mro(unmerged_mro, self, context)) + clean_typing_generic_mro(unmerged_mro) return _c3_merge(unmerged_mro, self, context) def mro(self, context=None) -> List["ClassDef"]: diff --git a/tests/unittest_scoped_nodes.py b/tests/unittest_scoped_nodes.py index a0f882aee6..b98cd8f836 100644 --- a/tests/unittest_scoped_nodes.py +++ b/tests/unittest_scoped_nodes.py @@ -1275,6 +1275,9 @@ class NodeBase(object): def assertEqualMro(self, klass, expected_mro): self.assertEqual([member.name for member in klass.mro()], expected_mro) + def assertEqualMroQName(self, klass, expected_mro): + self.assertEqual([member.qname() for member in klass.mro()], expected_mro) + @unittest.skipUnless(HAS_SIX, "These tests require the six library") def test_with_metaclass_mro(self): astroid = builder.parse( @@ -1438,6 +1441,142 @@ class C(scope.A, scope.B): ) self.assertEqualMro(cls, ["C", "A", "B", "object"]) + @test_utils.require_version(minver="3.7") + def test_mro_generic_1(self): + cls = builder.extract_node( + """ + import typing + T = typing.TypeVar('T') + class A(typing.Generic[T]): ... + class B: ... + class C(A[T], B): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", "typing.Generic", ".B", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_2(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A: ... + class B(Generic[T]): ... + class C(Generic[T], A, B[T]): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_3(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A: ... + class B(A, Generic[T]): ... + class C(Generic[T]): ... + class D(B[T], C[T], Generic[T]): ... + """ + ) + self.assertEqualMroQName( + cls, [".D", ".B", ".A", ".C", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_4(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A: ... + class B(Generic[T]): ... + class C(A, Generic[T], B[T]): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_5(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T1 = TypeVar('T1') + T2 = TypeVar('T2') + class A(Generic[T1]): ... + class B(Generic[T2]): ... + class C(A[T1], B[T2]): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_6(self): + cls = builder.extract_node( + """ + from typing import Generic as TGeneric, TypeVar + T = TypeVar('T') + class Generic: ... + class A(Generic): ... + class B(TGeneric[T]): ... + class C(A, B[T]): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", ".Generic", ".B", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_7(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A(): ... + class B(Generic[T]): ... + class C(A, B[T]): ... + class D: ... + class E(C[str], D): ... + """ + ) + self.assertEqualMroQName( + cls, [".E", ".C", ".A", ".B", "typing.Generic", ".D", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_error_1(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T1 = TypeVar('T1') + T2 = TypeVar('T2') + class A(Generic[T1], Generic[T2]): ... + """ + ) + with self.assertRaises(DuplicateBasesError) as ex: + cls.mro() + + @test_utils.require_version(minver="3.7") + def test_mro_generic_error_2(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A(Generic[T]): ... + class B(A[T], A[T]): ... + """ + ) + with self.assertRaises(DuplicateBasesError) as ex: + cls.mro() + def test_generator_from_infer_call_result_parent(self): func = builder.extract_node( """