Skip to content

Commit

Permalink
Added BaseJobPydantic support and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mhenc committed Apr 5, 2023
1 parent f97a02a commit f872dc5
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 5 deletions.
2 changes: 1 addition & 1 deletion airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse:
log.debug("Calling method %.", {method_name})
try:
output = handler(**params)
output_json = BaseSerialization.serialize(output)
output_json = BaseSerialization.serialize(output, use_pydantic_models=True)
log.debug("Returning response")
return Response(
response=json.dumps(output_json or "{}"), headers={"Content-Type": "application/json"}
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class DagAttributeTypes(str, Enum):
XCOM_REF = "xcomref"
DATASET = "dataset"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
BASE_JOB = "base_job"
TASK_INSTANCE = "task_instance"
DAG_RUN = "dag_run"
DATA_SET = "data_set"
17 changes: 13 additions & 4 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError
from airflow.jobs.base_job import BaseJob
from airflow.jobs.pydantic.base_job import BaseJobPydantic
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.connection import Connection
from airflow.models.dag import DAG, create_timetable
Expand Down Expand Up @@ -467,7 +469,12 @@ def serialize(
elif isinstance(var, Dataset):
return cls._encode(dict(uri=var.uri, extra=var.extra), type_=DAT.DATASET)
elif isinstance(var, SimpleTaskInstance):
return cls._encode(cls.serialize(var.__dict__, strict=strict), type_=DAT.SIMPLE_TASK_INSTANCE)
return cls._encode(
cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models),
type_=DAT.SIMPLE_TASK_INSTANCE,
)
elif use_pydantic_models and isinstance(var, BaseJob):
return cls._encode(BaseJobPydantic.from_orm(var).dict(), type_=DAT.BASE_JOB)
elif use_pydantic_models and isinstance(var, TaskInstance):
return cls._encode(TaskInstancePydantic.from_orm(var).dict(), type_=DAT.TASK_INSTANCE)
elif use_pydantic_models and isinstance(var, DagRun):
Expand Down Expand Up @@ -528,11 +535,13 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
return Dataset(**var)
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
elif use_pydantic_models and type == DAT.TASK_INSTANCE:
elif use_pydantic_models and type_ == DAT.BASE_JOB:
return BaseJobPydantic.parse_obj(var)
elif use_pydantic_models and type_ == DAT.TASK_INSTANCE:
return TaskInstancePydantic.parse_obj(var)
elif use_pydantic_models and type == DAT.DAG_RUN:
elif use_pydantic_models and type_ == DAT.DAG_RUN:
return DagRunPydantic.parse_obj(var)
elif use_pydantic_models and type == DAT.DATA_SET:
elif use_pydantic_models and type_ == DAT.DATA_SET:
return DatasetPydantic.parse_obj(var)
else:
raise TypeError(f"Invalid type {type_!s} in deserialization.")
Expand Down
19 changes: 19 additions & 0 deletions tests/api_internal/endpoints/test_rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
import pytest
from flask import Flask

from airflow.models.pydantic.taskinstance import TaskInstancePydantic
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.state import State
from airflow.www import app
from tests.test_utils.config import conf_vars
from tests.test_utils.decorators import dont_initialize_flask_app_submodules
Expand Down Expand Up @@ -110,6 +114,21 @@ def test_method(self, input_data, method_result, method_params, expected_mock, e

expected_mock.assert_called_once_with(**method_params)

def test_method_with_pydantic_serialized_object(self):
ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING)
mock_test_method.return_value = ti

response = self.client.post(
"/internal_api/v1/rpcapi",
headers={"Content-Type": "application/json"},
data=json.dumps({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}),
)
assert response.status_code == 200
print(response.data)
response_data = BaseSerialization.deserialize(json.loads(response.data), use_pydantic_models=True)
expected_data = TaskInstancePydantic.from_orm(ti)
assert response_data == expected_data

def test_method_with_exception(self):
mock_test_method.side_effect = ValueError("Error!!!")
data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}
Expand Down
45 changes: 45 additions & 0 deletions tests/api_internal/test_internal_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import requests

from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call
from airflow.models.pydantic.taskinstance import TaskInstancePydantic
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.state import State
from tests.test_utils.config import conf_vars


Expand Down Expand Up @@ -81,6 +85,14 @@ def fake_method_with_params(dag_id: str, task_id: int, session) -> str:
def fake_class_method_with_params(cls, dag_id: str, session) -> str:
return f"local-classmethod-call-with-params-{dag_id}"

@staticmethod
@internal_api_call
def fake_class_method_with_serialized_params(
ti: TaskInstance | TaskInstancePydantic,
session,
) -> str:
return f"local-classmethod-call-with-serialized-{ti.task_id}"

@conf_vars(
{
("core", "database_access_isolation"): "false",
Expand Down Expand Up @@ -200,3 +212,36 @@ def test_remote_classmethod_call_with_params(self, mock_requests):
data=expected_data,
headers={"Content-Type": "application/json"},
)

@conf_vars(
{
("core", "database_access_isolation"): "true",
("core", "internal_api_url"): "http://localhost:8888",
}
)
@mock.patch("airflow.api_internal.internal_api_call.requests")
def test_remote_call_with_serialized_model(self, mock_requests):
response = requests.Response()
response.status_code = 200

response._content = json.dumps(BaseSerialization.serialize("remote-call"))

mock_requests.post.return_value = response
ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING)

result = TestInternalApiCall.fake_class_method_with_serialized_params(ti, session="session")

assert result == "remote-call"
expected_data = json.dumps(
{
"jsonrpc": "2.0",
"method": "tests.api_internal.test_internal_api_call.TestInternalApiCall."
"fake_class_method_with_serialized_params",
"params": json.dumps(BaseSerialization.serialize({"ti": ti}, use_pydantic_models=True)),
}
)
mock_requests.post.assert_called_once_with(
url="http://localhost:8888/internal_api/v1/rpcapi",
data=expected_data,
headers={"Content-Type": "application/json"},
)
18 changes: 18 additions & 0 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
import pytest

from airflow.exceptions import SerializationError
from airflow.models.pydantic.taskinstance import TaskInstancePydantic
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.utils.state import State
from tests import REPO_ROOT


Expand Down Expand Up @@ -76,3 +80,17 @@ class Test:
BaseSerialization.serialize(obj) # does not raise
with pytest.raises(SerializationError, match="Encountered unexpected type"):
BaseSerialization.serialize(obj, strict=True) # now raises


def test_use_pydantic_models():
"""If use_pydantic_models=True the TaskInstance object should be serialized to TaskInstancePydantic."""

from airflow.serialization.serialized_objects import BaseSerialization

ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING)
obj = [[ti]] # nested to verify recursive behavior

serialized = BaseSerialization.serialize(obj, use_pydantic_models=True) # does not raise
deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True) # does not raise

assert isinstance(deserialized[0][0], TaskInstancePydantic)

0 comments on commit f872dc5

Please sign in to comment.