Skip to content

Commit

Permalink
Format using black
Browse files Browse the repository at this point in the history
  • Loading branch information
cmarteepants committed Jul 12, 2024
1 parent 5705ee4 commit 2952bc2
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 89 deletions.
1 change: 1 addition & 0 deletions dagfactory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Modules and methods to export for easier access"""

from .dagfactory import DagFactory, load_yaml_dags
1 change: 1 addition & 0 deletions dagfactory/__version__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Module contains the version of dag-factory"""

__version__ = "0.19.0"
55 changes: 24 additions & 31 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module contains code for generating tasks and constructing a DAG"""

# pylint: disable=ungrouped-imports
import os
import re
Expand Down Expand Up @@ -197,34 +198,26 @@ def get_dag_params(self) -> Dict[str, Any]:

if utils.check_dict_key(dag_params["default_args"], "sla_miss_callback"):
if isinstance(dag_params["default_args"]["sla_miss_callback"], str):
dag_params["default_args"][
"sla_miss_callback"
]: Callable = import_string(
dag_params["default_args"]["sla_miss_callback"]
dag_params["default_args"]["sla_miss_callback"]: Callable = (
import_string(dag_params["default_args"]["sla_miss_callback"])
)

if utils.check_dict_key(dag_params["default_args"], "on_success_callback"):
if isinstance(dag_params["default_args"]["on_success_callback"], str):
dag_params["default_args"][
"on_success_callback"
]: Callable = import_string(
dag_params["default_args"]["on_success_callback"]
dag_params["default_args"]["on_success_callback"]: Callable = (
import_string(dag_params["default_args"]["on_success_callback"])
)

if utils.check_dict_key(dag_params["default_args"], "on_failure_callback"):
if isinstance(dag_params["default_args"]["on_failure_callback"], str):
dag_params["default_args"][
"on_failure_callback"
]: Callable = import_string(
dag_params["default_args"]["on_failure_callback"]
dag_params["default_args"]["on_failure_callback"]: Callable = (
import_string(dag_params["default_args"]["on_failure_callback"])
)

if utils.check_dict_key(dag_params["default_args"], "on_retry_callback"):
if isinstance(dag_params["default_args"]["on_retry_callback"], str):
dag_params["default_args"][
"on_retry_callback"
]: Callable = import_string(
dag_params["default_args"]["on_retry_callback"]
dag_params["default_args"]["on_retry_callback"]: Callable = (
import_string(dag_params["default_args"]["on_retry_callback"])
)

if utils.check_dict_key(dag_params, "sla_miss_callback"):
Expand Down Expand Up @@ -351,11 +344,11 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
" python_callable_file: !!python/name:my_module.my_func"
)
if not task_params.get("python_callable"):
task_params[
"python_callable"
]: Callable = utils.get_python_callable(
task_params["python_callable_name"],
task_params["python_callable_file"],
task_params["python_callable"]: Callable = (
utils.get_python_callable(
task_params["python_callable_name"],
task_params["python_callable_file"],
)
)
# remove dag-factory specific parameters
# Airflow 2.0 doesn't allow these to be passed to operator
Expand Down Expand Up @@ -419,10 +412,10 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
del task_params["response_check_name"]
del task_params["response_check_file"]
else:
task_params[
"response_check"
]: Callable = utils.get_python_callable_lambda(
task_params["response_check_lambda"]
task_params["response_check"]: Callable = (
utils.get_python_callable_lambda(
task_params["response_check_lambda"]
)
)
# remove dag-factory specific parameters
# Airflow 2.0 doesn't allow these to be passed to operator
Expand Down Expand Up @@ -669,18 +662,18 @@ def set_dependencies(
group_id = conf["task_group"].group_id
name = f"{group_id}.{name}"
if conf.get("dependencies"):
source: Union[
BaseOperator, "TaskGroup"
] = tasks_and_task_groups_instances[name]
source: Union[BaseOperator, "TaskGroup"] = (
tasks_and_task_groups_instances[name]
)
for dep in conf["dependencies"]:
if tasks_and_task_groups_config[dep].get("task_group"):
group_id = tasks_and_task_groups_config[dep][
"task_group"
].group_id
dep = f"{group_id}.{dep}"
dep: Union[
BaseOperator, "TaskGroup"
] = tasks_and_task_groups_instances[dep]
dep: Union[BaseOperator, "TaskGroup"] = (
tasks_and_task_groups_instances[dep]
)
source.set_upstream(dep)

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions dagfactory/dagfactory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module contains code for loading a DagFactory config and generating DAGs"""

