diff --git a/trio/_core/_run.py b/trio/_core/_run.py index ff30a56cb1..af16ac92a7 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1101,6 +1101,13 @@ def aborted(raise_cancel: _core.RaiseCancelT) -> Abort: popped = self._parent_task._child_nurseries.pop() assert popped is self + + # don't unnecessarily wrap an exceptiongroup in another exceptiongroup + # see https://github.com/python-trio/trio/issues/2611 + if len(self._pending_excs) == 1 and isinstance( + self._pending_excs[0], BaseExceptionGroup + ): + return self._pending_excs[0] if self._pending_excs: try: return MultiError( diff --git a/trio/_core/_tests/test_run.py b/trio/_core/_tests/test_run.py index 650f8ef77f..361310143e 100644 --- a/trio/_core/_tests/test_run.py +++ b/trio/_core/_tests/test_run.py @@ -40,7 +40,7 @@ ) if sys.version_info < (3, 11): - from exceptiongroup import ExceptionGroup + from exceptiongroup import BaseExceptionGroup, ExceptionGroup T = TypeVar("T") @@ -2538,3 +2538,62 @@ async def test_cancel_scope_no_cancellederror() -> None: raise ExceptionGroup("test", [RuntimeError(), RuntimeError()]) assert not scope.cancelled_caught + + +"""These tests are for fixing https://github.com/python-trio/trio/issues/2611 +where exceptions raised before `task_status.started()` got wrapped twice. +""" + + +async def raise_before(*, task_status: _core.TaskStatus[None]) -> None: + raise ValueError + + +async def raise_after_started(*, task_status: _core.TaskStatus[None]) -> None: + task_status.started() + raise ValueError + + +async def raise_custom_exception_group_before( + *, task_status: _core.TaskStatus[None] +) -> None: + raise ExceptionGroup("my group", [ValueError()]) + + +def _check_exception(exc: pytest.ExceptionInfo[BaseException]) -> None: + assert isinstance(exc.value, BaseExceptionGroup) + assert len(exc.value.exceptions) == 1 + assert isinstance(exc.value.exceptions[0], ValueError) + + +async def _start_raiser( + raiser: Callable[[], Awaitable[None]], strict: bool | None = None +) -> None: + async with _core.open_nursery(strict_exception_groups=strict) as nursery: + await nursery.start(raiser) + + +@pytest.mark.parametrize("strict", [False, True]) +@pytest.mark.parametrize("raiser", [raise_before, raise_after_started]) +async def test_strict_before_started( + strict: bool, raiser: Callable[[], Awaitable[None]] +) -> None: + with pytest.raises(BaseExceptionGroup if strict else ValueError) as exc: + await _start_raiser(raiser, strict) + if strict: + _check_exception(exc) + + +# it was only when run from `trio.run` that the double wrapping happened +@pytest.mark.parametrize("strict", [False, True]) +@pytest.mark.parametrize( + "raiser", [raise_before, raise_after_started, raise_custom_exception_group_before] +) +def test_trio_run_strict_before_started( + strict: bool, raiser: Callable[[], Awaitable[None]] +) -> None: + expect_group = strict or raiser is raise_custom_exception_group_before + with pytest.raises(BaseExceptionGroup if expect_group else ValueError) as exc: + _core.run(_start_raiser, raiser, strict_exception_groups=strict) + if expect_group: + _check_exception(exc)