Skip to content

Commit

Permalink
AIP-72: Adding missing supervisor handler for RTIF (apache#45102)
Browse files Browse the repository at this point in the history
  • Loading branch information
amoghrajesh authored Dec 20, 2024
1 parent b6e3d1c commit 7e9d7b1
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 44 deletions.
3 changes: 3 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
GetXCom,
PutVariable,
RescheduleTask,
SetRenderedFields,
SetXCom,
StartupDetails,
TaskState,
Expand Down Expand Up @@ -733,6 +734,8 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index)
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
elif isinstance(msg, SetRenderedFields):
self.client.task_instances.set_rtif(self.id, msg.rendered_fields)
else:
log.error("Unhandled request", msg=msg)
return
Expand Down
9 changes: 9 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
GetXCom,
PutVariable,
RescheduleTask,
SetRenderedFields,
SetXCom,
TaskState,
VariableResult,
Expand Down Expand Up @@ -882,6 +883,14 @@ def watched_subprocess(self, mocker):
"",
id="patch_task_instance_to_skipped",
),
pytest.param(
SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}),
b"",
"task_instances.set_rtif",
(TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}),
{"ok": True},
id="set_rtif",
),
],
)
def test_handle_requests(
Expand Down
65 changes: 21 additions & 44 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,43 +248,6 @@ def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context):
)


def test_startup_basic_templated_dag(mocked_parse, make_ti_context):
"""Test running a DAG with templated task."""
from airflow.providers.standard.operators.bash import BashOperator

task = BashOperator(
task_id="templated_task",
bash_command="echo 'Logical date is {{ logical_date }}'",
)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)
mocked_parse(what, "basic_templated_dag", task)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = what
startup()

mock_supervisor_comms.send_request.assert_called_once_with(
msg=SetRenderedFields(
rendered_fields={
"bash_command": "echo 'Logical date is {{ logical_date }}'",
"cwd": None,
"env": None,
}
),
log=mock.ANY,
)


@pytest.mark.parametrize(
["task_params", "expected_rendered_fields"],
[
Expand All @@ -311,8 +274,8 @@ def test_startup_basic_templated_dag(mocked_parse, make_ti_context):
),
],
)
def test_startup_dag_with_templated_fields(
mocked_parse, task_params, expected_rendered_fields, make_ti_context
def test_startup_and_run_dag_with_templated_fields(
mocked_parse, task_params, expected_rendered_fields, make_ti_context, time_machine
):
"""Test startup of a DAG with various templated fields."""

Expand All @@ -324,6 +287,10 @@ def __init__(self, *args, **kwargs):
for key, value in task_params.items():
setattr(self, key, value)

def execute(self, context):
for key in self.template_fields:
print(key, getattr(self, key))

task = CustomOperator(task_id="templated_task")

what = StartupDetails(
Expand All @@ -332,18 +299,28 @@ def __init__(self, *args, **kwargs):
requests_fd=0,
ti_context=make_ti_context(),
)
mocked_parse(what, "basic_dag", task)
ti = mocked_parse(what, "basic_dag", task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)
with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = what

startup()
mock_supervisor_comms.send_request.assert_called_once_with(
msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
log=mock.ANY,
)
run(ti, log=mock.MagicMock())
expected_calls = [
mock.call.send_request(
msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
log=mock.ANY,
),
mock.call.send_request(
msg=TaskState(end_date=instant, state=TerminalTIState.SUCCESS),
log=mock.ANY,
),
]
mock_supervisor_comms.assert_has_calls(expected_calls)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 7e9d7b1

Please sign in to comment.