diff --git a/distributed/shuffle/_exceptions.py b/distributed/shuffle/_exceptions.py new file mode 100644 index 0000000000..57a54a15e7 --- /dev/null +++ b/distributed/shuffle/_exceptions.py @@ -0,0 +1,5 @@ +from __future__ import annotations + + +class ShuffleClosedError(RuntimeError): + pass diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index 911f4f89b9..ec670c0b07 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -6,7 +6,7 @@ import logging from collections import defaultdict from collections.abc import Callable, Iterable, Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from itertools import product from typing import TYPE_CHECKING, Any, ClassVar @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) -@dataclass +@dataclass(eq=False) class ShuffleState(abc.ABC): _run_id_iterator: ClassVar[itertools.count] = itertools.count(1) @@ -42,6 +42,7 @@ class ShuffleState(abc.ABC): run_id: int output_workers: set[str] participating_workers: set[str] + _archived_by: str | None = field(default=None, init=False) @abc.abstractmethod def to_msg(self) -> dict[str, Any]: @@ -50,8 +51,11 @@ def to_msg(self) -> dict[str, Any]: def __str__(self) -> str: return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>" + def __hash__(self) -> int: + return hash(self.run_id) -@dataclass + +@dataclass(eq=False) class DataFrameShuffleState(ShuffleState): type: ClassVar[ShuffleType] = ShuffleType.DATAFRAME worker_for: dict[int, str] @@ -68,7 +72,7 @@ def to_msg(self) -> dict[str, Any]: } -@dataclass +@dataclass(eq=False) class ArrayRechunkState(ShuffleState): type: ClassVar[ShuffleType] = ShuffleType.ARRAY_RECHUNK worker_for: dict[NDIndex, str] @@ -90,19 +94,18 @@ def to_msg(self) -> dict[str, Any]: class ShuffleSchedulerPlugin(SchedulerPlugin): """ Shuffle plugin for the scheduler - This coordinates the individual worker plugins to ensure correctness and collects heartbeat messages for the dashboard. - See Also -------- ShuffleWorkerPlugin """ scheduler: Scheduler - states: dict[ShuffleId, ShuffleState] + active_shuffles: dict[ShuffleId, ShuffleState] heartbeats: defaultdict[ShuffleId, dict] - erred_shuffles: dict[ShuffleId, Exception] + _shuffles: defaultdict[ShuffleId, set[ShuffleState]] + _archived_by_stimulus: defaultdict[str, set[ShuffleState]] def __init__(self, scheduler: Scheduler): self.scheduler = scheduler @@ -115,9 +118,10 @@ def __init__(self, scheduler: Scheduler): } ) self.heartbeats = defaultdict(lambda: defaultdict(dict)) - self.states = {} - self.erred_shuffles = {} + self.active_shuffles = {} self.scheduler.add_plugin(self, name="shuffle") + self._shuffles = defaultdict(set) + self._archived_by_stimulus = defaultdict(set) async def start(self, scheduler: Scheduler) -> None: worker_plugin = ShuffleWorkerPlugin() @@ -126,18 +130,19 @@ async def start(self, scheduler: Scheduler) -> None: ) def shuffle_ids(self) -> set[ShuffleId]: - return set(self.states) + return set(self.active_shuffles) async def barrier(self, id: ShuffleId, run_id: int) -> None: - shuffle = self.states[id] + shuffle = self.active_shuffles[id] assert shuffle.run_id == run_id, f"{run_id=} does not match {shuffle}" msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id} await self.scheduler.broadcast( - msg=msg, workers=list(shuffle.participating_workers) + msg=msg, + workers=list(shuffle.participating_workers), ) def restrict_task(self, id: ShuffleId, run_id: int, key: str, worker: str) -> dict: - shuffle = self.states[id] + shuffle = self.active_shuffles[id] if shuffle.run_id > run_id: return { "status": "error", @@ -158,15 +163,19 @@ def heartbeat(self, ws: WorkerState, data: dict) -> None: self.heartbeats[shuffle_id][ws.address].update(d) def get(self, id: ShuffleId, worker: str) -> dict[str, Any]: - if exception := self.erred_shuffles.get(id): - return {"status": "error", "message": str(exception)} - state = self.states[id] + if worker not in self.scheduler.workers: + # This should never happen + raise RuntimeError( + f"Scheduler is unaware of this worker {worker!r}" + ) # pragma: nocover + state = self.active_shuffles[id] state.participating_workers.add(worker) return state.to_msg() def get_or_create( self, id: ShuffleId, + key: str, type: str, worker: str, spec: dict[str, Any], @@ -178,6 +187,7 @@ def get_or_create( # known by its name. If the name has been mangled, we cannot guarantee # that the shuffle works as intended and should fail instead. self._raise_if_barrier_unknown(id) + self._raise_if_task_not_processing(key) state: ShuffleState if type == ShuffleType.DATAFRAME: @@ -186,7 +196,8 @@ def get_or_create( state = self._create_array_rechunk_state(id, spec) else: # pragma: no cover raise TypeError(type) - self.states[id] = state + self.active_shuffles[id] = state + self._shuffles[id].add(state) state.participating_workers.add(worker) return state.to_msg() @@ -201,6 +212,11 @@ def _raise_if_barrier_unknown(self, id: ShuffleId) -> None: "into this by leaving a comment at distributed#7816." ) + def _raise_if_task_not_processing(self, key: str) -> None: + task = self.scheduler.tasks[key] + if task.state != "processing": + raise RuntimeError(f"Expected {task} to be processing, is {task.state}.") + def _create_dataframe_shuffle_state( self, id: ShuffleId, spec: dict[str, Any] ) -> DataFrameShuffleState: @@ -309,34 +325,67 @@ def _unset_restriction(self, ts: TaskState) -> None: original_restrictions = ts.annotations.pop("shuffle_original_restrictions") self.scheduler.set_restrictions({ts.key: original_restrictions}) + def _restart_recommendations(self, id: ShuffleId) -> Recs: + barrier_task = self.scheduler.tasks[barrier_key(id)] + recs: Recs = {} + + for dt in barrier_task.dependents: + if dt.state == "erred": + return {} + recs.update({dt.key: "released"}) + + if barrier_task.state == "erred": + # This should never happen, a dependent of the barrier should already + # be `erred` + raise RuntimeError( + f"Expected dependents of {barrier_task=} to be 'erred' if " + "the barrier is." + ) # pragma: no cover + recs.update({barrier_task.key: "released"}) + + for dt in barrier_task.dependencies: + if dt.state == "erred": + # This should never happen, a dependent of the barrier should already + # be `erred` + raise RuntimeError( + f"Expected barrier and its dependents to be " + f"'erred' if the barrier's dependency {dt} is." + ) # pragma: no cover + recs.update({dt.key: "released"}) + return recs + + def _restart_shuffle( + self, id: ShuffleId, scheduler: Scheduler, *, stimulus_id: str + ) -> None: + recs = self._restart_recommendations(id) + self.scheduler.transitions(recs, stimulus_id=stimulus_id) + self.scheduler.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) + def remove_worker( self, scheduler: Scheduler, worker: str, *, stimulus_id: str, **kwargs: Any ) -> None: - from time import time - - stimulus_id = f"shuffle-failed-worker-left-{time()}" + """Restart all active shuffles when a participating worker leaves the cluster. + + .. note:: + Due to the order of operations in :meth:`~Scheduler.remove_worker`, the + shuffle may have already been archived by + :meth:`~ShuffleSchedulerPlugin.transition`. In this case, the + ``stimulus_id`` is used as a transaction identifier and all archived shuffles + with a matching `stimulus_id` are restarted. + """ - recs: Recs = {} - for shuffle_id, shuffle in self.states.items(): + # If processing the transactions causes a task to get released, this + # removes the shuffle from self.active_shuffles. Therefore, we must iterate + # over a copy. + for shuffle_id, shuffle in self.active_shuffles.copy().items(): if worker not in shuffle.participating_workers: continue exception = RuntimeError(f"Worker {worker} left during active {shuffle}") - self.erred_shuffles[shuffle_id] = exception self._fail_on_workers(shuffle, str(exception)) + self._clean_on_scheduler(shuffle_id, stimulus_id) - barrier_task = self.scheduler.tasks[barrier_key(shuffle_id)] - if barrier_task.state == "memory": - for dt in barrier_task.dependents: - if worker not in dt.worker_restrictions: - continue - self._unset_restriction(dt) - recs.update({dt.key: "waiting"}) - # TODO: Do we need to handle other states? - - # If processing the transactions causes a task to get released, this - # removes the shuffle from self.states. Therefore, we must process them - # outside of the loop. - self.scheduler.transitions(recs, stimulus_id=stimulus_id) + for shuffle in self._archived_by_stimulus.get(stimulus_id, set()): + self._restart_shuffle(shuffle.id, scheduler, stimulus_id=stimulus_id) def transition( self, @@ -347,17 +396,25 @@ def transition( stimulus_id: str, **kwargs: Any, ) -> None: + """Clean up scheduler and worker state once a shuffle becomes inactive.""" if finish not in ("released", "forgotten"): return if not key.startswith("shuffle-barrier-"): return shuffle_id = id_from_key(key) - try: - shuffle = self.states[shuffle_id] - except KeyError: - return - self._fail_on_workers(shuffle, message=f"{shuffle} forgotten") - self._clean_on_scheduler(shuffle_id) + + if shuffle := self.active_shuffles.get(shuffle_id): + self._fail_on_workers(shuffle, message=f"{shuffle} forgotten") + self._clean_on_scheduler(shuffle_id, stimulus_id=stimulus_id) + + if finish == "forgotten": + shuffles = self._shuffles.pop(shuffle_id, set()) + for shuffle in shuffles: + if shuffle._archived_by: + archived = self._archived_by_stimulus[shuffle._archived_by] + archived.remove(shuffle) + if not archived: + del self._archived_by_stimulus[shuffle._archived_by] def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None: worker_msgs = { @@ -373,9 +430,12 @@ def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None: } self.scheduler.send_all({}, worker_msgs) - def _clean_on_scheduler(self, id: ShuffleId) -> None: - del self.states[id] - self.erred_shuffles.pop(id, None) + def _clean_on_scheduler(self, id: ShuffleId, stimulus_id: str | None) -> None: + shuffle = self.active_shuffles.pop(id) + if not shuffle._archived_by and stimulus_id: + shuffle._archived_by = stimulus_id + self._archived_by_stimulus[stimulus_id].add(shuffle) + with contextlib.suppress(KeyError): del self.heartbeats[id] @@ -384,9 +444,10 @@ def _clean_on_scheduler(self, id: ShuffleId) -> None: self._unset_restriction(dt) def restart(self, scheduler: Scheduler) -> None: - self.states.clear() + self.active_shuffles.clear() self.heartbeats.clear() - self.erred_shuffles.clear() + self._shuffles.clear() + self._archived_by_stimulus.clear() def get_worker_for_range_sharding( diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index ed9c22bb07..62bb13c90b 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -11,6 +11,7 @@ from distributed.exceptions import Reschedule from distributed.shuffle._arrow import check_dtype_support, check_minimal_arrow_version +from distributed.shuffle._exceptions import ShuffleClosedError logger = logging.getLogger("distributed.shuffle") if TYPE_CHECKING: @@ -69,6 +70,8 @@ def shuffle_transfer( column=column, parts_out=parts_out, ) + except ShuffleClosedError: + raise Reschedule() except Exception as e: raise RuntimeError(f"shuffle_transfer failed during shuffle {id}") from e @@ -82,6 +85,8 @@ def shuffle_unpack( ) except Reschedule as e: raise e + except ShuffleClosedError: + raise Reschedule() except Exception as e: raise RuntimeError(f"shuffle_unpack failed during shuffle {id}") from e diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 8a3ab7ba72..0f2fcf415e 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -29,6 +29,7 @@ ) from distributed.shuffle._comms import CommShardsBuffer from distributed.shuffle._disk import DiskShardsBuffer +from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._rechunk import ChunkedAxes, NDIndex, split_axes from distributed.shuffle._shuffle import ShuffleId, ShuffleType @@ -50,10 +51,6 @@ logger = logging.getLogger(__name__) -class ShuffleClosedError(RuntimeError): - pass - - class ShuffleRun(Generic[T_partition_id, T_partition_type]): def __init__( self, @@ -577,6 +574,7 @@ class ShuffleWorkerPlugin(WorkerPlugin): worker: Worker shuffles: dict[ShuffleId, ShuffleRun] _runs: set[ShuffleRun] + _runs_cleanup_condition: asyncio.Condition memory_limiter_comms: ResourceLimiter memory_limiter_disk: ResourceLimiter closed: bool @@ -592,6 +590,7 @@ def setup(self, worker: Worker) -> None: self.worker = worker self.shuffles = {} self._runs = set() + self._runs_cleanup_condition = asyncio.Condition() self.memory_limiter_comms = ResourceLimiter(parse_bytes("100 MiB")) self.memory_limiter_disk = ResourceLimiter(parse_bytes("1 GiB")) self.closed = False @@ -632,6 +631,12 @@ async def shuffle_inputs_done(self, shuffle_id: ShuffleId, run_id: int) -> None: shuffle = await self._get_shuffle_run(shuffle_id, run_id) await shuffle.inputs_done() + async def _close_shuffle_run(self, shuffle: ShuffleRun) -> None: + await shuffle.close() + async with self._runs_cleanup_condition: + self._runs.remove(shuffle) + self._runs_cleanup_condition.notify_all() + def shuffle_fail(self, shuffle_id: ShuffleId, run_id: int, message: str) -> None: """Fails the shuffle run with the message as exception and triggers cleanup. @@ -648,11 +653,9 @@ def shuffle_fail(self, shuffle_id: ShuffleId, run_id: int, message: str) -> None exception = RuntimeError(message) shuffle.fail(exception) - async def _(extension: ShuffleWorkerPlugin, shuffle: ShuffleRun) -> None: - await shuffle.close() - extension._runs.remove(shuffle) - - self.worker._ongoing_background_tasks.call_soon(_, self, shuffle) + self.worker._ongoing_background_tasks.call_soon( + self._close_shuffle_run, shuffle + ) def add_partition( self, @@ -726,6 +729,7 @@ async def _get_or_create_shuffle( self, shuffle_id: ShuffleId, type: ShuffleType, + key: str, **kwargs: Any, ) -> ShuffleRun: """Get or create a shuffle matching the ID and data spec. @@ -736,12 +740,15 @@ async def _get_or_create_shuffle( Unique identifier of the shuffle type: Type of the shuffle operation + key: + Task key triggering the function """ shuffle = self.shuffles.get(shuffle_id, None) if shuffle is None: shuffle = await self._refresh_shuffle( shuffle_id=shuffle_id, type=type, + key=key, kwargs=kwargs, ) @@ -763,6 +770,7 @@ async def _refresh_shuffle( self, shuffle_id: ShuffleId, type: ShuffleType, + key: str, kwargs: dict, ) -> ShuffleRun: ... @@ -771,8 +779,10 @@ async def _refresh_shuffle( self, shuffle_id: ShuffleId, type: ShuffleType | None = None, + key: str | None = None, kwargs: dict | None = None, ) -> ShuffleRun: + result: dict[str, Any] if type is None: result = await self.worker.scheduler.shuffle_get( id=shuffle_id, @@ -782,6 +792,7 @@ async def _refresh_shuffle( assert kwargs is not None result = await self.worker.scheduler.shuffle_get_or_create( id=shuffle_id, + key=key, type=type, spec={ "npartitions": kwargs["npartitions"], @@ -794,6 +805,7 @@ async def _refresh_shuffle( assert kwargs is not None result = await self.worker.scheduler.shuffle_get_or_create( id=shuffle_id, + key=key, type=type, spec=kwargs, worker=self.worker.address, @@ -812,67 +824,94 @@ async def _refresh_shuffle( return existing else: self.shuffles.pop(shuffle_id) - existing.fail(RuntimeError("Stale Shuffle")) + existing.fail( + RuntimeError("{existing!r} stale, expected run_id=={run_id}") + ) async def _( extension: ShuffleWorkerPlugin, shuffle: ShuffleRun ) -> None: await shuffle.close() - extension._runs.remove(shuffle) + async with extension._runs_cleanup_condition: + extension._runs.remove(shuffle) + extension._runs_cleanup_condition.notify_all() self.worker._ongoing_background_tasks.call_soon(_, self, existing) + + shuffle = self._create_shuffle_run(shuffle_id, result) + self.shuffles[shuffle_id] = shuffle + self._runs.add(shuffle) + return shuffle + + def _create_shuffle_run( + self, shuffle_id: ShuffleId, result: dict[str, Any] + ) -> ShuffleRun: shuffle: ShuffleRun if result["type"] == ShuffleType.DATAFRAME: - shuffle = DataFrameShuffleRun( - column=result["column"], - worker_for=result["worker_for"], - output_workers=result["output_workers"], - id=shuffle_id, - run_id=result["run_id"], - directory=os.path.join( - self.worker.local_directory, - f"shuffle-{shuffle_id}-{result['run_id']}", - ), - executor=self._executor, - local_address=self.worker.address, - rpc=self.worker.rpc, - scheduler=self.worker.scheduler, - memory_limiter_disk=self.memory_limiter_disk, - memory_limiter_comms=self.memory_limiter_comms, - ) + shuffle = self._create_dataframe_shuffle_run(shuffle_id, result) elif result["type"] == ShuffleType.ARRAY_RECHUNK: - shuffle = ArrayRechunkRun( - worker_for=result["worker_for"], - output_workers=result["output_workers"], - old=result["old"], - new=result["new"], - id=shuffle_id, - run_id=result["run_id"], - directory=os.path.join( - self.worker.local_directory, - f"shuffle-{shuffle_id}-{result['run_id']}", - ), - executor=self._executor, - local_address=self.worker.address, - rpc=self.worker.rpc, - scheduler=self.worker.scheduler, - memory_limiter_disk=self.memory_limiter_disk, - memory_limiter_comms=self.memory_limiter_comms, - ) + shuffle = self._create_array_rechunk_run(shuffle_id, result) else: # pragma: no cover raise TypeError(result["type"]) - self.shuffles[shuffle_id] = shuffle - self._runs.add(shuffle) return shuffle + def _create_dataframe_shuffle_run( + self, shuffle_id: ShuffleId, result: dict[str, Any] + ) -> DataFrameShuffleRun: + return DataFrameShuffleRun( + column=result["column"], + worker_for=result["worker_for"], + output_workers=result["output_workers"], + id=shuffle_id, + run_id=result["run_id"], + directory=os.path.join( + self.worker.local_directory, + f"shuffle-{shuffle_id}-{result['run_id']}", + ), + executor=self._executor, + local_address=self.worker.address, + rpc=self.worker.rpc, + scheduler=self.worker.scheduler, + memory_limiter_disk=self.memory_limiter_disk, + memory_limiter_comms=self.memory_limiter_comms, + ) + + def _create_array_rechunk_run( + self, shuffle_id: ShuffleId, result: dict[str, Any] + ) -> ArrayRechunkRun: + return ArrayRechunkRun( + worker_for=result["worker_for"], + output_workers=result["output_workers"], + old=result["old"], + new=result["new"], + id=shuffle_id, + run_id=result["run_id"], + directory=os.path.join( + self.worker.local_directory, + f"shuffle-{shuffle_id}-{result['run_id']}", + ), + executor=self._executor, + local_address=self.worker.address, + rpc=self.worker.rpc, + scheduler=self.worker.scheduler, + memory_limiter_disk=self.memory_limiter_disk, + memory_limiter_comms=self.memory_limiter_comms, + ) + async def teardown(self, worker: Worker) -> None: assert not self.closed self.closed = True + while self.shuffles: _, shuffle = self.shuffles.popitem() - await shuffle.close() - self._runs.remove(shuffle) + self.worker._ongoing_background_tasks.call_soon( + self._close_shuffle_run, shuffle + ) + + async with self._runs_cleanup_condition: + await self._runs_cleanup_condition.wait_for(lambda: not self._runs) + try: self._executor.shutdown(cancel_futures=True) except Exception: # pragma: no cover @@ -904,11 +943,13 @@ def get_or_create_shuffle( type: ShuffleType, **kwargs: Any, ) -> ShuffleRun: + key = thread_state.key return sync( self.worker.loop, self._get_or_create_shuffle, shuffle_id, type, + key, **kwargs, ) @@ -920,7 +961,7 @@ def get_output_partition( meta: pd.DataFrame | None = None, ) -> Any: """ - Task: Retrieve a shuffled output partition from the ShuffleExtension. + Task: Retrieve a shuffled output partition from the ShuffleWorkerPlugin. Calling this for a ``shuffle_id`` which is unknown or incomplete is an error. """ diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 664b31d011..e86f06eb83 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -9,7 +9,6 @@ from collections import defaultdict from collections.abc import Mapping from concurrent.futures import ThreadPoolExecutor -from contextlib import AsyncExitStack from itertools import count from typing import Any from unittest import mock @@ -24,8 +23,7 @@ from dask.utils import stringify from distributed.client import Client -from distributed.diagnostics.plugin import SchedulerPlugin -from distributed.scheduler import Scheduler +from distributed.scheduler import KilledWorker, Scheduler from distributed.scheduler import TaskState as SchedulerTaskState from distributed.shuffle._arrow import serialize_table from distributed.shuffle._limiter import ResourceLimiter @@ -49,6 +47,7 @@ ) from distributed.utils import Deadline from distributed.utils_test import ( + async_poll_for, cluster, gen_cluster, gen_test, @@ -72,7 +71,7 @@ async def check_worker_cleanup( worker: Worker, closed: bool = False, interval: float = 0.01, - timeout: int | None = None, + timeout: int | None = 5, ) -> None: """Assert that the worker has no shuffle state""" deadline = Deadline.after(timeout) @@ -91,15 +90,17 @@ async def check_worker_cleanup( async def check_scheduler_cleanup( - scheduler: Scheduler, interval: float = 0.01, timeout: int | None = None + scheduler: Scheduler, interval: float = 0.01, timeout: int | None = 5 ) -> None: """Assert that the scheduler has no shuffle state""" deadline = Deadline.after(timeout) plugin = scheduler.plugins["shuffle"] assert isinstance(plugin, ShuffleSchedulerPlugin) - while plugin.states and not deadline.expired: + while plugin._shuffles and not deadline.expired: await asyncio.sleep(interval) - assert not plugin.states + assert not plugin.active_shuffles + assert not plugin._shuffles, scheduler.tasks + assert not plugin._archived_by_stimulus assert not plugin.heartbeats @@ -300,12 +301,13 @@ async def test_closed_worker_during_transfer(c, s, a, b): freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) await b.close() - with pytest.raises(RuntimeError): - out = await c.compute(out) + x = await x + y = await y + assert x == y await c.close() await check_worker_cleanup(a) @@ -313,6 +315,85 @@ async def test_closed_worker_during_transfer(c, s, a, b): await check_scheduler_cleanup(s) +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.scheduler.allowed-failures": 0}, +) +async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = c.compute(out.x.size) + await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) + await b.close() + + with pytest.raises(KilledWorker): + await out + + await c.close() + await check_worker_cleanup(a) + await check_worker_cleanup(b, closed=True) + await check_scheduler_cleanup(s) + + +class BlockedGetOrCreateWorkerPlugin(ShuffleWorkerPlugin): + def setup(self, worker: Worker) -> None: + super().setup(worker) + self.in_get_or_create = asyncio.Event() + self.block_get_or_create = asyncio.Event() + + async def _get_or_create_shuffle(self, *args, **kwargs): + self.in_get_or_create.set() + await self.block_get_or_create.wait() + return await super()._get_or_create_shuffle(*args, **kwargs) + + +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.scheduler.allowed-failures": 0}, +) +async def test_get_or_create_from_dangling_transfer(c, s, a, b): + await c.register_worker_plugin(BlockedGetOrCreateWorkerPlugin(), name="shuffle") + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = c.compute(out.x.size) + + shuffle_extA = a.plugins["shuffle"] + shuffle_extB = b.plugins["shuffle"] + shuffle_extB.block_get_or_create.set() + + await shuffle_extA.in_get_or_create.wait() + await b.close() + await async_poll_for( + lambda: not any(ws.processing for ws in s.workers.values()), timeout=5 + ) + + with pytest.raises(KilledWorker): + await out + + await async_poll_for(lambda: not s.plugins["shuffle"].active_shuffles, timeout=5) + assert a.state.tasks + shuffle_extA.block_get_or_create.set() + await async_poll_for(lambda: not a.state.tasks, timeout=10) + + assert not s.plugins["shuffle"].active_shuffles + await check_worker_cleanup(a) + await check_worker_cleanup(b, closed=True) + await c.close() + await check_scheduler_cleanup(s) + + @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)]) async def test_crashed_worker_during_transfer(c, s, a): @@ -325,21 +406,21 @@ async def test_crashed_worker_during_transfer(c, s, a): freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) await wait_until_worker_has_tasks( "shuffle-transfer", killed_worker_address, 1, s ) await n.process.process.kill() - with pytest.raises(RuntimeError): - out = await c.compute(out) + x = await x + y = await y + assert x == y await c.close() await check_worker_cleanup(a) await check_scheduler_cleanup(s) -# TODO: Deduplicate instead of failing: distributed#7324 @gen_cluster(client=True, nthreads=[("", 1)] * 2) async def test_closed_input_only_worker_during_transfer(c, s, a, b): def mock_get_worker_for_range_sharding( @@ -358,12 +439,13 @@ def mock_get_worker_for_range_sharding( freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b, 0.001) await b.close() - with pytest.raises(RuntimeError): - out = await c.compute(out) + x = await x + y = await y + assert x == y await c.close() await check_worker_cleanup(a) @@ -371,7 +453,6 @@ def mock_get_worker_for_range_sharding( await check_scheduler_cleanup(s) -# TODO: Deduplicate instead of failing: distributed#7324 @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)], clean_kwargs={"processes": False}) async def test_crashed_input_only_worker_during_transfer(c, s, a): @@ -393,14 +474,15 @@ def mock_mock_get_worker_for_range_sharding( freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) await wait_until_worker_has_tasks( "shuffle-transfer", n.worker_address, 1, s ) await n.process.process.kill() - with pytest.raises(RuntimeError): - out = await c.compute(out) + x = await x + y = await y + assert x == y await c.close() await check_worker_cleanup(a) @@ -457,7 +539,7 @@ async def test_closed_worker_during_barrier(c, s, a, b): freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) shuffle_id = await wait_until_new_shuffle_is_initialized(s) key = barrier_key(shuffle_id) await wait_for_state(key, "processing", s) @@ -478,9 +560,72 @@ async def test_closed_worker_during_barrier(c, s, a, b): await close_worker.close() alive_shuffle.block_inputs_done.set() + alive_shuffles = alive_worker.extensions["shuffle"].shuffles - with pytest.raises(RuntimeError): - out = await c.compute(out) + def shuffle_restarted(): + try: + return alive_shuffles[shuffle_id].run_id > alive_shuffle.run_id + except KeyError: + return False + + await async_poll_for( + shuffle_restarted, + timeout=5, + ) + restarted_shuffle = alive_shuffles[shuffle_id] + restarted_shuffle.block_inputs_done.set() + + x = await x + y = await y + assert x == y + + await c.close() + await check_worker_cleanup(close_worker, closed=True) + await check_worker_cleanup(alive_worker) + await check_scheduler_cleanup(s) + + +@mock.patch( + "distributed.shuffle._worker_plugin.DataFrameShuffleRun", + BlockedInputsDoneShuffle, +) +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.scheduler.allowed-failures": 0}, +) +async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = c.compute(out.x.size) + shuffle_id = await wait_until_new_shuffle_is_initialized(s) + key = barrier_key(shuffle_id) + await wait_for_state(key, "processing", s) + shuffleA = get_shuffle_run_from_worker(shuffle_id, a) + shuffleB = get_shuffle_run_from_worker(shuffle_id, b) + await shuffleA.in_inputs_done.wait() + await shuffleB.in_inputs_done.wait() + + ts = s.tasks[key] + processing_worker = a if ts.processing_on.address == a.address else b + if processing_worker == a: + close_worker, alive_worker = a, b + alive_shuffle = shuffleB + + else: + close_worker, alive_worker = b, a + alive_shuffle = shuffleA + await close_worker.close() + + with pytest.raises(KilledWorker): + await out + + alive_shuffle.block_inputs_done.set() await c.close() await check_worker_cleanup(close_worker, closed=True) @@ -501,7 +646,7 @@ async def test_closed_other_worker_during_barrier(c, s, a, b): freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) shuffle_id = await wait_until_new_shuffle_is_initialized(s) key = barrier_key(shuffle_id) @@ -524,9 +669,24 @@ async def test_closed_other_worker_during_barrier(c, s, a, b): await close_worker.close() alive_shuffle.block_inputs_done.set() + alive_shuffles = alive_worker.extensions["shuffle"].shuffles - with pytest.raises(RuntimeError, match="shuffle_barrier failed"): - out = await c.compute(out) + def shuffle_restarted(): + try: + return alive_shuffles[shuffle_id].run_id > alive_shuffle.run_id + except KeyError: + return False + + await async_poll_for( + shuffle_restarted, + timeout=5, + ) + restarted_shuffle = alive_shuffles[shuffle_id] + restarted_shuffle.block_inputs_done.set() + + x = await x + y = await y + assert x == y await c.close() await check_worker_cleanup(close_worker, closed=True) @@ -549,20 +709,34 @@ async def test_crashed_other_worker_during_barrier(c, s, a): freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) shuffle_id = await wait_until_new_shuffle_is_initialized(s) key = barrier_key(shuffle_id) # Ensure that barrier is not executed on the nanny s.set_restrictions({key: {a.address}}) await wait_for_state(key, "processing", s, interval=0) - + shuffles = a.extensions["shuffle"].shuffles shuffle = get_shuffle_run_from_worker(shuffle_id, a) await shuffle.in_inputs_done.wait() await n.process.process.kill() shuffle.block_inputs_done.set() - with pytest.raises(RuntimeError, match="shuffle"): - out = await c.compute(out) + def shuffle_restarted(): + try: + return shuffles[shuffle_id].run_id > shuffle.run_id + except KeyError: + return False + + await async_poll_for( + shuffle_restarted, + timeout=5, + ) + restarted_shuffle = get_shuffle_run_from_worker(shuffle_id, a) + restarted_shuffle.block_inputs_done.set() + + x = await x + y = await y + assert x == y await c.close() await check_worker_cleanup(a) @@ -578,12 +752,39 @@ async def test_closed_worker_during_unpack(c, s, a, b): freq="10 s", ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + x, y = c.compute([df.x.size, out.x.size]) await wait_for_tasks_in_state("shuffle-p2p", "memory", 1, b) await b.close() - with pytest.raises(RuntimeError): - out = await c.compute(out) + x = await x + y = await y + assert x == y + + await c.close() + await check_worker_cleanup(a) + await check_worker_cleanup(b, closed=True) + await check_scheduler_cleanup(s) + + +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.scheduler.allowed-failures": 0}, +) +async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = c.compute(out.x.size) + await wait_for_tasks_in_state("shuffle-p2p", "memory", 1, b) + await b.close() + + with pytest.raises(KilledWorker): + await out await c.close() await check_worker_cleanup(a) @@ -602,14 +803,15 @@ async def test_crashed_worker_during_unpack(c, s, a): dtypes={"x": float, "y": float}, freq="10 s", ) + x = await c.compute(df.x.size) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() + y = c.compute(out.x.size) + await wait_until_worker_has_tasks("shuffle-p2p", killed_worker_address, 1, s) await n.process.process.kill() - with pytest.raises( - RuntimeError, - ): - out = await c.compute(out) + + y = await y + assert x == y await c.close() await check_worker_cleanup(a) @@ -857,9 +1059,9 @@ async def test_clean_after_forgotten_early(c, s, a, b): await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, a) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) del out - await check_worker_cleanup(a, timeout=2) - await check_worker_cleanup(b, timeout=2) - await check_scheduler_cleanup(s, timeout=2) + await check_worker_cleanup(a) + await check_worker_cleanup(b) + await check_scheduler_cleanup(s) @gen_cluster(client=True) @@ -910,9 +1112,9 @@ async def test_repeat_shuffle_instance(c, s, a, b, wait_until_forgotten): await c.compute(out) - await check_worker_cleanup(a, timeout=2) - await check_worker_cleanup(b, timeout=2) - await check_scheduler_cleanup(s, timeout=2) + await check_worker_cleanup(a) + await check_worker_cleanup(b) + await check_scheduler_cleanup(s) @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @@ -939,9 +1141,9 @@ async def test_repeat_shuffle_operation(c, s, a, b, wait_until_forgotten): await c.compute(dd.shuffle.shuffle(df, "x", shuffle="p2p")) - await check_worker_cleanup(a, timeout=2) - await check_worker_cleanup(b, timeout=2) - await check_scheduler_cleanup(s, timeout=2) + await check_worker_cleanup(a) + await check_worker_cleanup(b) + await check_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -1056,7 +1258,7 @@ async def test_new_worker(c, s, a, b): ) shuffled = dd.shuffle.shuffle(df, "x", shuffle="p2p") persisted = shuffled.persist() - while not s.plugins["shuffle"].states: + while not s.plugins["shuffle"].active_shuffles: await asyncio.sleep(0.001) async with Worker(s.address) as w: @@ -1131,12 +1333,12 @@ async def test_delete_some_results(c, s, a, b): while not s.tasks or not any(ts.state == "memory" for ts in s.tasks.values()): await asyncio.sleep(0.01) - x = x.partitions[: x.npartitions // 2].persist() + x = x.partitions[: x.npartitions // 2] + x = await c.compute(x.size) - await c.compute(x.size) - del x await check_worker_cleanup(a) await check_worker_cleanup(b) + del x await check_scheduler_cleanup(s) @@ -1515,9 +1717,9 @@ async def test_deduplicate_stale_transfer(c, s, a, b, wait_until_forgotten): y = await c.compute(df.x.size) assert x == y - await check_worker_cleanup(a, timeout=2) - await check_worker_cleanup(b, timeout=2) - await check_scheduler_cleanup(s, timeout=2) + await check_worker_cleanup(a) + await check_worker_cleanup(b) + await check_scheduler_cleanup(s) class BlockedBarrierShuffleWorkerPlugin(ShuffleWorkerPlugin): @@ -1571,9 +1773,9 @@ async def test_handle_stale_barrier(c, s, a, b, wait_until_forgotten): y = await y assert x == y - await check_worker_cleanup(a, timeout=2) - await check_worker_cleanup(b, timeout=2) - await check_scheduler_cleanup(s, timeout=2) + await check_worker_cleanup(a) + await check_worker_cleanup(b) + await check_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -1643,9 +1845,29 @@ async def test_shuffle_run_consistency(c, s, a): worker_plugin.block_barrier.set() await out del out + while s.tasks: + await asyncio.sleep(0) + worker_plugin.block_barrier.clear() - await check_worker_cleanup(a, timeout=2) - await check_scheduler_cleanup(s, timeout=2) + out = dd.shuffle.shuffle(df, "y", shuffle="p2p") + out = out.persist() + independent_shuffle_id = await wait_until_new_shuffle_is_initialized(s) + assert shuffle_id != independent_shuffle_id + + independent_shuffle_dict = scheduler_ext.get( + independent_shuffle_id, a.worker_address + ) + + # Check invariant that the new run ID is larger than the previous + # for independent shuffles + assert new_shuffle_dict["run_id"] < independent_shuffle_dict["run_id"] + + worker_plugin.block_barrier.set() + await out + del out + + await check_worker_cleanup(a) + await check_scheduler_cleanup(s) class BlockedShuffleAccessAndFailWorkerPlugin(ShuffleWorkerPlugin): @@ -1748,94 +1970,6 @@ async def test_replace_stale_shuffle(c, s, a, b): await check_scheduler_cleanup(s) -class BlockedRemoveWorkerSchedulerPlugin(SchedulerPlugin): - def __init__(self, scheduler: Scheduler, *args: Any, **kwargs: Any): - self.scheduler = scheduler - super().__init__(*args, **kwargs) - self.in_remove_worker = asyncio.Event() - self.block_remove_worker = asyncio.Event() - self.scheduler.add_plugin(self, name="blocking") - - async def remove_worker(self, *args: Any, **kwargs: Any) -> None: - self.in_remove_worker.set() - await self.block_remove_worker.wait() - - -class BlockedBarrierSchedulerPlugin(ShuffleSchedulerPlugin): - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self.in_barrier = asyncio.Event() - self.block_barrier = asyncio.Event() - - async def barrier(self, *args: Any, **kwargs: Any) -> None: - self.in_barrier.set() - await self.block_barrier.wait() - await super().barrier(*args, **kwargs) - - -@gen_cluster( - client=True, - nthreads=[], - scheduler_kwargs={ - "extensions": { - "blocking": BlockedRemoveWorkerSchedulerPlugin, - "shuffle": BlockedBarrierSchedulerPlugin, - } - }, -) -async def test_closed_worker_returns_before_barrier(c, s): - async with AsyncExitStack() as stack: - workers = [await stack.enter_async_context(Worker(s.address)) for _ in range(2)] - - df = dask.datasets.timeseries( - start="2000-01-01", - end="2000-01-10", - dtypes={"x": float, "y": float}, - freq="10 s", - ) - out = dd.shuffle.shuffle(df, "x", shuffle="p2p") - out = out.persist() - shuffle_id = await wait_until_new_shuffle_is_initialized(s) - key = barrier_key(shuffle_id) - await wait_for_state(key, "processing", s) - scheduler_plugin = s.plugins["shuffle"] - await scheduler_plugin.in_barrier.wait() - - flushes = [ - get_shuffle_run_from_worker(shuffle_id, w)._flush_comm() for w in workers - ] - await asyncio.gather(*flushes) - - ts = s.tasks[key] - to_close = None - for worker in workers: - if ts.processing_on.address != worker.address: - to_close = worker - break - assert to_close - closed_port = to_close.port - await to_close.close() - - blocking_plugin = s.plugins["blocking"] - assert blocking_plugin.in_remove_worker.is_set() - - workers.append( - await stack.enter_async_context(Worker(s.address, port=closed_port)) - ) - - scheduler_plugin.block_barrier.set() - - with pytest.raises( - RuntimeError, match=f"shuffle_barrier failed .* {shuffle_id}" - ): - await c.compute(out.x.size) - - blocking_plugin.block_remove_worker.set() - await c.close() - await asyncio.gather(*[check_worker_cleanup(w) for w in workers]) - await check_scheduler_cleanup(s) - - @gen_cluster(client=True) async def test_handle_null_partitions_p2p_shuffling(c, s, *workers): data = [