From 4f7b8e9d15e8f64b46b4e227d18a3504b06e1aa2 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 27 Apr 2022 12:11:38 +0100 Subject: [PATCH] revisit test --- distributed/tests/test_worker_memory.py | 29 ++++++++++++++++++------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index f2789562dc2..01fea447cbf 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -2,6 +2,7 @@ import asyncio import logging +import threading from collections import Counter, UserDict from time import sleep @@ -649,10 +650,11 @@ def leak(): }, ) async def test_pause_while_spilling(c, s, a): - N = 50 + N_PAUSE = 3 + N_TOTAL = 5 def get_process_memory(): - if len(a.data) < N: + if len(a.data) < N_PAUSE: # Don't trigger spilling until after all tasks have completed return 0 elif a.data.fast and not a.data.slow: @@ -665,20 +667,31 @@ def get_process_memory(): a.monitor.get_process_memory = get_process_memory class SlowSpill: - def __init__(self, _, paused: bool = False): - self.paused = paused + def __init__(self, _): + # Can't pickle a Semaphore, so instead of a default value, we create it + # here. Don't worry about race conditions; the worker is single-threaded. + if not hasattr(type(self), "sem"): + type(self).sem = threading.Semaphore(N_PAUSE) + # Block if there are N_PAUSE tasks in a.data.fast + self.sem.acquire() def __reduce__(self): paused = distributed.get_worker().status == Status.paused if not paused: sleep(0.1) - return SlowSpill, (None, paused) + self.sem.release() + return bool, (paused,) - futs = c.map(SlowSpill, range(N)) - while len(a.data.slow) < N: + futs = c.map(SlowSpill, range(N_TOTAL)) + while len(a.data.slow) < N_PAUSE + 1: await asyncio.sleep(0.01) + assert a.status == Status.paused - assert any(sp.paused for sp in a.data.values()) + # Worker should have become paused after the first `SlowSpill` was evicted, because + # the spill to disk took longer than the memory monitor interval. + assert len(a.data.fast) == 0 + assert len(a.data.slow) == N_PAUSE + 1 + assert sum(paused is True for paused in a.data.slow.values()) == N_PAUSE @pytest.mark.slow