Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor find_missing and refresh_who_has #3

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3024,6 +3024,7 @@ def __init__(
"keep-alive": lambda *args, **kwargs: None,
"log-event": self.log_worker_event,
"worker-status-change": self.handle_worker_status_change,
"request-refresh-who-has": self.handle_request_refresh_who_has,
}

client_handlers = {
Expand Down Expand Up @@ -4766,6 +4767,21 @@ def handle_worker_status_change(
else:
self.running.discard(ws)

async def handle_request_refresh_who_has(
self, keys: Iterable[str], worker: str, stimulus_id: str
) -> None:
"""Asynchronous request (through bulk comms) from a Worker to refresh the
who_has for some keys. Not to be confused with scheduler.who_has, which is a
synchronous RPC request from a Client.
"""
self.stream_comms[worker].send(
{
"op": "refresh-who-has",
"who_has": self.get_who_has(keys),
"stimulus_id": stimulus_id,
},
)

async def handle_worker(self, comm=None, worker=None, stimulus_id=None):
"""
Listen to responses from a single worker
Expand Down Expand Up @@ -6214,13 +6230,13 @@ def get_processing(self, workers=None):
w: [ts.key for ts in ws.processing] for w, ws in self.workers.items()
}

def get_who_has(self, keys=None):
def get_who_has(self, keys: Iterable[str] | None = None) -> dict[str, list[str]]:
if keys is not None:
return {
k: [ws.address for ws in self.tasks[k].who_has]
if k in self.tasks
key: [ws.address for ws in self.tasks[key].who_has]
if key in self.tasks
else []
for k in keys
for key in keys
}
else:
return {
Expand Down
108 changes: 70 additions & 38 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
Execute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
GatherDep,
GatherDepDoneEvent,
Instructions,
Expand All @@ -123,7 +124,9 @@
MissingDataMsg,
Recs,
RecsInstrs,
RefreshWhoHasEvent,
ReleaseWorkerDataMsg,
RequestRefreshWhoHasMsg,
RescheduleEvent,
RescheduleMsg,
SendMessageToScheduler,
Expand Down Expand Up @@ -813,6 +816,7 @@ def __init__(
"free-keys": self.handle_free_keys,
"remove-replicas": self.handle_remove_replicas,
"steal-request": self.handle_steal_request,
"refresh-who-has": self.handle_refresh_who_has,
"worker-status-change": self.handle_worker_status_change,
}

Expand Down Expand Up @@ -840,9 +844,7 @@ def __init__(
)
self.periodic_callbacks["keep-alive"] = pc

# FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117
pc = PeriodicCallback(self.find_missing, 1000) # type: ignore
self._find_missing_running = False
pc = PeriodicCallback(self.find_missing, 1000)
self.periodic_callbacks["find-missing"] = pc

self._address = contact_address
Expand Down Expand Up @@ -1836,6 +1838,13 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str:

return "OK"

def handle_refresh_who_has(
self, who_has: dict[str, list[str]], stimulus_id: str
) -> None:
self.handle_stimulus(
RefreshWhoHasEvent(who_has=who_has, stimulus_id=stimulus_id)
)

async def set_resources(self, **resources) -> None:
for r, quantity in resources.items():
if r in self.total_resources:
Expand Down Expand Up @@ -2846,7 +2855,8 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None:

@log_errors
def handle_stimulus(self, stim: StateMachineEvent) -> None:
self.stimulus_log.append(stim.to_loggable(handled=time()))
if not isinstance(stim, FindMissingEvent):
self.stimulus_log.append(stim.to_loggable(handled=time()))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ugly, but post refactor the alternative is for Worker to put its nose directly into the internal data structures of the WorkerState. The alternative would also require the Worker to autonomously realise that something's stuck on the state and query the scheduler accordingly; I don't think it should own this kind of intelligence.

recs, instructions = self.handle_event(stim)
self.transitions(recs, stimulus_id=stim.stimulus_id)
self._handle_instructions(instructions)
Expand Down Expand Up @@ -3345,7 +3355,7 @@ def done_event():
self.busy_workers.add(worker)
self.io_loop.call_later(0.15, self._readd_busy_worker, worker)

refresh_who_has = set()
refresh_who_has = []

for d in self.in_flight_workers.pop(worker):
ts = self.tasks[d]
Expand All @@ -3355,7 +3365,7 @@ def done_event():
elif busy:
recommendations[ts] = "fetch"
if not ts.who_has - self.busy_workers:
refresh_who_has.add(ts.key)
refresh_who_has.append(d)
elif ts not in recommendations:
ts.who_has.discard(worker)
self.has_what[worker].discard(ts.key)
Expand All @@ -3368,17 +3378,19 @@ def done_event():
)
)
recommendations[ts] = "fetch"
del data, response
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
who_has = await retry_operation(
self.scheduler.who_has, keys=refresh_who_has
instructions.append(
RequestRefreshWhoHasMsg(
keys=refresh_who_has,
stimulus_id=f"gather-dep-busy-{time()}",
)
)
self._update_who_has(who_has)

self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

@log_errors
def _readd_busy_worker(self, worker: str) -> None:
Expand All @@ -3388,33 +3400,13 @@ def _readd_busy_worker(self, worker: str) -> None:
)

@log_errors
async def find_missing(self) -> None:
if self._find_missing_running or not self._missing_dep_flight:
return
try:
self._find_missing_running = True
if self.validate:
for ts in self._missing_dep_flight:
assert not ts.who_has

stimulus_id = f"find-missing-{time()}"
who_has = await retry_operation(
self.scheduler.who_has,
keys=[ts.key for ts in self._missing_dep_flight],
)
self._update_who_has(who_has)
recommendations: Recs = {}
for ts in self._missing_dep_flight:
if ts.who_has:
recommendations[ts] = "fetch"
self.transitions(recommendations, stimulus_id=stimulus_id)
def find_missing(self) -> None:
self.handle_stimulus(FindMissingEvent(stimulus_id=f"find-missing-{time()}"))

finally:
self._find_missing_running = False
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks[
"find-missing"
].callback_time = self.periodic_callbacks["heartbeat"].callback_time
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks["find-missing"].callback_time = self.periodic_callbacks[
"heartbeat"
].callback_time

def _update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None:
for key, workers in who_has.items():
Expand Down Expand Up @@ -3960,6 +3952,46 @@ def _(self, ev: RescheduleEvent) -> RecsInstrs:
assert ts, self.story(ev.key)
return {ts: "rescheduled"}, []

@handle_event.register
def _(self, ev: FindMissingEvent) -> RecsInstrs:
if not self._missing_dep_flight:
return {}, []

if self.validate:
assert not any(ts.who_has for ts in self._missing_dep_flight)

smsg = RequestRefreshWhoHasMsg(
keys=[ts.key for ts in self._missing_dep_flight],
stimulus_id=ev.stimulus_id,
)
return {}, [smsg]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, the only reason for this event - which is always triggered by the worker itself - is to encapsulate this logic away from the Worker and into the WorkerState.


@handle_event.register
def _(self, ev: RefreshWhoHasEvent) -> RecsInstrs:
self._update_who_has(ev.who_has)
recommendations: Recs = {}
instructions: Instructions = []

for key in ev.who_has:
ts = self.tasks.get(key)
if not ts:
continue

if ts.who_has and ts.state == "missing":
recommendations[ts] = "fetch"
elif ts.who_has and ts.state == "fetch":
# We potentially just acquired new replicas whereas all previously known
# workers are in flight or busy. We're deliberately not testing the
# minute use cases here for the sake of simplicity; instead we rely on
# _ensure_communicating to be a no-op when there's nothing to do.
instructions.append(
EnsureCommunicatingAfterTransitions(stimulus_id=ev.stimulus_id)
)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is triggered specifically for find_missing and refresh_who_has and does not fix dask#6446

elif not ts.who_has and ts.state == "fetch":
recommendations[ts] = "missing"

return recommendations, instructions

def _prepare_args_for_execution(
self, ts: TaskState, args: tuple, kwargs: dict[str, Any]
) -> tuple[tuple, dict[str, Any]]:
Expand Down
39 changes: 39 additions & 0 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,26 @@ class AddKeysMsg(SendMessageToScheduler):
keys: list[str]


@dataclass
class RequestRefreshWhoHasMsg(SendMessageToScheduler):
"""Worker -> Scheduler asynchronous request for updated who_has information.
Not to be confused with the scheduler.who_has synchronous RPC call, which is used
by the Client.

See also
--------
RefreshWhoHasEvent
distributed.scheduler.Scheduler.request_refresh_who_has
distributed.client.Client.who_has
distributed.scheduler.Scheduler.get_who_has
"""

op = "request-refresh-who-has"

__slots__ = ("keys",)
keys: list[str]


@dataclass
class StateMachineEvent:
__slots__ = ("stimulus_id", "handled")
Expand Down Expand Up @@ -533,6 +553,25 @@ class RescheduleEvent(StateMachineEvent):
key: str


@dataclass
class FindMissingEvent(StateMachineEvent):
__slots__ = ()


@dataclass
class RefreshWhoHasEvent(StateMachineEvent):
"""Scheduler -> Worker message containing updated who_has information.

See also
--------
RequestRefreshWhoHasMsg
"""

__slots__ = ("who_has",)
# {key: [worker address, ...]}
who_has: dict[str, list[str]]


if TYPE_CHECKING:
# TODO remove quotes (requires Python >=3.9)
# TODO get out of TYPE_CHECKING (requires Python >=3.10)
Expand Down