diff --git a/tests/test_base.py b/tests/test_base.py index 00ecbcd1..512b0d7a 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -689,6 +689,17 @@ async def foo(): asyncio.wait_for(foo(), timeout=float('inf'))) self.assertEqual(res, 123) + def test_shutdown_default_executor(self): + if not hasattr(self.loop, "shutdown_default_executor"): + raise unittest.SkipTest() + + async def foo(): + await self.loop.run_in_executor(None, time.sleep, .1) + + self.loop.run_until_complete(foo()) + self.loop.run_until_complete( + self.loop.shutdown_default_executor()) + class TestBaseUV(_TestBase, UVTestCase): diff --git a/uvloop/includes/stdlib.pxi b/uvloop/includes/stdlib.pxi index 474c0826..1debe718 100644 --- a/uvloop/includes/stdlib.pxi +++ b/uvloop/includes/stdlib.pxi @@ -134,6 +134,7 @@ cdef int ssl_SSL_ERROR_WANT_WRITE = ssl.SSL_ERROR_WANT_WRITE cdef int ssl_SSL_ERROR_SYSCALL = ssl.SSL_ERROR_SYSCALL cdef uint64_t MAIN_THREAD_ID = threading.main_thread().ident +cdef threading_Thread = threading.Thread cdef int subprocess_PIPE = subprocess.PIPE cdef int subprocess_STDOUT = subprocess.STDOUT diff --git a/uvloop/loop.pxd b/uvloop/loop.pxd index f36d1e2f..86b9e5d0 100644 --- a/uvloop/loop.pxd +++ b/uvloop/loop.pxd @@ -83,6 +83,8 @@ cdef class Loop: object _asyncgens bint _asyncgens_shutdown_called + bint _executor_shutdown_called + char _recv_buffer[UV_STREAM_RECV_BUF_SIZE] bint _recv_buffer_in_use diff --git a/uvloop/loop.pyx b/uvloop/loop.pyx index a1ea92cd..b576eabe 100644 --- a/uvloop/loop.pyx +++ b/uvloop/loop.pyx @@ -186,6 +186,8 @@ cdef class Loop: # Set to True when `loop.shutdown_asyncgens` is called. self._asyncgens_shutdown_called = False + # Set to True when `loop.shutdown_default_executor` is called. + self._executor_shutdown_called = False self._servers = set() @@ -591,6 +593,7 @@ cdef class Loop: self.handler_idle = None self.handler_check__exec_writes = None + self._executor_shutdown_called = True executor = self._default_executor if executor is not None: self._default_executor = None @@ -2669,6 +2672,8 @@ cdef class Loop: if executor is None: executor = self._default_executor + # Only check when the default executor is being used + self._check_default_executor() if executor is None: executor = cc_ThreadPoolExecutor() self._default_executor = executor @@ -3090,6 +3095,10 @@ cdef class Loop: await waiter return udp, protocol + def _check_default_executor(self): + if self._executor_shutdown_called: + raise RuntimeError('Executor shutdown has been called') + def _asyncgen_finalizer_hook(self, agen): self._asyncgens.discard(agen) if not self.is_closed(): @@ -3131,6 +3140,27 @@ cdef class Loop: 'asyncgen': agen }) + @cython.iterable_coroutine + async def shutdown_default_executor(self): + """Schedule the shutdown of the default executor.""" + self._executor_shutdown_called = True + if self._default_executor is None: + return + future = self.create_future() + thread = threading_Thread(target=self._do_shutdown, args=(future,)) + thread.start() + try: + await future + finally: + thread.join() + + def _do_shutdown(self, future): + try: + self._default_executor.shutdown(wait=True) + self.call_soon_threadsafe(future.set_result, None) + except Exception as ex: + self.call_soon_threadsafe(future.set_exception, ex) + cdef void __loop_alloc_buffer(uv.uv_handle_t* uvhandle, size_t suggested_size,