From c82bba52070093e4bf3ffe7da36dbfdd18974d81 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sun, 26 Jun 2022 11:59:31 +0100 Subject: [PATCH] Deduplicate data_needed (#6587) --- distributed/tests/test_worker.py | 3 +- distributed/tests/test_worker_memory.py | 4 +- .../tests/test_worker_state_machine.py | 97 +++++- distributed/worker.py | 12 +- distributed/worker_state_machine.py | 296 ++++++++++-------- 5 files changed, 274 insertions(+), 138 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index b1ad5dba87f..fe951f7fff1 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2992,7 +2992,7 @@ def __sizeof__(self): ) await b.in_gather_dep.wait() assert len(b.state.in_flight_tasks) == 5 - assert len(b.state.data_needed) == 5 + assert len(b.state.data_needed[a.address]) == 5 b.block_gather_dep.set() while len(b.data) < 10: await asyncio.sleep(0.01) @@ -3348,7 +3348,6 @@ async def test_Worker__to_dict(c, s, a): "transition_counter", "tasks", "data_needed", - "data_needed_per_worker", } assert d["tasks"]["x"]["key"] == "x" assert d["data"] == {"x": None} diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 4714a62b0c7..5e46c60c54b 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -527,7 +527,7 @@ def pause_on_unpickle(): # - w and z respectively make x and y go into fetch state. # w has a higher priority than z, therefore w's dependency x has a higher priority # than z's dependency y. - # a.state.data_needed = ["x", "y"] + # a.state.data_needed[b.address] = ["x", "y"] # - ensure_communicating decides to fetch x but not to fetch y together with it, as # it thinks x is 1TB in size # - x fetch->flight; a is added to in_flight_workers @@ -543,7 +543,7 @@ def pause_on_unpickle(): await asyncio.sleep(0.1) assert a.state.tasks["y"].state == "fetch" assert "y" not in a.data - assert [ts.key for ts in a.state.data_needed] == ["y"] + assert [ts.key for ts in a.state.data_needed[b.address]] == ["y"] # Unpausing kicks off ensure_communicating again a.status = Status.running diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index b942dc4d4d7..47c767c68fe 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -31,6 +31,7 @@ ExecuteSuccessEvent, GatherDep, Instruction, + PauseEvent, RecommendationsConflict, RefreshWhoHasEvent, ReleaseWorkerDataMsg, @@ -39,6 +40,7 @@ SerializedTask, StateMachineEvent, TaskState, + UnpauseEvent, UpdateDataEvent, merge_recs_instructions, ) @@ -117,8 +119,7 @@ def test_WorkerState__to_dict(ws): "busy_workers": [], "constrained": [], "data": {"y": None}, - "data_needed": [], - "data_needed_per_worker": {"127.0.0.1:1235": []}, + "data_needed": {}, "executing": [], "in_flight_tasks": ["x"], "in_flight_workers": {"127.0.0.1:1235": ["x"]}, @@ -860,6 +861,9 @@ async def test_deprecated_worker_attributes(s, a, b): with pytest.warns(FutureWarning, match=msg): assert a.in_flight_tasks == 0 + with pytest.warns(FutureWarning, match="attribute has been removed"): + assert a.data_needed == set() + @pytest.mark.parametrize( "nbytes,n_in_flight", @@ -885,3 +889,92 @@ def test_aggregate_gather_deps(ws, nbytes, n_in_flight): assert len(instructions) == 1 assert isinstance(instructions[0], GatherDep) assert len(ws.in_flight_tasks) == n_in_flight + + +def test_gather_priority(ws): + """Test that tasks are fetched in the following order: + + 1. by task priority + 2. in case of tie, from local workers first + 3. in case of tie, from the worker with the most tasks queued + 4. in case of tie, from a random worker (which is actually deterministic). + """ + ws.total_out_connections = 4 + + instructions = ws.handle_stimulus( + PauseEvent(stimulus_id="pause"), + # Note: tasks fetched by acquire-replicas always have priority=(1, ) + AcquireReplicasEvent( + who_has={ + # Remote + local + "x1": ["127.0.0.2:1", "127.0.0.1:2"], + # Remote. After getting x11 from .1, .2 will have less tasks than .3 + "x2": ["127.0.0.2:1"], + "x3": ["127.0.0.3:1"], + "x4": ["127.0.0.3:1"], + # It will be a random choice between .2, .4, .5, .6, and .7 + "x5": ["127.0.0.4:1"], + "x6": ["127.0.0.5:1"], + "x7": ["127.0.0.6:1"], + # This will be fetched first because it's on the same worker as y + "x8": ["127.0.0.7:1"], + }, + # Substantial nbytes prevents total_out_connections to be overridden by + # comm_threshold_bytes, but it's less than target_message_size + nbytes={f"x{i}": 4 * 2**20 for i in range(1, 9)}, + stimulus_id="compute1", + ), + # A higher-priority task, even if scheduled later, is fetched first + ComputeTaskEvent( + key="z", + who_has={"y": ["127.0.0.7:1"]}, + nbytes={"y": 1}, + priority=(0,), + duration=1.0, + run_spec=None, + function=None, + args=None, + kwargs=None, + resource_restrictions={}, + actor=False, + annotations={}, + stimulus_id="compute2", + ), + UnpauseEvent(stimulus_id="unpause"), + ) + + assert instructions == [ + # Highest-priority task first. Lower priority tasks from the same worker are + # shoved into the same instruction (up to 50MB worth) + GatherDep( + stimulus_id="unpause", + worker="127.0.0.7:1", + to_gather={"y", "x8"}, + total_nbytes=1 + 4 * 2**20, + ), + # Followed by local workers + GatherDep( + stimulus_id="unpause", + worker="127.0.0.1:2", + to_gather={"x1"}, + total_nbytes=4 * 2**20, + ), + # Followed by remote workers with the most tasks + GatherDep( + stimulus_id="unpause", + worker="127.0.0.3:1", + to_gather={"x3", "x4"}, + total_nbytes=8 * 2**20, + ), + # Followed by other remote workers, randomly. + # Determinism is guaranteed by a statically-seeded random number generator. + # FIXME It would have not been deterministic if we instead of multiple keys we + # had used a single key with multiple workers, because sets + # (like TaskState.who_has) change order at every interpreter restart. + GatherDep( + stimulus_id="unpause", + worker="127.0.0.4:1", + to_gather={"x5"}, + total_nbytes=4 * 2**20, + ), + ] diff --git a/distributed/worker.py b/distributed/worker.py index 68294004ea9..a239f3b72c5 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -859,8 +859,7 @@ def data(self) -> MutableMapping[str, Any]: comm_nbytes = DeprecatedWorkerStateAttribute() comm_threshold_bytes = DeprecatedWorkerStateAttribute() constrained = DeprecatedWorkerStateAttribute() - data_needed = DeprecatedWorkerStateAttribute() - data_needed_per_worker = DeprecatedWorkerStateAttribute() + data_needed_per_worker = DeprecatedWorkerStateAttribute(target="data_needed") executed_count = DeprecatedWorkerStateAttribute() executing_count = DeprecatedWorkerStateAttribute() generation = DeprecatedWorkerStateAttribute() @@ -883,6 +882,15 @@ def data(self) -> MutableMapping[str, Any]: validate_task = DeprecatedWorkerStateAttribute() waiting_for_data_count = DeprecatedWorkerStateAttribute() + @property + def data_needed(self) -> set[TaskState]: + warnings.warn( + "The `Worker.data_needed` attribute has been removed; " + "use `Worker.state.data_needed[address]`", + FutureWarning, + ) + return {ts for tss in self.state.data_needed.values() for ts in tss} + ################## # Administrative # ################## diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 68238b6c75e..33d6ce63093 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -949,14 +949,10 @@ class WorkerState: has_what: defaultdict[str, set[str]] #: The tasks which still require data in order to execute and are in memory on at - #: least another worker, prioritized as a heap. All and only tasks with - #: ``TaskState.state == 'fetch'`` are in this collection. - data_needed: HeapSet[TaskState] - - #: Same as :attr:`data_needed`, individually for every peer worker. A - #: :class:`TaskState` with multiple entries in :attr:`~TaskState.who_has` will - #: appear multiple times here. - data_needed_per_worker: defaultdict[str, HeapSet[TaskState]] + #: least another worker, prioritized as per-worker heaps. All and only tasks with + #: ``TaskState.state == 'fetch'`` are in this collection. A :class:`TaskState` with + #: multiple entries in :attr:`~TaskState.who_has` will appear multiple times here. + data_needed: defaultdict[str, HeapSet[TaskState]] #: Number of bytes to fetch from the same worker in a single call to #: :meth:`BaseWorker.gather_dep`. Multiple small tasks that can be fetched from the @@ -1043,6 +1039,10 @@ class WorkerState: #: In production, it should always be set to False. transition_counter_max: int | Literal[False] + #: Statically-seeded random state, used to guarantee determinism whenever a + #: pseudo-random choice is required + rng: random.Random + __slots__ = tuple(__annotations__) def __init__( @@ -1077,8 +1077,7 @@ def __init__( self.running = True self.waiting_for_data_count = 0 self.has_what = defaultdict(set) - self.data_needed = HeapSet(key=operator.attrgetter("priority")) - self.data_needed_per_worker = defaultdict( + self.data_needed = defaultdict( partial(HeapSet[TaskState], key=operator.attrgetter("priority")) ) self.in_flight_workers = {} @@ -1100,6 +1099,7 @@ def __init__( self.transition_counter = 0 self.transition_counter_max = transition_counter_max self.actors = {} + self.rng = random.Random(0) def handle_stimulus(self, *stims: StateMachineEvent) -> Instructions: """Process one or more external events, transition relevant tasks to new states, @@ -1196,12 +1196,12 @@ def _update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None: for worker in ts.who_has - workers: self.has_what[worker].remove(key) if ts.state == "fetch": - self.data_needed_per_worker[worker].remove(ts) + self.data_needed[worker].remove(ts) for worker in workers - ts.who_has: self.has_what[worker].add(key) if ts.state == "fetch": - self.data_needed_per_worker[worker].add(ts) + self.data_needed[worker].add(ts) ts.who_has = workers @@ -1218,9 +1218,8 @@ def _purge_state(self, ts: TaskState) -> None: for worker in ts.who_has: self.has_what[worker].discard(ts.key) - self.data_needed_per_worker[worker].discard(ts) + self.data_needed[worker].discard(ts) ts.who_has.clear() - self.data_needed.discard(ts) self.threads.pop(key, None) @@ -1238,66 +1237,49 @@ def _purge_state(self, ts: TaskState) -> None: self.in_flight_tasks.discard(ts) def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: - if not self.running: + """Transition tasks from fetch to flight, until there are no more tasks in fetch + state or a threshold has been reached. + """ + if not self.running or not self.data_needed: + return {}, [] + if ( + len(self.in_flight_workers) >= self.total_out_connections + and self.comm_nbytes >= self.comm_threshold_bytes + ): return {}, [] - - skipped_worker_in_flight_or_busy = [] recommendations: Recs = {} instructions: Instructions = [] - while self.data_needed and ( - len(self.in_flight_workers) < self.total_out_connections - or self.comm_nbytes < self.comm_threshold_bytes - ): + for worker, available_tasks in self._select_workers_for_gather(): + assert worker != self.address + to_gather_tasks, total_nbytes = self._select_keys_for_gather( + available_tasks + ) + assert to_gather_tasks + to_gather_keys = {ts.key for ts in to_gather_tasks} + logger.debug( - "Ensure communicating. Pending: %d. Connections: %d/%d. Busy: %d", + "Gathering %d tasks from %s; %d more remain. " + "Pending workers: %d; connections: %d/%d; busy: %d", + len(to_gather_tasks), + worker, + len(available_tasks), len(self.data_needed), len(self.in_flight_workers), self.total_out_connections, len(self.busy_workers), ) - - ts = self.data_needed.pop() - - if self.validate: - assert ts.state == "fetch" - assert self.address not in ts.who_has - - if not ts.who_has: - recommendations[ts] = "missing" - continue - - workers = [ - w - for w in ts.who_has - if w not in self.in_flight_workers and w not in self.busy_workers - ] - if not workers: - skipped_worker_in_flight_or_busy.append(ts) - continue - - for w in ts.who_has: - self.data_needed_per_worker[w].remove(ts) - - host = get_address_host(self.address) - local = [w for w in workers if get_address_host(w) == host] - worker = random.choice(local or workers) - - to_gather_tasks, total_nbytes = self._select_keys_for_gather(worker, ts) - to_gather_keys = {ts.key for ts in to_gather_tasks} - self.log.append( ("gather-dependencies", worker, to_gather_keys, stimulus_id, time()) ) - self.comm_nbytes += total_nbytes - self.in_flight_workers[worker] = to_gather_keys - for d_ts in to_gather_tasks: + for ts in to_gather_tasks: if self.validate: - assert d_ts.state == "fetch" - assert d_ts not in recommendations - recommendations[d_ts] = ("flight", worker) + assert ts.state == "fetch" + assert worker in ts.who_has + assert ts not in recommendations + recommendations[ts] = ("flight", worker) # A single invocation of _ensure_communicating may generate up to one # GatherDep instruction per worker. Multiple tasks from the same worker may @@ -1313,11 +1295,104 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: ) ) - for ts in skipped_worker_in_flight_or_busy: - self.data_needed.add(ts) + self.in_flight_workers[worker] = to_gather_keys + self.comm_nbytes += total_nbytes + if ( + len(self.in_flight_workers) >= self.total_out_connections + and self.comm_nbytes >= self.comm_threshold_bytes + ): + break return recommendations, instructions + def _select_workers_for_gather(self) -> Iterator[tuple[str, HeapSet[TaskState]]]: + """Helper of _ensure_communicating. + + Yield the peer workers and tasks in data_needed, sorted by: + + 1. By highest-priority task available across all workers + 2. If tied, first by local peer workers, then remote. Note that, if a task is + replicated across multiple host, it may go in a tie with itself. + 3. If still tied, by number of tasks available to be fetched from the host + (see note below) + 4. If still tied, by a random element. This is statically seeded to guarantee + reproducibility. + + FIXME https://github.com/dask/distributed/issues/6620 + You won't get determinism when a single task is replicated on multiple + workers, because TaskState.who_has changes order at every interpreter + restart. + + Omit workers that are either busy or in flight. + Remove peer workers with no tasks from data_needed. + + Note + ---- + Instead of number of tasks, we could've measured total nbytes and/or number of + tasks that only exist on the worker. Raw number of tasks is cruder but simpler. + """ + host = get_address_host(self.address) + heap = [] + + for worker, tasks in list(self.data_needed.items()): + if not tasks: + del self.data_needed[worker] + continue + if worker in self.in_flight_workers or worker in self.busy_workers: + continue + heap.append( + ( + tasks.peek().priority, + get_address_host(worker) != host, # False < True + -len(tasks), + self.rng.random(), + worker, + tasks, + ) + ) + + heapq.heapify(heap) + while heap: + _, is_remote, ntasks_neg, rnd, worker, tasks = heapq.heappop(heap) + # The number of tasks and possibly the top priority task may have changed + # since the last sort, since _select_keys_for_gather may have removed tasks + # that are also replicated on a higher-priority worker. + if not tasks: + del self.data_needed[worker] + elif -ntasks_neg != len(tasks): + heapq.heappush( + heap, + (tasks.peek().priority, is_remote, -len(tasks), rnd, worker, tasks), + ) + else: + yield worker, tasks + if not tasks: # _select_keys_for_gather just emptied it + del self.data_needed[worker] + + def _select_keys_for_gather( + self, available: HeapSet[TaskState] + ) -> tuple[list[TaskState], int]: + """Helper of _ensure_communicating. + + Fetch all tasks that are replicated on the target worker within a single + message, up to target_message_size. + """ + to_gather: list[TaskState] = [] + total_nbytes = 0 + + while available: + ts = available.peek() + # The top-priority task is fetched regardless of its size + if to_gather and total_nbytes + ts.get_nbytes() > self.target_message_size: + break + for worker in ts.who_has: + # This also effectively pops from available + self.data_needed[worker].remove(ts) + to_gather.append(ts) + total_nbytes += ts.get_nbytes() + + return to_gather, total_nbytes + def _ensure_computing(self) -> RecsInstrs: if not self.running: return {}, [] @@ -1461,37 +1536,6 @@ def _put_key_in_memory( self.log.append((ts.key, "put-in-memory", stimulus_id, time())) return recommendations - def _select_keys_for_gather( - self, worker: str, ts: TaskState - ) -> tuple[set[TaskState], int]: - """``_ensure_communicating`` decided to fetch a single task from a worker, - following priority. In order to minimise overhead, request fetching other tasks - from the same worker within the message, following priority for the single - worker but ignoring higher priority tasks from other workers, up to - ``target_message_size``. - """ - tss = {ts} - total_bytes = ts.get_nbytes() - tasks = self.data_needed_per_worker[worker] - - while tasks: - ts = tasks.peek() - if self.validate: - assert ts.state == "fetch" - assert worker in ts.who_has - if total_bytes + ts.get_nbytes() > self.target_message_size: - break - tasks.pop() - self.data_needed.remove(ts) - for other_worker in ts.who_has: - if other_worker != worker: - self.data_needed_per_worker[other_worker].remove(ts) - - tss.add(ts) - total_bytes += ts.get_nbytes() - - return tss, total_bytes - ############### # Transitions # ############### @@ -1503,9 +1547,8 @@ def _transition_generic_fetch(self, ts: TaskState, stimulus_id: str) -> RecsInst ts.state = "fetch" ts.done = False assert ts.priority - self.data_needed.add(ts) for w in ts.who_has: - self.data_needed_per_worker[w].add(ts) + self.data_needed[w].add(ts) return {}, [] def _transition_missing_waiting( @@ -1613,9 +1656,8 @@ def _transition_fetch_flight( assert ts.state == "fetch" assert ts.who_has # The task has already been removed by _ensure_communicating - assert ts not in self.data_needed for w in ts.who_has: - assert ts not in self.data_needed_per_worker[w] + assert ts not in self.data_needed[w] ts.done = False ts.state = "flight" @@ -1623,13 +1665,6 @@ def _transition_fetch_flight( self.in_flight_tasks.add(ts) return {}, [] - def _transition_fetch_missing( - self, ts: TaskState, *, stimulus_id: str - ) -> RecsInstrs: - # _ensure_communicating could have just popped this task out of data_needed - self.data_needed.discard(ts) - return self._transition_generic_missing(ts, stimulus_id=stimulus_id) - def _transition_memory_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: @@ -2151,7 +2186,7 @@ def _transition_released_forgotten( ("executing", "released"): _transition_executing_released, ("executing", "rescheduled"): _transition_executing_rescheduled, ("fetch", "flight"): _transition_fetch_flight, - ("fetch", "missing"): _transition_fetch_missing, + ("fetch", "missing"): _transition_generic_missing, ("fetch", "released"): _transition_generic_released, ("flight", "error"): _transition_flight_error, ("flight", "fetch"): _transition_flight_fetch, @@ -2500,21 +2535,22 @@ def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs: if self.validate: assert ev.who_has.keys() == ev.nbytes.keys() - assert all(ev.who_has.values()) + for dep_workers in ev.who_has.values(): + assert dep_workers + assert len(dep_workers) == len(set(dep_workers)) - for dep_key, dep_workers in ev.who_has.items(): + for dep_key, nbytes in ev.nbytes.items(): dep_ts = self._ensure_task_exists( key=dep_key, priority=priority, stimulus_id=ev.stimulus_id, ) + self.tasks[dep_key].nbytes = nbytes + # link up to child / parents ts.dependencies.add(dep_ts) dep_ts.dependents.add(ts) - for dep_key, value in ev.nbytes.items(): - self.tasks[dep_key].nbytes = value - self._update_who_has(ev.who_has) else: raise RuntimeError( # pragma: nocover @@ -2549,7 +2585,7 @@ def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs: self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) if self.validate: assert ts.state != "fetch" - assert ts not in self.data_needed_per_worker[ev.worker] + assert ts not in self.data_needed[ev.worker] ts.who_has.discard(ev.worker) self.has_what[ev.worker].discard(ts.key) recommendations[ts] = "fetch" @@ -2600,16 +2636,24 @@ def _handle_gather_dep_network_failure( either retry a different worker, or ask the scheduler to inform us of a new worker if no other worker is available. """ - self.data_needed_per_worker.pop(ev.worker) - for key in self.has_what.pop(ev.worker): - ts = self.tasks[key] - ts.who_has.discard(ev.worker) - recommendations: Recs = {} + for ts in self._gather_dep_done_common(ev): self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) recommendations[ts] = "fetch" + for ts in self.data_needed.pop(ev.worker, ()): + if self.validate: + assert ts.state == "fetch" + assert ev.worker in ts.who_has + if ts.who_has == {ev.worker}: + # This can override a recommendation from the previous for loop + recommendations[ts] = "missing" + + for key in self.has_what.pop(ev.worker): + ts = self.tasks[key] + ts.who_has.remove(ev.worker) + return recommendations, [] @_handle_event.register @@ -2821,10 +2865,9 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: "ready": self.ready, "constrained": self.constrained, "data": dict.fromkeys(self.data), - "data_needed": [ts.key for ts in self.data_needed.sorted()], - "data_needed_per_worker": { + "data_needed": { w: [ts.key for ts in tss.sorted()] - for w, tss in self.data_needed_per_worker.items() + for w, tss in self.data_needed.items() }, "executing": {ts.key for ts in self.executing}, "long_running": self.long_running, @@ -2888,11 +2931,10 @@ def _validate_task_fetch(self, ts: TaskState) -> None: assert ts.key not in self.data assert self.address not in ts.who_has assert not ts.done - assert ts in self.data_needed - # Note: ts.who_has may be empty; see GatherDepNetworkFailureEvent + assert ts.who_has for w in ts.who_has: assert ts.key in self.has_what[w] - assert ts in self.data_needed_per_worker[w] + assert ts in self.data_needed[w] def _validate_task_missing(self, ts: TaskState) -> None: assert ts.key not in self.data @@ -2916,8 +2958,7 @@ def _validate_task_released(self, ts: TaskState) -> None: assert ts.key not in self.data assert not ts._next assert not ts._previous - assert ts not in self.data_needed - for tss in self.data_needed_per_worker.values(): + for tss in self.data_needed.values(): assert ts not in tss assert ts not in self.executing assert ts not in self.in_flight_tasks @@ -3001,19 +3042,15 @@ def validate_state(self) -> None: assert k in self.tasks, self.story(k) assert worker in self.tasks[k].who_has - for ts in self.data_needed: - assert ts.state == "fetch", self.story(ts) - for worker, tss in self.data_needed_per_worker.items(): + for worker, tss in self.data_needed.items(): for ts in tss: assert ts.state == "fetch" - assert ts in self.data_needed assert worker in ts.who_has # Test that there aren't multiple TaskState objects with the same key in any # Set[TaskState]. See note in TaskState.__hash__. for ts in chain( - self.data_needed, - *self.data_needed_per_worker.values(), + *self.data_needed.values(), self.missing_dep_flight, self.in_flight_tasks, self.executing, @@ -3027,8 +3064,7 @@ def validate_state(self) -> None: assert self.transition_counter < self.transition_counter_max # Test that there aren't multiple TaskState objects with the same key in data_needed - assert len({ts.key for ts in self.data_needed}) == len(self.data_needed) - for tss in self.data_needed_per_worker.values(): + for tss in self.data_needed.values(): assert len({ts.key for ts in tss}) == len(tss)