Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Worker state machine refactor #5046

Merged
merged 8 commits into from
Sep 27, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions distributed/cfexecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def map(self, fn, *iterables, **kwargs):
raise TypeError("unexpected arguments to map(): %s" % sorted(kwargs))

fs = self._client.map(fn, *iterables, **self._kwargs)
if isinstance(fs, list):
# Below iterator relies on this being a generator to cancel
# remaining futures
fs = (val for val in fs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This chunk should disappear after merging from main


# Yield must be hidden in closure so that the tasks are submitted
# before the first iterator value is required.
Expand Down
18 changes: 0 additions & 18 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,24 +157,6 @@ def transition(self, key, start, finish, **kwargs):
kwargs : More options passed when transitioning
"""

def release_key(self, key, state, cause, reason, report):
"""
Called when the worker releases a task.

Parameters
----------
key : string
state : string
State of the released task.
One of waiting, ready, executing, long-running, memory, error.
cause : string or None
Additional information on what triggered the release of the task.
reason : None
Not used.
report : bool
Whether the worker should report the released task to the scheduler.
"""


class NannyPlugin:
"""Interface to extend the Nanny
Expand Down
75 changes: 62 additions & 13 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def transition(self, key, start, finish, **kwargs):
{"key": key, "start": start, "finish": finish}
)

def release_key(self, key, state, cause, reason, report):
self.observed_notifications.append({"key": key, "state": state})


