Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove server close background task grace period #6633

Merged
merged 8 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 28 additions & 36 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,36 +237,26 @@ def close(self) -> None:
"""
self.closed = True

async def stop(self, timeout: float = 1) -> None:
async def stop(self) -> None:
"""Close the group and stop all currently running tasks.

Closes the task group and waits `timeout` seconds for all tasks to gracefully finish.
After the timeout, all remaining tasks are cancelled.
Closes the task group and cancels all tasks. All tasks are cancelled
an additional time for each time this task is cancelled.
"""
self.close()

current_task = asyncio.current_task(self._get_loop())
tasks_to_stop = [t for t in self._ongoing_tasks if t is not current_task]

if tasks_to_stop:
# Wrap gather in task to avoid Python3.8 issue,
# see https://github.com/dask/distributed/pull/6478#discussion_r885696827
async def gather():
return await asyncio.gather(*tasks_to_stop, return_exceptions=True)

err = None
while tasks_to_stop := (self._ongoing_tasks - {current_task}):
for task in tasks_to_stop:
task.cancel()
try:
await asyncio.wait_for(
gather(),
timeout,
)
except asyncio.TimeoutError:
# The timeout on gather has cancelled the tasks, so this will not hang indefinitely
await asyncio.gather(*tasks_to_stop, return_exceptions=True)
await asyncio.wait(tasks_to_stop)
except asyncio.CancelledError as e:
err = e

if [t for t in self._ongoing_tasks if t is not current_task]:
raise RuntimeError(
f"Expected all ongoing tasks to be cancelled and removed, found {self._ongoing_tasks}."
)
if err is not None:
raise err
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning behind calling task.cancel() multiple times?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the current task is cancelled and any child task suppresses the cancelation then the child tasks will leak from the task group


def __len__(self):
return len(self._ongoing_tasks)
Expand Down Expand Up @@ -359,7 +349,6 @@ def __init__(
self.counters = None
self.digests = None
self._ongoing_background_tasks = AsyncTaskGroup()
self._ongoing_comm_handlers = AsyncTaskGroup()
self._event_finished = asyncio.Event()

self.listeners = []
Expand Down Expand Up @@ -523,17 +512,22 @@ def start_periodic_callbacks(self):
pc.start()

def stop(self):
if not self.__stopped:
self.__stopped = True
if self.__stopped:
return

for listener in self.listeners:
self.__stopped = True
_stops = set()
for listener in self.listeners:
future = listener.stop()
if inspect.isawaitable(future):
_stops.add(future)
Comment on lines +522 to +523
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC you introduce this merely to ensure that there are no users outside that provide a listener with an async stop?
This smells like we should introduce a deprecation warning and get rid of this

Copy link
Member Author

@graingert graingert Jun 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, it's tricky to work out where to put the deprecation, as right at server stop is a bit late.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, it's tricky to work out where to put the deprecation, as right at server stop is a bit late.

Well, better than nothing. If anybody is using this, they'll see it in their CI then. We're also not in a rush to remove this so we can let the warning sit for a while


async def stop_listener(listener):
v = listener.stop()
if inspect.isawaitable(v):
await v
if _stops:

self._ongoing_background_tasks.call_soon(stop_listener, listener)
async def background_stops():
await asyncio.gather(*_stops)

self._ongoing_background_tasks.call_soon(background_stops)
Copy link
Member

@hendrikmakait hendrikmakait Jun 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a comment/warning here to highlight that these are likely to get cancelled due to the lack of a grace period? Alternatively, we could add a small grace period back in that would allow fast tasks to finish but have less of a performance impact? For example, 100 ms should result in ~5 % performance degradation, 20 ms in 1 %.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a PendingDeprecationWarning to https://github.com/dask/distributed/pull/6633/files#r907345989

I have no idea who would implement this asynchronously. I think the deprecation warning there should be sufficient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a regular DeprecationWarning would be better, otherwise you have to go via Pend -> Deprecate -> Remove


@property
def listener(self):
Expand Down Expand Up @@ -874,13 +868,11 @@ async def close(self, timeout=None):
future = listener.stop()
if inspect.isawaitable(future):
_stops.add(future)
await asyncio.gather(*_stops)

# TODO: Deal with exceptions
await self._ongoing_background_tasks.stop(timeout=1)
if _stops:
await asyncio.gather(*_stops)

# TODO: Deal with exceptions
await self._ongoing_comm_handlers.stop(timeout=1)
await self._ongoing_background_tasks.stop()

await self.rpc.close()
await asyncio.gather(*[comm.close() for comm in list(self._comms)])
Expand Down
13 changes: 8 additions & 5 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,10 @@ def blockable_compute(x, lock):
await block_compute.acquire()

# Close in scheduler to ensure we transition and reschedule task properly
await s.close_worker(worker=a.address, stimulus_id="test")
await wait_for_state(fut1.key, "resumed", b)
await asyncio.gather(
wait_for_state(fut1.key, "resumed", b, interval=0),
s.close_worker(worker=a.address, stimulus_id="test"),
)

block_get_data.release()
await block_compute.release()
Expand Down Expand Up @@ -415,9 +417,10 @@ async def get_data(self, comm, *args, **kwargs):
f3.key: {w2.address},
}
)
await s.remove_worker(w1.address, stimulus_id="stim-id")

await wait_for_state(f3.key, "resumed", w2)
await asyncio.gather(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait_for_state(f3.key, "resumed", w2, interval=0),
s.remove_worker(w1.address, stimulus_id="stim-id"),
)
assert_story(
w2.state.log,
[
Expand Down
31 changes: 16 additions & 15 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
assert_can_connect_locally_4,
assert_can_connect_locally_6,
assert_cannot_connect,
async_wait_for,
captured_logger,
gen_cluster,
gen_test,
Expand Down Expand Up @@ -190,25 +189,21 @@ async def set_flag():


@gen_test()
async def test_async_task_group_stop_allows_shutdown():
async def test_async_task_group_stop_disallows_shutdown():
group = AsyncTaskGroup()

task = None

async def set_flag():
nonlocal task
while not group.closed:
await asyncio.sleep(0.01)
task = asyncio.current_task()
return None

assert group.call_soon(set_flag) is None
assert len(group) == 1
# when given a grace period of 1 second tasks are allowed to poll group.stop
# before awaiting other async functions
await group.stop(timeout=1)
assert task.done()
assert not task.cancelled()
# tasks are not given a grace period, and are not even allowed to start
# if the group is closed immediately
await group.stop()
assert task is None


@gen_test()
Expand All @@ -217,20 +212,24 @@ async def test_async_task_group_stop_cancels_long_running():

task = None
flag = False
started = asyncio.Event()

async def set_flag():
nonlocal task
task = asyncio.current_task()
started.set()
await asyncio.sleep(10)
nonlocal flag
flag = True
return True

assert group.call_soon(set_flag) is None
assert len(group) == 1
await group.stop(timeout=1)
assert not flag
await started.wait()
await group.stop()
assert task
assert task.cancelled()
assert not flag


@gen_test()
Expand Down Expand Up @@ -1065,9 +1064,12 @@ async def test_close_properly():
GH4704
"""

sleep_started = asyncio.Event()

async def sleep(comm=None):
# We want to ensure this is actually canceled therefore don't give it a
# chance to actually complete
sleep_started.set()
await asyncio.sleep(2000000)

server = await Server({"sleep": sleep})
Expand All @@ -1087,8 +1089,7 @@ async def sleep(comm=None):

comm = await remote.live_comm()
await comm.write({"op": "sleep"})

await async_wait_for(lambda: not server._ongoing_comm_handlers, 10)
await sleep_started.wait()

listeners = server.listeners
assert len(listeners) == len(ports)
Expand All @@ -1102,7 +1103,7 @@ async def sleep(comm=None):
await assert_cannot_connect(f"tcp://{ip}:{port}")

# weakref set/dict should be cleaned up
assert not len(server._ongoing_comm_handlers)
assert not len(server._ongoing_background_tasks)


@gen_test()
Expand Down