Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor find_missing and refresh_who_has #6348

Merged
merged 1 commit into from
Jun 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3031,6 +3031,7 @@ def __init__(
"keep-alive": lambda *args, **kwargs: None,
"log-event": self.log_worker_event,
"worker-status-change": self.handle_worker_status_change,
"request-refresh-who-has": self.handle_request_refresh_who_has,
}

client_handlers = {
Expand Down Expand Up @@ -4782,6 +4783,21 @@ def handle_worker_status_change(
else:
self.running.discard(ws)

async def handle_request_refresh_who_has(
self, keys: Iterable[str], worker: str, stimulus_id: str
) -> None:
"""Asynchronous request (through bulk comms) from a Worker to refresh the
who_has for some keys. Not to be confused with scheduler.who_has, which is a
synchronous RPC request from a Client.
"""
self.stream_comms[worker].send(
{
"op": "refresh-who-has",
"who_has": self.get_who_has(keys),
"stimulus_id": stimulus_id,
},
)

async def handle_worker(self, comm=None, worker=None, stimulus_id=None):
"""
Listen to responses from a single worker
Expand Down Expand Up @@ -6230,13 +6246,13 @@ def get_processing(self, workers=None):
w: [ts.key for ts in ws.processing] for w, ws in self.workers.items()
}

def get_who_has(self, keys=None):
def get_who_has(self, keys: Iterable[str] | None = None) -> dict[str, list[str]]:
if keys is not None:
return {
k: [ws.address for ws in self.tasks[k].who_has]
if k in self.tasks
key: [ws.address for ws in self.tasks[key].who_has]
if key in self.tasks
else []
for k in keys
for key in keys
}
else:
return {
Expand Down
118 changes: 74 additions & 44 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
Execute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
GatherDep,
GatherDepDoneEvent,
Instructions,
Expand All @@ -123,7 +124,9 @@
MissingDataMsg,
Recs,
RecsInstrs,
RefreshWhoHasEvent,
ReleaseWorkerDataMsg,
RequestRefreshWhoHasMsg,
RescheduleEvent,
RescheduleMsg,
SendMessageToScheduler,
Expand Down Expand Up @@ -813,6 +816,7 @@ def __init__(
"free-keys": self.handle_free_keys,
"remove-replicas": self.handle_remove_replicas,
"steal-request": self.handle_steal_request,
"refresh-who-has": self.handle_refresh_who_has,
"worker-status-change": self.handle_worker_status_change,
}

Expand Down Expand Up @@ -840,9 +844,7 @@ def __init__(
)
self.periodic_callbacks["keep-alive"] = pc

# FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117
pc = PeriodicCallback(self.find_missing, 1000) # type: ignore
self._find_missing_running = False
pc = PeriodicCallback(self.find_missing, 1000)
self.periodic_callbacks["find-missing"] = pc

self._address = contact_address
Expand Down Expand Up @@ -1839,6 +1841,13 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str:

return "OK"

def handle_refresh_who_has(
self, who_has: dict[str, list[str]], stimulus_id: str
) -> None:
self.handle_stimulus(
RefreshWhoHasEvent(who_has=who_has, stimulus_id=stimulus_id)
)

async def set_resources(self, **resources) -> None:
for r, quantity in resources.items():
if r in self.total_resources:
Expand Down Expand Up @@ -2849,7 +2858,8 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None:

@log_errors
def handle_stimulus(self, stim: StateMachineEvent) -> None:
self.stimulus_log.append(stim.to_loggable(handled=time()))
if not isinstance(stim, FindMissingEvent):
self.stimulus_log.append(stim.to_loggable(handled=time()))
Comment on lines +2861 to +2862
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you're logging this because it is flooding the log otherwise?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Plus, it's really not interesting to see in the log even when relevant. Note that the response from the scheduler is logged.

recs, instructions = self.handle_event(stim)
self.transitions(recs, stimulus_id=stim.stimulus_id)
self._handle_instructions(instructions)
Expand Down Expand Up @@ -2991,11 +3001,8 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
if ts.state != "fetch" or ts.key in all_keys_to_gather:
continue

if not ts.who_has:
recommendations[ts] = "missing"
continue

if self.validate:
assert ts.who_has
assert self.address not in ts.who_has

workers = [
Expand Down Expand Up @@ -3348,7 +3355,7 @@ def done_event():
self.busy_workers.add(worker)
self.io_loop.call_later(0.15, self._readd_busy_worker, worker)

refresh_who_has = set()
refresh_who_has = []

for d in self.in_flight_workers.pop(worker):
ts = self.tasks[d]
Expand All @@ -3358,7 +3365,7 @@ def done_event():
elif busy:
recommendations[ts] = "fetch"
if not ts.who_has - self.busy_workers:
refresh_who_has.add(ts.key)
refresh_who_has.append(d)
elif ts not in recommendations:
ts.who_has.discard(worker)
self.has_what[worker].discard(ts.key)
Expand All @@ -3371,17 +3378,19 @@ def done_event():
)
)
recommendations[ts] = "fetch"
del data, response
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
who_has = await retry_operation(
self.scheduler.who_has, keys=refresh_who_has
instructions.append(
RequestRefreshWhoHasMsg(
keys=refresh_who_has,
stimulus_id=f"gather-dep-busy-{time()}",
)
)
self._update_who_has(who_has)

self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

@log_errors
def _readd_busy_worker(self, worker: str) -> None:
Expand All @@ -3391,33 +3400,13 @@ def _readd_busy_worker(self, worker: str) -> None:
)

@log_errors
async def find_missing(self) -> None:
if self._find_missing_running or not self._missing_dep_flight:
return
try:
self._find_missing_running = True
if self.validate:
for ts in self._missing_dep_flight:
assert not ts.who_has
def find_missing(self) -> None:
self.handle_stimulus(FindMissingEvent(stimulus_id=f"find-missing-{time()}"))

stimulus_id = f"find-missing-{time()}"
who_has = await retry_operation(
self.scheduler.who_has,
keys=[ts.key for ts in self._missing_dep_flight],
)
self._update_who_has(who_has)
recommendations: Recs = {}
for ts in self._missing_dep_flight:
if ts.who_has:
recommendations[ts] = "fetch"
self.transitions(recommendations, stimulus_id=stimulus_id)

finally:
self._find_missing_running = False
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks[
"find-missing"
].callback_time = self.periodic_callbacks["heartbeat"].callback_time
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks["find-missing"].callback_time = self.periodic_callbacks[
"heartbeat"
].callback_time

def _update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None:
for key, workers in who_has.items():
Expand Down Expand Up @@ -3965,6 +3954,47 @@ def _(self, ev: RescheduleEvent) -> RecsInstrs:
assert ts, self.story(ev.key)
return {ts: "rescheduled"}, []

@handle_event.register
def _(self, ev: FindMissingEvent) -> RecsInstrs:
if not self._missing_dep_flight:
return {}, []

if self.validate:
assert not any(ts.who_has for ts in self._missing_dep_flight)

smsg = RequestRefreshWhoHasMsg(
keys=[ts.key for ts in self._missing_dep_flight],
stimulus_id=ev.stimulus_id,
)
return {}, [smsg]

@handle_event.register
def _(self, ev: RefreshWhoHasEvent) -> RecsInstrs:
self._update_who_has(ev.who_has)
recommendations: Recs = {}
instructions: Instructions = []

for key in ev.who_has:
ts = self.tasks.get(key)
if not ts:
continue

if ts.who_has and ts.state == "missing":
recommendations[ts] = "fetch"
elif ts.who_has and ts.state == "fetch":
# We potentially just acquired new replicas whereas all previously known
# workers are in flight or busy. We're deliberately not testing the
# minute use cases here for the sake of simplicity; instead we rely on
# _ensure_communicating to be a no-op when there's nothing to do.
recommendations, instructions = merge_recs_instructions(
(recommendations, instructions),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
elif not ts.who_has and ts.state == "fetch":
recommendations[ts] = "missing"

return recommendations, instructions

def _prepare_args_for_execution(
self, ts: TaskState, args: tuple, kwargs: dict[str, Any]
) -> tuple[tuple, dict[str, Any]]:
Expand Down Expand Up @@ -4190,8 +4220,8 @@ def validate_task_fetch(self, ts):
assert self.address not in ts.who_has
assert not ts.done
assert ts in self.data_needed
# Note: ts.who_has may be have been emptied by _update_who_has, but the task
# won't transition to missing until it reaches the top of the data_needed heap.
assert ts.who_has

for w in ts.who_has:
assert ts.key in self.has_what[w]
assert ts in self.data_needed_per_worker[w]
Expand Down
39 changes: 39 additions & 0 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,26 @@ class AddKeysMsg(SendMessageToScheduler):
keys: list[str]


@dataclass
class RequestRefreshWhoHasMsg(SendMessageToScheduler):
"""Worker -> Scheduler asynchronous request for updated who_has information.
Not to be confused with the scheduler.who_has synchronous RPC call, which is used
by the Client.
See also
--------
RefreshWhoHasEvent
distributed.scheduler.Scheduler.request_refresh_who_has
distributed.client.Client.who_has
distributed.scheduler.Scheduler.get_who_has
"""

op = "request-refresh-who-has"

__slots__ = ("keys",)
keys: list[str]


@dataclass
class StateMachineEvent:
__slots__ = ("stimulus_id", "handled")
Expand Down Expand Up @@ -533,6 +553,25 @@ class RescheduleEvent(StateMachineEvent):
key: str


@dataclass
class FindMissingEvent(StateMachineEvent):
__slots__ = ()


@dataclass
class RefreshWhoHasEvent(StateMachineEvent):
"""Scheduler -> Worker message containing updated who_has information.
See also
--------
RequestRefreshWhoHasMsg
"""

__slots__ = ("who_has",)
# {key: [worker address, ...]}
who_has: dict[str, list[str]]


if TYPE_CHECKING:
# TODO remove quotes (requires Python >=3.9)
# TODO get out of TYPE_CHECKING (requires Python >=3.10)
Expand Down