diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c13f529253..911bf4fba3 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2168,17 +2168,17 @@ def transition_memory_released(self, key, safe: bint = False): dts._waiting_on.add(ts) # XXX factor this out? + ts_nbytes: Py_ssize_t = ts.get_nbytes() + worker_msg = { + "op": "delete-data", + "keys": [key], + "report": False, + } for ws in ts._who_has: ws._has_what.remove(ts) - ws._nbytes -= ts.get_nbytes() - ts._group._nbytes_in_memory -= ts.get_nbytes() - worker_msgs[ws._address] = [ - { - "op": "delete-data", - "keys": [key], - "report": False, - } - ] + ws._nbytes -= ts_nbytes + ts._group._nbytes_in_memory -= ts_nbytes + worker_msgs[ws._address] = [worker_msg] ts._who_has.clear() @@ -4053,9 +4053,10 @@ def stimulus_missing_data( if cts is not None and cts._state == "memory": # couldn't find this ws: WorkerState + cts_nbytes: Py_ssize_t = cts.get_nbytes() for ws in cts._who_has: # TODO: this behavior is extreme ws._has_what.remove(cts) - ws._nbytes -= cts.get_nbytes() + ws._nbytes -= cts_nbytes cts._who_has.clear() recommendations[cause] = "released" @@ -4865,12 +4866,13 @@ async def gather(self, comm=None, keys=None, serializers=None): ) if not workers or ts is None: continue + ts_nbytes: Py_ssize_t = ts.get_nbytes() for worker in workers: ws = parent._workers_dv.get(worker) if ws is not None and ts in ws._has_what: ws._has_what.remove(ts) ts._who_has.remove(ws) - ws._nbytes -= ts.get_nbytes() + ws._nbytes -= ts_nbytes self.transitions({key: "released"}) self.log_event("all", {"action": "gather", "count": len(keys)}) @@ -5566,7 +5568,12 @@ def add_keys(self, comm=None, worker=None, keys=()): return "OK" def update_data( - self, comm=None, who_has=None, nbytes=None, client=None, serializers=None + self, + comm=None, + who_has=None, + nbytes: dict = None, + client=None, + serializers=None, ): """ Learn that new data has entered the network from an external source @@ -5587,12 +5594,15 @@ def update_data( if ts is None: ts: TaskState = self.new_task(key, None, "memory") ts.state = "memory" - if key in nbytes: - ts.set_nbytes(nbytes[key]) + ts_nbytes: Py_ssize_t = nbytes.get(key, -1) + if ts_nbytes >= 0: + ts.set_nbytes(ts_nbytes) + else: + ts_nbytes = ts.get_nbytes() for w in workers: ws: WorkerState = parent._workers_dv[w] if ts not in ws._has_what: - ws._nbytes += ts.get_nbytes() + ws._nbytes += ts_nbytes ws._has_what.add(ts) ts._who_has.add(ws) self.report( @@ -6712,13 +6722,14 @@ def _propagate_forgotten( ts._dependencies.clear() ts._waiting_on.clear() + ts_nbytes: Py_ssize_t = ts.get_nbytes() if ts._who_has: - ts._group._nbytes_in_memory -= ts.get_nbytes() + ts._group._nbytes_in_memory -= ts_nbytes ws: WorkerState for ws in ts._who_has: ws._has_what.remove(ts) - ws._nbytes -= ts.get_nbytes() + ws._nbytes -= ts_nbytes w: str = ws._address if w in state._workers_dv: # in case worker has died worker_msgs[w] = [{"op": "delete-data", "keys": [key], "report": False}]