Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-119180: Avoid going through AST and eval() when possible in annotationlib #124337

Merged
merged 11 commits into from
Sep 25, 2024
77 changes: 50 additions & 27 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Helpers for introspecting and wrapping annotations."""

import ast
import builtins
import enum
import functools
import sys
Expand Down Expand Up @@ -154,8 +155,19 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
globals[param_name] = param
locals.pop(param_name, None)

code = self.__forward_code__
value = eval(code, globals=globals, locals=locals)
arg = self.__forward_arg__
if arg.isidentifier():
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
if arg in locals:
value = locals[arg]
elif globals is not None and arg in globals:
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
value = globals[arg]
elif hasattr(builtins, arg):
return getattr(builtins, arg)
else:
raise NameError(arg)
else:
code = self.__forward_code__
value = eval(code, globals=globals, locals=locals)
self.__forward_evaluated__ = True
self.__forward_value__ = value
return value
Expand Down Expand Up @@ -254,7 +266,9 @@ class _Stringifier:
__slots__ = _SLOTS

def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
assert isinstance(node, ast.AST)
# Either an AST node or a simple str (for the common case where a ForwardRef
# represent a single name).
assert isinstance(node, (ast.AST, str))
self.__arg__ = None
self.__forward_evaluated__ = False
self.__forward_value__ = None
Expand All @@ -267,18 +281,25 @@ def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
self.__cell__ = cell
self.__owner__ = owner

def __convert(self, other):
def __convert_to_ast(self, other):
if isinstance(other, _Stringifier):
if isinstance(other.__ast_node__, str):
return ast.Name(id=other.__ast_node__)
return other.__ast_node__
elif isinstance(other, slice):
return ast.Slice(
lower=self.__convert(other.start) if other.start is not None else None,
upper=self.__convert(other.stop) if other.stop is not None else None,
step=self.__convert(other.step) if other.step is not None else None,
lower=self.__convert_to_ast(other.start) if other.start is not None else None,
upper=self.__convert_to_ast(other.stop) if other.stop is not None else None,
step=self.__convert_to_ast(other.step) if other.step is not None else None,
)
else:
return ast.Constant(value=other)

def __get_ast(self):
if isinstance(self.__ast_node__, str):
return ast.Name(id=self.__ast_node__)
return self.__ast_node__
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved

def __make_new(self, node):
return _Stringifier(
node, self.__globals__, self.__owner__, self.__forward_is_class__
Expand All @@ -292,38 +313,37 @@ def __hash__(self):
def __getitem__(self, other):
# Special case, to avoid stringifying references to class-scoped variables
# as '__classdict__["x"]'.
if (
isinstance(self.__ast_node__, ast.Name)
and self.__ast_node__.id == "__classdict__"
):
if self.__ast_node__ == "__classdict__":
raise KeyError
if isinstance(other, tuple):
elts = [self.__convert(elt) for elt in other]
elts = [self.__convert_to_ast(elt) for elt in other]
other = ast.Tuple(elts)
else:
other = self.__convert(other)
other = self.__convert_to_ast(other)
assert isinstance(other, ast.AST), repr(other)
return self.__make_new(ast.Subscript(self.__ast_node__, other))
return self.__make_new(ast.Subscript(self.__get_ast(), other))

def __getattr__(self, attr):
return self.__make_new(ast.Attribute(self.__ast_node__, attr))
return self.__make_new(ast.Attribute(self.__get_ast(), attr))

def __call__(self, *args, **kwargs):
return self.__make_new(
ast.Call(
self.__ast_node__,
[self.__convert(arg) for arg in args],
self.__get_ast(),
[self.__convert_to_ast(arg) for arg in args],
[
ast.keyword(key, self.__convert(value))
ast.keyword(key, self.__convert_to_ast(value))
for key, value in kwargs.items()
],
)
)

def __iter__(self):
yield self.__make_new(ast.Starred(self.__ast_node__))
yield self.__make_new(ast.Starred(self.__get_ast()))

def __repr__(self):
if isinstance(self.__ast_node__, str):
return self.__ast_node__
return ast.unparse(self.__ast_node__)

def __format__(self, format_spec):
Expand All @@ -332,7 +352,7 @@ def __format__(self, format_spec):
def _make_binop(op: ast.AST):
def binop(self, other):
return self.__make_new(
ast.BinOp(self.__ast_node__, op, self.__convert(other))
ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other))
)

return binop
Expand All @@ -356,7 +376,7 @@ def binop(self, other):
def _make_rbinop(op: ast.AST):
def rbinop(self, other):
return self.__make_new(
ast.BinOp(self.__convert(other), op, self.__ast_node__)
ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast())
)

return rbinop
Expand All @@ -381,9 +401,9 @@ def _make_compare(op):
def compare(self, other):
return self.__make_new(
ast.Compare(
left=self.__ast_node__,
left=self.__get_ast(),
ops=[op],
comparators=[self.__convert(other)],
comparators=[self.__convert_to_ast(other)],
)
)

Expand All @@ -400,7 +420,7 @@ def compare(self, other):

def _make_unary_op(op):
def unary_op(self):
return self.__make_new(ast.UnaryOp(op, self.__ast_node__))
return self.__make_new(ast.UnaryOp(op, self.__get_ast()))

return unary_op

Expand All @@ -422,7 +442,7 @@ def __init__(self, namespace, globals=None, owner=None, is_class=False):

def __missing__(self, key):
fwdref = _Stringifier(
ast.Name(id=key),
key,
globals=self.globals,
owner=self.owner,
is_class=self.is_class,
Expand Down Expand Up @@ -480,7 +500,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
name = freevars[i]
else:
name = "__cell__"
fwdref = _Stringifier(ast.Name(id=name))
fwdref = _Stringifier(name)
new_closure.append(types.CellType(fwdref))
closure = tuple(new_closure)
else:
Expand Down Expand Up @@ -532,7 +552,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
else:
name = "__cell__"
fwdref = _Stringifier(
ast.Name(id=name),
name,
cell=cell,
owner=owner,
globals=annotate.__globals__,
Expand All @@ -555,6 +575,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
result = func(Format.VALUE)
for obj in globals.stringifiers:
obj.__class__ = ForwardRef
if isinstance(obj.__ast_node__, str):
obj.__arg__ = obj.__ast_node__
obj.__ast_node__ = None
return result
elif format == Format.VALUE:
# Should be impossible because __annotate__ functions must not raise
Expand Down
28 changes: 27 additions & 1 deletion Lib/test/test_annotationlib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for the annotations module."""

