Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inference of Enum subclasses. #1121

Merged
merged 4 commits into from
Aug 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ Release date: TBA
* Import from ``astroid.node_classes`` and ``astroid.scoped_nodes`` has been deprecated in favor of
``astroid.nodes``. Only the imports from ``astroid.nodes`` will work in astroid 3.0.0.

* Add support for arbitrary Enum subclass hierachies

Closes PyCQA/pylint#533
Closes PyCQA/pylint#2224
Closes PyCQA/pylint#2626


What's New in astroid 2.6.7?
============================
Release date: TBA
Expand Down
30 changes: 19 additions & 11 deletions astroid/brain/brain_namedtuple_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
import keyword
from textwrap import dedent

import astroid
from astroid import arguments, inference_tip, nodes, util
from astroid.builder import AstroidBuilder, extract_node
from astroid.exceptions import (
AstroidTypeError,
AstroidValueError,
InferenceError,
MroError,
UseInferenceDefault,
)
from astroid.manager import AstroidManager
Expand Down Expand Up @@ -354,9 +356,7 @@ def __mul__(self, other):

def infer_enum_class(node):
"""Specific inference for enums."""
for basename in node.basenames:
# TODO: doesn't handle subclasses yet. This implementation
# is a hack to support enums.
for basename in (b for cls in node.mro() for b in cls.basenames):
if basename not in ENUM_BASE_NAMES:
continue
if node.root().name == "enum":
Expand Down Expand Up @@ -417,9 +417,9 @@ def name(self):
# should result in some nice symbolic execution
classdef += INT_FLAG_ADDITION_METHODS.format(name=target.name)

fake = AstroidBuilder(AstroidManager()).string_build(classdef)[
target.name
]
fake = AstroidBuilder(
AstroidManager(), apply_transforms=False
).string_build(classdef)[target.name]
fake.parent = target.parent
for method in node.mymethods():
fake.locals[method.name] = [method]
Expand Down Expand Up @@ -544,18 +544,26 @@ def infer_typing_namedtuple(node, context=None):
return infer_named_tuple(node, context)


def _is_enum_subclass(cls: astroid.ClassDef) -> bool:
"""Return whether cls is a subclass of an Enum."""
try:
return any(
klass.name in ENUM_BASE_NAMES
and getattr(klass.root(), "name", None) == "enum"
for klass in cls.mro()
)
except MroError:
return False


