diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 73389f4038282..61442a6014ebc 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -18,6 +18,7 @@ from __future__ import annotations +import datetime import json import logging import re @@ -30,6 +31,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import kubernetes +from deprecated import deprecated from kubernetes.client import CoreV1Api, V1Pod, models as k8s from kubernetes.stream import stream from urllib3.exceptions import HTTPError @@ -68,7 +70,6 @@ EMPTY_XCOM_RESULT, OnFinishAction, PodLaunchFailedException, - PodLaunchTimeoutException, PodManager, PodNotFoundException, PodOperatorHookProtocol, @@ -79,7 +80,6 @@ from airflow.settings import pod_mutation_hook from airflow.utils import yaml from airflow.utils.helpers import prune_dict, validate_key -from airflow.utils.timezone import utcnow from airflow.version import version as airflow_version if TYPE_CHECKING: @@ -656,7 +656,7 @@ def execute_async(self, context: Context): def invoke_defer_method(self, last_log_time: DateTime | None = None): """Redefine triggers which are being used in child classes.""" - trigger_start_time = utcnow() + trigger_start_time = datetime.datetime.now(tz=datetime.timezone.utc) self.defer( trigger=KubernetesPodTrigger( pod_name=self.pod.metadata.name, # type: ignore[union-attr] @@ -678,117 +678,87 @@ def invoke_defer_method(self, last_log_time: DateTime | None = None): method_name="trigger_reentry", ) - @staticmethod - def raise_for_trigger_status(event: dict[str, Any]) -> None: - """Raise exception if pod is not in expected state.""" - if event["status"] == "error": - error_type = event["error_type"] - description = event["description"] - if error_type == "PodLaunchTimeoutException": - raise PodLaunchTimeoutException(description) - else: - raise AirflowException(description) - def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any: """ Point of re-entry from trigger. - If ``logging_interval`` is None, then at this point the pod should be done and we'll just fetch + If ``logging_interval`` is None, then at this point, the pod should be done, and we'll just fetch the logs and exit. - If ``logging_interval`` is not None, it could be that the pod is still running and we'll just + If ``logging_interval`` is not None, it could be that the pod is still running, and we'll just grab the latest logs and defer back to the trigger again. """ - remote_pod = None + self.pod = None try: - self.pod_request_obj = self.build_pod_request_obj(context) - self.pod = self.find_pod( - namespace=self.namespace or self.pod_request_obj.metadata.namespace, - context=context, - ) + pod_name = event["name"] + pod_namespace = event["namespace"] - # we try to find pod before possibly raising so that on_kill will have `pod` attr - self.raise_for_trigger_status(event) + self.pod = self.hook.get_pod(pod_name, pod_namespace) if not self.pod: raise PodNotFoundException("Could not find pod after resuming from deferral") - if self.get_logs: - last_log_time = event and event.get("last_log_time") - if last_log_time: - self.log.info("Resuming logs read from time %r", last_log_time) - pod_log_status = self.pod_manager.fetch_container_logs( - pod=self.pod, - container_name=self.BASE_CONTAINER_NAME, - follow=self.logging_interval is None, - since_time=last_log_time, - ) - if pod_log_status.running: - self.log.info("Container still running; deferring again.") - self.invoke_defer_method(pod_log_status.last_log_time) - - if self.do_xcom_push: - result = self.extract_xcom(pod=self.pod) - remote_pod = self.pod_manager.await_pod_completion(self.pod) - except TaskDeferred: - raise - except Exception: - self.cleanup( - pod=self.pod or self.pod_request_obj, - remote_pod=remote_pod, - ) - raise - self.cleanup( - pod=self.pod or self.pod_request_obj, - remote_pod=remote_pod, - ) - if self.do_xcom_push: - return result - - def execute_complete(self, context: Context, event: dict, **kwargs): - self.log.debug("Triggered with event: %s", event) - pod = None - try: - pod = self.hook.get_pod( - event["name"], - event["namespace"], - ) - if self.callbacks: + if self.callbacks and event["status"] != "running": self.callbacks.on_operator_resuming( - pod=pod, event=event, client=self.client, mode=ExecutionMode.SYNC + pod=self.pod, event=event, client=self.client, mode=ExecutionMode.SYNC ) + if event["status"] in ("error", "failed", "timeout"): - # fetch some logs when pod is failed - if self.get_logs: - self.write_logs(pod) - if "stack_trace" in event: - message = f"{event['message']}\n{event['stack_trace']}" - else: - message = event["message"] if self.do_xcom_push: - # In the event of base container failure, we need to kill the xcom sidecar. - # We disregard xcom output and do that here - _ = self.extract_xcom(pod=pod) + _ = self.extract_xcom(pod=self.pod) + + message = event.get("stack_trace", event["message"]) raise AirflowException(message) - elif event["status"] == "success": - # fetch some logs when pod is executed successfully + + elif event["status"] == "running": if self.get_logs: - self.write_logs(pod) + last_log_time = event.get("last_log_time") + self.log.info("Resuming logs read from time %r", last_log_time) + + pod_log_status = self.pod_manager.fetch_container_logs( + pod=self.pod, + container_name=self.BASE_CONTAINER_NAME, + follow=self.logging_interval is None, + since_time=last_log_time, + ) + if pod_log_status.running: + self.log.info("Container still running; deferring again.") + self.invoke_defer_method(pod_log_status.last_log_time) + else: + self.invoke_defer_method() + + elif event["status"] == "success": if self.do_xcom_push: - xcom_sidecar_output = self.extract_xcom(pod=pod) + xcom_sidecar_output = self.extract_xcom(pod=self.pod) return xcom_sidecar_output + return + except TaskDeferred: + raise finally: - istio_enabled = self.is_istio_enabled(pod) - # Skip await_pod_completion when the event is 'timeout' due to the pod can hang - # on the ErrImagePull or ContainerCreating step and it will never complete - if event["status"] != "timeout": - pod = self.pod_manager.await_pod_completion(pod, istio_enabled, self.base_container_name) - if pod is not None: - self.post_complete_action( - pod=pod, - remote_pod=pod, - ) + self._clean(event) + + def _clean(self, event: dict[str, Any]): + if event["status"] == "running": + return + if self.get_logs: + self.write_logs(self.pod) + istio_enabled = self.is_istio_enabled(self.pod) + # Skip await_pod_completion when the event is 'timeout' due to the pod can hang + # on the ErrImagePull or ContainerCreating step and it will never complete + if event["status"] != "timeout": + self.pod = self.pod_manager.await_pod_completion( + self.pod, istio_enabled, self.base_container_name + ) + if self.pod is not None: + self.post_complete_action( + pod=self.pod, + remote_pod=self.pod, + ) + + @deprecated(reason="use `trigger_reentry` instead.", category=AirflowProviderDeprecationWarning) + def execute_complete(self, context: Context, event: dict, **kwargs): + self.trigger_reentry(context=context, event=event) def write_logs(self, pod: k8s.V1Pod): try: diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py index e34a73f146fe2..c9b1e62226541 100644 --- a/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -30,10 +30,8 @@ OnFinishAction, PodLaunchTimeoutException, PodPhase, - container_is_running, ) from airflow.triggers.base import BaseTrigger, TriggerEvent -from airflow.utils import timezone if TYPE_CHECKING: from kubernetes_asyncio.client.models import V1Pod @@ -160,22 +158,49 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.log.info("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace) try: state = await self._wait_for_pod_start() - if state in PodPhase.terminal_states: + if state == ContainerState.TERMINATED: event = TriggerEvent( - {"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name} + { + "status": "success", + "namespace": self.pod_namespace, + "name": self.pod_name, + "message": "All containers inside pod have started successfully.", + } + ) + elif state == ContainerState.FAILED: + event = TriggerEvent( + { + "status": "failed", + "namespace": self.pod_namespace, + "name": self.pod_name, + "message": "pod failed", + } ) else: event = await self._wait_for_container_completion() yield event + return + except PodLaunchTimeoutException as e: + message = self._format_exception_description(e) + yield TriggerEvent( + { + "name": self.pod_name, + "namespace": self.pod_namespace, + "status": "timeout", + "message": message, + } + ) except Exception as e: - description = self._format_exception_description(e) yield TriggerEvent( { + "name": self.pod_name, + "namespace": self.pod_namespace, "status": "error", - "error_type": e.__class__.__name__, - "description": description, + "message": str(e), + "stack_trace": traceback.format_exc(), } ) + return def _format_exception_description(self, exc: Exception) -> Any: if isinstance(exc, PodLaunchTimeoutException): @@ -189,14 +214,13 @@ def _format_exception_description(self, exc: Exception) -> Any: description += f"\ntrigger traceback:\n{curr_traceback}" return description - async def _wait_for_pod_start(self) -> Any: + async def _wait_for_pod_start(self) -> ContainerState: """Loops until pod phase leaves ``PENDING`` If timeout is reached, throws error.""" - start_time = timezone.utcnow() - timeout_end = start_time + datetime.timedelta(seconds=self.startup_timeout) - while timeout_end > timezone.utcnow(): + delta = datetime.datetime.now(tz=datetime.timezone.utc) - self.trigger_start_time + while self.startup_timeout >= delta.total_seconds(): pod = await self.hook.get_pod(self.pod_name, self.pod_namespace) if not pod.status.phase == "Pending": - return pod.status.phase + return self.define_container_state(pod) self.log.info("Still waiting for pod to start. The pod state is %s", pod.status.phase) await asyncio.sleep(self.poll_interval) raise PodLaunchTimeoutException("Pod did not leave 'Pending' phase within specified timeout") @@ -208,18 +232,30 @@ async def _wait_for_container_completion(self) -> TriggerEvent: Waits until container is no longer in running state. If trigger is configured with a logging period, then will emit an event to resume the task for the purpose of fetching more logs. """ - time_begin = timezone.utcnow() + time_begin = datetime.datetime.now(tz=datetime.timezone.utc) time_get_more_logs = None if self.logging_interval is not None: time_get_more_logs = time_begin + datetime.timedelta(seconds=self.logging_interval) while True: pod = await self.hook.get_pod(self.pod_name, self.pod_namespace) - if not container_is_running(pod=pod, container_name=self.base_container_name): + container_state = self.define_container_state(pod) + if container_state == ContainerState.TERMINATED: + return TriggerEvent( + {"status": "success", "namespace": self.pod_namespace, "name": self.pod_name} + ) + elif container_state == ContainerState.FAILED: + return TriggerEvent( + {"status": "failed", "namespace": self.pod_namespace, "name": self.pod_name} + ) + if time_get_more_logs and datetime.datetime.now(tz=datetime.timezone.utc) > time_get_more_logs: return TriggerEvent( - {"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name} + { + "status": "running", + "last_log_time": self.last_log_time, + "namespace": self.pod_namespace, + "name": self.pod_name, + } ) - if time_get_more_logs and timezone.utcnow() > time_get_more_logs: - return TriggerEvent({"status": "running", "last_log_time": self.last_log_time}) await asyncio.sleep(self.poll_interval) def _get_async_hook(self) -> AsyncKubernetesHook: diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py b/tests/providers/cncf/kubernetes/operators/test_pod.py index c27cd231465cb..faa21eb7d75fc 100644 --- a/tests/providers/cncf/kubernetes/operators/test_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_pod.py @@ -35,7 +35,6 @@ from airflow.providers.cncf.kubernetes.secret import Secret from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger from airflow.providers.cncf.kubernetes.utils.pod_manager import ( - PodLaunchTimeoutException, PodLoggingStatus, PodPhase, ) @@ -1973,41 +1972,39 @@ def test_cleanup_log_pod_spec_on_failure(self, log_pod_spec_on_failure, expect_m with pytest.raises(AirflowException, match=expect_match): k.cleanup(pod, pod) - @mock.patch( - "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.raise_for_trigger_status" - ) - @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.find_pod") + @mock.patch(f"{HOOK_CLASS}.get_pod") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs") def test_get_logs_running( self, fetch_container_logs, await_pod_completion, - find_pod, - raise_for_trigger_status, + get_pod, ): """When logs fetch exits with status running, raise task deferred""" pod = MagicMock() - find_pod.return_value = pod + get_pod.return_value = pod op = KubernetesPodOperator(task_id="test_task", name="test-pod", get_logs=True) await_pod_completion.return_value = None fetch_container_logs.return_value = PodLoggingStatus(True, None) with pytest.raises(TaskDeferred): - op.trigger_reentry(create_context(op), None) + op.trigger_reentry( + create_context(op), + event={"name": TEST_NAME, "namespace": TEST_NAMESPACE, "status": "running"}, + ) fetch_container_logs.is_called_with(pod, "base") @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup") - @mock.patch( - "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.raise_for_trigger_status" - ) @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.find_pod") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs") - def test_get_logs_not_running(self, fetch_container_logs, find_pod, raise_for_trigger_status, cleanup): + def test_get_logs_not_running(self, fetch_container_logs, find_pod, cleanup): pod = MagicMock() find_pod.return_value = pod op = KubernetesPodOperator(task_id="test_task", name="test-pod", get_logs=True) fetch_container_logs.return_value = PodLoggingStatus(False, None) - op.trigger_reentry(create_context(op), None) + op.trigger_reentry( + create_context(op), event={"name": TEST_NAME, "namespace": TEST_NAMESPACE, "status": "success"} + ) fetch_container_logs.is_called_with(pod, "base") @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup") @@ -2016,14 +2013,15 @@ def test_trigger_error(self, find_pod, cleanup): """Assert that trigger_reentry raise exception in case of error""" find_pod.return_value = MagicMock() op = KubernetesPodOperator(task_id="test_task", name="test-pod", get_logs=True) - with pytest.raises(PodLaunchTimeoutException): + with pytest.raises(AirflowException): context = create_context(op) op.trigger_reentry( context, { - "status": "error", - "error_type": "PodLaunchTimeoutException", - "description": "any message", + "status": "timeout", + "message": "any message", + "name": TEST_NAME, + "namespace": TEST_NAMESPACE, }, ) diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py b/tests/providers/cncf/kubernetes/triggers/test_pod.py index d12100e4e35c7..bed52811fc675 100644 --- a/tests/providers/cncf/kubernetes/triggers/test_pod.py +++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py @@ -122,9 +122,10 @@ async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigg expected_event = TriggerEvent( { - "pod_name": POD_NAME, - "namespace": NAMESPACE, - "status": "done", + "status": "success", + "namespace": "default", + "name": "test-pod-name", + "message": "All containers inside pod have started successfully.", } ) actual_event = await trigger.run().asend(None) @@ -132,16 +133,11 @@ async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigg assert actual_event == expected_event @pytest.mark.asyncio - @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.container_is_running") - @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook.get_pod") - @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start") + @mock.patch(f"{TRIGGER_PATH}.define_container_state") @mock.patch(f"{TRIGGER_PATH}.hook") - async def test_run_loop_return_waiting_event( - self, mock_hook, mock_method, mock_get_pod, mock_container_is_running, trigger, caplog - ): + async def test_run_loop_return_waiting_event(self, mock_hook, mock_method, trigger, caplog): mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.WAITING - mock_container_is_running.return_value = True caplog.set_level(logging.INFO) @@ -153,16 +149,11 @@ async def test_run_loop_return_waiting_event( assert f"Sleeping for {POLL_INTERVAL} seconds." @pytest.mark.asyncio - @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.container_is_running") - @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook.get_pod") - @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start") + @mock.patch(f"{TRIGGER_PATH}.define_container_state") @mock.patch(f"{TRIGGER_PATH}.hook") - async def test_run_loop_return_running_event( - self, mock_hook, mock_method, mock_get_pod, mock_container_is_running, trigger, caplog - ): + async def test_run_loop_return_running_event(self, mock_hook, mock_method, trigger, caplog): mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.RUNNING - mock_container_is_running.return_value = True caplog.set_level(logging.INFO) @@ -187,11 +178,7 @@ async def test_run_loop_return_failed_event(self, mock_hook, mock_method, trigge mock_method.return_value = ContainerState.FAILED expected_event = TriggerEvent( - { - "pod_name": POD_NAME, - "namespace": NAMESPACE, - "status": "done", - } + {"status": "failed", "namespace": "default", "name": "test-pod-name", "message": "pod failed"} ) actual_event = await trigger.run().asend(None) @@ -210,8 +197,14 @@ async def test_logging_in_trigger_when_exception_should_execute_successfully( generator = trigger.run() actual = await generator.asend(None) - actual_stack_trace = actual.payload.pop("description") - assert actual_stack_trace.startswith("Trigger KubernetesPodTrigger failed with exception Exception") + actual_stack_trace = actual.payload.pop("stack_trace") + assert ( + TriggerEvent( + {"name": POD_NAME, "namespace": NAMESPACE, "status": "error", "message": "Test exception"} + ) + == actual + ) + assert actual_stack_trace.startswith("Traceback (most recent call last):") @pytest.mark.asyncio @mock.patch(f"{TRIGGER_PATH}.define_container_state") @@ -235,16 +228,24 @@ async def test_logging_in_trigger_when_fail_should_execute_successfully( @pytest.mark.parametrize( "logging_interval, exp_event", [ - param(0, {"status": "running", "last_log_time": DateTime(2022, 1, 1)}, id="short_interval"), - param(None, {"status": "done", "namespace": mock.ANY, "pod_name": mock.ANY}, id="no_interval"), + param( + 0, + { + "status": "running", + "last_log_time": DateTime(2022, 1, 1), + "name": POD_NAME, + "namespace": NAMESPACE, + }, + id="short_interval", + ), ], ) - @mock.patch( - "kubernetes_asyncio.client.CoreV1Api.read_namespaced_pod", - new=get_read_pod_mock_containers([1, 1, None, None]), - ) - @mock.patch("kubernetes_asyncio.config.load_kube_config") - async def test_running_log_interval(self, load_kube_config, logging_interval, exp_event): + @mock.patch(f"{TRIGGER_PATH}.define_container_state") + @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start") + @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.AsyncKubernetesHook.get_pod") + async def test_running_log_interval( + self, mock_get_pod, mock_wait_for_pod_start, define_container_state, logging_interval, exp_event + ): """ If log interval given, should emit event with running status and last log time. Otherwise, should make it to second loop and emit "done" event. @@ -254,14 +255,15 @@ async def test_running_log_interval(self, load_kube_config, logging_interval, ex interval is None, the second "running" status will just result in continuation of the loop. And when in the next loop we get a non-running status, the trigger fires a "done" event. """ + define_container_state.return_value = "running" trigger = KubernetesPodTrigger( - pod_name=mock.ANY, - pod_namespace=mock.ANY, - trigger_start_time=mock.ANY, - base_container_name=mock.ANY, + pod_name=POD_NAME, + pod_namespace=NAMESPACE, + trigger_start_time=datetime.datetime.now(tz=datetime.timezone.utc), + base_container_name=BASE_CONTAINER_NAME, startup_timeout=5, poll_interval=1, - logging_interval=logging_interval, + logging_interval=1, last_log_time=DateTime(2022, 1, 1), ) assert await trigger.run().__anext__() == TriggerEvent(exp_event) @@ -306,12 +308,12 @@ def test_define_container_state_should_execute_successfully( @pytest.mark.asyncio @pytest.mark.parametrize("container_state", [ContainerState.WAITING, ContainerState.UNDEFINED]) - @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start") + @mock.patch(f"{TRIGGER_PATH}.define_container_state") @mock.patch(f"{TRIGGER_PATH}.hook") async def test_run_loop_return_timeout_event( self, mock_hook, mock_method, trigger, caplog, container_state ): - trigger.trigger_start_time = TRIGGER_START_TIME - datetime.timedelta(seconds=5) + trigger.trigger_start_time = TRIGGER_START_TIME - datetime.timedelta(minutes=2) mock_hook.get_pod.return_value = self._mock_pod_result( mock.MagicMock( status=mock.MagicMock( @@ -325,4 +327,14 @@ async def test_run_loop_return_timeout_event( generator = trigger.run() actual = await generator.asend(None) - assert actual == TriggerEvent({"status": "done", "namespace": NAMESPACE, "pod_name": POD_NAME}) + assert ( + TriggerEvent( + { + "name": POD_NAME, + "namespace": NAMESPACE, + "status": "timeout", + "message": "Pod did not leave 'Pending' phase within specified timeout", + } + ) + == actual + ) diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py index ca7b7ba3588fd..c6a2d4e72fcbb 100644 --- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py @@ -108,19 +108,20 @@ def test_serialize_should_execute_successfully(self, trigger): } @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") + @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start") @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_run_loop_return_success_event_should_execute_successfully( - self, mock_hook, mock_method, trigger + self, mock_hook, mock_wait_pod, trigger ): mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) - mock_method.return_value = ContainerState.TERMINATED + mock_wait_pod.return_value = ContainerState.TERMINATED expected_event = TriggerEvent( { - "pod_name": POD_NAME, + "name": POD_NAME, "namespace": NAMESPACE, - "status": "done", + "status": "success", + "message": "All containers inside pod have started successfully.", } ) actual_event = await trigger.run().asend(None) @@ -128,10 +129,10 @@ async def test_run_loop_return_success_event_should_execute_successfully( assert actual_event == expected_event @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") + @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start") @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_run_loop_return_failed_event_should_execute_successfully( - self, mock_hook, mock_method, trigger + self, mock_hook, mock_wait_pod, trigger ): mock_hook.get_pod.return_value = self._mock_pod_result( mock.MagicMock( @@ -140,13 +141,14 @@ async def test_run_loop_return_failed_event_should_execute_successfully( ) ) ) - mock_method.return_value = ContainerState.FAILED + mock_wait_pod.return_value = ContainerState.FAILED expected_event = TriggerEvent( { - "pod_name": POD_NAME, + "name": POD_NAME, "namespace": NAMESPACE, - "status": "done", + "status": "failed", + "message": "pod failed", } ) actual_event = await trigger.run().asend(None) @@ -154,18 +156,15 @@ async def test_run_loop_return_failed_event_should_execute_successfully( assert actual_event == expected_event @pytest.mark.asyncio - @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.container_is_running") - @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook.get_pod") @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start") + @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_run_loop_return_waiting_event_should_execute_successfully( - self, mock_hook, mock_method, mock_get_pod, mock_container_is_running, trigger, caplog + self, mock_hook, mock_method, mock_wait_pod, trigger, caplog ): mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) - mock_method.return_value = ContainerState.RUNNING - mock_container_is_running.return_value = True + mock_method.return_value = ContainerState.WAITING - trigger.logging_interval = 10 caplog.set_level(logging.INFO) task = asyncio.create_task(trigger.run().__anext__()) @@ -176,15 +175,13 @@ async def test_run_loop_return_waiting_event_should_execute_successfully( assert f"Sleeping for {POLL_INTERVAL} seconds." @pytest.mark.asyncio - @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.container_is_running") - @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook.get_pod") @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start") + @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_run_loop_return_running_event_should_execute_successfully( - self, mock_hook, mock_method, mock_get_pod, mock_container_is_running, trigger, caplog + self, mock_hook, mock_method, mock_wait_pod, trigger, caplog ): mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) - mock_container_is_running.return_value = True mock_method.return_value = ContainerState.RUNNING caplog.set_level(logging.INFO) @@ -197,9 +194,10 @@ async def test_run_loop_return_running_event_should_execute_successfully( assert f"Sleeping for {POLL_INTERVAL} seconds." @pytest.mark.asyncio + @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start") @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_logging_in_trigger_when_exception_should_execute_successfully( - self, mock_hook, trigger, caplog + self, mock_hook, mock_wait_pod, trigger, caplog ): """ Test that GKEStartPodTrigger fires the correct event in case of an error. @@ -208,9 +206,14 @@ async def test_logging_in_trigger_when_exception_should_execute_successfully( generator = trigger.run() actual = await generator.asend(None) - - actual_stack_trace = actual.payload.pop("description") - assert actual_stack_trace.startswith("Trigger GKEStartPodTrigger failed with exception Exception") + actual_stack_trace = actual.payload.pop("stack_trace") + assert ( + TriggerEvent( + {"name": POD_NAME, "namespace": NAMESPACE, "status": "error", "message": "Test exception"} + ) + == actual + ) + assert actual_stack_trace.startswith("Traceback (most recent call last):") @pytest.mark.asyncio @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")