From 7ebd1d99240448eebc346370ce73dece78c92e0d Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Thu, 23 Jun 2022 16:18:39 -0600 Subject: [PATCH] Fix co-assignment for binary operations Bit of a hack, but closes https://github.com/dask/distributed/issues/6597. I'd like to have a better metric for the batch size, but I think this is about as good as we can get. Any reasonably large number will do here. --- distributed/scheduler.py | 31 +++++++++-------------- distributed/tests/test_scheduler.py | 39 ++++++++++++++++------------- 2 files changed, 33 insertions(+), 37 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index adc3af8194..f039348533 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -804,14 +804,6 @@ class TaskGroup: #: The result types of this TaskGroup types: set[str] - #: The worker most recently assigned a task from this group, or None when the group - #: is not identified to be root-like by `SchedulerState.decide_worker`. - last_worker: WorkerState | None - - #: If `last_worker` is not None, the number of times that worker should be assigned - #: subsequent tasks until a new worker is chosen. - last_worker_tasks_left: int - prefix: TaskPrefix | None start: float stop: float @@ -831,8 +823,6 @@ def __init__(self, name: str): self.start = 0.0 self.stop = 0.0 self.all_durations = defaultdict(float) - self.last_worker = None - self.last_worker_tasks_left = 0 def add_duration(self, action: str, start: float, stop: float) -> None: duration = stop - start @@ -1269,6 +1259,8 @@ class SchedulerState: "extensions", "host_info", "idle", + "last_root_worker", + "last_root_worker_tasks_left", "n_tasks", "queued", "resources", @@ -1337,6 +1329,8 @@ def __init__( self.total_nthreads = 0 self.total_occupancy = 0.0 self.unknown_durations: dict[str, set[TaskState]] = {} + self.last_root_worker: WorkerState | None = None + self.last_root_worker_tasks_left: int = 0 self.queued = queued self.unrunnable = unrunnable self.validate = validate @@ -1807,25 +1801,24 @@ def decide_worker( recommendations[ts.key] = "no-worker" return None - lws = tg.last_worker + lws = self.last_root_worker if not ( lws - and tg.last_worker_tasks_left + and self.last_root_worker_tasks_left and self.workers.get(lws.address) is lws ): # Last-used worker is full or unknown; pick a new worker for the next few tasks - ws = min(pool, key=partial(self.worker_objective, ts)) - tg.last_worker_tasks_left = math.floor( + ws = self.last_root_worker = min( + pool, key=lambda ws: len(ws.processing) / ws.nthreads + ) + # TODO better batching metric (`len(tg)` is not necessarily the total number of root tasks!) + self.last_root_worker_tasks_left = math.floor( (len(tg) / self.total_nthreads) * ws.nthreads ) else: ws = lws - # Record `last_worker`, or clear it on the final task - tg.last_worker = ( - ws if tg.states["released"] + tg.states["waiting"] > 1 else None - ) - tg.last_worker_tasks_left -= 1 + self.last_root_worker_tasks_left -= 1 return ws if not self.idle: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 509cad3712..a4fbf6242c 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -134,7 +134,6 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c): assert x.key in a.data or x.key in b.data -# @pytest.mark.skip("Current queuing does not support co-assignment") @pytest.mark.parametrize("ndeps", [0, 1, 4]) @pytest.mark.parametrize( "nthreads", @@ -151,10 +150,6 @@ def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads): "distributed.scheduler.work-stealing": False, "distributed.scheduler.worker-saturation": float("inf"), }, - scheduler_kwargs=dict( # TODO remove - dashboard=True, - dashboard_address=":8787", - ), ) async def test_decide_worker_coschedule_order_neighbors_(c, s, *workers): r""" @@ -254,6 +249,24 @@ def random(**kwargs): test_decide_worker_coschedule_order_neighbors_() +@pytest.mark.parametrize("ngroups", [1, 2, 3, 5]) +@gen_cluster( + client=True, + nthreads=[("", 1), ("", 1)], + config={ + "distributed.scheduler.worker-saturation": float("inf"), + }, +) +async def test_decide_worker_coschedule_order_binary_op(c, s, a, b, ngroups): + roots = [[delayed(i, name=f"x-{n}-{i}") for i in range(8)] for n in range(ngroups)] + zs = [sum(rs) for rs in zip(*roots)] + + await c.gather(c.compute(zs)) + + assert not a.incoming_transfer_log, [l["keys"] for l in a.incoming_transfer_log] + assert not b.incoming_transfer_log, [l["keys"] for l in b.incoming_transfer_log] + + @pytest.mark.slow @gen_cluster( client=True, @@ -381,17 +394,7 @@ async def _test_saturation_factor(c, s, a, b): @pytest.mark.skip("Current queuing does not support co-assignment") -@pytest.mark.parametrize( - "saturation_factor", - [ - 1.0, - 2.0, - pytest.param( - float("inf"), - marks=pytest.mark.skip("https://github.com/dask/distributed/issues/6597"), - ), - ], -) +@pytest.mark.parametrize("saturation_factor", [1.0, 2.0, float("inf")]) @gen_cluster( client=True, nthreads=[("", 2), ("", 1)], @@ -406,8 +409,8 @@ async def test_oversaturation_multiple_task_groups(c, s, a, b, saturation_factor assert not a.incoming_transfer_log, [l["keys"] for l in a.incoming_transfer_log] assert not b.incoming_transfer_log, [l["keys"] for l in b.incoming_transfer_log] - assert len(a.tasks) == 18 - assert len(b.tasks) == 9 + assert len(a.state.tasks) == 18 + assert len(b.state.tasks) == 9 @pytest.mark.slow