Skip to content

Commit

Permalink
Send SIGKILL after SIGTERM when passing 95% memory (#6419)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored May 25, 2022
1 parent 8035d36 commit ba39915
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 93 deletions.
14 changes: 0 additions & 14 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from time import sleep as sync_sleep
from typing import TYPE_CHECKING, ClassVar

import psutil
from tornado import gen
from tornado.ioloop import IOLoop

Expand Down Expand Up @@ -486,19 +485,6 @@ async def _():
else:
return "OK"

@property
def _psutil_process(self):
pid = self.process.process.pid
try:
self._psutil_process_obj
except AttributeError:
self._psutil_process_obj = psutil.Process(pid)

if self._psutil_process_obj.pid != pid:
self._psutil_process_obj = psutil.Process(pid)

return self._psutil_process_obj

def is_alive(self):
return self.process is not None and self.process.is_alive()

Expand Down
40 changes: 35 additions & 5 deletions distributed/process.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import asyncio
import logging
import multiprocessing
import os
import re
import threading
Expand Down Expand Up @@ -50,6 +53,8 @@ class AsyncProcess:
All normally blocking methods are wrapped in Tornado coroutines.
"""

_process: multiprocessing.Process

def __init__(self, loop=None, target=None, name=None, args=(), kwargs={}):
if not callable(target):
raise TypeError(f"`target` needs to be callable, not {type(target)!r}")
Expand Down Expand Up @@ -175,7 +180,9 @@ def _run(
target(*args, **kwargs)

@classmethod
def _watch_message_queue(cls, selfref, process, loop, state, q, exit_future):
def _watch_message_queue(
cls, selfref, process: multiprocessing.Process, loop, state, q, exit_future
):
# As multiprocessing.Process is not thread-safe, we run all
# blocking operations from this single loop and ship results
# back to the caller when needed.
Expand Down Expand Up @@ -204,7 +211,12 @@ def _start():
if op == "start":
_call_and_set_future(loop, msg["future"], _start)
elif op == "terminate":
# Send SIGTERM
_call_and_set_future(loop, msg["future"], process.terminate)
elif op == "kill":
# Send SIGKILL
_call_and_set_future(loop, msg["future"], process.kill)

elif op == "stop":
break
else:
Expand Down Expand Up @@ -240,17 +252,35 @@ def start(self):
self._watch_q.put_nowait({"op": "start", "future": fut})
return fut

def terminate(self):
"""
Terminate the child process.
def terminate(self) -> asyncio.Future[None]:
"""Terminate the child process.
This method returns a future.
See also
--------
multiprocessing.Process.terminate
"""
self._check_closed()
fut = Future()
fut: Future[None] = Future()
self._watch_q.put_nowait({"op": "terminate", "future": fut})
return fut

def kill(self) -> asyncio.Future[None]:
"""Send SIGKILL to the child process.
On Windows, this is the same as terminate().
This method returns a future.
See also
--------
multiprocessing.Process.kill
"""
self._check_closed()
fut: Future[None] = Future()
self._watch_q.put_nowait({"op": "kill", "future": fut})
return fut

async def join(self, timeout=None):
"""
Wait for the child process to exit.
Expand Down
22 changes: 21 additions & 1 deletion distributed/tests/test_asyncprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ async def test_terminate():
await proc.start()
await proc.terminate()

await proc.join(timeout=30)
await proc.join()
assert not proc.is_alive()
assert proc.exitcode in (-signal.SIGTERM, 255)

Expand Down Expand Up @@ -312,6 +312,26 @@ async def test_terminate_after_stop():
await proc.start()
await asyncio.sleep(0.1)
await proc.terminate()
await proc.join()


def kill_target(ev):
signal.signal(signal.SIGTERM, signal.SIG_IGN)
ev.set()
sleep(300)


@pytest.mark.skipif(WINDOWS, reason="Needs SIGKILL")
@gen_test()
async def test_kill():
ev = mp_context.Event()
proc = AsyncProcess(target=kill_target, args=(ev,))
await proc.start()
ev.wait()
await proc.kill()
await proc.join()
assert not proc.is_alive()
assert proc.exitcode in (-signal.SIGKILL, 255)


def _worker_process(worker_ready, child_pipe):
Expand Down
105 changes: 43 additions & 62 deletions distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
import logging
import os
import signal
import sys
import threading
from collections import Counter, UserDict
from time import sleep

import psutil
import pytest

import dask.config

import distributed.system
from distributed import Client, Event, KilledWorker, Nanny, Scheduler, Worker, wait
from distributed.compatibility import MACOS
from distributed.compatibility import MACOS, WINDOWS
from distributed.core import Status
from distributed.metrics import monotonic
from distributed.spill import has_zict_210
Expand Down Expand Up @@ -684,7 +684,7 @@ async def test_manual_evict_proto(c, s, a):
await asyncio.sleep(0.01)


async def leak_until_restart(c: Client, s: Scheduler, a: Nanny) -> None:
async def leak_until_restart(c: Client, s: Scheduler) -> None:
s.allowed_failures = 0

def leak():
Expand All @@ -693,32 +693,25 @@ def leak():
L.append(b"0" * 5_000_000)
sleep(0.01)

assert a.process
assert a.process.process
pid = a.process.pid
addr = a.worker_address
with captured_logger(logging.getLogger("distributed.worker_memory")) as logger:
future = c.submit(leak, key="leak")
while (
not a.process
or not a.process.process
or a.process.pid == pid
or a.worker_address == addr
):
await asyncio.sleep(0.01)
(addr,) = s.workers
pid = (await c.run(os.getpid))[addr]

# Test that the restarting message happened only once;
# see test_slow_terminate below.
assert logger.getvalue() == (
f"Worker {addr} (pid={pid}) exceeded 95% memory budget. Restarting...\n"
)
future = c.submit(leak, key="leak")

# Wait until the worker is restarted
while len(s.workers) != 1 or set(s.workers) == {addr}:
await asyncio.sleep(0.01)

# Test that the process has been properly waited for and not just left there
with pytest.raises(psutil.NoSuchProcess):
psutil.Process(pid)

with pytest.raises(KilledWorker):
await future
assert s.tasks["leak"].suspicious == 1
assert await c.run(lambda dask_worker: "leak" in dask_worker.tasks) == {
a.worker_address: False
}
assert not any(
(await c.run(lambda dask_worker: "leak" in dask_worker.tasks)).values()
)
future.release()
while "leak" in s.tasks:
await asyncio.sleep(0.01)
Expand All @@ -733,61 +726,49 @@ def leak():
config={"distributed.worker.memory.monitor-interval": "10ms"},
)
async def test_nanny_terminate(c, s, a):
await leak_until_restart(c, s, a)
await leak_until_restart(c, s)


@pytest.mark.slow
@pytest.mark.parametrize(
"ignore_sigterm",
[
False,
pytest.param(True, marks=pytest.mark.skipif(WINDOWS, reason="Needs SIGKILL")),
],
)
@gen_cluster(
nthreads=[("", 1)],
client=True,
Worker=Nanny,
worker_kwargs={"memory_limit": "400 MiB"},
config={"distributed.worker.memory.monitor-interval": "10ms"},
)
async def test_disk_cleanup_on_terminate(c, s, a):
"""Test that the spilled data on disk is cleaned up when the nanny kills the worker"""
async def test_disk_cleanup_on_terminate(c, s, a, ignore_sigterm):
"""Test that the spilled data on disk is cleaned up when the nanny kills the worker.
Unlike in a regular worker shutdown, where the worker deletes its own spill
directory, the cleanup in case of termination from the monitor is performed by the
nanny.
The worker may be slow to accept SIGTERM, for whatever reason.
At the next iteration of the memory manager, if the process is still alive, the
nanny sends SIGKILL.
"""
if ignore_sigterm:
await c.run(signal.signal, signal.SIGTERM, signal.SIG_IGN)

fut = c.submit(inc, 1, key="myspill")
await wait(fut)
await c.run(lambda dask_worker: dask_worker.data.evict())

glob_out = await c.run(
lambda dask_worker: glob.glob(dask_worker.local_directory + "/**/myspill")
)
spill_file = glob_out[a.worker_address][0]
assert os.path.exists(spill_file)
await leak_until_restart(c, s, a)
assert not os.path.exists(spill_file)


@pytest.mark.slow
@gen_cluster(
client=True,
Worker=Nanny,
nthreads=[("", 1)],
worker_kwargs={"memory_limit": "400 MiB"},
config={"distributed.worker.memory.monitor-interval": "10ms"},
)
async def test_slow_terminate(c, s, a):
"""A worker is slow to accept SIGTERM, e.g. because the
distributed.diskutils.WorkDir teardown is deleting tens of GB worth of spilled data.
"""
spill_fname = next(iter(glob_out.values()))[0]
assert os.path.exists(spill_fname)

def install_slow_sigterm_handler():
def cb(signo, frame):
# If something sends SIGTERM while the previous SIGTERM handler is running,
# you will eventually get RecursionError.
print(f"Received signal {signo}")
sleep(0.2) # Longer than monitor-interval
print("Leaving handler")
sys.exit(0)

signal.signal(signal.SIGTERM, cb)

await c.run(install_slow_sigterm_handler)
# Test that SIGTERM is only sent once
await leak_until_restart(c, s, a)
# Test that SIGTERM can be sent again after the worker restarts
await leak_until_restart(c, s, a)
await leak_until_restart(c, s)
assert not os.path.exists(spill_fname)


@gen_cluster(
Expand Down
45 changes: 34 additions & 11 deletions distributed/worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from dask.utils import format_bytes, parse_bytes, parse_timedelta

from distributed import system
from distributed.compatibility import WINDOWS
from distributed.core import Status
from distributed.metrics import monotonic
from distributed.spill import ManualEvictProto, SpillBuffer
Expand Down Expand Up @@ -333,32 +334,54 @@ def __init__(

def memory_monitor(self, nanny: Nanny) -> None:
"""Track worker's memory. Restart if it goes above terminate fraction."""
if nanny.status != Status.running:
return # pragma: nocover
if nanny.process is None or nanny.process.process is None:
if (
nanny.status != Status.running
or nanny.process is None
or nanny.process.process is None
or nanny.process.process.pid is None
):
return # pragma: nocover

process = nanny.process.process
try:
proc = nanny._psutil_process
memory = proc.memory_info().rss
memory = psutil.Process(process.pid).memory_info().rss
except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied):
return # pragma: nocover

