diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index b68643b844..6a4ec01fe9 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2944,3 +2944,54 @@ async def test_who_has_consistent_remove_replica(c, s, *workers): assert ("missing-dep", f1.key) in a.story(f1.key) assert a.tasks[f1.key].suspicious_count == 0 assert s.tasks[f1.key].suspicious == 0 + + +@gen_cluster(client=True) +async def test_missing_released_zombie_tasks(c, s, a, b): + """ + Ensure that no fetch/flight tasks are left in the task dict of a + worker after everything was released + """ + a.total_in_connections = 0 + f1 = c.submit(inc, 1, key="f1", workers=[a.address]) + f2 = c.submit(inc, f1, key="f2", workers=[b.address]) + key = f1.key + + while key not in b.tasks or b.tasks[key].state != "fetch": + await asyncio.sleep(0.01) + + await a.close(report=False) + + del f1, f2 + + while b.tasks: + await asyncio.sleep(0.01) + + +@gen_cluster(client=True) +async def test_missing_released_zombie_tasks_2(c, s, a, b): + a.total_in_connections = 0 + f1 = c.submit(inc, 1, key="f1", workers=[a.address]) + f2 = c.submit(inc, f1, key="f2", workers=[b.address]) + + while f1.key not in b.tasks: + await asyncio.sleep(0) + + ts = b.tasks[f1.key] + assert ts.state == "fetch" + + # A few things can happen to clear who_has. The dominant process is upon + # connection failure to a worker. Regardless of how the set was cleared, the + # task will be transitioned to missing where the worker is trying to + # reaquire this information from the scheduler. While this is happening on + # worker side, the tasks are released and we want to ensure that no dangling + # zombie tasks are left on the worker + ts.who_has.clear() + + del f1, f2 + + while b.tasks: + await asyncio.sleep(0.01) + + story = b.story(ts) + assert any("missing" in msg for msg in story) diff --git a/distributed/worker.py b/distributed/worker.py index 4e240dfa58..07ed6db0d9 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -191,6 +191,7 @@ def __init__(self, key, runspec=None): self.who_has = set() self.coming_from = None self.waiting_for_data = set() + self.waiters = set() self.resource_restrictions = {} self.exception = None self.exception_text = "" @@ -1824,6 +1825,7 @@ def transition_released_waiting(self, ts, *, stimulus_id): for dep_ts in ts.dependencies: if not dep_ts.state == "memory": ts.waiting_for_data.add(dep_ts) + dep_ts.waiters.add(ts) if ts.waiting_for_data: self.waiting_for_data_count += 1 @@ -2639,19 +2641,6 @@ async def find_missing(self): who_has = {k: v for k, v in who_has.items() if v} self.update_who_has(who_has, stimulus_id=stimulus_id) - if self._missing_dep_flight: - logger.debug( - "No new workers found for %s", self._missing_dep_flight - ) - recommendations = { - dep: "released" - for dep in self._missing_dep_flight - if dep.state == "missing" - } - self.transitions( - recommendations=recommendations, stimulus_id=stimulus_id - ) - finally: # This is quite arbitrary but the heartbeat has scaling implemented self.periodic_callbacks[ @@ -2762,9 +2751,11 @@ def release_key( self.available_resources[resource] += quantity for d in ts.dependencies: - ts.waiting_for_data.discard(ts) - if not d.dependents and d.state in {"flight", "fetch", "missing"}: - recommendations[d] = "released" + ts.waiting_for_data.discard(d) + d.waiters.discard(ts) + + if not d.waiters and d.state in {"flight", "fetch", "missing"}: + recommendations[d] = "forgotten" ts.waiting_for_data.clear() ts.nbytes = None