diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 45d7c58bf79..bb019b2f4e4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1088,7 +1088,7 @@ def __init__(self, key: str, run_spec: object): self.has_lost_dependencies = False self.host_restrictions = None # type: ignore self.worker_restrictions = None # type: ignore - self.resource_restrictions = None # type: ignore + self.resource_restrictions = {} self.loose_restrictions = False self.actor = False self.prefix = None # type: ignore @@ -2670,14 +2670,12 @@ def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None return s def consume_resources(self, ts: TaskState, ws: WorkerState): - if ts.resource_restrictions: - for r, required in ts.resource_restrictions.items(): - ws.used_resources[r] += required + for r, required in ts.resource_restrictions.items(): + ws.used_resources[r] += required def release_resources(self, ts: TaskState, ws: WorkerState): - if ts.resource_restrictions: - for r, required in ts.resource_restrictions.items(): - ws.used_resources[r] -= required + for r, required in ts.resource_restrictions.items(): + ws.used_resources[r] -= required def coerce_hostname(self, host): """ @@ -7060,29 +7058,28 @@ def adaptive_target(self, target_duration=None): to_close = self.workers_to_close() return len(self.workers) - len(to_close) - def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str): + def request_acquire_replicas( + self, addr: str, keys: Iterable[str], *, stimulus_id: str + ) -> None: """Asynchronously ask a worker to acquire a replica of the listed keys from other workers. This is a fire-and-forget operation which offers no feedback for success or failure, and is intended for housekeeping and not for computation. """ - who_has = {} - for key in keys: - ts = self.tasks[key] - who_has[key] = {ws.address for ws in ts.who_has} - + who_has = {key: [ws.address for ws in self.tasks[key].who_has] for key in keys} if self.validate: assert all(who_has.values()) self.stream_comms[addr].send( { "op": "acquire-replicas", - "keys": keys, "who_has": who_has, "stimulus_id": stimulus_id, }, ) - def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str): + def request_remove_replicas( + self, addr: str, keys: list[str], *, stimulus_id: str + ) -> None: """Asynchronously ask a worker to discard its replica of the listed keys. This must never be used to destroy the last replica of a key. This is a fire-and-forget operation, intended for housekeeping and not for computation. @@ -7093,15 +7090,14 @@ def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str): to re-add itself to who_has. If the worker agrees to discard the task, there is no feedback. """ - ws: WorkerState = self.workers[addr] - validate = self.validate + ws = self.workers[addr] # The scheduler immediately forgets about the replica and suggests the worker to # drop it. The worker may refuse, at which point it will send back an add-keys # message to reinstate it. for key in keys: - ts: TaskState = self.tasks[key] - if validate: + ts = self.tasks[key] + if self.validate: # Do not destroy the last copy assert len(ts.who_has) > 1 self.remove_replica(ts, ws) @@ -7282,22 +7278,16 @@ def _task_to_msg( dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies }, "nbytes": {dts.key: dts.nbytes for dts in ts.dependencies}, + "run_spec": ts.run_spec, + "resource_restrictions": ts.resource_restrictions, + "actor": ts.actor, + "annotations": ts.annotations, } if state.validate: assert all(msg["who_has"].values()) - - if ts.resource_restrictions: - msg["resource_restrictions"] = ts.resource_restrictions - if ts.actor: - msg["actor"] = True - - if isinstance(ts.run_spec, dict): - msg.update(ts.run_spec) - else: - msg["task"] = ts.run_spec - - if ts.annotations: - msg["annotations"] = ts.annotations + if isinstance(msg["run_spec"], dict): + assert set(msg["run_spec"]).issubset({"function", "args", "kwargs"}) + assert msg["run_spec"].get("function") return msg diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 9fc2420d6ba..2e95a1c4d73 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -30,6 +30,7 @@ slowidentity, slowinc, ) +from distributed.worker_state_machine import StealRequestEvent pytestmark = pytest.mark.ci1 @@ -868,7 +869,7 @@ async def test_dont_steal_already_released(c, s, a, b): while key in a.tasks and a.tasks[key].state != "released": await asyncio.sleep(0.05) - a.handle_steal_request(key=key, stimulus_id="test") + a.handle_stimulus(StealRequestEvent(key=key, stimulus_id="test")) assert len(a.batched_stream.buffer) == 1 msg = a.batched_stream.buffer[0] assert msg["op"] == "steal-response" diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 034e0bf2953..237a8702b18 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -70,6 +70,14 @@ error_message, logger, ) +from distributed.worker_state_machine import ( + ComputeTaskEvent, + ExecuteFailureEvent, + ExecuteSuccessEvent, + RemoveReplicasEvent, + SerializedTask, + StealRequestEvent, +) pytestmark = pytest.mark.ci1 @@ -1840,28 +1848,42 @@ async def test_stimulus_story(c, s, a): class C: pass - f = c.submit(C, key="f") # Test that substrings aren't matched by story() - f2 = c.submit(inc, 2, key="f2") - f3 = c.submit(inc, 3, key="f3") - await wait([f, f2, f3]) - - # Test that ExecuteSuccessEvent.value is not stored in the the event log - assert isinstance(a.data["f"], C) - ref = weakref.ref(a.data["f"]) - del f - while "f" in a.data: + # Test that substrings aren't matched by stimulus_story() + f = c.submit(inc, 1, key="f") + f1 = c.submit(C, key="f1") + f2 = c.submit(inc, f1, key="f2") # This will fail + await wait([f, f2]) + + # Test that the data is not referenced permanently anywhere + assert isinstance(a.data["f1"], C) + ref = weakref.ref(a.data["f1"]) + del f1 + while "f1" in a.data: await asyncio.sleep(0.01) wait_profiler() assert ref() is None - story = a.stimulus_story("f", "f2") - assert {ev.key for ev in story} == {"f", "f2"} - assert {ev.type for ev in story} == {C, int} + story = a.stimulus_story("f1", "f2") + assert len(story) == 4 + + assert isinstance(story[0], ComputeTaskEvent) + assert story[0].key == "f1" + assert story[0].run_spec == SerializedTask(task=None) # Not logged - prev_handled = story[0].handled - for ev in story[1:]: - assert ev.handled >= prev_handled - prev_handled = ev.handled + assert isinstance(story[1], ExecuteSuccessEvent) + assert story[1].key == "f1" + assert story[1].value is None # Not logged + assert story[1].handled >= story[0].handled + + assert isinstance(story[2], ComputeTaskEvent) + assert story[2].key == "f2" + assert story[2].who_has == {"f1": (a.address,)} + assert story[2].run_spec == SerializedTask(task=None) # Not logged + assert story[2].handled >= story[1].handled + + assert isinstance(story[3], ExecuteFailureEvent) + assert story[3].key == "f2" + assert story[3].handled >= story[2].handled @gen_cluster(client=True) @@ -2556,7 +2578,7 @@ def __call__(self, *args, **kwargs): await asyncio.sleep(0) ts = s.tasks[fut.key] - a.handle_steal_request(fut.key, stimulus_id="test") + a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test")) stealing_ext.scheduler.send_task_to_worker(b.address, ts) fut2 = c.submit(inc, fut, workers=[a.address]) @@ -2916,8 +2938,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers): if w.address == a.tasks[f1.key].coming_from: break - coming_from.handle_remove_replicas([f1.key], "test") - + coming_from.handle_stimulus(RemoveReplicasEvent(keys=[f1.key], stimulus_id="test")) await f2 assert_story(a.story(f1.key), [(f1.key, "missing-dep")]) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 1b6f2e3f80b..fbb7ed51587 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -7,6 +7,7 @@ from distributed.utils import recursive_to_dict from distributed.utils_test import _LockedCommPool, assert_story, gen_cluster, inc from distributed.worker_state_machine import ( + ComputeTaskEvent, ExecuteFailureEvent, ExecuteSuccessEvent, Instruction, @@ -14,9 +15,11 @@ RescheduleEvent, RescheduleMsg, SendMessageToScheduler, + SerializedTask, StateMachineEvent, TaskState, UniqueTaskHeap, + UpdateDataEvent, merge_recs_instructions, ) @@ -167,6 +170,70 @@ def test_event_to_dict(): assert ev3 == ev +def test_computetask_to_dict(): + """The potentially very large ComputeTaskEvent.run_spec is not stored in the log""" + ev = ComputeTaskEvent( + key="x", + who_has={"y": ["w1"]}, + nbytes={"y": 123}, + priority=(0,), + duration=123.45, + # Automatically converted to SerializedTask on init + run_spec={"function": b"blob", "args": b"blob"}, + resource_restrictions={}, + actor=False, + annotations={}, + stimulus_id="test", + ) + assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob") + ev2 = ev.to_loggable(handled=11.22) + assert ev2.handled == 11.22 + assert ev2.run_spec == SerializedTask(task=None) + assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob") + d = recursive_to_dict(ev2) + assert d == { + "cls": "ComputeTaskEvent", + "key": "x", + "who_has": {"y": ["w1"]}, + "nbytes": {"y": 123}, + "priority": [0], + "run_spec": [None, None, None, None], + "duration": 123.45, + "resource_restrictions": {}, + "actor": False, + "annotations": {}, + "stimulus_id": "test", + "handled": 11.22, + } + ev3 = StateMachineEvent.from_dict(d) + assert isinstance(ev3, ComputeTaskEvent) + assert ev3.run_spec == SerializedTask(task=None) + assert ev3.priority == (0,) # List is automatically converted back to tuple + + +def test_updatedata_to_dict(): + """The potentially very large UpdateDataEvent.data is not stored in the log""" + ev = UpdateDataEvent( + data={"x": "foo", "y": "bar"}, + report=True, + stimulus_id="test", + ) + ev2 = ev.to_loggable(handled=11.22) + assert ev2.handled == 11.22 + assert ev2.data == {"x": None, "y": None} + d = recursive_to_dict(ev2) + assert d == { + "cls": "UpdateDataEvent", + "data": {"x": None, "y": None}, + "report": True, + "stimulus_id": "test", + "handled": 11.22, + } + ev3 = StateMachineEvent.from_dict(d) + assert isinstance(ev3, UpdateDataEvent) + assert ev3.data == {"x": None, "y": None} + + def test_executesuccess_to_dict(): """The potentially very large ExecuteSuccessEvent.value is not stored in the log""" ev = ExecuteSuccessEvent( diff --git a/distributed/worker.py b/distributed/worker.py index 437a5b20773..26261f229fa 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -106,15 +106,19 @@ WorkerMemoryManager, ) from distributed.worker_state_machine import ( + NO_VALUE, PROCESSING, READY, + AcquireReplicasEvent, AddKeysMsg, AlreadyCancelledEvent, CancelComputeEvent, + ComputeTaskEvent, EnsureCommunicatingAfterTransitions, Execute, ExecuteFailureEvent, ExecuteSuccessEvent, + FreeKeysEvent, GatherDep, GatherDepDoneEvent, Instructions, @@ -124,11 +128,14 @@ Recs, RecsInstrs, ReleaseWorkerDataMsg, + RemoveReplicasEvent, RescheduleEvent, RescheduleMsg, + SecedeEvent, SendMessageToScheduler, - SerializedTask, StateMachineEvent, + StealRequestEvent, + StealResponseMsg, TaskErredMsg, TaskFinishedMsg, TaskState, @@ -136,6 +143,7 @@ TransitionCounterMaxExceeded, UniqueTaskHeap, UnpauseEvent, + UpdateDataEvent, merge_recs_instructions, ) @@ -150,8 +158,6 @@ LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") -no_value = "--no-value-sentinel--" - DEFAULT_EXTENSIONS: dict[str, type] = { "pubsub": PubSubWorkerExtension, "shuffle": ShuffleWorkerExtension, @@ -784,7 +790,7 @@ def __init__( "run_coroutine": self.run_coroutine, "get_data": self.get_data, "update_data": self.update_data, - "free_keys": self.handle_free_keys, + "free_keys": self._handle_remote_stimulus(FreeKeysEvent), "terminate": self.terminate, "ping": pingpong, "upload_file": self.upload_file, @@ -805,15 +811,15 @@ def __init__( "get_story": self.get_story, } - stream_handlers = { + stream_handlers: dict[str, Callable] = { "close": self.close, "terminate": self.terminate, - "cancel-compute": self.handle_cancel_compute, - "acquire-replicas": self.handle_acquire_replicas, - "compute-task": self.handle_compute_task, - "free-keys": self.handle_free_keys, - "remove-replicas": self.handle_remove_replicas, - "steal-request": self.handle_steal_request, + "cancel-compute": self._handle_remote_stimulus(CancelComputeEvent), + "acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent), + "compute-task": self._handle_remote_stimulus(ComputeTaskEvent), + "free-keys": self._handle_remote_stimulus(FreeKeysEvent), + "remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent), + "steal-request": self._handle_remote_stimulus(StealRequestEvent), "worker-status-change": self.handle_worker_status_change, } @@ -1731,17 +1737,31 @@ async def get_data( # Local Execution # ################### + @functools.singledispatchmethod + def handle_event(self, ev: StateMachineEvent) -> RecsInstrs: + raise TypeError(ev) # pragma: nocover + def update_data( self, data: dict[str, object], report: bool = True, stimulus_id: str = None, ) -> dict[str, Any]: - if stimulus_id is None: - stimulus_id = f"update-data-{time()}" + self.handle_stimulus( + UpdateDataEvent( + data=data, + report=report, + stimulus_id=stimulus_id or f"update-data-{time()}", + ) + ) + return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} + + @handle_event.register + def _(self, ev: UpdateDataEvent) -> RecsInstrs: recommendations: Recs = {} instructions: Instructions = [] - for key, value in data.items(): + + for key, value in ev.data.items(): try: ts = self.tasks[key] recommendations[ts] = ("memory", value) @@ -1749,25 +1769,27 @@ def update_data( self.tasks[key] = ts = TaskState(key) try: - recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + recs = self._put_key_in_memory( + ts, value, stimulus_id=ev.stimulus_id + ) except Exception as e: msg = error_message(e) recommendations = {ts: tuple(msg.values())} else: recommendations.update(recs) - self.log.append((key, "receive-from-scatter", stimulus_id, time())) + self.log.append((key, "receive-from-scatter", ev.stimulus_id, time())) - if report: - instructions.append(AddKeysMsg(keys=list(data), stimulus_id=stimulus_id)) + if ev.report: + instructions.append( + AddKeysMsg(keys=list(ev.data), stimulus_id=ev.stimulus_id) + ) - self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) - return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} + return recommendations, instructions - def handle_free_keys(self, keys: list[str], stimulus_id: str) -> None: - """ - Handler to be called by the scheduler. + @handle_event.register + def _(self, ev: FreeKeysEvent) -> RecsInstrs: + """Handler to be called by the scheduler. The given keys are no longer referred to and required by the scheduler. The worker is now allowed to release the key, if applicable. @@ -1776,16 +1798,16 @@ def handle_free_keys(self, keys: list[str], stimulus_id: str) -> None: still decide to hold on to the data and task since it is required by an upstream dependency. """ - self.log.append(("free-keys", keys, stimulus_id, time())) + self.log.append(("free-keys", ev.keys, ev.stimulus_id, time())) recommendations: Recs = {} - for key in keys: + for key in ev.keys: ts = self.tasks.get(key) if ts: recommendations[ts] = "released" + return recommendations, [] - self.transitions(recommendations, stimulus_id=stimulus_id) - - def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str: + @handle_event.register + def _(self, ev: RemoveReplicasEvent) -> RecsInstrs: """Stream handler notifying the worker that it might be holding unreferenced, superfluous data. @@ -1804,30 +1826,30 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str: For stronger guarantees, see handler free_keys """ - self.log.append(("remove-replicas", keys, stimulus_id, time())) + self.log.append(("remove-replicas", ev.keys, ev.stimulus_id, time())) recommendations: Recs = {} + instructions: Instructions = [] rejected = [] - for key in keys: + for key in ev.keys: ts = self.tasks.get(key) if ts is None or ts.state != "memory": continue if not ts.is_protected(): self.log.append( - (ts.key, "remove-replica-confirmed", stimulus_id, time()) + (ts.key, "remove-replica-confirmed", ev.stimulus_id, time()) ) recommendations[ts] = "released" else: rejected.append(key) if rejected: - self.log.append(("remove-replica-rejected", rejected, stimulus_id, time())) - smsg = AddKeysMsg(keys=rejected, stimulus_id=stimulus_id) - self._handle_instructions([smsg]) - - self.transitions(recommendations, stimulus_id=stimulus_id) + self.log.append( + ("remove-replica-rejected", rejected, ev.stimulus_id, time()) + ) + instructions.append(AddKeysMsg(keys=rejected, stimulus_id=ev.stimulus_id)) - return "OK" + return recommendations, instructions async def set_resources(self, **resources) -> None: for r, quantity in resources.items(): @@ -1847,27 +1869,25 @@ async def set_resources(self, **resources) -> None: # Task Management # ################### - def handle_cancel_compute(self, key: str, stimulus_id: str) -> None: - """ - Cancel a task on a best effort basis. This is only possible while a task - is in state `waiting` or `ready`. - Nothing will happen otherwise. - """ - self.handle_stimulus(CancelComputeEvent(key=key, stimulus_id=stimulus_id)) + @fail_hard + def _handle_remote_stimulus( + self, cls: type[StateMachineEvent] + ) -> Callable[..., None]: + def _(**kwargs): + event = cls(**kwargs) + self.handle_stimulus(event) - def handle_acquire_replicas( - self, - *, - keys: Collection[str], - who_has: dict[str, Collection[str]], - stimulus_id: str, - ) -> None: + _.__name__ = f"_handle_remote_stimulus({cls.__name__})" + return _ + + @handle_event.register + def _(self, ev: AcquireReplicasEvent) -> RecsInstrs: + self.log.append(("acquire-replicas", set(ev.who_has), ev.stimulus_id, time())) if self.validate: - assert set(keys) == who_has.keys() - assert all(who_has.values()) + assert all(ev.who_has.values()) recommendations: Recs = {} - for key in keys: + for key in ev.who_has: ts = self.ensure_task_exists( key=key, # Transfer this data after all dependency tasks of computations with @@ -1875,17 +1895,13 @@ def handle_acquire_replicas( # computations with low priority (<0). Note that the priority= parameter # of compute() is multiplied by -1 before it reaches TaskState.priority. priority=(1,), - stimulus_id=stimulus_id, + stimulus_id=ev.stimulus_id, ) if ts.state != "memory": recommendations[ts] = "fetch" - self.update_who_has(who_has) - self.transitions(recommendations, stimulus_id=stimulus_id) - - if self.validate: - for key in keys: - assert self.tasks[key].state != "released", self.story(key) + self.update_who_has(ev.who_has) + return recommendations, [] def ensure_task_exists( self, key: str, *, priority: tuple[int, ...], stimulus_id: str @@ -1902,94 +1918,68 @@ def ensure_task_exists( self.log.append((key, "ensure-task-exists", ts.state, stimulus_id, time())) return ts - def handle_compute_task( - self, - *, - key: str, - who_has: dict[str, Collection[str]], - nbytes: dict[str, int], - priority: tuple[int, ...], - duration: float, - function=None, - args=None, - kwargs=None, - task=no_value, # distributed.scheduler.TaskState.run_spec - resource_restrictions: dict[str, float] | None = None, - actor: bool = False, - annotations: dict | None = None, - stimulus_id: str, - ) -> None: + @handle_event.register + def _(self, ev: ComputeTaskEvent) -> RecsInstrs: try: - ts = self.tasks[key] + ts = self.tasks[ev.key] logger.debug( "Asked to compute an already known task %s", - {"task": ts, "stimulus_id": stimulus_id}, + {"task": ts, "stimulus_id": ev.stimulus_id}, ) except KeyError: - self.tasks[key] = ts = TaskState(key) - self.log.append((key, "compute-task", ts.state, stimulus_id, time())) - - recommendations: Recs = {} - instructions: Instructions = [] + self.tasks[ev.key] = ts = TaskState(ev.key) + self.log.append((ev.key, "compute-task", ts.state, ev.stimulus_id, time())) if ts.state in READY | {"executing", "long-running", "waiting", "resumed"}: - pass - elif ts.state == "memory": - instructions.append( - self._get_task_finished_msg(ts, stimulus_id=stimulus_id) - ) - elif ts.state in { + return {}, [] + + if ts.state == "memory": + return {}, [self._get_task_finished_msg(ts, stimulus_id=ev.stimulus_id)] + + assert ts.state in { "released", "fetch", "flight", "missing", "cancelled", "error", - }: - recommendations[ts] = "waiting" - - ts.run_spec = SerializedTask(function, args, kwargs, task) - - assert isinstance(priority, tuple) - priority = priority + (self.generation,) - self.generation -= 1 - - if actor: - self.actors[ts.key] = None + }, ts - ts.exception = None - ts.traceback = None - ts.exception_text = "" - ts.traceback_text = "" - ts.priority = priority - ts.duration = duration - if resource_restrictions: - ts.resource_restrictions = resource_restrictions - ts.annotations = annotations + ts.run_spec = ev.run_spec + priority = ev.priority + (self.generation,) + self.generation -= 1 - if self.validate: - assert who_has.keys() == nbytes.keys() - assert all(who_has.values()) + if ev.actor: + self.actors[ts.key] = None - for dep_key, dep_workers in who_has.items(): - dep_ts = self.ensure_task_exists( - key=dep_key, - priority=priority, - stimulus_id=stimulus_id, - ) - # link up to child / parents - ts.dependencies.add(dep_ts) - dep_ts.dependents.add(ts) + ts.exception = None + ts.traceback = None + ts.exception_text = "" + ts.traceback_text = "" + ts.priority = priority + ts.duration = ev.duration + ts.resource_restrictions = ev.resource_restrictions.copy() + ts.annotations = ev.annotations - for dep_key, value in nbytes.items(): - self.tasks[dep_key].nbytes = value + if self.validate: + assert ev.who_has.keys() == ev.nbytes.keys() + assert all(ev.who_has.values()) + + for dep_key, dep_workers in ev.who_has.items(): + dep_ts = self.ensure_task_exists( + key=dep_key, + priority=priority, + stimulus_id=ev.stimulus_id, + ) + # link up to child / parents + ts.dependencies.add(dep_ts) + dep_ts.dependents.add(ts) - self.update_who_has(who_has) - else: # pragma: nocover - raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") + for dep_key, value in ev.nbytes.items(): + self.tasks[dep_key].nbytes = value - self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) + self.update_who_has(ev.who_has) + return {ts: "waiting"}, [] ######################## # Worker State Machine # @@ -2407,15 +2397,15 @@ def transition_executing_released( return self._ensure_computing() def transition_long_running_memory( - self, ts: TaskState, value=no_value, *, stimulus_id: str + self, ts: TaskState, value=NO_VALUE, *, stimulus_id: str ) -> RecsInstrs: self.executed_count += 1 return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) def transition_generic_memory( - self, ts: TaskState, value=no_value, *, stimulus_id: str + self, ts: TaskState, value=NO_VALUE, *, stimulus_id: str ) -> RecsInstrs: - if value is no_value and ts.key not in self.data: + if value is NO_VALUE and ts.key not in self.data: raise RuntimeError( f"Tried to transition task {ts} to `memory` without data available" ) @@ -2444,7 +2434,7 @@ def transition_generic_memory( return recs, instructions def transition_executing_memory( - self, ts: TaskState, value=no_value, *, stimulus_id: str + self, ts: TaskState, value=NO_VALUE, *, stimulus_id: str ) -> RecsInstrs: if self.validate: assert ts.state == "executing" or ts.key in self.long_running @@ -2919,17 +2909,13 @@ def _handle_instructions(self, instructions: Instructions) -> None: else: instructions = [] - def maybe_transition_long_running( - self, ts: TaskState, *, compute_duration: float, stimulus_id: str - ): - if ts.state == "executing": - self.transition( - ts, - "long-running", - compute_duration=compute_duration, - stimulus_id=stimulus_id, - ) - assert ts.state == "long-running" + @handle_event.register + def _(self, ev: SecedeEvent) -> RecsInstrs: + ts = self.tasks.get(ev.key) + if ts and ts.state == "executing": + return {ts: ("long-running", ev.compute_duration)}, [] + else: + return {}, [] def stateof(self, key: str) -> dict[str, Any]: ts = self.tasks[key] @@ -3401,7 +3387,7 @@ async def find_missing(self) -> None: "find-missing" ].callback_time = self.periodic_callbacks["heartbeat"].callback_time - def update_who_has(self, who_has: dict[str, Collection[str]]) -> None: + def update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None: try: for dep, workers in who_has.items(): if not workers: @@ -3430,27 +3416,23 @@ def update_who_has(self, who_has: dict[str, Collection[str]]) -> None: pdb.set_trace() raise - def handle_steal_request(self, key: str, stimulus_id: str) -> None: + @handle_event.register + def _(self, ev: StealRequestEvent) -> RecsInstrs: # There may be a race condition between stealing and releasing a task. # In this case the self.tasks is already cleared. The `None` will be # registered as `already-computing` on the other end - ts = self.tasks.get(key) + ts = self.tasks.get(ev.key) state = ts.state if ts is not None else None - - response = { - "op": "steal-response", - "key": key, - "state": state, - "stimulus_id": stimulus_id, - } - self.batched_stream.send(response) + smsg = StealResponseMsg(key=ev.key, state=state, stimulus_id=ev.stimulus_id) if state in READY | {"waiting"}: - assert ts # If task is marked as "constrained" we haven't yet assigned it an # `available_resources` to run on, that happens in # `transition_constrained_executing` - self.transition(ts, "released", stimulus_id=stimulus_id) + assert ts + return {ts: "released"}, [smsg] + else: + return {}, [smsg] def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: new_status = Status.lookup[status] # type: ignore @@ -3862,10 +3844,6 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No stimulus_id=f"task-erred-{time()}", ) - @functools.singledispatchmethod - def handle_event(self, ev: StateMachineEvent) -> RecsInstrs: - raise TypeError(ev) # pragma: nocover - @handle_event.register def _(self, ev: UnpauseEvent) -> RecsInstrs: """Emerge from paused status. Do not send this event directly. Instead, just set @@ -3884,7 +3862,9 @@ def _(self, ev: GatherDepDoneEvent) -> RecsInstrs: @handle_event.register def _(self, ev: CancelComputeEvent) -> RecsInstrs: - """Scheduler requested to cancel a task""" + """Cancel a task on a best effort basis. This is only possible while a task + is in state `waiting` or `ready`; nothing will happen otherwise. + """ ts = self.tasks.get(ev.key) if not ts or ts.state not in READY | {"waiting"}: return {}, [] @@ -4525,12 +4505,13 @@ def secede(): """ worker = get_worker() tpe_secede() # have this thread secede from the thread pool - duration = time() - thread_state.start_time worker.loop.add_callback( - worker.maybe_transition_long_running, - worker.tasks[thread_state.key], - compute_duration=duration, - stimulus_id=f"secede-{thread_state.key}-{time()}", + worker.handle_stimulus, + SecedeEvent( + key=thread_state.key, + compute_duration=time() - thread_state.start_time, + stimulus_id=f"secede-{time()}", + ), ) @@ -4618,7 +4599,7 @@ def loads_function(bytes_object): return pickle.loads(bytes_object) -def _deserialize(function=None, args=None, kwargs=None, task=no_value): +def _deserialize(function=None, args=None, kwargs=None, task=NO_VALUE): """Deserialize task inputs and regularize to func, args, kwargs""" if function is not None: function = loads_function(function) @@ -4627,7 +4608,7 @@ def _deserialize(function=None, args=None, kwargs=None, task=no_value): if kwargs and isinstance(kwargs, bytes): kwargs = pickle.loads(kwargs) - if task is not no_value: + if task is not NO_VALUE: assert not function and not args and not kwargs function = execute_task args = (task,) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index f5fa39c0802..0b22bc84d78 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -2,7 +2,7 @@ import heapq import sys -from collections.abc import Callable, Container, Iterator +from collections.abc import Container, Iterator from copy import copy from dataclasses import dataclass, field from functools import lru_cache @@ -52,11 +52,20 @@ READY: set[TaskStateState] = {"ready", "constrained"} +NO_VALUE = "--no-value-sentinel--" + + class SerializedTask(NamedTuple): - function: Callable - args: tuple - kwargs: dict[str, Any] - task: object # distributed.scheduler.TaskState.run_spec + """Info from distributed.scheduler.TaskState.run_spec + Input to distributed.worker._deserialize + + (function, args kwargs) and task are mutually exclusive + """ + + function: bytes | None = None + args: bytes | tuple | list | None = None + kwargs: bytes | dict[str, Any] | None = None + task: object = NO_VALUE class StartStop(TypedDict, total=False): @@ -379,6 +388,22 @@ class AddKeysMsg(SendMessageToScheduler): keys: list[str] +@dataclass +class StealResponseMsg(SendMessageToScheduler): + """Worker->Scheduler response to ``{op: steal-request}`` + + See also + -------- + StealRequestEvent + """ + + op = "steal-response" + + __slots__ = ("key", "state") + key: str + state: TaskStateState | None + + @dataclass class StateMachineEvent: __slots__ = ("stimulus_id", "handled") @@ -450,6 +475,39 @@ class GatherDepDoneEvent(StateMachineEvent): __slots__ = () +@dataclass +class ComputeTaskEvent(StateMachineEvent): + key: str + who_has: dict[str, Collection[str]] + nbytes: dict[str, int] + priority: tuple[int, ...] + duration: float + run_spec: SerializedTask + resource_restrictions: dict[str, float] + actor: bool + annotations: dict + __slots__ = tuple(__annotations__) # type: ignore + + def __post_init__(self): + # Fixes after msgpack decode + if isinstance(self.priority, list): + self.priority = tuple(self.priority) + + if isinstance(self.run_spec, dict): + self.run_spec = SerializedTask(**self.run_spec) + elif not isinstance(self.run_spec, SerializedTask): + self.run_spec = SerializedTask(task=self.run_spec) + + def to_loggable(self, *, handled: float) -> StateMachineEvent: + out = copy(self) + out.handled = handled + out.run_spec = SerializedTask(task=None) + return out + + def _after_from_dict(self) -> None: + self.run_spec = SerializedTask(task=None) + + @dataclass class ExecuteSuccessEvent(StateMachineEvent): key: str @@ -508,6 +566,58 @@ class RescheduleEvent(StateMachineEvent): key: str +@dataclass +class AcquireReplicasEvent(StateMachineEvent): + __slots__ = ("who_has",) + who_has: dict[str, Collection[str]] + + +@dataclass +class RemoveReplicasEvent(StateMachineEvent): + __slots__ = ("keys",) + keys: list[str] + + +@dataclass +class FreeKeysEvent(StateMachineEvent): + __slots__ = ("keys",) + keys: list[str] + + +@dataclass +class StealRequestEvent(StateMachineEvent): + """Event that requests a worker to release a key because it's now being computed + somewhere else. + + See also + -------- + StealResponseMsg + """ + + __slots__ = ("key",) + key: str + + +@dataclass +class UpdateDataEvent(StateMachineEvent): + __slots__ = ("data", "report") + data: dict[str, object] + report: bool + + def to_loggable(self, *, handled: float) -> StateMachineEvent: + out = copy(self) + out.handled = handled + out.data = dict.fromkeys(self.data) + return out + + +@dataclass +class SecedeEvent(StateMachineEvent): + __slots__ = ("key", "compute_duration") + key: str + compute_duration: float + + if TYPE_CHECKING: # TODO remove quotes (requires Python >=3.9) # TODO get out of TYPE_CHECKING (requires Python >=3.10)