import annotationlib
import builtins
import collections
import functools
import itertools
Expand Down Expand Up @@ -280,7 +281,13 @@ class Gen[T]:

def test_fwdref_with_module(self):
self.assertIs(ForwardRef("Format", module="annotationlib").evaluate(), Format)
self.assertIs(ForwardRef("Counter", module="collections").evaluate(), collections.Counter)
self.assertIs(
ForwardRef("Counter", module="collections").evaluate(), collections.Counter
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
)
self.assertEqual(
ForwardRef("Counter[int]", module="collections").evaluate(),
collections.Counter[int],
)

with self.assertRaises(NameError):
# If globals are passed explicitly, we don't look at the module dict
Expand All @@ -305,6 +312,25 @@ def test_fwdref_value_is_cached(self):
self.assertIs(fr.evaluate(globals={"hello": str}), str)
self.assertIs(fr.evaluate(), str)

def test_fwdref_with_owner(self):
self.assertEqual(
ForwardRef("Counter[int]", owner=collections).evaluate(),
collections.Counter[int],
)

def test_name_lookup_without_eval(self):
# test the codepath where we look up simple names directly in the
# namespaces without going through eval()
self.assertIs(ForwardRef("int").evaluate(), int)
self.assertIs(ForwardRef("int").evaluate(locals={"int": str}), str)
self.assertIs(ForwardRef("int").evaluate(locals={"int": float}, globals={"int": str}), float)
self.assertIs(ForwardRef("int").evaluate(globals={"int": str}), str)
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
with support.swap_attr(builtins, "int", dict):
self.assertIs(ForwardRef("int").evaluate(), dict)

with self.assertRaises(NameError):
ForwardRef("doesntexist").evaluate()


class TestGetAnnotations(unittest.TestCase):
def test_builtin_type(self):
Expand Down
Loading