Skip to content

Commit

Permalink
Don't queue tasks on workers
Browse files Browse the repository at this point in the history
With this change, workers are not sent more tasks than they have threads. This eliminates root task overproduction and removes the need for work stealing.

This is optimized for a minimal diff. There is _lots_ we might change/rip out if we went forward with this.

Performance will likely be poor on some workloads without speculative task assignment and root task co-assignment.
  • Loading branch information
gjoseph92 committed Jun 16, 2022
1 parent cb88e3b commit d93fd4b
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 67 deletions.
11 changes: 9 additions & 2 deletions distributed/dashboard/components/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3096,9 +3096,13 @@ def __init__(self, scheduler, **kwargs):
<span style="font-size: 10px; font-family: Monaco, monospace;">@erred</span>
</div>
<div>
<span style="font-size: 14px; font-weight: bold;">Ready:</span>&nbsp;
<span style="font-size: 14px; font-weight: bold;">Processing:</span>&nbsp;
<span style="font-size: 10px; font-family: Monaco, monospace;">@processing</span>
</div>
<div>
<span style="font-size: 14px; font-weight: bold;">No worker:</span>&nbsp;
<span style="font-size: 10px; font-family: Monaco, monospace;">@no_worker</span>
</div>
""",
)
self.root.add_tools(hover)
Expand All @@ -3112,6 +3116,7 @@ def update(self):
"released": {},
"processing": {},
"waiting": {},
"no_worker": {},
}

for tp in self.scheduler.task_prefixes.values():
Expand All @@ -3122,6 +3127,7 @@ def update(self):
state["released"][tp.name] = active_states["released"]
state["processing"][tp.name] = active_states["processing"]
state["waiting"][tp.name] = active_states["waiting"]
state["no_worker"][tp.name] = active_states["no-worker"]

state["all"] = {k: sum(v[k] for v in state.values()) for k in state["memory"]}

Expand All @@ -3134,7 +3140,7 @@ def update(self):

totals = {
k: sum(state[k].values())
for k in ["all", "memory", "erred", "released", "waiting"]
for k in ["all", "memory", "erred", "released", "waiting", "no_worker"]
}
totals["processing"] = totals["all"] - sum(
v for k, v in totals.items() if k != "all"
Expand All @@ -3144,6 +3150,7 @@ def update(self):
"Progress -- total: %(all)s, "
"in-memory: %(memory)s, processing: %(processing)s, "
"waiting: %(waiting)s, "
"no worker: %(no_worker)s, "
"erred: %(erred)s" % totals
)

Expand Down
2 changes: 1 addition & 1 deletion distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ distributed:
idle-timeout: null # Shut down after this duration, like "1h" or "30 minutes"
transition-log-length: 100000
events-log-length: 100000
work-stealing: True # workers should steal tasks from each other
work-stealing: False # workers should steal tasks from each other
work-stealing-interval: 100ms # Callback time for work stealing
worker-ttl: "5 minutes" # like '60s'. Time to live for workers. They must heartbeat faster than this
pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings
Expand Down
157 changes: 95 additions & 62 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,14 +792,6 @@ class TaskGroup:
#: The result types of this TaskGroup
types: set[str]

#: The worker most recently assigned a task from this group, or None when the group
#: is not identified to be root-like by `SchedulerState.decide_worker`.
last_worker: WorkerState | None

#: If `last_worker` is not None, the number of times that worker should be assigned
#: subsequent tasks until a new worker is chosen.
last_worker_tasks_left: int

prefix: TaskPrefix | None
start: float
stop: float
Expand All @@ -819,8 +811,6 @@ def __init__(self, name: str):
self.start = 0.0
self.stop = 0.0
self.all_durations = defaultdict(float)
self.last_worker = None
self.last_worker_tasks_left = 0

def add_duration(self, action: str, start: float, stop: float) -> None:
duration = stop - start
Expand Down Expand Up @@ -1223,7 +1213,7 @@ class SchedulerState:
* **tasks:** ``{task key: TaskState}``
Tasks currently known to the scheduler
* **unrunnable:** ``{TaskState}``
Tasks in the "no-worker" state
Tasks in the "no-worker" state in priority order
* **workers:** ``{worker key: WorkerState}``
Workers currently connected to the scheduler
Expand Down Expand Up @@ -1283,7 +1273,7 @@ def __init__(
host_info: dict,
resources: dict,
tasks: dict,
unrunnable: set,
unrunnable: SortedSet,
validate: bool,
plugins: Iterable[SchedulerPlugin] = (),
transition_counter_max: int | Literal[False] = False,
Expand Down Expand Up @@ -1683,10 +1673,10 @@ def transition_no_worker_waiting(self, key, stimulus_id):
ts.state = "waiting"

if not ts.waiting_on:
if self.workers:
if self.idle:
recommendations[key] = "processing"
else:
self.unrunnable.add(ts)
self.unrunnable.add(ts) # TODO avoid remove-add
ts.state = "no-worker"

return recommendations, client_msgs, worker_msgs
Expand Down Expand Up @@ -1753,59 +1743,40 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None:
Otherwise, we pick the least occupied worker, or pick from all workers
in a round-robin fashion.
"""
if not self.workers:
if not self.idle:
return None

