Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revisit tests mocking gather_dep #6385

Merged
merged 2 commits into from
May 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 56 additions & 114 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3186,6 +3186,30 @@ async def test_task_flight_compute_oserror(c, s, a, b):
assert_story(sum_story, expected_sum_story, strict=True)


class BlockedGatherDep(Worker):
def __init__(self, *args, **kwargs):
self.in_gather_dep = asyncio.Event()
self.block_gather_dep = asyncio.Event()
super().__init__(*args, **kwargs)

async def gather_dep(self, *args, **kwargs):
self.in_gather_dep.set()
await self.block_gather_dep.wait()
return await super().gather_dep(*args, **kwargs)


class BlockedGetData(Worker):
def __init__(self, *args, **kwargs):
self.in_get_data = asyncio.Event()
self.block_get_data = asyncio.Event()
super().__init__(*args, **kwargs)

async def get_data(self, comm, *args, **kwargs):
self.in_get_data.set()
await self.block_get_data.wait()
return await super().get_data(comm, *args, **kwargs)


@gen_cluster(client=True, nthreads=[])
async def test_gather_dep_cancelled_rescheduled(c, s):
"""At time of writing, the gather_dep implementation filtered tasks again
Expand All @@ -3205,62 +3229,33 @@ async def test_gather_dep_cancelled_rescheduled(c, s):

See also test_gather_dep_do_not_handle_response_of_not_requested_tasks
"""
in_gather_dep = asyncio.Event()
gather_dep_finished = asyncio.Event()
block_gather_dep = asyncio.Lock()
await block_gather_dep.acquire()

class InstrumentedWorker(Worker):
async def gather_dep(self, *args, **kwargs):
in_gather_dep.set()
async with block_gather_dep:
try:
return await super().gather_dep(*args, **kwargs)
finally:
gather_dep_finished.set()

block_get_data = asyncio.Lock()
in_get_data = asyncio.Event()

class BlockedGetData(Worker):
async def get_data(self, comm, *args, **kwargs):
in_get_data.set()
async with block_get_data:
return await super().get_data(comm, *args, **kwargs)

async with BlockedGetData(s.address) as a:
async with InstrumentedWorker(s.address) as b:
async with BlockedGatherDep(s.address) as b:
fut1 = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(inc, fut1, workers=[a.address], key="f2")
await fut2
await block_get_data.acquire()
await wait(fut2)
fut4 = c.submit(sum, fut1, fut2, workers=[b.address], key="f4")
fut3 = c.submit(inc, fut1, workers=[b.address], key="f3")

fut2_key = fut2.key

await _wait_for_state(fut2_key, b, "flight")
await in_gather_dep.wait()
await _wait_for_state(fut2.key, b, "flight")
await b.in_gather_dep.wait()

fut4.release()
while fut4.key in b.tasks:
await asyncio.sleep(0)

assert b.tasks[fut2.key].state == "cancelled"

block_gather_dep.release()

await in_get_data.wait()
b.block_gather_dep.set()
await a.in_get_data.wait()

fut4 = c.submit(sum, [fut1, fut2], workers=[b.address], key="f4")
while b.tasks[fut2.key].state != "flight":
await asyncio.sleep(0.1)
block_get_data.release()
await gather_dep_finished.wait()
await _wait_for_state(fut2.key, b, "flight")

a.block_get_data.set()
await wait([fut3, fut4])
f2_story = b.story(fut2.key)
assert f2_story
await fut3
await fut4


