diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d4a84c184b5..921508a68e2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1088,7 +1088,7 @@ def __init__(self, key: str, run_spec: object): self.has_lost_dependencies = False self.host_restrictions = None # type: ignore self.worker_restrictions = None # type: ignore - self.resource_restrictions = None # type: ignore + self.resource_restrictions = {} self.loose_restrictions = False self.actor = False self.prefix = None # type: ignore @@ -2670,14 +2670,12 @@ def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None return s def consume_resources(self, ts: TaskState, ws: WorkerState): - if ts.resource_restrictions: - for r, required in ts.resource_restrictions.items(): - ws.used_resources[r] += required + for r, required in ts.resource_restrictions.items(): + ws.used_resources[r] += required def release_resources(self, ts: TaskState, ws: WorkerState): - if ts.resource_restrictions: - for r, required in ts.resource_restrictions.items(): - ws.used_resources[r] -= required + for r, required in ts.resource_restrictions.items(): + ws.used_resources[r] -= required def coerce_hostname(self, host): """ @@ -7092,29 +7090,28 @@ def adaptive_target(self, target_duration=None): to_close = self.workers_to_close() return len(self.workers) - len(to_close) - def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str): + def request_acquire_replicas( + self, addr: str, keys: Iterable[str], *, stimulus_id: str + ) -> None: """Asynchronously ask a worker to acquire a replica of the listed keys from other workers. This is a fire-and-forget operation which offers no feedback for success or failure, and is intended for housekeeping and not for computation. """ - who_has = {} - for key in keys: - ts = self.tasks[key] - who_has[key] = {ws.address for ws in ts.who_has} - + who_has = {key: [ws.address for ws in self.tasks[key].who_has] for key in keys} if self.validate: assert all(who_has.values()) self.stream_comms[addr].send( { "op": "acquire-replicas", - "keys": keys, "who_has": who_has, "stimulus_id": stimulus_id, }, ) - def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str): + def request_remove_replicas( + self, addr: str, keys: list[str], *, stimulus_id: str + ) -> None: """Asynchronously ask a worker to discard its replica of the listed keys. This must never be used to destroy the last replica of a key. This is a fire-and-forget operation, intended for housekeeping and not for computation. @@ -7125,15 +7122,14 @@ def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str): to re-add itself to who_has. If the worker agrees to discard the task, there is no feedback. """ - ws: WorkerState = self.workers[addr] - validate = self.validate + ws = self.workers[addr] # The scheduler immediately forgets about the replica and suggests the worker to # drop it. The worker may refuse, at which point it will send back an add-keys # message to reinstate it. for key in keys: - ts: TaskState = self.tasks[key] - if validate: + ts = self.tasks[key] + if self.validate: # Do not destroy the last copy assert len(ts.who_has) > 1 self.remove_replica(ts, ws) @@ -7314,22 +7310,16 @@ def _task_to_msg( dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies }, "nbytes": {dts.key: dts.nbytes for dts in ts.dependencies}, + "run_spec": ts.run_spec, + "resource_restrictions": ts.resource_restrictions, + "actor": ts.actor, + "annotations": ts.annotations, } if state.validate: assert all(msg["who_has"].values()) - - if ts.resource_restrictions: - msg["resource_restrictions"] = ts.resource_restrictions - if ts.actor: - msg["actor"] = True - - if isinstance(ts.run_spec, dict): - msg.update(ts.run_spec) - else: - msg["task"] = ts.run_spec - - if ts.annotations: - msg["annotations"] = ts.annotations + if isinstance(msg["run_spec"], dict): + assert set(msg["run_spec"]).issubset({"function", "args", "kwargs"}) + assert msg["run_spec"].get("function") return msg diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index f87479f3bfe..1e074bb5e7c 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -24,7 +24,7 @@ slowadd, slowinc, ) -from distributed.worker_state_machine import TaskState +from distributed.worker_state_machine import FreeKeysEvent, TaskState pytestmark = pytest.mark.ci1 @@ -425,7 +425,9 @@ def sink(*args): # artificially, without notifying the scheduler. # This can only succeed if B handles the missing data properly by # removing A from the known sources of keys - a.handle_free_keys(keys=["f1"], stimulus_id="Am I evil?") # Yes, I am! + a.handle_stimulus( + FreeKeysEvent(keys=["f1"], stimulus_id="Am I evil?") + ) # Yes, I am! result_fut = c.submit(sink, futures, workers=x.address) await result_fut diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 4b68f927596..b834b02c746 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -29,6 +29,7 @@ slowidentity, slowinc, ) +from distributed.worker_state_machine import StealRequestEvent pytestmark = pytest.mark.ci1 @@ -867,7 +868,7 @@ async def test_dont_steal_already_released(c, s, a, b): while key in a.tasks and a.tasks[key].state != "released": await asyncio.sleep(0.05) - a.handle_steal_request(key=key, stimulus_id="test") + a.handle_stimulus(StealRequestEvent(key=key, stimulus_id="test")) assert len(a.batched_stream.buffer) == 1 msg = a.batched_stream.buffer[0] assert msg["op"] == "steal-response" diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 3116fb3a8c8..6e9890d1641 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -656,7 +656,7 @@ async def test_log_invalid_transitions(c, s, a): await asyncio.sleep(0.01) ts = a.tasks[xkey] with pytest.raises(InvalidTransition): - a.transition(ts, "foo", stimulus_id="bar") + a._transition(ts, "foo", stimulus_id="bar") while not s.events["invalid-worker-transition"]: await asyncio.sleep(0.01) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index a9d0c27e40f..7dbefe495f3 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import gc import importlib import logging import os @@ -72,6 +73,15 @@ error_message, logger, ) +from distributed.worker_state_machine import ( + AcquireReplicasEvent, + ComputeTaskEvent, + ExecuteFailureEvent, + ExecuteSuccessEvent, + RemoveReplicasEvent, + SerializedTask, + StealRequestEvent, +) pytestmark = pytest.mark.ci1 @@ -1851,31 +1861,67 @@ async def test_story(c, s, w): @gen_cluster(client=True, nthreads=[("", 1)]) async def test_stimulus_story(c, s, a): + # Test that substrings aren't matched by stimulus_story() + f = c.submit(inc, 1, key="f") + f1 = c.submit(lambda: "foo", key="f1") + f2 = c.submit(inc, f1, key="f2") # This will fail + await wait([f, f1, f2]) + + story = a.stimulus_story("f1", "f2") + assert len(story) == 4 + + assert isinstance(story[0], ComputeTaskEvent) + assert story[0].key == "f1" + assert story[0].run_spec == SerializedTask(task=None) # Not logged + + assert isinstance(story[1], ExecuteSuccessEvent) + assert story[1].key == "f1" + assert story[1].value is None # Not logged + assert story[1].handled >= story[0].handled + + assert isinstance(story[2], ComputeTaskEvent) + assert story[2].key == "f2" + assert story[2].who_has == {"f1": (a.address,)} + assert story[2].run_spec == SerializedTask(task=None) # Not logged + assert story[2].handled >= story[1].handled + + assert isinstance(story[3], ExecuteFailureEvent) + assert story[3].key == "f2" + assert story[3].handled >= story[2].handled + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_worker_descopes_data(c, s, a): + """Test that data is released on the worker: + 1. when it's the output of a successful task + 2. when it's the input of a failed task + 3. when it's a local variable in the frame of a failed task + 4. when it's embedded in the exception of a failed task + """ + class C: - pass + instances = weakref.WeakSet() - f = c.submit(C, key="f") # Test that substrings aren't matched by story() - f2 = c.submit(inc, 2, key="f2") - f3 = c.submit(inc, 3, key="f3") - await wait([f, f2, f3]) + def __init__(self): + C.instances.add(self) - # Test that ExecuteSuccessEvent.value is not stored in the the event log - assert isinstance(a.data["f"], C) - ref = weakref.ref(a.data["f"]) - del f - while "f" in a.data: - await asyncio.sleep(0.01) - with profile.lock: - assert ref() is None + def f(x): + y = C() + raise Exception(x, y) + + f1 = c.submit(C, key="f1") + f2 = c.submit(f, f1, key="f2") + await wait([f2]) - story = a.stimulus_story("f", "f2") - assert {ev.key for ev in story} == {"f", "f2"} - assert {ev.type for ev in story} == {C, int} + assert type(a.data["f1"]) is C - prev_handled = story[0].handled - for ev in story[1:]: - assert ev.handled >= prev_handled - prev_handled = ev.handled + del f1 + del f2 + while a.data: + await asyncio.sleep(0.01) + with profile.lock: + gc.collect() + assert not C.instances @gen_cluster(client=True) @@ -2570,7 +2616,7 @@ def __call__(self, *args, **kwargs): await asyncio.sleep(0) ts = s.tasks[fut.key] - a.handle_steal_request(fut.key, stimulus_id="test") + a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test")) stealing_ext.scheduler.send_task_to_worker(b.address, ts) fut2 = c.submit(inc, fut, workers=[a.address]) @@ -2681,41 +2727,29 @@ async def test_acquire_replicas_many(c, s, *workers): await asyncio.sleep(0.001) -@pytest.mark.slow -@gen_cluster(client=True, Worker=Nanny) -async def test_acquire_replicas_already_in_flight(c, s, *nannies): +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_acquire_replicas_already_in_flight(c, s, a): """Trying to acquire a replica that is already in flight is a no-op""" + async with BlockedGatherDep(s.address) as b: + x = c.submit(inc, 1, workers=[a.address], key="x") + y = c.submit(inc, x, workers=[b.address], key="y") + await b.in_gather_dep.wait() + assert b.tasks["x"].state == "flight" - class SlowToFly: - def __getstate__(self): - sleep(0.9) - return {} - - a, b = s.workers - x = c.submit(SlowToFly, workers=[a], key="x") - await wait(x) - y = c.submit(lambda x: 123, x, workers=[b], key="y") - await asyncio.sleep(0.3) - s.request_acquire_replicas(b, [x.key], stimulus_id=f"test-{time()}") - assert await y == 123 + b.handle_stimulus( + AcquireReplicasEvent(who_has={"x": a.address}, stimulus_id="test") + ) + assert b.tasks["x"].state == "flight" + b.block_gather_dep.set() + assert await y == 3 - story = await c.run(lambda dask_worker: dask_worker.story("x")) - assert_story( - story[b], - [ - ("x", "ensure-task-exists", "released"), - ("x", "released", "fetch", "fetch", {}), - ("gather-dependencies", a, {"x"}), - ("x", "fetch", "flight", "flight", {}), - ("request-dep", a, {"x"}), - ("x", "ensure-task-exists", "flight"), - ("x", "flight", "fetch", "flight", {}), - ("receive-dep", a, {"x"}), - ("x", "put-in-memory"), - ("x", "flight", "memory", "memory", {"y": "ready"}), - ], - strict=True, - ) + assert_story( + b.story("x"), + [ + ("x", "fetch", "flight", "flight", {}), + ("x", "flight", "fetch", "flight", {}), + ], + ) @gen_cluster(client=True) @@ -2873,8 +2907,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers): if w.address == a.tasks[f1.key].coming_from: break - coming_from.handle_remove_replicas([f1.key], "test") - + coming_from.handle_stimulus(RemoveReplicasEvent(keys=[f1.key], stimulus_id="test")) await f2 assert_story(a.story(f1.key), [(f1.key, "missing-dep")]) @@ -3343,7 +3376,7 @@ async def test_log_invalid_transitions(c, s, a): await asyncio.sleep(0.01) ts = a.tasks[xkey] with pytest.raises(InvalidTransition): - a.transition(ts, "foo", stimulus_id="bar") + a._transition(ts, "foo", stimulus_id="bar") while not s.events["invalid-worker-transition"]: await asyncio.sleep(0.01) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index fdc7f01237f..1b1be8200d2 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -17,15 +17,19 @@ inc, ) from distributed.worker_state_machine import ( + AcquireReplicasEvent, + ComputeTaskEvent, ExecuteFailureEvent, ExecuteSuccessEvent, Instruction, ReleaseWorkerDataMsg, RescheduleEvent, RescheduleMsg, + SerializedTask, StateMachineEvent, TaskState, TaskStateState, + UpdateDataEvent, merge_recs_instructions, ) @@ -138,6 +142,70 @@ def test_event_to_dict(): assert ev3 == ev +def test_computetask_to_dict(): + """The potentially very large ComputeTaskEvent.run_spec is not stored in the log""" + ev = ComputeTaskEvent( + key="x", + who_has={"y": ["w1"]}, + nbytes={"y": 123}, + priority=(0,), + duration=123.45, + # Automatically converted to SerializedTask on init + run_spec={"function": b"blob", "args": b"blob"}, + resource_restrictions={}, + actor=False, + annotations={}, + stimulus_id="test", + ) + assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob") + ev2 = ev.to_loggable(handled=11.22) + assert ev2.handled == 11.22 + assert ev2.run_spec == SerializedTask(task=None) + assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob") + d = recursive_to_dict(ev2) + assert d == { + "cls": "ComputeTaskEvent", + "key": "x", + "who_has": {"y": ["w1"]}, + "nbytes": {"y": 123}, + "priority": [0], + "run_spec": [None, None, None, None], + "duration": 123.45, + "resource_restrictions": {}, + "actor": False, + "annotations": {}, + "stimulus_id": "test", + "handled": 11.22, + } + ev3 = StateMachineEvent.from_dict(d) + assert isinstance(ev3, ComputeTaskEvent) + assert ev3.run_spec == SerializedTask(task=None) + assert ev3.priority == (0,) # List is automatically converted back to tuple + + +def test_updatedata_to_dict(): + """The potentially very large UpdateDataEvent.data is not stored in the log""" + ev = UpdateDataEvent( + data={"x": "foo", "y": "bar"}, + report=True, + stimulus_id="test", + ) + ev2 = ev.to_loggable(handled=11.22) + assert ev2.handled == 11.22 + assert ev2.data == {"x": None, "y": None} + d = recursive_to_dict(ev2) + assert d == { + "cls": "UpdateDataEvent", + "data": {"x": None, "y": None}, + "report": True, + "stimulus_id": "test", + "handled": 11.22, + } + ev3 = StateMachineEvent.from_dict(d) + assert isinstance(ev3, UpdateDataEvent) + assert ev3.data == {"x": None, "y": None} + + def test_executesuccess_to_dict(): """The potentially very large ExecuteSuccessEvent.value is not stored in the log""" ev = ExecuteSuccessEvent( @@ -499,18 +567,18 @@ async def test_forget_data_needed(c, s, a, b): @gen_cluster(client=True, nthreads=[("", 1)] * 3) async def test_missing_handle_compute_dependency(c, s, w1, w2, w3): - """Test that it is OK for a dependency to be in state missing if a dependent is asked to be computed""" - + """Test that it is OK for a dependency to be in state missing if a dependent is + asked to be computed + """ w3.periodic_callbacks["find-missing"].stop() f1 = c.submit(inc, 1, key="f1", workers=[w1.address]) f2 = c.submit(inc, 2, key="f2", workers=[w1.address]) await wait_for_state(f1.key, "memory", w1) - w3.handle_acquire_replicas( - keys=[f1.key], who_has={f1.key: [w2.address]}, stimulus_id="acquire" + w3.handle_stimulus( + AcquireReplicasEvent(who_has={f1.key: [w2.address]}, stimulus_id="acquire") ) - await wait_for_state(f1.key, "missing", w3) f3 = c.submit(sum, [f1, f2], key="f3", workers=[w3.address]) @@ -525,10 +593,9 @@ async def test_missing_to_waiting(c, s, w1, w2, w3): f1 = c.submit(inc, 1, key="f1", workers=[w1.address], allow_other_workers=True) await wait_for_state(f1.key, "memory", w1) - w3.handle_acquire_replicas( - keys=[f1.key], who_has={f1.key: [w2.address]}, stimulus_id="acquire" + w3.handle_stimulus( + AcquireReplicasEvent(who_has={f1.key: [w2.address]}, stimulus_id="acquire") ) - await wait_for_state(f1.key, "missing", w3) await w2.close() diff --git a/distributed/worker.py b/distributed/worker.py index 9288e68782e..86e3081bff4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -107,16 +107,20 @@ WorkerMemoryManager, ) from distributed.worker_state_machine import ( + NO_VALUE, PROCESSING, READY, + AcquireReplicasEvent, AddKeysMsg, AlreadyCancelledEvent, CancelComputeEvent, + ComputeTaskEvent, EnsureCommunicatingAfterTransitions, Execute, ExecuteFailureEvent, ExecuteSuccessEvent, FindMissingEvent, + FreeKeysEvent, GatherDep, GatherDepDoneEvent, Instructions, @@ -127,20 +131,24 @@ RecsInstrs, RefreshWhoHasEvent, ReleaseWorkerDataMsg, + RemoveReplicasEvent, RequestRefreshWhoHasMsg, RescheduleEvent, RescheduleMsg, RetryBusyWorkerEvent, RetryBusyWorkerLater, + SecedeEvent, SendMessageToScheduler, - SerializedTask, StateMachineEvent, + StealRequestEvent, + StealResponseMsg, TaskErredMsg, TaskFinishedMsg, TaskState, TaskStateState, TransitionCounterMaxExceeded, UnpauseEvent, + UpdateDataEvent, merge_recs_instructions, ) @@ -155,8 +163,6 @@ LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") -no_value = "--no-value-sentinel--" - DEFAULT_EXTENSIONS: dict[str, type] = { "pubsub": PubSubWorkerExtension, "shuffle": ShuffleWorkerExtension, @@ -798,7 +804,7 @@ def __init__( "run_coroutine": self.run_coroutine, "get_data": self.get_data, "update_data": self.update_data, - "free_keys": self.handle_free_keys, + "free_keys": self._handle_remote_stimulus(FreeKeysEvent), "terminate": self.close, "ping": pingpong, "upload_file": self.upload_file, @@ -821,13 +827,13 @@ def __init__( stream_handlers = { "close": self.close, - "cancel-compute": self.handle_cancel_compute, - "acquire-replicas": self.handle_acquire_replicas, - "compute-task": self.handle_compute_task, - "free-keys": self.handle_free_keys, - "remove-replicas": self.handle_remove_replicas, - "steal-request": self.handle_steal_request, - "refresh-who-has": self.handle_refresh_who_has, + "cancel-compute": self._handle_remote_stimulus(CancelComputeEvent), + "acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent), + "compute-task": self._handle_remote_stimulus(ComputeTaskEvent), + "free-keys": self._handle_remote_stimulus(FreeKeysEvent), + "remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent), + "steal-request": self._handle_remote_stimulus(StealRequestEvent), + "refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent), "worker-status-change": self.handle_worker_status_change, } @@ -1762,17 +1768,30 @@ async def get_data( # Local Execution # ################### + @functools.singledispatchmethod + def _handle_event(self, ev: StateMachineEvent) -> RecsInstrs: + raise TypeError(ev) # pragma: nocover + def update_data( self, data: dict[str, object], report: bool = True, stimulus_id: str = None, ) -> dict[str, Any]: - if stimulus_id is None: - stimulus_id = f"update-data-{time()}" + self.handle_stimulus( + UpdateDataEvent( + data=data, + report=report, + stimulus_id=stimulus_id or f"update-data-{time()}", + ) + ) + return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} + + @_handle_event.register + def _handle_update_data(self, ev: UpdateDataEvent) -> RecsInstrs: recommendations: Recs = {} instructions: Instructions = [] - for key, value in data.items(): + for key, value in ev.data.items(): try: ts = self.tasks[key] recommendations[ts] = ("memory", value) @@ -1780,25 +1799,27 @@ def update_data( self.tasks[key] = ts = TaskState(key) try: - recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + recs = self._put_key_in_memory( + ts, value, stimulus_id=ev.stimulus_id + ) except Exception as e: msg = error_message(e) recommendations = {ts: tuple(msg.values())} else: recommendations.update(recs) - self.log.append((key, "receive-from-scatter", stimulus_id, time())) + self.log.append((key, "receive-from-scatter", ev.stimulus_id, time())) - if report: - instructions.append(AddKeysMsg(keys=list(data), stimulus_id=stimulus_id)) + if ev.report: + instructions.append( + AddKeysMsg(keys=list(ev.data), stimulus_id=ev.stimulus_id) + ) - self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) - return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} + return recommendations, instructions - def handle_free_keys(self, keys: list[str], stimulus_id: str) -> None: - """ - Handler to be called by the scheduler. + @_handle_event.register + def _handle_free_keys(self, ev: FreeKeysEvent) -> RecsInstrs: + """Handler to be called by the scheduler. The given keys are no longer referred to and required by the scheduler. The worker is now allowed to release the key, if applicable. @@ -1807,16 +1828,16 @@ def handle_free_keys(self, keys: list[str], stimulus_id: str) -> None: still decide to hold on to the data and task since it is required by an upstream dependency. """ - self.log.append(("free-keys", keys, stimulus_id, time())) + self.log.append(("free-keys", ev.keys, ev.stimulus_id, time())) recommendations: Recs = {} - for key in keys: + for key in ev.keys: ts = self.tasks.get(key) if ts: recommendations[ts] = "released" + return recommendations, [] - self.transitions(recommendations, stimulus_id=stimulus_id) - - def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str: + @_handle_event.register + def _handle_remove_replicas(self, ev: RemoveReplicasEvent) -> RecsInstrs: """Stream handler notifying the worker that it might be holding unreferenced, superfluous data. @@ -1835,37 +1856,29 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str: For stronger guarantees, see handler free_keys """ - self.log.append(("remove-replicas", keys, stimulus_id, time())) recommendations: Recs = {} + instructions: Instructions = [] rejected = [] - for key in keys: + for key in ev.keys: ts = self.tasks.get(key) if ts is None or ts.state != "memory": continue if not ts.is_protected(): self.log.append( - (ts.key, "remove-replica-confirmed", stimulus_id, time()) + (ts.key, "remove-replica-confirmed", ev.stimulus_id, time()) ) recommendations[ts] = "released" else: rejected.append(key) if rejected: - self.log.append(("remove-replica-rejected", rejected, stimulus_id, time())) - smsg = AddKeysMsg(keys=rejected, stimulus_id=stimulus_id) - self._handle_instructions([smsg]) - - self.transitions(recommendations, stimulus_id=stimulus_id) - - return "OK" + self.log.append( + ("remove-replica-rejected", rejected, ev.stimulus_id, time()) + ) + instructions.append(AddKeysMsg(keys=rejected, stimulus_id=ev.stimulus_id)) - def handle_refresh_who_has( - self, who_has: dict[str, list[str]], stimulus_id: str - ) -> None: - self.handle_stimulus( - RefreshWhoHasEvent(who_has=who_has, stimulus_id=stimulus_id) - ) + return recommendations, instructions async def set_resources(self, **resources) -> None: for r, quantity in resources.items(): @@ -1885,27 +1898,24 @@ async def set_resources(self, **resources) -> None: # Task Management # ################### - def handle_cancel_compute(self, key: str, stimulus_id: str) -> None: - """ - Cancel a task on a best effort basis. This is only possible while a task - is in state `waiting` or `ready`. - Nothing will happen otherwise. - """ - self.handle_stimulus(CancelComputeEvent(key=key, stimulus_id=stimulus_id)) - - def handle_acquire_replicas( - self, - *, - keys: Collection[str], - who_has: dict[str, Collection[str]], - stimulus_id: str, - ) -> None: + @fail_hard + def _handle_remote_stimulus( + self, cls: type[StateMachineEvent] + ) -> Callable[..., None]: + def _(**kwargs): + event = cls(**kwargs) + self.handle_stimulus(event) + + _.__name__ = f"_handle_remote_stimulus({cls.__name__})" + return _ + + @_handle_event.register + def _handle_acquire_replicas(self, ev: AcquireReplicasEvent) -> RecsInstrs: if self.validate: - assert set(keys) == who_has.keys() - assert all(who_has.values()) + assert all(ev.who_has.values()) recommendations: Recs = {} - for key in keys: + for key in ev.who_has: ts = self.ensure_task_exists( key=key, # Transfer this data after all dependency tasks of computations with @@ -1913,17 +1923,13 @@ def handle_acquire_replicas( # computations with low priority (<0). Note that the priority= parameter # of compute() is multiplied by -1 before it reaches TaskState.priority. priority=(1,), - stimulus_id=stimulus_id, + stimulus_id=ev.stimulus_id, ) if ts.state != "memory": recommendations[ts] = "fetch" - self._update_who_has(who_has) - self.transitions(recommendations, stimulus_id=stimulus_id) - - if self.validate: - for key in keys: - assert self.tasks[key].state != "released", self.story(key) + self._update_who_has(ev.who_has) + return recommendations, [] def ensure_task_exists( self, key: str, *, priority: tuple[int, ...], stimulus_id: str @@ -1940,32 +1946,17 @@ def ensure_task_exists( self.log.append((key, "ensure-task-exists", ts.state, stimulus_id, time())) return ts - def handle_compute_task( - self, - *, - key: str, - who_has: dict[str, Collection[str]], - nbytes: dict[str, int], - priority: tuple[int, ...], - duration: float, - function=None, - args=None, - kwargs=None, - task=no_value, # distributed.scheduler.TaskState.run_spec - resource_restrictions: dict[str, float] | None = None, - actor: bool = False, - annotations: dict | None = None, - stimulus_id: str, - ) -> None: + @_handle_event.register + def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs: try: - ts = self.tasks[key] + ts = self.tasks[ev.key] logger.debug( "Asked to compute an already known task %s", - {"task": ts, "stimulus_id": stimulus_id}, + {"task": ts, "stimulus_id": ev.stimulus_id}, ) except KeyError: - self.tasks[key] = ts = TaskState(key) - self.log.append((key, "compute-task", ts.state, stimulus_id, time())) + self.tasks[ev.key] = ts = TaskState(ev.key) + self.log.append((ev.key, "compute-task", ts.state, ev.stimulus_id, time())) recommendations: Recs = {} instructions: Instructions = [] @@ -1978,10 +1969,10 @@ def handle_compute_task( pass elif ts.state == "memory": instructions.append( - self._get_task_finished_msg(ts, stimulus_id=stimulus_id) + self._get_task_finished_msg(ts, stimulus_id=ev.stimulus_id) ) elif ts.state == "error": - instructions.append(TaskErredMsg.from_task(ts, stimulus_id=stimulus_id)) + instructions.append(TaskErredMsg.from_task(ts, stimulus_id=ev.stimulus_id)) elif ts.state in { "released", "fetch", @@ -1992,13 +1983,12 @@ def handle_compute_task( }: recommendations[ts] = "waiting" - ts.run_spec = SerializedTask(function, args, kwargs, task) + ts.run_spec = ev.run_spec - assert isinstance(priority, tuple) - priority = priority + (self.generation,) + priority = ev.priority + (self.generation,) self.generation -= 1 - if actor: + if ev.actor: self.actors[ts.key] = None ts.exception = None @@ -2006,34 +1996,35 @@ def handle_compute_task( ts.exception_text = "" ts.traceback_text = "" ts.priority = priority - ts.duration = duration - if resource_restrictions: - ts.resource_restrictions = resource_restrictions - ts.annotations = annotations + ts.duration = ev.duration + ts.resource_restrictions = ev.resource_restrictions + ts.annotations = ev.annotations if self.validate: - assert who_has.keys() == nbytes.keys() - assert all(who_has.values()) + assert ev.who_has.keys() == ev.nbytes.keys() + assert all(ev.who_has.values()) - for dep_key, dep_workers in who_has.items(): + for dep_key, dep_workers in ev.who_has.items(): dep_ts = self.ensure_task_exists( key=dep_key, priority=priority, - stimulus_id=stimulus_id, + stimulus_id=ev.stimulus_id, ) # link up to child / parents ts.dependencies.add(dep_ts) dep_ts.dependents.add(ts) - for dep_key, value in nbytes.items(): + for dep_key, value in ev.nbytes.items(): self.tasks[dep_key].nbytes = value - self._update_who_has(who_has) - else: # pragma: nocover - raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") + self._update_who_has(ev.who_has) + else: + raise RuntimeError( # pragma: nocover + f"Unexpected task state encountered for {ts}; " + f"stimulus_id={ev.stimulus_id}; story={self.story(ts)}" + ) - self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) + return recommendations, instructions ######################## # Worker State Machine # @@ -2471,15 +2462,15 @@ def transition_executing_released( return self._ensure_computing() def transition_long_running_memory( - self, ts: TaskState, value=no_value, *, stimulus_id: str + self, ts: TaskState, value=NO_VALUE, *, stimulus_id: str ) -> RecsInstrs: self.executed_count += 1 return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) def transition_generic_memory( - self, ts: TaskState, value=no_value, *, stimulus_id: str + self, ts: TaskState, value=NO_VALUE, *, stimulus_id: str ) -> RecsInstrs: - if value is no_value and ts.key not in self.data: + if value is NO_VALUE and ts.key not in self.data: raise RuntimeError( f"Tried to transition task {ts} to `memory` without data available" ) @@ -2508,7 +2499,7 @@ def transition_generic_memory( return recs, instructions def transition_executing_memory( - self, ts: TaskState, value=no_value, *, stimulus_id: str + self, ts: TaskState, value=NO_VALUE, *, stimulus_id: str ) -> RecsInstrs: if self.validate: assert ts.state == "executing" or ts.key in self.long_running @@ -2734,6 +2725,12 @@ def _transition( stimulus_id: str, **kwargs, ) -> RecsInstrs: + """Transition a key from its current state to the finish state + + See Also + -------- + Worker.transitions: wrapper around this method + """ if isinstance(finish, tuple): # the concatenated transition path might need to access the tuple assert not args @@ -2843,30 +2840,6 @@ def _transition( ) return recs, instructions - def transition( - self, ts: TaskState, finish: TaskStateState, *, stimulus_id: str, **kwargs - ) -> None: - """Transition a key from its current state to the finish state - - Examples - -------- - >>> self.transition('x', 'waiting', stimulus_id=f"test-{(time()}") - {'x': 'processing'} - - Returns - ------- - Dictionary of recommendations for future transitions - - See Also - -------- - Scheduler.transitions: transitive version of this function - """ - recs, instructions = self._transition( - ts, finish, stimulus_id=stimulus_id, **kwargs - ) - self._handle_instructions(instructions) - self.transitions(recs, stimulus_id=stimulus_id) - def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: """Process transitions until none are left @@ -2899,7 +2872,7 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: def handle_stimulus(self, stim: StateMachineEvent) -> None: if not isinstance(stim, FindMissingEvent): self.stimulus_log.append(stim.to_loggable(handled=time())) - recs, instructions = self.handle_event(stim) + recs, instructions = self._handle_event(stim) self.transitions(recs, stimulus_id=stim.stimulus_id) self._handle_instructions(instructions) @@ -2983,17 +2956,13 @@ def _handle_instructions(self, instructions: Instructions) -> None: else: instructions = [] - def maybe_transition_long_running( - self, ts: TaskState, *, compute_duration: float, stimulus_id: str - ): - if ts.state == "executing": - self.transition( - ts, - "long-running", - compute_duration=compute_duration, - stimulus_id=stimulus_id, - ) - assert ts.state == "long-running" + @_handle_event.register + def _handle_secede(self, ev: SecedeEvent) -> RecsInstrs: + ts = self.tasks.get(ev.key) + if ts and ts.state == "executing": + return {ts: ("long-running", ev.compute_duration)}, [] + else: + return {}, [] def stateof(self, key: str) -> dict[str, Any]: ts = self.tasks[key] @@ -3490,27 +3459,23 @@ def _update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None: ts.who_has = workers - def handle_steal_request(self, key: str, stimulus_id: str) -> None: + @_handle_event.register + def _handle_steal_request(self, ev: StealRequestEvent) -> RecsInstrs: # There may be a race condition between stealing and releasing a task. # In this case the self.tasks is already cleared. The `None` will be # registered as `already-computing` on the other end - ts = self.tasks.get(key) + ts = self.tasks.get(ev.key) state = ts.state if ts is not None else None - - response = { - "op": "steal-response", - "key": key, - "state": state, - "stimulus_id": stimulus_id, - } - self.batched_send(response) + smsg = StealResponseMsg(key=ev.key, state=state, stimulus_id=ev.stimulus_id) if state in READY | {"waiting"}: - assert ts # If task is marked as "constrained" we haven't yet assigned it an # `available_resources` to run on, that happens in # `transition_constrained_executing` - self.transition(ts, "released", stimulus_id=stimulus_id) + assert ts + return {ts: "released"}, [smsg] + else: + return {}, [smsg] def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: new_status = Status.lookup[status] # type: ignore @@ -3864,12 +3829,8 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No stimulus_id=f"execute-unknown-error-{time()}", ) - @functools.singledispatchmethod - def handle_event(self, ev: StateMachineEvent) -> RecsInstrs: - raise TypeError(ev) # pragma: nocover - - @handle_event.register - def _(self, ev: UnpauseEvent) -> RecsInstrs: + @_handle_event.register + def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs: """Emerge from paused status. Do not send this event directly. Instead, just set Worker.status back to running. """ @@ -3879,19 +3840,21 @@ def _(self, ev: UnpauseEvent) -> RecsInstrs: self._ensure_communicating(stimulus_id=ev.stimulus_id), ) - @handle_event.register - def _(self, ev: GatherDepDoneEvent) -> RecsInstrs: + @_handle_event.register + def _handle_gather_dep_done(self, ev: GatherDepDoneEvent) -> RecsInstrs: """Temporary hack - to be removed""" return self._ensure_communicating(stimulus_id=ev.stimulus_id) - @handle_event.register - def _(self, ev: RetryBusyWorkerEvent) -> RecsInstrs: + @_handle_event.register + def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs: self.busy_workers.discard(ev.worker) return self._ensure_communicating(stimulus_id=ev.stimulus_id) - @handle_event.register - def _(self, ev: CancelComputeEvent) -> RecsInstrs: - """Scheduler requested to cancel a task""" + @_handle_event.register + def _handle_cancel_compute(self, ev: CancelComputeEvent) -> RecsInstrs: + """Cancel a task on a best-effort basis. This is only possible while a task + is in state `waiting` or `ready`; nothing will happen otherwise. + """ ts = self.tasks.get(ev.key) if not ts or ts.state not in READY | {"waiting"}: return {}, [] @@ -3902,8 +3865,8 @@ def _(self, ev: CancelComputeEvent) -> RecsInstrs: assert not ts.dependents return {ts: "released"}, [] - @handle_event.register - def _(self, ev: AlreadyCancelledEvent) -> RecsInstrs: + @_handle_event.register + def _handle_already_cancelled(self, ev: AlreadyCancelledEvent) -> RecsInstrs: """Task is already cancelled by the time execute() runs""" # key *must* be still in tasks. Releasing it directly is forbidden # without going through cancelled @@ -3912,8 +3875,8 @@ def _(self, ev: AlreadyCancelledEvent) -> RecsInstrs: ts.done = True return {ts: "released"}, [] - @handle_event.register - def _(self, ev: ExecuteSuccessEvent) -> RecsInstrs: + @_handle_event.register + def _handle_execute_success(self, ev: ExecuteSuccessEvent) -> RecsInstrs: """Task completed successfully""" # key *must* be still in tasks. Releasing it directly is forbidden # without going through cancelled @@ -3926,8 +3889,8 @@ def _(self, ev: ExecuteSuccessEvent) -> RecsInstrs: ts.type = ev.type return {ts: ("memory", ev.value)}, [] - @handle_event.register - def _(self, ev: ExecuteFailureEvent) -> RecsInstrs: + @_handle_event.register + def _handle_execute_failure(self, ev: ExecuteFailureEvent) -> RecsInstrs: """Task execution failed""" # key *must* be still in tasks. Releasing it directly is forbidden # without going through cancelled @@ -3950,8 +3913,8 @@ def _(self, ev: ExecuteFailureEvent) -> RecsInstrs: ) }, [] - @handle_event.register - def _(self, ev: RescheduleEvent) -> RecsInstrs: + @_handle_event.register + def _handle_reschedule(self, ev: RescheduleEvent) -> RecsInstrs: """Task raised Reschedule exception while it was running""" # key *must* be still in tasks. Releasing it directly is forbidden # without going through cancelled @@ -3959,8 +3922,8 @@ def _(self, ev: RescheduleEvent) -> RecsInstrs: assert ts, self.story(ev.key) return {ts: "rescheduled"}, [] - @handle_event.register - def _(self, ev: FindMissingEvent) -> RecsInstrs: + @_handle_event.register + def _handle_find_missing(self, ev: FindMissingEvent) -> RecsInstrs: if not self._missing_dep_flight: return {}, [] @@ -3974,8 +3937,8 @@ def _(self, ev: FindMissingEvent) -> RecsInstrs: ) return {}, [smsg] - @handle_event.register - def _(self, ev: RefreshWhoHasEvent) -> RecsInstrs: + @_handle_event.register + def _handle_refresh_who_has(self, ev: RefreshWhoHasEvent) -> RecsInstrs: self._update_who_has(ev.who_has) recommendations: Recs = {} instructions: Instructions = [] @@ -4588,10 +4551,12 @@ def secede(): tpe_secede() # have this thread secede from the thread pool duration = time() - thread_state.start_time worker.loop.add_callback( - worker.maybe_transition_long_running, - worker.tasks[thread_state.key], - compute_duration=duration, - stimulus_id=f"secede-{thread_state.key}-{time()}", + worker.handle_stimulus, + SecedeEvent( + key=thread_state.key, + compute_duration=duration, + stimulus_id=f"secede-{time()}", + ), ) @@ -4679,7 +4644,7 @@ def loads_function(bytes_object): return pickle.loads(bytes_object) -def _deserialize(function=None, args=None, kwargs=None, task=no_value): +def _deserialize(function=None, args=None, kwargs=None, task=NO_VALUE): """Deserialize task inputs and regularize to func, args, kwargs""" if function is not None: function = loads_function(function) @@ -4688,7 +4653,7 @@ def _deserialize(function=None, args=None, kwargs=None, task=no_value): if kwargs and isinstance(kwargs, bytes): kwargs = pickle.loads(kwargs) - if task is not no_value: + if task is not NO_VALUE: assert not function and not args and not kwargs function = execute_task args = (task,) diff --git a/distributed/worker_client.py b/distributed/worker_client.py index 5a775b38191..bf7beda5ae7 100644 --- a/distributed/worker_client.py +++ b/distributed/worker_client.py @@ -6,6 +6,7 @@ from distributed.metrics import time from distributed.threadpoolexecutor import rejoin, secede from distributed.worker import get_client, get_worker, thread_state +from distributed.worker_state_machine import SecedeEvent @contextmanager @@ -54,11 +55,12 @@ def worker_client(timeout=None, separate_thread=True): duration = time() - thread_state.start_time secede() # have this thread secede from the thread pool worker.loop.add_callback( - worker.transition, - worker.tasks[thread_state.key], - "long-running", - stimulus_id=f"worker-client-secede-{time()}", - compute_duration=duration, + worker.handle_stimulus, + SecedeEvent( + key=thread_state.key, + compute_duration=duration, + stimulus_id=f"worker-client-secede-{time()}", + ), ) yield client diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 65bf2adb0e7..386d399250e 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from collections.abc import Callable, Container +from collections.abc import Collection, Container from copy import copy from dataclasses import dataclass, field from functools import lru_cache @@ -51,11 +51,20 @@ READY: set[TaskStateState] = {"ready", "constrained"} +NO_VALUE = "--no-value-sentinel--" + + class SerializedTask(NamedTuple): - function: Callable - args: tuple - kwargs: dict[str, Any] - task: object # distributed.scheduler.TaskState.run_spec + """Info from distributed.scheduler.TaskState.run_spec + Input to distributed.worker._deserialize + + (function, args kwargs) and task are mutually exclusive + """ + + function: bytes | None = None + args: bytes | tuple | list | None = None + kwargs: bytes | dict[str, Any] | None = None + task: object = NO_VALUE class StartStop(TypedDict, total=False): @@ -357,7 +366,7 @@ class AddKeysMsg(SendMessageToScheduler): op = "add-keys" __slots__ = ("keys",) - keys: list[str] + keys: Collection[str] @dataclass @@ -377,7 +386,23 @@ class RequestRefreshWhoHasMsg(SendMessageToScheduler): op = "request-refresh-who-has" __slots__ = ("keys",) - keys: list[str] + keys: Collection[str] + + +@dataclass +class StealResponseMsg(SendMessageToScheduler): + """Worker->Scheduler response to ``{op: steal-request}`` + + See also + -------- + StealRequestEvent + """ + + op = "steal-response" + + __slots__ = ("key", "state") + key: str + state: TaskStateState | None @dataclass @@ -457,6 +482,39 @@ class GatherDepDoneEvent(StateMachineEvent): __slots__ = () +@dataclass +class ComputeTaskEvent(StateMachineEvent): + key: str + who_has: dict[str, Collection[str]] + nbytes: dict[str, int] + priority: tuple[int, ...] + duration: float + run_spec: SerializedTask + resource_restrictions: dict[str, float] + actor: bool + annotations: dict + __slots__ = tuple(__annotations__) # type: ignore + + def __post_init__(self): + # Fixes after msgpack decode + if isinstance(self.priority, list): + self.priority = tuple(self.priority) + + if isinstance(self.run_spec, dict): + self.run_spec = SerializedTask(**self.run_spec) + elif not isinstance(self.run_spec, SerializedTask): + self.run_spec = SerializedTask(task=self.run_spec) + + def to_loggable(self, *, handled: float) -> StateMachineEvent: + out = copy(self) + out.handled = handled + out.run_spec = SerializedTask(task=None) + return out + + def _after_from_dict(self) -> None: + self.run_spec = SerializedTask(task=None) + + @dataclass class ExecuteSuccessEvent(StateMachineEvent): key: str @@ -555,7 +613,59 @@ class RefreshWhoHasEvent(StateMachineEvent): __slots__ = ("who_has",) # {key: [worker address, ...]} - who_has: dict[str, list[str]] + who_has: dict[str, Collection[str]] + + +@dataclass +class AcquireReplicasEvent(StateMachineEvent): + __slots__ = ("who_has",) + who_has: dict[str, Collection[str]] + + +@dataclass +class RemoveReplicasEvent(StateMachineEvent): + __slots__ = ("keys",) + keys: Collection[str] + + +@dataclass +class FreeKeysEvent(StateMachineEvent): + __slots__ = ("keys",) + keys: Collection[str] + + +@dataclass +class StealRequestEvent(StateMachineEvent): + """Event that requests a worker to release a key because it's now being computed + somewhere else. + + See also + -------- + StealResponseMsg + """ + + __slots__ = ("key",) + key: str + + +@dataclass +class UpdateDataEvent(StateMachineEvent): + __slots__ = ("data", "report") + data: dict[str, object] + report: bool + + def to_loggable(self, *, handled: float) -> StateMachineEvent: + out = copy(self) + out.handled = handled + out.data = dict.fromkeys(self.data) + return out + + +@dataclass +class SecedeEvent(StateMachineEvent): + __slots__ = ("key", "compute_duration") + key: str + compute_duration: float if TYPE_CHECKING: