Skip to content

Commit

Permalink
transition_memory_released and get_nbytes() optimizations (#4516)
Browse files Browse the repository at this point in the history
* Annotate `nbytes` in `update_data` as a `dict`

The one case where this variable is used it is a `dict`. So go ahead and
annotate it that way. Should speed up usage of this variable when
cythonized.

* Assign `ts.get_nbytes()` to a variable

To avoid calling this method repeatedly in a few cases, assign the
result to a variable and reuse it.

* Create msg once in `transition_memory_released`
  • Loading branch information
jakirkham authored Feb 23, 2021
1 parent fdeca21 commit 7146449
Showing 1 changed file with 28 additions and 17 deletions.
45 changes: 28 additions & 17 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,17 +2168,17 @@ def transition_memory_released(self, key, safe: bint = False):
dts._waiting_on.add(ts)

# XXX factor this out?
ts_nbytes: Py_ssize_t = ts.get_nbytes()
worker_msg = {
"op": "delete-data",
"keys": [key],
"report": False,
}
for ws in ts._who_has:
ws._has_what.remove(ts)
ws._nbytes -= ts.get_nbytes()
ts._group._nbytes_in_memory -= ts.get_nbytes()
worker_msgs[ws._address] = [
{
"op": "delete-data",
"keys": [key],
"report": False,
}
]
ws._nbytes -= ts_nbytes
ts._group._nbytes_in_memory -= ts_nbytes
worker_msgs[ws._address] = [worker_msg]

ts._who_has.clear()

Expand Down Expand Up @@ -4053,9 +4053,10 @@ def stimulus_missing_data(

if cts is not None and cts._state == "memory": # couldn't find this
ws: WorkerState
cts_nbytes: Py_ssize_t = cts.get_nbytes()
for ws in cts._who_has: # TODO: this behavior is extreme
ws._has_what.remove(cts)
ws._nbytes -= cts.get_nbytes()
ws._nbytes -= cts_nbytes
cts._who_has.clear()
recommendations[cause] = "released"

Expand Down Expand Up @@ -4865,12 +4866,13 @@ async def gather(self, comm=None, keys=None, serializers=None):
)
if not workers or ts is None:
continue
ts_nbytes: Py_ssize_t = ts.get_nbytes()
for worker in workers:
ws = parent._workers_dv.get(worker)
if ws is not None and ts in ws._has_what:
ws._has_what.remove(ts)
ts._who_has.remove(ws)
ws._nbytes -= ts.get_nbytes()
ws._nbytes -= ts_nbytes
self.transitions({key: "released"})

self.log_event("all", {"action": "gather", "count": len(keys)})
Expand Down Expand Up @@ -5566,7 +5568,12 @@ def add_keys(self, comm=None, worker=None, keys=()):
return "OK"

def update_data(
self, comm=None, who_has=None, nbytes=None, client=None, serializers=None
self,
comm=None,
who_has=None,
nbytes: dict = None,
client=None,
serializers=None,
):
"""
Learn that new data has entered the network from an external source
Expand All @@ -5587,12 +5594,15 @@ def update_data(
if ts is None:
ts: TaskState = self.new_task(key, None, "memory")
ts.state = "memory"
if key in nbytes:
ts.set_nbytes(nbytes[key])
ts_nbytes: Py_ssize_t = nbytes.get(key, -1)
if ts_nbytes >= 0:
ts.set_nbytes(ts_nbytes)
else:
ts_nbytes = ts.get_nbytes()
for w in workers:
ws: WorkerState = parent._workers_dv[w]
if ts not in ws._has_what:
ws._nbytes += ts.get_nbytes()
ws._nbytes += ts_nbytes
ws._has_what.add(ts)
ts._who_has.add(ws)
self.report(
Expand Down Expand Up @@ -6712,13 +6722,14 @@ def _propagate_forgotten(
ts._dependencies.clear()
ts._waiting_on.clear()

ts_nbytes: Py_ssize_t = ts.get_nbytes()
if ts._who_has:
ts._group._nbytes_in_memory -= ts.get_nbytes()
ts._group._nbytes_in_memory -= ts_nbytes

ws: WorkerState
for ws in ts._who_has:
ws._has_what.remove(ts)
ws._nbytes -= ts.get_nbytes()
ws._nbytes -= ts_nbytes
w: str = ws._address
if w in state._workers_dv: # in case worker has died
worker_msgs[w] = [{"op": "delete-data", "keys": [key], "report": False}]
Expand Down

0 comments on commit 7146449

Please sign in to comment.