Skip to content

Commit

Permalink
AIP-72: Handle External update TI state in Supervisor
Browse files Browse the repository at this point in the history
- Updated logic to handle externally updated TI state in Supervisor. This states could have been externally changed via UI, CLI, API etc
- Replaced FASTEST_HEARTBEAT_INTERVAL and SLOWEST_HEARTBEAT_INTERVAL with MIN_HEARTBEAT_INTERVAL and HEARTBEAT_THRESHOLD for clarity and alignment with terminology used in the codebase.
  • Loading branch information
kaxil committed Nov 27, 2024
1 parent 43adccf commit 0990abf
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 11 deletions.
46 changes: 36 additions & 10 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import structlog
from pydantic import TypeAdapter

from airflow.sdk.api.client import Client
from airflow.sdk.api.client import Client, ServerResponseError
from airflow.sdk.api.datamodels._generated import IntermediateTIState, TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
DeferTask,
Expand All @@ -63,9 +63,11 @@
log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor")

# TODO: Pull this from config
SLOWEST_HEARTBEAT_INTERVAL: int = 30
# (previously `[scheduler] local_task_job_heartbeat_sec` with the following as fallback if it is 0:
# `[scheduler] scheduler_zombie_task_threshold`)
HEARTBEAT_THRESHOLD: int = 30
# Don't heartbeat more often than this
FASTEST_HEARTBEAT_INTERVAL: int = 5
MIN_HEARTBEAT_INTERVAL: int = 5


@overload
Expand Down Expand Up @@ -416,10 +418,6 @@ def _monitor_subprocess(self):
- Sends heartbeats to the client to keep the task alive
- Checks if the subprocess has exited
"""
# Until we have a selector for the process, don't poll for more than 10s, just in case it exists but
# doesn't produce any output
max_poll_interval = 10

while self._exit_code is None or len(self.selector.get_map()):
last_heartbeat_ago = time.monotonic() - self._last_heartbeat
# Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible
Expand All @@ -428,8 +426,8 @@ def _monitor_subprocess(self):
0, # Make sure this value is never negative,
min(
# Ensure we heartbeat _at most_ 75% through time the zombie threshold time
SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75,
max_poll_interval,
HEARTBEAT_THRESHOLD - last_heartbeat_ago * 0.75,
MIN_HEARTBEAT_INTERVAL,
),
)
events = self.selector.select(timeout=max_wait_time)
Expand All @@ -455,10 +453,38 @@ def _check_subprocess_exit(self):

def _send_heartbeat_if_needed(self):
"""Send a heartbeat to the client if heartbeat interval has passed."""
if time.monotonic() - self._last_heartbeat >= FASTEST_HEARTBEAT_INTERVAL:
# If the process has exited, we don't need to send any more heartbeats
if self._exit_code is not None:
return

if time.monotonic() - self._last_heartbeat >= MIN_HEARTBEAT_INTERVAL:
try:
self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid)
self._last_heartbeat = time.monotonic()
except ServerResponseError as e:
# TODO: Should we instead check httpx.HTTPStatusError?
if e.response.status_code == 409:
reason = e.detail.get("reason", "")
if reason == "not_running":
error_msg = e.detail.get(
"message", "Task is no longer in the running state and task should terminate"
)
log.error(error_msg, current_state=e.detail.get("current_state"))
elif reason == "running_elsewhere":
error_msg = e.detail.get("message", "Task is already running elsewhere")
log.error(
error_msg,
current_hostname=e.detail.get("current_hostname"),
current_pid=e.detail.get("current_pid"),
)
elif e.response.status_code == 404:
log.error("Task Instance not found")
else:
# TODO: Handle other errors
raise

# If heartbeating raises an error, kill the subprocess
self.kill(signal.SIGTERM)
except Exception:
log.warning("Failed to send heartbeat", exc_info=True)
# TODO: If we couldn't heartbeat for X times the interval, kill ourselves
Expand Down
51 changes: 50 additions & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch):
"""Test that the WatchedSubprocess class regularly sends heartbeat requests, up to a certain frequency"""
import airflow.sdk.execution_time.supervisor

monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "FASTEST_HEARTBEAT_INTERVAL", 0.1)
monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1)

def subprocess_main():
sys.stdin.readline()
Expand Down Expand Up @@ -304,6 +304,55 @@ def handle_request(request: httpx.Request) -> httpx.Response:
"previous_state": "running",
}

def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker):
"""
Test that ensures that the Supervisor does not cause the task to fail if the Task Instance is no longer
in the running state.
"""
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.ERROR))

import airflow.sdk.execution_time.supervisor

monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1)

def subprocess_main():
sys.stdin.readline()
sleep(5)

ti_id = uuid7()

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/heartbeat":
return httpx.Response(
409,
json={
"reason": "not_running",
"message": "TI is no longer in the running state and task should terminate",
"current_state": "success",
},
)
return httpx.Response(status_code=204)

proc = WatchedSubprocess.start(
path=os.devnull,
ti=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1),
client=make_client(transport=httpx.MockTransport(handle_request)),
target=subprocess_main,
)

# Wait for the subprocess to finish
assert proc.wait() == -signal.SIGTERM

assert captured_logs == [
{
"current_state": "success",
"event": "TI is no longer in the running state and task should terminate",
"level": "error",
"logger": "supervisor",
"timestamp": mocker.ANY,
}
]


class TestHandleRequest:
@pytest.fixture
Expand Down

0 comments on commit 0990abf

Please sign in to comment.