diff --git a/astroid/nodes/scoped_nodes.py b/astroid/nodes/scoped_nodes.py index 84fbd09bea..9ad5e864ca 100644 --- a/astroid/nodes/scoped_nodes.py +++ b/astroid/nodes/scoped_nodes.py @@ -43,6 +43,7 @@ import io import itertools import typing +from collections import OrderedDict from typing import List, Optional from astroid import bases @@ -1962,6 +1963,9 @@ def my_meth(self, arg): # a dictionary of class instances attributes _astroid_fields = ("decorators", "bases", "keywords", "body") # name + all_ancestors = {} + direct_ancestors = {} + decorators = None """The decorators that are applied to this class. @@ -2331,13 +2335,26 @@ def ancestors(self, recurs=True, context=None): :returns: The base classes :rtype: iterable(NodeNG) """ + if recurs and context in self.all_ancestors: + yield from self.all_ancestors[context].keys() + elif not recurs and context in self.direct_ancestors: + yield from self.direct_ancestors[context].keys() + # FIXME: should be possible to choose the resolution order # FIXME: inference make infinite loops possible here - yielded = {self} + yielded = OrderedDict() + yielded[self] = None if context is None: context = contextmod.InferenceContext() if not self.bases and self.qname() != "builtins.object": - yield builtin_lookup("object")[1][0] + result = builtin_lookup("object")[1][0] + yielded[result] = None + del yielded[self] + if recurs: + self.all_ancestors[context] = yielded + else: + self.direct_ancestors[context] = yielded + yield result return for stmt in self.bases: @@ -2352,7 +2369,7 @@ def ancestors(self, recurs=True, context=None): if not baseobj.hide: if baseobj in yielded: continue - yielded.add(baseobj) + yielded[baseobj] = None yield baseobj if not recurs: continue @@ -2362,11 +2379,17 @@ def ancestors(self, recurs=True, context=None): break if grandpa in yielded: continue - yielded.add(grandpa) + yielded[grandpa] = None yield grandpa except InferenceError: continue + del yielded[self] + if recurs: + self.all_ancestors[context] = yielded + else: + self.direct_ancestors[context] = yielded + def local_attr_ancestors(self, name, context=None): """Iterate over the parents that define the given name. diff --git a/tests/unittest_scoped_nodes.py b/tests/unittest_scoped_nodes.py index a868e564d4..3bfe3a1d13 100644 --- a/tests/unittest_scoped_nodes.py +++ b/tests/unittest_scoped_nodes.py @@ -1040,7 +1040,7 @@ class Past(Present): astroid = builder.parse(data) past = astroid["Past"] attr = past.getattr("attr") - self.assertEqual(len(attr), 1) + self.assertEqual(len(attr), 1, attr) attr1 = attr[0] self.assertIsInstance(attr1, nodes.AssignName) self.assertEqual(attr1.name, "attr")