diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index f493fc38da897..5e3931c10d848 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -92,7 +92,7 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse: log.debug("Calling method %.", {method_name}) try: output = handler(**params) - output_json = BaseSerialization.serialize(output) + output_json = BaseSerialization.serialize(output, use_pydantic_models=True) log.debug("Returning response") return Response( response=json.dumps(output_json or "{}"), headers={"Content-Type": "application/json"} diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index 8ac7ac05a087c..1f8dce26dd265 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -51,6 +51,7 @@ class DagAttributeTypes(str, Enum): XCOM_REF = "xcomref" DATASET = "dataset" SIMPLE_TASK_INSTANCE = "simple_task_instance" + BASE_JOB = "base_job" TASK_INSTANCE = "task_instance" DAG_RUN = "dag_run" DATA_SET = "data_set" diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 6503861dcfd0e..b3d783ad690aa 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -38,6 +38,8 @@ from airflow.configuration import conf from airflow.datasets import Dataset from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError +from airflow.jobs.base_job import BaseJob +from airflow.jobs.pydantic.base_job import BaseJobPydantic from airflow.models.baseoperator import BaseOperator, BaseOperatorLink from airflow.models.connection import Connection from airflow.models.dag import DAG, create_timetable @@ -467,7 +469,12 @@ def serialize( elif isinstance(var, Dataset): return cls._encode(dict(uri=var.uri, extra=var.extra), type_=DAT.DATASET) elif isinstance(var, SimpleTaskInstance): - return cls._encode(cls.serialize(var.__dict__, strict=strict), type_=DAT.SIMPLE_TASK_INSTANCE) + return cls._encode( + cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models), + type_=DAT.SIMPLE_TASK_INSTANCE, + ) + elif use_pydantic_models and isinstance(var, BaseJob): + return cls._encode(BaseJobPydantic.from_orm(var).dict(), type_=DAT.BASE_JOB) elif use_pydantic_models and isinstance(var, TaskInstance): return cls._encode(TaskInstancePydantic.from_orm(var).dict(), type_=DAT.TASK_INSTANCE) elif use_pydantic_models and isinstance(var, DagRun): @@ -528,11 +535,13 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: return Dataset(**var) elif type_ == DAT.SIMPLE_TASK_INSTANCE: return SimpleTaskInstance(**cls.deserialize(var)) - elif use_pydantic_models and type == DAT.TASK_INSTANCE: + elif use_pydantic_models and type_ == DAT.BASE_JOB: + return BaseJobPydantic.parse_obj(var) + elif use_pydantic_models and type_ == DAT.TASK_INSTANCE: return TaskInstancePydantic.parse_obj(var) - elif use_pydantic_models and type == DAT.DAG_RUN: + elif use_pydantic_models and type_ == DAT.DAG_RUN: return DagRunPydantic.parse_obj(var) - elif use_pydantic_models and type == DAT.DATA_SET: + elif use_pydantic_models and type_ == DAT.DATA_SET: return DatasetPydantic.parse_obj(var) else: raise TypeError(f"Invalid type {type_!s} in deserialization.") diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py b/tests/api_internal/endpoints/test_rpc_api_endpoint.py index a0ec216147047..9e09376dd6819 100644 --- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py +++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py @@ -23,7 +23,11 @@ import pytest from flask import Flask +from airflow.models.pydantic.taskinstance import TaskInstancePydantic +from airflow.models.taskinstance import TaskInstance +from airflow.operators.empty import EmptyOperator from airflow.serialization.serialized_objects import BaseSerialization +from airflow.utils.state import State from airflow.www import app from tests.test_utils.config import conf_vars from tests.test_utils.decorators import dont_initialize_flask_app_submodules @@ -110,6 +114,21 @@ def test_method(self, input_data, method_result, method_params, expected_mock, e expected_mock.assert_called_once_with(**method_params) + def test_method_with_pydantic_serialized_object(self): + ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING) + mock_test_method.return_value = ti + + response = self.client.post( + "/internal_api/v1/rpcapi", + headers={"Content-Type": "application/json"}, + data=json.dumps({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}), + ) + assert response.status_code == 200 + print(response.data) + response_data = BaseSerialization.deserialize(json.loads(response.data), use_pydantic_models=True) + expected_data = TaskInstancePydantic.from_orm(ti) + assert response_data == expected_data + def test_method_with_exception(self): mock_test_method.side_effect = ValueError("Error!!!") data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""} diff --git a/tests/api_internal/test_internal_api_call.py b/tests/api_internal/test_internal_api_call.py index c96b2bde32a11..e7cd488e66ef5 100644 --- a/tests/api_internal/test_internal_api_call.py +++ b/tests/api_internal/test_internal_api_call.py @@ -25,7 +25,11 @@ import requests from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call +from airflow.models.pydantic.taskinstance import TaskInstancePydantic +from airflow.models.taskinstance import TaskInstance +from airflow.operators.empty import EmptyOperator from airflow.serialization.serialized_objects import BaseSerialization +from airflow.utils.state import State from tests.test_utils.config import conf_vars @@ -81,6 +85,14 @@ def fake_method_with_params(dag_id: str, task_id: int, session) -> str: def fake_class_method_with_params(cls, dag_id: str, session) -> str: return f"local-classmethod-call-with-params-{dag_id}" + @staticmethod + @internal_api_call + def fake_class_method_with_serialized_params( + ti: TaskInstance | TaskInstancePydantic, + session, + ) -> str: + return f"local-classmethod-call-with-serialized-{ti.task_id}" + @conf_vars( { ("core", "database_access_isolation"): "false", @@ -200,3 +212,36 @@ def test_remote_classmethod_call_with_params(self, mock_requests): data=expected_data, headers={"Content-Type": "application/json"}, ) + + @conf_vars( + { + ("core", "database_access_isolation"): "true", + ("core", "internal_api_url"): "http://localhost:8888", + } + ) + @mock.patch("airflow.api_internal.internal_api_call.requests") + def test_remote_call_with_serialized_model(self, mock_requests): + response = requests.Response() + response.status_code = 200 + + response._content = json.dumps(BaseSerialization.serialize("remote-call")) + + mock_requests.post.return_value = response + ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING) + + result = TestInternalApiCall.fake_class_method_with_serialized_params(ti, session="session") + + assert result == "remote-call" + expected_data = json.dumps( + { + "jsonrpc": "2.0", + "method": "tests.api_internal.test_internal_api_call.TestInternalApiCall." + "fake_class_method_with_serialized_params", + "params": json.dumps(BaseSerialization.serialize({"ti": ti}, use_pydantic_models=True)), + } + ) + mock_requests.post.assert_called_once_with( + url="http://localhost:8888/internal_api/v1/rpcapi", + data=expected_data, + headers={"Content-Type": "application/json"}, + ) diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 3298fb6cbaa30..7ccf16bb6ab65 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -20,6 +20,10 @@ import pytest from airflow.exceptions import SerializationError +from airflow.models.pydantic.taskinstance import TaskInstancePydantic +from airflow.models.taskinstance import TaskInstance +from airflow.operators.empty import EmptyOperator +from airflow.utils.state import State from tests import REPO_ROOT @@ -76,3 +80,17 @@ class Test: BaseSerialization.serialize(obj) # does not raise with pytest.raises(SerializationError, match="Encountered unexpected type"): BaseSerialization.serialize(obj, strict=True) # now raises + + +def test_use_pydantic_models(): + """If use_pydantic_models=True the TaskInstance object should be serialized to TaskInstancePydantic.""" + + from airflow.serialization.serialized_objects import BaseSerialization + + ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING) + obj = [[ti]] # nested to verify recursive behavior + + serialized = BaseSerialization.serialize(obj, use_pydantic_models=True) # does not raise + deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True) # does not raise + + assert isinstance(deserialized[0][0], TaskInstancePydantic)