Skip to content

Commit

Permalink
pythongh-125507: Call annotate(FORWARDREF) before trying __annotations__
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra committed Oct 15, 2024
1 parent 703227d commit 74f990a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
16 changes: 9 additions & 7 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,13 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
for key, val in annos.items()
}
elif format == Format.FORWARDREF:
# In FORWARDREF format, try returning the owner's __annotations__ first,
# if they exist.
if owner is not None:
try:
return _get_dunder_annotations(owner)
except NameError:
pass
# FORWARDREF is implemented similarly to STRING, but there are two changes,
# at the beginning and the end of the process.
# First, while STRING uses an empty dictionary as the namespace, so that all
Expand Down Expand Up @@ -683,13 +690,8 @@ def get_annotations(
# For VALUE, we only look at __annotations__
ann = _get_dunder_annotations(obj)
case Format.FORWARDREF:
# For FORWARDREF, we use __annotations__ if it exists
try:
return dict(_get_dunder_annotations(obj))
except NameError:
pass

# But if __annotations__ threw a NameError, we try calling __annotate__
# First we use call_annotate_function(), which will internally also
# try __annotations__ if the FORWARDREF format is passed.
ann = _get_and_call_annotate(obj, format)
if ann is not None:
return ann
Expand Down
20 changes: 19 additions & 1 deletion Lib/test/test_annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,20 @@ def f(x: int, y: doesntexist):
fwdref.evaluate()
self.assertEqual(fwdref.evaluate(globals={"doesntexist": 1}), 1)

def test_custom_annotate(self):
def __annotate__(format):
return {"a": Format(format).name}

class C:
pass

C.__annotate__ = __annotate__

for format in Format:
with self.subTest(format=format):
anno = annotationlib.get_annotations(C, format=format)
self.assertEqual(anno, {"a": format.name})


class TestSourceFormat(unittest.TestCase):
def test_closure(self):
Expand Down Expand Up @@ -809,7 +823,11 @@ def __annotations__(self):

@property
def __annotate__(self):
return lambda format: {"x": str}
def anno(format):
if format == Format.FORWARDREF:
raise NotImplementedError
return {"x": str}
return anno

hb = HasBoth()
self.assertEqual(
Expand Down

0 comments on commit 74f990a

Please sign in to comment.