Skip to content

Commit

Permalink
Add support to TestBehavior.BUILD
Browse files Browse the repository at this point in the history
Closes: #892
  • Loading branch information
tatiana committed Dec 12, 2024
1 parent 2fa5d01 commit df279d7
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 12 deletions.
37 changes: 25 additions & 12 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from cosmos.constants import (
DBT_COMPILE_TASK_ID,
DEFAULT_DBT_RESOURCES,
SUPPORTED_BUILD_RESOURCES,
TESTABLE_DBT_RESOURCES,
DbtResourceType,
ExecutionMode,
Expand Down Expand Up @@ -135,6 +136,7 @@ def create_task_metadata(
dbt_dag_task_group_identifier: str,
use_task_group: bool = False,
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE,
test_behavior: TestBehavior = TestBehavior.AFTER_ALL,
) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.
Expand All @@ -148,33 +150,43 @@ def create_task_metadata(
If it is False, then use the name as a prefix for the task id, otherwise do not.
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
"""
dbt_resource_to_class = {
DbtResourceType.MODEL: "DbtRun",
DbtResourceType.SNAPSHOT: "DbtSnapshot",
DbtResourceType.SEED: "DbtSeed",
DbtResourceType.TEST: "DbtTest",
DbtResourceType.SOURCE: "DbtSource",
}
if test_behavior == TestBehavior.BUILD:
dbt_resource_to_class = {
DbtResourceType.MODEL: "DbtBuild",
DbtResourceType.SNAPSHOT: "DbtBuild",
DbtResourceType.SEED: "DbtBuild",
DbtResourceType.TEST: "DbtTest",
DbtResourceType.SOURCE: "DbtSource",
}
else:
dbt_resource_to_class = {
DbtResourceType.MODEL: "DbtRun",
DbtResourceType.SNAPSHOT: "DbtSnapshot",
DbtResourceType.SEED: "DbtSeed",
DbtResourceType.TEST: "DbtTest",
DbtResourceType.SOURCE: "DbtSource",
}
args = {**args, **{"models": node.resource_name}}

if DbtResourceType(node.resource_type) in DEFAULT_DBT_RESOURCES and node.resource_type in dbt_resource_to_class:
extra_context = {
"dbt_node_config": node.context_dict,
"dbt_dag_task_group_identifier": dbt_dag_task_group_identifier,
}
if node.resource_type == DbtResourceType.MODEL:
task_id = f"{node.name}_run"
if use_task_group is True:
if test_behavior == TestBehavior.BUILD and node.resource_type in SUPPORTED_BUILD_RESOURCES:
task_id = f"{node.name}_{node.resource_type.value}_build"
elif node.resource_type == DbtResourceType.MODEL:
if use_task_group:
task_id = "run"
else:
task_id = f"{node.name}_run"
elif node.resource_type == DbtResourceType.SOURCE:
if (source_rendering_behavior == SourceRenderingBehavior.NONE) or (
source_rendering_behavior == SourceRenderingBehavior.WITH_TESTS_OR_FRESHNESS
and node.has_freshness is False
and node.has_test is False
):
return None
# TODO: https://github.com/astronomer/astronomer-cosmos
# pragma: no cover
task_id = f"{node.name}_source"
args["select"] = f"source:{node.resource_name}"
args.pop("models")
Expand Down Expand Up @@ -234,6 +246,7 @@ def generate_task_or_group(
dbt_dag_task_group_identifier=_get_dbt_dag_task_group_identifier(dag, task_group),
use_task_group=use_task_group,
source_rendering_behavior=source_rendering_behavior,
test_behavior=test_behavior,
)

# In most cases, we'll map one DBT node to one Airflow task
Expand Down
9 changes: 9 additions & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TestBehavior(Enum):
Behavior of the tests.
"""

BUILD = "build"
NONE = "none"
AFTER_EACH = "after_each"
AFTER_ALL = "after_all"
Expand Down Expand Up @@ -144,6 +145,14 @@ def _missing_value_(cls, value): # type: ignore

DEFAULT_DBT_RESOURCES = DbtResourceType.__members__.values()

# According to the dbt documentation (https://docs.getdbt.com/reference/commands/build), build also supports test nodes.
# However, in the context of Cosmos, we will run test nodes together with the respective models/seeds/snapshots nodes
SUPPORTED_BUILD_RESOURCES = [
DbtResourceType.MODEL,
DbtResourceType.SNAPSHOT,
DbtResourceType.SEED,
]

# dbt test runs tests defined on models, sources, snapshots, and seeds.
# It expects that you have already created those resources through the appropriate commands.
# https://docs.getdbt.com/reference/commands/test
Expand Down
47 changes: 47 additions & 0 deletions dev/dags/example_cosmos_dbt_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
An example Airflow DAG that illustrates using the dbt build to run both models/seeds/sources and their respective tests.
"""

import os
from datetime import datetime
from pathlib import Path

from cosmos import DbtDag, ProfileConfig, ProjectConfig, RenderConfig
from cosmos.constants import TestBehavior
from cosmos.profiles import PostgresUserPasswordProfileMapping

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))

profile_config = ProfileConfig(
profile_name="default",
target_name="dev",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="example_conn",
profile_args={"schema": "public"},
disable_event_tracking=True,
),
)

# [START build_example]
example_cosmos_dbt_build = DbtDag(
# dbt/cosmos-specific parameters
project_config=ProjectConfig(
DBT_ROOT_PATH / "jaffle_shop",
),
render_config=RenderConfig(
test_behavior=TestBehavior.BUILD,
),
profile_config=profile_config,
operator_args={
"install_deps": True, # install any necessary dependencies before running any dbt command
"full_refresh": True, # used only in dbt commands that support this flag
},
# normal dag parameters
schedule_interval="@daily",
start_date=datetime(2023, 1, 1),
catchup=False,
dag_id="example_cosmos_dbt_build",
default_args={"retries": 2},
)
# [END build_example]
45 changes: 45 additions & 0 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,51 @@ def test_build_airflow_graph_with_after_all():
assert dag.leaves[0].select == ["tag:some"]


@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.4"),
reason="Airflow DAG did not have task_group_dict until the 2.4 release",
)
@pytest.mark.integration
def test_build_airflow_graph_with_build():
with DAG("test-id", start_date=datetime(2022, 1, 1)) as dag:
task_args = {
"project_dir": SAMPLE_PROJ_PATH,
"conn_id": "fake_conn",
"profile_config": ProfileConfig(
profile_name="default",
target_name="default",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="fake_conn",
profile_args={"schema": "public"},
),
),
}
render_config = RenderConfig(
select=["tag:some"],
test_behavior=TestBehavior.BUILD,
source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR,
)
build_airflow_graph(
nodes=sample_nodes,
dag=dag,
execution_mode=ExecutionMode.LOCAL,
test_indirect_selection=TestIndirectSelection.EAGER,
task_args=task_args,
dbt_project_name="astro_shop",
render_config=render_config,
)
topological_sort = [task.task_id for task in dag.topological_sort()]
expected_sort = ["seed_parent_seed_build", "parent_model_build", "child_model_build", "child2_v2_model_build"]
assert topological_sort == expected_sort

task_groups = dag.task_group_dict
assert len(task_groups) == 0

assert len(dag.leaves) == 2
assert dag.leaves[0].task_id == "orders_model_build"
assert dag.leaves[0].task_id == "customers_model_build"


@pytest.mark.integration
@patch("airflow.hooks.base.BaseHook.get_connection", new=MagicMock())
def test_build_airflow_graph_with_dbt_compile_task():
Expand Down

0 comments on commit df279d7

Please sign in to comment.