Skip to content

Commit

Permalink
Deduplicate data_needed (dask#6587)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 23, 2022
1 parent e3b70da commit 5ede365
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 137 deletions.
3 changes: 1 addition & 2 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2992,7 +2992,7 @@ def __sizeof__(self):
)
await b.in_gather_dep.wait()
assert len(b.state.in_flight_tasks) == 5
assert len(b.state.data_needed) == 5
assert len(b.state.data_needed[a.address]) == 5
b.block_gather_dep.set()
while len(b.data) < 10:
await asyncio.sleep(0.01)
Expand Down Expand Up @@ -3348,7 +3348,6 @@ async def test_Worker__to_dict(c, s, a):
"transition_counter",
"tasks",
"data_needed",
"data_needed_per_worker",
}
assert d["tasks"]["x"]["key"] == "x"
assert d["data"] == {"x": None}
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def pause_on_unpickle():
# - w and z respectively make x and y go into fetch state.
# w has a higher priority than z, therefore w's dependency x has a higher priority
# than z's dependency y.
# a.state.data_needed = ["x", "y"]
# a.state.data_needed[b.address] = ["x", "y"]
# - ensure_communicating decides to fetch x but not to fetch y together with it, as
# it thinks x is 1TB in size
# - x fetch->flight; a is added to in_flight_workers
Expand All @@ -543,7 +543,7 @@ def pause_on_unpickle():
await asyncio.sleep(0.1)
assert a.state.tasks["y"].state == "fetch"
assert "y" not in a.data
assert [ts.key for ts in a.state.data_needed] == ["y"]
assert [ts.key for ts in a.state.data_needed[b.address]] == ["y"]

# Unpausing kicks off ensure_communicating again
a.status = Status.running
Expand Down
94 changes: 92 additions & 2 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ExecuteSuccessEvent,
GatherDep,
Instruction,
PauseEvent,
RecommendationsConflict,
RefreshWhoHasEvent,
ReleaseWorkerDataMsg,
Expand All @@ -38,6 +39,7 @@
SerializedTask,
StateMachineEvent,
TaskState,
UnpauseEvent,
UpdateDataEvent,
merge_recs_instructions,
)
Expand Down Expand Up @@ -105,8 +107,7 @@ def test_WorkerState__to_dict(ws):
"busy_workers": [],
"constrained": [],
"data": {"y": None},
"data_needed": [],
"data_needed_per_worker": {"127.0.0.1:1235": []},
"data_needed": {},
"executing": [],
"in_flight_tasks": ["x"],
"in_flight_workers": {"127.0.0.1:1235": ["x"]},
Expand Down Expand Up @@ -822,6 +823,95 @@ async def test_deprecated_worker_attributes(s, a, b):
with pytest.warns(FutureWarning, match=msg):
assert a.in_flight_tasks == 0

with pytest.warns(FutureWarning, match="attribute has been removed"):
assert a.data_needed == set()


def test_gather_priority(ws):
"""Test that tasks are fetched in the following order:
1. by task priority
2. in case of tie, from local workers first
3. in case of tie, from the worker with the most tasks queued
4. in case of tie, from a random worker (which is actually deterministic).
"""
ws.total_out_connections = 4

instructions = ws.handle_stimulus(
PauseEvent(stimulus_id="pause"),
# Note: tasks fetched by acquire-replicas always have priority=(1, )
AcquireReplicasEvent(
who_has={
# Remote + local
"x1": ["127.0.0.2:1", "127.0.0.1:2"],
# Remote. After getting x11 from .1, .2 will have less tasks than .3
"x2": ["127.0.0.2:1"],
"x3": ["127.0.0.3:1"],
"x4": ["127.0.0.3:1"],
# It will be a random choice between .2, .4, .5, .6, and .7
"x5": ["127.0.0.4:1"],
"x6": ["127.0.0.5:1"],
"x7": ["127.0.0.6:1"],
# This will be fetched first because it's on the same worker as y
"x8": ["127.0.0.7:1"],
},
# Substantial nbytes prevents total_out_connections to be overridden by
# comm_threshold_bytes, but it's less than target_message_size
nbytes={f"x{i}": 4 * 2**20 for i in range(1, 9)},
stimulus_id="compute1",
),
# A higher-priority task, even if scheduled later, is fetched first
ComputeTaskEvent(
key="z",
who_has={"y": ["127.0.0.7:1"]},
nbytes={"y": 1},
priority=(0,),
duration=1.0,
run_spec=None,
resource_restrictions={},
actor=False,
annotations={},
stimulus_id="compute2",
),
UnpauseEvent(stimulus_id="unpause"),
)

assert instructions == [
# Highest-priority task first. Lower priority tasks from the same worker are
# shoved into the same instruction (up to 50MB worth)
GatherDep(
stimulus_id="unpause",
worker="127.0.0.7:1",
to_gather={"y", "x8"},
total_nbytes=1 + 4 * 2**20,
),
# Followed by local workers
GatherDep(
stimulus_id="unpause",
worker="127.0.0.1:2",
to_gather={"x1"},
total_nbytes=4 * 2**20,
),
# Followed by remote workers with the most tasks
GatherDep(
stimulus_id="unpause",
worker="127.0.0.3:1",
to_gather={"x3", "x4"},
total_nbytes=8 * 2**20,
),
# Followed by other remote workers, randomly.
# Determinism is guaranteed by a statically-seeded random number generator.
# FIXME It would have not been deterministic if we instead of multiple keys we
# had used a single key with multiple workers, because sets
# (like TaskState.who_has) change order at every interpreter restart.
GatherDep(
stimulus_id="unpause",
worker="127.0.0.4:1",
to_gather={"x5"},
total_nbytes=4 * 2**20,
),
]


@pytest.mark.parametrize(
"nbytes,n_in_flight",
Expand Down
12 changes: 10 additions & 2 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,8 +859,7 @@ def data(self) -> MutableMapping[str, Any]:
comm_nbytes = DeprecatedWorkerStateAttribute()
comm_threshold_bytes = DeprecatedWorkerStateAttribute()
constrained = DeprecatedWorkerStateAttribute()
data_needed = DeprecatedWorkerStateAttribute()
data_needed_per_worker = DeprecatedWorkerStateAttribute()
data_needed_per_worker = DeprecatedWorkerStateAttribute(target="data_needed")
executed_count = DeprecatedWorkerStateAttribute()
executing_count = DeprecatedWorkerStateAttribute()
generation = DeprecatedWorkerStateAttribute()
Expand All @@ -883,6 +882,15 @@ def data(self) -> MutableMapping[str, Any]:
validate_task = DeprecatedWorkerStateAttribute()
waiting_for_data_count = DeprecatedWorkerStateAttribute()

@property
def data_needed(self) -> set[TaskState]:
warnings.warn(
"The `Worker.data_needed` attribute has been removed; "
"use `Worker.state.data_needed[address]`",
FutureWarning,
)
return {ts for tss in self.state.data_needed.values() for ts in tss}

##################
# Administrative #
##################
Expand Down
Loading

0 comments on commit 5ede365

Please sign in to comment.