From 2bb499d53e642255b4172e884dcce129c0959683 Mon Sep 17 00:00:00 2001 From: Mark Byrne Date: Mon, 7 Aug 2023 18:53:59 +0200 Subject: [PATCH] Infer user-defined enum classes by checking if the class is a subtype of ``enum.Enum``. Closes pylint-dev/pylint#8897 --- ChangeLog | 4 +++ astroid/brain/brain_namedtuple_enum.py | 18 +------------ tests/brain/test_enum.py | 36 ++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/ChangeLog b/ChangeLog index 5b2ff4d0bb..028c0e00f7 100644 --- a/ChangeLog +++ b/ChangeLog @@ -221,6 +221,10 @@ Release date: TBA Closes pylint-dev/pylint#8802 +* Infer user-defined enum classes by checking if the class is a subtype of ``enum.Enum``. + + Closes pylint-dev/pylint#8897 + * Fix inference of functions with ``@functools.lru_cache`` decorators without parentheses. diff --git a/astroid/brain/brain_namedtuple_enum.py b/astroid/brain/brain_namedtuple_enum.py index 7212f89083..4b727b568f 100644 --- a/astroid/brain/brain_namedtuple_enum.py +++ b/astroid/brain/brain_namedtuple_enum.py @@ -20,19 +20,10 @@ AstroidTypeError, AstroidValueError, InferenceError, - MroError, UseInferenceDefault, ) from astroid.manager import AstroidManager -ENUM_BASE_NAMES = { - "Enum", - "IntEnum", - "enum.Enum", - "enum.IntEnum", - "IntFlag", - "enum.IntFlag", -} ENUM_QNAME: Final[str] = "enum.Enum" TYPING_NAMEDTUPLE_QUALIFIED: Final = { "typing.NamedTuple", @@ -644,14 +635,7 @@ def _get_namedtuple_fields(node: nodes.Call) -> str: def _is_enum_subclass(cls: astroid.ClassDef) -> bool: """Return whether cls is a subclass of an Enum.""" - try: - return any( - klass.name in ENUM_BASE_NAMES - and getattr(klass.root(), "name", None) == "enum" - for klass in cls.mro() - ) - except MroError: - return False + return cls.is_subtype_of("enum.Enum") AstroidManager().register_transform( diff --git a/tests/brain/test_enum.py b/tests/brain/test_enum.py index bbdb812cee..38fa55316f 100644 --- a/tests/brain/test_enum.py +++ b/tests/brain/test_enum.py @@ -521,3 +521,39 @@ def __init__(self, mass, radius): mars, radius = enum_members.items assert mars[1].name == "MARS" assert radius[1].name == "radius" + + def test_local_enum_child_class_inference(self) -> None: + """Originally reported in https://github.com/pylint-dev/pylint/issues/8897 + + Test that a user-defined enum class is inferred when it subclasses + another user-defined enum class. + """ + enum_class_node, enum_member_value_node = astroid.extract_node( + """ + import sys + + from enum import Enum + + if sys.version_info >= (3, 11): + from enum import StrEnum + else: + class StrEnum(str, Enum): + pass + + + class Color(StrEnum): #@ + RED = "red" + + + Color.RED.value #@ + """ + ) + assert "RED" in enum_class_node.locals + + enum_members = enum_class_node.locals["__members__"][0].items + assert len(enum_members) == 1 + _, name = enum_members[0] + assert name.name == "RED" + + inferred_enum_member_value_node = next(enum_member_value_node.infer()) + assert inferred_enum_member_value_node.value == "red"