Skip to content

Commit

Permalink
Merge branch 'main' into WSMR/retry_busy_worker
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 20, 2022
2 parents 83c07f4 + fb3589c commit 02e4193
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 49 deletions.
28 changes: 28 additions & 0 deletions distributed/_signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

import asyncio
import logging
import signal
from typing import Any

logger = logging.getLogger(__name__)


async def wait_for_signals(signals: list[signal.Signals]) -> None:
"""Wait for the passed signals by setting global signal handlers"""
loop = asyncio.get_running_loop()
event = asyncio.Event()

old_handlers: dict[int, Any] = {}

def handle_signal(signum, frame):
# Restore old signal handler to allow for quicker exit
# if the user sends the signal again.
signal.signal(signum, old_handlers[signum])
logger.info("Received signal %s (%d)", signal.Signals(signum).name, signum)
loop.call_soon_threadsafe(event.set)

for sig in signals:
old_handlers[sig] = signal.signal(sig, handle_signal)

await event.wait()
2 changes: 1 addition & 1 deletion distributed/cli/dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tornado.ioloop import IOLoop

from distributed import Scheduler
from distributed.cli.utils import wait_for_signals
from distributed._signals import wait_for_signals
from distributed.preloading import validate_preload_argv
from distributed.proctitle import (
enable_proctitle_on_children,
Expand Down
2 changes: 1 addition & 1 deletion distributed/cli/dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dask.system import CPU_COUNT

from distributed import Nanny
from distributed.cli.utils import wait_for_signals
from distributed._signals import wait_for_signals
from distributed.comm import get_address_host_port
from distributed.deploy.utils import nprocesses_nthreads
from distributed.preloading import validate_preload_argv
Expand Down
31 changes: 4 additions & 27 deletions distributed/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,10 @@
from __future__ import annotations

import asyncio
import logging
import signal
from typing import Any
import warnings

from tornado.ioloop import IOLoop

logger = logging.getLogger(__name__)


async def wait_for_signals(signals: list[signal.Signals]) -> None:
"""Wait for the passed signals by setting global signal handlers"""
loop = asyncio.get_running_loop()
event = asyncio.Event()

old_handlers: dict[int, Any] = {}

def handle_signal(signum, frame):
# Restore old signal handler to allow for quicker exit
# if the user sends the signal again.
signal.signal(signum, old_handlers[signum])
logger.info("Received signal %s (%d)", signal.Signals(signum).name, signum)
loop.call_soon_threadsafe(event.set)

for sig in signals:
old_handlers[sig] = signal.signal(sig, handle_signal)

await event.wait()
warnings.warn(
"the distributed.cli.utils module is deprecated", DeprecationWarning, stacklevel=2
)


def install_signal_handlers(loop=None, cleanup=None):
Expand Down
3 changes: 3 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def add_client(self, scheduler: Scheduler, client: str) -> None:
def remove_client(self, scheduler: Scheduler, client: str) -> None:
"""Run when a client disconnects"""

def log_event(self, name, msg) -> None:
"""Run when an event is logged"""


class WorkerPlugin:
"""Interface to extend the Worker
Expand Down
22 changes: 21 additions & 1 deletion distributed/diagnostics/tests/test_scheduler_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from distributed import Scheduler, SchedulerPlugin, Worker
from distributed import Scheduler, SchedulerPlugin, Worker, get_worker
from distributed.utils_test import gen_cluster, gen_test, inc


Expand Down Expand Up @@ -178,3 +178,23 @@ def start(self, scheduler):
assert "distributed.scheduler.pickle" in msg

assert n_plugins == len(s.plugins)


@gen_cluster(client=True)
async def test_log_event_plugin(c, s, a, b):
class EventPlugin(SchedulerPlugin):
async def start(self, scheduler: Scheduler) -> None:
self.scheduler = scheduler
self.scheduler._recorded_events = list() # type: ignore

def log_event(self, name, msg):
self.scheduler._recorded_events.append((name, msg))

await c.register_scheduler_plugin(EventPlugin())

def f():
get_worker().log_event("foo", 123)

await c.submit(f)

assert ("foo", 123) in s._recorded_events
6 changes: 3 additions & 3 deletions distributed/http/scheduler/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get(self):

class RetireWorkersHandler(RequestHandler):
async def post(self):
self.set_header("Content-Type", "text/json")
self.set_header("Content-Type", "application/json")
scheduler = self.server
try:
params = json.loads(self.request.body)
Expand All @@ -32,7 +32,7 @@ async def post(self):

class GetWorkersHandler(RequestHandler):
def get(self):
self.set_header("Content-Type", "text/json")
self.set_header("Content-Type", "application/json")
scheduler = self.server
try:
response = {
Expand All @@ -50,7 +50,7 @@ def get(self):

class AdaptiveTargetHandler(RequestHandler):
def get(self):
self.set_header("Content-Type", "text/json")
self.set_header("Content-Type", "application/json")
scheduler = self.server
try:
desired_workers = scheduler.adaptive_target()
Expand Down
6 changes: 3 additions & 3 deletions distributed/http/scheduler/tests/test_scheduler_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ async def test_retire_workers(c, s, a, b):
json=params,
) as resp:
assert resp.status == 200
assert resp.headers["Content-Type"] == "text/json"
assert resp.headers["Content-Type"] == "application/json"
retired_workers_info = json.loads(await resp.text())
assert len(retired_workers_info) == 2

Expand All @@ -280,7 +280,7 @@ async def test_get_workers(c, s, a, b):
"http://localhost:%d/api/v1/get_workers" % s.http_server.port
) as resp:
assert resp.status == 200
assert resp.headers["Content-Type"] == "text/json"
assert resp.headers["Content-Type"] == "application/json"
workers_info = json.loads(await resp.text())["workers"]
workers_address = [worker.get("address") for worker in workers_info]
assert set(workers_address) == {a.address, b.address}
Expand All @@ -293,6 +293,6 @@ async def test_adaptive_target(c, s, a, b):
"http://localhost:%d/api/v1/adaptive_target" % s.http_server.port
) as resp:
assert resp.status == 200
assert resp.headers["Content-Type"] == "text/json"
assert resp.headers["Content-Type"] == "application/json"
num_workers = json.loads(await resp.text())["workers"]
assert num_workers == 0
6 changes: 6 additions & 0 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6954,6 +6954,12 @@ def log_event(self, name, msg):
self.event_counts[name] += 1
self._report_event(name, event)

for plugin in list(self.plugins.values()):
try:
plugin.log_event(name, msg)
except Exception:
logger.info("Plugin failed with exception", exc_info=True)

def _report_event(self, name, event):
for client in self.event_subscriber[name]:
self.report(
Expand Down
66 changes: 61 additions & 5 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import distributed
from distributed import Event, Lock, Worker
from distributed.client import wait
from distributed.utils_test import (
_LockedCommPool,
assert_story,
Expand Down Expand Up @@ -264,11 +265,12 @@ async def test_in_flight_lost_after_resumed(c, s, b):
block_get_data = asyncio.Lock()
in_get_data = asyncio.Event()

await block_get_data.acquire()
lock_executing = Lock()

def block_execution(lock):
with lock:
return
return 1

class BlockedGetData(Worker):
async def get_data(self, comm, *args, **kwargs):
Expand All @@ -281,15 +283,12 @@ async def get_data(self, comm, *args, **kwargs):
block_execution,
lock_executing,
workers=[a.address],
allow_other_workers=True,
key="fut1",
)
# Ensure fut1 is in memory but block any further execution afterwards to
# ensure we control when the recomputation happens
await fut1
await wait(fut1)
await lock_executing.acquire()
in_get_data.clear()
await block_get_data.acquire()
fut2 = c.submit(inc, fut1, workers=[b.address], key="fut2")

# This ensures that B already fetches the task, i.e. after this the task
Expand All @@ -298,6 +297,7 @@ async def get_data(self, comm, *args, **kwargs):
assert fut1.key in b.tasks
assert b.tasks[fut1.key].state == "flight"

s.set_restrictions({fut1.key: [a.address, b.address]})
# It is removed, i.e. get_data is guaranteed to fail and f1 is scheduled
# to be recomputed on B
await s.remove_worker(a.address, "foo", close=False, safe=True)
Expand Down Expand Up @@ -396,3 +396,59 @@ def block_execution(event, lock):
await lock_executing.release()

assert await fut2 == 2


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_cancelled_resumed_after_flight_with_dependencies(c, s, w2, w3):
# See https://github.com/dask/distributed/pull/6327#discussion_r872231090
block_get_data_1 = asyncio.Lock()
enter_get_data_1 = asyncio.Event()
await block_get_data_1.acquire()

class BlockGetDataWorker(Worker):
def __init__(self, *args, get_data_event, get_data_lock, **kwargs):
self._get_data_event = get_data_event
self._get_data_lock = get_data_lock
super().__init__(*args, **kwargs)

async def get_data(self, comm, *args, **kwargs):
self._get_data_event.set()
async with self._get_data_lock:
return await super().get_data(comm, *args, **kwargs)

async with await BlockGetDataWorker(
s.address,
get_data_event=enter_get_data_1,
get_data_lock=block_get_data_1,
name="w1",
) as w1:

f1 = c.submit(inc, 1, key="f1", workers=[w1.address])
f2 = c.submit(inc, 2, key="f2", workers=[w1.address])
f3 = c.submit(sum, [f1, f2], key="f3", workers=[w1.address])

await wait(f3)
f4 = c.submit(inc, f3, key="f4", workers=[w2.address])

await enter_get_data_1.wait()
s.set_restrictions(
{
f1.key: {w3.address},
f2.key: {w3.address},
f3.key: {w2.address},
}
)
await s.remove_worker(w1.address, "stim-id")

await wait_for_state(f3.key, "resumed", w2)
assert_story(
w2.log,
[
(f3.key, "flight", "released", "cancelled", {}),
# ...
(f3.key, "cancelled", "waiting", "resumed", {}),
],
)
# w1 closed

assert await f4 == 6
12 changes: 4 additions & 8 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ def __init__(

# FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117
pc = PeriodicCallback(self.find_missing, 1000) # type: ignore
self._find_missing_running = False
self.periodic_callbacks["find-missing"] = pc

self._address = contact_address
Expand Down Expand Up @@ -1985,13 +1986,6 @@ def handle_compute_task(
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

if self.validate:
# All previously unknown tasks that were created above by
# ensure_tasks_exists() have been transitioned to fetch or flight
assert all(
ts2.state != "released" for ts2 in (ts, *ts.dependencies)
), self.story(ts, *ts.dependencies)

########################
# Worker State Machine #
########################
Expand Down Expand Up @@ -3442,9 +3436,10 @@ async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None

@log_errors
async def find_missing(self) -> None:
if not self._missing_dep_flight:
if self._find_missing_running or not self._missing_dep_flight:
return
try:
self._find_missing_running = True
if self.validate:
for ts in self._missing_dep_flight:
assert not ts.who_has
Expand All @@ -3462,6 +3457,7 @@ async def find_missing(self) -> None:
self.transitions(recommendations, stimulus_id=stimulus_id)

finally:
self._find_missing_running = False
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks[
"find-missing"
Expand Down

0 comments on commit 02e4193

Please sign in to comment.