@gen_cluster(client=True, nthreads=[("", 1)])
Expand All @@ -3271,90 +3266,54 @@ async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a)
potential rescheduling or data leaks.
This test may become obsolete if the implementation changes significantly.
"""
in_gather_dep = asyncio.Event()
gather_dep_finished = asyncio.Event()
block_gather_dep = asyncio.Lock()
await block_gather_dep.acquire()

class InstrumentedWorker(Worker):
async def gather_dep(self, *args, **kwargs):
in_gather_dep.set()
async with block_gather_dep:
try:
return await super().gather_dep(*args, **kwargs)
finally:
gather_dep_finished.set()

async with InstrumentedWorker(s.address) as b:
async with BlockedGatherDep(s.address) as b:
fut1 = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(inc, fut1, workers=[a.address], key="f2")
await fut2
fut4 = c.submit(sum, fut1, fut2, workers=[b.address], key="f4")
fut3 = c.submit(inc, fut1, workers=[b.address], key="f3")

fut2_key = fut2.key

await _wait_for_state(fut2_key, b, "flight")

await in_gather_dep.wait()
await b.in_gather_dep.wait()
assert b.tasks[fut2.key].state == "flight"

fut4.release()
while fut4.key in b.tasks:
await asyncio.sleep(0.05)
await asyncio.sleep(0.01)

assert b.tasks[fut2.key].state == "cancelled"

block_gather_dep.release()
await gather_dep_finished.wait()
b.block_gather_dep.set()

await fut3
assert fut2.key not in b.tasks
f2_story = b.story(fut2.key)
assert f2_story
assert not any("missing-dep" in msg for msg in f2_story)
await fut3


@gen_cluster(
client=True,
nthreads=[("", 1)],
config={
"distributed.comm.recent-messages-log-length": 1000,
},
config={"distributed.comm.recent-messages-log-length": 1000},
)
async def test_gather_dep_no_longer_in_flight_tasks(c, s, a):
in_gather_dep = asyncio.Event()
gather_dep_finished = asyncio.Event()
block_gather_dep = asyncio.Lock()
await block_gather_dep.acquire()

class InstrumentedWorker(Worker):
async def gather_dep(self, *args, **kwargs):
in_gather_dep.set()
async with block_gather_dep:
try:
return await super().gather_dep(*args, **kwargs)
finally:
gather_dep_finished.set()

async with InstrumentedWorker(s.address) as b:
async with BlockedGatherDep(s.address) as b:
fut1 = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(sum, fut1, fut1, workers=[b.address], key="f2")

fut1_key = fut1.key

await _wait_for_state(fut1_key, b, "flight")
await in_gather_dep.wait()
await _wait_for_state(fut1.key, b, "flight")
await b.in_gather_dep.wait()

fut2.release()
while fut2.key in b.tasks:
await asyncio.sleep(0)
await asyncio.sleep(0.01)

assert b.tasks[fut1.key].state == "cancelled"

block_gather_dep.release()
await gather_dep_finished.wait()
b.block_gather_dep.set()
while fut2.key in b.tasks:
await asyncio.sleep(0.01)

assert fut2.key not in b.tasks
f1_story = b.story(fut1.key)
f2_story = b.story(fut2.key)
assert f1_story
Expand All @@ -3364,53 +3323,36 @@ async def gather_dep(self, *args, **kwargs):

@pytest.mark.parametrize("intermediate_state", ["resumed", "cancelled"])
@pytest.mark.parametrize("close_worker", [False, True])
@gen_cluster(client=True)
@gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "500ms"})
async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(
c, s, a, x, intermediate_state, close_worker
):
"""If a task was transitioned to in-flight, the gather-dep coroutine was
scheduled but a cancel request came in before gather_data_from_worker was
issued this might corrupt the state machine if the cancelled key is not
properly handled"""

in_gather_dep = asyncio.Event()
gather_dep_finished = asyncio.Event()
block_gather_dep = asyncio.Lock()
await block_gather_dep.acquire()

class InstrumentedWorker(Worker):
async def gather_dep(self, *args, **kwargs):
in_gather_dep.set()
async with block_gather_dep:
try:
return await super().gather_dep(*args, **kwargs)
finally:
gather_dep_finished.set()

properly handled
"""
fut1 = c.submit(slowinc, 1, workers=[a.address], key="f1")
fut1B = c.submit(slowinc, 2, workers=[x.address], key="f1B")
fut2 = c.submit(sum, [fut1, fut1B], workers=[x.address], key="f2")
await fut2

async with InstrumentedWorker(s.address, name="b") as b:
async with BlockedGatherDep(s.address, name="b") as b:
fut3 = c.submit(inc, fut2, workers=[b.address], key="f3")

fut2_key = fut2.key

await _wait_for_state(fut2_key, b, "flight")
await _wait_for_state(fut2.key, b, "flight")

s.set_restrictions(worker={fut1B.key: a.address, fut2.key: b.address})

await in_gather_dep.wait()
await b.in_gather_dep.wait()

await s.remove_worker(
address=x.address, safe=True, close=close_worker, stimulus_id="test"
)

await _wait_for_state(fut2_key, b, intermediate_state)
await _wait_for_state(fut2.key, b, intermediate_state)

block_gather_dep.release()
await gather_dep_finished.wait()
b.block_gather_dep.set()
await fut3


Expand Down