Skip to content

Commit

Permalink
Remove EnsureCommunicatingAfterTransitions (dask#6462)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 23, 2022
1 parent dc019ed commit 7c40e1b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 122 deletions.
42 changes: 36 additions & 6 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ComputeTaskEvent,
ExecuteFailureEvent,
ExecuteSuccessEvent,
GatherDep,
Instruction,
RecommendationsConflict,
RefreshWhoHasEvent,
Expand Down Expand Up @@ -104,14 +105,16 @@ def test_WorkerState__to_dict(ws):
"busy_workers": [],
"constrained": [],
"data": {"y": None},
"data_needed": ["x"],
"data_needed_per_worker": {"127.0.0.1:1235": ["x"]},
"data_needed": [],
"data_needed_per_worker": {"127.0.0.1:1235": []},
"executing": [],
"in_flight_tasks": [],
"in_flight_workers": {},
"in_flight_tasks": ["x"],
"in_flight_workers": {"127.0.0.1:1235": ["x"]},
"log": [
["x", "ensure-task-exists", "released", "s1"],
["x", "released", "fetch", "fetch", {}, "s1"],
["gather-dependencies", "127.0.0.1:1235", ["x"], "s1"],
["x", "fetch", "flight", "flight", {}, "s1"],
["y", "put-in-memory", "s2"],
["y", "receive-from-scatter", "s2"],
],
Expand All @@ -135,10 +138,11 @@ def test_WorkerState__to_dict(ws):
],
"tasks": {
"x": {
"coming_from": "127.0.0.1:1235",
"key": "x",
"nbytes": 123,
"priority": [1],
"state": "fetch",
"state": "flight",
"who_has": ["127.0.0.1:1235"],
},
"y": {
Expand All @@ -147,7 +151,7 @@ def test_WorkerState__to_dict(ws):
"state": "memory",
},
},
"transition_counter": 1,
"transition_counter": 2,
}
assert actual == expect

Expand Down Expand Up @@ -817,3 +821,29 @@ async def test_deprecated_worker_attributes(s, a, b):
)
with pytest.warns(FutureWarning, match=msg):
assert a.in_flight_tasks == 0


@pytest.mark.parametrize(
"nbytes,n_in_flight",
[
# Note: target_message_size = 50e6 bytes
(int(10e6), 3),
(int(20e6), 2),
(int(30e6), 1),
],
)
def test_aggregate_gather_deps(ws, nbytes, n_in_flight):
instructions = ws.handle_stimulus(
AcquireReplicasEvent(
who_has={
"x1": ["127.0.0.1:1235"],
"x2": ["127.0.0.1:1235"],
"x3": ["127.0.0.1:1235"],
},
nbytes={"x1": nbytes, "x2": nbytes, "x3": nbytes},
stimulus_id="test",
)
)
assert len(instructions) == 1
assert isinstance(instructions[0], GatherDep)
assert len(ws.in_flight_tasks) == n_in_flight
185 changes: 69 additions & 116 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,12 +343,6 @@ class RetryBusyWorkerLater(Instruction):
worker: str


@dataclass
class EnsureCommunicatingAfterTransitions(Instruction):
__slots__ = ()


@dataclass
class SendMessageToScheduler(Instruction):
#: Matches a key in Scheduler.stream_handlers
op: ClassVar[str]
Expand Down Expand Up @@ -1489,13 +1483,7 @@ def _transition_generic_fetch(self, ts: TaskState, stimulus_id: str) -> RecsInst
self.data_needed.add(ts)
for w in ts.who_has:
self.data_needed_per_worker[w].add(ts)

# This is the same as `return self._ensure_communicating()`, except that when
# many tasks transition to fetch at the same time, e.g. from a single
# compute-task or acquire-replicas command from the scheduler, it allows
# clustering the transfers into less GatherDep instructions; see
# _select_keys_for_gather().
return {}, [EnsureCommunicatingAfterTransitions(stimulus_id=stimulus_id)]
return {}, []

