diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 673e661bd9..e88414f489 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -124,6 +124,16 @@ class CancelScope: cancel_called = attr.ib(default=False) cancelled_caught = attr.ib(default=False) + @staticmethod + def create(deadline, shield): + task = _core.current_task() + scope = CancelScope() + scope._scope_task = task + scope._add_task(task) + scope.deadline = deadline + scope.shield = shield + return scope + def __repr__(self): return "".format(id(self)) @@ -212,6 +222,12 @@ def _exc_filter(self, exc): return None return exc + def _close(self, exc): + self._remove_task(self._scope_task) + if exc is not None: + filtered_exc = MultiError.filter(self._exc_filter, exc) + return filtered_exc + # Note we explicitly avoid @contextmanager since it adds extraneous stack # frames to exceptions. @enable_ki_protection @@ -220,24 +236,21 @@ def __enter__(self): @enable_ki_protection def __exit__(self, etype, exc, tb): - try: + filtered_exc = self._close(exc) + if filtered_exc is None: + return True + elif filtered_exc is exc: + return False + else: # Copied verbatim from MultiErrorCatcher. Python doesn't - # allow us to encapsulate the __context__ fixup. - if exc is not None: - filtered_exc = MultiError.filter(self._exc_filter, exc) - if filtered_exc is exc: - return False - if filtered_exc is None: - return True - old_context = filtered_exc.__context__ - try: - raise filtered_exc - finally: - _, value, _ = sys.exc_info() - assert value is filtered_exc - value.__context__ = old_context - finally: - self._remove_task(self._scope_task) + # allow us to encapsulate this __context__ fixup. + old_context = filtered_exc.__context__ + try: + raise filtered_exc + finally: + _, value, _ = sys.exc_info() + assert value is filtered_exc + value.__context__ = old_context def open_cancel_scope(*, deadline=inf, shield=False): @@ -245,12 +258,7 @@ def open_cancel_scope(*, deadline=inf, shield=False): """ - scope = CancelScope() - scope._scope_task = _core.current_task() - scope._add_task(scope._scope_task) - scope.deadline = deadline - scope.shield = shield - return scope + return CancelScope.create(deadline, shield) ################################################################ @@ -356,27 +364,30 @@ class NurseryManager: @enable_ki_protection async def __aenter__(self): - self._scope_manager = open_cancel_scope() - scope = self._scope_manager.__enter__() - self._nursery = Nursery(current_task(), scope) + self._scope = CancelScope.create(deadline=inf, shield=False) + self._nursery = Nursery(current_task(), self._scope) return self._nursery @enable_ki_protection async def __aexit__(self, etype, exc, tb): new_exc = await self._nursery._nested_child_finished(exc) if new_exc: - try: - if self._scope_manager.__exit__( - type(new_exc), new_exc, new_exc.__traceback__ - ): - return True - except BaseException as scope_manager_exc: - if scope_manager_exc == exc: - return False - raise # scope_manager_exc - raise new_exc + scope_exc = self._scope._close(new_exc) + if scope_exc is None: + return True + elif scope_exc is exc: + return False + else: + # Copied verbatim from MultiErrorCatcher. Python doesn't + # allow us to encapsulate this __context__ fixup. + old_context = scope_exc.__context__ + try: + raise scope_exc + finally: + _, value, _ = sys.exc_info() + assert value is scope_exc + value.__context__ = old_context else: - self._scope_manager.__exit__(None, None, None) return True def __enter__(self):