Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix duplicate bases error (MROs) with GenericAlias #910

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ Release Date: TBA

Closes #895 #899

* Use flag to guard DuplicateBasesError

Closes #905

What's New in astroid 2.5?
============================
Release Date: 2021-02-15
Expand Down
34 changes: 26 additions & 8 deletions astroid/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,20 @@ def _c3_merge(sequences, cls, context):
return None


def clean_duplicates_mro(sequences, cls, context):
def clean_duplicates_mro(
sequences, cls, context, raise_duplicate_bases_error: bool = True
):
for sequence in sequences:
names = [
(node.lineno, node.qname()) if node.name else None for node in sequence
]
last_index = dict(map(reversed, enumerate(names)))
if names and names[0] is not None and last_index[names[0]] != 0:
if (
raise_duplicate_bases_error is True
and names
and names[0] is not None
and last_index[names[0]] != 0
):
raise exceptions.DuplicateBasesError(
message="Duplicates found in MROs {mros} for {cls!r}.",
mros=sequences,
Expand Down Expand Up @@ -2816,7 +2823,7 @@ def slots(self):

def grouped_slots():
# Not interested in object, since it can't have slots.
for cls in self.mro()[:-1]:
for cls in self.mro(raise_duplicate_bases_error=False)[:-1]:
try:
cls_slots = cls._slots()
except NotImplementedError:
Expand Down Expand Up @@ -2871,15 +2878,18 @@ def _inferred_bases(self, context=None):
else:
yield from baseobj.bases

def _compute_mro(self, context=None):
def _compute_mro(self, context=None, raise_duplicate_bases_error: bool = True):
inferred_bases = list(self._inferred_bases(context=context))
bases_mro = []
for base in inferred_bases:
if base is self:
continue

try:
mro = base._compute_mro(context=context)
mro = base._compute_mro(
context=context,
raise_duplicate_bases_error=raise_duplicate_bases_error,
)
bases_mro.append(mro)
except NotImplementedError:
# Some classes have in their ancestors both newstyle and
Expand All @@ -2891,18 +2901,26 @@ def _compute_mro(self, context=None):
bases_mro.append(ancestors)

unmerged_mro = [[self]] + bases_mro + [inferred_bases]
unmerged_mro = list(clean_duplicates_mro(unmerged_mro, self, context))
unmerged_mro = list(
clean_duplicates_mro(
unmerged_mro, self, context, raise_duplicate_bases_error
)
)
return _c3_merge(unmerged_mro, self, context)

def mro(self, context=None) -> List["ClassDef"]:
def mro(
self, context=None, raise_duplicate_bases_error: bool = True
) -> List["ClassDef"]:
"""Get the method resolution order, using C3 linearization.

:returns: The list of ancestors, sorted by the mro.
:rtype: list(NodeNG)
:raises DuplicateBasesError: Duplicate bases in the same class base
:raises InconsistentMroError: A class' MRO is inconsistent
"""
return self._compute_mro(context=context)
return self._compute_mro(
context=context, raise_duplicate_bases_error=raise_duplicate_bases_error
)

def bool_value(self, context=None):
"""Determine the boolean value of this node.
Expand Down
53 changes: 51 additions & 2 deletions tests/unittest_scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,8 +1270,18 @@ class NodeBase(object):
assert len(slots) == 3, slots
assert [slot.value for slot in slots] == ["a", "b", "c"]

def assertEqualMro(self, klass, expected_mro):
self.assertEqual([member.name for member in klass.mro()], expected_mro)
def assertEqualMro(
self, klass, expected_mro, raise_duplicate_bases_error: bool = True
):
self.assertEqual(
[
member.name
for member in klass.mro(
raise_duplicate_bases_error=raise_duplicate_bases_error
)
],
expected_mro,
)

@unittest.skipUnless(HAS_SIX, "These tests require the six library")
def test_with_metaclass_mro(self):
Expand Down Expand Up @@ -1436,6 +1446,45 @@ class C(scope.A, scope.B):
)
self.assertEqualMro(cls, ["C", "A", "B", "object"])

@test_utils.require_version("3.7", "3.9")
def test_mro_with_duplicate_generic_alias(self):
"""Catch false positive. Assert no error is thrown."""
cls = builder.extract_node(
"""
from typing import Sized, Hashable
class Derived(Sized, Hashable):
def __init__(self):
self.var = 1
"""
)
self.assertEqualMro(
cls,
["Derived", "_GenericAlias", "_Final", "object"],
raise_duplicate_bases_error=False,
)

@test_utils.require_version("3.9")
def test_mro_with_duplicate_generic_alias_2(self):
cls = builder.extract_node(
"""
from typing import Sized, Hashable
class Derived(Sized, Hashable):
def __init__(self):
self.var = 1
"""
)
self.assertEqualMro(
cls,
[
"Derived",
"_SpecialGenericAlias",
"_BaseGenericAlias",
"_Final",
"object",
],
raise_duplicate_bases_error=False,
)

def test_generator_from_infer_call_result_parent(self):
func = builder.extract_node(
"""
Expand Down