Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure P2PShuffle extensions #7390

Merged
merged 5 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions distributed/shuffle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from __future__ import annotations

from distributed.shuffle._scheduler_extension import ShuffleSchedulerExtension
from distributed.shuffle._shuffle import P2PShuffleLayer, rearrange_by_column_p2p
from distributed.shuffle._shuffle_extension import (
ShuffleSchedulerExtension,
ShuffleWorkerExtension,
)
from distributed.shuffle._worker_extension import ShuffleWorkerExtension

__all__ = [
"P2PShuffleLayer",
Expand Down
236 changes: 236 additions & 0 deletions distributed/shuffle/_scheduler_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from __future__ import annotations

import asyncio
import contextlib
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any

from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.shuffle._shuffle import P2PShuffleLayer, ShuffleId

if TYPE_CHECKING:
from distributed.scheduler import Recs, Scheduler, TaskStateState, WorkerState

logger = logging.getLogger(__name__)


class ShuffleSchedulerExtension(SchedulerPlugin):
"""
Shuffle extension for the scheduler

Today this mostly just collects heartbeat messages for the dashboard,
but in the future it may be responsible for more

See Also
--------
ShuffleWorkerExtension
"""

scheduler: Scheduler
worker_for: dict[ShuffleId, dict[int, str]]
heartbeats: defaultdict[ShuffleId, dict]
schemas: dict[ShuffleId, bytes]
columns: dict[ShuffleId, str]
output_workers: dict[ShuffleId, set[str]]
completed_workers: dict[ShuffleId, set[str]]
participating_workers: dict[ShuffleId, set[str]]
tombstones: set[ShuffleId]
erred_shuffles: dict[ShuffleId, Exception]
barriers: dict[ShuffleId, str]

def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_get": self.get,
"shuffle_get_participating_workers": self.get_participating_workers,
"shuffle_register_complete": self.register_complete,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.worker_for = {}
self.schemas = {}
self.columns = {}
self.output_workers = {}
self.completed_workers = {}
self.participating_workers = {}
self.tombstones = set()
self.erred_shuffles = {}
self.barriers = {}
self.scheduler.add_plugin(self)

def shuffle_ids(self) -> set[ShuffleId]:
return set(self.worker_for)

def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.output_workers:
self.heartbeats[shuffle_id][ws.address].update(d)

def get(
self,
id: ShuffleId,
schema: bytes | None,
column: str | None,
npartitions: int | None,
worker: str,
) -> dict:

if id in self.tombstones:
return {
"status": "ERROR",
"message": f"Shuffle {id} has already been forgotten",
}
if exception := self.erred_shuffles.get(id):
return {"status": "ERROR", "message": str(exception)}

if id not in self.worker_for:
assert schema is not None
assert column is not None
assert npartitions is not None
workers = list(self.scheduler.workers)
output_workers = set()

name = P2PShuffleLayer.barrier_key(id)
self.barriers[id] = name
mapping = {}

for ts in self.scheduler.tasks[name].dependents:
part = ts.annotations["shuffle"]
if ts.worker_restrictions:
output_worker = list(ts.worker_restrictions)[0]
else:
output_worker = get_worker_for(part, workers, npartitions)
mapping[part] = output_worker
output_workers.add(output_worker)
self.scheduler.set_restrictions({ts.key: {output_worker}})

self.worker_for[id] = mapping
self.schemas[id] = schema
self.columns[id] = column
self.output_workers[id] = output_workers
self.completed_workers[id] = set()
self.participating_workers[id] = output_workers.copy()

self.participating_workers[id].add(worker)
return {
"status": "OK",
"worker_for": self.worker_for[id],
"column": self.columns[id],
"schema": self.schemas[id],
"output_workers": self.output_workers[id],
}

def get_participating_workers(self, id: ShuffleId) -> list[str]:
return list(self.participating_workers[id])

async def remove_worker(self, scheduler: Scheduler, worker: str) -> None:
affected_shuffles = set()
broadcasts = []
from time import time

recs: Recs = {}
stimulus_id = f"shuffle-failed-worker-left-{time()}"
barriers = []
for shuffle_id, shuffle_workers in self.participating_workers.items():
if worker not in shuffle_workers:
continue
exception = RuntimeError(
f"Worker {worker} left during active shuffle {shuffle_id}"
)
self.erred_shuffles[shuffle_id] = exception
contact_workers = shuffle_workers.copy()
contact_workers.discard(worker)
affected_shuffles.add(shuffle_id)
name = self.barriers[shuffle_id]
barrier_task = self.scheduler.tasks.get(name)
if barrier_task:
barriers.append(barrier_task)
broadcasts.append(
scheduler.broadcast(
msg={
"op": "shuffle_fail",
"message": str(exception),
"shuffle_id": shuffle_id,
},
workers=list(contact_workers),
)
)

results = await asyncio.gather(*broadcasts, return_exceptions=True)
for barrier_task in barriers:
if barrier_task.state == "memory":
for dt in barrier_task.dependents:
if worker not in dt.worker_restrictions:
continue
dt.worker_restrictions.clear()
recs.update({dt.key: "waiting"})
# TODO: Do we need to handle other states?
self.scheduler.transitions(recs, stimulus_id=stimulus_id)

# Assumption: No new shuffle tasks scheduled on the worker
# + no existing tasks anymore
# All task-finished/task-errer are queued up in batched stream

exceptions = [result for result in results if isinstance(result, Exception)]
if exceptions:
# TODO: Do we need to handle errors here?
raise RuntimeError(exceptions)

def transition(
self,
key: str,
start: TaskStateState,
finish: TaskStateState,
*args: Any,
**kwargs: Any,
) -> None:
if finish != "forgotten":
return
if key not in self.barriers.values():

return

shuffle_id = P2PShuffleLayer.id_from_key(key)
participating_workers = self.participating_workers[shuffle_id]
worker_msgs = {
worker: [
{
"op": "shuffle-fail",
"shuffle_id": shuffle_id,
"message": f"Shuffle {shuffle_id} forgotten",
}
]
for worker in participating_workers
}
self._clean_on_scheduler(shuffle_id)
self.scheduler.send_all({}, worker_msgs)

def register_complete(self, id: ShuffleId, worker: str) -> None:
"""Learn from a worker that it has completed all reads of a shuffle"""
if exception := self.erred_shuffles.get(id):
raise exception
if id not in self.completed_workers:
logger.info("Worker shuffle reported complete after shuffle was removed")
return
self.completed_workers[id].add(worker)

def _clean_on_scheduler(self, id: ShuffleId) -> None:
self.tombstones.add(id)
del self.worker_for[id]
del self.schemas[id]
del self.columns[id]
del self.output_workers[id]
del self.completed_workers[id]
del self.participating_workers[id]
self.erred_shuffles.pop(id, None)
del self.barriers[id]
with contextlib.suppress(KeyError):
del self.heartbeats[id]


def get_worker_for(output_partition: int, workers: list[str], npartitions: int) -> str:
"Get the address of the worker which should hold this output partition number"
i = len(workers) * output_partition // npartitions
return workers[i]
25 changes: 20 additions & 5 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar, NewType

from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph
from dask.layers import SimpleShuffleLayer

from distributed.shuffle._shuffle_extension import ShuffleId, ShuffleWorkerExtension

logger = logging.getLogger("distributed.shuffle")
if TYPE_CHECKING:
import pandas as pd

from dask.dataframe import DataFrame

# circular dependency
from distributed.shuffle._worker_extension import ShuffleWorkerExtension

ShuffleId = NewType("ShuffleId", str)
Copy link
Member Author

@hendrikmakait hendrikmakait Dec 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved this into _shuffle now since the P2PShuffleLayer is responsible for all ID- or naming-related things. Some helpers from the ShuffleSchedulerExtension moved here as well which allows us to rely on constants instead of hard-coded strings within methods.

Moving ShuffleId into __init__ would create some cyclic typing dependencies that I'm not too fond of.



def _get_worker_extension() -> ShuffleWorkerExtension:
from distributed import get_worker
Expand Down Expand Up @@ -102,6 +105,9 @@ def rearrange_by_column_p2p(


class P2PShuffleLayer(SimpleShuffleLayer):
_BARRIER_PREFIX: ClassVar[str] = "shuffle-barrier-"
_TRANSFER_PREFIX: ClassVar[str] = "shuffle-transfer-"

def __init__(
self,
name: str,
Expand Down Expand Up @@ -132,6 +138,15 @@ def get_split_keys(self) -> list:
# TODO: This is doing some funky stuff to set priorities but we don't need this
return []

@classmethod
def barrier_key(cls, shuffle_id: ShuffleId) -> str:
return cls._BARRIER_PREFIX + shuffle_id

@classmethod
def id_from_key(cls, key: str) -> ShuffleId:
assert cls._BARRIER_PREFIX in key
return ShuffleId(key.replace(cls._BARRIER_PREFIX, ""))

def __repr__(self) -> str:
return (
f"{type(self).__name__}<name='{self.name}', npartitions={self.npartitions}>"
Expand All @@ -152,8 +167,8 @@ def _cull(self, parts_out: list) -> P2PShuffleLayer:
def _construct_graph(self, deserializing: Any = None) -> dict[tuple | str, tuple]:
token = tokenize(self.name_input, self.column, self.npartitions, self.parts_out)
dsk: dict[tuple | str, tuple] = {}
barrier_key = "shuffle-barrier-" + token
name = "shuffle-transfer-" + token
barrier_key = self._BARRIER_PREFIX + token
name = self._TRANSFER_PREFIX + token
transfer_keys = list()
for i in range(self.npartitions_input):
transfer_keys.append((name, i))
Expand Down
Loading