diff --git a/Lib/exception_group.py b/Lib/exception_group.py new file mode 100644 index 00000000000000..0fd11f278447da --- /dev/null +++ b/Lib/exception_group.py @@ -0,0 +1,177 @@ + +import sys +import textwrap +import traceback + + +class ExceptionGroup(BaseException): + def __init__(self, message, *excs): + """ Construct a new ExceptionGroup + + message: The exception Group's error message + excs: sequence of exceptions + """ + assert message is None or isinstance(message, str) + assert all(isinstance(e, BaseException) for e in excs) + + self.message = message + self.excs = excs + super().__init__(self.message) + + @staticmethod + def project(exc, condition, with_complement=False): + """ Split an ExceptionGroup based on an exception predicate + + returns a new ExceptionGroup, match, of the exceptions of exc + for which condition returns true. If with_complement is true, + returns another ExceptionGroup for the exception for which + condition returns false. Note that condition is checked for + exc and nested ExceptionGroups as well, and if it returns true + then the whole ExceptionGroup is considered to be matched. + + match and rest have the same nested structure as exc, but empty + sub-exceptions are not included. They have the same message, + __traceback__, __cause__ and __context__ fields as exc. + + condition: BaseException --> Boolean + with_complement: Bool If True, construct also an EG of the non-matches + """ + + if condition(exc): + return exc, None + elif not isinstance(exc, ExceptionGroup): + return None, exc if with_complement else None + else: + # recurse into ExceptionGroup + match_exc = rest_exc = None + match = [] + rest = [] if with_complement else None + for e in exc.excs: + e_match, e_rest = ExceptionGroup.project( + e, condition, with_complement=with_complement) + + if e_match is not None: + match.append(e_match) + if with_complement and e_rest is not None: + rest.append(e_rest) + + def copy_metadata(src, target): + target.__traceback__ = src.__traceback__ + target.__context__ = src.__context__ + target.__cause__ = src.__cause__ + + if match: + match_exc = ExceptionGroup(exc.message, *match) + copy_metadata(exc, match_exc) + if with_complement and rest: + rest_exc = ExceptionGroup(exc.message, *rest) + copy_metadata(exc, rest_exc) + return match_exc, rest_exc + + def split(self, type): + """ Split an ExceptionGroup to extract exceptions of type E + + type: An exception type + """ + return self.project( + self, lambda e: isinstance(e, type), with_complement=True) + + def subgroup(self, keep): + """ Split an ExceptionGroup to extract only exceptions in keep + + keep: List[BaseException] + """ + match, _ = self.project(self, lambda e: e in keep) + return match + + def __iter__(self): + ''' iterate over the individual exceptions (flattens the tree) ''' + for e in self.excs: + if isinstance(e, ExceptionGroup): + for e_ in e: + yield e_ + else: + yield e + + def __repr__(self): + return f"ExceptionGroup({self.message}, {self.excs})" + + @staticmethod + def catch(types, handler): + return ExceptionGroupCatcher(types, handler) + + +class ExceptionGroupCatcher: + """ Based on trio.MultiErrorCatcher """ + + def __init__(self, types, handler): + """ Context manager to catch and handle ExceptionGroups + + types: the exception types that this handler is interested in + handler: a function that takes an ExceptionGroup of the + matched type and does something with them + + Any rest exceptions are raised at the end as another + exception group + """ + self.types = types + self.handler = handler + + def __enter__(self): + pass + + def __exit__(self, etype, exc, tb): + if exc is not None and isinstance(exc, ExceptionGroup): + match, rest = exc.split(self.types) + + if match is None: + # Let the interpreter reraise the exception + return False + + naked_raise = False + handler_excs = None + try: + naked_raise = self.handler(match) + except (Exception, ExceptionGroup) as e: + handler_excs = e + + if naked_raise or handler_excs is match: + # handler reraised all of the matched exceptions. + # reraise exc as is. + return False + + if handler_excs is None: + if rest is None: + # handled and swallowed all exceptions + # do not raise anything. + return True + else: + # raise the rest exceptions + to_raise = rest + elif rest is None: + to_raise = handler_excs # raise what handler returned + else: + # Merge handler's exceptions with rest + # to_keep: EG subgroup of exc with only those to reraise + # (either not matched or reraised by handler) + to_keep = exc.subgroup( + list(rest) + [e for e in handler_excs if e in match]) + # to_add: new exceptions raised by handler + to_add = handler_excs.subgroup( + [e for e in handler_excs if e not in match]) + if to_add is not None: + to_raise = ExceptionGroup(exc.message, to_keep, to_add) + else: + to_raise = to_keep + + # When we raise to_raise, Python will unconditionally blow + # away its __context__ attribute and replace it with the original + # exc we caught. So after we raise it, we have to pause it while + # it's in flight to put the correct __context__ back. + old_context = to_raise.__context__ + try: + raise to_raise + finally: + _, value, _ = sys.exc_info() + assert value is to_raise + value.__context__ = old_context diff --git a/Lib/test/test_exception_group.py b/Lib/test/test_exception_group.py new file mode 100644 index 00000000000000..96089a46823f23 --- /dev/null +++ b/Lib/test/test_exception_group.py @@ -0,0 +1,491 @@ + +import collections.abc +import functools +import traceback +import unittest +from exception_group import ExceptionGroup +from io import StringIO + + +def newEG(message, raisers, cls=ExceptionGroup): + excs = [] + for r in raisers: + try: + r() + except (Exception, ExceptionGroup) as e: + excs.append(e) + try: + raise cls(message, *excs) + except ExceptionGroup as e: + return e + +def newVE(v): + raise ValueError(v) + +def newTE(t): + raise TypeError(t) + +def newSimpleEG(msg=None): + bind = functools.partial + return newEG(msg, [bind(newVE, 1), bind(newTE, int), bind(newVE, 2)]) + +class MyExceptionGroup(ExceptionGroup): + pass + +def newNestedEG(arg, message=None): + bind = functools.partial + + def level1(i): + return newEG( + 'msg1', + [bind(newVE, i), bind(newTE, int), bind(newVE, i+1)]) + + def raiseExc(e): + raise e + + def level2(i): + return newEG( + 'msg2', + [bind(raiseExc, level1(i)), + bind(raiseExc, level1(i+1)), + bind(newVE, i+2), + ], + cls=MyExceptionGroup) + + def level3(i): + return newEG( + 'msg3', + [bind(raiseExc, level2(i+1)), bind(newVE, i+2)]) + + return level3(arg) + +def extract_traceback(exc, eg): + """ returns the traceback of a single exception + + If exc is in the exception group, return its + traceback as the concatenation of the outputs + of traceback.extract_tb() on each segment of + it traceback (one per each ExceptionGroup that + it belongs to). + """ + if exc not in eg: + return None + e = eg.subgroup([exc]) + result = None + while e is not None: + if isinstance(e, ExceptionGroup): + assert len(e.excs) == 1 and exc in e + r = traceback.extract_tb(e.__traceback__) + if result is not None: + result.extend(r) + else: + result = r + e = e.excs[0] if isinstance(e, ExceptionGroup) else None + return result + +class ExceptionGroupTestBase(unittest.TestCase): + def assertMatchesTemplate(self, exc, template): + """ Assert that the exception matches the template """ + if isinstance(exc, ExceptionGroup): + self.assertIsInstance(template, collections.abc.Sequence) + self.assertEqual(len(exc.excs), len(template)) + for e, t in zip(exc.excs, template): + self.assertMatchesTemplate(e, t) + else: + self.assertIsInstance(template, BaseException) + self.assertEqual(type(exc), type(template)) + self.assertEqual(exc.args, template.args) + + +class ExceptionGroupBasicsTests(ExceptionGroupTestBase): + def test_simple_group(self): + eg = newSimpleEG('simple EG') + + self.assertMatchesTemplate( + eg, [ValueError(1), TypeError(int), ValueError(2)]) + + self.assertEqual(list(eg), list(eg.excs)) # check iteration + + # check message + self.assertEqual(eg.message, 'simple EG') + self.assertEqual(eg.args, ('simple EG',)) + + # check tracebacks + for e in eg: + fname = 'new' + ''.join(filter(str.isupper, type(e).__name__)) + self.assertEqual( + ['newEG', 'newEG', fname], + [f.name for f in extract_traceback(e, eg)]) + + def test_nested_group(self): + eg = newNestedEG(5) + + self.assertMatchesTemplate( + eg, + [ + [ + [ValueError(6), TypeError(int), ValueError(7)], + [ValueError(7), TypeError(int), ValueError(8)], + ValueError(8), + ], + ValueError(7) + ]) + + self.assertEqual(len(list(eg)), 8) # check iteration + + # check tracebacks + all_excs = list(eg) + for e in all_excs[0:6]: + fname = 'new' + ''.join(filter(str.isupper, type(e).__name__)) + self.assertEqual( + [ + 'newEG', 'newEG', 'raiseExc', + 'newEG', 'newEG', 'raiseExc', + 'newEG', 'newEG', fname, + ], + [f.name for f in extract_traceback(e, eg)]) + + self.assertEqual([ + 'newEG', 'newEG', 'raiseExc', 'newEG', 'newEG', 'newVE'], + [f.name for f in extract_traceback(all_excs[6], eg)]) + + self.assertEqual( + ['newEG', 'newEG', 'newVE'], + [f.name for f in extract_traceback(all_excs[7], eg)]) + + +class ExceptionGroupSplitTests(ExceptionGroupTestBase): + def _split_exception_group(self, eg, types): + """ Split an EG and do some sanity checks on the result """ + self.assertIsInstance(eg, ExceptionGroup) + fnames = [t.name for t in traceback.extract_tb(eg.__traceback__)] + all_excs = list(eg) + + match, rest = eg.split(types) + + if match is not None: + self.assertIsInstance(match, ExceptionGroup) + for e in match: + self.assertIsInstance(e, types) + + if rest is not None: + self.assertIsInstance(rest, ExceptionGroup) + for e in rest: + self.assertNotIsInstance(e, types) + + match_len = len(list(match)) if match is not None else 0 + rest_len = len(list(rest)) if rest is not None else 0 + self.assertEqual(len(list(all_excs)), match_len + rest_len) + + for e in all_excs: + # each exception is in eg and exactly one of match and rest + self.assertIn(e, eg) + self.assertNotEqual(match and e in match, rest and e in rest) + + for part in [match, rest]: + if part is not None: + self.assertEqual(eg.message, part.message) + for e in part: + self.assertEqual( + extract_traceback(e, eg), + extract_traceback(e, part)) + + return match, rest + + def test_split_nested(self): + try: + raise newNestedEG(25) + except ExceptionGroup as e: + eg = e + + fnames = ['test_split_nested', 'newEG'] + tb = traceback.extract_tb(eg.__traceback__) + self.assertEqual(fnames, [t.name for t in tb]) + + eg_template = [ + [ + [ValueError(26), TypeError(int), ValueError(27)], + [ValueError(27), TypeError(int), ValueError(28)], + ValueError(28), + ], + ValueError(27) + ] + self.assertMatchesTemplate(eg, eg_template) + + valueErrors_template = [ + [ + [ValueError(26), ValueError(27)], + [ValueError(27), ValueError(28)], + ValueError(28), + ], + ValueError(27) + ] + + typeErrors_template = [[[TypeError(int)], [TypeError(int)]]] + + # Match Nothing + match, rest = self._split_exception_group(eg, SyntaxError) + self.assertTrue(match is None) + self.assertMatchesTemplate(rest, eg_template) + + # Match Everything + match, rest = self._split_exception_group(eg, BaseException) + self.assertMatchesTemplate(match, eg_template) + self.assertTrue(rest is None) + match, rest = self._split_exception_group(eg, (ValueError, TypeError)) + self.assertMatchesTemplate(match, eg_template) + self.assertTrue(rest is None) + + # Match ValueErrors + match, rest = self._split_exception_group(eg, ValueError) + self.assertMatchesTemplate(match, valueErrors_template) + self.assertMatchesTemplate(rest, typeErrors_template) + + # Match TypeErrors + match, rest = self._split_exception_group(eg, (TypeError, SyntaxError)) + self.assertMatchesTemplate(match, typeErrors_template) + self.assertMatchesTemplate(rest, valueErrors_template) + + # Match ExceptionGroup + match, rest = eg.split(ExceptionGroup) + self.assertIs(match, eg) + self.assertIsNone(rest) + + # Match MyExceptionGroup (ExceptionGroup subclass) + match, rest = eg.split(MyExceptionGroup) + self.assertMatchesTemplate(match, [eg_template[0]]) + self.assertMatchesTemplate(rest, [eg_template[1]]) + +class ExceptionGroupCatchTests(ExceptionGroupTestBase): + def setUp(self): + super().setUp() + + try: + raise newNestedEG(35) + except ExceptionGroup as e: + self.eg = e + + fnames = ['setUp', 'newEG'] + tb = traceback.extract_tb(self.eg.__traceback__) + self.assertEqual(fnames, [t.name for t in tb]) + + # templates + self.eg_template = [ + [ + [ValueError(36), TypeError(int), ValueError(37)], + [ValueError(37), TypeError(int), ValueError(38)], + ValueError(38), + ], + ValueError(37) + ] + + self.valueErrors_template = [ + [ + [ValueError(36), ValueError(37)], + [ValueError(37), ValueError(38)], + ValueError(38), + ], + ValueError(37) + ] + + self.typeErrors_template = [[[TypeError(int)], [TypeError(int)]]] + + def checkMatch(self, exc, template, orig_eg): + self.assertMatchesTemplate(exc, template) + for e in exc: + + def f_data(f): + return [f.name, f.lineno] + + new = list(map(f_data, extract_traceback(e, exc))) + if e in orig_eg: + old = list(map(f_data, extract_traceback(e, orig_eg))) + self.assertSequenceEqual(old, new[-len(old):]) + + class BaseHandler: + def __init__(self): + self.caught = None + + def __call__(self, eg): + self.caught = eg + return self.handle(eg) + + def apply_catcher(self, catch, handler_cls, eg): + try: + raised = None + handler = handler_cls() + with ExceptionGroup.catch(catch, handler): + raise eg + except ExceptionGroup as e: + raised = e + return handler.caught, raised + + def test_catch_handler_raises_nothing(self): + eg = self.eg + eg_template = self.eg_template + valueErrors_template = self.valueErrors_template + typeErrors_template = self.typeErrors_template + + class Handler(self.BaseHandler): + def handle(self, eg): + pass + + # ######## Catch nothing: + caught, raised = self.apply_catcher(SyntaxError, Handler, eg) + self.checkMatch(raised, eg_template, eg) + self.assertIsNone(caught) + + # ######## Catch everything: + error_types = (ValueError, TypeError) + caught, raised = self.apply_catcher(error_types, Handler, eg) + self.assertIsNone(raised) + self.checkMatch(caught, eg_template, eg) + + # ######## Catch TypeErrors: + caught, raised = self.apply_catcher(TypeError, Handler, eg) + self.checkMatch(raised, valueErrors_template, eg) + self.checkMatch(caught, typeErrors_template, eg) + + # ######## Catch ValueErrors: + error_types = (ValueError, SyntaxError) + caught, raised = self.apply_catcher(error_types, Handler, eg) + self.checkMatch(raised, typeErrors_template, eg) + self.checkMatch(caught, valueErrors_template, eg) + + def test_catch_handler_adds_new_exceptions(self): + # create a nested exception group + eg = self.eg + eg_template = self.eg_template + valueErrors_template = self.valueErrors_template + typeErrors_template = self.typeErrors_template + + class Handler(self.BaseHandler): + def handle(self, eg): + raise ExceptionGroup( + "msg1", + ValueError('foo'), + ExceptionGroup( + "msg2",SyntaxError('bar'), ValueError('baz'))) + + newErrors_template = [ + ValueError('foo'), [SyntaxError('bar'), ValueError('baz')]] + + # ######## Catch nothing: + caught, raised = self.apply_catcher(SyntaxError, Handler, eg) + self.checkMatch(raised, eg_template, eg) + self.assertIsNone(caught) + + # ######## Catch everything: + error_types = (ValueError, TypeError) + caught, raised = self.apply_catcher(error_types, Handler, eg) + self.checkMatch(raised, newErrors_template, eg) + self.checkMatch(caught, eg_template, eg) + + # ######## Catch TypeErrors: + caught, raised = self.apply_catcher(TypeError, Handler, eg) + self.checkMatch(raised, [valueErrors_template, newErrors_template], eg) + self.checkMatch(caught, typeErrors_template, eg) + + # ######## Catch ValueErrors: + caught, raised = self.apply_catcher((ValueError, OSError), Handler, eg) + self.checkMatch(raised, [typeErrors_template, newErrors_template], eg) + self.checkMatch(caught, valueErrors_template, eg) + + def test_catch_handler_reraise_all_matched(self): + eg = self.eg + eg_template = self.eg_template + valueErrors_template = self.valueErrors_template + typeErrors_template = self.typeErrors_template + + # There are two ways to do this + class Handler1(self.BaseHandler): + def handle(self, eg): + return True + + class Handler2(self.BaseHandler): + def handle(self, eg): + raise eg + + for handler in [Handler1, Handler2]: + # ######## Catch nothing: + caught, raised = self.apply_catcher(SyntaxError, handler, eg) + # handler is never called + self.checkMatch(raised, eg_template, eg) + self.assertIsNone(caught) + + # ######## Catch everything: + error_types = (ValueError, TypeError) + caught, raised = self.apply_catcher(error_types, handler, eg) + self.checkMatch(raised, eg_template, eg) + self.checkMatch(caught, eg_template, eg) + + # ######## Catch TypeErrors: + caught, raised = self.apply_catcher(TypeError, handler, eg) + self.checkMatch(raised, eg_template, eg) + self.checkMatch(caught, typeErrors_template, eg) + + # ######## Catch ValueErrors: + catch = (ValueError, SyntaxError) + caught, raised = self.apply_catcher(catch, handler, eg) + self.checkMatch(raised, eg_template, eg) + self.checkMatch(caught, valueErrors_template, eg) + + def test_catch_handler_reraise_new_and_all_old(self): + eg = self.eg + eg_template = self.eg_template + valueErrors_template = self.valueErrors_template + typeErrors_template = self.typeErrors_template + + class Handler(self.BaseHandler): + def handle(self, eg): + raise ExceptionGroup( + "msg1", + eg, + ValueError('foo'), + ExceptionGroup( + "msg2", SyntaxError('bar'), ValueError('baz'))) + + newErrors_template = [ + ValueError('foo'), [SyntaxError('bar'), ValueError('baz')]] + + # ######## Catch TypeErrors: + caught, raised = self.apply_catcher(TypeError, Handler, eg) + self.checkMatch(raised, [eg_template, newErrors_template], eg) + self.checkMatch(caught, typeErrors_template, eg) + + # ######## Catch ValueErrors: + caught, raised = self.apply_catcher(ValueError, Handler, eg) + self.checkMatch(raised, [eg_template, newErrors_template], eg) + self.checkMatch(caught, valueErrors_template, eg) + + def test_catch_handler_reraise_new_and_some_old(self): + eg = self.eg + eg_template = self.eg_template + valueErrors_template = self.valueErrors_template + typeErrors_template = self.typeErrors_template + + class Handler(self.BaseHandler): + def handle(self, eg): + raise ExceptionGroup( + "msg1", + eg.excs[0], + ValueError('foo'), + ExceptionGroup( + "msg2", SyntaxError('bar'), ValueError('baz'))) + + newErrors_template = [ + ValueError('foo'), [SyntaxError('bar'), ValueError('baz')]] + + # ######## Catch TypeErrors: + caught, raised = self.apply_catcher(TypeError, Handler, eg) + self.checkMatch(raised, [eg_template, newErrors_template], eg) + self.checkMatch(caught, typeErrors_template, eg) + + # ######## Catch ValueErrors: + caught, raised = self.apply_catcher(ValueError, Handler, eg) + # eg.excs[0] is reraised and eg.excs[1] is consumed + self.checkMatch(raised, [[eg_template[0]], newErrors_template], eg) + self.checkMatch(caught, valueErrors_template, eg) + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py index 5df701caf0f01e..c8d5450d57622b 100644 --- a/Lib/test/test_traceback.py +++ b/Lib/test/test_traceback.py @@ -12,6 +12,7 @@ from test.support.script_helper import assert_python_ok import textwrap +import exception_group import traceback @@ -1265,6 +1266,197 @@ def test_traceback_header(self): exc = traceback.TracebackException(Exception, Exception("haven"), None) self.assertEqual(list(exc.format()), ["Exception: haven\n"]) + def test_ExceptionFormatter_factory(self): + def f(): + x = 12 + x/0 + + try: + f() + except: + exc_info = sys.exc_info() + + exc = exc_info[1] + + factory = traceback.ExceptionFormatter + direct = traceback.TracebackException + + self.assertEqual(factory.get(*exc_info), direct(*exc_info)) + self.assertEqual( + factory.from_exception(exc), direct.from_exception(exc)) + + self.assertEqual( + factory.get(*exc_info, limit=1), direct(*exc_info, limit=1)) + self.assertEqual( + factory.from_exception(exc, limit=2), + direct.from_exception(exc, limit=2)) + + self.assertNotEqual( + factory.get(*exc_info, limit=1), direct(*exc_info, limit=2)) + self.assertNotEqual( + factory.from_exception(exc, limit=2), + direct.from_exception(exc, limit=1)) + + self.assertNotEqual( + factory.get(*exc_info, capture_locals=True), direct(*exc_info)) + self.assertNotEqual( + factory.from_exception(exc, capture_locals=True), + direct.from_exception(exc)) + +class TestTracebackExceptionGroup(unittest.TestCase): + def setUp(self): + super().setUp() + self.eg_info = self._get_exception_group() + + def _get_exception_group(self): + def f(): + 1/0 + + def g(v): + raise ValueError(v) + + self.lno_f = f.__code__.co_firstlineno + self.lno_g = g.__code__.co_firstlineno + + try: + try: + try: + f() + except Exception as e: + exc1 = e + try: + g(42) + except Exception as e: + exc2 = e + raise exception_group.ExceptionGroup("eg1", exc1, exc2) + except exception_group.ExceptionGroup as e: + exc3 = e + try: + g(24) + except Exception as e: + exc4 = e + raise exception_group.ExceptionGroup("eg2", exc3, exc4) + except exception_group.ExceptionGroup: + return sys.exc_info() + self.fail('Exception Not Raised') + + def test_single_exception_raises(self): + try: + 1/0 + except: + exc_info = sys.exc_info() + msg = "Expected an ExceptionGroup, got " + with self.assertRaisesRegex(ValueError, msg): + traceback.TracebackExceptionGroup(*exc_info) + with self.assertRaisesRegex(ValueError, msg): + traceback.TracebackExceptionGroup.from_exception(exc_info[1]) + + def test_exception_group_construction(self): + eg_info = self.eg_info + teg1 = traceback.TracebackExceptionGroup(*eg_info) + teg2 = traceback.TracebackExceptionGroup.from_exception(eg_info[1]) + self.assertIsNot(teg1, teg2) + self.assertEqual(teg1, teg2) + + def test_exception_group_format_exception_only(self): + teg = traceback.TracebackExceptionGroup(*self.eg_info) + formatted = ''.join(teg.format_exception_only()).split('\n') + expected = textwrap.dedent(f"""\ + exception_group.ExceptionGroup: eg2 + exception_group.ExceptionGroup: eg1 + ZeroDivisionError: division by zero + ValueError: 42 + ValueError: 24 + """).split('\n') + + self.assertEqual(formatted, expected) + + def test_exception_group_format(self): + teg = traceback.TracebackExceptionGroup(*self.eg_info) + + formatted = ''.join(teg.format()).split('\n') + lno_f = self.lno_f + lno_g = self.lno_g + + expected = textwrap.dedent(f"""\ + Traceback (most recent call last): + File "{__file__}", line {lno_g+23}, in _get_exception_group + raise exception_group.ExceptionGroup("eg2", exc3, exc4) + exception_group.ExceptionGroup: eg2 + ------------------------------------------------------------ + Traceback (most recent call last): + File "{__file__}", line {lno_g+16}, in _get_exception_group + raise exception_group.ExceptionGroup("eg1", exc1, exc2) + exception_group.ExceptionGroup: eg1 + ------------------------------------------------------------ + Traceback (most recent call last): + File "{__file__}", line {lno_g+9}, in _get_exception_group + f() + File "{__file__}", line {lno_f+1}, in f + 1/0 + ZeroDivisionError: division by zero + ------------------------------------------------------------ + Traceback (most recent call last): + File "{__file__}", line {lno_g+13}, in _get_exception_group + g(42) + File "{__file__}", line {lno_g+1}, in g + raise ValueError(v) + ValueError: 42 + ------------------------------------------------------------ + Traceback (most recent call last): + File "{__file__}", line {lno_g+20}, in _get_exception_group + g(24) + File "{__file__}", line {lno_g+1}, in g + raise ValueError(v) + ValueError: 24 + """).split('\n') + + self.assertEqual(formatted, expected) + + def test_comparison(self): + try: + raise self.eg_info[1] + except exception_group.ExceptionGroup: + exc_info = sys.exc_info() + for _ in range(5): + try: + raise exc_info[1] + except: + exc_info = sys.exc_info() + exc = traceback.TracebackExceptionGroup(*exc_info) + exc2 = traceback.TracebackExceptionGroup(*exc_info) + exc3 = traceback.TracebackExceptionGroup(*exc_info, limit=300) + ne = traceback.TracebackExceptionGroup(*exc_info, limit=3) + self.assertIsNot(exc, exc2) + self.assertEqual(exc, exc2) + self.assertEqual(exc, exc3) + self.assertNotEqual(exc, ne) + self.assertNotEqual(exc, object()) + self.assertEqual(exc, ALWAYS_EQ) + + def test_ExceptionFormatter_factory(self): + exc_info = self.eg_info + + factory = traceback.ExceptionFormatter + direct = traceback.TracebackExceptionGroup + + self.assertEqual(factory.get(*exc_info), direct(*exc_info)) + self.assertEqual( + factory.from_exception(exc_info[1]), + direct.from_exception(exc_info[1])) + + self.assertEqual( + factory.get(*exc_info, limit=10), direct(*exc_info, limit=20)) + self.assertEqual( + factory.from_exception(exc_info[1], limit=10), + direct.from_exception(exc_info[1], limit=20)) + + self.assertNotEqual( + factory.get(*exc_info, capture_locals=True), direct(*exc_info)) + self.assertNotEqual( + factory.from_exception(exc_info[1], capture_locals=True), + direct.from_exception(exc_info[1])) + class MiscTest(unittest.TestCase): diff --git a/Lib/traceback.py b/Lib/traceback.py index 457d92511af051..ff0bccb6128412 100644 --- a/Lib/traceback.py +++ b/Lib/traceback.py @@ -1,15 +1,18 @@ """Extract, format and print information about Python stack traces.""" import collections +import exception_group import itertools import linecache import sys +import textwrap __all__ = ['extract_stack', 'extract_tb', 'format_exception', 'format_exception_only', 'format_list', 'format_stack', 'format_tb', 'print_exc', 'format_exc', 'print_exception', 'print_last', 'print_stack', 'print_tb', 'clear_frames', - 'FrameSummary', 'StackSummary', 'TracebackException', + 'ExceptionFormatter', 'FrameSummary', 'StackSummary', + 'TracebackException', 'TracebackExceptionGroup', 'walk_stack', 'walk_tb'] # @@ -71,6 +74,7 @@ def extract_tb(tb, limit=None): """ return StackSummary.extract(walk_tb(tb), limit=limit) + # # Exception formatting and output. # @@ -110,7 +114,7 @@ def print_exception(exc, /, value=_sentinel, tb=_sentinel, limit=None, \ value, tb = _parse_value_tb(exc, value, tb) if file is None: file = sys.stderr - for line in TracebackException( + for line in ExceptionFormatter.get( type(value), value, tb, limit=limit).format(chain=chain): print(line, file=file, end="") @@ -126,7 +130,7 @@ def format_exception(exc, /, value=_sentinel, tb=_sentinel, limit=None, \ printed as does print_exception(). """ value, tb = _parse_value_tb(exc, value, tb) - return list(TracebackException( + return list(ExceptionFormatter.get( type(value), value, tb, limit=limit).format(chain=chain)) @@ -146,7 +150,7 @@ def format_exception_only(exc, /, value=_sentinel): """ if value is _sentinel: value = exc - return list(TracebackException( + return list(ExceptionFormatter.get( type(value), value, None).format_exception_only()) @@ -611,9 +615,10 @@ def format(self, *, chain=True): If chain is not *True*, *__cause__* and *__context__* will not be formatted. - The return value is a generator of strings, each ending in a newline and - some containing internal newlines. `print_exception` is a wrapper around - this method which just prints the lines to a file. + The return value is a generator of strings, each ending in a newline + and some containing internal newlines. `print_exception(e)`, when `e` + is a single exception (rather than an ExceptionGroup), is a wrapper + around this method which just prints the lines to a file. The message indicating which exception occurred is always the last string in the output. @@ -630,3 +635,101 @@ def format(self, *, chain=True): yield 'Traceback (most recent call last):\n' yield from self.stack.format() yield from self.format_exception_only() + + +class TracebackExceptionGroup: + """An exception group ready for rendering. + + We capture enough attributes from the original exception group to this + intermediary form to ensure that no references are held, while still being + able to fully print or format it. + + Use `from_exception()` to create TracebackExceptionGroup instances from exception + objects, or the constructor to create TracebackExceptionGroup instances from + individual components. + + - :attr:`excs` A list of TracebackException objects, one for each exception + in the group. + """ + + SEPARATOR_LINE = '-' * 60 + '\n' + INDENT_SIZE = 3 + + def __init__(self, exc_type, exc_value, exc_traceback, **kwargs): + if not isinstance(exc_value, exception_group.ExceptionGroup): + raise ValueError(f'Expected an ExceptionGroup, got {type(exc_value)}') + self.this = TracebackException( + exc_type, exc_value, exc_traceback, **kwargs) + self.excs = [ + ExceptionFormatter.from_exception(e) for e in exc_value.excs] + + @staticmethod + def from_exception(exc, *args, **kwargs): + """Create a TracebackExceptionGroup from an exceptionGroup.""" + return TracebackExceptionGroup( + type(exc), exc, exc.__traceback__, *args, **kwargs) + + def format(self, *, chain=True): + """Format the exception group. + + The shared part of the traceback is emitted, followed by each + exception in the group, which is expanded recursively. + + If chain is false(y), *__cause__* and *__context__* will not be formatted. + + This is a generator of strings, each ending in a newline + and some containing internal newlines. `print_exception`, when called on + an ExceptionGroup, is a wrapper around this method which just prints the + lines to a file. + """ + # TODO: Add two args to bound - + # (1) the depth of exceptions reported, and + # (2) the number of exceptions reported per level + separator = self.SEPARATOR_LINE + yield from self.this.format(chain=chain) + for exc in self.excs: + yield from self._emit( + exc.format(chain=chain), sep=separator) + + def format_exception_only(self): + yield from self.this.format_exception_only() + for exc in self.excs: + yield from self._emit(exc.format_exception_only()) + + def _emit(self, text_gen, sep=None): + text = ''.join(list(text_gen)) + indent_str = ' ' * self.INDENT_SIZE + if '\n' not in text: + yield indent_str + text + else: + if sep is not None: + yield indent_str + sep + yield textwrap.indent(text, indent_str) + + def __eq__(self, other): + if isinstance(other, TracebackExceptionGroup): + return self.__dict__ == other.__dict__ + return NotImplemented + + +class ExceptionFormatter: + '''Factory functions to get the correct formatter for an exception + + Returns a TracebackException instance for a single exception, and a + TracebackExceptionGroup for an exception group. + ''' + + @staticmethod + def get(exc_type, exc_value, exc_traceback, **kwargs): + if isinstance(exc_value, exception_group.ExceptionGroup): + cls = TracebackExceptionGroup + else: + cls = TracebackException + return cls(exc_type, exc_value, exc_traceback, **kwargs) + + @staticmethod + def from_exception(exc, **kwargs): + if isinstance(exc, exception_group.ExceptionGroup): + return TracebackExceptionGroup.from_exception(exc, **kwargs) + else: + return TracebackException.from_exception(exc, **kwargs)