Skip to content

Commit

Permalink
Remove superfluous ShuffleSchedulerExtension.barriers (#7389)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Dec 12, 2022
1 parent 19deee3 commit c4af791
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions distributed/shuffle/_shuffle_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,6 @@ class ShuffleSchedulerExtension(SchedulerPlugin):
participating_workers: dict[ShuffleId, set[str]]
tombstones: set[ShuffleId]
erred_shuffles: dict[ShuffleId, Exception]
barriers: dict[ShuffleId, str]

def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
Expand All @@ -629,7 +628,6 @@ def __init__(self, scheduler: Scheduler):
self.participating_workers = {}
self.tombstones = set()
self.erred_shuffles = {}
self.barriers = {}
self.scheduler.add_plugin(self)

def shuffle_ids(self) -> set[ShuffleId]:
Expand All @@ -646,7 +644,7 @@ def barrier_key(cls, shuffle_id: ShuffleId) -> str:

@classmethod
def id_from_key(cls, key: str) -> ShuffleId:
assert "shuffle-barrier-" in key
assert key.startswith("shuffle-barrier-")
return ShuffleId(key.replace("shuffle-barrier-", ""))

def get(
Expand Down Expand Up @@ -674,7 +672,6 @@ def get(
output_workers = set()

name = self.barrier_key(id)
self.barriers[id] = name
mapping = {}

for ts in self.scheduler.tasks[name].dependents:
Expand Down Expand Up @@ -724,7 +721,7 @@ async def remove_worker(self, scheduler: Scheduler, worker: str) -> None:
contact_workers = shuffle_workers.copy()
contact_workers.discard(worker)
affected_shuffles.add(shuffle_id)
name = self.barriers[shuffle_id]
name = self.barrier_key(shuffle_id)
barrier_task = self.scheduler.tasks.get(name)
if barrier_task:
barriers.append(barrier_task)
Expand Down Expand Up @@ -769,11 +766,12 @@ def transition(
) -> None:
if finish != "forgotten":
return
if key not in self.barriers.values():

if not key.startswith("shuffle-barrier-"):
return
shuffle_id = self.id_from_key(key)
if shuffle_id not in self.worker_for:
return

shuffle_id = ShuffleSchedulerExtension.id_from_key(key)
participating_workers = self.participating_workers[shuffle_id]
worker_msgs = {
worker: [
Expand Down Expand Up @@ -806,7 +804,6 @@ def _clean_on_scheduler(self, id: ShuffleId) -> None:
del self.completed_workers[id]
del self.participating_workers[id]
self.erred_shuffles.pop(id, None)
del self.barriers[id]
with contextlib.suppress(KeyError):
del self.heartbeats[id]

Expand Down

0 comments on commit c4af791

Please sign in to comment.