tg = ts.group
valid_workers = self.valid_workers(ts)

if (
valid_workers is not None
and not valid_workers
and not ts.loose_restrictions
):
self.unrunnable.add(ts)
ts.state = "no-worker"
return None

# Group is larger than cluster with few dependencies?
# Minimize future data transfers.
if (
valid_workers is None
and len(tg) > self.total_nthreads * 2
and len(tg.dependencies) < 5
and sum(map(len, tg.dependencies)) < 5
):
ws = tg.last_worker

if not (ws and tg.last_worker_tasks_left and ws.address in self.workers):
# Last-used worker is full or unknown; pick a new worker for the next few tasks
ws = min(
(self.idle or self.workers).values(),
key=partial(self.worker_objective, ts),
)
tg.last_worker_tasks_left = math.floor(
(len(tg) / self.total_nthreads) * ws.nthreads
)

# Record `last_worker`, or clear it on the final task
tg.last_worker = (
ws if tg.states["released"] + tg.states["waiting"] > 1 else None
# self.unrunnable.add(ts)
# ts.state = "no-worker"
raise NotImplementedError(
"not entirely sure if this needs to be a separate case or not"
)
tg.last_worker_tasks_left -= 1
return ws
return None

if ts.dependencies or valid_workers is not None:
if len(self.idle) < len(self.workers):
if valid_workers:
# TODO make `valid_workers` handle this for us
valid_workers.intersection_update(self.idle.values())
else:
valid_workers = set(self.idle.values())

ws = decide_worker(
ts,
self.workers.values(),
self.idle.values(), # TODO actually all workers??
valid_workers,
partial(self.worker_objective, ts),
)
else:
# Fastpath when there are no related tasks or restrictions
worker_pool = self.idle or self.workers
wp_vals = worker_pool.values()
wp_vals = self.idle.values()
n_workers: int = len(wp_vals)
if n_workers < 20: # smart but linear in small case
ws = min(wp_vals, key=operator.attrgetter("occupancy"))
Expand All @@ -1826,6 +1797,8 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None:

if self.validate and ws is not None:
assert ws.address in self.workers
assert ws.address in self.idle, (ws, list(self.idle.values()))
assert len(ws.processing) < ws.nthreads

return ws

Expand All @@ -1847,6 +1820,7 @@ def transition_waiting_processing(self, key, stimulus_id):

ws = self.decide_worker(ts)
if ws is None:
recommendations[ts.key] = "no-worker"
return recommendations, client_msgs, worker_msgs
worker = ws.address

Expand All @@ -1873,6 +1847,30 @@ def transition_waiting_processing(self, key, stimulus_id):
pdb.set_trace()
raise

def transition_waiting_no_worker(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
if self.validate:
assert not ts.waiting_on
assert not ts.who_has
assert not ts.exception_blame
assert not ts.processing_on
assert not ts.has_lost_dependencies
assert ts not in self.unrunnable
assert all(dts.who_has for dts in ts.dependencies)

ts.state = "no-worker"
self.unrunnable.add(ts)

return {}, {}, {}
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb

pdb.set_trace()
raise

def transition_waiting_memory(
self,
key,
Expand Down Expand Up @@ -1999,7 +1997,7 @@ def transition_processing_memory(
if nbytes is not None:
ts.set_nbytes(nbytes)

_remove_from_processing(self, ts)
_remove_from_processing(self, ts, recommendations)

_add_to_memory(
self, ts, ws, recommendations, client_msgs, type=type, typename=typename
Expand Down Expand Up @@ -2232,7 +2230,7 @@ def transition_processing_released(self, key, stimulus_id):
assert not ts.waiting_on
assert self.tasks[key].state == "processing"

w: str = _remove_from_processing(self, ts)
w: str = _remove_from_processing(self, ts, recommendations)
if w:
worker_msgs[w] = [
{
Expand Down Expand Up @@ -2301,7 +2299,7 @@ def transition_processing_erred(
ws = ts.processing_on
ws.actors.remove(ts)

w = _remove_from_processing(self, ts)
w = _remove_from_processing(self, ts, recommendations)

ts.erred_on.add(w or worker) # type: ignore
if exception is not None:
Expand Down Expand Up @@ -2494,6 +2492,7 @@ def transition_released_forgotten(self, key, stimulus_id):
("released", "waiting"): transition_released_waiting,
("waiting", "released"): transition_waiting_released,
("waiting", "processing"): transition_waiting_processing,
("waiting", "no-worker"): transition_waiting_no_worker,
("waiting", "memory"): transition_waiting_memory,
("processing", "released"): transition_processing_released,
("processing", "memory"): transition_processing_memory,
Expand Down Expand Up @@ -2548,9 +2547,11 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0):
- Saturated: have enough work to stay busy
- Idle: do not have enough work to stay busy
They are considered saturated if they both have enough tasks to occupy
all of their threads, and if the expected runtime of those tasks is
large enough.
They are considered idle if they don't have enough tasks to occupy
all of their threads.
Idle and saturated are not dichotomous; a worker can be neither
idle nor saturated. TODO should this be the case?
This is useful for load balancing and adaptivity.
"""
Expand All @@ -2565,7 +2566,7 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0):

idle = self.idle
saturated = self.saturated
if p < nc or occ < nc * avg / 2:
if p < nc:
idle[ws.address] = ws
saturated.discard(ws)
else:
Expand Down Expand Up @@ -2662,6 +2663,7 @@ def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None
else:
s &= ww

# TODO restrict to idle & running
if s is None:
if len(self.running) < len(self.workers):
return self.running.copy()
Expand All @@ -2672,6 +2674,22 @@ def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None

return s

def is_valid_worker_for_task(self, ws: WorkerState, ts: TaskState) -> bool:
if ts.worker_restrictions and ws.address not in ts.worker_restrictions:
return False
if ts.host_restrictions and ws.host not in set(
map(self.coerce_hostname, ts.host_restrictions)
):
return False
if ts.resource_restrictions:
if not ws.resources:
# common case fastpath
return False
for resource, required in ts.resource_restrictions.items():
if (supplied := ws.resources.get(resource)) and supplied < required:
return False
return True

def consume_resources(self, ts: TaskState, ws: WorkerState):
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] += required
Expand Down Expand Up @@ -2765,14 +2783,14 @@ def bulk_schedule_after_adding_worker(self, ws: WorkerState):
ordering, so the recommendations are sorted by priority order here.
"""
ts: TaskState
tasks = []
for ts in self.unrunnable:
tasks: list[TaskState] = []
for ts in self.unrunnable: # NOTE: priority ordered
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
tasks.append(ts)
# These recommendations will generate {"op": "compute-task"} messages
# to the worker in reversed order
tasks.sort(key=operator.attrgetter("priority"), reverse=True)
if len(tasks) == ws.nthreads:
break
# FIXME why were these in reverse priority order before??
return {ts.key: "waiting" for ts in tasks}


Expand Down Expand Up @@ -2980,7 +2998,7 @@ def __init__(
self.generation = 0
self._last_client = None
self._last_time = 0
unrunnable = set()
unrunnable = SortedSet(key=operator.attrgetter("priority"))

self.datasets = {}

Expand Down Expand Up @@ -7131,9 +7149,13 @@ def request_remove_replicas(
)


def _remove_from_processing(state: SchedulerState, ts: TaskState) -> str | None:
def _remove_from_processing(
state: SchedulerState, ts: TaskState, recommendations: dict[str, str]
) -> str | None:
"""Remove *ts* from the set of processing tasks.
Recommend the next unrunnable task is scheduled.
See also
--------
Scheduler._set_duration_estimate
Expand All @@ -7157,6 +7179,17 @@ def _remove_from_processing(state: SchedulerState, ts: TaskState) -> str | None:
state.check_idle_saturated(ws)
state.release_resources(ts, ws)

assert len(ws.processing) < ws.nthreads, (len(ws.processing), ws.nthreads)
assert ws.address in state.idle

uts: TaskState
for uts in state.unrunnable:
if state.is_valid_worker_for_task(
ws, uts
): # TODO linear search is inefficient with restrictions
recommendations[uts.key] = "waiting"
break

return ws.address


Expand Down
Loading

0 comments on commit d93fd4b

Please sign in to comment.