Skip to content

Commit

Permalink
Worker state machine refactor (#5046)
Browse files Browse the repository at this point in the history
Co-authored-by: crusaderky <crusaderky@gmail.com>
  • Loading branch information
fjetter and crusaderky authored Sep 27, 2021
1 parent e5eb40c commit a8d4ffa
Show file tree
Hide file tree
Showing 12 changed files with 1,508 additions and 1,091 deletions.
18 changes: 0 additions & 18 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,24 +161,6 @@ def transition(self, key, start, finish, **kwargs):
kwargs : More options passed when transitioning
"""

def release_key(self, key, state, cause, reason, report):
"""
Called when the worker releases a task.
Parameters
----------
key : string
state : string
State of the released task.
One of waiting, ready, executing, long-running, memory, error.
cause : string or None
Additional information on what triggered the release of the task.
reason : None
Not used.
report : bool
Whether the worker should report the released task to the scheduler.
"""


class NannyPlugin:
"""Interface to extend the Nanny
Expand Down
75 changes: 62 additions & 13 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def transition(self, key, start, finish, **kwargs):
{"key": key, "start": start, "finish": finish}
)

def release_key(self, key, state, cause, reason, report):
self.observed_notifications.append({"key": key, "state": state})


@gen_cluster(client=True, nthreads=[])
async def test_create_with_client(c, s):
Expand Down Expand Up @@ -107,11 +104,12 @@ async def test_create_on_construction(c, s, a, b):
@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_normal_task_transitions_called(c, s, w):
expected_notifications = [
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "state": "memory"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -127,11 +125,12 @@ def failing(x):
raise Exception()

expected_notifications = [
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "error"},
{"key": "task", "state": "error"},
{"key": "task", "start": "error", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -147,11 +146,12 @@ def failing(x):
)
async def test_superseding_task_transitions_called(c, s, w):
expected_notifications = [
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "constrained"},
{"key": "task", "start": "constrained", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "state": "memory"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -166,16 +166,18 @@ async def test_dependent_tasks(c, s, w):
dsk = {"dep": 1, "task": (inc, "dep")}

expected_notifications = [
{"key": "dep", "start": "new", "finish": "waiting"},
{"key": "dep", "start": "released", "finish": "waiting"},
{"key": "dep", "start": "waiting", "finish": "ready"},
{"key": "dep", "start": "ready", "finish": "executing"},
{"key": "dep", "start": "executing", "finish": "memory"},
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "dep", "state": "memory"},
{"key": "task", "state": "memory"},
{"key": "dep", "start": "memory", "finish": "released"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
{"key": "dep", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand Down Expand Up @@ -203,6 +205,53 @@ class MyCustomPlugin(WorkerPlugin):
assert next(iter(w.plugins)).startswith("MyCustomPlugin-")


def test_release_key_deprecated():
class ReleaseKeyDeprecated(WorkerPlugin):
def __init__(self):
self._called = False

def release_key(self, key, state, cause, reason, report):
# Ensure that the handler still works
self._called = True
assert state == "memory"
assert key == "task"

def teardown(self, worker):
assert self._called
return super().teardown(worker)

@gen_cluster(client=True, nthreads=[("", 1)])
async def test(c, s, a):

await c.register_worker_plugin(ReleaseKeyDeprecated())
fut = await c.submit(inc, 1, key="task")
assert fut == 2

with pytest.deprecated_call(
match="The `WorkerPlugin.release_key` hook is depreacted"
):
test()


def test_assert_no_warning_no_overload():
"""Assert we do not receive a deprecation warning if we do not overload any
methods
"""

class Dummy(WorkerPlugin):
pass

@gen_cluster(client=True, nthreads=[("", 1)])
async def test(c, s, a):

await c.register_worker_plugin(Dummy())
fut = await c.submit(inc, 1, key="task")
assert fut == 2

with pytest.warns(None):
test()


@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_WorkerPlugin_overwrite(c, s, w):
class MyCustomPlugin(WorkerPlugin):
Expand Down
34 changes: 19 additions & 15 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3001,7 +3001,11 @@ def transition_processing_released(self, key):
w: str = _remove_from_processing(self, ts)
if w:
worker_msgs[w] = [
{"op": "free-keys", "keys": [key], "reason": "Processing->Released"}
{
"op": "free-keys",
"keys": [key],
"reason": f"processing-released-{time()}",
}
]

ts.state = "released"
Expand Down Expand Up @@ -5367,7 +5371,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
self.log.append(("missing", key, errant_worker))

ts: TaskState = parent._tasks.get(key)
if ts is None or not ts._who_has:
if ts is None:
return
ws: WorkerState = parent._workers_dv.get(errant_worker)
if ws is not None and ws in ts._who_has:
Expand All @@ -5380,17 +5384,14 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
else:
self.transitions({key: "forgotten"})

def release_worker_data(self, comm=None, keys=None, worker=None):
def release_worker_data(self, comm=None, key=None, worker=None):
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState = parent._workers_dv.get(worker)
if not ws:
ts: TaskState = parent._tasks.get(key)
if not ws or not ts:
return
tasks: set = {parent._tasks[k] for k in keys if k in parent._tasks}
removed_tasks: set = tasks.intersection(ws._has_what)

ts: TaskState
recommendations: dict = {}
for ts in removed_tasks:
if ts in ws._has_what:
del ws._has_what[ts]
ws._nbytes -= ts.get_nbytes()
wh: set = ts._who_has
Expand Down Expand Up @@ -6709,7 +6710,7 @@ def add_keys(self, comm=None, worker=None, keys=()):
if worker not in parent._workers_dv:
return "not found"
ws: WorkerState = parent._workers_dv[worker]
superfluous_data = []
redundant_replicas = []
for key in keys:
ts: TaskState = parent._tasks.get(key)
if ts is not None and ts._state == "memory":
Expand All @@ -6718,14 +6719,15 @@ def add_keys(self, comm=None, worker=None, keys=()):
ws._has_what[ts] = None
ts._who_has.add(ws)
else:
superfluous_data.append(key)
if superfluous_data:
redundant_replicas.append(key)

if redundant_replicas:
self.worker_send(
worker,
{
"op": "superfluous-data",
"keys": superfluous_data,
"reason": f"Add keys which are not in-memory {superfluous_data}",
"op": "remove-replicas",
"keys": redundant_replicas,
"stimulus_id": f"redundant-replicas-{time()}",
},
)

Expand Down Expand Up @@ -7867,6 +7869,8 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->
"key": ts._key,
"priority": ts._priority,
"duration": duration,
"stimulus_id": f"compute-task-{time()}",
"who_has": {},
}
if ts._resource_restrictions:
msg["resource_restrictions"] = ts._resource_restrictions
Expand Down
10 changes: 9 additions & 1 deletion distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,15 @@ async def move_task_confirm(self, key=None, worker=None, state=None):
return

# Victim had already started execution, reverse stealing
if state in ("memory", "executing", "long-running", None):
if state in (
"memory",
"executing",
"long-running",
"released",
"cancelled",
"resumed",
None,
):
self.log(("already-computing", key, victim.address, thief.address))
self.scheduler.check_idle_saturated(thief)
self.scheduler.check_idle_saturated(victim)
Expand Down
133 changes: 133 additions & 0 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import asyncio
from unittest import mock

import distributed
from distributed.core import CommClosedError
from distributed.utils_test import _LockedCommPool, gen_cluster, inc, slowinc


async def wait_for_state(key, state, dask_worker):
while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
await asyncio.sleep(0.005)


async def wait_for_cancelled(key, dask_worker):
while key in dask_worker.tasks:
if dask_worker.tasks[key].state == "cancelled":
return
await asyncio.sleep(0.005)
assert False


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_abort_execution_release(c, s, a):
fut = c.submit(slowinc, 1, delay=1)
await wait_for_state(fut.key, "executing", a)
fut.release()
await wait_for_cancelled(fut.key, a)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_abort_execution_reschedule(c, s, a):
fut = c.submit(slowinc, 1, delay=1)
await wait_for_state(fut.key, "executing", a)
fut.release()
await wait_for_cancelled(fut.key, a)
fut = c.submit(slowinc, 1, delay=0.1)
await fut


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_abort_execution_add_as_dependency(c, s, a):
fut = c.submit(slowinc, 1, delay=1)
await wait_for_state(fut.key, "executing", a)
fut.release()
await wait_for_cancelled(fut.key, a)

fut = c.submit(slowinc, 1, delay=1)
fut = c.submit(slowinc, fut, delay=1)
await fut


@gen_cluster(client=True)
async def test_abort_execution_to_fetch(c, s, a, b):
fut = c.submit(slowinc, 1, delay=2, key="f1", workers=[a.worker_address])
await wait_for_state(fut.key, "executing", a)
fut.release()
await wait_for_cancelled(fut.key, a)

# While the first worker is still trying to compute f1, we'll resubmit it to
# another worker with a smaller delay. The key is still the same
fut = c.submit(inc, 1, key="f1", workers=[b.worker_address])
# then, a must switch the execute to fetch. Instead of doing so, it will
# simply re-use the currently computing result.
fut = c.submit(inc, fut, workers=[a.worker_address], key="f2")
await fut


@gen_cluster(client=True)
async def test_worker_find_missing(c, s, a, b):
fut = c.submit(inc, 1, workers=[a.address])
await fut
# We do not want to use proper API since it would ensure that the cluster is
# informed properly
del a.data[fut.key]
del a.tasks[fut.key]

# Actually no worker has the data; the scheduler is supposed to reschedule
assert await c.submit(inc, fut, workers=[b.address]) == 3


@gen_cluster(client=True)
async def test_worker_stream_died_during_comm(c, s, a, b):
write_queue = asyncio.Queue()
write_event = asyncio.Event()
b.rpc = _LockedCommPool(
b.rpc,
write_queue=write_queue,
write_event=write_event,
)
fut = c.submit(inc, 1, workers=[a.address], allow_other_workers=True)
await fut
# Actually no worker has the data; the scheduler is supposed to reschedule
res = c.submit(inc, fut, workers=[b.address])

await write_queue.get()
await a.close()
write_event.set()

await res
assert any("receive-dep-failed" in msg for msg in b.log)


@gen_cluster(client=True)
async def test_flight_to_executing_via_cancelled_resumed(c, s, a, b):
lock = asyncio.Lock()
await lock.acquire()

async def wait_and_raise(*args, **kwargs):
async with lock:
raise CommClosedError()

with mock.patch.object(
distributed.worker,
"get_data_from_worker",
side_effect=wait_and_raise,
):
fut1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True)
fut2 = c.submit(inc, fut1, workers=[b.address])

await wait_for_state(fut1.key, "flight", b)

# Close in scheduler to ensure we transition and reschedule task properly
await s.close_worker(worker=a.address)
await wait_for_state(fut1.key, "resumed", b)

lock.release()
assert await fut2 == 3

b_story = b.story(fut1.key)
assert any("receive-dep-failed" in msg for msg in b_story)
assert any("missing-dep" in msg for msg in b_story)
assert any("cancelled" in msg for msg in b_story)
assert any("resumed" in msg for msg in b_story)
Loading

0 comments on commit a8d4ffa

Please sign in to comment.