Skip to content

Commit

Permalink
stimulus_id for all Instructions (#6347)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored May 16, 2022
1 parent 877ef5c commit 5ca7a5a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 39 deletions.
24 changes: 11 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4726,7 +4726,9 @@ def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None:
if not ts.who_has:
self.transitions({key: "released"}, stimulus_id)

def handle_long_running(self, key=None, worker=None, compute_duration=None):
def handle_long_running(
self, key: str, worker: str, compute_duration: float, stimulus_id: str
) -> None:
"""A task has seceded from the thread pool
We stop the task from being stolen in the future, and change task
Expand All @@ -4735,27 +4737,23 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None):
if key not in self.tasks:
logger.debug("Skipping long_running since key %s was already released", key)
return
ts: TaskState = self.tasks[key]
ts = self.tasks[key]
steal = self.extensions.get("stealing")
if steal is not None:
steal.remove_key_from_stealable(ts)

ws: WorkerState = ts.processing_on
ws = ts.processing_on
if ws is None:
logger.debug("Received long-running signal from duplicate task. Ignoring.")
return

if compute_duration:
old_duration: float = ts.prefix.duration_average
new_duration: float = compute_duration
if old_duration < 0:
avg_duration = new_duration
else:
avg_duration = 0.5 * old_duration + 0.5 * new_duration

ts.prefix.duration_average = avg_duration
old_duration = ts.prefix.duration_average
if old_duration < 0:
ts.prefix.duration_average = compute_duration
else:
ts.prefix.duration_average = (old_duration + compute_duration) / 2

occ: float = ws.processing[ts]
occ = ws.processing[ts]
ws.occupancy -= occ
self.total_occupancy -= occ
# Cannot remove from processing since we're using this for things like
Expand Down
8 changes: 4 additions & 4 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2555,11 +2555,11 @@ def transition_executing_long_running(
self._executing.discard(ts)
self.long_running.add(ts.key)

smsg = LongRunningMsg(
key=ts.key, compute_duration=compute_duration, stimulus_id=stimulus_id
)
return merge_recs_instructions(
(
{},
[LongRunningMsg(key=ts.key, compute_duration=compute_duration)],
),
({}, [smsg]),
self._ensure_computing(),
)

Expand Down
39 changes: 17 additions & 22 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,46 +257,47 @@ def __repr__(self) -> str:
return f"<{type(self).__name__}: {len(self)} items>"


@dataclass
class Instruction:
"""Command from the worker state machine to the Worker, in response to an event"""

__slots__ = ()
__slots__ = ("stimulus_id",)
stimulus_id: str


@dataclass
class GatherDep(Instruction):
__slots__ = ("worker", "to_gather", "total_nbytes")
worker: str
to_gather: set[str]
total_nbytes: int
stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore


@dataclass
class Execute(Instruction):
__slots__ = ("key", "stimulus_id")
__slots__ = ("key",)
key: str
stimulus_id: str


class SendMessageToScheduler(Instruction):
@dataclass
class EnsureCommunicatingAfterTransitions(Instruction):
__slots__ = ()


@dataclass
class SendMessageToScheduler(Instruction):
#: Matches a key in Scheduler.stream_handlers
op: ClassVar[str]
__slots__ = ()

def to_dict(self) -> dict[str, Any]:
"""Convert object to dict so that it can be serialized with msgpack"""
d = {k: getattr(self, k) for k in self.__annotations__}
d["op"] = self.op
d["stimulus_id"] = self.stimulus_id
return d


@dataclass
class EnsureCommunicatingAfterTransitions(Instruction):
__slots__ = ("stimulus_id",)
stimulus_id: str


@dataclass
class TaskFinishedMsg(SendMessageToScheduler):
op = "task-finished"
Expand All @@ -308,7 +309,6 @@ class TaskFinishedMsg(SendMessageToScheduler):
metadata: dict
thread: int | None
startstops: list[StartStop]
stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore

def to_dict(self) -> dict[str, Any]:
Expand All @@ -328,7 +328,6 @@ class TaskErredMsg(SendMessageToScheduler):
traceback_text: str
thread: int | None
startstops: list[StartStop]
stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore

def to_dict(self) -> dict[str, Any]:
Expand All @@ -341,29 +340,26 @@ def to_dict(self) -> dict[str, Any]:
class ReleaseWorkerDataMsg(SendMessageToScheduler):
op = "release-worker-data"

__slots__ = ("key", "stimulus_id")
__slots__ = ("key",)
key: str
stimulus_id: str


@dataclass
class MissingDataMsg(SendMessageToScheduler):
op = "missing-data"

__slots__ = ("key", "errant_worker", "stimulus_id")
__slots__ = ("key", "errant_worker")
key: str
errant_worker: str
stimulus_id: str


# Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception
@dataclass
class RescheduleMsg(SendMessageToScheduler):
op = "reschedule"

__slots__ = ("key", "stimulus_id")
__slots__ = ("key",)
key: str
stimulus_id: str


@dataclass
Expand All @@ -379,9 +375,8 @@ class LongRunningMsg(SendMessageToScheduler):
class AddKeysMsg(SendMessageToScheduler):
op = "add-keys"

__slots__ = ("keys", "stimulus_id")
__slots__ = ("keys",)
keys: list[str]
stimulus_id: str


@dataclass
Expand Down

0 comments on commit 5ca7a5a

Please sign in to comment.