diff --git a/ChangeLog b/ChangeLog index a11e076644..45476732f2 100644 --- a/ChangeLog +++ b/ChangeLog @@ -11,6 +11,9 @@ Release Date: TBA Closes #895 #899 +* Improve typing.TypedDict inference + + What's New in astroid 2.5? ============================ Release Date: 2021-02-15 diff --git a/astroid/brain/brain_typing.py b/astroid/brain/brain_typing.py index 60557952c8..16b6e84828 100644 --- a/astroid/brain/brain_typing.py +++ b/astroid/brain/brain_typing.py @@ -4,6 +4,7 @@ # Copyright (c) 2018 Bryce Guinta """Astroid hooks for typing.py support.""" +import sys import typing from astroid import ( @@ -12,9 +13,11 @@ extract_node, inference_tip, nodes, + context, InferenceError, ) +PY39 = sys.version_info[:2] >= (3, 9) TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"} TYPING_TYPEVARS = {"TypeVar", "NewType"} @@ -85,6 +88,28 @@ def infer_typing_attr(node, context=None): return node.infer(context=context) +def _looks_like_typedDict( # pylint: disable=invalid-name + node: nodes.FunctionDef, +) -> bool: + """Check if node is TypedDict FunctionDef.""" + return isinstance(node, nodes.FunctionDef) and node.name == "TypedDict" + + +def infer_typedDict( # pylint: disable=invalid-name + node: nodes.FunctionDef, ctx: context.InferenceContext = None +) -> None: + """Replace TypedDict FunctionDef with ClassDef.""" + class_def = nodes.ClassDef( + name="TypedDict", + doc=node.doc, + lineno=node.lineno, + col_offset=node.col_offset, + parent=node.parent, + ) + class_def.postinit(bases=[], body=[], decorators=None) + node.root().locals["TypedDict"] = [class_def] + + MANAGER.register_transform( nodes.Call, inference_tip(infer_typing_typevar_or_newtype), @@ -93,3 +118,8 @@ def infer_typing_attr(node, context=None): MANAGER.register_transform( nodes.Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript ) + +if PY39: + MANAGER.register_transform( + nodes.FunctionDef, infer_typedDict, _looks_like_typedDict + ) diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index dc3f25d430..40ee7dc3e3 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -1188,6 +1188,22 @@ def test_typing_namedtuple_dont_crash_on_no_fields(self): inferred = next(node.infer()) self.assertIsInstance(inferred, astroid.Instance) + @test_utils.require_version("3.8") + def test_typedDict(self): + node = builder.extract_node( + """ + from typing import TypedDict + class CustomTD(TypedDict): + var: int + """ + ) + assert len(node.bases) == 1 + inferred_base = next(node.bases[0].infer()) + self.assertIsInstance(inferred_base, nodes.ClassDef, node.as_string()) + typing_module = inferred_base.root() + assert len(typing_module.locals["TypedDict"]) == 1 + assert inferred_base == typing_module.locals["TypedDict"][0] + class ReBrainTest(unittest.TestCase): def test_regex_flags(self):