Skip to content

Commit

Permalink
Refactor find-missing and refresh_who_has
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 16, 2022
1 parent 959ca3f commit 8eb8eea
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 42 deletions.
24 changes: 20 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3022,6 +3022,7 @@ def __init__(
"keep-alive": lambda *args, **kwargs: None,
"log-event": self.log_worker_event,
"worker-status-change": self.handle_worker_status_change,
"refresh-who-has": self.handle_refresh_who_has,
}

client_handlers = {
Expand Down Expand Up @@ -4796,6 +4797,21 @@ def handle_worker_status_change(
else:
self.running.discard(ws)

async def handle_refresh_who_has(
self, keys: Iterable[str], worker: str, stimulus_id: str
) -> None:
"""Asynchronous request (through bulk comms) from a Worker to refresh the
who_has for some keys. Not to be confused with scheduler.who_has, which is a
synchronous RPC request from a Client.
"""
self.stream_comms[worker].send(
{
"op": "refresh-who-has",
"who_has": self.get_who_has(keys),
"stimulus_id": stimulus_id,
},
)

async def handle_worker(self, comm=None, worker=None, stimulus_id=None):
"""
Listen to responses from a single worker
Expand Down Expand Up @@ -6268,13 +6284,13 @@ def get_processing(self, workers=None):
w: [ts.key for ts in ws.processing] for w, ws in self.workers.items()
}

def get_who_has(self, keys=None):
def get_who_has(self, keys: Iterable[str] | None = None) -> dict[str, list[str]]:
if keys is not None:
return {
k: [ws.address for ws in self.tasks[k].who_has]
if k in self.tasks
key: [ws.address for ws in self.tasks[key].who_has]
if key in self.tasks
else []
for k in keys
for key in keys
}
else:
return {
Expand Down
91 changes: 53 additions & 38 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
Execute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
GatherDep,
GatherDepDoneEvent,
Instructions,
Expand All @@ -123,6 +124,8 @@
MissingDataMsg,
Recs,
RecsInstrs,
RefreshWhoHasEvent,
RefreshWhoHasMsg,
ReleaseWorkerDataMsg,
RescheduleEvent,
RescheduleMsg,
Expand Down Expand Up @@ -798,6 +801,7 @@ def __init__(
"compute-task": self.handle_compute_task,
"free-keys": self.handle_free_keys,
"remove-replicas": self.handle_remove_replicas,
"refresh-who-has": self.handle_refresh_who_has,
"steal-request": self.handle_steal_request,
"worker-status-change": self.handle_worker_status_change,
}
Expand Down Expand Up @@ -1822,6 +1826,13 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str:

return "OK"

def handle_refresh_who_has(
self, who_has: dict[str, list[str]], stimulus_id: str
) -> None:
self.handle_stimulus(
RefreshWhoHasEvent(who_has=who_has, stimulus_id=stimulus_id)
)

async def set_resources(self, **resources) -> None:
for r, quantity in resources.items():
if r in self.total_resources:
Expand Down Expand Up @@ -3423,22 +3434,19 @@ def done_event():
)
)
recommendations[ts] = "fetch"
del data, response
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
who_has = await retry_operation(
self.scheduler.who_has, keys=refresh_who_has
)
refresh_stimulus_id = f"refresh-who-has-{time()}"
recommendations, instructions = self._update_who_has(
who_has, stimulus_id=refresh_stimulus_id
instructions.append(
RefreshWhoHasMsg(
keys=list(refresh_who_has),
stimulus_id=f"gather-dep-busy-{time()}",
)
)
self.transitions(recommendations, stimulus_id=refresh_stimulus_id)
self._handle_instructions(instructions)

self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

@log_errors
def _readd_busy_worker(self, worker: str) -> None:
Expand All @@ -3449,39 +3457,19 @@ def _readd_busy_worker(self, worker: str) -> None:

@log_errors
async def find_missing(self) -> None:
if not self._missing_dep_flight:
return
try:
if self.validate:
for ts in self._missing_dep_flight:
assert not ts.who_has

stimulus_id = f"find-missing-{time()}"
who_has = await retry_operation(
self.scheduler.who_has,
keys=[ts.key for ts in self._missing_dep_flight],
)
recommendations, instructions = self._update_who_has(
who_has, stimulus_id=stimulus_id
)
for ts in self._missing_dep_flight:
if ts.who_has:
assert ts not in recommendations
recommendations[ts] = "fetch"
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)
self.handle_stimulus(FindMissingEvent(stimulus_id=f"find-missing-{time()}"))

finally:
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks[
"find-missing"
].callback_time = self.periodic_callbacks["heartbeat"].callback_time
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks["find-missing"].callback_time = self.periodic_callbacks[
"heartbeat"
].callback_time

def _update_who_has(
self, who_has: dict[str, Collection[str]], *, stimulus_id: str
self, who_has: Mapping[str, Collection[str]], *, stimulus_id: str
) -> RecsInstrs:
recs: Recs = {}
instructions: Instructions = []
ensure_communicating = False

for key, workers in who_has.items():
ts = self.tasks.get(key)
Expand Down Expand Up @@ -3527,11 +3515,19 @@ def _update_who_has(
self.has_what[worker].add(key)
if ts.state == "fetch":
self.data_needed_per_worker[worker].push(ts)
ensure_communicating = True

ts.who_has = workers
if not workers and ts.state == "fetch":
recs[ts] = "missing"
elif workers and ts.state == "missing":
recs[ts] = "fetch"

if ensure_communicating:
recs, instructions = merge_recs_instructions(
(recs, instructions),
self._ensure_communicating(stimulus_id=stimulus_id),
)
return recs, instructions

def handle_steal_request(self, key: str, stimulus_id: str) -> None:
Expand Down Expand Up @@ -4056,6 +4052,25 @@ def _(self, ev: RescheduleEvent) -> RecsInstrs:
assert ts, self.story(ev.key)
return {ts: "rescheduled"}, []

@handle_event.register
def _(self, ev: FindMissingEvent) -> RecsInstrs:
if not self._missing_dep_flight:
return {}, []

if self.validate:
for ts in self._missing_dep_flight:
assert not ts.who_has

smsg = RefreshWhoHasMsg(
keys=[ts.key for ts in self._missing_dep_flight],
stimulus_id=ev.stimulus_id,
)
return {}, [smsg]

@handle_event.register
def _(self, ev: RefreshWhoHasEvent) -> RecsInstrs:
return self._update_who_has(ev.who_has, stimulus_id=ev.stimulus_id)

def _prepare_args_for_execution(
self, ts: TaskState, args: tuple, kwargs: dict[str, Any]
) -> tuple[tuple, dict[str, Any]]:
Expand Down
39 changes: 39 additions & 0 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,26 @@ class AddKeysMsg(SendMessageToScheduler):
keys: list[str]


@dataclass
class RefreshWhoHasMsg(SendMessageToScheduler):
"""Worker -> Scheduler asynchronous request for updated who_has information.
Not to be confused with the scheduler.who_has synchronous RPC call, which is used
by the Client.
See also
--------
RefreshWhoHasEvent
distributed.scheduler.Scheduler.refresh_who_has
distributed.client.Client.who_has
distributed.scheduler.Scheduler.get_who_has
"""

op = "refresh-who-has"

__slots__ = ("keys",)
keys: list[str]


@dataclass
class StateMachineEvent:
__slots__ = ("stimulus_id", "handled")
Expand Down Expand Up @@ -508,6 +528,25 @@ class RescheduleEvent(StateMachineEvent):
key: str


@dataclass
class FindMissingEvent(StateMachineEvent):
__slots__ = ()


@dataclass
class RefreshWhoHasEvent(StateMachineEvent):
"""Scheduler -> Worker message containing updated who_has information.
See also
--------
RefreshWhoHasMsg
"""

__slots__ = ("who_has",)
# {key: [worker address, ...]}
who_has: dict[str, list[str]]


if TYPE_CHECKING:
# TODO remove quotes (requires Python >=3.9)
# TODO get out of TYPE_CHECKING (requires Python >=3.10)
Expand Down

0 comments on commit 8eb8eea

Please sign in to comment.