Skip to content

Commit

Permalink
gh-119180: Avoid going through AST and eval() when possible in annota…
Browse files Browse the repository at this point in the history
…tionlib (#124337)

Often, ForwardRefs represent a single simple name. In that case, we
can avoid going through the overhead of creating AST nodes and code
objects and calling eval(): we can simply look up the name directly
in the relevant namespaces.

Co-authored-by: Victor Stinner <vstinner@python.org>
  • Loading branch information
JelleZijlstra and vstinner authored Sep 25, 2024
1 parent 9d8f2d8 commit 17a544b
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 28 deletions.
79 changes: 52 additions & 27 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Helpers for introspecting and wrapping annotations."""

import ast
import builtins
import enum
import functools
import keyword
import sys
import types

Expand Down Expand Up @@ -154,8 +156,19 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
globals[param_name] = param
locals.pop(param_name, None)

code = self.__forward_code__
value = eval(code, globals=globals, locals=locals)
arg = self.__forward_arg__
if arg.isidentifier() and not keyword.iskeyword(arg):
if arg in locals:
value = locals[arg]
elif arg in globals:
value = globals[arg]
elif hasattr(builtins, arg):
return getattr(builtins, arg)
else:
raise NameError(arg)
else:
code = self.__forward_code__
value = eval(code, globals=globals, locals=locals)
self.__forward_evaluated__ = True
self.__forward_value__ = value
return value
Expand Down Expand Up @@ -254,7 +267,9 @@ class _Stringifier:
__slots__ = _SLOTS

def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
assert isinstance(node, ast.AST)
# Either an AST node or a simple str (for the common case where a ForwardRef
# represent a single name).
assert isinstance(node, (ast.AST, str))
self.__arg__ = None
self.__forward_evaluated__ = False
self.__forward_value__ = None
Expand All @@ -267,18 +282,26 @@ def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
self.__cell__ = cell
self.__owner__ = owner

def __convert(self, other):
def __convert_to_ast(self, other):
if isinstance(other, _Stringifier):
if isinstance(other.__ast_node__, str):
return ast.Name(id=other.__ast_node__)
return other.__ast_node__
elif isinstance(other, slice):
return ast.Slice(
lower=self.__convert(other.start) if other.start is not None else None,
upper=self.__convert(other.stop) if other.stop is not None else None,
step=self.__convert(other.step) if other.step is not None else None,
lower=self.__convert_to_ast(other.start) if other.start is not None else None,
upper=self.__convert_to_ast(other.stop) if other.stop is not None else None,
step=self.__convert_to_ast(other.step) if other.step is not None else None,
)
else:
return ast.Constant(value=other)

def __get_ast(self):
node = self.__ast_node__
if isinstance(node, str):
return ast.Name(id=node)
return node

def __make_new(self, node):
return _Stringifier(
node, self.__globals__, self.__owner__, self.__forward_is_class__
Expand All @@ -292,38 +315,37 @@ def __hash__(self):
def __getitem__(self, other):
# Special case, to avoid stringifying references to class-scoped variables
# as '__classdict__["x"]'.
if (
isinstance(self.__ast_node__, ast.Name)
and self.__ast_node__.id == "__classdict__"
):
if self.__ast_node__ == "__classdict__":
raise KeyError
if isinstance(other, tuple):
elts = [self.__convert(elt) for elt in other]
elts = [self.__convert_to_ast(elt) for elt in other]
other = ast.Tuple(elts)
else:
other = self.__convert(other)
other = self.__convert_to_ast(other)
assert isinstance(other, ast.AST), repr(other)
return self.__make_new(ast.Subscript(self.__ast_node__, other))
return self.__make_new(ast.Subscript(self.__get_ast(), other))

def __getattr__(self, attr):
return self.__make_new(ast.Attribute(self.__ast_node__, attr))
return self.__make_new(ast.Attribute(self.__get_ast(), attr))

def __call__(self, *args, **kwargs):
return self.__make_new(
ast.Call(
self.__ast_node__,
[self.__convert(arg) for arg in args],
self.__get_ast(),
[self.__convert_to_ast(arg) for arg in args],
[
ast.keyword(key, self.__convert(value))
ast.keyword(key, self.__convert_to_ast(value))
for key, value in kwargs.items()
],
)
)

def __iter__(self):
yield self.__make_new(ast.Starred(self.__ast_node__))
yield self.__make_new(ast.Starred(self.__get_ast()))

def __repr__(self):
if isinstance(self.__ast_node__, str):
return self.__ast_node__
return ast.unparse(self.__ast_node__)

def __format__(self, format_spec):
Expand All @@ -332,7 +354,7 @@ def __format__(self, format_spec):
def _make_binop(op: ast.AST):
def binop(self, other):
return self.__make_new(
ast.BinOp(self.__ast_node__, op, self.__convert(other))
ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other))
)

return binop
Expand All @@ -356,7 +378,7 @@ def binop(self, other):
def _make_rbinop(op: ast.AST):
def rbinop(self, other):
return self.__make_new(
ast.BinOp(self.__convert(other), op, self.__ast_node__)
ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast())
)

return rbinop
Expand All @@ -381,9 +403,9 @@ def _make_compare(op):
def compare(self, other):
return self.__make_new(
ast.Compare(
left=self.__ast_node__,
left=self.__get_ast(),
ops=[op],
comparators=[self.__convert(other)],
comparators=[self.__convert_to_ast(other)],
)
)

Expand All @@ -400,7 +422,7 @@ def compare(self, other):

def _make_unary_op(op):
def unary_op(self):
return self.__make_new(ast.UnaryOp(op, self.__ast_node__))
return self.__make_new(ast.UnaryOp(op, self.__get_ast()))

return unary_op

Expand All @@ -422,7 +444,7 @@ def __init__(self, namespace, globals=None, owner=None, is_class=False):

def __missing__(self, key):
fwdref = _Stringifier(
ast.Name(id=key),
key,
globals=self.globals,
owner=self.owner,
is_class=self.is_class,
Expand Down Expand Up @@ -480,7 +502,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
name = freevars[i]
else:
name = "__cell__"
fwdref = _Stringifier(ast.Name(id=name))
fwdref = _Stringifier(name)
new_closure.append(types.CellType(fwdref))
closure = tuple(new_closure)
else:
Expand Down Expand Up @@ -532,7 +554,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
else:
name = "__cell__"
fwdref = _Stringifier(
ast.Name(id=name),
name,
cell=cell,
owner=owner,
globals=annotate.__globals__,
Expand All @@ -555,6 +577,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
result = func(Format.VALUE)
for obj in globals.stringifiers:
obj.__class__ = ForwardRef
if isinstance(obj.__ast_node__, str):
obj.__arg__ = obj.__ast_node__
obj.__ast_node__ = None
return result
elif format == Format.VALUE:
# Should be impossible because __annotate__ functions must not raise
Expand Down
37 changes: 36 additions & 1 deletion Lib/test/test_annotationlib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for the annotations module."""

import annotationlib
import builtins
import collections
import functools
import itertools
Expand Down Expand Up @@ -280,7 +281,14 @@ class Gen[T]:

def test_fwdref_with_module(self):
self.assertIs(ForwardRef("Format", module="annotationlib").evaluate(), Format)
self.assertIs(ForwardRef("Counter", module="collections").evaluate(), collections.Counter)
self.assertIs(
ForwardRef("Counter", module="collections").evaluate(),
collections.Counter
)
self.assertEqual(
ForwardRef("Counter[int]", module="collections").evaluate(),
collections.Counter[int],
)

with self.assertRaises(NameError):
# If globals are passed explicitly, we don't look at the module dict
Expand All @@ -305,6 +313,33 @@ def test_fwdref_value_is_cached(self):
self.assertIs(fr.evaluate(globals={"hello": str}), str)
self.assertIs(fr.evaluate(), str)

def test_fwdref_with_owner(self):
self.assertEqual(
ForwardRef("Counter[int]", owner=collections).evaluate(),
collections.Counter[int],
)

def test_name_lookup_without_eval(self):
# test the codepath where we look up simple names directly in the
# namespaces without going through eval()
self.assertIs(ForwardRef("int").evaluate(), int)
self.assertIs(ForwardRef("int").evaluate(locals={"int": str}), str)
self.assertIs(ForwardRef("int").evaluate(locals={"int": float}, globals={"int": str}), float)
self.assertIs(ForwardRef("int").evaluate(globals={"int": str}), str)
with support.swap_attr(builtins, "int", dict):
self.assertIs(ForwardRef("int").evaluate(), dict)

with self.assertRaises(NameError):
ForwardRef("doesntexist").evaluate()

def test_fwdref_invalid_syntax(self):
fr = ForwardRef("if")
with self.assertRaises(SyntaxError):
fr.evaluate()
fr = ForwardRef("1+")
with self.assertRaises(SyntaxError):
fr.evaluate()


class TestGetAnnotations(unittest.TestCase):
def test_builtin_type(self):
Expand Down

0 comments on commit 17a544b

Please sign in to comment.