From 883b9484f86441e858f7c0f51a7c8e60fa1647fc Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 24 Jun 2022 12:43:47 +0200 Subject: [PATCH 1/2] Partial revert of compute-task message format --- distributed/scheduler.py | 13 +++++++++---- distributed/tests/test_scheduler.py | 15 +++++++++++++++ distributed/worker_state_machine.py | 15 ++++++++++----- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8980bbc10d9..79fccd4a51c 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -7321,16 +7321,21 @@ def _task_to_msg( dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies }, "nbytes": {dts.key: dts.nbytes for dts in ts.dependencies}, - "run_spec": ts.run_spec, + "run_spec": None, + "function": None, + "args": None, + "kwargs": None, "resource_restrictions": ts.resource_restrictions, "actor": ts.actor, "annotations": ts.annotations, } if state.validate: assert all(msg["who_has"].values()) - if isinstance(msg["run_spec"], dict): - assert set(msg["run_spec"]).issubset({"function", "args", "kwargs"}) - assert msg["run_spec"].get("function") + + if isinstance(ts.run_spec, dict): + msg.update(ts.run_spec) + else: + msg["run_spec"] = ts.run_spec return msg diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 4c3d43ab711..31ec49c3f05 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3641,3 +3641,18 @@ async def test_worker_state_unique_regardless_of_address(s, w): async def test_scheduler_close_fast_deprecated(s, w): with pytest.warns(FutureWarning): await s.close(fast=True) + + +def test_runspec_regression_sync(): + # https://github.com/dask/distributed/issues/6624 + + da = pytest.importorskip("dask.array") + np = pytest.importorskip("numpy") + with Client(): + v = da.random.random((20, 20), chunks=(5, 5)) + + overlapped = da.map_overlap(np.sum, v, depth=2, boundary="reflect") + # This computation is somehow broken but we want to avoid catching any + # serialization errors that result in KilledWorker + with pytest.raises(IndexError): + overlapped.compute() diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index f55b468c9f8..9945e544898 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -652,7 +652,10 @@ class ComputeTaskEvent(StateMachineEvent): nbytes: dict[str, int] priority: tuple[int, ...] duration: float - run_spec: SerializedTask + run_spec: SerializedTask | None + function: bytes | None + args: bytes | tuple | list | None | None + kwargs: bytes | dict[str, Any] | None resource_restrictions: dict[str, float] actor: bool annotations: dict @@ -663,19 +666,21 @@ def __post_init__(self) -> None: if isinstance(self.priority, list): # type: ignore[unreachable] self.priority = tuple(self.priority) # type: ignore[unreachable] - if isinstance(self.run_spec, dict): - self.run_spec = SerializedTask(**self.run_spec) # type: ignore[unreachable] + if self.run_spec is None: + self.run_spec = SerializedTask( + function=self.function, args=self.args, kwargs=self.kwargs + ) elif not isinstance(self.run_spec, SerializedTask): self.run_spec = SerializedTask(task=self.run_spec) # type: ignore[unreachable] def to_loggable(self, *, handled: float) -> StateMachineEvent: out = copy(self) out.handled = handled - out.run_spec = SerializedTask(task=None) + out.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None) return out def _after_from_dict(self) -> None: - self.run_spec = SerializedTask(task=None) + self.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None) @dataclass From 27f9f67e5f308c9149760f49a1580d93485360db Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 24 Jun 2022 14:26:54 +0200 Subject: [PATCH 2/2] fix tests --- .../tests/test_worker_state_machine.py | 9 +++++++-- distributed/worker_state_machine.py | 19 +++++++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index eac95602f9f..1252b302525 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -232,12 +232,14 @@ def test_computetask_to_dict(): nbytes={"y": 123}, priority=(0,), duration=123.45, - # Automatically converted to SerializedTask on init - run_spec={"function": b"blob", "args": b"blob"}, + run_spec=None, resource_restrictions={}, actor=False, annotations={}, stimulus_id="test", + function=b"blob", + args=b"blob", + kwargs=None, ) assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob") ev2 = ev.to_loggable(handled=11.22) @@ -258,6 +260,9 @@ def test_computetask_to_dict(): "annotations": {}, "stimulus_id": "test", "handled": 11.22, + "function": None, + "args": None, + "kwargs": None, } ev3 = StateMachineEvent.from_dict(d) assert isinstance(ev3, ComputeTaskEvent) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 9945e544898..b740b42e28d 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -666,19 +666,30 @@ def __post_init__(self) -> None: if isinstance(self.priority, list): # type: ignore[unreachable] self.priority = tuple(self.priority) # type: ignore[unreachable] - if self.run_spec is None: + if self.function is not None: + assert self.run_spec is None self.run_spec = SerializedTask( function=self.function, args=self.args, kwargs=self.kwargs ) elif not isinstance(self.run_spec, SerializedTask): - self.run_spec = SerializedTask(task=self.run_spec) # type: ignore[unreachable] + self.run_spec = SerializedTask(task=self.run_spec) - def to_loggable(self, *, handled: float) -> StateMachineEvent: + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: + return StateMachineEvent._to_dict(self._clean(), exclude=exclude) + + def _clean(self) -> StateMachineEvent: out = copy(self) - out.handled = handled + out.function = None + out.kwargs = None + out.args = None out.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None) return out + def to_loggable(self, *, handled: float) -> StateMachineEvent: + out = self._clean() + out.handled = handled + return out + def _after_from_dict(self) -> None: self.run_spec = SerializedTask(task=None, function=None, args=None, kwargs=None)