diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 92ac98587ae..a18594dc6e4 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -168,7 +168,6 @@ async def wait_and_raise(*args, **kwargs): b_story = b.story(fut1.key) assert any("receive-dep-failed" in msg for msg in b_story) - assert any("missing-dep" in msg for msg in b_story) assert any("cancelled" in msg for msg in b_story) assert any("resumed" in msg for msg in b_story) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index de3bf2423b9..489f902b9b9 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3389,3 +3389,24 @@ async def test_tick_interval(c, s, a, b): while s.workers[a.address].metrics["event_loop_interval"] < 0.100: await asyncio.sleep(0.01) time.sleep(0.200) + + +class BreakingWorker(Worker): + broke_once = False + + def get_data(self, comm, **kwargs): + if not self.broke_once: + self.broke_once = True + raise OSError("fake error") + return super().get_data(comm, **kwargs) + + +@pytest.mark.slow +@gen_cluster(client=True, Worker=BreakingWorker) +async def test_broken_comm(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + ) + s = df.shuffle("id", shuffle="tasks") + await c.compute(s.size) diff --git a/distributed/worker.py b/distributed/worker.py index 7c5bc61ca15..69486b4a099 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -546,6 +546,7 @@ def __init__( ("executing", "released"): self.transition_executing_released, ("executing", "rescheduled"): self.transition_executing_rescheduled, ("fetch", "flight"): self.transition_fetch_flight, + ("fetch", "missing"): self.transition_fetch_missing, ("fetch", "released"): self.transition_generic_released, ("flight", "error"): self.transition_flight_error, ("flight", "fetch"): self.transition_flight_fetch, @@ -1929,6 +1930,14 @@ def transition_flight_missing( ts.done = False return {}, [] + def transition_fetch_missing( + self, ts: TaskState, *, stimulus_id: str + ) -> RecsInstrs: + ts.state = "missing" + self._missing_dep_flight.add(ts) + ts.done = False + return {}, [] + def transition_released_fetch( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: @@ -2671,6 +2680,9 @@ def ensure_communicating(self) -> None: if ts.state != "fetch": continue + if self.validate: + assert ts.who_has + workers = [w for w in ts.who_has if w not in self.in_flight_workers] if not workers: assert ts.priority is not None @@ -2999,7 +3011,11 @@ async def gather_dep( for d in has_what: ts = self.tasks[d] ts.who_has.remove(worker) - + if not ts.who_has: + recommendations[ts] = "missing" + self.log.append( + ("missing-who-has", worker, ts.key, stimulus_id, time()) + ) except Exception as e: logger.exception(e) if self.batched_stream and LOG_PDB: