From 4488144cd52879da2fd43ac23b431180e0f3b19e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 20 May 2022 01:27:37 +0100 Subject: [PATCH] Revisit tests mocking gather_dep (#6385) --- distributed/tests/test_worker.py | 170 ++++++++++--------------------- 1 file changed, 56 insertions(+), 114 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 6bf2e2fa9b7..0bbcf001ba9 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3187,6 +3187,30 @@ async def test_task_flight_compute_oserror(c, s, a, b): assert_story(sum_story, expected_sum_story, strict=True) +class BlockedGatherDep(Worker): + def __init__(self, *args, **kwargs): + self.in_gather_dep = asyncio.Event() + self.block_gather_dep = asyncio.Event() + super().__init__(*args, **kwargs) + + async def gather_dep(self, *args, **kwargs): + self.in_gather_dep.set() + await self.block_gather_dep.wait() + return await super().gather_dep(*args, **kwargs) + + +class BlockedGetData(Worker): + def __init__(self, *args, **kwargs): + self.in_get_data = asyncio.Event() + self.block_get_data = asyncio.Event() + super().__init__(*args, **kwargs) + + async def get_data(self, comm, *args, **kwargs): + self.in_get_data.set() + await self.block_get_data.wait() + return await super().get_data(comm, *args, **kwargs) + + @gen_cluster(client=True, nthreads=[]) async def test_gather_dep_cancelled_rescheduled(c, s): """At time of writing, the gather_dep implementation filtered tasks again @@ -3206,42 +3230,16 @@ async def test_gather_dep_cancelled_rescheduled(c, s): See also test_gather_dep_do_not_handle_response_of_not_requested_tasks """ - in_gather_dep = asyncio.Event() - gather_dep_finished = asyncio.Event() - block_gather_dep = asyncio.Lock() - await block_gather_dep.acquire() - - class InstrumentedWorker(Worker): - async def gather_dep(self, *args, **kwargs): - in_gather_dep.set() - async with block_gather_dep: - try: - return await super().gather_dep(*args, **kwargs) - finally: - gather_dep_finished.set() - - block_get_data = asyncio.Lock() - in_get_data = asyncio.Event() - - class BlockedGetData(Worker): - async def get_data(self, comm, *args, **kwargs): - in_get_data.set() - async with block_get_data: - return await super().get_data(comm, *args, **kwargs) - async with BlockedGetData(s.address) as a: - async with InstrumentedWorker(s.address) as b: + async with BlockedGatherDep(s.address) as b: fut1 = c.submit(inc, 1, workers=[a.address], key="f1") fut2 = c.submit(inc, fut1, workers=[a.address], key="f2") - await fut2 - await block_get_data.acquire() + await wait(fut2) fut4 = c.submit(sum, fut1, fut2, workers=[b.address], key="f4") fut3 = c.submit(inc, fut1, workers=[b.address], key="f3") - fut2_key = fut2.key - - await _wait_for_state(fut2_key, b, "flight") - await in_gather_dep.wait() + await _wait_for_state(fut2.key, b, "flight") + await b.in_gather_dep.wait() fut4.release() while fut4.key in b.tasks: @@ -3249,19 +3247,16 @@ async def get_data(self, comm, *args, **kwargs): assert b.tasks[fut2.key].state == "cancelled" - block_gather_dep.release() - - await in_get_data.wait() + b.block_gather_dep.set() + await a.in_get_data.wait() fut4 = c.submit(sum, [fut1, fut2], workers=[b.address], key="f4") - while b.tasks[fut2.key].state != "flight": - await asyncio.sleep(0.1) - block_get_data.release() - await gather_dep_finished.wait() + await _wait_for_state(fut2.key, b, "flight") + + a.block_get_data.set() + await wait([fut3, fut4]) f2_story = b.story(fut2.key) assert f2_story - await fut3 - await fut4 @gen_cluster(client=True, nthreads=[("", 1)]) @@ -3272,90 +3267,54 @@ async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a) potential rescheduling or data leaks. This test may become obsolete if the implementation changes significantly. """ - in_gather_dep = asyncio.Event() - gather_dep_finished = asyncio.Event() - block_gather_dep = asyncio.Lock() - await block_gather_dep.acquire() - - class InstrumentedWorker(Worker): - async def gather_dep(self, *args, **kwargs): - in_gather_dep.set() - async with block_gather_dep: - try: - return await super().gather_dep(*args, **kwargs) - finally: - gather_dep_finished.set() - - async with InstrumentedWorker(s.address) as b: + async with BlockedGatherDep(s.address) as b: fut1 = c.submit(inc, 1, workers=[a.address], key="f1") fut2 = c.submit(inc, fut1, workers=[a.address], key="f2") await fut2 fut4 = c.submit(sum, fut1, fut2, workers=[b.address], key="f4") fut3 = c.submit(inc, fut1, workers=[b.address], key="f3") - fut2_key = fut2.key - - await _wait_for_state(fut2_key, b, "flight") - - await in_gather_dep.wait() + await b.in_gather_dep.wait() + assert b.tasks[fut2.key].state == "flight" fut4.release() while fut4.key in b.tasks: - await asyncio.sleep(0.05) + await asyncio.sleep(0.01) assert b.tasks[fut2.key].state == "cancelled" - block_gather_dep.release() - await gather_dep_finished.wait() + b.block_gather_dep.set() + await fut3 assert fut2.key not in b.tasks f2_story = b.story(fut2.key) assert f2_story assert not any("missing-dep" in msg for msg in f2_story) - await fut3 @gen_cluster( client=True, nthreads=[("", 1)], - config={ - "distributed.comm.recent-messages-log-length": 1000, - }, + config={"distributed.comm.recent-messages-log-length": 1000}, ) async def test_gather_dep_no_longer_in_flight_tasks(c, s, a): - in_gather_dep = asyncio.Event() - gather_dep_finished = asyncio.Event() - block_gather_dep = asyncio.Lock() - await block_gather_dep.acquire() - - class InstrumentedWorker(Worker): - async def gather_dep(self, *args, **kwargs): - in_gather_dep.set() - async with block_gather_dep: - try: - return await super().gather_dep(*args, **kwargs) - finally: - gather_dep_finished.set() - - async with InstrumentedWorker(s.address) as b: + async with BlockedGatherDep(s.address) as b: fut1 = c.submit(inc, 1, workers=[a.address], key="f1") fut2 = c.submit(sum, fut1, fut1, workers=[b.address], key="f2") - fut1_key = fut1.key - - await _wait_for_state(fut1_key, b, "flight") - await in_gather_dep.wait() + await _wait_for_state(fut1.key, b, "flight") + await b.in_gather_dep.wait() fut2.release() while fut2.key in b.tasks: - await asyncio.sleep(0) + await asyncio.sleep(0.01) assert b.tasks[fut1.key].state == "cancelled" - block_gather_dep.release() - await gather_dep_finished.wait() + b.block_gather_dep.set() + while fut2.key in b.tasks: + await asyncio.sleep(0.01) - assert fut2.key not in b.tasks f1_story = b.story(fut1.key) f2_story = b.story(fut2.key) assert f1_story @@ -3365,53 +3324,36 @@ async def gather_dep(self, *args, **kwargs): @pytest.mark.parametrize("intermediate_state", ["resumed", "cancelled"]) @pytest.mark.parametrize("close_worker", [False, True]) -@gen_cluster(client=True) +@gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "500ms"}) async def test_deadlock_cancelled_after_inflight_before_gather_from_worker( c, s, a, x, intermediate_state, close_worker ): """If a task was transitioned to in-flight, the gather-dep coroutine was scheduled but a cancel request came in before gather_data_from_worker was issued this might corrupt the state machine if the cancelled key is not - properly handled""" - - in_gather_dep = asyncio.Event() - gather_dep_finished = asyncio.Event() - block_gather_dep = asyncio.Lock() - await block_gather_dep.acquire() - - class InstrumentedWorker(Worker): - async def gather_dep(self, *args, **kwargs): - in_gather_dep.set() - async with block_gather_dep: - try: - return await super().gather_dep(*args, **kwargs) - finally: - gather_dep_finished.set() - + properly handled + """ fut1 = c.submit(slowinc, 1, workers=[a.address], key="f1") fut1B = c.submit(slowinc, 2, workers=[x.address], key="f1B") fut2 = c.submit(sum, [fut1, fut1B], workers=[x.address], key="f2") await fut2 - async with InstrumentedWorker(s.address, name="b") as b: + async with BlockedGatherDep(s.address, name="b") as b: fut3 = c.submit(inc, fut2, workers=[b.address], key="f3") - fut2_key = fut2.key - - await _wait_for_state(fut2_key, b, "flight") + await _wait_for_state(fut2.key, b, "flight") s.set_restrictions(worker={fut1B.key: a.address, fut2.key: b.address}) - await in_gather_dep.wait() + await b.in_gather_dep.wait() await s.remove_worker( address=x.address, safe=True, close=close_worker, stimulus_id="test" ) - await _wait_for_state(fut2_key, b, intermediate_state) + await _wait_for_state(fut2.key, b, intermediate_state) - block_gather_dep.release() - await gather_dep_finished.wait() + b.block_gather_dep.set() await fut3