From 934264d1730329b99b0d8fa8519f9ea359b58a41 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Mon, 22 Feb 2021 11:00:48 +0100 Subject: [PATCH 1/3] Fix duplicate bases error for typing._GenericAlias * Fixes #905 --- ChangeLog | 4 ++++ astroid/scoped_nodes.py | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/ChangeLog b/ChangeLog index 519fa249f3..13a4507129 100644 --- a/ChangeLog +++ b/ChangeLog @@ -11,6 +11,10 @@ Release Date: TBA Closes #895 #899 +* Fixed duplicate bases error for "typing._GenericAlias" (false-positive) + + Closes #905 + What's New in astroid 2.5? ============================ Release Date: 2021-02-15 diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py index 269b16373d..8cb72dbb2a 100644 --- a/astroid/scoped_nodes.py +++ b/astroid/scoped_nodes.py @@ -105,7 +105,12 @@ def clean_duplicates_mro(sequences, cls, context): (node.lineno, node.qname()) if node.name else None for node in sequence ] last_index = dict(map(reversed, enumerate(names))) - if names and names[0] is not None and last_index[names[0]] != 0: + if ( + names + and names[0] is not None + and last_index[names[0]] != 0 + and names[0][1] != "typing._GenericAlias" + ): raise exceptions.DuplicateBasesError( message="Duplicates found in MROs {mros} for {cls!r}.", mros=sequences, From 3bbf4d24275eb62ff8268bd2b7bd29926276ceaa Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 23 Feb 2021 01:48:29 +0100 Subject: [PATCH 2/3] Add test cases for duplicate bases error --- tests/unittest_scoped_nodes.py | 56 ++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/unittest_scoped_nodes.py b/tests/unittest_scoped_nodes.py index 104cdae1a7..7ef68a3e77 100644 --- a/tests/unittest_scoped_nodes.py +++ b/tests/unittest_scoped_nodes.py @@ -1436,6 +1436,62 @@ class C(scope.A, scope.B): ) self.assertEqualMro(cls, ["C", "A", "B", "object"]) + @test_utils.require_version("3.7", "3.9") + def test_mro_with_duplicate_generic_alias(self): + """Catch false positive. Assert no error is thrown.""" + cls = builder.extract_node( + """ + import abc + from typing import Sized, Iterable + class AbstractRoute(abc.ABC): + pass + class AbstractResource(Sized, Iterable["AbstractRoute"]): + pass + class IndexView(AbstractResource): + def __init__(self): + self.var = 1 + """ + ) + self.assertEqualMro( + cls, + [ + "IndexView", + "AbstractResource", + "_GenericAlias", + "_Final", + "_GenericAlias", + "object", + ], + ) + + @test_utils.require_version("3.9") + def test_mro_with_duplicate_generic_alias_2(self): + cls = builder.extract_node( + """ + import abc + from typing import Sized, Iterable + class AbstractRoute(abc.ABC): + pass + class AbstractResource(Sized, Iterable["AbstractRoute"]): + pass + class IndexView(AbstractResource): + def __init__(self): + self.var = 1 + """ + ) + self.assertEqualMro( + cls, + [ + "IndexView", + "AbstractResource", + "_SpecialGenericAlias", + "_BaseGenericAlias", + "_Final", + "_SpecialGenericAlias", + "object", + ], + ) + def test_generator_from_infer_call_result_parent(self): func = builder.extract_node( """ From b151e807806d2f188c45ee450a7d84535e75ed4c Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sat, 27 Feb 2021 15:46:18 +0100 Subject: [PATCH 3/3] Use flag to guard DuplicateBasesError --- ChangeLog | 2 +- astroid/scoped_nodes.py | 31 +++++++++++++++------- tests/unittest_scoped_nodes.py | 47 +++++++++++++++------------------- 3 files changed, 43 insertions(+), 37 deletions(-) diff --git a/ChangeLog b/ChangeLog index 13a4507129..8c15f95abe 100644 --- a/ChangeLog +++ b/ChangeLog @@ -11,7 +11,7 @@ Release Date: TBA Closes #895 #899 -* Fixed duplicate bases error for "typing._GenericAlias" (false-positive) +* Use flag to guard DuplicateBasesError Closes #905 diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py index 8cb72dbb2a..73e08d0b01 100644 --- a/astroid/scoped_nodes.py +++ b/astroid/scoped_nodes.py @@ -99,17 +99,19 @@ def _c3_merge(sequences, cls, context): return None -def clean_duplicates_mro(sequences, cls, context): +def clean_duplicates_mro( + sequences, cls, context, raise_duplicate_bases_error: bool = True +): for sequence in sequences: names = [ (node.lineno, node.qname()) if node.name else None for node in sequence ] last_index = dict(map(reversed, enumerate(names))) if ( - names + raise_duplicate_bases_error is True + and names and names[0] is not None and last_index[names[0]] != 0 - and names[0][1] != "typing._GenericAlias" ): raise exceptions.DuplicateBasesError( message="Duplicates found in MROs {mros} for {cls!r}.", @@ -2821,7 +2823,7 @@ def slots(self): def grouped_slots(): # Not interested in object, since it can't have slots. - for cls in self.mro()[:-1]: + for cls in self.mro(raise_duplicate_bases_error=False)[:-1]: try: cls_slots = cls._slots() except NotImplementedError: @@ -2876,7 +2878,7 @@ def _inferred_bases(self, context=None): else: yield from baseobj.bases - def _compute_mro(self, context=None): + def _compute_mro(self, context=None, raise_duplicate_bases_error: bool = True): inferred_bases = list(self._inferred_bases(context=context)) bases_mro = [] for base in inferred_bases: @@ -2884,7 +2886,10 @@ def _compute_mro(self, context=None): continue try: - mro = base._compute_mro(context=context) + mro = base._compute_mro( + context=context, + raise_duplicate_bases_error=raise_duplicate_bases_error, + ) bases_mro.append(mro) except NotImplementedError: # Some classes have in their ancestors both newstyle and @@ -2896,10 +2901,16 @@ def _compute_mro(self, context=None): bases_mro.append(ancestors) unmerged_mro = [[self]] + bases_mro + [inferred_bases] - unmerged_mro = list(clean_duplicates_mro(unmerged_mro, self, context)) + unmerged_mro = list( + clean_duplicates_mro( + unmerged_mro, self, context, raise_duplicate_bases_error + ) + ) return _c3_merge(unmerged_mro, self, context) - def mro(self, context=None) -> List["ClassDef"]: + def mro( + self, context=None, raise_duplicate_bases_error: bool = True + ) -> List["ClassDef"]: """Get the method resolution order, using C3 linearization. :returns: The list of ancestors, sorted by the mro. @@ -2907,7 +2918,9 @@ def mro(self, context=None) -> List["ClassDef"]: :raises DuplicateBasesError: Duplicate bases in the same class base :raises InconsistentMroError: A class' MRO is inconsistent """ - return self._compute_mro(context=context) + return self._compute_mro( + context=context, raise_duplicate_bases_error=raise_duplicate_bases_error + ) def bool_value(self, context=None): """Determine the boolean value of this node. diff --git a/tests/unittest_scoped_nodes.py b/tests/unittest_scoped_nodes.py index 7ef68a3e77..cd77afe7f7 100644 --- a/tests/unittest_scoped_nodes.py +++ b/tests/unittest_scoped_nodes.py @@ -1270,8 +1270,18 @@ class NodeBase(object): assert len(slots) == 3, slots assert [slot.value for slot in slots] == ["a", "b", "c"] - def assertEqualMro(self, klass, expected_mro): - self.assertEqual([member.name for member in klass.mro()], expected_mro) + def assertEqualMro( + self, klass, expected_mro, raise_duplicate_bases_error: bool = True + ): + self.assertEqual( + [ + member.name + for member in klass.mro( + raise_duplicate_bases_error=raise_duplicate_bases_error + ) + ], + expected_mro, + ) @unittest.skipUnless(HAS_SIX, "These tests require the six library") def test_with_metaclass_mro(self): @@ -1441,40 +1451,24 @@ def test_mro_with_duplicate_generic_alias(self): """Catch false positive. Assert no error is thrown.""" cls = builder.extract_node( """ - import abc - from typing import Sized, Iterable - class AbstractRoute(abc.ABC): - pass - class AbstractResource(Sized, Iterable["AbstractRoute"]): - pass - class IndexView(AbstractResource): + from typing import Sized, Hashable + class Derived(Sized, Hashable): def __init__(self): self.var = 1 """ ) self.assertEqualMro( cls, - [ - "IndexView", - "AbstractResource", - "_GenericAlias", - "_Final", - "_GenericAlias", - "object", - ], + ["Derived", "_GenericAlias", "_Final", "object"], + raise_duplicate_bases_error=False, ) @test_utils.require_version("3.9") def test_mro_with_duplicate_generic_alias_2(self): cls = builder.extract_node( """ - import abc - from typing import Sized, Iterable - class AbstractRoute(abc.ABC): - pass - class AbstractResource(Sized, Iterable["AbstractRoute"]): - pass - class IndexView(AbstractResource): + from typing import Sized, Hashable + class Derived(Sized, Hashable): def __init__(self): self.var = 1 """ @@ -1482,14 +1476,13 @@ def __init__(self): self.assertEqualMro( cls, [ - "IndexView", - "AbstractResource", + "Derived", "_SpecialGenericAlias", "_BaseGenericAlias", "_Final", - "_SpecialGenericAlias", "object", ], + raise_duplicate_bases_error=False, ) def test_generator_from_infer_call_result_parent(self):