diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index bcb5512e4f..ca57c8e9e1 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2411,7 +2411,13 @@ async def test_hold_on_to_replicas(c, s, *workers): await asyncio.sleep(0.01) -@gen_cluster(client=True) +@gen_cluster( + client=True, + nthreads=[ + ("", 1), + ("", 1), + ], +) async def test_worker_reconnects_mid_compute(c, s, a, b): """Ensure that, if a worker disconnects while computing a result, the scheduler will still accept the result. @@ -2479,7 +2485,13 @@ def fast_on_a(lock): await asyncio.sleep(0.001) -@gen_cluster(client=True) +@gen_cluster( + client=True, + nthreads=[ + ("", 1), + ("", 1), + ], +) async def test_worker_reconnects_mid_compute_multiple_states_on_scheduler(c, s, a, b): """ Ensure that a reconnecting worker does not break the scheduler regardless of @@ -2494,6 +2506,7 @@ async def test_worker_reconnects_mid_compute_multiple_states_on_scheduler(c, s, # different states f1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True) f2 = c.submit(inc, f1, workers=[a.address], allow_other_workers=True) + await f1 a_address = a.address a.periodic_callbacks["heartbeat"].stop() @@ -2522,14 +2535,29 @@ def fast_on_a(lock): while f3.key not in a.tasks: await asyncio.sleep(0.01) + story_before = s.story(f1.key) await s.stream_comms[a.address].close() + # Release f1 which triggers a release cycle of all tasks such that + # they are rescheduled on B. However, at this time, B will never be + # able to compute f3 / fast_on_a since it is locked on that worker. + # The only way to get f3 to complete is for Worker A to reconnect. + f1.release() assert len(s.workers) == 1 - while s.tasks[f1.key].state != "released": - await asyncio.sleep(0) + story = s.story(f1.key) + while len(story) == len(story_before): + story = s.story(f1.key) + await asyncio.sleep(0.1) + + next = story[len(story_before)] + assert next[:3] == (f1.key, "memory", "released") + a.heartbeat_active = False await a.heartbeat() - assert len(s.workers) == 2 + + while len(s.workers) != 2: + await asyncio.sleep(0.01) + # Since B is locked, this is ensured to originate from A await f3