AstroidManager().register_transform(
nodes.Call, inference_tip(infer_named_tuple), _looks_like_namedtuple
)
AstroidManager().register_transform(
nodes.Call, inference_tip(infer_enum), _looks_like_enum
)
AstroidManager().register_transform(
nodes.ClassDef,
infer_enum_class,
predicate=lambda cls: any(
basename for basename in cls.basenames if basename in ENUM_BASE_NAMES
),
nodes.ClassDef, infer_enum_class, predicate=_is_enum_subclass
)
AstroidManager().register_transform(
nodes.ClassDef, inference_tip(infer_typing_namedtuple_class), _has_namedtuple_base
Expand Down
9 changes: 7 additions & 2 deletions astroid/nodes/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from astroid.interpreter.dunder_lookup import lookup
from astroid.interpreter.objectmodel import ClassModel, FunctionModel, ModuleModel
from astroid.manager import AstroidManager
from astroid.nodes import node_classes
from astroid.nodes import Const, node_classes

ITER_METHODS = ("__iter__", "__getitem__")
EXCEPTION_BASE_CLASSES = frozenset({"Exception", "BaseException"})
Expand Down Expand Up @@ -2962,7 +2962,12 @@ def _inferred_bases(self, context=None):

for stmt in self.bases:
try:
baseobj = next(stmt.infer(context=context.clone()))
# Find the first non-None inferred base value
baseobj = next(
b
for b in stmt.infer(context=context.clone())
if not (isinstance(b, Const) and b.value is None)
)
except (InferenceError, StopIteration):
continue
if isinstance(baseobj, bases.Instance):
Expand Down
109 changes: 104 additions & 5 deletions tests/unittest_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,24 @@ def __init__(self, name, enum_list):
test = next(enumeration.igetattr("test"))
self.assertEqual(test.value, 42)

def test_user_enum_false_positive(self):
# Test that a user-defined class named Enum is not considered a builtin enum.
ast_node = astroid.extract_node(
"""
class Enum:
pass

class Color(Enum):
red = 1

Color.red #@
"""
)
inferred = ast_node.inferred()
self.assertEqual(len(inferred), 1)
self.assertIsInstance(inferred[0], astroid.Const)
self.assertEqual(inferred[0].value, 1)

def test_ignores_with_nodes_from_body_of_enum(self):
code = """
import enum
Expand Down Expand Up @@ -1051,6 +1069,91 @@ def func(self):
assert isinstance(inferred, bases.Instance)
assert inferred.pytype() == ".TrickyEnum.value"

def test_enum_subclass_member_name(self):
ast_node = astroid.extract_node(
"""
from enum import Enum

class EnumSubclass(Enum):
pass

class Color(EnumSubclass):
red = 1

Color.red.name #@
"""
)
inferred = ast_node.inferred()
self.assertEqual(len(inferred), 1)
self.assertIsInstance(inferred[0], astroid.Const)
self.assertEqual(inferred[0].value, "red")

def test_enum_subclass_member_value(self):
ast_node = astroid.extract_node(
"""
from enum import Enum

class EnumSubclass(Enum):
pass

class Color(EnumSubclass):
red = 1

Color.red.value #@
"""
)
inferred = ast_node.inferred()
self.assertEqual(len(inferred), 1)
self.assertIsInstance(inferred[0], astroid.Const)
self.assertEqual(inferred[0].value, 1)

def test_enum_subclass_member_method(self):
# See Pylint issue #2626
ast_node = astroid.extract_node(
"""
from enum import Enum

class EnumSubclass(Enum):
def hello_pylint(self) -> str:
return self.name

class Color(EnumSubclass):
red = 1

Color.red.hello_pylint() #@
"""
)
inferred = ast_node.inferred()
self.assertEqual(len(inferred), 1)
self.assertIsInstance(inferred[0], astroid.Const)
self.assertEqual(inferred[0].value, "red")

def test_enum_subclass_different_modules(self):
# See Pylint issue #2626
astroid.extract_node(
"""
from enum import Enum

class EnumSubclass(Enum):
pass
""",
"a",
)
ast_node = astroid.extract_node(
"""
from a import EnumSubclass

class Color(EnumSubclass):
red = 1

Color.red.value #@
"""
)
inferred = ast_node.inferred()
self.assertEqual(len(inferred), 1)
self.assertIsInstance(inferred[0], astroid.Const)
self.assertEqual(inferred[0].value, 1)


@unittest.skipUnless(HAS_DATEUTIL, "This test requires the dateutil library.")
class DateutilBrainTest(unittest.TestCase):
Expand Down Expand Up @@ -1568,7 +1671,7 @@ def test_typing_annotated_subscriptable(self):

@test_utils.require_version(minver="3.7")
def test_typing_generic_slots(self):
"""Test cache reset for slots if Generic subscript is inferred."""
"""Test slots for Generic subclass."""
node = builder.extract_node(
"""
from typing import Generic, TypeVar
Expand All @@ -1580,10 +1683,6 @@ def __init__(self, value):
"""
)
inferred = next(node.infer())
assert len(inferred.slots()) == 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Pierre-Sassoulas this test failed after I removed the additional check: len(inferred.slots()) is 1 at this line. It seems the caching behaviour being tested no longer applies, since mro is called earlier, so the bases of the ClassDef are inferred earlier as well.

I'm not sure if this is good or bad, but it's definitely a change in behaviour. This functionality was introduced in #931, so you might want to check that PR and with the author?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cdce8p was the slots length supposed to be zero in the following code ? Do you remember what could be the problem if it's not ?

from typing import Generic, TypeVar

T = TypeVar('T')


class A(Generic[T]):
    __slots__ = ['value']
    
     def __init__(self, value):
         self.value = value

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, it's safe to remove that part of the check. As @david-yz-liu already mentioned, it was testing the, in that case invalid, caching of slots. Previously, slots would have been empty as the base wouldn't have been inferred as Generic just yet. What's important is that slots eventually contains the correct value, as tested a few lines below.

It seems as this behavior change is a side effect of this PR.

# Only after the subscript base is inferred and the inference tip applied,
# will slots contain the correct value
next(node.bases[0].infer())
slots = inferred.slots()
assert len(slots) == 1
assert isinstance(slots[0], nodes.Const)
Expand Down