From 33ae2efe758361dfe63f22f0ecda826a82bde9ed Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 24 May 2022 18:42:18 +0200 Subject: [PATCH 01/75] Add background tasks and rename ongoing_coroutines to ongoing_comm_handlers --- distributed/core.py | 24 ++++++++++++++++-------- distributed/tests/test_core.py | 4 ++-- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 2a2a12521ff..f67abb330d0 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -187,7 +187,8 @@ def __init__( self.monitor = SystemMonitor() self.counters = None self.digests = None - self._ongoing_coroutines = set() + self._ongoing_background_tasks = set() + self._ongoing_comm_handlers = set() self._event_finished = asyncio.Event() self.listeners = [] @@ -218,7 +219,7 @@ def stop() -> bool: self.counters = defaultdict(partial(Counter, loop=self.io_loop)) - self.periodic_callbacks = dict() + self.periodic_callbacks = {} pc = PeriodicCallback( self.monitor.update, @@ -583,8 +584,8 @@ async def handle_comm(self, comm): result = asyncio.create_task( result, name=f"handle-comm-{address}-{op}" ) - self._ongoing_coroutines.add(result) - result.add_done_callback(self._ongoing_coroutines.remove) + self._ongoing_comm_handlers.add(result) + result.add_done_callback(self._ongoing_comm_handlers.remove) result = await result elif inspect.isawaitable(result): raise RuntimeError( @@ -673,6 +674,11 @@ async def handle_stream(self, comm, extra=None): await comm.close() assert comm.closed() + def add_background_task(self, coro) -> None: + task = asyncio.create_task(coro()) + self._ongoing_background_tasks.add(task) + task.add_done_callback(self._ongoing_background_tasks.remove) + async def close(self, timeout=None): for pc in self.periodic_callbacks.values(): pc.stop() @@ -686,20 +692,22 @@ async def close(self, timeout=None): _stops.add(future) await asyncio.gather(*_stops) - def _ongoing_tasks(): + def _ongoing_comm_handlers(): return ( - t for t in self._ongoing_coroutines if t is not asyncio.current_task() + t + for t in self._ongoing_comm_handlers + if t is not asyncio.current_task() ) # TODO: Deal with exceptions try: # Give the handlers a bit of time to finish gracefully await asyncio.wait_for( - asyncio.gather(*_ongoing_tasks(), return_exceptions=True), 1 + asyncio.gather(*_ongoing_comm_handlers(), return_exceptions=True), 1 ) except asyncio.TimeoutError: # the timeout on gather should've cancelled all the tasks - await asyncio.gather(*_ongoing_tasks(), return_exceptions=True) + await asyncio.gather(*_ongoing_comm_handlers(), return_exceptions=True) await self.rpc.close() await asyncio.gather(*[comm.close() for comm in list(self._comms)]) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 7422181d024..9ce06b8a149 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -928,7 +928,7 @@ async def sleep(comm=None): comm = await remote.live_comm() await comm.write({"op": "sleep"}) - await async_wait_for(lambda: not server._ongoing_coroutines, 10) + await async_wait_for(lambda: not server._ongoing_comm_handlers, 10) listeners = server.listeners assert len(listeners) == len(ports) @@ -942,7 +942,7 @@ async def sleep(comm=None): await assert_cannot_connect(f"tcp://{ip}:{port}") # weakref set/dict should be cleaned up - assert not len(server._ongoing_coroutines) + assert not len(server._ongoing_comm_handlers) @gen_test() From e680be971360b12ec2a6a5959894eccda1da1398 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 25 May 2022 17:18:14 +0200 Subject: [PATCH 02/75] Replace add_callback and call_later --- distributed/core.py | 32 +++++++++------- distributed/scheduler.py | 82 ++++++++++++++++++++-------------------- distributed/worker.py | 41 +++++++++++--------- 3 files changed, 82 insertions(+), 73 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index f67abb330d0..cd70d426530 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -15,11 +15,10 @@ from contextlib import suppress from enum import Enum from functools import partial -from typing import Callable, ClassVar, TypedDict, TypeVar +from typing import Callable, ClassVar, Coroutine, TypedDict, TypeVar import tblib from tlz import merge -from tornado import gen from tornado.ioloop import IOLoop, PeriodicCallback import dask @@ -248,10 +247,10 @@ def stop() -> bool: self.thread_id = 0 - def set_thread_ident(): + async def set_thread_ident(): self.thread_id = threading.get_ident() - self.io_loop.add_callback(set_thread_ident) + self.add_background_task(set_thread_ident()) self._startup_lock = asyncio.Lock() self.__startup_exc: Exception | None = None self.__started = asyncio.Event() @@ -344,16 +343,17 @@ def start_periodic_callbacks(self): """ self._last_tick = time() - def start_pcs(): + async def start_pcs(): for pc in self.periodic_callbacks.values(): if not pc.is_running(): pc.start() - self.io_loop.add_callback(start_pcs) + self.add_background_task(start_pcs()) def stop(self): if not self.__stopped: self.__stopped = True + for listener in self.listeners: # Delay closing the server socket until the next IO loop tick. # Otherwise race conditions can appear if an event handler @@ -362,7 +362,7 @@ def stop(self): # The demonstrator for this is Worker.terminate(), which # closes the server socket in response to an incoming message. # See https://github.com/tornadoweb/tornado/issues/2069 - self.io_loop.add_callback(listener.stop) + self.add_background_task(asyncio.coroutine(listener.stop)()) @property def listener(self): @@ -653,8 +653,8 @@ async def handle_stream(self, comm, extra=None): break handler = self.stream_handlers[op] if is_coroutine_function(handler): - self.loop.add_callback(handler, **merge(extra, msg)) - await gen.sleep(0) + self.add_background_task(handler(**merge(extra, msg))) + await asyncio.sleep(0) else: handler(**merge(extra, msg)) else: @@ -674,8 +674,16 @@ async def handle_stream(self, comm, extra=None): await comm.close() assert comm.closed() - def add_background_task(self, coro) -> None: - task = asyncio.create_task(coro()) + def add_background_task(self, coro: Coroutine, delay: float | None = None) -> None: + if delay is not None: + + async def _delay(coro, delay): + await asyncio.sleep(delay) + await coro + + coro = _delay(coro, delay) + + task = asyncio.create_task(coro) self._ongoing_background_tasks.add(task) task.add_done_callback(self._ongoing_background_tasks.remove) @@ -877,12 +885,10 @@ async def _close_comm(comm): tasks = [] for comm in list(self.comms): if comm and not comm.closed(): - # IOLoop.current().add_callback(_close_comm, comm) task = asyncio.ensure_future(_close_comm(comm)) tasks.append(task) for comm in list(self._created): if comm and not comm.closed(): - # IOLoop.current().add_callback(_close_comm, comm) task = asyncio.ensure_future(_close_comm(comm)) tasks.append(task) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c08a79d73a8..a99b873d11f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -28,7 +28,6 @@ Set, ) from contextlib import suppress -from datetime import timedelta from functools import partial from numbers import Number from typing import Any, ClassVar, Literal, cast @@ -3321,7 +3320,7 @@ async def start_unsafe(self): for k, v in self.services.items(): logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) - self.loop.add_callback(self.reevaluate_occupancy) + self.add_background_task(self.reevaluate_occupancy()) if self.scheduler_file: with open(self.scheduler_file, "w") as f: @@ -4244,7 +4243,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True): self.bandwidth_workers.pop((address, w), None) self.bandwidth_workers.pop((w, address), None) - def remove_worker_from_events(): + async def remove_worker_from_events(): # If the worker isn't registered anymore after the delay, remove from events if address not in self.workers and address in self.events: del self.events[address] @@ -4252,7 +4251,8 @@ def remove_worker_from_events(): cleanup_delay = parse_timedelta( dask.config.get("distributed.scheduler.events-cleanup-delay") ) - self.loop.call_later(cleanup_delay, remove_worker_from_events) + + self.add_background_task(remove_worker_from_events(), cleanup_delay) logger.debug("Removed worker %s", ws) return "OK" @@ -4604,7 +4604,7 @@ def remove_client(self, client: str, stimulus_id: str = None) -> None: except Exception as e: logger.exception(e) - def remove_client_from_events(): + async def remove_client_from_events(): # If the client isn't registered anymore after the delay, remove from events if client not in self.clients and client in self.events: del self.events[client] @@ -4612,7 +4612,8 @@ def remove_client_from_events(): cleanup_delay = parse_timedelta( dask.config.get("distributed.scheduler.events-cleanup-delay") ) - self.loop.call_later(cleanup_delay, remove_client_from_events) + + self.add_background_task(remove_client_from_events(), cleanup_delay) def send_task_to_worker(self, worker, ts: TaskState, duration: float = -1): """Send a single computational task to a worker""" @@ -4879,10 +4880,10 @@ def worker_send(self, worker: str, msg: dict[str, Any]) -> None: try: stream_comms[worker].send(msg) except (CommClosedError, AttributeError): - self.loop.add_callback( - self.remove_worker, - address=worker, - stimulus_id=f"worker-send-comm-fail-{time()}", + self.add_background_task( + self.remove_worker( + address=worker, stimulus_id=f"worker-send-comm-fail-{time()}" + ) ) def client_send(self, client, msg): @@ -4928,10 +4929,11 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): # worker already gone pass except (CommClosedError, AttributeError): - self.loop.add_callback( - self.remove_worker, - address=worker, - stimulus_id=f"send-all-comm-fail-{time()}", + self.add_background_task( + self.remove_worker( + address=worker, + stimulus_id=f"send-all-comm-fail-{time()}", + ) ) ############################ @@ -6930,7 +6932,7 @@ async def get_worker_monitor_info(self, recent=False, starts=None): # Cleanup # ########### - def reevaluate_occupancy(self, worker_index: int = 0): + async def reevaluate_occupancy(self, worker_index: int = 0): """Periodically reassess task duration time The expected duration of a task can change over time. Unfortunately we @@ -6946,33 +6948,29 @@ def reevaluate_occupancy(self, worker_index: int = 0): think about. """ try: - if self.status == Status.closed: - return - last = time() - next_time = timedelta(seconds=0.1) - - if self.proc.cpu_percent() < 50: - workers: list = list(self.workers.values()) - nworkers: int = len(workers) - i: int - for i in range(nworkers): - ws: WorkerState = workers[worker_index % nworkers] - worker_index += 1 - try: - if ws is None or not ws.processing: - continue - self._reevaluate_occupancy_worker(ws) - finally: - del ws # lose ref - - duration = time() - last - if duration > 0.005: # 5ms since last release - next_time = timedelta(seconds=duration * 5) # 25ms gap - break + while self.status != Status.closed: + last = time() + delay = 0.1 - self.loop.add_timeout( - next_time, self.reevaluate_occupancy, worker_index=worker_index - ) + if self.proc.cpu_percent() < 50: + workers: list = list(self.workers.values()) + nworkers: int = len(workers) + i: int + for i in range(nworkers): + ws: WorkerState = workers[worker_index % nworkers] + worker_index += 1 + try: + if ws is None or not ws.processing: + continue + self._reevaluate_occupancy_worker(ws) + finally: + del ws # lose ref + + duration = time() - last + if duration > 0.005: # 5ms since last release + delay = duration * 5 # 25ms gap + break + await asyncio.sleep(delay) except Exception: logger.error("Error in reevaluate occupancy", exc_info=True) @@ -7004,7 +7002,7 @@ def check_idle(self): "Scheduler closing after being idle for %s", format_time(self.idle_timeout), ) - self.loop.add_callback(self.close) + self.add_background_task(self.close()) def adaptive_target(self, target_duration=None): """Desired number of workers based on the current workload diff --git a/distributed/worker.py b/distributed/worker.py index 1795bb80c59..70f0b66b3a1 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -890,7 +890,7 @@ def __init__( if self.lifetime: self.lifetime += (random.random() * 2 - 1) * lifetime_stagger - self.io_loop.call_later(self.lifetime, self.close_gracefully) + self.add_background_task(self.close_gracefully(), self.lifetime) self._async_instructions = set() @@ -952,7 +952,9 @@ def log_event(self, topic, msg): if self.thread_id == threading.get_ident(): self.batched_stream.send(full_msg) else: - self.loop.add_callback(self.batched_stream.send, full_msg) + self.add_background_task( + asyncio.coroutine(self.batched_stream.send)(full_msg) + ) @property def executing_count(self) -> int: @@ -983,21 +985,23 @@ def status(self, value): if prev_status == Status.paused and value == Status.running: self.handle_stimulus(UnpauseEvent(stimulus_id=stimulus_id)) - def _send_worker_status_change(self, stimulus_id: str) -> None: - if ( + async def _send_worker_status_change(self, stimulus_id: str) -> None: + while not ( self.batched_stream and self.batched_stream.comm and not self.batched_stream.comm.closed() ): - self.batched_stream.send( - { - "op": "worker-status-change", - "status": self._status.name, - "stimulus_id": stimulus_id, - }, - ) - elif self._status != Status.closed: - self.loop.call_later(0.05, self._send_worker_status_change, stimulus_id) + if self._status == Status.closed: + return + await asyncio.sleep(0.05) + + self.batched_stream.send( + { + "op": "worker-status-change", + "status": self._status.name, + "stimulus_id": stimulus_id, + }, + ) async def get_metrics(self) -> dict: try: @@ -1180,7 +1184,7 @@ async def _register_with_scheduler(self): self.batched_stream.start(comm) self.periodic_callbacks["keep-alive"].start() self.periodic_callbacks["heartbeat"].start() - self.loop.add_callback(self.handle_scheduler, comm) + self.add_background_task(self.handle_scheduler(comm)) def _update_latency(self, latency): self.latency = latency * 0.05 + self.latency * 0.95 @@ -1638,7 +1642,7 @@ async def batched_send_connect(): bcomm.start(comm) - self.loop.add_callback(batched_send_connect) + self.add_background_task(batched_send_connect()) self.stream_comms[address].send(msg) @@ -3335,7 +3339,7 @@ def done_event(): # Avoid hammering the worker. If there are multiple replicas # available, immediately try fetching from a different worker. self.busy_workers.add(worker) - self.io_loop.call_later(0.15, self._readd_busy_worker, worker) + self.add_background_task(self._readd_busy_worker(worker), delay=0.15) refresh_who_has = set() @@ -3373,7 +3377,7 @@ def done_event(): self.update_who_has(who_has) @log_errors - def _readd_busy_worker(self, worker: str) -> None: + async def _readd_busy_worker(self, worker: str) -> None: self.busy_workers.remove(worker) self.handle_stimulus( GatherDepDoneEvent(stimulus_id=f"readd-busy-worker-{time()}") @@ -3470,7 +3474,8 @@ def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: "Invalid Worker.status transition: %s -> %s", self._status, new_status ) # Reiterate the current status to the scheduler to restore sync - self._send_worker_status_change(stimulus_id) + self.add_background_task(self._send_worker_status_change(stimulus_id)) + else: # Update status and send confirmation to the Scheduler (see status.setter) self.status = new_status From 0af0741023a0957ee9824e3fc0e50652e8150ef9 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 25 May 2022 17:37:02 +0200 Subject: [PATCH 03/75] Move close out of background tasks since it would cancel itself --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a99b873d11f..f714ae20c48 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -7002,7 +7002,7 @@ def check_idle(self): "Scheduler closing after being idle for %s", format_time(self.idle_timeout), ) - self.add_background_task(self.close()) + asyncio.create_task(self.close()) def adaptive_target(self, target_duration=None): """Desired number of workers based on the current workload From d9ce34fdd056142bdf04bec6531afd13a3817fdb Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 25 May 2022 18:18:46 +0200 Subject: [PATCH 04/75] Fix issue with non-running IO loop --- distributed/core.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index cd70d426530..d925e646fd2 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -247,10 +247,10 @@ def stop() -> bool: self.thread_id = 0 - async def set_thread_ident(): + def set_thread_ident(): self.thread_id = threading.get_ident() - self.add_background_task(set_thread_ident()) + self.io_loop.add_callback(set_thread_ident) self._startup_lock = asyncio.Lock() self.__startup_exc: Exception | None = None self.__started = asyncio.Event() @@ -700,6 +700,23 @@ async def close(self, timeout=None): _stops.add(future) await asyncio.gather(*_stops) + def _ongoing_background_tasks(): + return ( + t + for t in self._ongoing_background_tasks + if t is not asyncio.current_task() + ) + + # TODO: Deal with exceptions + try: + # Give the handlers a bit of time to finish gracefully + await asyncio.wait_for( + asyncio.gather(*_ongoing_background_tasks(), return_exceptions=True), 1 + ) + except asyncio.TimeoutError: + # the timeout on gather should've cancelled all the tasks + await asyncio.gather(*_ongoing_background_tasks(), return_exceptions=True) + def _ongoing_comm_handlers(): return ( t From 89f70db666f5c234d00aae96136ca511b4671820 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 25 May 2022 19:53:20 +0200 Subject: [PATCH 05/75] Remove deprecated asyncio.coroutine --- distributed/core.py | 5 ++++- distributed/worker.py | 8 +++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index d925e646fd2..a0d786cb6bd 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -362,7 +362,10 @@ def stop(self): # The demonstrator for this is Worker.terminate(), which # closes the server socket in response to an incoming message. # See https://github.com/tornadoweb/tornado/issues/2069 - self.add_background_task(asyncio.coroutine(listener.stop)()) + async def _stop_listener(): + listener.stop() + + self.add_background_task(_stop_listener()) @property def listener(self): diff --git a/distributed/worker.py b/distributed/worker.py index 70f0b66b3a1..ef38990af7e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -952,9 +952,11 @@ def log_event(self, topic, msg): if self.thread_id == threading.get_ident(): self.batched_stream.send(full_msg) else: - self.add_background_task( - asyncio.coroutine(self.batched_stream.send)(full_msg) - ) + + async def _send_batched_stream(batched_stream): + batched_stream.send(full_msg) + + self.add_background_task(_send_batched_stream(self.batched_stream)) @property def executing_count(self) -> int: From 76edf23d70b61def32627ff530041a0f26bbf5ab Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 27 May 2022 13:15:12 +0200 Subject: [PATCH 06/75] Add delay decorator to delay async function evaluation --- distributed/core.py | 10 +--------- distributed/scheduler.py | 5 +++-- distributed/utils.py | 8 ++++++++ distributed/worker.py | 7 ++++--- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index a0d786cb6bd..d0ec01bdee4 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -677,15 +677,7 @@ async def handle_stream(self, comm, extra=None): await comm.close() assert comm.closed() - def add_background_task(self, coro: Coroutine, delay: float | None = None) -> None: - if delay is not None: - - async def _delay(coro, delay): - await asyncio.sleep(delay) - await coro - - coro = _delay(coro, delay) - + def add_background_task(self, coro: Coroutine) -> None: task = asyncio.create_task(coro) self._ongoing_background_tasks.add(task) task.add_done_callback(self._ongoing_background_tasks.remove) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f714ae20c48..395a6d87329 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -88,6 +88,7 @@ from distributed.utils import ( All, TimeoutError, + delay, empty_context, get_fileno_limit, key_split, @@ -4252,7 +4253,7 @@ async def remove_worker_from_events(): dask.config.get("distributed.scheduler.events-cleanup-delay") ) - self.add_background_task(remove_worker_from_events(), cleanup_delay) + self.add_background_task(delay(remove_worker_from_events, cleanup_delay)()) logger.debug("Removed worker %s", ws) return "OK" @@ -4613,7 +4614,7 @@ async def remove_client_from_events(): dask.config.get("distributed.scheduler.events-cleanup-delay") ) - self.add_background_task(remove_client_from_events(), cleanup_delay) + self.add_background_task(delay(remove_client_from_events, cleanup_delay)()) def send_task_to_worker(self, worker, ts: TaskState, duration: float = -1): """Send a single computational task to a worker""" diff --git a/distributed/utils.py b/distributed/utils.py index e24e45b0c86..94404c206a0 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1732,3 +1732,11 @@ def is_python_shutting_down() -> bool: from distributed import _python_shutting_down return _python_shutting_down + + +def delay(func, delay): + async def wrapper(*args, **kwargs): + await asyncio.sleep(delay) + return await func(*args, **kwargs) + + return wrapper diff --git a/distributed/worker.py b/distributed/worker.py index ef38990af7e..b9356ea402a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -82,6 +82,7 @@ LRU, TimeoutError, _maybe_complex, + delay, get_ip, has_arg, import_file, @@ -890,7 +891,7 @@ def __init__( if self.lifetime: self.lifetime += (random.random() * 2 - 1) * lifetime_stagger - self.add_background_task(self.close_gracefully(), self.lifetime) + self.add_background_task(delay(self.close_gracefully, self.lifetime)()) self._async_instructions = set() @@ -983,7 +984,7 @@ def status(self, value): prev_status = self.status ServerNode.status.__set__(self, value) stimulus_id = f"worker-status-change-{time()}" - self._send_worker_status_change(stimulus_id) + self.add_background_task(self._send_worker_status_change(stimulus_id)) if prev_status == Status.paused and value == Status.running: self.handle_stimulus(UnpauseEvent(stimulus_id=stimulus_id)) @@ -3341,7 +3342,7 @@ def done_event(): # Avoid hammering the worker. If there are multiple replicas # available, immediately try fetching from a different worker. self.busy_workers.add(worker) - self.add_background_task(self._readd_busy_worker(worker), delay=0.15) + self.add_background_task(delay(self._readd_busy_worker, 0.15)(worker)) refresh_who_has = set() From e415d97123d1096ebdfda8039f3dbcb13c58bdef Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 27 May 2022 13:20:15 +0200 Subject: [PATCH 07/75] Replace add_callback in nanny --- distributed/nanny.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index 66351f9d881..73d5668c3fe 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -506,7 +506,7 @@ def run(self, comm, *args, **kwargs): return run(self, comm, *args, **kwargs) def _on_exit_sync(self, exitcode): - self.loop.add_callback(self._on_exit, exitcode) + self.add_background_task(self._on_exit(exitcode)) @log_errors async def _on_exit(self, exitcode): @@ -603,7 +603,7 @@ async def _log_event(self, topic, msg): ) def log_event(self, topic, msg): - self.loop.add_callback(self._log_event, topic, msg) + self.add_background_task(self._log_event(topic, msg)) class WorkerProcess: From 3d688ad5acb941f77384302c3e8e99e9a3654c2f Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 27 May 2022 13:33:44 +0200 Subject: [PATCH 08/75] Add docstring and rename to create_background_task --- distributed/core.py | 10 ++++++---- distributed/nanny.py | 4 ++-- distributed/scheduler.py | 10 +++++----- distributed/utils.py | 8 +++++--- distributed/worker.py | 16 +++++++++------- 5 files changed, 27 insertions(+), 21 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index d0ec01bdee4..f749208691e 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -348,7 +348,7 @@ async def start_pcs(): if not pc.is_running(): pc.start() - self.add_background_task(start_pcs()) + self.create_background_task(start_pcs()) def stop(self): if not self.__stopped: @@ -365,7 +365,7 @@ def stop(self): async def _stop_listener(): listener.stop() - self.add_background_task(_stop_listener()) + self.create_background_task(_stop_listener()) @property def listener(self): @@ -656,7 +656,9 @@ async def handle_stream(self, comm, extra=None): break handler = self.stream_handlers[op] if is_coroutine_function(handler): - self.add_background_task(handler(**merge(extra, msg))) + self.create_background_task( + handler(**merge(extra, msg)) + ) await asyncio.sleep(0) else: handler(**merge(extra, msg)) @@ -677,7 +679,7 @@ async def handle_stream(self, comm, extra=None): await comm.close() assert comm.closed() - def add_background_task(self, coro: Coroutine) -> None: + def create_background_task(self, coro: Coroutine) -> None: task = asyncio.create_task(coro) self._ongoing_background_tasks.add(task) task.add_done_callback(self._ongoing_background_tasks.remove) diff --git a/distributed/nanny.py b/distributed/nanny.py index 73d5668c3fe..71165aa22ee 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -506,7 +506,7 @@ def run(self, comm, *args, **kwargs): return run(self, comm, *args, **kwargs) def _on_exit_sync(self, exitcode): - self.add_background_task(self._on_exit(exitcode)) + self.create_background_task(self._on_exit(exitcode)) @log_errors async def _on_exit(self, exitcode): @@ -603,7 +603,7 @@ async def _log_event(self, topic, msg): ) def log_event(self, topic, msg): - self.add_background_task(self._log_event(topic, msg)) + self.create_background_task(self._log_event(topic, msg)) class WorkerProcess: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 395a6d87329..6d23db1e0bf 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3321,7 +3321,7 @@ async def start_unsafe(self): for k, v in self.services.items(): logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) - self.add_background_task(self.reevaluate_occupancy()) + self.create_background_task(self.reevaluate_occupancy()) if self.scheduler_file: with open(self.scheduler_file, "w") as f: @@ -4253,7 +4253,7 @@ async def remove_worker_from_events(): dask.config.get("distributed.scheduler.events-cleanup-delay") ) - self.add_background_task(delay(remove_worker_from_events, cleanup_delay)()) + self.create_background_task(delay(remove_worker_from_events, cleanup_delay)()) logger.debug("Removed worker %s", ws) return "OK" @@ -4614,7 +4614,7 @@ async def remove_client_from_events(): dask.config.get("distributed.scheduler.events-cleanup-delay") ) - self.add_background_task(delay(remove_client_from_events, cleanup_delay)()) + self.create_background_task(delay(remove_client_from_events, cleanup_delay)()) def send_task_to_worker(self, worker, ts: TaskState, duration: float = -1): """Send a single computational task to a worker""" @@ -4881,7 +4881,7 @@ def worker_send(self, worker: str, msg: dict[str, Any]) -> None: try: stream_comms[worker].send(msg) except (CommClosedError, AttributeError): - self.add_background_task( + self.create_background_task( self.remove_worker( address=worker, stimulus_id=f"worker-send-comm-fail-{time()}" ) @@ -4930,7 +4930,7 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): # worker already gone pass except (CommClosedError, AttributeError): - self.add_background_task( + self.create_background_task( self.remove_worker( address=worker, stimulus_id=f"send-all-comm-fail-{time()}", diff --git a/distributed/utils.py b/distributed/utils.py index 94404c206a0..8f577f6f8fc 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -32,7 +32,7 @@ from types import ModuleType from typing import TYPE_CHECKING from typing import Any as AnyType -from typing import ClassVar, TypeVar, overload +from typing import ClassVar, Coroutine, TypeVar, overload import click import tblib.pickling_support @@ -1734,9 +1734,11 @@ def is_python_shutting_down() -> bool: return _python_shutting_down -def delay(func, delay): +def delay(corofunc: Callable[..., Coroutine], delay: float): + """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" + async def wrapper(*args, **kwargs): await asyncio.sleep(delay) - return await func(*args, **kwargs) + return await corofunc(*args, **kwargs) return wrapper diff --git a/distributed/worker.py b/distributed/worker.py index b9356ea402a..feda84b99cc 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -891,7 +891,7 @@ def __init__( if self.lifetime: self.lifetime += (random.random() * 2 - 1) * lifetime_stagger - self.add_background_task(delay(self.close_gracefully, self.lifetime)()) + self.create_background_task(delay(self.close_gracefully, self.lifetime)()) # type: ignore self._async_instructions = set() @@ -957,7 +957,7 @@ def log_event(self, topic, msg): async def _send_batched_stream(batched_stream): batched_stream.send(full_msg) - self.add_background_task(_send_batched_stream(self.batched_stream)) + self.create_background_task(_send_batched_stream(self.batched_stream)) @property def executing_count(self) -> int: @@ -984,7 +984,7 @@ def status(self, value): prev_status = self.status ServerNode.status.__set__(self, value) stimulus_id = f"worker-status-change-{time()}" - self.add_background_task(self._send_worker_status_change(stimulus_id)) + self.create_background_task(self._send_worker_status_change(stimulus_id)) if prev_status == Status.paused and value == Status.running: self.handle_stimulus(UnpauseEvent(stimulus_id=stimulus_id)) @@ -1187,7 +1187,7 @@ async def _register_with_scheduler(self): self.batched_stream.start(comm) self.periodic_callbacks["keep-alive"].start() self.periodic_callbacks["heartbeat"].start() - self.add_background_task(self.handle_scheduler(comm)) + self.create_background_task(self.handle_scheduler(comm)) def _update_latency(self, latency): self.latency = latency * 0.05 + self.latency * 0.95 @@ -1645,7 +1645,7 @@ async def batched_send_connect(): bcomm.start(comm) - self.add_background_task(batched_send_connect()) + self.create_background_task(batched_send_connect()) self.stream_comms[address].send(msg) @@ -3342,7 +3342,9 @@ def done_event(): # Avoid hammering the worker. If there are multiple replicas # available, immediately try fetching from a different worker. self.busy_workers.add(worker) - self.add_background_task(delay(self._readd_busy_worker, 0.15)(worker)) + self.create_background_task( + delay(self._readd_busy_worker, 0.15)(worker) + ) refresh_who_has = set() @@ -3477,7 +3479,7 @@ def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: "Invalid Worker.status transition: %s -> %s", self._status, new_status ) # Reiterate the current status to the scheduler to restore sync - self.add_background_task(self._send_worker_status_change(stimulus_id)) + self.create_background_task(self._send_worker_status_change(stimulus_id)) else: # Update status and send confirmation to the Scheduler (see status.setter) From d91a4cf5923df21eeb7271a138056d9021f5bc45 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 27 May 2022 17:40:52 +0200 Subject: [PATCH 09/75] Factor functionality out into TaskGroup and adjust interface to avoid unawaited coroutine warning --- distributed/core.py | 76 ++++++++++++++++++++++++++-------------- distributed/nanny.py | 4 +-- distributed/scheduler.py | 24 ++++++------- distributed/utils.py | 2 +- distributed/worker.py | 17 ++++----- 5 files changed, 71 insertions(+), 52 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index f749208691e..5586f426a6d 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -15,7 +15,7 @@ from contextlib import suppress from enum import Enum from functools import partial -from typing import Callable, ClassVar, Coroutine, TypedDict, TypeVar +from typing import Callable, ClassVar, TypedDict, TypeVar import tblib from tlz import merge @@ -37,6 +37,7 @@ from distributed.metrics import time from distributed.system_monitor import SystemMonitor from distributed.utils import ( + delayed, get_traceback, has_keyword, is_coroutine_function, @@ -108,6 +109,44 @@ def _expects_comm(func: Callable) -> bool: return False +class TaskGroup: + def __init__(self): + self.closed = False + self._ongoing_tasks = set() + + def call_soon(self, afunc, *args, **kwargs): + if not self.closed: + task = asyncio.create_task(afunc(*args, **kwargs)) + self._ongoing_tasks.add(task) + task.add_done_callback(self._ongoing_tasks.remove) + + def call_later(self, delay, afunc, *args, **kwargs): + self.call_soon(delayed(afunc, delay), *args, **kwargs) + + @property + def _cancellable_tasks(self): + return (t for t in self._ongoing_tasks if t is not asyncio.current_task()) + + async def cancel(self): + self.close() + return await asyncio.gather(*self._cancellable_tasks, return_exceptions=True) + + def close(self): + self.closed = True + + async def stop(self, timeout=1): + self.close() + try: + # Give the tasks a bit of time to finish gracefully + return await asyncio.wait_for( + asyncio.gather(*self._cancellable_tasks, return_exceptions=True), + timeout, + ) + except asyncio.TimeoutError: + # the timeout on gather should've cancelled all the tasks + return await self.cancel() + + class Server: """Dask Distributed Server @@ -186,7 +225,7 @@ def __init__( self.monitor = SystemMonitor() self.counters = None self.digests = None - self._ongoing_background_tasks = set() + self._background_tasks = TaskGroup() self._ongoing_comm_handlers = set() self._event_finished = asyncio.Event() @@ -348,7 +387,7 @@ async def start_pcs(): if not pc.is_running(): pc.start() - self.create_background_task(start_pcs()) + self._call_soon(start_pcs) def stop(self): if not self.__stopped: @@ -365,7 +404,7 @@ def stop(self): async def _stop_listener(): listener.stop() - self.create_background_task(_stop_listener()) + self._call_soon(_stop_listener) @property def listener(self): @@ -656,9 +695,7 @@ async def handle_stream(self, comm, extra=None): break handler = self.stream_handlers[op] if is_coroutine_function(handler): - self.create_background_task( - handler(**merge(extra, msg)) - ) + self._call_soon(handler, **merge(extra, msg)) await asyncio.sleep(0) else: handler(**merge(extra, msg)) @@ -679,10 +716,11 @@ async def handle_stream(self, comm, extra=None): await comm.close() assert comm.closed() - def create_background_task(self, coro: Coroutine) -> None: - task = asyncio.create_task(coro) - self._ongoing_background_tasks.add(task) - task.add_done_callback(self._ongoing_background_tasks.remove) + def _call_later(self, delay, afunc, *args, **kwargs): + self._background_tasks.call_later(delay, afunc, *args, **kwargs) + + def _call_soon(self, afunc, *args, **kwargs): + self._background_tasks.call_soon(afunc, *args, **kwargs) async def close(self, timeout=None): for pc in self.periodic_callbacks.values(): @@ -697,22 +735,8 @@ async def close(self, timeout=None): _stops.add(future) await asyncio.gather(*_stops) - def _ongoing_background_tasks(): - return ( - t - for t in self._ongoing_background_tasks - if t is not asyncio.current_task() - ) - # TODO: Deal with exceptions - try: - # Give the handlers a bit of time to finish gracefully - await asyncio.wait_for( - asyncio.gather(*_ongoing_background_tasks(), return_exceptions=True), 1 - ) - except asyncio.TimeoutError: - # the timeout on gather should've cancelled all the tasks - await asyncio.gather(*_ongoing_background_tasks(), return_exceptions=True) + await self._background_tasks.stop(timeout=1) def _ongoing_comm_handlers(): return ( diff --git a/distributed/nanny.py b/distributed/nanny.py index 71165aa22ee..2daba622cb8 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -506,7 +506,7 @@ def run(self, comm, *args, **kwargs): return run(self, comm, *args, **kwargs) def _on_exit_sync(self, exitcode): - self.create_background_task(self._on_exit(exitcode)) + self._call_soon(self._on_exit, exitcode) @log_errors async def _on_exit(self, exitcode): @@ -603,7 +603,7 @@ async def _log_event(self, topic, msg): ) def log_event(self, topic, msg): - self.create_background_task(self._log_event(topic, msg)) + self._call_soon(self._log_event, topic, msg) class WorkerProcess: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6d23db1e0bf..16b516573a2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -88,7 +88,6 @@ from distributed.utils import ( All, TimeoutError, - delay, empty_context, get_fileno_limit, key_split, @@ -3321,7 +3320,7 @@ async def start_unsafe(self): for k, v in self.services.items(): logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) - self.create_background_task(self.reevaluate_occupancy()) + self._call_soon(self.reevaluate_occupancy) if self.scheduler_file: with open(self.scheduler_file, "w") as f: @@ -4253,7 +4252,7 @@ async def remove_worker_from_events(): dask.config.get("distributed.scheduler.events-cleanup-delay") ) - self.create_background_task(delay(remove_worker_from_events, cleanup_delay)()) + self._call_later(cleanup_delay, remove_worker_from_events) logger.debug("Removed worker %s", ws) return "OK" @@ -4614,7 +4613,7 @@ async def remove_client_from_events(): dask.config.get("distributed.scheduler.events-cleanup-delay") ) - self.create_background_task(delay(remove_client_from_events, cleanup_delay)()) + self._call_later(cleanup_delay, remove_client_from_events) def send_task_to_worker(self, worker, ts: TaskState, duration: float = -1): """Send a single computational task to a worker""" @@ -4881,10 +4880,10 @@ def worker_send(self, worker: str, msg: dict[str, Any]) -> None: try: stream_comms[worker].send(msg) except (CommClosedError, AttributeError): - self.create_background_task( - self.remove_worker( - address=worker, stimulus_id=f"worker-send-comm-fail-{time()}" - ) + self._call_soon( + self.remove_worker, + address=worker, + stimulus_id=f"worker-send-comm-fail-{time()}", ) def client_send(self, client, msg): @@ -4930,11 +4929,10 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): # worker already gone pass except (CommClosedError, AttributeError): - self.create_background_task( - self.remove_worker( - address=worker, - stimulus_id=f"send-all-comm-fail-{time()}", - ) + self._call_soon( + self.remove_worker, + address=worker, + stimulus_id=f"send-all-comm-fail-{time()}", ) ############################ diff --git a/distributed/utils.py b/distributed/utils.py index 8f577f6f8fc..841278e67fa 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1734,7 +1734,7 @@ def is_python_shutting_down() -> bool: return _python_shutting_down -def delay(corofunc: Callable[..., Coroutine], delay: float): +def delayed(corofunc: Callable[..., Coroutine], delay: float): """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" async def wrapper(*args, **kwargs): diff --git a/distributed/worker.py b/distributed/worker.py index feda84b99cc..9628abd884c 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -82,7 +82,6 @@ LRU, TimeoutError, _maybe_complex, - delay, get_ip, has_arg, import_file, @@ -891,7 +890,7 @@ def __init__( if self.lifetime: self.lifetime += (random.random() * 2 - 1) * lifetime_stagger - self.create_background_task(delay(self.close_gracefully, self.lifetime)()) # type: ignore + self._call_later(self.lifetime, self.close_gracefully) self._async_instructions = set() @@ -957,7 +956,7 @@ def log_event(self, topic, msg): async def _send_batched_stream(batched_stream): batched_stream.send(full_msg) - self.create_background_task(_send_batched_stream(self.batched_stream)) + self._call_soon(_send_batched_stream, self.batched_stream) @property def executing_count(self) -> int: @@ -984,7 +983,7 @@ def status(self, value): prev_status = self.status ServerNode.status.__set__(self, value) stimulus_id = f"worker-status-change-{time()}" - self.create_background_task(self._send_worker_status_change(stimulus_id)) + self._call_soon(self._send_worker_status_change, stimulus_id) if prev_status == Status.paused and value == Status.running: self.handle_stimulus(UnpauseEvent(stimulus_id=stimulus_id)) @@ -1187,7 +1186,7 @@ async def _register_with_scheduler(self): self.batched_stream.start(comm) self.periodic_callbacks["keep-alive"].start() self.periodic_callbacks["heartbeat"].start() - self.create_background_task(self.handle_scheduler(comm)) + self._call_soon(self.handle_scheduler, comm) def _update_latency(self, latency): self.latency = latency * 0.05 + self.latency * 0.95 @@ -1645,7 +1644,7 @@ async def batched_send_connect(): bcomm.start(comm) - self.create_background_task(batched_send_connect()) + self._call_soon(batched_send_connect) self.stream_comms[address].send(msg) @@ -3342,9 +3341,7 @@ def done_event(): # Avoid hammering the worker. If there are multiple replicas # available, immediately try fetching from a different worker. self.busy_workers.add(worker) - self.create_background_task( - delay(self._readd_busy_worker, 0.15)(worker) - ) + self._call_later(0.15, self._readd_busy_worker, worker) refresh_who_has = set() @@ -3479,7 +3476,7 @@ def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: "Invalid Worker.status transition: %s -> %s", self._status, new_status ) # Reiterate the current status to the scheduler to restore sync - self.create_background_task(self._send_worker_status_change(stimulus_id)) + self._call_soon(self._send_worker_status_change, stimulus_id) else: # Update status and send confirmation to the Scheduler (see status.setter) From 8312041c23d9b16c636468e9a1289a67e811037e Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 27 May 2022 18:50:29 +0200 Subject: [PATCH 10/75] Rename --- distributed/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 5586f426a6d..4e4e6e0f997 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -225,7 +225,7 @@ def __init__( self.monitor = SystemMonitor() self.counters = None self.digests = None - self._background_tasks = TaskGroup() + self._ongoing_background_tasks = TaskGroup() self._ongoing_comm_handlers = set() self._event_finished = asyncio.Event() @@ -717,10 +717,10 @@ async def handle_stream(self, comm, extra=None): assert comm.closed() def _call_later(self, delay, afunc, *args, **kwargs): - self._background_tasks.call_later(delay, afunc, *args, **kwargs) + self._ongoing_background_tasks.call_later(delay, afunc, *args, **kwargs) def _call_soon(self, afunc, *args, **kwargs): - self._background_tasks.call_soon(afunc, *args, **kwargs) + self._ongoing_background_tasks.call_soon(afunc, *args, **kwargs) async def close(self, timeout=None): for pc in self.periodic_callbacks.values(): @@ -736,7 +736,7 @@ async def close(self, timeout=None): await asyncio.gather(*_stops) # TODO: Deal with exceptions - await self._background_tasks.stop(timeout=1) + await self._ongoing_background_tasks.stop(timeout=1) def _ongoing_comm_handlers(): return ( From 488b1119ee0eb52cb3274d72a83888d43a265d0c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 27 May 2022 19:09:16 +0200 Subject: [PATCH 11/75] Revert changes to log_event --- distributed/tests/test_client.py | 2 +- distributed/worker.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index fe2fed711a5..363b40434ee 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5151,7 +5151,7 @@ def long_running(lock, entered): assert s.total_occupancy == 0 assert ws.occupancy == 0 - s.reevaluate_occupancy(0) + s._call_soon(s.reevaluate_occupancy, 0) assert s.workers[a.address].occupancy == 0 await l.release() diff --git a/distributed/worker.py b/distributed/worker.py index 9628abd884c..057bedd0d4e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -952,11 +952,7 @@ def log_event(self, topic, msg): if self.thread_id == threading.get_ident(): self.batched_stream.send(full_msg) else: - - async def _send_batched_stream(batched_stream): - batched_stream.send(full_msg) - - self._call_soon(_send_batched_stream, self.batched_stream) + self.loop.add_callback(self.batched_stream.send, full_msg) @property def executing_count(self) -> int: From 9155235e232d9bfe34691e3b3e999b0c5126f6a0 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 30 May 2022 11:45:18 +0200 Subject: [PATCH 12/75] Revert changes to lifetime callback --- distributed/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index 057bedd0d4e..7e0cb160440 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -890,7 +890,7 @@ def __init__( if self.lifetime: self.lifetime += (random.random() * 2 - 1) * lifetime_stagger - self._call_later(self.lifetime, self.close_gracefully) + self.io_loop.call_later(self.lifetime, self.close_gracefully) self._async_instructions = set() From 62c9383b3f25a77e236cd0639d8c1a6a0a5bd1af Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 30 May 2022 11:56:05 +0200 Subject: [PATCH 13/75] Fix test --- distributed/tests/test_nanny.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 6e52b38f9a3..79743a6bcd5 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -578,7 +578,7 @@ async def test_scheduler_crash_doesnt_restart(s, a): bcomm.abort() await s.close() - while a.status != Status.closing_gracefully: + while a.status not in {Status.closing_gracefully, Status.closed}: await asyncio.sleep(0.01) await a.finished() From 96c634606c6a1aaa1d269ca75945b2513da16cff Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 30 May 2022 12:15:29 +0200 Subject: [PATCH 14/75] Ignore cancelled error when awaiting finished() --- distributed/worker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index 7e0cb160440..627a725b669 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1453,7 +1453,10 @@ async def close( # nanny+worker, the nanny must be notified first. ==> Remove kwarg # nanny, see also Scheduler.retire_workers if self.status in (Status.closed, Status.closing, Status.failed): - await self.finished() + try: + await self.finished() + except asyncio.CancelledError: + pass return if self.status == Status.init: From 0248587ef4ba384ef2d401cbc016eae3a50b6cee Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 30 May 2022 15:14:15 +0200 Subject: [PATCH 15/75] Fix test --- distributed/deploy/tests/test_local.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 85c1f21b30e..8e75c84b106 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -996,6 +996,7 @@ async def test_no_dangling_asyncio_tasks(): async with LocalCluster(asynchronous=True, processes=False, dashboard_address=":0"): await asyncio.sleep(0.01) + await asyncio.sleep(0.01) tasks = asyncio.all_tasks() assert tasks == start From 3a5695a404171f00017d44ef3b85e0598acd625f Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 30 May 2022 15:27:01 +0200 Subject: [PATCH 16/75] Fix test for adaptive scaling by adjusting wait condition --- distributed/deploy/tests/test_adaptive.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 27b2bd3f6f4..893c5a9a14c 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -225,7 +225,11 @@ async def test_adapt_quickly(): await cluster - while len(cluster.scheduler.workers) > 1 or len(cluster.worker_spec) > 1: + while ( + len(cluster.scheduler.workers) > 1 + or len(cluster.worker_spec) > 1 + or len(cluster.workers) > 1 + ): await asyncio.sleep(0.01) # Don't scale up for large sequential computations From aaa3cd4a9d1e523b2e69b9b0798ac5e96daf6511 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 30 May 2022 17:00:27 +0200 Subject: [PATCH 17/75] Enable tmate for remote debugging of flaking tests --- .github/workflows/tests.yaml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 41d97885498..27e95174239 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -124,6 +124,7 @@ jobs: --leaks=fds,processes,threads \ --junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \ --cov=distributed --cov-report=xml \ + -k test_quiet_close_process \ | tee reports/stdout - name: Generate junit XML report in case of pytest-timeout @@ -138,9 +139,9 @@ jobs: python continuous_integration/scripts/parse_stdout.py < reports/stdout > reports/pytest.xml fi - # - name: Debug with tmate on failure - # if: ${{ failure() }} - # uses: mxschmitt/action-tmate@v3 + - name: Debug with tmate on failure + if: ${{ failure() }} + uses: mxschmitt/action-tmate@v3 - name: Coverage uses: codecov/codecov-action@v1 From b27178392b69186b473b7bb6a65b99b5aaa74195 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 30 May 2022 17:50:16 +0200 Subject: [PATCH 18/75] Re-raise exception unless cancelled --- distributed/core.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 4e4e6e0f997..f876f6df7e8 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -134,14 +134,19 @@ async def cancel(self): def close(self): self.closed = True - async def stop(self, timeout=1): + async def stop(self, timeout=0): self.close() try: # Give the tasks a bit of time to finish gracefully - return await asyncio.wait_for( + futures = await asyncio.wait_for( asyncio.gather(*self._cancellable_tasks, return_exceptions=True), timeout, ) + for future in futures: + try: + future.exception() + except asyncio.CancelledError: + pass except asyncio.TimeoutError: # the timeout on gather should've cancelled all the tasks return await self.cancel() @@ -748,9 +753,14 @@ def _ongoing_comm_handlers(): # TODO: Deal with exceptions try: # Give the handlers a bit of time to finish gracefully - await asyncio.wait_for( + futures = await asyncio.wait_for( asyncio.gather(*_ongoing_comm_handlers(), return_exceptions=True), 1 ) + for future in futures: + try: + future.exception() + except asyncio.CancelledError: + pass except asyncio.TimeoutError: # the timeout on gather should've cancelled all the tasks await asyncio.gather(*_ongoing_comm_handlers(), return_exceptions=True) From 54599649235752a8fff4dd330b3890422a7e8904 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 30 May 2022 18:17:56 +0200 Subject: [PATCH 19/75] Re-raise exception unless cancelled --- distributed/core.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index f876f6df7e8..86fd127d6fa 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -129,7 +129,12 @@ def _cancellable_tasks(self): async def cancel(self): self.close() - return await asyncio.gather(*self._cancellable_tasks, return_exceptions=True) + futures = await asyncio.gather(*self._cancellable_tasks, return_exceptions=True) + for future in futures: + try: + future.exception() + except asyncio.CancelledError: + pass def close(self): self.closed = True @@ -763,7 +768,15 @@ def _ongoing_comm_handlers(): pass except asyncio.TimeoutError: # the timeout on gather should've cancelled all the tasks - await asyncio.gather(*_ongoing_comm_handlers(), return_exceptions=True) + futures = await asyncio.gather( + *_ongoing_comm_handlers(), return_exceptions=True + ) + + for future in futures: + try: + future.exception() + except asyncio.CancelledError: + pass await self.rpc.close() await asyncio.gather(*[comm.close() for comm in list(self._comms)]) From 773188ad968d450628512386e98a6bd5998f2e4d Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 30 May 2022 21:13:43 +0200 Subject: [PATCH 20/75] Fix '_GatheringFuture exception was never retrieved' --- distributed/core.py | 47 +++++++++++---------------------------------- 1 file changed, 11 insertions(+), 36 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 86fd127d6fa..465b02bc149 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -127,34 +127,23 @@ def call_later(self, delay, afunc, *args, **kwargs): def _cancellable_tasks(self): return (t for t in self._ongoing_tasks if t is not asyncio.current_task()) - async def cancel(self): - self.close() - futures = await asyncio.gather(*self._cancellable_tasks, return_exceptions=True) - for future in futures: - try: - future.exception() - except asyncio.CancelledError: - pass - def close(self): self.closed = True - async def stop(self, timeout=0): + async def stop(self, timeout=1): self.close() try: # Give the tasks a bit of time to finish gracefully - futures = await asyncio.wait_for( - asyncio.gather(*self._cancellable_tasks, return_exceptions=True), + gather = asyncio.gather( + *self._cancellable_tasks, + return_exceptions=True, + ) + await asyncio.wait_for( + gather, timeout, ) - for future in futures: - try: - future.exception() - except asyncio.CancelledError: - pass except asyncio.TimeoutError: - # the timeout on gather should've cancelled all the tasks - return await self.cancel() + return await gather class Server: @@ -758,25 +747,11 @@ def _ongoing_comm_handlers(): # TODO: Deal with exceptions try: # Give the handlers a bit of time to finish gracefully - futures = await asyncio.wait_for( - asyncio.gather(*_ongoing_comm_handlers(), return_exceptions=True), 1 - ) - for future in futures: - try: - future.exception() - except asyncio.CancelledError: - pass + gather = asyncio.gather(*_ongoing_comm_handlers(), return_exceptions=True) + await asyncio.wait_for(gather, 1) except asyncio.TimeoutError: # the timeout on gather should've cancelled all the tasks - futures = await asyncio.gather( - *_ongoing_comm_handlers(), return_exceptions=True - ) - - for future in futures: - try: - future.exception() - except asyncio.CancelledError: - pass + await gather await self.rpc.close() await asyncio.gather(*[comm.close() for comm in list(self._comms)]) From 734e893e9b64df71d5e37e861da33fea31c7fdf9 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 30 May 2022 21:49:51 +0200 Subject: [PATCH 21/75] Catch cancellederror on cancelled gather --- distributed/core.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 465b02bc149..fcbaf9d8708 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -143,7 +143,10 @@ async def stop(self, timeout=1): timeout, ) except asyncio.TimeoutError: - return await gather + try: + await gather + except asyncio.CancelledError: + pass class Server: @@ -751,7 +754,10 @@ def _ongoing_comm_handlers(): await asyncio.wait_for(gather, 1) except asyncio.TimeoutError: # the timeout on gather should've cancelled all the tasks - await gather + try: + await gather + except asyncio.CancelledError: + pass await self.rpc.close() await asyncio.gather(*[comm.close() for comm in list(self._comms)]) From bc45f99c010096a08e66501a3180039ba6b4439a Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 30 May 2022 21:56:22 +0200 Subject: [PATCH 22/75] Revert changes to tests.yaml --- .github/workflows/tests.yaml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 27e95174239..41d97885498 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -124,7 +124,6 @@ jobs: --leaks=fds,processes,threads \ --junitxml reports/pytest.xml -o junit_suite_name=$TEST_ID \ --cov=distributed --cov-report=xml \ - -k test_quiet_close_process \ | tee reports/stdout - name: Generate junit XML report in case of pytest-timeout @@ -139,9 +138,9 @@ jobs: python continuous_integration/scripts/parse_stdout.py < reports/stdout > reports/pytest.xml fi - - name: Debug with tmate on failure - if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3 + # - name: Debug with tmate on failure + # if: ${{ failure() }} + # uses: mxschmitt/action-tmate@v3 - name: Coverage uses: codecov/codecov-action@v1 From a932b79e245676df71407dc93e0269286707294b Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 31 May 2022 14:19:52 +0200 Subject: [PATCH 23/75] Replace _ongoing_comm_handlers with TaskGroup --- distributed/core.py | 65 ++++++++++++++++++--------------------------- 1 file changed, 26 insertions(+), 39 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index fcbaf9d8708..df750eb97d2 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -114,14 +114,17 @@ def __init__(self): self.closed = False self._ongoing_tasks = set() - def call_soon(self, afunc, *args, **kwargs): - if not self.closed: - task = asyncio.create_task(afunc(*args, **kwargs)) - self._ongoing_tasks.add(task) - task.add_done_callback(self._ongoing_tasks.remove) + def call_soon(self, afunc, *args, **kwargs) -> asyncio.Task | None: + if self.closed: + return None + + task = asyncio.create_task(afunc(*args, **kwargs)) + self._ongoing_tasks.add(task) + task.add_done_callback(self._ongoing_tasks.remove) + return task - def call_later(self, delay, afunc, *args, **kwargs): - self.call_soon(delayed(afunc, delay), *args, **kwargs) + def call_later(self, delay, afunc, *args, **kwargs) -> asyncio.Task | None: + return self.call_soon(delayed(afunc, delay), *args, **kwargs) @property def _cancellable_tasks(self): @@ -228,7 +231,7 @@ def __init__( self.counters = None self.digests = None self._ongoing_background_tasks = TaskGroup() - self._ongoing_comm_handlers = set() + self._ongoing_comm_handlers = TaskGroup() self._event_finished = asyncio.Event() self.listeners = [] @@ -620,21 +623,21 @@ async def handle_comm(self, comm): logger.debug("Calling into handler %s", handler.__name__) try: - if _expects_comm(handler): - result = handler(comm, **msg) - else: - result = handler(**msg) - if inspect.iscoroutine(result): - result = asyncio.create_task( - result, name=f"handle-comm-{address}-{op}" - ) - self._ongoing_comm_handlers.add(result) - result.add_done_callback(self._ongoing_comm_handlers.remove) + if inspect.iscoroutinefunction(handler): + if _expects_comm(handler): + result = self._ongoing_comm_handlers.call_soon( + handler, comm, **msg + ) + else: + result = self._ongoing_comm_handlers.call_soon( + handler, **msg + ) result = await result - elif inspect.isawaitable(result): - raise RuntimeError( - f"Comm handler returned unknown awaitable. Expected coroutine, instead got {type(result)}" - ) + else: + if _expects_comm(handler): + result = handler(comm, **msg) + else: + result = handler(**msg) except CommClosedError: if self.status == Status.running: logger.info("Lost connection to %r", address, exc_info=True) @@ -740,24 +743,8 @@ async def close(self, timeout=None): # TODO: Deal with exceptions await self._ongoing_background_tasks.stop(timeout=1) - def _ongoing_comm_handlers(): - return ( - t - for t in self._ongoing_comm_handlers - if t is not asyncio.current_task() - ) - # TODO: Deal with exceptions - try: - # Give the handlers a bit of time to finish gracefully - gather = asyncio.gather(*_ongoing_comm_handlers(), return_exceptions=True) - await asyncio.wait_for(gather, 1) - except asyncio.TimeoutError: - # the timeout on gather should've cancelled all the tasks - try: - await gather - except asyncio.CancelledError: - pass + await self._ongoing_comm_handlers.stop(timeout=1) await self.rpc.close() await asyncio.gather(*[comm.close() for comm in list(self._comms)]) From bb9d71b5f5208355cb6228ab824d1dee277a389f Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 31 May 2022 14:44:56 +0200 Subject: [PATCH 24/75] Add docstrings --- distributed/core.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/distributed/core.py b/distributed/core.py index df750eb97d2..62c0e8c7992 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -110,11 +110,20 @@ def _expects_comm(func: Callable) -> bool: class TaskGroup: + """Collection tracking all currently running tasks within a group""" + + #: If True, the group is closed and does not allow adding new tasks. + closed: bool + def __init__(self): self.closed = False self._ongoing_tasks = set() def call_soon(self, afunc, *args, **kwargs) -> asyncio.Task | None: + """Schedule the coroutine function `afunc` to be executed with `args` + arguments and `kwargs` keyword arguments as an asyncio.Task. + Returns the Task object. + """ if self.closed: return None @@ -124,6 +133,10 @@ def call_soon(self, afunc, *args, **kwargs) -> asyncio.Task | None: return task def call_later(self, delay, afunc, *args, **kwargs) -> asyncio.Task | None: + """Schedule the coroutine function `afunc` to be executed after `delay` seconds with `args` + arguments and `kwargs` keyword arguments as an asyncio.Task. + Returns the Task object. + """ return self.call_soon(delayed(afunc, delay), *args, **kwargs) @property @@ -131,12 +144,17 @@ def _cancellable_tasks(self): return (t for t in self._ongoing_tasks if t is not asyncio.current_task()) def close(self): + """Closes the task group so that no new tasks can be scheduled. + Existing tasks continue to run. + """ self.closed = True async def stop(self, timeout=1): + """Closes the task group and waits `timeout` seconds for all tasks to gracefully finish. + After the timeout, all remaining tasks are cancelled. + """ self.close() try: - # Give the tasks a bit of time to finish gracefully gather = asyncio.gather( *self._cancellable_tasks, return_exceptions=True, From eda6b7317d1a45ba5ab0b9b0117faae71d4dbae8 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 31 May 2022 15:22:34 +0200 Subject: [PATCH 25/75] Improved docs --- distributed/core.py | 61 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 62c0e8c7992..dc064151a1d 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -15,7 +15,7 @@ from contextlib import suppress from enum import Enum from functools import partial -from typing import Callable, ClassVar, TypedDict, TypeVar +from typing import Callable, ClassVar, Coroutine, TypedDict, TypeVar import tblib from tlz import merge @@ -110,7 +110,7 @@ def _expects_comm(func: Callable) -> bool: class TaskGroup: - """Collection tracking all currently running tasks within a group""" + """Collection tracking all currently running asynchronous tasks within a group""" #: If True, the group is closed and does not allow adding new tasks. closed: bool @@ -119,10 +119,27 @@ def __init__(self): self.closed = False self._ongoing_tasks = set() - def call_soon(self, afunc, *args, **kwargs) -> asyncio.Task | None: - """Schedule the coroutine function `afunc` to be executed with `args` - arguments and `kwargs` keyword arguments as an asyncio.Task. - Returns the Task object. + def call_soon( + self, afunc: Callable[..., Coroutine], *args, **kwargs + ) -> asyncio.Task | None: + """Schedule a coroutine function to be executed as an `asyncio.Task`. + + The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments + as an `asyncio.Task`. + + Parameters + ---------- + afunc + Coroutine function to schedule. + *args + Arguments to be passed to `afunc`. + **kwargs + Keyword arguments to be passed to `afunc` + + Returns + ------- + asyncio.Task | None + The scheduled Task object, or None if the group is closed. """ if self.closed: return None @@ -132,10 +149,29 @@ def call_soon(self, afunc, *args, **kwargs) -> asyncio.Task | None: task.add_done_callback(self._ongoing_tasks.remove) return task - def call_later(self, delay, afunc, *args, **kwargs) -> asyncio.Task | None: - """Schedule the coroutine function `afunc` to be executed after `delay` seconds with `args` - arguments and `kwargs` keyword arguments as an asyncio.Task. - Returns the Task object. + def call_later( + self, delay: int, afunc: Callable[..., Coroutine], *args, **kwargs + ) -> asyncio.Task | None: + """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`. + + The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments + as an `asyncio.Task` that is executed after `delay` seconds. + + Parameters + ---------- + delay + Delay in seconds. + afunc + Coroutine function to schedule. + *args + Arguments to be passed to `afunc`. + **kwargs + Keyword arguments to be passed to `afunc` + + Returns + ------- + asyncio.Task | None + The scheduled Task object, or None if the group is closed. """ return self.call_soon(delayed(afunc, delay), *args, **kwargs) @@ -145,12 +181,15 @@ def _cancellable_tasks(self): def close(self): """Closes the task group so that no new tasks can be scheduled. + Existing tasks continue to run. """ self.closed = True async def stop(self, timeout=1): - """Closes the task group and waits `timeout` seconds for all tasks to gracefully finish. + """Close the group and stop all currently running tasks. + + Closes the task group and waits `timeout` seconds for all tasks to gracefully finish. After the timeout, all remaining tasks are cancelled. """ self.close() From 79970d79bd51c9bbd6ef65d134d2b4fa3b5470ea Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 31 May 2022 15:56:34 +0200 Subject: [PATCH 26/75] Improve typing --- distributed/core.py | 15 ++++++++------- distributed/utils.py | 7 ++++++- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index dc064151a1d..798a5f206c3 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -15,7 +15,7 @@ from contextlib import suppress from enum import Enum from functools import partial -from typing import Callable, ClassVar, Coroutine, TypedDict, TypeVar +from typing import Callable, ClassVar, TypedDict, TypeVar import tblib from tlz import merge @@ -37,6 +37,7 @@ from distributed.metrics import time from distributed.system_monitor import SystemMonitor from distributed.utils import ( + CoroutineFunctionType, delayed, get_traceback, has_keyword, @@ -115,12 +116,12 @@ class TaskGroup: #: If True, the group is closed and does not allow adding new tasks. closed: bool - def __init__(self): + def __init__(self) -> None: self.closed = False - self._ongoing_tasks = set() + self._ongoing_tasks: set[asyncio.Task] = set() def call_soon( - self, afunc: Callable[..., Coroutine], *args, **kwargs + self, afunc: CoroutineFunctionType, *args, **kwargs ) -> asyncio.Task | None: """Schedule a coroutine function to be executed as an `asyncio.Task`. @@ -150,7 +151,7 @@ def call_soon( return task def call_later( - self, delay: int, afunc: Callable[..., Coroutine], *args, **kwargs + self, delay: int, afunc: CoroutineFunctionType, *args, **kwargs ) -> asyncio.Task | None: """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`. @@ -179,14 +180,14 @@ def call_later( def _cancellable_tasks(self): return (t for t in self._ongoing_tasks if t is not asyncio.current_task()) - def close(self): + def close(self) -> None: """Closes the task group so that no new tasks can be scheduled. Existing tasks continue to run. """ self.closed = True - async def stop(self, timeout=1): + async def stop(self, timeout=1) -> None: """Close the group and stop all currently running tasks. Closes the task group and waits `timeout` seconds for all tasks to gracefully finish. diff --git a/distributed/utils.py b/distributed/utils.py index 841278e67fa..fae309a303c 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1734,7 +1734,12 @@ def is_python_shutting_down() -> bool: return _python_shutting_down -def delayed(corofunc: Callable[..., Coroutine], delay: float): +CoroutineFunctionType = TypeVar( + "CoroutineFunctionType", bound=Callable[..., Coroutine[AnyType, AnyType, AnyType]] +) + + +def delayed(corofunc: CoroutineFunctionType, delay: float): """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" async def wrapper(*args, **kwargs): From bdb1b70a9930b838b3987eff20d25942d56e4bcc Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 31 May 2022 16:19:37 +0200 Subject: [PATCH 27/75] Fix typing of delayed --- distributed/utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index fae309a303c..d2678ebf856 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -32,7 +32,7 @@ from types import ModuleType from typing import TYPE_CHECKING from typing import Any as AnyType -from typing import ClassVar, Coroutine, TypeVar, overload +from typing import ClassVar, Coroutine, TypeVar, cast, overload import click import tblib.pickling_support @@ -1734,16 +1734,14 @@ def is_python_shutting_down() -> bool: return _python_shutting_down -CoroutineFunctionType = TypeVar( - "CoroutineFunctionType", bound=Callable[..., Coroutine[AnyType, AnyType, AnyType]] -) +CoroutineFunctionType = TypeVar("CoroutineFunctionType", bound=Callable[..., Coroutine]) -def delayed(corofunc: CoroutineFunctionType, delay: float): +def delayed(corofunc: CoroutineFunctionType, delay: float) -> CoroutineFunctionType: """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" async def wrapper(*args, **kwargs): await asyncio.sleep(delay) return await corofunc(*args, **kwargs) - return wrapper + return cast(CoroutineFunctionType, wrapper) From d3f04a5a5990a4b15f08e2d0b36a678db2e47963 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 31 May 2022 17:07:11 +0200 Subject: [PATCH 28/75] Adjust stop() --- distributed/core.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 798a5f206c3..801f1b96e19 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -194,20 +194,22 @@ async def stop(self, timeout=1) -> None: After the timeout, all remaining tasks are cancelled. """ self.close() - try: - gather = asyncio.gather( + + # Wrap to avoid Python3.8 issue, + # see https://github.com/dask/distributed/pull/6478#discussion_r885696827 + async def _gather(): + return await asyncio.gather( *self._cancellable_tasks, return_exceptions=True, ) + + try: await asyncio.wait_for( - gather, + _gather(), timeout, ) except asyncio.TimeoutError: - try: - await gather - except asyncio.CancelledError: - pass + pass class Server: From 0532a21aea3e07f0db37308601c988e6a7d6eb07 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 31 May 2022 17:13:18 +0200 Subject: [PATCH 29/75] Adjust stop() --- distributed/core.py | 22 +++++++++------------- distributed/utils.py | 7 ++++--- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 801f1b96e19..66790e80a50 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -15,7 +15,7 @@ from contextlib import suppress from enum import Enum from functools import partial -from typing import Callable, ClassVar, TypedDict, TypeVar +from typing import Callable, ClassVar, Coroutine, TypedDict, TypeVar import tblib from tlz import merge @@ -37,7 +37,6 @@ from distributed.metrics import time from distributed.system_monitor import SystemMonitor from distributed.utils import ( - CoroutineFunctionType, delayed, get_traceback, has_keyword, @@ -121,7 +120,7 @@ def __init__(self) -> None: self._ongoing_tasks: set[asyncio.Task] = set() def call_soon( - self, afunc: CoroutineFunctionType, *args, **kwargs + self, afunc: Callable[..., Coroutine], *args, **kwargs ) -> asyncio.Task | None: """Schedule a coroutine function to be executed as an `asyncio.Task`. @@ -151,7 +150,7 @@ def call_soon( return task def call_later( - self, delay: int, afunc: CoroutineFunctionType, *args, **kwargs + self, delay: int, afunc: Callable[..., Coroutine], *args, **kwargs ) -> asyncio.Task | None: """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`. @@ -195,17 +194,14 @@ async def stop(self, timeout=1) -> None: """ self.close() - # Wrap to avoid Python3.8 issue, - # see https://github.com/dask/distributed/pull/6478#discussion_r885696827 - async def _gather(): - return await asyncio.gather( - *self._cancellable_tasks, - return_exceptions=True, - ) - + # Wrap gather in task to avoid Python3.8 issue, + # see https://github.com/dask/distributed/pull/6478#discussion_r885757056 + gather_task = asyncio.create_task( + asyncio.gather(*self._cancellable_tasks, return_exceptions=True) + ) try: await asyncio.wait_for( - _gather(), + gather_task, timeout, ) except asyncio.TimeoutError: diff --git a/distributed/utils.py b/distributed/utils.py index d2678ebf856..b8740cee5a6 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -75,6 +75,10 @@ P = ParamSpec("P") T = TypeVar("T") + CoroutineFunctionType = TypeVar( + "CoroutineFunctionType", bound=Callable[..., Coroutine] + ) + no_default = "__no_default__" @@ -1734,9 +1738,6 @@ def is_python_shutting_down() -> bool: return _python_shutting_down -CoroutineFunctionType = TypeVar("CoroutineFunctionType", bound=Callable[..., Coroutine]) - - def delayed(corofunc: CoroutineFunctionType, delay: float) -> CoroutineFunctionType: """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" From ac53b942efcadc5c33a7bd01d58a260a9f3642bd Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 31 May 2022 17:42:23 +0200 Subject: [PATCH 30/75] Make call_soon and call_later public --- distributed/core.py | 52 +++++++++++++++++++++++++++++++++++----- distributed/nanny.py | 4 ++-- distributed/scheduler.py | 10 ++++---- distributed/worker.py | 10 ++++---- 4 files changed, 58 insertions(+), 18 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 66790e80a50..de2653bb543 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -150,7 +150,7 @@ def call_soon( return task def call_later( - self, delay: int, afunc: Callable[..., Coroutine], *args, **kwargs + self, delay: float, afunc: Callable[..., Coroutine], *args, **kwargs ) -> asyncio.Task | None: """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`. @@ -448,7 +448,7 @@ async def start_pcs(): if not pc.is_running(): pc.start() - self._call_soon(start_pcs) + self.call_soon(start_pcs) def stop(self): if not self.__stopped: @@ -465,7 +465,7 @@ def stop(self): async def _stop_listener(): listener.stop() - self._call_soon(_stop_listener) + self.call_soon(_stop_listener) @property def listener(self): @@ -756,7 +756,7 @@ async def handle_stream(self, comm, extra=None): break handler = self.stream_handlers[op] if is_coroutine_function(handler): - self._call_soon(handler, **merge(extra, msg)) + self.call_soon(handler, **merge(extra, msg)) await asyncio.sleep(0) else: handler(**merge(extra, msg)) @@ -777,10 +777,50 @@ async def handle_stream(self, comm, extra=None): await comm.close() assert comm.closed() - def _call_later(self, delay, afunc, *args, **kwargs): + def call_later( + self, delay: float, afunc: Callable[..., Coroutine], *args, **kwargs + ) -> None: + """Schedule a coroutine function to be asynchronously executed after `delay` seconds. + + The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments + to be asynchronously executed after `delay` seconds. + + Parameters + ---------- + delay + Delay in seconds. + afunc + Coroutine function to schedule. + *args + Arguments to be passed to `afunc`. + **kwargs + Keyword arguments to be passed to `afunc` + + Returns + ------- + None + """ + self._ongoing_background_tasks.call_later(delay, afunc, *args, **kwargs) - def _call_soon(self, afunc, *args, **kwargs): + def call_soon(self, afunc: Callable[..., Coroutine], *args, **kwargs) -> None: + """Schedule a coroutine function to be executed asynchronously. + + The coroutine function `afunc` is scheduled asynchronously with `args` arguments and `kwargs` keyword arguments. + + Parameters + ---------- + afunc + Coroutine function to schedule. + *args + Arguments to be passed to `afunc`. + **kwargs + Keyword arguments to be passed to `afunc` + + Returns + ------- + None + """ self._ongoing_background_tasks.call_soon(afunc, *args, **kwargs) async def close(self, timeout=None): diff --git a/distributed/nanny.py b/distributed/nanny.py index 2daba622cb8..08d5cee6736 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -506,7 +506,7 @@ def run(self, comm, *args, **kwargs): return run(self, comm, *args, **kwargs) def _on_exit_sync(self, exitcode): - self._call_soon(self._on_exit, exitcode) + self.call_soon(self._on_exit, exitcode) @log_errors async def _on_exit(self, exitcode): @@ -603,7 +603,7 @@ async def _log_event(self, topic, msg): ) def log_event(self, topic, msg): - self._call_soon(self._log_event, topic, msg) + self.call_soon(self._log_event, topic, msg) class WorkerProcess: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 16b516573a2..a34d5c5c6de 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3320,7 +3320,7 @@ async def start_unsafe(self): for k, v in self.services.items(): logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) - self._call_soon(self.reevaluate_occupancy) + self.call_soon(self.reevaluate_occupancy) if self.scheduler_file: with open(self.scheduler_file, "w") as f: @@ -4252,7 +4252,7 @@ async def remove_worker_from_events(): dask.config.get("distributed.scheduler.events-cleanup-delay") ) - self._call_later(cleanup_delay, remove_worker_from_events) + self.call_later(cleanup_delay, remove_worker_from_events) logger.debug("Removed worker %s", ws) return "OK" @@ -4613,7 +4613,7 @@ async def remove_client_from_events(): dask.config.get("distributed.scheduler.events-cleanup-delay") ) - self._call_later(cleanup_delay, remove_client_from_events) + self.call_later(cleanup_delay, remove_client_from_events) def send_task_to_worker(self, worker, ts: TaskState, duration: float = -1): """Send a single computational task to a worker""" @@ -4880,7 +4880,7 @@ def worker_send(self, worker: str, msg: dict[str, Any]) -> None: try: stream_comms[worker].send(msg) except (CommClosedError, AttributeError): - self._call_soon( + self.call_soon( self.remove_worker, address=worker, stimulus_id=f"worker-send-comm-fail-{time()}", @@ -4929,7 +4929,7 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): # worker already gone pass except (CommClosedError, AttributeError): - self._call_soon( + self.call_soon( self.remove_worker, address=worker, stimulus_id=f"send-all-comm-fail-{time()}", diff --git a/distributed/worker.py b/distributed/worker.py index 627a725b669..e907166d88e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -979,7 +979,7 @@ def status(self, value): prev_status = self.status ServerNode.status.__set__(self, value) stimulus_id = f"worker-status-change-{time()}" - self._call_soon(self._send_worker_status_change, stimulus_id) + self.call_soon(self._send_worker_status_change, stimulus_id) if prev_status == Status.paused and value == Status.running: self.handle_stimulus(UnpauseEvent(stimulus_id=stimulus_id)) @@ -1182,7 +1182,7 @@ async def _register_with_scheduler(self): self.batched_stream.start(comm) self.periodic_callbacks["keep-alive"].start() self.periodic_callbacks["heartbeat"].start() - self._call_soon(self.handle_scheduler, comm) + self.call_soon(self.handle_scheduler, comm) def _update_latency(self, latency): self.latency = latency * 0.05 + self.latency * 0.95 @@ -1643,7 +1643,7 @@ async def batched_send_connect(): bcomm.start(comm) - self._call_soon(batched_send_connect) + self.call_soon(batched_send_connect) self.stream_comms[address].send(msg) @@ -3340,7 +3340,7 @@ def done_event(): # Avoid hammering the worker. If there are multiple replicas # available, immediately try fetching from a different worker. self.busy_workers.add(worker) - self._call_later(0.15, self._readd_busy_worker, worker) + self.call_later(0.15, self._readd_busy_worker, worker) refresh_who_has = set() @@ -3475,7 +3475,7 @@ def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: "Invalid Worker.status transition: %s -> %s", self._status, new_status ) # Reiterate the current status to the scheduler to restore sync - self._call_soon(self._send_worker_status_change, stimulus_id) + self.call_soon(self._send_worker_status_change, stimulus_id) else: # Update status and send confirmation to the Scheduler (see status.setter) From 2f7d90c326ae63bfaa2c826f02f94fb675553d69 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 31 May 2022 17:44:03 +0200 Subject: [PATCH 31/75] Make TypeVar private --- distributed/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index b8740cee5a6..de7f76ee6c3 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -75,8 +75,8 @@ P = ParamSpec("P") T = TypeVar("T") - CoroutineFunctionType = TypeVar( - "CoroutineFunctionType", bound=Callable[..., Coroutine] + _CoroutineFunctionType = TypeVar( + "_CoroutineFunctionType", bound=Callable[..., Coroutine] ) @@ -1738,11 +1738,11 @@ def is_python_shutting_down() -> bool: return _python_shutting_down -def delayed(corofunc: CoroutineFunctionType, delay: float) -> CoroutineFunctionType: +def delayed(corofunc: _CoroutineFunctionType, delay: float) -> _CoroutineFunctionType: """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" async def wrapper(*args, **kwargs): await asyncio.sleep(delay) return await corofunc(*args, **kwargs) - return cast(CoroutineFunctionType, wrapper) + return cast(_CoroutineFunctionType, wrapper) From 4f739b22abd9638e1b01ffa74a27da5a0e0f06b6 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 31 May 2022 17:45:10 +0200 Subject: [PATCH 32/75] Rename TaskGroup to AsyncTaskGroup --- distributed/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index de2653bb543..9de45b5ba82 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -109,7 +109,7 @@ def _expects_comm(func: Callable) -> bool: return False -class TaskGroup: +class AsyncTaskGroup: """Collection tracking all currently running asynchronous tasks within a group""" #: If True, the group is closed and does not allow adding new tasks. @@ -286,8 +286,8 @@ def __init__( self.monitor = SystemMonitor() self.counters = None self.digests = None - self._ongoing_background_tasks = TaskGroup() - self._ongoing_comm_handlers = TaskGroup() + self._ongoing_background_tasks = AsyncTaskGroup() + self._ongoing_comm_handlers = AsyncTaskGroup() self._event_finished = asyncio.Event() self.listeners = [] From f6348eef2ec05389894d73d3800ea8d6f8da84d1 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 31 May 2022 17:53:20 +0200 Subject: [PATCH 33/75] Minor --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a34d5c5c6de..02011cecee8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -7001,7 +7001,7 @@ def check_idle(self): "Scheduler closing after being idle for %s", format_time(self.idle_timeout), ) - asyncio.create_task(self.close()) + self.call_soon(self.close) def adaptive_target(self, target_duration=None): """Desired number of workers based on the current workload From b66fae612ec1f6c5a23558d210d8cad67ad8514f Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 31 May 2022 19:03:54 +0200 Subject: [PATCH 34/75] Fix stop() without running tasks --- distributed/core.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 9de45b5ba82..b16f64cf7cd 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -143,8 +143,8 @@ def call_soon( """ if self.closed: return None - - task = asyncio.create_task(afunc(*args, **kwargs)) + coro = afunc(*args, **kwargs) + task = asyncio.create_task(coro) self._ongoing_tasks.add(task) task.add_done_callback(self._ongoing_tasks.remove) return task @@ -175,10 +175,6 @@ def call_later( """ return self.call_soon(delayed(afunc, delay), *args, **kwargs) - @property - def _cancellable_tasks(self): - return (t for t in self._ongoing_tasks if t is not asyncio.current_task()) - def close(self) -> None: """Closes the task group so that no new tasks can be scheduled. @@ -194,18 +190,22 @@ async def stop(self, timeout=1) -> None: """ self.close() - # Wrap gather in task to avoid Python3.8 issue, - # see https://github.com/dask/distributed/pull/6478#discussion_r885757056 - gather_task = asyncio.create_task( - asyncio.gather(*self._cancellable_tasks, return_exceptions=True) - ) - try: - await asyncio.wait_for( - gather_task, - timeout, + tasks_to_stop = [ + t for t in self._ongoing_tasks if t is not asyncio.current_task() + ] + if tasks_to_stop: + # Wrap gather in task to avoid Python3.8 issue, + # see https://github.com/dask/distributed/pull/6478#discussion_r885757056 + gather_task = asyncio.create_task( + asyncio.gather(*tasks_to_stop, return_exceptions=True) ) - except asyncio.TimeoutError: - pass + try: + await asyncio.wait_for( + gather_task, + timeout, + ) + except asyncio.TimeoutError: + pass class Server: From b064ba1bdff6f4243674c8c47ade5ae0fcd1cea3 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 1 Jun 2022 08:20:56 +0200 Subject: [PATCH 35/75] Wrap gather in async def --- distributed/core.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index b16f64cf7cd..5f6f488c0ad 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -40,7 +40,7 @@ delayed, get_traceback, has_keyword, - is_coroutine_function, + iscoroutinefunction, recursive_to_dict, truncate_exception, ) @@ -143,8 +143,7 @@ def call_soon( """ if self.closed: return None - coro = afunc(*args, **kwargs) - task = asyncio.create_task(coro) + task = asyncio.create_task(afunc(*args, **kwargs)) self._ongoing_tasks.add(task) task.add_done_callback(self._ongoing_tasks.remove) return task @@ -195,13 +194,13 @@ async def stop(self, timeout=1) -> None: ] if tasks_to_stop: # Wrap gather in task to avoid Python3.8 issue, - # see https://github.com/dask/distributed/pull/6478#discussion_r885757056 - gather_task = asyncio.create_task( - asyncio.gather(*tasks_to_stop, return_exceptions=True) - ) + # see https://github.com/dask/distributed/pull/6478#discussion_r885696827 + async def gather(): + return await asyncio.gather(*tasks_to_stop, return_exceptions=True) + try: await asyncio.wait_for( - gather_task, + gather(), timeout, ) except asyncio.TimeoutError: @@ -755,7 +754,7 @@ async def handle_stream(self, comm, extra=None): closed = True break handler = self.stream_handlers[op] - if is_coroutine_function(handler): + if iscoroutinefunction(handler): self.call_soon(handler, **merge(extra, msg)) await asyncio.sleep(0) else: From 52061d904e62c8dc685c018f0ac122fb30a8f4bd Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 1 Jun 2022 08:29:46 +0200 Subject: [PATCH 36/75] Add __len__ to AsyncTaskGroup --- distributed/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/distributed/core.py b/distributed/core.py index 5f6f488c0ad..ecce742a01e 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -206,6 +206,9 @@ async def gather(): except asyncio.TimeoutError: pass + def __len__(self): + return len(self._ongoing_tasks) + class Server: """Dask Distributed Server From 1e41e20ac64fa92830b30bef680d6488c5b4f419 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 1 Jun 2022 09:34:43 +0200 Subject: [PATCH 37/75] Fix comm handling and delayed typing --- distributed/core.py | 50 ++++++++++++++++++++++++++++---------------- distributed/utils.py | 12 ++++++----- 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index ecce742a01e..a7766bb86d5 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -119,6 +119,27 @@ def __init__(self) -> None: self.closed = False self._ongoing_tasks: set[asyncio.Task] = set() + def schedule(self, coro: Coroutine) -> asyncio.Task | None: + """Schedules a coroutine object to be executed as an `asyncio.Task`. + + Parameters + ---------- + coro : Coroutine + Coroutine object to schedule. + + Returns + ------- + asyncio.Task | None + The scheduled Task object, or None if the group is closed. + """ + if self.closed: + coro.close() + return None + task = asyncio.create_task(coro) + self._ongoing_tasks.add(task) + task.add_done_callback(self._ongoing_tasks.remove) + return task + def call_soon( self, afunc: Callable[..., Coroutine], *args, **kwargs ) -> asyncio.Task | None: @@ -143,10 +164,7 @@ def call_soon( """ if self.closed: return None - task = asyncio.create_task(afunc(*args, **kwargs)) - self._ongoing_tasks.add(task) - task.add_done_callback(self._ongoing_tasks.remove) - return task + return self.schedule(afunc(*args, **kwargs)) def call_later( self, delay: float, afunc: Callable[..., Coroutine], *args, **kwargs @@ -681,21 +699,17 @@ async def handle_comm(self, comm): logger.debug("Calling into handler %s", handler.__name__) try: - if inspect.iscoroutinefunction(handler): - if _expects_comm(handler): - result = self._ongoing_comm_handlers.call_soon( - handler, comm, **msg - ) - else: - result = self._ongoing_comm_handlers.call_soon( - handler, **msg - ) - result = await result + if _expects_comm(handler): + result = handler(comm, **msg) else: - if _expects_comm(handler): - result = handler(comm, **msg) - else: - result = handler(**msg) + result = handler(**msg) + if inspect.iscoroutine(result): + result = self._ongoing_comm_handlers.schedule(result) + result = await result + elif inspect.isawaitable(result): + raise RuntimeError( + f"Comm handler returned unknown awaitable. Expected coroutine, instead got {type(result)}" + ) except CommClosedError: if self.status == Status.running: logger.info("Lost connection to %r", address, exc_info=True) diff --git a/distributed/utils.py b/distributed/utils.py index de7f76ee6c3..07f32334ee7 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -20,7 +20,7 @@ import xml.etree.ElementTree from asyncio import TimeoutError from collections import OrderedDict, UserDict, deque -from collections.abc import Callable, Collection, Container, KeysView, ValuesView +from collections.abc import Collection, Container, KeysView, ValuesView from concurrent.futures import CancelledError, ThreadPoolExecutor # noqa: F401 from contextlib import contextmanager, suppress from contextvars import ContextVar @@ -32,7 +32,7 @@ from types import ModuleType from typing import TYPE_CHECKING from typing import Any as AnyType -from typing import ClassVar, Coroutine, TypeVar, cast, overload +from typing import Callable, ClassVar, Coroutine, TypeVar, cast, overload import click import tblib.pickling_support @@ -75,9 +75,6 @@ P = ParamSpec("P") T = TypeVar("T") - _CoroutineFunctionType = TypeVar( - "_CoroutineFunctionType", bound=Callable[..., Coroutine] - ) no_default = "__no_default__" @@ -1738,6 +1735,11 @@ def is_python_shutting_down() -> bool: return _python_shutting_down +_CoroutineFunctionType = TypeVar( + "_CoroutineFunctionType", bound=Callable[..., Coroutine] +) + + def delayed(corofunc: _CoroutineFunctionType, delay: float) -> _CoroutineFunctionType: """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" From 688aaa11a1b93737da2a56c4fe58840b4c03e0db Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 1 Jun 2022 11:21:01 +0200 Subject: [PATCH 38/75] Fix test --- distributed/tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 363b40434ee..3570e150216 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5151,7 +5151,7 @@ def long_running(lock, entered): assert s.total_occupancy == 0 assert ws.occupancy == 0 - s._call_soon(s.reevaluate_occupancy, 0) + s.call_soon(s.reevaluate_occupancy, 0) assert s.workers[a.address].occupancy == 0 await l.release() From f3d9fa9f2f2c26a3524a044fbb2cdc011a78ec20 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 1 Jun 2022 12:07:46 +0200 Subject: [PATCH 39/75] Add unit tests for AsyncTaskGroup --- distributed/tests/test_core.py | 144 +++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 9ce06b8a149..c1d9d33ee65 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -12,6 +12,7 @@ from distributed.comm.core import CommClosedError from distributed.core import ( + AsyncTaskGroup, ConnectionPool, Server, Status, @@ -73,6 +74,149 @@ def echo_no_serialize(comm, x): return {"result": x} +def test_async_task_group_initialization(): + group = AsyncTaskGroup() + assert not group.closed + assert len(group) == 0 + + +@gen_test() +async def test_async_task_group_call_soon_executes_task_in_background(): + group = AsyncTaskGroup() + ev = asyncio.Event() + flag = False + + async def set_flag(): + nonlocal flag + await ev.wait() + flag = True + + task = group.call_soon(set_flag) + assert task is not None + assert len(group) == 1 + ev.set() + await task + assert len(group) == 0 + assert flag + + +@gen_test() +async def test_async_task_group_call_later_executes_delayed_task_in_background(): + group = AsyncTaskGroup() + flag = False + + async def set_flag(): + nonlocal flag + flag = True + + start = time() + task = group.call_later(1, set_flag) + assert task is not None + assert len(group) == 1 + await task + end = time() + assert len(group) == 0 + assert flag + assert end - start > 1 + + +def test_async_task_group_close_closes(): + group = AsyncTaskGroup() + group.close() + assert group.closed + + +@gen_test() +async def test_async_task_group_close_does_not_cancel_existing_tasks(): + group = AsyncTaskGroup() + + ev = asyncio.Event() + flag = False + + async def set_flag(): + nonlocal flag + await ev.wait() + flag = True + return True + + task = group.call_soon(set_flag) + + group.close() + + assert not task.cancelled() + assert len(group) == 1 + + ev.set() + await task + assert task.result() + assert len(group) == 0 + + +@gen_test() +async def test_async_task_group_close_prohibits_new_tasks(): + group = AsyncTaskGroup() + group.close() + + ev = asyncio.Event() + flag = False + + async def set_flag(): + nonlocal flag + await ev.wait() + flag = True + return True + + task = group.call_soon(set_flag) + assert task is None + assert len(group) == 0 + + task = group.call_later(1, set_flag) + assert task is None + assert len(group) == 0 + + await asyncio.sleep(0.01) + assert not flag + + +@gen_test() +async def test_async_task_group_stop_allows_shutdown(): + group = AsyncTaskGroup() + + flag = False + + async def set_flag(): + nonlocal flag + while not group.closed: + asyncio.sleep(0.01) + flag = True + return True + + task = group.call_soon(set_flag) + assert len(group) == 1 + await group.stop(timeout=1) + assert not task.cancelled() + assert flag + assert task.result() + + +@gen_test() +async def test_async_task_group_stop_cancels_long_running(): + group = AsyncTaskGroup() + + flag = False + + async def set_flag(): + nonlocal flag + flag = True + return True + + task = group.call_later(10, set_flag) + assert len(group) == 1 + await group.stop(timeout=1) + assert task.cancelled() + assert not flag + + @gen_test() async def test_server_status_is_always_enum(): """Assignments with strings is forbidden""" From 683d8379d59df548a6480b6ca2bb6f4f6561cbf4 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 1 Jun 2022 21:57:21 +0200 Subject: [PATCH 40/75] Replace more add_callback's --- distributed/worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index 0c86933389c..23e4763a08c 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -212,7 +212,7 @@ def wrapper(self, *args, **kwargs): }, ) logger.exception(e) - self.loop.add_callback(_force_close, self) + self.call_soon(_force_close, self) raise return wrapper @@ -5005,7 +5005,7 @@ async def run(server, comm, function, args=(), kwargs=None, wait=True): if wait: result = await function(*args, **kwargs) else: - server.loop.add_callback(function, *args, **kwargs) + server.call_soon(function, *args, **kwargs) result = None except Exception as e: From 350c2b4186f657d9551b21d2c8609f6d51ccee0d Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 1 Jun 2022 22:02:30 +0200 Subject: [PATCH 41/75] Add hopefully unreachable error --- distributed/core.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/distributed/core.py b/distributed/core.py index a7766bb86d5..aae45c24434 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -224,6 +224,11 @@ async def gather(): except asyncio.TimeoutError: pass + if self._ongoing_tasks: + raise RuntimeError( + f"Expected all ongoing tasks to be cancelled and removed, found {self._ongoing_tasks}." + ) + def __len__(self): return len(self._ongoing_tasks) From 43c7b3c5a0886427ba19204bea339586c53814b2 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 2 Jun 2022 09:45:32 +0200 Subject: [PATCH 42/75] Rollback to add_callback in fail_hard --- distributed/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index 23e4763a08c..98235ba0efc 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -212,7 +212,7 @@ def wrapper(self, *args, **kwargs): }, ) logger.exception(e) - self.call_soon(_force_close, self) + self.loop.add_callback(_force_close, self) raise return wrapper From 936710ad8532b05e9feca32c72c3f2bc46851235 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 2 Jun 2022 11:18:28 +0200 Subject: [PATCH 43/75] Retrigger CI to check for flake From f46bc592711652cd51c563283eb78daf2b3b3393 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 2 Jun 2022 12:28:28 +0200 Subject: [PATCH 44/75] Clean up closing logic --- distributed/core.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index aae45c24434..7ca33cac87a 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -190,6 +190,8 @@ def call_later( asyncio.Task | None The scheduled Task object, or None if the group is closed. """ + if self.closed: + return None return self.call_soon(delayed(afunc, delay), *args, **kwargs) def close(self) -> None: @@ -207,9 +209,14 @@ async def stop(self, timeout=1) -> None: """ self.close() - tasks_to_stop = [ - t for t in self._ongoing_tasks if t is not asyncio.current_task() - ] + current_task = asyncio.current_task() + if current_task: + self._ongoing_tasks.discard( + current_task + ) #: Discard to avoid cancelling the current task + + tasks_to_stop = list(self._ongoing_tasks) + if tasks_to_stop: # Wrap gather in task to avoid Python3.8 issue, # see https://github.com/dask/distributed/pull/6478#discussion_r885696827 @@ -222,7 +229,7 @@ async def gather(): timeout, ) except asyncio.TimeoutError: - pass + await asyncio.gather(*tasks_to_stop, return_exceptions=True) if self._ongoing_tasks: raise RuntimeError( From 06749bb71c71be890987d33b31b2df8634ad3d4c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 2 Jun 2022 14:52:09 +0200 Subject: [PATCH 45/75] Add locking --- distributed/core.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 7ca33cac87a..397eb9e8122 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -117,6 +117,7 @@ class AsyncTaskGroup: def __init__(self) -> None: self.closed = False + self._lock = threading.Lock() self._ongoing_tasks: set[asyncio.Task] = set() def schedule(self, coro: Coroutine) -> asyncio.Task | None: @@ -132,13 +133,14 @@ def schedule(self, coro: Coroutine) -> asyncio.Task | None: asyncio.Task | None The scheduled Task object, or None if the group is closed. """ - if self.closed: - coro.close() - return None - task = asyncio.create_task(coro) - self._ongoing_tasks.add(task) - task.add_done_callback(self._ongoing_tasks.remove) - return task + with self._lock: + if self.closed: + coro.close() + return None + task = asyncio.create_task(coro) + task.add_done_callback(self._ongoing_tasks.remove) + self._ongoing_tasks.add(task) + return task def call_soon( self, afunc: Callable[..., Coroutine], *args, **kwargs @@ -162,8 +164,6 @@ def call_soon( asyncio.Task | None The scheduled Task object, or None if the group is closed. """ - if self.closed: - return None return self.schedule(afunc(*args, **kwargs)) def call_later( @@ -190,8 +190,6 @@ def call_later( asyncio.Task | None The scheduled Task object, or None if the group is closed. """ - if self.closed: - return None return self.call_soon(delayed(afunc, delay), *args, **kwargs) def close(self) -> None: @@ -199,7 +197,8 @@ def close(self) -> None: Existing tasks continue to run. """ - self.closed = True + with self._lock: + self.closed = True async def stop(self, timeout=1) -> None: """Close the group and stop all currently running tasks. @@ -210,12 +209,7 @@ async def stop(self, timeout=1) -> None: self.close() current_task = asyncio.current_task() - if current_task: - self._ongoing_tasks.discard( - current_task - ) #: Discard to avoid cancelling the current task - - tasks_to_stop = list(self._ongoing_tasks) + tasks_to_stop = [t for t in self._ongoing_tasks if t is not current_task] if tasks_to_stop: # Wrap gather in task to avoid Python3.8 issue, @@ -231,7 +225,7 @@ async def gather(): except asyncio.TimeoutError: await asyncio.gather(*tasks_to_stop, return_exceptions=True) - if self._ongoing_tasks: + if [t for t in self._ongoing_tasks if t is not current_task]: raise RuntimeError( f"Expected all ongoing tasks to be cancelled and removed, found {self._ongoing_tasks}." ) From 9ca5bea4d8a4f9d56fa8102afb94795a87ebe45a Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 2 Jun 2022 15:41:51 +0200 Subject: [PATCH 46/75] Fix invalid worker states by not tracking handle_scheduler --- distributed/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index c4616531c06..bac600e7ada 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1202,7 +1202,7 @@ async def _register_with_scheduler(self): self.batched_stream.start(comm) self.periodic_callbacks["keep-alive"].start() self.periodic_callbacks["heartbeat"].start() - self.call_soon(self.handle_scheduler, comm) + self.loop.add_callback(self.handle_scheduler, comm) def _update_latency(self, latency): self.latency = latency * 0.05 + self.latency * 0.95 From daa845a1c186378928b7539129250a25a95edbbe Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 2 Jun 2022 17:27:42 +0200 Subject: [PATCH 47/75] Fix call_later test by using monotonic clock and subtracting clock resolution --- distributed/tests/test_core.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index c1d9d33ee65..c2a3948e422 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -3,6 +3,7 @@ import os import socket import threading +import time as timemod import weakref import pytest @@ -109,15 +110,15 @@ async def set_flag(): nonlocal flag flag = True - start = time() + start = timemod.monotonic() task = group.call_later(1, set_flag) assert task is not None assert len(group) == 1 await task - end = time() + end = timemod.monotonic() assert len(group) == 0 assert flag - assert end - start > 1 + assert end - start > 1 - timemod.get_clock_info("monotonic").resolution def test_async_task_group_close_closes(): From 075a5395015ae9084d03b477f4b1eb48664d9a5f Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 7 Jun 2022 17:56:32 +0200 Subject: [PATCH 48/75] Drop lock and ensure AsyncTaskGroup is called from the correct thread --- distributed/core.py | 47 +++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 31742292f21..0107713dace 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -109,7 +109,26 @@ def _expects_comm(func: Callable) -> bool: return False -class AsyncTaskGroup: +class _LoopBoundMixin: + """Backport of the private asyncio.mixins._LoopBoundMixin from 3.11""" + + _global_lock = threading.Lock() + + _loop = None + + def _get_loop(self): + loop = asyncio.get_running_loop() + + if self._loop is None: + with self._global_lock: + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f"{self!r} is bound to a different event loop") + return loop + + +class AsyncTaskGroup(_LoopBoundMixin): """Collection tracking all currently running asynchronous tasks within a group""" #: If True, the group is closed and does not allow adding new tasks. @@ -117,7 +136,6 @@ class AsyncTaskGroup: def __init__(self) -> None: self.closed = False - self._lock = threading.Lock() self._ongoing_tasks: set[asyncio.Task] = set() def schedule(self, coro: Coroutine) -> asyncio.Task | None: @@ -125,22 +143,20 @@ def schedule(self, coro: Coroutine) -> asyncio.Task | None: Parameters ---------- - coro : Coroutine + coro Coroutine object to schedule. Returns ------- - asyncio.Task | None The scheduled Task object, or None if the group is closed. """ - with self._lock: - if self.closed: - coro.close() - return None - task = asyncio.create_task(coro) - task.add_done_callback(self._ongoing_tasks.remove) - self._ongoing_tasks.add(task) - return task + if self.closed: + coro.close() + return None + task = self._get_loop().create_task(coro) + task.add_done_callback(self._ongoing_tasks.remove) + self._ongoing_tasks.add(task) + return task def call_soon( self, afunc: Callable[..., Coroutine], *args, **kwargs @@ -161,7 +177,6 @@ def call_soon( Returns ------- - asyncio.Task | None The scheduled Task object, or None if the group is closed. """ return self.schedule(afunc(*args, **kwargs)) @@ -187,7 +202,6 @@ def call_later( Returns ------- - asyncio.Task | None The scheduled Task object, or None if the group is closed. """ return self.call_soon(delayed(afunc, delay), *args, **kwargs) @@ -197,8 +211,7 @@ def close(self) -> None: Existing tasks continue to run. """ - with self._lock: - self.closed = True + self.closed = True async def stop(self, timeout=1) -> None: """Close the group and stop all currently running tasks. @@ -208,7 +221,7 @@ async def stop(self, timeout=1) -> None: """ self.close() - current_task = asyncio.current_task() + current_task = asyncio.current_task(self._get_loop()) tasks_to_stop = [t for t in self._ongoing_tasks if t is not current_task] if tasks_to_stop: From 8df26a57f9861d14946f23d269e02a2bc4224ec3 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 12:37:25 +0200 Subject: [PATCH 49/75] Raise exception if coro cannot be scheduled --- distributed/core.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 0107713dace..b2f51a9053d 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -128,6 +128,10 @@ def _get_loop(self): return loop +class AsyncTaskGroupClosedError(RuntimeError): + pass + + class AsyncTaskGroup(_LoopBoundMixin): """Collection tracking all currently running asynchronous tasks within a group""" @@ -138,7 +142,7 @@ def __init__(self) -> None: self.closed = False self._ongoing_tasks: set[asyncio.Task] = set() - def schedule(self, coro: Coroutine) -> asyncio.Task | None: + def schedule(self, coro: Coroutine) -> asyncio.Task: """Schedules a coroutine object to be executed as an `asyncio.Task`. Parameters @@ -148,11 +152,18 @@ def schedule(self, coro: Coroutine) -> asyncio.Task | None: Returns ------- - The scheduled Task object, or None if the group is closed. + The scheduled Task object. + + Raises + ------ + AsyncTaskGroupClosedError + If the task group is closed. """ if self.closed: coro.close() - return None + raise AsyncTaskGroupClosedError( + "Cannot schedule a new coroutine as the group is already closed." + ) task = self._get_loop().create_task(coro) task.add_done_callback(self._ongoing_tasks.remove) self._ongoing_tasks.add(task) @@ -160,7 +171,7 @@ def schedule(self, coro: Coroutine) -> asyncio.Task | None: def call_soon( self, afunc: Callable[..., Coroutine], *args, **kwargs - ) -> asyncio.Task | None: + ) -> asyncio.Task: """Schedule a coroutine function to be executed as an `asyncio.Task`. The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments @@ -177,13 +188,18 @@ def call_soon( Returns ------- - The scheduled Task object, or None if the group is closed. + The scheduled Task object. + + Raises + ------ + AsyncTaskGroupClosedError + If the task group is closed. """ return self.schedule(afunc(*args, **kwargs)) def call_later( self, delay: float, afunc: Callable[..., Coroutine], *args, **kwargs - ) -> asyncio.Task | None: + ) -> asyncio.Task: """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`. The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments @@ -202,7 +218,12 @@ def call_later( Returns ------- - The scheduled Task object, or None if the group is closed. + The scheduled Task object. + + Raises + ------ + AsyncTaskGroupClosedError + If the task group is closed. """ return self.call_soon(delayed(afunc, delay), *args, **kwargs) From a77c61ec76e0903fcc854f972831db3af57abccd Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 12:41:47 +0200 Subject: [PATCH 50/75] AsyncTaskGroupClosedError --- distributed/core.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index b2f51a9053d..d5f1c5024fa 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -160,7 +160,6 @@ def schedule(self, coro: Coroutine) -> asyncio.Task: If the task group is closed. """ if self.closed: - coro.close() raise AsyncTaskGroupClosedError( "Cannot schedule a new coroutine as the group is already closed." ) @@ -195,6 +194,10 @@ def call_soon( AsyncTaskGroupClosedError If the task group is closed. """ + if self.closed: + raise AsyncTaskGroupClosedError( + "Cannot schedule a new coroutine function as the group is already closed." + ) return self.schedule(afunc(*args, **kwargs)) def call_later( @@ -747,7 +750,11 @@ async def handle_comm(self, comm): else: result = handler(**msg) if inspect.iscoroutine(result): - result = self._ongoing_comm_handlers.schedule(result) + try: + result = self._ongoing_comm_handlers.schedule(result) + except AsyncTaskGroupClosedError: + result.close() # TODO: Don't call coroutinefunctions that we can't await + return result = await result elif inspect.isawaitable(result): raise RuntimeError( From 50f5851a9d878f121f7ae768b23bb5dd660e540a Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 13:05:53 +0200 Subject: [PATCH 51/75] Add comment --- distributed/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/core.py b/distributed/core.py index d5f1c5024fa..6d70675b898 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -194,7 +194,7 @@ def call_soon( AsyncTaskGroupClosedError If the task group is closed. """ - if self.closed: + if self.closed: # Avoid creating a coroutine raise AsyncTaskGroupClosedError( "Cannot schedule a new coroutine function as the group is already closed." ) From 22d0952927c2df3217cf2f77c15e1346e346319c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 13:12:20 +0200 Subject: [PATCH 52/75] Use ParamSpec --- distributed/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index 6fb847d6eb8..4699094f311 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -32,7 +32,7 @@ from types import ModuleType from typing import TYPE_CHECKING from typing import Any as AnyType -from typing import Callable, ClassVar, Coroutine, TypeVar, cast, overload +from typing import Callable, ClassVar, Coroutine, TypeVar, overload import click import tblib.pickling_support @@ -1715,16 +1715,17 @@ def is_python_shutting_down() -> bool: return _python_shutting_down -_CoroutineFunctionType = TypeVar( - "_CoroutineFunctionType", bound=Callable[..., Coroutine] -) +_P = ParamSpec("_P") +_T = TypeVar("_T") -def delayed(corofunc: _CoroutineFunctionType, delay: float) -> _CoroutineFunctionType: +def delayed( + corofunc: Callable[_P, Coroutine[AnyType, AnyType, _T]], delay: float +) -> Callable[_P, Coroutine[AnyType, AnyType, _T]]: """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" async def wrapper(*args, **kwargs): await asyncio.sleep(delay) return await corofunc(*args, **kwargs) - return cast(_CoroutineFunctionType, wrapper) + return wrapper From cd04ab2c9bbcfc96b34451e18426bb924a9214e7 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 13:15:32 +0200 Subject: [PATCH 53/75] Fix ParamSpec --- distributed/utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index 4699094f311..1e6d70c8a09 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1715,13 +1715,9 @@ def is_python_shutting_down() -> bool: return _python_shutting_down -_P = ParamSpec("_P") -_T = TypeVar("_T") - - def delayed( - corofunc: Callable[_P, Coroutine[AnyType, AnyType, _T]], delay: float -) -> Callable[_P, Coroutine[AnyType, AnyType, _T]]: + corofunc: Callable[P, Coroutine[AnyType, AnyType, T]], delay: float +) -> Callable[P, Coroutine[AnyType, AnyType, T]]: """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" async def wrapper(*args, **kwargs): From 2daae465e6c6566c700a4b2e7564296beb415d1e Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 14:14:59 +0200 Subject: [PATCH 54/75] Fix tests --- distributed/scheduler.py | 4 ++-- distributed/tests/test_core.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 30c4b137be7..8c2fb528e33 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4628,8 +4628,8 @@ async def remove_client_from_events(): cleanup_delay = parse_timedelta( dask.config.get("distributed.scheduler.events-cleanup-delay") ) - - self.call_later(cleanup_delay, remove_client_from_events) + if not self._ongoing_background_tasks.closed: + self.call_later(cleanup_delay, remove_client_from_events) def send_task_to_worker(self, worker, ts: TaskState, duration: float = -1): """Send a single computational task to a worker""" diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index c2a3948e422..1d1b762f04d 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -14,6 +14,7 @@ from distributed.comm.core import CommClosedError from distributed.core import ( AsyncTaskGroup, + AsyncTaskGroupClosedError, ConnectionPool, Server, Status, @@ -167,12 +168,12 @@ async def set_flag(): flag = True return True - task = group.call_soon(set_flag) - assert task is None + with pytest.raises(AsyncTaskGroupClosedError): + group.call_soon(set_flag) assert len(group) == 0 - task = group.call_later(1, set_flag) - assert task is None + with pytest.raises(AsyncTaskGroupClosedError): + task = group.call_later(1, set_flag) assert len(group) == 0 await asyncio.sleep(0.01) From 55bcdd3ca73e55086be8099a024cf02ae40b6100 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 14:26:28 +0200 Subject: [PATCH 55/75] Drop wrapper methods to highlight which group is being used --- distributed/core.py | 54 ++++------------------------------------ distributed/nanny.py | 4 +-- distributed/scheduler.py | 16 +++++++----- distributed/worker.py | 2 +- 4 files changed, 18 insertions(+), 58 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 25d1c53fc83..7f4a8a94d84 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -520,7 +520,7 @@ async def start_pcs(): if not pc.is_running(): pc.start() - self.call_soon(start_pcs) + self._ongoing_background_tasks.call_soon(start_pcs) def stop(self): if not self.__stopped: @@ -537,7 +537,7 @@ def stop(self): async def _stop_listener(): listener.stop() - self.call_soon(_stop_listener) + self._ongoing_background_tasks.call_soon(_stop_listener) @property def listener(self): @@ -828,7 +828,9 @@ async def handle_stream(self, comm, extra=None): break handler = self.stream_handlers[op] if iscoroutinefunction(handler): - self.call_soon(handler, **merge(extra, msg)) + self._ongoing_background_tasks.call_soon( + handler, **merge(extra, msg) + ) await asyncio.sleep(0) else: handler(**merge(extra, msg)) @@ -849,52 +851,6 @@ async def handle_stream(self, comm, extra=None): await comm.close() assert comm.closed() - def call_later( - self, delay: float, afunc: Callable[..., Coroutine], *args, **kwargs - ) -> None: - """Schedule a coroutine function to be asynchronously executed after `delay` seconds. - - The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments - to be asynchronously executed after `delay` seconds. - - Parameters - ---------- - delay - Delay in seconds. - afunc - Coroutine function to schedule. - *args - Arguments to be passed to `afunc`. - **kwargs - Keyword arguments to be passed to `afunc` - - Returns - ------- - None - """ - - self._ongoing_background_tasks.call_later(delay, afunc, *args, **kwargs) - - def call_soon(self, afunc: Callable[..., Coroutine], *args, **kwargs) -> None: - """Schedule a coroutine function to be executed asynchronously. - - The coroutine function `afunc` is scheduled asynchronously with `args` arguments and `kwargs` keyword arguments. - - Parameters - ---------- - afunc - Coroutine function to schedule. - *args - Arguments to be passed to `afunc`. - **kwargs - Keyword arguments to be passed to `afunc` - - Returns - ------- - None - """ - self._ongoing_background_tasks.call_soon(afunc, *args, **kwargs) - async def close(self, timeout=None): for pc in self.periodic_callbacks.values(): pc.stop() diff --git a/distributed/nanny.py b/distributed/nanny.py index 1d66f03d300..70eab0c0b38 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -498,7 +498,7 @@ def run(self, comm, *args, **kwargs): return run(self, comm, *args, **kwargs) def _on_exit_sync(self, exitcode): - self.call_soon(self._on_exit, exitcode) + self._ongoing_background_tasks.call_soon(self._on_exit, exitcode) @log_errors async def _on_exit(self, exitcode): @@ -595,7 +595,7 @@ async def _log_event(self, topic, msg): ) def log_event(self, topic, msg): - self.call_soon(self._log_event, topic, msg) + self._ongoing_background_tasks.call_soon(self._log_event, topic, msg) class WorkerProcess: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8c2fb528e33..32596d0c7ed 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3325,7 +3325,7 @@ async def start_unsafe(self): for k, v in self.services.items(): logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) - self.call_soon(self.reevaluate_occupancy) + self._ongoing_background_tasks.call_soon(self.reevaluate_occupancy) if self.scheduler_file: with open(self.scheduler_file, "w") as f: @@ -4268,7 +4268,9 @@ async def remove_worker_from_events(): dask.config.get("distributed.scheduler.events-cleanup-delay") ) - self.call_later(cleanup_delay, remove_worker_from_events) + self._ongoing_background_tasks.call_later( + cleanup_delay, remove_worker_from_events + ) logger.debug("Removed worker %s", ws) return "OK" @@ -4629,7 +4631,9 @@ async def remove_client_from_events(): dask.config.get("distributed.scheduler.events-cleanup-delay") ) if not self._ongoing_background_tasks.closed: - self.call_later(cleanup_delay, remove_client_from_events) + self._ongoing_background_tasks.call_later( + cleanup_delay, remove_client_from_events + ) def send_task_to_worker(self, worker, ts: TaskState, duration: float = -1): """Send a single computational task to a worker""" @@ -4911,7 +4915,7 @@ def worker_send(self, worker: str, msg: dict[str, Any]) -> None: try: stream_comms[worker].send(msg) except (CommClosedError, AttributeError): - self.call_soon( + self._ongoing_background_tasks.call_soon( self.remove_worker, address=worker, stimulus_id=f"worker-send-comm-fail-{time()}", @@ -4960,7 +4964,7 @@ def send_all(self, client_msgs: dict, worker_msgs: dict): # worker already gone pass except (CommClosedError, AttributeError): - self.call_soon( + self._ongoing_background_tasks.call_soon( self.remove_worker, address=worker, stimulus_id=f"send-all-comm-fail-{time()}", @@ -7036,7 +7040,7 @@ def check_idle(self): "Scheduler closing after being idle for %s", format_time(self.idle_timeout), ) - self.call_soon(self.close) + self._ongoing_background_tasks.call_soon(self.close) def adaptive_target(self, target_duration=None): """Desired number of workers based on the current workload diff --git a/distributed/worker.py b/distributed/worker.py index 10215337ac4..e7faa9479a8 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1683,7 +1683,7 @@ async def batched_send_connect(): bcomm.start(comm) - self.call_soon(batched_send_connect) + self._ongoing_background_tasks.call_soon(batched_send_connect) self.stream_comms[address].send(msg) From 6c4536c5396723adee0f80e3431094e41a98a583 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 15:18:09 +0200 Subject: [PATCH 56/75] Fix tests errors after cleanup --- distributed/tests/test_client.py | 2 +- distributed/worker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 828e30a4afb..bcf000a0a39 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5129,7 +5129,7 @@ def long_running(lock, entered): assert s.total_occupancy == 0 assert ws.occupancy == 0 - s.call_soon(s.reevaluate_occupancy, 0) + s._ongoing_background_tasks.call_soon(s.reevaluate_occupancy, 0) assert s.workers[a.address].occupancy == 0 await l.release() diff --git a/distributed/worker.py b/distributed/worker.py index e7faa9479a8..d150f33e72b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -4951,7 +4951,7 @@ async def run(server, comm, function, args=(), kwargs=None, wait=True): if wait: result = await function(*args, **kwargs) else: - server.call_soon(function, *args, **kwargs) + server._ongoing_background_tasks.call_soon(function, *args, **kwargs) result = None except Exception as e: From 53643cd388e8e9a53caca718a180ef175892597b Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 16:35:53 +0200 Subject: [PATCH 57/75] Fix inproc cancel handling and simplify pc startup --- distributed/comm/inproc.py | 7 ++++-- distributed/core.py | 34 ++++++++++++-------------- distributed/deploy/tests/test_local.py | 1 - 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 041c9d530a8..97797b554d1 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -103,7 +103,7 @@ def get_nowait(self): raise QueueEmpty return q.popleft() - def get(self): + async def get(self): assert not self._read_future, "Only one reader allowed" fut = Future() q = self._q @@ -111,7 +111,10 @@ def get(self): fut.set_result(q.popleft()) else: self._read_future = fut - return fut + try: + return await fut + finally: + self._read_future = None def put_nowait(self, value): q = self._q diff --git a/distributed/core.py b/distributed/core.py index 7f4a8a94d84..30db09c9903 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -37,6 +37,7 @@ from distributed.metrics import time from distributed.system_monitor import SystemMonitor from distributed.utils import ( + NoOpAwaitable, delayed, get_traceback, has_keyword, @@ -511,16 +512,17 @@ def start_periodic_callbacks(self): """Start Periodic Callbacks consistently This starts all PeriodicCallbacks stored in self.periodic_callbacks if - they are not yet running. It does this safely on the IOLoop. + they are not yet running. It does this safely by checking that it is using the + correct event loop. """ - self._last_tick = time() - - async def start_pcs(): - for pc in self.periodic_callbacks.values(): - if not pc.is_running(): - pc.start() + if self.io_loop.asyncio_loop is not asyncio.get_running_loop(): + raise RuntimeError(f"{self!r} is bound to a different event loop") - self._ongoing_background_tasks.call_soon(start_pcs) + self._last_tick = time() + for pc in self.periodic_callbacks.values(): + if not pc.is_running(): + logger.info("Starting periodic callback {pc!r}") + pc.start() def stop(self): if not self.__stopped: @@ -534,10 +536,7 @@ def stop(self): # The demonstrator for this is Worker.terminate(), which # closes the server socket in response to an incoming message. # See https://github.com/tornadoweb/tornado/issues/2069 - async def _stop_listener(): - listener.stop() - - self._ongoing_background_tasks.call_soon(_stop_listener) + listener.stop() @property def listener(self): @@ -662,7 +661,11 @@ async def listen(self, port_or_addr=None, allow_offload=True, **kwargs): ) self.listeners.append(listener) - async def handle_comm(self, comm): + def handle_comm(self, comm): + self._ongoing_background_tasks.call_soon(self._handle_comm, comm) + return NoOpAwaitable() + + async def _handle_comm(self, comm): """Dispatch new communications to coroutine-handlers Handlers is a dictionary mapping operation names to functions or @@ -756,11 +759,6 @@ async def handle_comm(self, comm): else: result = handler(**msg) if inspect.iscoroutine(result): - try: - result = self._ongoing_comm_handlers.schedule(result) - except AsyncTaskGroupClosedError: - result.close() # TODO: Don't call coroutinefunctions that we can't await - return result = await result elif inspect.isawaitable(result): raise RuntimeError( diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index bb1e1becf6e..781dc29a131 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -1024,7 +1024,6 @@ async def test_no_dangling_asyncio_tasks(): async with LocalCluster(asynchronous=True, processes=False, dashboard_address=":0"): await asyncio.sleep(0.01) - await asyncio.sleep(0.01) tasks = asyncio.all_tasks() assert tasks == start From 2494e40657d01a64636033b8874dd70ffdf84157 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 16:39:34 +0200 Subject: [PATCH 58/75] Drop AsyncTaskGroup.schedule() --- distributed/core.py | 31 ++++--------------------------- 1 file changed, 4 insertions(+), 27 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 30db09c9903..7a8f7dea01d 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -143,32 +143,6 @@ def __init__(self) -> None: self.closed = False self._ongoing_tasks: set[asyncio.Task] = set() - def schedule(self, coro: Coroutine) -> asyncio.Task: - """Schedules a coroutine object to be executed as an `asyncio.Task`. - - Parameters - ---------- - coro - Coroutine object to schedule. - - Returns - ------- - The scheduled Task object. - - Raises - ------ - AsyncTaskGroupClosedError - If the task group is closed. - """ - if self.closed: - raise AsyncTaskGroupClosedError( - "Cannot schedule a new coroutine as the group is already closed." - ) - task = self._get_loop().create_task(coro) - task.add_done_callback(self._ongoing_tasks.remove) - self._ongoing_tasks.add(task) - return task - def call_soon( self, afunc: Callable[..., Coroutine], *args, **kwargs ) -> asyncio.Task: @@ -199,7 +173,10 @@ def call_soon( raise AsyncTaskGroupClosedError( "Cannot schedule a new coroutine function as the group is already closed." ) - return self.schedule(afunc(*args, **kwargs)) + task = self._get_loop().create_task(afunc(*args, **kwargs)) + task.add_done_callback(self._ongoing_tasks.remove) + self._ongoing_tasks.add(task) + return task def call_later( self, delay: float, afunc: Callable[..., Coroutine], *args, **kwargs From 2751c52216caa8130db4577e31acf5f545c2f866 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 16:59:35 +0200 Subject: [PATCH 59/75] Add comment --- distributed/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/core.py b/distributed/core.py index 7a8f7dea01d..cf82cc7bff7 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -639,6 +639,7 @@ async def listen(self, port_or_addr=None, allow_offload=True, **kwargs): self.listeners.append(listener) def handle_comm(self, comm): + """Start a background task that dispatches new communications to coroutine-handlers""" self._ongoing_background_tasks.call_soon(self._handle_comm, comm) return NoOpAwaitable() From 93ff2ba0e83899977408e9db37f601b6d8e86bdf Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 17:45:01 +0200 Subject: [PATCH 60/75] Fix f-string --- distributed/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/core.py b/distributed/core.py index cf82cc7bff7..34b9c3c39f1 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -498,7 +498,7 @@ def start_periodic_callbacks(self): self._last_tick = time() for pc in self.periodic_callbacks.values(): if not pc.is_running(): - logger.info("Starting periodic callback {pc!r}") + logger.info(f"Starting periodic callback {pc!r}") pc.start() def stop(self): From e76489c8a950e4850605375bb0a1a24bf0f704c5 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 18:57:55 +0200 Subject: [PATCH 61/75] Abort comm if we cannot handle it --- distributed/core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/core.py b/distributed/core.py index 34b9c3c39f1..0a4d681865c 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -640,7 +640,10 @@ async def listen(self, port_or_addr=None, allow_offload=True, **kwargs): def handle_comm(self, comm): """Start a background task that dispatches new communications to coroutine-handlers""" - self._ongoing_background_tasks.call_soon(self._handle_comm, comm) + try: + self._ongoing_background_tasks.call_soon(self._handle_comm, comm) + except AsyncTaskGroupClosedError: + comm.abort() return NoOpAwaitable() async def _handle_comm(self, comm): From 09a3b2cccb46175b25768aca218a9e5f8ce733b9 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 19:03:34 +0200 Subject: [PATCH 62/75] Fix listener.stop --- distributed/core.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 0a4d681865c..da53b53a273 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -506,14 +506,13 @@ def stop(self): self.__stopped = True for listener in self.listeners: - # Delay closing the server socket until the next IO loop tick. - # Otherwise race conditions can appear if an event handler - # for an accept() call is already scheduled by the IO loop, - # raising EBADF. - # The demonstrator for this is Worker.terminate(), which - # closes the server socket in response to an incoming message. - # See https://github.com/tornadoweb/tornado/issues/2069 - listener.stop() + + async def stop_listener(listener): + v = listener.stop() + if inspect.isawaitable(v): + await v + + self._ongoing_background_tasks.call_soon(stop_listener, listener) @property def listener(self): From a7df081f53eb43882c76d350b61b0667ad00cabf Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 19:10:28 +0200 Subject: [PATCH 63/75] Add comment --- distributed/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/core.py b/distributed/core.py index da53b53a273..ca7c8034b82 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -238,6 +238,7 @@ async def gather(): timeout, ) except asyncio.TimeoutError: + # The timeout on gather has cancelled the tasks, so this will not hang indefinitely await asyncio.gather(*tasks_to_stop, return_exceptions=True) if [t for t in self._ongoing_tasks if t is not current_task]: From 8873ccbf0d6b0c067b7784902b5d5bd51644db69 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 8 Jun 2022 19:12:56 +0200 Subject: [PATCH 64/75] Test idempotency --- distributed/tests/test_core.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 1d1b762f04d..4962e02940f 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -127,6 +127,10 @@ def test_async_task_group_close_closes(): group.close() assert group.closed + # Test idempotency + group.close() + assert group.closed + @gen_test() async def test_async_task_group_close_does_not_cancel_existing_tasks(): From 6a308babc692ae21d25975ba7b9f9b5b554e8a8a Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Mon, 20 Jun 2022 14:57:41 +0100 Subject: [PATCH 65/75] do not return tasks from AsyncTaskGroup.call_soon and AsyncTaskGroup.call_later it's very easy to accidentally call "await group.call_soon(set_flag)" which is always wrong, this is especially bad because some type checkers will warn you that you didn't await your awaitables here also this prevents issues like "task = group.call_soon(set_flag); ... await task" where it's easy to confuse a cancellation coming from asyncio.current_task().cancel() and task.cancelled() which should be handled differently --- distributed/core.py | 35 ++++++++++++++++------- distributed/tests/test_core.py | 52 ++++++++++++++++++++-------------- distributed/utils.py | 7 ++--- 3 files changed, 57 insertions(+), 37 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index ca7c8034b82..45057fd6b47 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -11,11 +11,11 @@ import warnings import weakref from collections import defaultdict, deque -from collections.abc import Container +from collections.abc import Container, Coroutine from contextlib import suppress from enum import Enum from functools import partial -from typing import Callable, ClassVar, Coroutine, TypedDict, TypeVar +from typing import TYPE_CHECKING, Any, Callable, ClassVar, TypedDict, TypeVar import tblib from tlz import merge @@ -46,6 +46,14 @@ truncate_exception, ) +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + P = ParamSpec("P") + R = TypeVar("R") + T = TypeVar("T") + Coro = Coroutine[Any, Any, T] + class Status(Enum): """ @@ -141,11 +149,11 @@ class AsyncTaskGroup(_LoopBoundMixin): def __init__(self) -> None: self.closed = False - self._ongoing_tasks: set[asyncio.Task] = set() + self._ongoing_tasks: set[asyncio.Task[None]] = set() def call_soon( - self, afunc: Callable[..., Coroutine], *args, **kwargs - ) -> asyncio.Task: + self, afunc: Callable[P, Coro[None]], /, *args: P.args, **kwargs: P.kwargs + ) -> None: """Schedule a coroutine function to be executed as an `asyncio.Task`. The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments @@ -162,7 +170,7 @@ def call_soon( Returns ------- - The scheduled Task object. + None Raises ------ @@ -176,11 +184,16 @@ def call_soon( task = self._get_loop().create_task(afunc(*args, **kwargs)) task.add_done_callback(self._ongoing_tasks.remove) self._ongoing_tasks.add(task) - return task + return None def call_later( - self, delay: float, afunc: Callable[..., Coroutine], *args, **kwargs - ) -> asyncio.Task: + self, + delay: float, + afunc: Callable[P, Coro[None]], + /, + *args: P.args, + **kwargs: P.kwargs, + ) -> None: """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`. The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments @@ -199,14 +212,14 @@ def call_later( Returns ------- - The scheduled Task object. + The None Raises ------ AsyncTaskGroupClosedError If the task group is closed. """ - return self.call_soon(delayed(afunc, delay), *args, **kwargs) + self.call_soon(delayed(afunc, delay), *args, **kwargs) def close(self) -> None: """Closes the task group so that no new tasks can be scheduled. diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 4962e02940f..754729810b9 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -82,6 +82,11 @@ def test_async_task_group_initialization(): assert len(group) == 0 +async def _wait_for_n_loop_cycles(n): + for i in range(n): + await asyncio.sleep(0) + + @gen_test() async def test_async_task_group_call_soon_executes_task_in_background(): group = AsyncTaskGroup() @@ -93,11 +98,10 @@ async def set_flag(): await ev.wait() flag = True - task = group.call_soon(set_flag) - assert task is not None + assert group.call_soon(set_flag) is None assert len(group) == 1 ev.set() - await task + await _wait_for_n_loop_cycles(2) assert len(group) == 0 assert flag @@ -112,10 +116,11 @@ async def set_flag(): flag = True start = timemod.monotonic() - task = group.call_later(1, set_flag) - assert task is not None + assert group.call_later(1, set_flag) is None assert len(group) == 1 - await task + # the task must complete in exactly 1 event loop cycle + await asyncio.sleep(1) + await _wait_for_n_loop_cycles(2) end = timemod.monotonic() assert len(group) == 0 assert flag @@ -143,18 +148,16 @@ async def set_flag(): nonlocal flag await ev.wait() flag = True - return True + return None - task = group.call_soon(set_flag) + assert group.call_soon(set_flag) is None group.close() - assert not task.cancelled() assert len(group) == 1 ev.set() - await task - assert task.result() + await _wait_for_n_loop_cycles(2) assert len(group) == 0 @@ -177,7 +180,7 @@ async def set_flag(): assert len(group) == 0 with pytest.raises(AsyncTaskGroupClosedError): - task = group.call_later(1, set_flag) + group.call_later(1, set_flag) assert len(group) == 0 await asyncio.sleep(0.01) @@ -188,39 +191,44 @@ async def set_flag(): async def test_async_task_group_stop_allows_shutdown(): group = AsyncTaskGroup() - flag = False + task = None async def set_flag(): - nonlocal flag + nonlocal task while not group.closed: - asyncio.sleep(0.01) - flag = True - return True + await asyncio.sleep(0.01) + task = asyncio.current_task() + return None - task = group.call_soon(set_flag) + assert group.call_soon(set_flag) is None assert len(group) == 1 + # when given a grace period of 1 second tasks are allowed to poll group.stop + # before awaiting other async functions await group.stop(timeout=1) + assert task.done() assert not task.cancelled() - assert flag - assert task.result() @gen_test() async def test_async_task_group_stop_cancels_long_running(): group = AsyncTaskGroup() + task = None flag = False async def set_flag(): + nonlocal task + task = asyncio.current_task() + await asyncio.sleep(10) nonlocal flag flag = True return True - task = group.call_later(10, set_flag) + assert group.call_soon(set_flag) is None assert len(group) == 1 await group.stop(timeout=1) - assert task.cancelled() assert not flag + assert task.cancelled() @gen_test() diff --git a/distributed/utils.py b/distributed/utils.py index 5353a918934..4d17020e82c 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -75,6 +75,7 @@ P = ParamSpec("P") T = TypeVar("T") + Coro = Coroutine[AnyType, AnyType, T] no_default = "__no_default__" @@ -1731,12 +1732,10 @@ def is_python_shutting_down() -> bool: return _python_shutting_down -def delayed( - corofunc: Callable[P, Coroutine[AnyType, AnyType, T]], delay: float -) -> Callable[P, Coroutine[AnyType, AnyType, T]]: +def delayed(corofunc: Callable[P, Coro[T]], delay: float) -> Callable[P, Coro[T]]: """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" - async def wrapper(*args, **kwargs): + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: await asyncio.sleep(delay) return await corofunc(*args, **kwargs) From 35fc707c83688c538c7350059bc0f7e147500413 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 21 Jun 2022 10:52:09 +0100 Subject: [PATCH 66/75] avoid deprecated loop kwarg --- distributed/cli/dask_worker.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 07165f6fc23..6d585c6c119 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -448,14 +448,11 @@ def del_pid_file(): signal_fired = False async def run(): - loop = IOLoop.current() - nannies = [ t( scheduler, scheduler_file=scheduler_file, nthreads=nthreads, - loop=loop, resources=resources, security=sec, contact_address=contact_address, From 753c6cee239c9510790d1683ac8cd5ca8853b0f5 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 21 Jun 2022 10:58:59 +0100 Subject: [PATCH 67/75] don't log periodic callback starts --- distributed/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/core.py b/distributed/core.py index 45057fd6b47..63b256c7cc2 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -512,7 +512,6 @@ def start_periodic_callbacks(self): self._last_tick = time() for pc in self.periodic_callbacks.values(): if not pc.is_running(): - logger.info(f"Starting periodic callback {pc!r}") pc.start() def stop(self): From 3902842137deab8ab7acdd310e8271e30c0b498d Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 21 Jun 2022 11:13:21 +0100 Subject: [PATCH 68/75] move d.core.delayed into private d.utils._delayed delayed already has a meaning in dask so we should come up with a new name before making it public --- distributed/core.py | 13 +++++++++++-- distributed/utils.py | 13 +------------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 63b256c7cc2..1f43d95debb 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -38,7 +38,6 @@ from distributed.system_monitor import SystemMonitor from distributed.utils import ( NoOpAwaitable, - delayed, get_traceback, has_keyword, iscoroutinefunction, @@ -141,6 +140,16 @@ class AsyncTaskGroupClosedError(RuntimeError): pass +def _delayed(corofunc: Callable[P, Coro[T]], delay: float) -> Callable[P, Coro[T]]: + """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" + + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + await asyncio.sleep(delay) + return await corofunc(*args, **kwargs) + + return wrapper + + class AsyncTaskGroup(_LoopBoundMixin): """Collection tracking all currently running asynchronous tasks within a group""" @@ -219,7 +228,7 @@ def call_later( AsyncTaskGroupClosedError If the task group is closed. """ - self.call_soon(delayed(afunc, delay), *args, **kwargs) + self.call_soon(_delayed(afunc, delay), *args, **kwargs) def close(self) -> None: """Closes the task group so that no new tasks can be scheduled. diff --git a/distributed/utils.py b/distributed/utils.py index 4d17020e82c..848f23c1e45 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -32,7 +32,7 @@ from types import ModuleType from typing import TYPE_CHECKING from typing import Any as AnyType -from typing import Callable, ClassVar, Coroutine, TypeVar, overload +from typing import Callable, ClassVar, TypeVar, overload import click import tblib.pickling_support @@ -75,7 +75,6 @@ P = ParamSpec("P") T = TypeVar("T") - Coro = Coroutine[AnyType, AnyType, T] no_default = "__no_default__" @@ -1730,13 +1729,3 @@ def is_python_shutting_down() -> bool: from distributed import _python_shutting_down return _python_shutting_down - - -def delayed(corofunc: Callable[P, Coro[T]], delay: float) -> Callable[P, Coro[T]]: - """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" - - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - await asyncio.sleep(delay) - return await corofunc(*args, **kwargs) - - return wrapper From 2cf6e9156af544d19bb8167dc23a81dd117b76f1 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 21 Jun 2022 11:14:25 +0100 Subject: [PATCH 69/75] remove unused IOLoop import --- distributed/cli/dask_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 6d585c6c119..1c67e86032e 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -14,7 +14,7 @@ import click from tlz import valmap -from tornado.ioloop import IOLoop, TimeoutError +from tornado.ioloop import TimeoutError import dask from dask.system import CPU_COUNT From e0ae0f97500a8a8768d79820904a2bdb42fc2a58 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 21 Jun 2022 11:15:27 +0100 Subject: [PATCH 70/75] back out cancellation changes to inproc this seems to have some difficult to understand consequences and should be handled in https://github.com/dask/distributed/issues/6548 --- distributed/comm/inproc.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 97797b554d1..041c9d530a8 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -103,7 +103,7 @@ def get_nowait(self): raise QueueEmpty return q.popleft() - async def get(self): + def get(self): assert not self._read_future, "Only one reader allowed" fut = Future() q = self._q @@ -111,10 +111,7 @@ async def get(self): fut.set_result(q.popleft()) else: self._read_future = fut - try: - return await fut - finally: - self._read_future = None + return fut def put_nowait(self, value): q = self._q From 2714ccc13a24dbe3522e63c27cb1d4d162f15f28 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 21 Jun 2022 13:26:50 +0100 Subject: [PATCH 71/75] fix ERROR - Failed while closing connection to 'inproc://...': invalid state --- distributed/comm/inproc.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index 041c9d530a8..71fd71fba09 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -88,6 +88,13 @@ class QueueEmpty(Exception): pass +def _set_result_unless_cancelled(fut, result): + """Helper setting the result only if the future was not cancelled.""" + if fut.cancelled(): + return + fut.set_result(result) + + class Queue: """ A single-reader, single-writer, non-threadsafe, peekable queue. @@ -119,7 +126,7 @@ def put_nowait(self, value): if fut is not None: assert len(q) == 0 self._read_future = None - fut.set_result(value) + _set_result_unless_cancelled(fut, value) else: q.append(value) From e6a0eeb98f7fc1f99c402117421f151f18d580d8 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 21 Jun 2022 14:48:34 +0100 Subject: [PATCH 72/75] remove whitespace/typing import changes to distributed/utils.py --- distributed/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index 848f23c1e45..7ac9676e336 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -20,7 +20,7 @@ import xml.etree.ElementTree from asyncio import TimeoutError from collections import deque -from collections.abc import Collection, Container, KeysView, ValuesView +from collections.abc import Callable, Collection, Container, KeysView, ValuesView from concurrent.futures import CancelledError, ThreadPoolExecutor # noqa: F401 from contextlib import contextmanager, suppress from contextvars import ContextVar @@ -32,7 +32,7 @@ from types import ModuleType from typing import TYPE_CHECKING from typing import Any as AnyType -from typing import Callable, ClassVar, TypeVar, overload +from typing import ClassVar, TypeVar, overload import click import tblib.pickling_support @@ -76,7 +76,6 @@ P = ParamSpec("P") T = TypeVar("T") - no_default = "__no_default__" _forkserver_preload_set = False From 253694d9eddcf2c35a58ace54f6b18e8cb2b45f8 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 22 Jun 2022 13:44:17 +0100 Subject: [PATCH 73/75] complete type annotation for AsyncTaskGroup.stop --- distributed/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/core.py b/distributed/core.py index aa4520e62e0..c8db5453eca 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -237,7 +237,7 @@ def close(self) -> None: """ self.closed = True - async def stop(self, timeout=1) -> None: + async def stop(self, timeout: float = 1) -> None: """Close the group and stop all currently running tasks. Closes the task group and waits `timeout` seconds for all tasks to gracefully finish. From a8244bde99182ffb7abd6508285d566414f95b22 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 22 Jun 2022 19:03:38 +0100 Subject: [PATCH 74/75] lock close to prevent it cancelling concurrent closes and always set the finished event to prevent hangs --- distributed/core.py | 46 ++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index c8db5453eca..02f6b9eb147 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -332,6 +332,7 @@ def __init__( stacklevel=2, ) + self.__close_lock = asyncio.Lock() self._status = Status.init self.handlers = { "identity": self.identity, @@ -852,27 +853,30 @@ async def handle_stream(self, comm, extra=None): assert comm.closed() async def close(self, timeout=None): - for pc in self.periodic_callbacks.values(): - pc.stop() - - if not self.__stopped: - self.__stopped = True - _stops = set() - for listener in self.listeners: - future = listener.stop() - if inspect.isawaitable(future): - _stops.add(future) - await asyncio.gather(*_stops) - - # TODO: Deal with exceptions - await self._ongoing_background_tasks.stop(timeout=1) - - # TODO: Deal with exceptions - await self._ongoing_comm_handlers.stop(timeout=1) - - await self.rpc.close() - await asyncio.gather(*[comm.close() for comm in list(self._comms)]) - self._event_finished.set() + try: + async with self.__close_lock: + for pc in self.periodic_callbacks.values(): + pc.stop() + + if not self.__stopped: + self.__stopped = True + _stops = set() + for listener in self.listeners: + future = listener.stop() + if inspect.isawaitable(future): + _stops.add(future) + await asyncio.gather(*_stops) + + # TODO: Deal with exceptions + await self._ongoing_background_tasks.stop(timeout=1) + + # TODO: Deal with exceptions + await self._ongoing_comm_handlers.stop(timeout=1) + + await self.rpc.close() + await asyncio.gather(*[comm.close() for comm in list(self._comms)]) + finally: + self._event_finished.set() def pingpong(comm): From d04f551d1111cb74c2448f1cd1a65681069d1515 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 23 Jun 2022 10:47:26 +0100 Subject: [PATCH 75/75] remove the close lock, setting the close event in finally should be enough --- distributed/core.py | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 2e8e48a94df..14b37f4281c 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -332,7 +332,6 @@ def __init__( stacklevel=2, ) - self.__close_lock = asyncio.Lock() self._status = Status.init self.handlers = { "identity": self.identity, @@ -867,27 +866,26 @@ async def handle_stream(self, comm, extra=None): async def close(self, timeout=None): try: - async with self.__close_lock: - for pc in self.periodic_callbacks.values(): - pc.stop() - - if not self.__stopped: - self.__stopped = True - _stops = set() - for listener in self.listeners: - future = listener.stop() - if inspect.isawaitable(future): - _stops.add(future) - await asyncio.gather(*_stops) - - # TODO: Deal with exceptions - await self._ongoing_background_tasks.stop(timeout=1) - - # TODO: Deal with exceptions - await self._ongoing_comm_handlers.stop(timeout=1) - - await self.rpc.close() - await asyncio.gather(*[comm.close() for comm in list(self._comms)]) + for pc in self.periodic_callbacks.values(): + pc.stop() + + if not self.__stopped: + self.__stopped = True + _stops = set() + for listener in self.listeners: + future = listener.stop() + if inspect.isawaitable(future): + _stops.add(future) + await asyncio.gather(*_stops) + + # TODO: Deal with exceptions + await self._ongoing_background_tasks.stop(timeout=1) + + # TODO: Deal with exceptions + await self._ongoing_comm_handlers.stop(timeout=1) + + await self.rpc.close() + await asyncio.gather(*[comm.close() for comm in list(self._comms)]) finally: self._event_finished.set()