diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 6e028c6c7a..7462d1cf01 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -667,7 +667,6 @@ class Runner: deadlines = attr.ib(default=attr.Factory(SortedDict)) init_task = attr.ib(default=None) - init_task_result = attr.ib(default=None) system_nursery = attr.ib(default=None) system_context = attr.ib(default=None) main_task = attr.ib(default=None) @@ -931,22 +930,21 @@ def task_exited(self, task, result): while task._cancel_stack: task._cancel_stack[-1]._remove_task(task) self.tasks.remove(task) - if task._parent_nursery is None: + if task is self.main_task: + self.main_task_result = result + self.system_nursery.cancel_scope.cancel() + self.system_nursery._child_finished(task, Value(None)) + elif task is self.init_task: + # If the init task crashed, then something is very wrong and we + # let the error propagate. (It'll eventually be wrapped in a + # TrioInternalError.) + result.unwrap() # the init task should be the last task to exit. If not, then - # something is very wrong. Probably it hit some unexpected error, - # in which case we re-raise the error (which will later get - # converted to a TrioInternalError, but at least we'll get a - # traceback). Otherwise, raise a new error. + # something is very wrong. if self.tasks: # pragma: no cover - result.unwrap() raise TrioInternalError else: task._parent_nursery._child_finished(task, result) - if task is self.main_task: - self.main_task_result = result - self.system_nursery.cancel_scope.cancel() - if task is self.init_task: - self.init_task_result = result if self.instruments: self.instrument("task_exited", task) @@ -973,7 +971,10 @@ def spawn_system_task(self, async_fn, *args, name=None): * By default, system tasks have :exc:`KeyboardInterrupt` protection *enabled*. If you want your task to be interruptible by control-C, - then you need to use :func:`disable_ki_protection` explicitly. + then you need to use :func:`disable_ki_protection` explicitly (and + come up with some plan for what to do with a + :exc:`KeyboardInterrupt`, given that system tasks aren't allowed to + raise exceptions). * System tasks do not inherit context variables from their creator. @@ -993,40 +994,21 @@ def spawn_system_task(self, async_fn, *args, name=None): """ - async def system_task_wrapper(async_fn, args): - PASS = ( - Cancelled, KeyboardInterrupt, GeneratorExit, TrioInternalError - ) - - def excfilter(exc): - if isinstance(exc, PASS): - return exc - else: - new_exc = TrioInternalError("system task crashed") - new_exc.__cause__ = exc - return new_exc - - with MultiError.catch(excfilter): - await async_fn(*args) - - if name is None: - name = async_fn return self.spawn_impl( - system_task_wrapper, - (async_fn, args), - self.system_nursery, - name, - system_task=True, + async_fn, args, self.system_nursery, name, system_task=True ) async def init(self, async_fn, args): async with open_nursery() as system_nursery: self.system_nursery = system_nursery - self.main_task = self.spawn_impl( - async_fn, args, system_nursery, None - ) + try: + self.main_task = self.spawn_impl( + async_fn, args, system_nursery, None + ) + except BaseException as exc: + self.main_task_result = Error(exc) + system_nursery.cancel_scope.cancel() self.entry_queue.spawn() - return self.main_task_result.unwrap() ################ # Outside context problems @@ -1326,7 +1308,7 @@ def run( with closing(runner): # The main reason this is split off into its own function # is just to get rid of this extra indentation. - result = run_impl(runner, async_fn, args) + run_impl(runner, async_fn, args) except TrioInternalError: raise except BaseException as exc: @@ -1335,7 +1317,7 @@ def run( ) from exc finally: GLOBAL_RUN_CONTEXT.__dict__.clear() - return result.unwrap() + return runner.main_task_result.unwrap() finally: # To guarantee that we never swallow a KeyboardInterrupt, we have to # check for pending ones once more after leaving the context manager: @@ -1504,8 +1486,6 @@ def run_impl(runner, async_fn, args): runner.instrument("after_task_step", task) del GLOBAL_RUN_CONTEXT.task - return runner.init_task_result - ################################################################ # Other public API functions diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index 6b6cecff66..d35d4c7da7 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -960,15 +960,14 @@ async def main(): _core.spawn_system_task(system_task) await sleep_forever() - with pytest.raises(_core.MultiError) as excinfo: + with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - assert len(excinfo.value.exceptions) == 2 - cause_types = set() - for exc in excinfo.value.exceptions: - assert type(exc) is _core.TrioInternalError - cause_types.add(type(exc.__cause__)) - assert cause_types == {KeyError, ValueError} + me = excinfo.value.__cause__ + assert isinstance(me, _core.MultiError) + assert len(me.exceptions) == 2 + for exc in me.exceptions: + assert isinstance(exc, (KeyError, ValueError)) def test_system_task_crash_plus_Cancelled(): @@ -1005,10 +1004,9 @@ async def main(): _core.spawn_system_task(ki) await sleep_forever() - # KI doesn't get wrapped with TrioInternalError - with pytest.raises(KeyboardInterrupt): + with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - + assert isinstance(excinfo.value.__cause__, KeyboardInterrupt) # This used to fail because checkpoint was a yield followed by an immediate # reschedule. So we had: