Skip to content

Commit

Permalink
gh-95882: fix regression in the traceback of exceptions propagated fr…
Browse files Browse the repository at this point in the history
…om inside a contextlib context manager (GH-95883)

(cherry picked from commit b3722ca)

Co-authored-by: Thomas Grainger <tagrain@gmail.com>
  • Loading branch information
miss-islington and graingert authored Jan 3, 2023
1 parent b99ac1d commit 861cdef
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 4 deletions.
5 changes: 4 additions & 1 deletion Lib/contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __exit__(self, typ, value, traceback):
isinstance(value, StopIteration)
and exc.__cause__ is value
):
exc.__traceback__ = traceback
value.__traceback__ = traceback
return False
raise
except BaseException as exc:
Expand Down Expand Up @@ -228,6 +228,7 @@ async def __aexit__(self, typ, value, traceback):
except RuntimeError as exc:
# Don't re-raise the passed in exception. (issue27122)
if exc is value:
exc.__traceback__ = traceback
return False
# Avoid suppressing if a Stop(Async)Iteration exception
# was passed to athrow() and later wrapped into a RuntimeError
Expand All @@ -239,6 +240,7 @@ async def __aexit__(self, typ, value, traceback):
isinstance(value, (StopIteration, StopAsyncIteration))
and exc.__cause__ is value
):
value.__traceback__ = traceback
return False
raise
except BaseException as exc:
Expand All @@ -250,6 +252,7 @@ async def __aexit__(self, typ, value, traceback):
# and the __exit__() protocol.
if exc is not value:
raise
exc.__traceback__ = traceback
return False
raise RuntimeError("generator didn't stop after athrow()")

Expand Down
30 changes: 27 additions & 3 deletions Lib/test/test_contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,39 @@ def f():
self.assertEqual(frames[0].line, '1/0')

# Repeat with RuntimeError (which goes through a different code path)
class RuntimeErrorSubclass(RuntimeError):
pass

try:
with f():
raise NotImplementedError(42)
except NotImplementedError as e:
raise RuntimeErrorSubclass(42)
except RuntimeErrorSubclass as e:
frames = traceback.extract_tb(e.__traceback__)

self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
self.assertEqual(frames[0].line, 'raise NotImplementedError(42)')
self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')

class StopIterationSubclass(StopIteration):
pass

for stop_exc in (
StopIteration('spam'),
StopIterationSubclass('spam'),
):
with self.subTest(type=type(stop_exc)):
try:
with f():
raise stop_exc
except type(stop_exc) as e:
self.assertIs(e, stop_exc)
frames = traceback.extract_tb(e.__traceback__)
else:
self.fail(f'{stop_exc} was suppressed')

self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
self.assertEqual(frames[0].line, 'raise stop_exc')

def test_contextmanager_no_reraise(self):
@contextmanager
Expand Down
57 changes: 57 additions & 0 deletions Lib/test/test_contextlib_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools
from test import support
import unittest
import traceback

from test.test_contextlib import TestBaseExitStack

Expand Down Expand Up @@ -125,6 +126,62 @@ async def woohoo():
raise ZeroDivisionError()
self.assertEqual(state, [1, 42, 999])

@_async_test
async def test_contextmanager_traceback(self):
@asynccontextmanager
async def f():
yield

try:
async with f():
1/0
except ZeroDivisionError as e:
frames = traceback.extract_tb(e.__traceback__)

self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
self.assertEqual(frames[0].line, '1/0')

# Repeat with RuntimeError (which goes through a different code path)
class RuntimeErrorSubclass(RuntimeError):
pass

try:
async with f():
raise RuntimeErrorSubclass(42)
except RuntimeErrorSubclass as e:
frames = traceback.extract_tb(e.__traceback__)

self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')

class StopIterationSubclass(StopIteration):
pass

class StopAsyncIterationSubclass(StopAsyncIteration):
pass

for stop_exc in (
StopIteration('spam'),
StopAsyncIteration('ham'),
StopIterationSubclass('spam'),
StopAsyncIterationSubclass('spam')
):
with self.subTest(type=type(stop_exc)):
try:
async with f():
raise stop_exc
except type(stop_exc) as e:
self.assertIs(e, stop_exc)
frames = traceback.extract_tb(e.__traceback__)
else:
self.fail(f'{stop_exc} was suppressed')

self.assertEqual(len(frames), 1)
self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
self.assertEqual(frames[0].line, 'raise stop_exc')

@_async_test
async def test_contextmanager_no_reraise(self):
@asynccontextmanager
Expand Down
1 change: 1 addition & 0 deletions Misc/ACKS
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ Hans de Graaff
Tim Graham
Kim Gräsman
Alex Grönholm
Thomas Grainger
Nathaniel Gray
Eddy De Greef
Duane Griffin
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a 3.11 regression in :func:`~contextlib.asynccontextmanager`, which caused it to propagate exceptions with incorrect tracebacks and fix a 3.11 regression in :func:`~contextlib.contextmanager`, which caused it to propagate exceptions with incorrect tracebacks for :exc:`StopIteration`.

0 comments on commit 861cdef

Please sign in to comment.