if process.pid in (self._last_terminated_pid, None):
# We already sent SIGTERM to the worker, but its handler is still running
# since the previous iteration of the memory_monitor - for example, it
# may be taking a long time deleting all the spilled data from disk.
if memory / self.memory_limit <= self.memory_terminate_fraction:
return
self._last_terminated_pid = -1

if memory / self.memory_limit > self.memory_terminate_fraction:
if self._last_terminated_pid != process.pid:
logger.warning(
f"Worker {nanny.worker_address} (pid={process.pid}) exceeded "
f"{self.memory_terminate_fraction * 100:.0f}% memory budget. "
"Restarting...",
)
self._last_terminated_pid = process.pid
process.terminate()
else:
# We already sent SIGTERM to the worker, but the process is still alive
# since the previous iteration of the memory_monitor - for example, some
# user code may have tampered with signal handlers.
# Send SIGKILL for immediate termination.
#
# Note that this should not be a disk-related issue. Unlike in a regular
# worker shutdown, where the worker cleans up its own spill directory, in
# case of SIGTERM no atexit or weakref.finalize callback is triggered
# whatsoever; instead, the nanny cleans up the spill directory *after* the
# worker has been shut down and before starting a new one.
# This is important, as spill directory cleanup may potentially take tens of
# seconds and, if the worker did it, any task that was running and leaking
# would continue to do so for the whole duration of the cleanup, increasing
# the risk of going beyond 100%.
logger.warning(
f"Worker {nanny.worker_address} (pid={process.pid}) is slow to %s",
# On Windows, kill() is an alias to terminate()
"terminate; trying again"
if WINDOWS
else "accept SIGTERM; sending SIGKILL",
)
process.kill()


def parse_memory_limit(
Expand Down

0 comments on commit ba39915

Please sign in to comment.