Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KPO Maintain backward compatibility for execute_complete and trigger run method #37363

Merged
merged 9 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 30 additions & 56 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence

import kubernetes
from deprecated import deprecated
from kubernetes.client import CoreV1Api, V1Pod, models as k8s
from kubernetes.stream import stream
from urllib3.exceptions import HTTPError
Expand Down Expand Up @@ -699,7 +700,7 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
If ``logging_interval`` is not None, it could be that the pod is still running and we'll just
grab the latest logs and defer back to the trigger again.
"""
remote_pod = None
self.pod = None
try:
self.pod_request_obj = self.build_pod_request_obj(context)
self.pod = self.find_pod(
Expand All @@ -713,6 +714,11 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")

if self.callbacks:
self.callbacks.on_operator_resuming(
pod=self.pod, event=event, client=self.client, mode=ExecutionMode.SYNC
)

if self.get_logs:
last_log_time = event and event.get("last_log_time")
if last_log_time:
Expand All @@ -729,66 +735,34 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:

if self.do_xcom_push:
result = self.extract_xcom(pod=self.pod)
remote_pod = self.pod_manager.await_pod_completion(self.pod)
if event["status"] in ("error", "failed", "timeout"):
raise AirflowException(event)
return result
# self.pod = self.pod_manager.await_pod_completion(self.pod)
except TaskDeferred:
raise
except Exception:
self.cleanup(
pod=self.pod or self.pod_request_obj,
remote_pod=remote_pod,
finally:
self._clean(event)

def _clean(self, event: dict[str, Any]):
if event["status"] == "running":
return
istio_enabled = self.is_istio_enabled(self.pod)
# Skip await_pod_completion when the event is 'timeout' due to the pod can hang
# on the ErrImagePull or ContainerCreating step and it will never complete
if event["status"] != "timeout":
self.pod = self.pod_manager.await_pod_completion(
self.pod, istio_enabled, self.base_container_name
)
if self.pod is not None:
self.post_complete_action(
pod=self.pod,
remote_pod=self.pod,
)
raise
self.cleanup(
pod=self.pod or self.pod_request_obj,
remote_pod=remote_pod,
)
if self.do_xcom_push:
return result

@deprecated(reason="use `trigger_reentry` instead.", category=AirflowProviderDeprecationWarning)
def execute_complete(self, context: Context, event: dict, **kwargs):
self.log.debug("Triggered with event: %s", event)
pod = None
try:
pod = self.hook.get_pod(
event["name"],
event["namespace"],
)
if self.callbacks:
self.callbacks.on_operator_resuming(
pod=pod, event=event, client=self.client, mode=ExecutionMode.SYNC
)
if event["status"] in ("error", "failed", "timeout"):
# fetch some logs when pod is failed
if self.get_logs:
self.write_logs(pod)
if "stack_trace" in event:
message = f"{event['message']}\n{event['stack_trace']}"
else:
message = event["message"]
if self.do_xcom_push:
# In the event of base container failure, we need to kill the xcom sidecar.
# We disregard xcom output and do that here
_ = self.extract_xcom(pod=pod)
raise AirflowException(message)
elif event["status"] == "success":
# fetch some logs when pod is executed successfully
if self.get_logs:
self.write_logs(pod)

if self.do_xcom_push:
xcom_sidecar_output = self.extract_xcom(pod=pod)
return xcom_sidecar_output
finally:
istio_enabled = self.is_istio_enabled(pod)
# Skip await_pod_completion when the event is 'timeout' due to the pod can hang
# on the ErrImagePull or ContainerCreating step and it will never complete
if event["status"] != "timeout":
pod = self.pod_manager.await_pod_completion(pod, istio_enabled, self.base_container_name)
if pod is not None:
self.post_complete_action(
pod=pod,
remote_pod=pod,
)
self.trigger_reentry(context=context, event=event)

def write_logs(self, pod: k8s.V1Pod):
try:
Expand Down
29 changes: 23 additions & 6 deletions airflow/providers/cncf/kubernetes/triggers/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,32 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

def _get_terminal_event(self, state) -> TriggerEvent:
if state == PodPhase.SUCCEEDED:
status = "success"
else:
status = "failed"
return TriggerEvent({"status": status, "namespace": self.pod_namespace, "name": self.pod_name})

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current pod status and yield a TriggerEvent."""
self.log.info("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace)
try:
state = await self._wait_for_pod_start()
if state in PodPhase.terminal_states:
event = TriggerEvent(
{"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name}
)
event = self._get_terminal_event(state)
else:
event = await self._wait_for_container_completion()
yield event
except PodLaunchTimeoutException as e:
description = self._format_exception_description(e)
yield TriggerEvent(
{
"status": "timeout",
"error_type": e.__class__.__name__,
"description": description,
}
)
except Exception as e:
description = self._format_exception_description(e)
yield TriggerEvent(
Expand Down Expand Up @@ -215,9 +229,12 @@ async def _wait_for_container_completion(self) -> TriggerEvent:
while True:
pod = await self.hook.get_pod(self.pod_name, self.pod_namespace)
if not container_is_running(pod=pod, container_name=self.base_container_name):
return TriggerEvent(
{"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name}
)
container_state = self.define_container_state(pod)
if container_state == ContainerState.TERMINATED:
state = "success"
else:
state = "failed"
return TriggerEvent({"status": state, "namespace": self.pod_namespace, "name": self.pod_name})
if time_get_more_logs and timezone.utcnow() > time_get_more_logs:
return TriggerEvent({"status": "running", "last_log_time": self.last_log_time})
await asyncio.sleep(self.poll_interval)
Expand Down
8 changes: 4 additions & 4 deletions tests/providers/cncf/kubernetes/triggers/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigg

expected_event = TriggerEvent(
{
"pod_name": POD_NAME,
"name": POD_NAME,
"namespace": NAMESPACE,
"status": "done",
}
Expand Down Expand Up @@ -188,7 +188,7 @@ async def test_run_loop_return_failed_event(self, mock_hook, mock_method, trigge

expected_event = TriggerEvent(
{
"pod_name": POD_NAME,
"name": POD_NAME,
"namespace": NAMESPACE,
"status": "done",
}
Expand Down Expand Up @@ -236,7 +236,7 @@ async def test_logging_in_trigger_when_fail_should_execute_successfully(
"logging_interval, exp_event",
[
param(0, {"status": "running", "last_log_time": DateTime(2022, 1, 1)}, id="short_interval"),
param(None, {"status": "done", "namespace": mock.ANY, "pod_name": mock.ANY}, id="no_interval"),
param(None, {"status": "done", "namespace": mock.ANY, "name": mock.ANY}, id="no_interval"),
],
)
@mock.patch(
Expand Down Expand Up @@ -325,4 +325,4 @@ async def test_run_loop_return_timeout_event(

generator = trigger.run()
actual = await generator.asend(None)
assert actual == TriggerEvent({"status": "done", "namespace": NAMESPACE, "pod_name": POD_NAME})
assert actual == TriggerEvent({"status": "done", "namespace": NAMESPACE, "name": POD_NAME})
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def test_run_loop_return_success_event_should_execute_successfully(

expected_event = TriggerEvent(
{
"pod_name": POD_NAME,
"name": POD_NAME,
"namespace": NAMESPACE,
"status": "done",
}
Expand All @@ -144,7 +144,7 @@ async def test_run_loop_return_failed_event_should_execute_successfully(

expected_event = TriggerEvent(
{
"pod_name": POD_NAME,
"name": POD_NAME,
"namespace": NAMESPACE,
"status": "done",
}
Expand Down
Loading