Skip to content

Commit

Permalink
Tests for work stealing (#308)
Browse files Browse the repository at this point in the history
Co-authored-by: Florian Jetter <fjetter@users.noreply.github.com>
  • Loading branch information
hendrikmakait and fjetter committed Sep 14, 2022
1 parent 8c8fa38 commit 7cf89ff
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 6 deletions.
27 changes: 21 additions & 6 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,25 @@ def _measure_durations(client):
yield _measure_durations


@pytest.fixture(scope="function")
def benchmark_all(benchmark_memory, benchmark_task_durations, benchmark_time):
"""Return a function that creates a context manager for benchmarking.
Example:
>>> def test_example(benchmark_all):
>>> ...
>>> with benchmark_all(client):
>>> client.compute(my_computation)
"""

@contextlib.contextmanager
def _benchmark_all(client):
with benchmark_memory(client), benchmark_task_durations(client), benchmark_time:
yield

yield _benchmark_all


# ############################################### #
# END BENCHMARKING RELATED #
# ############################################### #
Expand Down Expand Up @@ -299,19 +318,15 @@ def small_cluster(request):
def small_client(
small_cluster,
upload_cluster_dump,
benchmark_task_durations,
benchmark_memory,
benchmark_time,
benchmark_all,
):
with Client(small_cluster) as client:
small_cluster.scale(10)
client.wait_for_workers(10)
client.restart()

with upload_cluster_dump(client, small_cluster):
with benchmark_memory(client), benchmark_task_durations(
client
), benchmark_time:
with benchmark_all(client):
yield client


Expand Down
101 changes: 101 additions & 0 deletions tests/benchmarks/test_work_stealing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import time

import dask.array as da
import distributed
import numpy as np
import pytest
from coiled.v2 import Cluster
from dask import delayed, utils
from distributed import Client
from tornado.ioloop import PeriodicCallback


def test_trivial_workload_should_not_cause_work_stealing(small_client):
root = delayed(lambda n: "x" * n)(utils.parse_bytes("1MiB"), dask_key_name="root")
results = [delayed(lambda *args: None)(root, i) for i in range(10000)]
futs = small_client.compute(results)
small_client.gather(futs)


@pytest.mark.xfail(
distributed.__version__ == "2022.6.0",
reason="https://github.com/dask/distributed/issues/6624",
)
def test_work_stealing_on_scaling_up(test_name_uuid, benchmark_all):
with Cluster(
name=test_name_uuid,
n_workers=1,
wait_for_workers=True,
worker_vm_types=["t3.medium"],
) as cluster:
with Client(cluster) as client:
with benchmark_all(client):
# Slow task.
def func1(chunk):
if sum(chunk.shape) != 0: # Make initialization fast
time.sleep(5)
return chunk

def func2(chunk):
return chunk

data = da.zeros((30, 30, 30), chunks=5)
result = data.map_overlap(func1, depth=1, dtype=data.dtype)
result = result.map_overlap(func2, depth=1, dtype=data.dtype)
future = client.compute(result)

print("started computation")

time.sleep(11)
# print('scaling to 4 workers')
# client.cluster.scale(4)

time.sleep(5)
print("scaling to 20 workers")
cluster.scale(20)

_ = future.result()


def test_work_stealing_on_inhomogeneous_workload(small_client):
np.random.seed(42)
delays = np.random.lognormal(1, 1.3, 500)

@delayed
def clog(n):
time.sleep(min(n, 60))
return n

results = [clog(i) for i in delays]
futs = small_client.compute(results)
small_client.gather(futs)


def test_work_stealing_on_straggling_worker(test_name_uuid, benchmark_all):
with Cluster(
name=test_name_uuid,
n_workers=10,
worker_vm_types=["t3.medium"],
wait_for_workers=True,
) as cluster:
with Client(cluster) as client:
with benchmark_all(client):

def clog():
time.sleep(1)

@delayed
def slowinc(i, delay):
time.sleep(delay)
return i + 1

def install_clogging_callback(dask_worker):
pc = PeriodicCallback(clog, 1500)
dask_worker.periodic_callbacks["clog"] = pc
pc.start()

straggler = list(client.scheduler_info()["workers"].keys())[0]
client.run(install_clogging_callback, workers=[straggler])
results = [slowinc(i, delay=1) for i in range(1000)]
futs = client.compute(results)
client.gather(futs)

0 comments on commit 7cf89ff

Please sign in to comment.