Skip to content

Commit

Permalink
WSMR/update_who_has
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 16, 2022
1 parent 5ca7a5a commit 96f06dc
Show file tree
Hide file tree
Showing 5 changed files with 270 additions and 69 deletions.
11 changes: 4 additions & 7 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7108,23 +7108,20 @@ def adaptive_target(self, target_duration=None):
to_close = self.workers_to_close()
return len(self.workers) - len(to_close)

def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str):
def request_acquire_replicas(
self, addr: str, keys: Iterable[str], *, stimulus_id: str
):
"""Asynchronously ask a worker to acquire a replica of the listed keys from
other workers. This is a fire-and-forget operation which offers no feedback for
success or failure, and is intended for housekeeping and not for computation.
"""
who_has = {}
for key in keys:
ts = self.tasks[key]
who_has[key] = {ws.address for ws in ts.who_has}

who_has = {key: {ws.address for ws in self.tasks[key].who_has} for key in keys}
if self.validate:
assert all(who_has.values())

self.stream_comms[addr].send(
{
"op": "acquire-replicas",
"keys": keys,
"who_has": who_has,
"stimulus_id": stimulus_id,
},
Expand Down
1 change: 1 addition & 0 deletions distributed/tests/test_stories.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ async def test_worker_story_with_deps(c, s, a, b):
assert stimulus_ids == {"compute-task"}
expected = [
("dep", "ensure-task-exists", "released"),
("dep", "update-who-has", [], [a.address]),
("dep", "released", "fetch", "fetch", {}),
("gather-dependencies", a.address, {"dep"}),
("dep", "fetch", "flight", "flight", {}),
Expand Down
1 change: 1 addition & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2814,6 +2814,7 @@ def __getstate__(self):
story[b],
[
("x", "ensure-task-exists", "released"),
("x", "update-who-has", [], [a]),
("x", "released", "fetch", "fetch", {}),
("gather-dependencies", a, {"x"}),
("x", "fetch", "flight", "flight", {}),
Expand Down
207 changes: 180 additions & 27 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import asyncio
from contextlib import contextmanager
from itertools import chain

import pytest

from distributed import Worker
from distributed.core import Status
from distributed.protocol.serialize import Serialize
from distributed.utils import recursive_to_dict
from distributed.utils_test import assert_story, gen_cluster, inc
Expand All @@ -16,12 +19,13 @@
SendMessageToScheduler,
StateMachineEvent,
TaskState,
TaskStateState,
UniqueTaskHeap,
merge_recs_instructions,
)


async def wait_for_state(key, state, dask_worker):
async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -> None:
while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
await asyncio.sleep(0.005)

Expand Down Expand Up @@ -245,28 +249,39 @@ def test_executefailure_to_dict():
assert ev3.traceback_text == "tb text"


@gen_cluster(client=True)
async def test_fetch_to_compute(c, s, a, b):
# Block ensure_communicating to ensure we indeed know that the task is in
# fetch and doesn't leave it accidentally
old_out_connections, b.total_out_connections = b.total_out_connections, 0
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0
@contextmanager
def freeze_inbound_comms(w: Worker):
"""Prevent any task from transitioning from fetch to flight on the worker while
inside the context.
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
f2 = c.submit(inc, f1, workers=[b.address], key="f2")
This is not the same as setting the worker to Status=paused, which would also
inform the Scheduler and prevent further tasks to be enqueued on the worker.
"""
old_out_connections = w.total_out_connections
old_comm_threshold = w.comm_threshold_bytes
w.total_out_connections = 0
w.comm_threshold_bytes = 0
yield
w.total_out_connections = old_out_connections
w.comm_threshold_bytes = old_comm_threshold
# Jump-start ensure_communicating
w.status = Status.paused
w.status = Status.running

await wait_for_state(f1.key, "fetch", b)
await a.close()

b.total_out_connections = old_out_connections
b.comm_threshold_bytes = old_comm_threshold
@gen_cluster(client=True)
async def test_fetch_to_compute(c, s, a, b):
with freeze_inbound_comms(b):
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
f2 = c.submit(inc, f1, workers=[b.address], key="f2")
await wait_for_state(f1.key, "fetch", b)
await a.close()

await f2

assert_story(
b.log,
# FIXME: This log should be replaced with an
# StateMachineEvent/Instruction log
# FIXME: This log should be replaced with a StateMachineEvent log
[
(f2.key, "compute-task", "released"),
# This is a "please fetch" request. We don't have anything like
Expand All @@ -285,20 +300,158 @@ async def test_fetch_to_compute(c, s, a, b):

@gen_cluster(client=True)
async def test_fetch_via_amm_to_compute(c, s, a, b):
# Block ensure_communicating to ensure we indeed know that the task is in
# fetch and doesn't leave it accidentally
old_out_connections, b.total_out_connections = b.total_out_connections, 0
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0

f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
with freeze_inbound_comms(b):
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
await f1
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")
await wait_for_state(f1.key, "fetch", b)
await a.close()

await f1
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")

await wait_for_state(f1.key, "fetch", b)
await a.close()
assert_story(
b.log,
# FIXME: This log should be replaced with a StateMachineEvent log
[
(f1.key, "ensure-task-exists", "released"),
(f1.key, "released", "fetch", "fetch", {}),
(f1.key, "compute-task", "fetch"),
(f1.key, "put-in-memory"),
],
)

b.total_out_connections = old_out_connections
b.comm_threshold_bytes = old_comm_threshold

await f1
@pytest.mark.parametrize("as_deps", [False, True])
@gen_cluster(client=True, nthreads=[("", 1)] * 3)
async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps):
"""
as_deps=True
0. task x is a dependency of y1 and y2
1. scheduler calls handle_compute("y1", who_has={"x": [w2, w3]}) on w1
2. x transitions released -> fetch
3. the network stack is busy, so x does not transition to flight yet.
4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1
5. when x finally reaches the top of the data_needed heap, the w1 will not try
contacting w2
as_deps=False
1. scheduler calls handle_acquire_replicas(who_has={"x": [w2, w3]}) on w1
2. x transitions released -> fetch
3. the network stack is busy, so x does not transition to flight yet.
4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1
5. when x finally reaches the top of the data_needed heap, the w1 will not try
contacting w2
"""
x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[
"x"
]
with freeze_inbound_comms(w1):
if as_deps:
y1 = c.submit(inc, x, key="y1", workers=[w1.address])
else:
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")

await wait_for_state("x", "fetch", w1)
assert w1.tasks["x"].who_has == {w2.address, w3.address}

assert len(s.tasks["x"].who_has) == 2
s.handle_missing_data(
key="x", worker="na", errant_worker=w2.address, stimulus_id="test"
)
assert len(s.tasks["x"].who_has) == 1

if as_deps:
y2 = c.submit(inc, x, key="y2", workers=[w1.address])
else:
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")

while w1.tasks["x"].who_has != {w3.address}:
await asyncio.sleep(0.01)

await wait_for_state("x", "memory", w1)

assert_story(
w1.story("request-dep"),
[("request-dep", w3.address, {"x"})],
# This tests that there has been no attempt to contact w2.
# If the assumption being tested breaks, this will fail 50% of the times.
strict=True,
)


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_fetch_to_missing(c, s, a, b):
"""
1. task x is a dependency of y
2. scheduler calls handle_compute("y", who_has={"x": [b]}) on a
3. x transitions released -> fetch -> flight; a connects to b
4. b responds it's busy. x transitions flight -> fetch
5. The busy state triggers an RPC call to Scheduler.who_has
6. the scheduler responds {"x": []}, because w1 in the meantime has lost the key.
7. x is transitioned fetch -> missing
"""
x = (await c.scatter({"x": 1}, workers=[b.address]))["x"]
b.total_in_connections = 0
with freeze_inbound_comms(a):
y = c.submit(inc, x, key="y", workers=[a.address])
await wait_for_state("x", "fetch", a)
# Do not use handle_missing_data, since it would cause the scheduler to call
# handle_free_keys(["y"]) on a
s.remove_replica(ts=s.tasks["x"], ws=s.workers[b.address])
# We used a scheduler internal call, thus corrupting its state.
# Don't crash at the end of the test.
s.validate = False

await wait_for_state("x", "missing", a)
assert_story(
a.story("x"),
[
("x", "ensure-task-exists", "released"),
("x", "update-who-has", [], [b.address]),
("x", "released", "fetch", "fetch", {}),
("gather-dependencies", b.address, {"x"}),
("x", "fetch", "flight", "flight", {}),
("request-dep", b.address, {"x"}),
("busy-gather", b.address, {"x"}),
("x", "flight", "fetch", "fetch", {}),
("x", "update-who-has", [b.address], []), # Called Scheduler.who_has
("x", "fetch", "missing", "missing", {}),
],
# There may be a round of find_missing() after this
strict=False,
)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_self_denounce_missing_data(c, s, a):
x = c.submit(inc, 1, key="x")
await x

# Manually wipe x from the worker
# (Using an endpoint like remove-replicas would inform the scheduler).
del a.data["x"]
del a.tasks["x"]
a.validate_state()

y = c.submit(inc, x, key="y")
# The scheduler tries computing y, but a responds that x is not available.
# The scheduler kicks off the computation of x and then y from scratch.
assert await y == 3

assert_story(
a.log,
[
# Omitting uninteresting events
("y", "compute-task", "released"),
("x", "ensure-task-exists", "released"),
("x", "released", "fetch", "released", {"x": "missing"}),
("x", "released", "missing", "missing", {}),
("y", "release-key"),
("x", "release-key"),
("x", "compute-task", "released"),
("x", "executing", "memory", "memory", {}),
("y", "compute-task", "released"),
("x", "ensure-task-exists", "memory"),
("y", "executing", "memory", "memory", {}),
],
)
Loading

0 comments on commit 96f06dc

Please sign in to comment.