diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index f2789562dc2..c4ce4fd64a0 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -649,10 +649,8 @@ def leak(): }, ) async def test_pause_while_spilling(c, s, a): - N = 50 - def get_process_memory(): - if len(a.data) < N: + if len(a.data) < 3: # Don't trigger spilling until after all tasks have completed return 0 elif a.data.fast and not a.data.slow: @@ -665,20 +663,25 @@ def get_process_memory(): a.monitor.get_process_memory = get_process_memory class SlowSpill: - def __init__(self, _, paused: bool = False): - self.paused = paused + def __init__(self, _, sem: distributed.Semaphore): + self.sem = sem + # Block if there are 50 tasks in a.data.fast + sem.acquire() def __reduce__(self): paused = distributed.get_worker().status == Status.paused if not paused: - sleep(0.1) - return SlowSpill, (None, paused) + sleep(0.1) # This is 10x the memory monitor interval + self.sem.release() + return bool, (paused,) - futs = c.map(SlowSpill, range(N)) - while len(a.data.slow) < N: + sem = await distributed.Semaphore(3) + futs = c.map(SlowSpill, range(5), sem=sem) + while len(a.data.slow) < 3: await asyncio.sleep(0.01) assert a.status == Status.paused - assert any(sp.paused for sp in a.data.values()) + assert any(sp is True for sp in a.data.slow.values()) + assert sum(ts.state == "ready" for ts in a.tasks.values()) == 2 @pytest.mark.slow