-
-
Notifications
You must be signed in to change notification settings - Fork 719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatically restart P2P shuffles when output worker leaves #7970
Changes from all commits
6658a65
4e7b425
1cfa0d3
fe79aa3
853d953
e058234
8510c06
6138fc6
0e7b051
8e92b85
3be333e
8412d2c
ba67405
9d804dc
0fd66cd
4f456e9
8f02f8d
29405c6
476f132
d34917e
69f38f8
337dc72
2a18033
85c7ab2
263a75b
266732e
8fc28eb
fef0a94
2d2923f
9936857
fd00098
1b7d055
88f11a4
a70d1ec
976ddb3
559075b
c70b4cb
9790360
25a5cc3
593fa02
f4b5f1c
e0d0174
ff711e2
31929ab
f83bc40
6059f91
dca742e
47eacfc
f544611
a357e38
34062b1
bbfa1b9
10126ff
1482153
36143fe
ed69caf
bd74bf2
74edf2e
391042e
788f8a4
c51e39e
cf2ac3c
6b999b0
396f1c2
5070674
938eab1
6cfe848
6c1be53
f904377
160d055
71778e6
00df2e8
f1ccd63
3ac439c
d72596b
f14c71b
90eb9ea
4b4ab50
dd7ac3e
9459019
3213ef5
e4be94f
c1ef0be
434d5a3
1f0ea84
9e7c2e8
36a2c94
5764f45
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from __future__ import annotations | ||
|
||
|
||
class ShuffleClosedError(RuntimeError): | ||
pass |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -6,7 +6,7 @@ | |||||||
import logging | ||||||||
from collections import defaultdict | ||||||||
from collections.abc import Callable, Iterable, Sequence | ||||||||
from dataclasses import dataclass | ||||||||
from dataclasses import dataclass, field | ||||||||
from functools import partial | ||||||||
from itertools import product | ||||||||
from typing import TYPE_CHECKING, Any, ClassVar | ||||||||
|
@@ -34,14 +34,15 @@ | |||||||
logger = logging.getLogger(__name__) | ||||||||
|
||||||||
|
||||||||
@dataclass | ||||||||
@dataclass(eq=False) | ||||||||
class ShuffleState(abc.ABC): | ||||||||
_run_id_iterator: ClassVar[itertools.count] = itertools.count(1) | ||||||||
|
||||||||
id: ShuffleId | ||||||||
run_id: int | ||||||||
output_workers: set[str] | ||||||||
participating_workers: set[str] | ||||||||
_archived_by: str | None = field(default=None, init=False) | ||||||||
|
||||||||
@abc.abstractmethod | ||||||||
def to_msg(self) -> dict[str, Any]: | ||||||||
|
@@ -50,8 +51,11 @@ def to_msg(self) -> dict[str, Any]: | |||||||
def __str__(self) -> str: | ||||||||
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>" | ||||||||
|
||||||||
def __hash__(self) -> int: | ||||||||
return hash(self.run_id) | ||||||||
|
||||||||
@dataclass | ||||||||
|
||||||||
@dataclass(eq=False) | ||||||||
class DataFrameShuffleState(ShuffleState): | ||||||||
type: ClassVar[ShuffleType] = ShuffleType.DATAFRAME | ||||||||
worker_for: dict[int, str] | ||||||||
|
@@ -68,7 +72,7 @@ def to_msg(self) -> dict[str, Any]: | |||||||
} | ||||||||
|
||||||||
|
||||||||
@dataclass | ||||||||
@dataclass(eq=False) | ||||||||
class ArrayRechunkState(ShuffleState): | ||||||||
type: ClassVar[ShuffleType] = ShuffleType.ARRAY_RECHUNK | ||||||||
worker_for: dict[NDIndex, str] | ||||||||
|
@@ -90,19 +94,18 @@ def to_msg(self) -> dict[str, Any]: | |||||||
class ShuffleSchedulerPlugin(SchedulerPlugin): | ||||||||
""" | ||||||||
Shuffle plugin for the scheduler | ||||||||
|
||||||||
This coordinates the individual worker plugins to ensure correctness | ||||||||
and collects heartbeat messages for the dashboard. | ||||||||
|
||||||||
See Also | ||||||||
-------- | ||||||||
ShuffleWorkerPlugin | ||||||||
""" | ||||||||
|
||||||||
scheduler: Scheduler | ||||||||
states: dict[ShuffleId, ShuffleState] | ||||||||
active_shuffles: dict[ShuffleId, ShuffleState] | ||||||||
heartbeats: defaultdict[ShuffleId, dict] | ||||||||
erred_shuffles: dict[ShuffleId, Exception] | ||||||||
_shuffles: defaultdict[ShuffleId, set[ShuffleState]] | ||||||||
_archived_by_stimulus: defaultdict[str, set[ShuffleState]] | ||||||||
|
||||||||
def __init__(self, scheduler: Scheduler): | ||||||||
self.scheduler = scheduler | ||||||||
|
@@ -115,9 +118,10 @@ def __init__(self, scheduler: Scheduler): | |||||||
} | ||||||||
) | ||||||||
self.heartbeats = defaultdict(lambda: defaultdict(dict)) | ||||||||
self.states = {} | ||||||||
self.erred_shuffles = {} | ||||||||
self.active_shuffles = {} | ||||||||
self.scheduler.add_plugin(self, name="shuffle") | ||||||||
self._shuffles = defaultdict(set) | ||||||||
self._archived_by_stimulus = defaultdict(set) | ||||||||
|
||||||||
async def start(self, scheduler: Scheduler) -> None: | ||||||||
worker_plugin = ShuffleWorkerPlugin() | ||||||||
|
@@ -126,18 +130,19 @@ async def start(self, scheduler: Scheduler) -> None: | |||||||
) | ||||||||
|
||||||||
def shuffle_ids(self) -> set[ShuffleId]: | ||||||||
return set(self.states) | ||||||||
return set(self.active_shuffles) | ||||||||
|
||||||||
async def barrier(self, id: ShuffleId, run_id: int) -> None: | ||||||||
shuffle = self.states[id] | ||||||||
shuffle = self.active_shuffles[id] | ||||||||
assert shuffle.run_id == run_id, f"{run_id=} does not match {shuffle}" | ||||||||
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id} | ||||||||
await self.scheduler.broadcast( | ||||||||
msg=msg, workers=list(shuffle.participating_workers) | ||||||||
msg=msg, | ||||||||
workers=list(shuffle.participating_workers), | ||||||||
) | ||||||||
|
||||||||
def restrict_task(self, id: ShuffleId, run_id: int, key: str, worker: str) -> dict: | ||||||||
shuffle = self.states[id] | ||||||||
shuffle = self.active_shuffles[id] | ||||||||
if shuffle.run_id > run_id: | ||||||||
return { | ||||||||
"status": "error", | ||||||||
|
@@ -158,15 +163,19 @@ def heartbeat(self, ws: WorkerState, data: dict) -> None: | |||||||
self.heartbeats[shuffle_id][ws.address].update(d) | ||||||||
|
||||||||
def get(self, id: ShuffleId, worker: str) -> dict[str, Any]: | ||||||||
if exception := self.erred_shuffles.get(id): | ||||||||
return {"status": "error", "message": str(exception)} | ||||||||
state = self.states[id] | ||||||||
if worker not in self.scheduler.workers: | ||||||||
# This should never happen | ||||||||
raise RuntimeError( | ||||||||
f"Scheduler is unaware of this worker {worker!r}" | ||||||||
) # pragma: nocover | ||||||||
state = self.active_shuffles[id] | ||||||||
state.participating_workers.add(worker) | ||||||||
return state.to_msg() | ||||||||
|
||||||||
def get_or_create( | ||||||||
self, | ||||||||
id: ShuffleId, | ||||||||
key: str, | ||||||||
type: str, | ||||||||
worker: str, | ||||||||
spec: dict[str, Any], | ||||||||
|
@@ -178,6 +187,7 @@ def get_or_create( | |||||||
# known by its name. If the name has been mangled, we cannot guarantee | ||||||||
# that the shuffle works as intended and should fail instead. | ||||||||
self._raise_if_barrier_unknown(id) | ||||||||
self._raise_if_task_not_processing(key) | ||||||||
|
||||||||
state: ShuffleState | ||||||||
if type == ShuffleType.DATAFRAME: | ||||||||
|
@@ -186,7 +196,8 @@ def get_or_create( | |||||||
state = self._create_array_rechunk_state(id, spec) | ||||||||
else: # pragma: no cover | ||||||||
raise TypeError(type) | ||||||||
self.states[id] = state | ||||||||
self.active_shuffles[id] = state | ||||||||
self._shuffles[id].add(state) | ||||||||
state.participating_workers.add(worker) | ||||||||
return state.to_msg() | ||||||||
|
||||||||
|
@@ -201,6 +212,11 @@ def _raise_if_barrier_unknown(self, id: ShuffleId) -> None: | |||||||
"into this by leaving a comment at distributed#7816." | ||||||||
) | ||||||||
|
||||||||
def _raise_if_task_not_processing(self, key: str) -> None: | ||||||||
task = self.scheduler.tasks[key] | ||||||||
if task.state != "processing": | ||||||||
raise RuntimeError(f"Expected {task} to be processing, is {task.state}.") | ||||||||
|
||||||||
def _create_dataframe_shuffle_state( | ||||||||
self, id: ShuffleId, spec: dict[str, Any] | ||||||||
) -> DataFrameShuffleState: | ||||||||
|
@@ -309,34 +325,67 @@ def _unset_restriction(self, ts: TaskState) -> None: | |||||||
original_restrictions = ts.annotations.pop("shuffle_original_restrictions") | ||||||||
self.scheduler.set_restrictions({ts.key: original_restrictions}) | ||||||||
|
||||||||
def _restart_recommendations(self, id: ShuffleId) -> Recs: | ||||||||
barrier_task = self.scheduler.tasks[barrier_key(id)] | ||||||||
recs: Recs = {} | ||||||||
|
||||||||
for dt in barrier_task.dependents: | ||||||||
if dt.state == "erred": | ||||||||
return {} | ||||||||
recs.update({dt.key: "released"}) | ||||||||
|
||||||||
if barrier_task.state == "erred": | ||||||||
# This should never happen, a dependent of the barrier should already | ||||||||
# be `erred` | ||||||||
raise RuntimeError( | ||||||||
f"Expected dependents of {barrier_task=} to be 'erred' if " | ||||||||
"the barrier is." | ||||||||
) # pragma: no cover | ||||||||
recs.update({barrier_task.key: "released"}) | ||||||||
|
||||||||
for dt in barrier_task.dependencies: | ||||||||
if dt.state == "erred": | ||||||||
# This should never happen, a dependent of the barrier should already | ||||||||
# be `erred` | ||||||||
raise RuntimeError( | ||||||||
f"Expected barrier and its dependents to be " | ||||||||
f"'erred' if the barrier's dependency {dt} is." | ||||||||
) # pragma: no cover | ||||||||
recs.update({dt.key: "released"}) | ||||||||
return recs | ||||||||
|
||||||||
def _restart_shuffle( | ||||||||
self, id: ShuffleId, scheduler: Scheduler, *, stimulus_id: str | ||||||||
) -> None: | ||||||||
recs = self._restart_recommendations(id) | ||||||||
self.scheduler.transitions(recs, stimulus_id=stimulus_id) | ||||||||
self.scheduler.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) | ||||||||
|
||||||||
def remove_worker( | ||||||||
self, scheduler: Scheduler, worker: str, *, stimulus_id: str, **kwargs: Any | ||||||||
) -> None: | ||||||||
from time import time | ||||||||
|
||||||||
stimulus_id = f"shuffle-failed-worker-left-{time()}" | ||||||||
"""Restart all active shuffles when a participating worker leaves the cluster. | ||||||||
|
||||||||
.. note:: | ||||||||
Due to the order of operations in :meth:`~Scheduler.remove_worker`, the | ||||||||
shuffle may have already been archived by | ||||||||
:meth:`~ShuffleSchedulerPlugin.transition`. In this case, the | ||||||||
``stimulus_id`` is used as a transaction identifier and all archived shuffles | ||||||||
with a matching `stimulus_id` are restarted. | ||||||||
""" | ||||||||
|
||||||||
recs: Recs = {} | ||||||||
for shuffle_id, shuffle in self.states.items(): | ||||||||
# If processing the transactions causes a task to get released, this | ||||||||
# removes the shuffle from self.active_shuffles. Therefore, we must iterate | ||||||||
# over a copy. | ||||||||
for shuffle_id, shuffle in self.active_shuffles.copy().items(): | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then we iterate over all active shuffles, remove and restart? Why do we not unconditionally restart the archived shuffles after this loop over active shuffles? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not 100% sure I'm following, but what I think you're saying is a very good point. |
||||||||
if worker not in shuffle.participating_workers: | ||||||||
continue | ||||||||
exception = RuntimeError(f"Worker {worker} left during active {shuffle}") | ||||||||
self.erred_shuffles[shuffle_id] = exception | ||||||||
self._fail_on_workers(shuffle, str(exception)) | ||||||||
self._clean_on_scheduler(shuffle_id, stimulus_id) | ||||||||
|
||||||||
barrier_task = self.scheduler.tasks[barrier_key(shuffle_id)] | ||||||||
if barrier_task.state == "memory": | ||||||||
for dt in barrier_task.dependents: | ||||||||
if worker not in dt.worker_restrictions: | ||||||||
continue | ||||||||
self._unset_restriction(dt) | ||||||||
recs.update({dt.key: "waiting"}) | ||||||||
# TODO: Do we need to handle other states? | ||||||||
|
||||||||
# If processing the transactions causes a task to get released, this | ||||||||
# removes the shuffle from self.states. Therefore, we must process them | ||||||||
# outside of the loop. | ||||||||
self.scheduler.transitions(recs, stimulus_id=stimulus_id) | ||||||||
for shuffle in self._archived_by_stimulus.get(stimulus_id, set()): | ||||||||
self._restart_shuffle(shuffle.id, scheduler, stimulus_id=stimulus_id) | ||||||||
|
||||||||
def transition( | ||||||||
self, | ||||||||
|
@@ -347,17 +396,25 @@ def transition( | |||||||
stimulus_id: str, | ||||||||
**kwargs: Any, | ||||||||
) -> None: | ||||||||
"""Clean up scheduler and worker state once a shuffle becomes inactive.""" | ||||||||
if finish not in ("released", "forgotten"): | ||||||||
return | ||||||||
if not key.startswith("shuffle-barrier-"): | ||||||||
return | ||||||||
shuffle_id = id_from_key(key) | ||||||||
try: | ||||||||
shuffle = self.states[shuffle_id] | ||||||||
except KeyError: | ||||||||
return | ||||||||
self._fail_on_workers(shuffle, message=f"{shuffle} forgotten") | ||||||||
self._clean_on_scheduler(shuffle_id) | ||||||||
|
||||||||
if shuffle := self.active_shuffles.get(shuffle_id): | ||||||||
self._fail_on_workers(shuffle, message=f"{shuffle} forgotten") | ||||||||
self._clean_on_scheduler(shuffle_id, stimulus_id=stimulus_id) | ||||||||
|
||||||||
if finish == "forgotten": | ||||||||
shuffles = self._shuffles.pop(shuffle_id, set()) | ||||||||
for shuffle in shuffles: | ||||||||
if shuffle._archived_by: | ||||||||
archived = self._archived_by_stimulus[shuffle._archived_by] | ||||||||
archived.remove(shuffle) | ||||||||
if not archived: | ||||||||
del self._archived_by_stimulus[shuffle._archived_by] | ||||||||
Comment on lines
+416
to
+417
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why remove this (now empty) set from the dict? (AKA, why is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On a long-running cluster, we would keep adding new entries to the dictionary and unnecessarily increase its size. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Individual distributed/distributed/scheduler.py Lines 5688 to 5690 in b7e5f8f
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, ok. |
||||||||
|
||||||||
def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None: | ||||||||
worker_msgs = { | ||||||||
|
@@ -373,9 +430,12 @@ def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None: | |||||||
} | ||||||||
self.scheduler.send_all({}, worker_msgs) | ||||||||
|
||||||||
def _clean_on_scheduler(self, id: ShuffleId) -> None: | ||||||||
del self.states[id] | ||||||||
self.erred_shuffles.pop(id, None) | ||||||||
def _clean_on_scheduler(self, id: ShuffleId, stimulus_id: str | None) -> None: | ||||||||
shuffle = self.active_shuffles.pop(id) | ||||||||
if not shuffle._archived_by and stimulus_id: | ||||||||
shuffle._archived_by = stimulus_id | ||||||||
self._archived_by_stimulus[stimulus_id].add(shuffle) | ||||||||
|
||||||||
with contextlib.suppress(KeyError): | ||||||||
del self.heartbeats[id] | ||||||||
|
||||||||
|
@@ -384,9 +444,10 @@ def _clean_on_scheduler(self, id: ShuffleId) -> None: | |||||||
self._unset_restriction(dt) | ||||||||
|
||||||||
def restart(self, scheduler: Scheduler) -> None: | ||||||||
self.states.clear() | ||||||||
self.active_shuffles.clear() | ||||||||
self.heartbeats.clear() | ||||||||
self.erred_shuffles.clear() | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This object now has more state than previously sitting around inside it, is it deliberate that (say) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, addressed. |
||||||||
self._shuffles.clear() | ||||||||
self._archived_by_stimulus.clear() | ||||||||
|
||||||||
|
||||||||
def get_worker_for_range_sharding( | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
|
||
from distributed.exceptions import Reschedule | ||
from distributed.shuffle._arrow import check_dtype_support, check_minimal_arrow_version | ||
from distributed.shuffle._exceptions import ShuffleClosedError | ||
|
||
logger = logging.getLogger("distributed.shuffle") | ||
if TYPE_CHECKING: | ||
|
@@ -69,6 +70,8 @@ def shuffle_transfer( | |
column=column, | ||
parts_out=parts_out, | ||
) | ||
except ShuffleClosedError: | ||
raise Reschedule() | ||
except Exception as e: | ||
raise RuntimeError(f"shuffle_transfer failed during shuffle {id}") from e | ||
|
||
|
@@ -82,6 +85,8 @@ def shuffle_unpack( | |
) | ||
except Reschedule as e: | ||
raise e | ||
except ShuffleClosedError: | ||
raise Reschedule() | ||
Comment on lines
+88
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the easiest way of dealing with race conditions between a worker closing down, thus raising the |
||
except Exception as e: | ||
raise RuntimeError(f"shuffle_unpack failed during shuffle {id}") from e | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that
active_shuffles
,_shuffles
and_archived_by_stimulus
are always required to be in sync with one-another. Is there a way to build a datastructure that enforces the relevant invariants? Rather than having to maintain all these separate mappings and remember to update them all correctly?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm certain there is, but I would like to keep the refactoring separate from actual logic changes. Let's keep this for a follow-up PR? (Would you be interested in taking that on?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please open an issue and assign to me, and I will give it a go, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do once this is merged 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#8018