Skip to content

Commit

Permalink
Refactor all Worker event handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 1, 2022
1 parent 715d7be commit 052828c
Show file tree
Hide file tree
Showing 7 changed files with 448 additions and 262 deletions.
54 changes: 22 additions & 32 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ def __init__(self, key: str, run_spec: object):
self.has_lost_dependencies = False
self.host_restrictions = None # type: ignore
self.worker_restrictions = None # type: ignore
self.resource_restrictions = None # type: ignore
self.resource_restrictions = {}
self.loose_restrictions = False
self.actor = False
self.prefix = None # type: ignore
Expand Down Expand Up @@ -2670,14 +2670,12 @@ def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None
return s

def consume_resources(self, ts: TaskState, ws: WorkerState):
if ts.resource_restrictions:
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] += required
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] += required

def release_resources(self, ts: TaskState, ws: WorkerState):
if ts.resource_restrictions:
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] -= required
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] -= required

def coerce_hostname(self, host):
"""
Expand Down Expand Up @@ -7076,29 +7074,28 @@ def adaptive_target(self, target_duration=None):
to_close = self.workers_to_close()
return len(self.workers) - len(to_close)

def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str):
def request_acquire_replicas(
self, addr: str, keys: Iterable[str], *, stimulus_id: str
) -> None:
"""Asynchronously ask a worker to acquire a replica of the listed keys from
other workers. This is a fire-and-forget operation which offers no feedback for
success or failure, and is intended for housekeeping and not for computation.
"""
who_has = {}
for key in keys:
ts = self.tasks[key]
who_has[key] = {ws.address for ws in ts.who_has}

who_has = {key: [ws.address for ws in self.tasks[key].who_has] for key in keys}
if self.validate:
assert all(who_has.values())

self.stream_comms[addr].send(
{
"op": "acquire-replicas",
"keys": keys,
"who_has": who_has,
"stimulus_id": stimulus_id,
},
)

def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str):
def request_remove_replicas(
self, addr: str, keys: list[str], *, stimulus_id: str
) -> None:
"""Asynchronously ask a worker to discard its replica of the listed keys.
This must never be used to destroy the last replica of a key. This is a
fire-and-forget operation, intended for housekeeping and not for computation.
Expand All @@ -7109,15 +7106,14 @@ def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str):
to re-add itself to who_has. If the worker agrees to discard the task, there is
no feedback.
"""
ws: WorkerState = self.workers[addr]
validate = self.validate
ws = self.workers[addr]

# The scheduler immediately forgets about the replica and suggests the worker to
# drop it. The worker may refuse, at which point it will send back an add-keys
# message to reinstate it.
for key in keys:
ts: TaskState = self.tasks[key]
if validate:
ts = self.tasks[key]
if self.validate:
# Do not destroy the last copy
assert len(ts.who_has) > 1
self.remove_replica(ts, ws)
Expand Down Expand Up @@ -7298,22 +7294,16 @@ def _task_to_msg(
dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies
},
"nbytes": {dts.key: dts.nbytes for dts in ts.dependencies},
"run_spec": ts.run_spec,
"resource_restrictions": ts.resource_restrictions,
"actor": ts.actor,
"annotations": ts.annotations,
}
if state.validate:
assert all(msg["who_has"].values())

if ts.resource_restrictions:
msg["resource_restrictions"] = ts.resource_restrictions
if ts.actor:
msg["actor"] = True

if isinstance(ts.run_spec, dict):
msg.update(ts.run_spec)
else:
msg["task"] = ts.run_spec

if ts.annotations:
msg["annotations"] = ts.annotations
if isinstance(msg["run_spec"], dict):
assert set(msg["run_spec"]).issubset({"function", "args", "kwargs"})
assert msg["run_spec"].get("function")

return msg

Expand Down
6 changes: 4 additions & 2 deletions distributed/tests/test_failed_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
slowadd,
slowinc,
)
from distributed.worker_state_machine import TaskState
from distributed.worker_state_machine import FreeKeysEvent, TaskState

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -425,7 +425,9 @@ def sink(*args):
# artificially, without notifying the scheduler.
# This can only succeed if B handles the missing data properly by
# removing A from the known sources of keys
a.handle_free_keys(keys=["f1"], stimulus_id="Am I evil?") # Yes, I am!
a.handle_stimulus(
FreeKeysEvent(keys=["f1"], stimulus_id="Am I evil?")
) # Yes, I am!
result_fut = c.submit(sink, futures, workers=x.address)

await result_fut
Expand Down
3 changes: 2 additions & 1 deletion distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
slowidentity,
slowinc,
)
from distributed.worker_state_machine import StealRequestEvent

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -867,7 +868,7 @@ async def test_dont_steal_already_released(c, s, a, b):
while key in a.tasks and a.tasks[key].state != "released":
await asyncio.sleep(0.05)

a.handle_steal_request(key=key, stimulus_id="test")
a.handle_stimulus(StealRequestEvent(key=key, stimulus_id="test"))
assert len(a.batched_stream.buffer) == 1
msg = a.batched_stream.buffer[0]
assert msg["op"] == "steal-response"
Expand Down
142 changes: 88 additions & 54 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import gc
import importlib
import logging
import os
Expand Down Expand Up @@ -72,6 +73,14 @@
error_message,
logger,
)
from distributed.worker_state_machine import (
ComputeTaskEvent,
ExecuteFailureEvent,
ExecuteSuccessEvent,
RemoveReplicasEvent,
SerializedTask,
StealRequestEvent,
)

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -1839,31 +1848,67 @@ async def test_story(c, s, w):

@gen_cluster(client=True, nthreads=[("", 1)])
async def test_stimulus_story(c, s, a):
# Test that substrings aren't matched by stimulus_story()
f = c.submit(inc, 1, key="f")
f1 = c.submit(lambda: "foo", key="f1")
f2 = c.submit(inc, f1, key="f2") # This will fail
await wait([f, f1, f2])

story = a.stimulus_story("f1", "f2")
assert len(story) == 4

assert isinstance(story[0], ComputeTaskEvent)
assert story[0].key == "f1"
assert story[0].run_spec == SerializedTask(task=None) # Not logged

assert isinstance(story[1], ExecuteSuccessEvent)
assert story[1].key == "f1"
assert story[1].value is None # Not logged
assert story[1].handled >= story[0].handled

assert isinstance(story[2], ComputeTaskEvent)
assert story[2].key == "f2"
assert story[2].who_has == {"f1": (a.address,)}
assert story[2].run_spec == SerializedTask(task=None) # Not logged
assert story[2].handled >= story[1].handled

assert isinstance(story[3], ExecuteFailureEvent)
assert story[3].key == "f2"
assert story[3].handled >= story[2].handled


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_worker_descopes_data(c, s, a):
"""Test that data is released on the worker:
1. when it's the output of a successful task
2. when it's the input of a failed task
3. when it's a local variable in the frame of a failed task
4. when it's embedded in the exception of a failed task
"""

class C:
pass
instances = weakref.WeakSet()

f = c.submit(C, key="f") # Test that substrings aren't matched by story()
f2 = c.submit(inc, 2, key="f2")
f3 = c.submit(inc, 3, key="f3")
await wait([f, f2, f3])
def __init__(self):
C.instances.add(self)

# Test that ExecuteSuccessEvent.value is not stored in the the event log
assert isinstance(a.data["f"], C)
ref = weakref.ref(a.data["f"])
del f
while "f" in a.data:
await asyncio.sleep(0.01)
with profile.lock:
assert ref() is None
def f(x):
y = C()
raise Exception(x, y)

f1 = c.submit(C, key="f1")
f2 = c.submit(f, f1, key="f2")
await wait([f2])

story = a.stimulus_story("f", "f2")
assert {ev.key for ev in story} == {"f", "f2"}
assert {ev.type for ev in story} == {C, int}
assert type(a.data["f1"]) is C

prev_handled = story[0].handled
for ev in story[1:]:
assert ev.handled >= prev_handled
prev_handled = ev.handled
del f1
del f2
while a.data:
await asyncio.sleep(0.01)
with profile.lock:
gc.collect()
assert not C.instances


@gen_cluster(client=True)
Expand Down Expand Up @@ -2558,7 +2603,7 @@ def __call__(self, *args, **kwargs):
await asyncio.sleep(0)

ts = s.tasks[fut.key]
a.handle_steal_request(fut.key, stimulus_id="test")
a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test"))
stealing_ext.scheduler.send_task_to_worker(b.address, ts)

fut2 = c.submit(inc, fut, workers=[a.address])
Expand Down Expand Up @@ -2744,41 +2789,31 @@ async def test_acquire_replicas_many(c, s, *workers):
await asyncio.sleep(0.001)


@pytest.mark.slow
@gen_cluster(client=True, Worker=Nanny)
async def test_acquire_replicas_already_in_flight(c, s, *nannies):
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_acquire_replicas_already_in_flight(c, s, a):
"""Trying to acquire a replica that is already in flight is a no-op"""
async with BlockedGatherDep(s.address) as b:
x = c.submit(inc, 1, workers=[a.address], key="x")
y = c.submit(inc, x, workers=[b.address], key="y")
await b.in_gather_dep.wait()
assert b.tasks["x"].state == "flight"

class SlowToFly:
def __getstate__(self):
sleep(0.9)
return {}
s.request_acquire_replicas(b.address, ["x"], stimulus_id=f"test-{time()}")
while not b.story("acquire-replicas"):
await asyncio.sleep(0.01)

a, b = s.workers
x = c.submit(SlowToFly, workers=[a], key="x")
await wait(x)
y = c.submit(lambda x: 123, x, workers=[b], key="y")
await asyncio.sleep(0.3)
s.request_acquire_replicas(b, [x.key], stimulus_id=f"test-{time()}")
assert await y == 123
assert b.tasks["x"].state == "flight"
b.block_gather_dep.set()
assert await y == 3

story = await c.run(lambda dask_worker: dask_worker.story("x"))
assert_story(
story[b],
[
("x", "ensure-task-exists", "released"),
("x", "released", "fetch", "fetch", {}),
("gather-dependencies", a, {"x"}),
("x", "fetch", "flight", "flight", {}),
("request-dep", a, {"x"}),
("x", "ensure-task-exists", "flight"),
("x", "flight", "fetch", "flight", {}),
("receive-dep", a, {"x"}),
("x", "put-in-memory"),
("x", "flight", "memory", "memory", {"y": "ready"}),
],
strict=True,
)
assert_story(
b.story("x"),
[
("x", "fetch", "flight", "flight", {}),
("acquire-replicas", {"x"}),
("x", "flight", "fetch", "flight", {}),
],
)


@gen_cluster(client=True)
Expand Down Expand Up @@ -2936,8 +2971,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers):
if w.address == a.tasks[f1.key].coming_from:
break

coming_from.handle_remove_replicas([f1.key], "test")

coming_from.handle_stimulus(RemoveReplicasEvent(keys=[f1.key], stimulus_id="test"))
await f2

assert_story(a.story(f1.key), [(f1.key, "missing-dep")])
Expand Down
Loading

0 comments on commit 052828c

Please sign in to comment.