Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove wrong assert in handle compute #6370

Merged
merged 3 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import distributed
from distributed import Event, Lock, Worker
from distributed.client import wait
from distributed.utils_test import (
_LockedCommPool,
assert_story,
Expand Down Expand Up @@ -396,3 +397,59 @@ def block_execution(event, lock):
await lock_executing.release()

assert await fut2 == 2


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_cancelled_resumed_after_flight_with_dependencies(c, s, w2, w3):
# See https://github.com/dask/distributed/pull/6327#discussion_r872231090
block_get_data_1 = asyncio.Lock()
enter_get_data_1 = asyncio.Event()
await block_get_data_1.acquire()

class BlockGetDataWorker(Worker):
def __init__(self, *args, get_data_event, get_data_lock, **kwargs):
self._get_data_event = get_data_event
self._get_data_lock = get_data_lock
super().__init__(*args, **kwargs)

async def get_data(self, comm, *args, **kwargs):
self._get_data_event.set()
async with self._get_data_lock:
return await super().get_data(comm, *args, **kwargs)

async with await BlockGetDataWorker(
s.address,
get_data_event=enter_get_data_1,
get_data_lock=block_get_data_1,
name="w1",
) as w1:

f1 = c.submit(inc, 1, key="f1", workers=[w1.address])
f2 = c.submit(inc, 2, key="f2", workers=[w1.address])
f3 = c.submit(sum, [f1, f2], key="f3", workers=[w1.address])

await wait(f3)
f4 = c.submit(inc, f3, key="f4", workers=[w2.address])

await enter_get_data_1.wait()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post #6371, you can:
1.

Suggested change
await enter_get_data_1.wait()
await wait_for_state(f1.key, "flight", w2)
await wait_for_state(f2.key, "flight", w2)
  1. get rid of the BlockGetDataWorker subclass
  2. initialise the worker with gen_cluster
  3. use
    event = asyncio.Event()
    w1.rpc = _LockedCommPool(w1.rpc, write_event=event)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually like my version with the worker subclass better. The _LockedCommPool requires much more low level knowledge and is more brittle in my opinion. I think it should only be used if nothing else is possible.

s.set_restrictions(
{
f1.key: {w3.address},
f2.key: {w3.address},
f3.key: {w2.address},
}
)
await s.remove_worker(w1.address, "stim-id")

await wait_for_state(f3.key, "resumed", w2)
assert_story(
w2.log,
[
(f3.key, "flight", "released", "cancelled", {}),
# ...
(f3.key, "cancelled", "waiting", "resumed", {}),
],
)
# w1 closed

assert await f4 == 6
12 changes: 4 additions & 8 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ def __init__(

# FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117
pc = PeriodicCallback(self.find_missing, 1000) # type: ignore
self._find_missing_running = False
self.periodic_callbacks["find-missing"] = pc

self._address = contact_address
Expand Down Expand Up @@ -1983,13 +1984,6 @@ def handle_compute_task(
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

if self.validate:
# All previously unknown tasks that were created above by
# ensure_tasks_exists() have been transitioned to fetch or flight
assert all(
ts2.state != "released" for ts2 in (ts, *ts.dependencies)
), self.story(ts, *ts.dependencies)

########################
# Worker State Machine #
########################
Expand Down Expand Up @@ -3432,9 +3426,10 @@ def _readd_busy_worker(self, worker: str) -> None:

@log_errors
async def find_missing(self) -> None:
if not self._missing_dep_flight:
if self._find_missing_running or not self._missing_dep_flight:
return
try:
self._find_missing_running = True
if self.validate:
for ts in self._missing_dep_flight:
assert not ts.who_has
Expand All @@ -3452,6 +3447,7 @@ async def find_missing(self) -> None:
self.transitions(recommendations, stimulus_id=stimulus_id)

finally:
self._find_missing_running = False
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks[
"find-missing"
Expand Down