import logging
import os
from itertools import chain
Expand Down
5 changes: 2 additions & 3 deletions dagfactory/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module contains various utilities used by dag-factory"""

import ast
import importlib.util
import os
Expand Down Expand Up @@ -212,9 +213,7 @@ def check_template_searchpath(template_searchpath: Union[str, List[str]]) -> boo
return False


def get_expand_partial_kwargs(
task_params: Dict[str, Any]
) -> Tuple[
def get_expand_partial_kwargs(task_params: Dict[str, Any]) -> Tuple[
Dict[str, Any],
Dict[str, Union[Dict[str, Any], Any]],
Dict[str, Union[Dict[str, Any], Any]],
Expand Down
2 changes: 1 addition & 1 deletion examples/expand_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ def example_task_mapping():
def expand_task(x, test_id):
print(test_id)
print(x)
return [x]
return [x]
92 changes: 53 additions & 39 deletions tests/test_dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,14 @@
"request": {
"operator": "airflow.operators.python_operator.PythonOperator",
"python_callable_name": "example_task_mapping",
"python_callable_file": os.path.realpath(__file__)
"python_callable_file": os.path.realpath(__file__),
},
"process_1": {
"operator": "airflow.operators.python_operator.PythonOperator",
"python_callable_name": "expand_task",
"python_callable_file": os.path.realpath(__file__),
"partial": {
"op_kwargs": {
"test_id": "test"
}
},
"expand": {
"op_args": "request.output"
}
"partial": {"op_kwargs": {"test_id": "test"}},
"expand": {"op_args": "request.output"},
},
},
}
Expand Down Expand Up @@ -617,10 +611,7 @@ def test_make_timetable():
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG)
timetable = "airflow.timetables.interval.CronDataIntervalTimetable"
timetable_params = {
"cron": "0 8,16 * * 1-5",
"timezone": "UTC"
}
timetable_params = {"cron": "0 8,16 * * 1-5", "timezone": "UTC"}
actual = td.make_timetable(timetable, timetable_params)
assert actual.periodic
assert actual.can_run
Expand All @@ -633,22 +624,31 @@ def test_make_dag_with_callback():

def test_get_dag_params_with_template_searchpath():
from dagfactory import utils
td = dagbuilder.DagBuilder("test_dag", {"template_searchpath": ["./sql"]}, DEFAULT_CONFIG)

td = dagbuilder.DagBuilder(
"test_dag", {"template_searchpath": ["./sql"]}, DEFAULT_CONFIG
)
error_message = "template_searchpath must be absolute paths"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()

td = dagbuilder.DagBuilder("test_dag", {"template_searchpath": ["/sql"]}, DEFAULT_CONFIG)
td = dagbuilder.DagBuilder(
"test_dag", {"template_searchpath": ["/sql"]}, DEFAULT_CONFIG
)
error_message = "template_searchpath must be existing paths"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()

td = dagbuilder.DagBuilder("test_dag", {"template_searchpath": "./sql"}, DEFAULT_CONFIG)

td = dagbuilder.DagBuilder(
"test_dag", {"template_searchpath": "./sql"}, DEFAULT_CONFIG
)
error_message = "template_searchpath must be absolute paths"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()

td = dagbuilder.DagBuilder("test_dag", {"template_searchpath": "/sql"}, DEFAULT_CONFIG)
td = dagbuilder.DagBuilder(
"test_dag", {"template_searchpath": "/sql"}, DEFAULT_CONFIG
)
error_message = "template_searchpath must be existing paths"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()
Expand All @@ -659,31 +659,40 @@ def test_get_dag_params_with_template_searchpath():


def test_get_dag_params_with_render_template_as_native_obj():
td = dagbuilder.DagBuilder("test_dag", {"render_template_as_native_obj": "true"}, DEFAULT_CONFIG)
td = dagbuilder.DagBuilder(
"test_dag", {"render_template_as_native_obj": "true"}, DEFAULT_CONFIG
)
error_message = "render_template_as_native_obj should be bool type!"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()

false = lambda x: print(x)
td = dagbuilder.DagBuilder("test_dag", {"render_template_as_native_obj": false}, DEFAULT_CONFIG)
td = dagbuilder.DagBuilder(
"test_dag", {"render_template_as_native_obj": false}, DEFAULT_CONFIG
)
error_message = "render_template_as_native_obj should be bool type!"
with pytest.raises(Exception, match=error_message):
td.get_dag_params()


def test_make_task_with_duplicated_partial_kwargs():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG)
td = dagbuilder.DagBuilder(
"test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG
)
operator = "airflow.operators.bash_operator.BashOperator"
task_params = {"task_id": "task_bash",
"bash_command": "echo 2",
"partial": {"bash_command": "echo 4"}
}
task_params = {
"task_id": "task_bash",
"bash_command": "echo 2",
"partial": {"bash_command": "echo 4"},
}
with pytest.raises(Exception):
td.make_task(operator, task_params)


def test_dynamic_task_mapping():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG)
td = dagbuilder.DagBuilder(
"test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG
)
if version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"):
error_message = "Dynamic task mapping available only in Airflow >= 2.3.0"
with pytest.raises(Exception, match=error_message):
Expand All @@ -694,31 +703,36 @@ def test_dynamic_task_mapping():
"task_id": "process",
"python_callable_name": "expand_task",
"python_callable_file": os.path.realpath(__file__),
"partial": {
"op_kwargs": {
"test_id": "test"
}
},
"expand": {
"op_args": "request.output"
}
"partial": {"op_kwargs": {"test_id": "test"}},
"expand": {"op_args": "request.output"},
}
actual = td.make_task(operator, task_params)
assert isinstance(actual, MappedOperator)


@patch("dagfactory.dagbuilder.PythonOperator", new=MockPythonOperator)
def test_replace_expand_string_with_xcom():
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG)
td = dagbuilder.DagBuilder(
"test_dag", DAG_CONFIG_DYNAMIC_TASK_MAPPING, DEFAULT_CONFIG
)
if version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"):
with pytest.raises(Exception):
td.build()
else:
from airflow.models.xcom_arg import XComArg

task_conf_output = {"expand": {"key_1": "task_1.output"}}
task_conf_xcomarg = {"expand": {"key_1": "XcomArg(task_1)"}}
tasks_dict = {"task_1": MockPythonOperator()}
updated_task_conf_output = dagbuilder.DagBuilder.replace_expand_values(task_conf_output, tasks_dict)
updated_task_conf_xcomarg = dagbuilder.DagBuilder.replace_expand_values(task_conf_xcomarg, tasks_dict)
assert updated_task_conf_output["expand"]["key_1"] == XComArg(tasks_dict["task_1"])
assert updated_task_conf_xcomarg["expand"]["key_1"] == XComArg(tasks_dict["task_1"])
updated_task_conf_output = dagbuilder.DagBuilder.replace_expand_values(
task_conf_output, tasks_dict
)
updated_task_conf_xcomarg = dagbuilder.DagBuilder.replace_expand_values(
task_conf_xcomarg, tasks_dict
)
assert updated_task_conf_output["expand"]["key_1"] == XComArg(
tasks_dict["task_1"]
)
assert updated_task_conf_xcomarg["expand"]["key_1"] == XComArg(
tasks_dict["task_1"]
)
8 changes: 4 additions & 4 deletions tests/test_dagfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_load_config_valid():
},
},
"example_dag4": {
"vars": {'arg1': 'hello', 'arg2': 'hello world'},
"vars": {"arg1": "hello", "arg2": "hello world"},
"tasks": {
"task_1": {
"operator": "airflow.operators.bash_operator.BashOperator",
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_get_dag_configs():
},
},
"example_dag4": {
"vars": {'arg1': 'hello', 'arg2': 'hello world'},
"vars": {"arg1": "hello", "arg2": "hello world"},
"tasks": {
"task_1": {
"operator": "airflow.operators.bash_operator.BashOperator",
Expand Down Expand Up @@ -434,15 +434,15 @@ def test_set_callback_after_loading_config():
def test_load_yaml_dags_fail():
with pytest.raises(Exception):
load_yaml_dags(
globals_dict= globals(),
globals_dict=globals(),
dags_folder="tests/fixtures",
suffix=["invalid_yaml.yml"],
)


def test_load_yaml_dags_succeed():
load_yaml_dags(
globals_dict= globals(),
globals_dict=globals(),
dags_folder="tests/fixtures",
suffix=["dag_factory_variables_as_arguments.yml"],
)
Loading

0 comments on commit 2952bc2

Please sign in to comment.