diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 932e6ead37947..2cdcbd8423725 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -68,6 +68,7 @@ GetXCom, PutVariable, RescheduleTask, + SetRenderedFields, SetXCom, StartupDetails, TaskState, @@ -733,6 +734,8 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index) elif isinstance(msg, PutVariable): self.client.variables.set(msg.key, msg.value, msg.description) + elif isinstance(msg, SetRenderedFields): + self.client.task_instances.set_rtif(self.id, msg.rendered_fields) else: log.error("Unhandled request", msg=msg) return diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 73ac6dea630ba..03d2bab939657 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -46,6 +46,7 @@ GetXCom, PutVariable, RescheduleTask, + SetRenderedFields, SetXCom, TaskState, VariableResult, @@ -882,6 +883,14 @@ def watched_subprocess(self, mocker): "", id="patch_task_instance_to_skipped", ), + pytest.param( + SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}), + b"", + "task_instances.set_rtif", + (TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), + {"ok": True}, + id="set_rtif", + ), ], ) def test_handle_requests( diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 96ac89db5cd9d..2c08a9b97cde9 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -248,43 +248,6 @@ def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context): ) -def test_startup_basic_templated_dag(mocked_parse, make_ti_context): - """Test running a DAG with templated task.""" - from airflow.providers.standard.operators.bash import BashOperator - - task = BashOperator( - task_id="templated_task", - bash_command="echo 'Logical date is {{ logical_date }}'", - ) - - what = StartupDetails( - ti=TaskInstance( - id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1 - ), - file="", - requests_fd=0, - ti_context=make_ti_context(), - ) - mocked_parse(what, "basic_templated_dag", task) - - with mock.patch( - "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True - ) as mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = what - startup() - - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SetRenderedFields( - rendered_fields={ - "bash_command": "echo 'Logical date is {{ logical_date }}'", - "cwd": None, - "env": None, - } - ), - log=mock.ANY, - ) - - @pytest.mark.parametrize( ["task_params", "expected_rendered_fields"], [ @@ -311,8 +274,8 @@ def test_startup_basic_templated_dag(mocked_parse, make_ti_context): ), ], ) -def test_startup_dag_with_templated_fields( - mocked_parse, task_params, expected_rendered_fields, make_ti_context +def test_startup_and_run_dag_with_templated_fields( + mocked_parse, task_params, expected_rendered_fields, make_ti_context, time_machine ): """Test startup of a DAG with various templated fields.""" @@ -324,6 +287,10 @@ def __init__(self, *args, **kwargs): for key, value in task_params.items(): setattr(self, key, value) + def execute(self, context): + for key in self.template_fields: + print(key, getattr(self, key)) + task = CustomOperator(task_id="templated_task") what = StartupDetails( @@ -332,18 +299,28 @@ def __init__(self, *args, **kwargs): requests_fd=0, ti_context=make_ti_context(), ) - mocked_parse(what, "basic_dag", task) + ti = mocked_parse(what, "basic_dag", task) + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) with mock.patch( "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True ) as mock_supervisor_comms: mock_supervisor_comms.get_message.return_value = what startup() - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SetRenderedFields(rendered_fields=expected_rendered_fields), - log=mock.ANY, - ) + run(ti, log=mock.MagicMock()) + expected_calls = [ + mock.call.send_request( + msg=SetRenderedFields(rendered_fields=expected_rendered_fields), + log=mock.ANY, + ), + mock.call.send_request( + msg=TaskState(end_date=instant, state=TerminalTIState.SUCCESS), + log=mock.ANY, + ), + ] + mock_supervisor_comms.assert_has_calls(expected_calls) @pytest.mark.parametrize(