@gen_cluster(client=True, nthreads=[])
async def test_create_with_client(c, s):
Expand Down Expand Up @@ -107,11 +104,12 @@ async def test_create_on_construction(c, s, a, b):
@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_normal_task_transitions_called(c, s, w):
expected_notifications = [
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "state": "memory"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -127,11 +125,12 @@ def failing(x):
raise Exception()

expected_notifications = [
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "error"},
{"key": "task", "state": "error"},
{"key": "task", "start": "error", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -147,11 +146,12 @@ def failing(x):
)
async def test_superseding_task_transitions_called(c, s, w):
expected_notifications = [
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "constrained"},
{"key": "task", "start": "constrained", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "state": "memory"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -166,16 +166,18 @@ async def test_dependent_tasks(c, s, w):
dsk = {"dep": 1, "task": (inc, "dep")}

expected_notifications = [
{"key": "dep", "start": "new", "finish": "waiting"},
{"key": "dep", "start": "released", "finish": "waiting"},
{"key": "dep", "start": "waiting", "finish": "ready"},
{"key": "dep", "start": "ready", "finish": "executing"},
{"key": "dep", "start": "executing", "finish": "memory"},
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "dep", "state": "memory"},
{"key": "task", "state": "memory"},
{"key": "dep", "start": "memory", "finish": "released"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
{"key": "dep", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand Down Expand Up @@ -219,3 +221,50 @@ class MyCustomPlugin(WorkerPlugin):
await c.register_worker_plugin(MyCustomPlugin())
assert len(w.plugins) == 1
assert next(iter(w.plugins)).startswith("MyCustomPlugin-")


def test_release_key_deprecated():
class ReleaseKeyDeprecated(WorkerPlugin):
def __init__(self):
self._called = False

def release_key(self, key, state, cause, reason, report):
# Ensure that the handler still works
self._called = True
assert state == "memory"
assert key == "task"

def teardown(self, worker):
assert self._called
return super().teardown(worker)

@gen_cluster(client=True, nthreads=[("", 1)])
async def test(c, s, a):

await c.register_worker_plugin(ReleaseKeyDeprecated())
fut = await c.submit(inc, 1, key="task")
assert fut == 2

with pytest.deprecated_call(
match="The `WorkerPlugin.release_key` hook is depreacted"
):
test()


def test_assert_no_warning_no_overload():
"""Assert we do not receive a deprecation warning if we do not overload any
methods
"""

class Dummy(WorkerPlugin):
pass

@gen_cluster(client=True, nthreads=[("", 1)])
async def test(c, s, a):

await c.register_worker_plugin(Dummy())
fut = await c.submit(inc, 1, key="task")
assert fut == 2

with pytest.warns(None):
test()
37 changes: 23 additions & 14 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2229,7 +2229,7 @@ def _transition(self, key, finish: str, *args, **kwargs):
self._transition_counter += 1
recommendations, client_msgs, worker_msgs = a
elif "released" not in start_finish:
assert not args and not kwargs
assert not args and not kwargs, start_finish
a_recs: dict
a_cmsgs: dict
a_wmsgs: dict
Expand Down Expand Up @@ -3048,7 +3048,11 @@ def transition_processing_released(self, key):
w: str = _remove_from_processing(self, ts)
if w:
worker_msgs[w] = [
{"op": "free-keys", "keys": [key], "reason": "Processing->Released"}
{
"op": "free-keys",
"keys": [key],
"reason": f"processing-released-{time()}",
}
]

ts.state = "released"
Expand Down Expand Up @@ -5398,7 +5402,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
self.log.append(("missing", key, errant_worker))

ts: TaskState = parent._tasks.get(key)
if ts is None or not ts._who_has:
if ts is None:
return
ws: WorkerState = parent._workers_dv.get(errant_worker)
if ws is not None and ws in ts._who_has:
Expand All @@ -5411,15 +5415,15 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
else:
self.transitions({key: "forgotten"})

def release_worker_data(self, comm=None, keys=None, worker=None):
def release_worker_data(self, comm=None, key=None, worker=None):
parent: SchedulerState = cast(SchedulerState, self)
if worker not in parent._workers_dv:
return
ws: WorkerState = parent._workers_dv[worker]
tasks: set = {parent._tasks[k] for k in keys if k in parent._tasks}
removed_tasks: set = tasks.intersection(ws._has_what)

ts: TaskState
ts = parent._tasks.get(key)
recommendations: dict = {}
for ts in removed_tasks:
if ts and ts in ws._has_what:
del ws._has_what[ts]
ws._nbytes -= ts.get_nbytes()
wh: set = ts._who_has
Expand Down Expand Up @@ -6670,7 +6674,7 @@ def add_keys(self, comm=None, worker=None, keys=()):
if worker not in parent._workers_dv:
return "not found"
ws: WorkerState = parent._workers_dv[worker]
superfluous_data = []
redundant_replicas = []
for key in keys:
ts: TaskState = parent._tasks.get(key)
if ts is not None and ts._state == "memory":
Expand All @@ -6679,14 +6683,15 @@ def add_keys(self, comm=None, worker=None, keys=()):
ws._has_what[ts] = None
ts._who_has.add(ws)
else:
superfluous_data.append(key)
if superfluous_data:
redundant_replicas.append(key)

if redundant_replicas:
self.worker_send(
worker,
{
"op": "superfluous-data",
"keys": superfluous_data,
"reason": f"Add keys which are not in-memory {superfluous_data}",
"op": "remove-replicas",
"keys": redundant_replicas,
"stimulus_id": f"redundant-replicas-{time()}",
},
)

Expand Down Expand Up @@ -7794,6 +7799,8 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->
"key": ts._key,
"priority": ts._priority,
"duration": duration,
"stimulus_id": f"compute-task-{time()}",
"who_has": {},
}
if ts._resource_restrictions:
msg["resource_restrictions"] = ts._resource_restrictions
Expand All @@ -7818,6 +7825,8 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->

if ts._annotations:
msg["annotations"] = ts._annotations

assert "stimulus_id" in msg
return msg


Expand Down
12 changes: 10 additions & 2 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,15 @@ async def move_task_confirm(self, key=None, worker=None, state=None):
return

# Victim had already started execution, reverse stealing
if state in ("memory", "executing", "long-running", None):
if state in (
"memory",
"executing",
"long-running",
"released",
"cancelled",
"resumed",
None,
):
self.log(("already-computing", key, victim.address, thief.address))
self.scheduler.check_idle_saturated(thief)
self.scheduler.check_idle_saturated(victim)
Expand All @@ -256,7 +264,7 @@ async def move_task_confirm(self, key=None, worker=None, state=None):
await self.scheduler.remove_worker(thief.address)
self.log(("confirm", key, victim.address, thief.address))
else:
raise ValueError("Unexpected task state: %s" % state)
raise ValueError(f"Unexpected task state: {ts}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError(f"Unexpected task state: {ts}")
raise ValueError(f"Unexpected task state: {state}")

except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down
Loading