def _transition_missing_waiting(
self, ts: TaskState, *, stimulus_id: str
Expand Down Expand Up @@ -2276,18 +2264,30 @@ def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> Instructio
reach a steady state
"""
instructions = []

remaining_recs = recommendations.copy()
tasks = set()
while remaining_recs:
ts, finish = remaining_recs.popitem()
tasks.add(ts)
a_recs, a_instructions = self._transition(
ts, finish, stimulus_id=stimulus_id
)

remaining_recs.update(a_recs)
instructions += a_instructions
def process_recs(recs: Recs) -> None:
while recs:
ts, finish = recs.popitem()
tasks.add(ts)
a_recs, a_instructions = self._transition(
ts, finish, stimulus_id=stimulus_id
)
recs.update(a_recs)
instructions.extend(a_instructions)

process_recs(recommendations.copy())

# We could call _ensure_communicating after we change something that could
# trigger a new call to gather_dep (e.g. on transitions to fetch,
# GatherDepDoneEvent, or RetryBusyWorkerEvent). However, doing so we'd
# potentially call it too early, before all tasks have transitioned to fetch.
# This in turn would hurt aggregation of multiple tasks into a single GatherDep
# instruction.
# Read: https://github.com/dask/distributed/issues/6497
a_recs, a_instructions = self._ensure_communicating(stimulus_id=stimulus_id)
instructions += a_instructions
process_recs(a_recs)

if self.validate:
# Full state validation is very expensive
Expand Down Expand Up @@ -2531,10 +2531,7 @@ def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs:
self.has_what[ev.worker].discard(ts.key)
recommendations[ts] = "fetch"

return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
return recommendations, []

@_handle_event.register
def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs:
Expand Down Expand Up @@ -2563,10 +2560,7 @@ def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs:
)
)

return merge_recs_instructions(
(recommendations, instructions),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
return recommendations, instructions

@_handle_event.register
def _handle_gather_dep_network_failure(
Expand All @@ -2593,10 +2587,7 @@ def _handle_gather_dep_network_failure(
self.log.append((ts.key, "missing-dep", ev.stimulus_id, time()))
recommendations[ts] = "fetch"

return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
return recommendations, []

@_handle_event.register
def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs:
Expand All @@ -2614,10 +2605,7 @@ def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs:
for ts in self._gather_dep_done_common(ev)
}

return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
return recommendations, []

@_handle_event.register
def _handle_secede(self, ev: SecedeEvent) -> RecsInstrs:
Expand Down Expand Up @@ -2657,15 +2645,12 @@ def _handle_pause(self, ev: PauseEvent) -> RecsInstrs:
def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs:
"""Emerge from paused status"""
self.running = True
return merge_recs_instructions(
self._ensure_computing(),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
return self._ensure_computing()

@_handle_event.register
def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs:
self.busy_workers.discard(ev.worker)
return self._ensure_communicating(stimulus_id=ev.stimulus_id)
return {}, []

@_handle_event.register
def _handle_cancel_compute(self, ev: CancelComputeEvent) -> RecsInstrs:
Expand Down Expand Up @@ -2767,17 +2752,13 @@ def _handle_refresh_who_has(self, ev: RefreshWhoHasEvent) -> RecsInstrs:

if ts.who_has and ts.state == "missing":
recommendations[ts] = "fetch"
elif ts.who_has and ts.state == "fetch":
# We potentially just acquired new replicas whereas all previously known
# workers are in flight or busy. We're deliberately not testing the
# minute use cases here for the sake of simplicity; instead we rely on
# _ensure_communicating to be a no-op when there's nothing to do.
recommendations, instructions = merge_recs_instructions(
(recommendations, instructions),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
elif not ts.who_has and ts.state == "fetch":
recommendations[ts] = "missing"
# Note: if ts.who_has and ts.state == "fetch", we may have just acquired new
# replicas whereas all previously known workers are in flight or busy. We
# rely on _transitions to call _ensure_communicating every time, even in
# absence of recommendations, to potentially kick off a new call to
# gather_dep.

return recommendations, instructions

Expand Down Expand Up @@ -3058,73 +3039,45 @@ def handle_stimulus(self, *stims: StateMachineEvent) -> None:
"""
instructions = self.state.handle_stimulus(*stims)

while instructions:
ensure_communicating: EnsureCommunicatingAfterTransitions | None = None
for inst in instructions:
task: asyncio.Task | None = None

if isinstance(inst, SendMessageToScheduler):
self.batched_send(inst.to_dict())

elif isinstance(inst, EnsureCommunicatingAfterTransitions):
# A single compute-task or acquire-replicas command may cause
# multiple tasks to transition to fetch; this in turn means that we
# will receive multiple instances of this instruction.
# _ensure_communicating is a no-op if it runs twice in a row; we're
# not calling it inside the for loop to avoid a O(n^2) condition
# when
# 1. there are many fetches queued because all workers are in flight
# 2. a single compute-task or acquire-replicas command just sent
# many dependencies to fetch at once.
ensure_communicating = inst

elif isinstance(inst, GatherDep):
assert inst.to_gather
keys_str = ", ".join(peekn(27, inst.to_gather)[0])
if len(keys_str) > 80:
keys_str = keys_str[:77] + "..."
task = asyncio.create_task(
self.gather_dep(
inst.worker,
inst.to_gather,
total_nbytes=inst.total_nbytes,
stimulus_id=inst.stimulus_id,
),
name=f"gather_dep({inst.worker}, {{{keys_str}}})",
)

elif isinstance(inst, Execute):
task = asyncio.create_task(
self.execute(inst.key, stimulus_id=inst.stimulus_id),
name=f"execute({inst.key})",
)

elif isinstance(inst, RetryBusyWorkerLater):
task = asyncio.create_task(
self.retry_busy_worker_later(inst.worker),
name=f"retry_busy_worker_later({inst.worker})",
)
for inst in instructions:
task: asyncio.Task | None = None

if isinstance(inst, SendMessageToScheduler):
self.batched_send(inst.to_dict())

elif isinstance(inst, GatherDep):
assert inst.to_gather
keys_str = ", ".join(peekn(27, inst.to_gather)[0])
if len(keys_str) > 80:
keys_str = keys_str[:77] + "..."
task = asyncio.create_task(
self.gather_dep(
inst.worker,
inst.to_gather,
total_nbytes=inst.total_nbytes,
stimulus_id=inst.stimulus_id,
),
name=f"gather_dep({inst.worker}, {{{keys_str}}})",
)

else:
raise TypeError(inst) # pragma: nocover

if task is not None:
self._async_instructions.add(task)
task.add_done_callback(self._handle_stimulus_from_task)

if ensure_communicating:
# Potentially re-fill instructions, causing a second iteration of `while
# instructions` at the top of this method
# FIXME access to private methods
# https://github.com/dask/distributed/issues/6497
recs, instructions = self.state._ensure_communicating(
stimulus_id=ensure_communicating.stimulus_id
elif isinstance(inst, Execute):
task = asyncio.create_task(
self.execute(inst.key, stimulus_id=inst.stimulus_id),
name=f"execute({inst.key})",
)
instructions += self.state._transitions(
recs, stimulus_id=ensure_communicating.stimulus_id

elif isinstance(inst, RetryBusyWorkerLater):
task = asyncio.create_task(
self.retry_busy_worker_later(inst.worker),
name=f"retry_busy_worker_later({inst.worker})",
)

else:
return
raise TypeError(inst) # pragma: nocover

if task is not None:
self._async_instructions.add(task)
task.add_done_callback(self._handle_stimulus_from_task)

async def close(self, timeout: float = 30) -> None:
"""Cancel all asynchronous instructions"""
Expand Down

0 comments on commit 7c40e1b

Please sign in to comment.