diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4c9d6c809..8044eade4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,7 +1,7 @@ Changelog ========= -1.9.0a4 (2025-01-29) +1.9.0a5 (2025-02-03) -------------------- Breaking changes @@ -18,6 +18,7 @@ Features * Allow users to opt-out of ``dbtRunner`` during DAG parsing with ``InvocationMode.SUBPROCESS`` by @tatiana in #1495. Check out the `documentation `_. * Add structure to support multiple db for async operator execution by @pankajastro in #1483 * Support overriding the ``profile_config`` per dbt node or folder using config by @tatiana in #1492. More information `here `_. +* Create and run accurate SQL statements when using ``ExecutionMode.AIRFLOW_ASYNC`` by @pankajkoti, @tatiana and @pankajastro in #1474 Bug Fixes @@ -27,9 +28,12 @@ Enhancement * Fix OpenLineage deprecation warning by @CorsettiS in #1449 * Move ``DbtRunner`` related functions into ``dbt/runner.py`` module by @tatiana in #1480 +* Add ``on_warning_callback`` to ``DbtSourceKubernetesOperator`` and refactor previous operators by @LuigiCerone in #1501 + Others +* Ignore dbt package tests when running Cosmos tests by @tatiana in #1502 * GitHub Actions Dependabot: #1487 * Pre-commit updates: #1473, #1493 diff --git a/cosmos/__init__.py b/cosmos/__init__.py index e245fb7e6..7374e9db6 100644 --- a/cosmos/__init__.py +++ b/cosmos/__init__.py @@ -6,7 +6,7 @@ Contains dags, task groups, and operators. """ -__version__ = "1.9.0a4" +__version__ = "1.9.0a5" from cosmos.airflow.dag import DbtDag diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index ef742f8eb..2c65361c5 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -11,7 +11,6 @@ from cosmos.config import RenderConfig from cosmos.constants import ( - DBT_COMPILE_TASK_ID, DEFAULT_DBT_RESOURCES, SUPPORTED_BUILD_RESOURCES, TESTABLE_DBT_RESOURCES, @@ -392,32 +391,6 @@ def generate_task_or_group( return task_or_group -def _add_dbt_compile_task( - nodes: dict[str, DbtNode], - dag: DAG, - execution_mode: ExecutionMode, - task_args: dict[str, Any], - tasks_map: dict[str, Any], - task_group: TaskGroup | None, -) -> None: - if execution_mode != ExecutionMode.AIRFLOW_ASYNC: - return - - compile_task_metadata = TaskMetadata( - id=DBT_COMPILE_TASK_ID, - operator_class="cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator", - arguments=task_args, - extra_context={"dbt_dag_task_group_identifier": _get_dbt_dag_task_group_identifier(dag, task_group)}, - ) - compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=task_group) - - for task_id, task in tasks_map.items(): - if not task.upstream_list: - compile_airflow_task >> task - - tasks_map[DBT_COMPILE_TASK_ID] = compile_airflow_task - - def _get_dbt_dag_task_group_identifier(dag: DAG, task_group: TaskGroup | None) -> str: dag_id = dag.dag_id task_group_id = task_group.group_id if task_group else None @@ -588,7 +561,6 @@ def build_airflow_graph( tasks_map[node_id] = test_task create_airflow_task_dependencies(nodes, tasks_map) - _add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group) return tasks_map diff --git a/cosmos/constants.py b/cosmos/constants.py index 0513d50d2..a68f5a836 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -6,6 +6,7 @@ import aenum from packaging.version import Version +BIGQUERY_PROFILE_TYPE = "bigquery" DBT_PROFILE_PATH = Path(os.path.expanduser("~")).joinpath(".dbt/profiles.yml") DEFAULT_DBT_PROFILE_NAME = "cosmos_profile" DEFAULT_DBT_TARGET_NAME = "cosmos_target" diff --git a/cosmos/dbt_adapters/__init__.py b/cosmos/dbt_adapters/__init__.py new file mode 100644 index 000000000..9c4f4dec0 --- /dev/null +++ b/cosmos/dbt_adapters/__init__.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Any + +from cosmos.constants import BIGQUERY_PROFILE_TYPE +from cosmos.dbt_adapters.bigquery import _associate_bigquery_async_op_args, _mock_bigquery_adapter + +PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP = { + BIGQUERY_PROFILE_TYPE: _mock_bigquery_adapter, +} + +PROFILE_TYPE_ASSOCIATE_ARGS_CALLABLE_MAP = { + BIGQUERY_PROFILE_TYPE: _associate_bigquery_async_op_args, +} + + +def associate_async_operator_args(async_operator_obj: Any, profile_type: str, **kwargs: Any) -> Any: + return PROFILE_TYPE_ASSOCIATE_ARGS_CALLABLE_MAP[profile_type](async_operator_obj, **kwargs) diff --git a/cosmos/dbt_adapters/bigquery.py b/cosmos/dbt_adapters/bigquery.py new file mode 100644 index 000000000..e7876e06b --- /dev/null +++ b/cosmos/dbt_adapters/bigquery.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Any + +from cosmos.exceptions import CosmosValueError + + +def _mock_bigquery_adapter() -> None: + from typing import Optional, Tuple + + import agate + from dbt.adapters.bigquery.connections import BigQueryAdapterResponse, BigQueryConnectionManager + from dbt_common.clients.agate_helper import empty_table + + def execute( # type: ignore[no-untyped-def] + self, sql, auto_begin=False, fetch=None, limit: Optional[int] = None + ) -> Tuple[BigQueryAdapterResponse, agate.Table]: + return BigQueryAdapterResponse("mock_bigquery_adapter_response"), empty_table() + + BigQueryConnectionManager.execute = execute + + +def _associate_bigquery_async_op_args(async_op_obj: Any, **kwargs: Any) -> Any: + sql = kwargs.get("sql") + if not sql: + raise CosmosValueError("Keyword argument 'sql' is required for BigQuery Async operator") + async_op_obj.configuration = { + "query": { + "query": sql, + "useLegacySql": False, + } + } + return async_op_obj diff --git a/cosmos/operators/_asynchronous/base.py b/cosmos/operators/_asynchronous/base.py index e957c9cac..f8d41b88c 100644 --- a/cosmos/operators/_asynchronous/base.py +++ b/cosmos/operators/_asynchronous/base.py @@ -1,9 +1,8 @@ +from __future__ import annotations + import importlib import logging -from abc import ABCMeta -from typing import Any, Sequence - -from airflow.utils.context import Context +from typing import Any from cosmos.airflow.graph import _snake_case_to_camelcase from cosmos.config import ProfileConfig @@ -36,11 +35,16 @@ def _create_async_operator_class(profile_type: str, dbt_class: str) -> Any: return DbtRunLocalOperator -class DbtRunAirflowAsyncFactoryOperator(DbtRunLocalOperator, metaclass=ABCMeta): # type: ignore[misc] +class DbtRunAirflowAsyncFactoryOperator(DbtRunLocalOperator): # type: ignore[misc] - template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ("project_dir",) # type: ignore[operator] - - def __init__(self, project_dir: str, profile_config: ProfileConfig, **kwargs: Any): + def __init__( + self, + project_dir: str, + profile_config: ProfileConfig, + extra_context: dict[str, object] | None = None, + dbt_kwargs: dict[str, object] | None = None, + **kwargs: Any, + ) -> None: self.project_dir = project_dir self.profile_config = profile_config @@ -51,7 +55,13 @@ def __init__(self, project_dir: str, profile_config: ProfileConfig, **kwargs: An # When using composition instead of inheritance to initialize the async class and run its execute method, # Airflow throws a `DuplicateTaskIdFound` error. DbtRunAirflowAsyncFactoryOperator.__bases__ = (async_operator_class,) - super().__init__(project_dir=project_dir, profile_config=profile_config, **kwargs) + super().__init__( + project_dir=project_dir, + profile_config=profile_config, + extra_context=extra_context, + dbt_kwargs=dbt_kwargs, + **kwargs, + ) def create_async_operator(self) -> Any: @@ -60,6 +70,3 @@ def create_async_operator(self) -> Any: async_class_operator = _create_async_operator_class(profile_type, "DbtRun") return async_class_operator - - def execute(self, context: Context) -> None: - super().execute(context) diff --git a/cosmos/operators/_asynchronous/bigquery.py b/cosmos/operators/_asynchronous/bigquery.py index decbf8d77..1c5dc01a8 100644 --- a/cosmos/operators/_asynchronous/bigquery.py +++ b/cosmos/operators/_asynchronous/bigquery.py @@ -1,22 +1,23 @@ from __future__ import annotations -from pathlib import Path -from typing import TYPE_CHECKING, Any, Sequence +from typing import Any, Sequence -from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +import airflow from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator from airflow.utils.context import Context +from packaging.version import Version from cosmos import settings from cosmos.config import ProfileConfig -from cosmos.exceptions import CosmosValueError -from cosmos.settings import remote_target_path, remote_target_path_conn_id +from cosmos.dataset import get_dataset_alias_name +from cosmos.operators.local import AbstractDbtLocalBase +AIRFLOW_VERSION = Version(airflow.__version__) -class DbtRunAirflowAsyncBigqueryOperator(BigQueryInsertJobOperator): # type: ignore[misc] + +class DbtRunAirflowAsyncBigqueryOperator(BigQueryInsertJobOperator, AbstractDbtLocalBase): # type: ignore[misc] template_fields: Sequence[str] = ( - "full_refresh", "gcp_project", "dataset", "location", @@ -27,6 +28,7 @@ def __init__( project_dir: str, profile_config: ProfileConfig, extra_context: dict[str, Any] | None = None, + dbt_kwargs: dict[str, Any] | None = None, **kwargs: Any, ): self.project_dir = project_dir @@ -36,73 +38,35 @@ def __init__( self.gcp_project = profile["project"] self.dataset = profile["dataset"] self.extra_context = extra_context or {} - self.full_refresh = None - if "full_refresh" in kwargs: - self.full_refresh = kwargs.pop("full_refresh") self.configuration: dict[str, Any] = {} + self.dbt_kwargs = dbt_kwargs or {} + task_id = self.dbt_kwargs.pop("task_id") + AbstractDbtLocalBase.__init__( + self, task_id=task_id, project_dir=project_dir, profile_config=profile_config, **self.dbt_kwargs + ) + if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"): + from airflow.datasets import DatasetAlias + + # ignoring the type because older versions of Airflow raise the follow error in mypy + # error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") + dag_id = kwargs.get("dag") + task_group_id = kwargs.get("task_group") + kwargs["outlets"] = [ + DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, self.task_id)) + ] # type: ignore super().__init__( gcp_conn_id=self.gcp_conn_id, configuration=self.configuration, deferrable=True, **kwargs, ) + self.async_context = extra_context or {} + self.async_context["profile_type"] = self.profile_config.get_profile_type() + self.async_context["async_operator"] = BigQueryInsertJobOperator - def get_remote_sql(self) -> str: - if not settings.AIRFLOW_IO_AVAILABLE: - raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.") - from airflow.io.path import ObjectStoragePath - - file_path = self.extra_context["dbt_node_config"]["file_path"] # type: ignore - dbt_dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"] - - remote_target_path_str = str(remote_target_path).rstrip("/") - - if TYPE_CHECKING: # pragma: no cover - assert self.project_dir is not None - - project_dir_parent = str(Path(self.project_dir).parent) - relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/") - remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/compiled/{relative_file_path}" - - object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id) - with object_storage_path.open() as fp: # type: ignore - return fp.read() # type: ignore - - def drop_table_sql(self) -> None: - model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore - sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};" - - hook = BigQueryHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - self.configuration = { - "query": { - "query": sql, - "useLegacySql": False, - } - } - hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project) - - def execute(self, context: Context) -> Any | None: + @property + def base_cmd(self) -> list[str]: + return ["run"] - if not self.full_refresh: - raise CosmosValueError("The async execution only supported for full_refresh") - else: - # It may be surprising to some, but the dbt-core --full-refresh argument fully drops the table before populating it - # https://github.com/dbt-labs/dbt-core/blob/5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24/core/dbt/context/base.py#L662-L666 - # https://docs.getdbt.com/reference/resource-configs/full_refresh#recommendation - # We're emulating this behaviour here - # The compiled SQL has several limitations here, but these will be addressed in the PR: https://github.com/astronomer/astronomer-cosmos/pull/1474. - self.drop_table_sql() - sql = self.get_remote_sql() - model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore - # prefix explicit create command to create table - sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}" - self.configuration = { - "query": { - "query": sql, - "useLegacySql": False, - } - } - return super().execute(context) + def execute(self, context: Context, **kwargs: Any) -> None: + self.build_and_run_cmd(context=context, run_as_async=True, async_context=self.async_context) diff --git a/cosmos/operators/_asynchronous/databricks.py b/cosmos/operators/_asynchronous/databricks.py index d49fd0be0..6e39bfd7c 100644 --- a/cosmos/operators/_asynchronous/databricks.py +++ b/cosmos/operators/_asynchronous/databricks.py @@ -1,4 +1,5 @@ # TODO: Implement it +from __future__ import annotations from typing import Any diff --git a/cosmos/operators/airflow_async.py b/cosmos/operators/airflow_async.py index de8d041c4..d6b1bda5a 100644 --- a/cosmos/operators/airflow_async.py +++ b/cosmos/operators/airflow_async.py @@ -1,10 +1,12 @@ from __future__ import annotations import inspect +from typing import Any from cosmos.config import ProfileConfig +from cosmos.constants import BIGQUERY_PROFILE_TYPE from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator -from cosmos.operators.base import AbstractDbtBaseOperator +from cosmos.operators.base import AbstractDbtBase from cosmos.operators.local import ( DbtBuildLocalOperator, DbtCloneLocalOperator, @@ -18,81 +20,76 @@ DbtTestLocalOperator, ) -_SUPPORTED_DATABASES = ["bigquery"] +_SUPPORTED_DATABASES = [BIGQUERY_PROFILE_TYPE] -from abc import ABCMeta -from airflow.models.baseoperator import BaseOperator - - -class DbtBaseAirflowAsyncOperator(BaseOperator, metaclass=ABCMeta): - def __init__(self, **kwargs) -> None: # type: ignore - if "location" in kwargs: - kwargs.pop("location") - super().__init__(**kwargs) - - -class DbtBuildAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtBuildLocalOperator): # type: ignore +class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator): pass -class DbtLSAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtLSLocalOperator): # type: ignore +class DbtLSAirflowAsyncOperator(DbtLSLocalOperator): pass -class DbtSeedAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSeedLocalOperator): # type: ignore +class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator): pass -class DbtSnapshotAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSnapshotLocalOperator): # type: ignore +class DbtSnapshotAirflowAsyncOperator(DbtSnapshotLocalOperator): pass -class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalOperator): # type: ignore +class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator): pass -class DbtRunAirflowAsyncOperator(DbtRunAirflowAsyncFactoryOperator): # type: ignore +class DbtRunAirflowAsyncOperator(DbtRunAirflowAsyncFactoryOperator): - def __init__( # type: ignore + def __init__( self, project_dir: str, profile_config: ProfileConfig, extra_context: dict[str, object] | None = None, - **kwargs, + **kwargs: Any, ) -> None: # Cosmos attempts to pass many kwargs that async operator simply does not accept. # We need to pop them. clean_kwargs = {} - non_async_args = set(inspect.signature(AbstractDbtBaseOperator.__init__).parameters.keys()) + non_async_args = set(inspect.signature(AbstractDbtBase.__init__).parameters.keys()) non_async_args |= set(inspect.signature(DbtLocalBaseOperator.__init__).parameters.keys()) - non_async_args -= {"task_id"} + + dbt_kwargs = {} for arg_key, arg_value in kwargs.items(): - if arg_key not in non_async_args: + if arg_key == "task_id": + clean_kwargs[arg_key] = arg_value + dbt_kwargs[arg_key] = arg_value + elif arg_key not in non_async_args: clean_kwargs[arg_key] = arg_value + else: + dbt_kwargs[arg_key] = arg_value - # The following are the minimum required parameters to run BigQueryInsertJobOperator using the deferrable mode super().__init__( project_dir=project_dir, profile_config=profile_config, extra_context=extra_context, + dbt_kwargs=dbt_kwargs, **clean_kwargs, ) -class DbtTestAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtTestLocalOperator): # type: ignore +class DbtTestAirflowAsyncOperator(DbtTestLocalOperator): pass -class DbtRunOperationAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtRunOperationLocalOperator): # type: ignore +class DbtRunOperationAirflowAsyncOperator(DbtRunOperationLocalOperator): pass -class DbtCompileAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtCompileLocalOperator): # type: ignore +class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator): pass -class DbtCloneAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtCloneLocalOperator): +class DbtCloneAirflowAsyncOperator(DbtCloneLocalOperator): pass diff --git a/cosmos/operators/azure_container_instance.py b/cosmos/operators/azure_container_instance.py index 7f335bd99..aeeec1a23 100644 --- a/cosmos/operators/azure_container_instance.py +++ b/cosmos/operators/azure_container_instance.py @@ -1,12 +1,13 @@ from __future__ import annotations +import inspect from typing import Any, Callable, Sequence from airflow.utils.context import Context from cosmos.config import ProfileConfig from cosmos.operators.base import ( - AbstractDbtBaseOperator, + AbstractDbtBase, DbtBuildMixin, DbtCloneMixin, DbtLSMixin, @@ -28,13 +29,13 @@ ) -class DbtAzureContainerInstanceBaseOperator(AbstractDbtBaseOperator, AzureContainerInstancesOperator): # type: ignore +class DbtAzureContainerInstanceBaseOperator(AbstractDbtBase, AzureContainerInstancesOperator): # type: ignore """ Executes a dbt core cli command in an Azure Container Instance """ template_fields: Sequence[str] = tuple( - list(AbstractDbtBaseOperator.template_fields) + list(AzureContainerInstancesOperator.template_fields) + list(AbstractDbtBase.template_fields) + list(AzureContainerInstancesOperator.template_fields) ) def __init__( @@ -51,19 +52,40 @@ def __init__( **kwargs: Any, ) -> None: self.profile_config = profile_config - super().__init__( - ci_conn_id=ci_conn_id, - resource_group=resource_group, - name=name, - image=image, - region=region, - remove_on_error=remove_on_error, - fail_if_exists=fail_if_exists, - registry_conn_id=registry_conn_id, - **kwargs, + kwargs.update( + { + "ci_conn_id": ci_conn_id, + "resource_group": resource_group, + "name": name, + "image": image, + "region": region, + "remove_on_error": remove_on_error, + "fail_if_exists": fail_if_exists, + "registry_conn_id": registry_conn_id, + } ) - - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> None: + super().__init__(**kwargs) + # In PR #1474, we refactored cosmos.operators.base.AbstractDbtBase to remove its inheritance from BaseOperator + # and eliminated the super().__init__() call. This change was made to resolve conflicts in parent class + # initializations while adding support for ExecutionMode.AIRFLOW_ASYNC. Operators under this mode inherit + # Airflow provider operators that enable deferrable SQL query execution. Since super().__init__() was removed + # from AbstractDbtBase and different parent classes require distinct initialization arguments, we explicitly + # initialize them (including the BaseOperator) here by segregating the required arguments for each parent class. + base_operator_args = set(inspect.signature(AzureContainerInstancesOperator.__init__).parameters.keys()) + base_kwargs = {} + for arg_key, arg_value in kwargs.items(): + if arg_key in base_operator_args: + base_kwargs[arg_key] = arg_value + base_kwargs["task_id"] = kwargs["task_id"] + AzureContainerInstancesOperator.__init__(self, **base_kwargs) + + def build_and_run_cmd( + self, + context: Context, + cmd_flags: list[str] | None = None, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> Any: self.build_command(context, cmd_flags) self.log.info(f"Running command: {self.command}") result = AzureContainerInstancesOperator.execute(self, context) diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 52fb98bac..18019ab92 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -1,20 +1,21 @@ from __future__ import annotations +import logging import os from abc import ABCMeta, abstractmethod from pathlib import Path from typing import Any, Sequence, Tuple import yaml -from airflow.models.baseoperator import BaseOperator from airflow.utils.context import Context, context_merge from airflow.utils.operator_helpers import context_to_airflow_vars from airflow.utils.strings import to_boolean from cosmos.dbt.executable import get_system_dbt +from cosmos.log import get_logger -class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta): +class AbstractDbtBase(metaclass=ABCMeta): """ Executes a dbt core cli command. @@ -140,7 +141,6 @@ def __init__( self.cache_dir = cache_dir self.extra_context = extra_context or {} kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes - super().__init__(**kwargs) def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]: """ @@ -191,6 +191,10 @@ def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]] return filtered_env + @property + def log(self) -> logging.Logger: + return get_logger(__name__) + def add_global_flags(self) -> list[str]: flags = [] for global_flag in self.global_flags: @@ -258,10 +262,16 @@ def build_cmd( return dbt_cmd, env @abstractmethod - def build_and_run_cmd(self, context: Context, cmd_flags: list[str]) -> Any: + def build_and_run_cmd( + self, + context: Context, + cmd_flags: list[str], + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> Any: """Override this method for the operator to execute the dbt command""" - def execute(self, context: Context) -> Any | None: # type: ignore + def execute(self, context: Context, **kwargs) -> Any | None: # type: ignore if self.extra_context: context_merge(context, self.extra_context) diff --git a/cosmos/operators/docker.py b/cosmos/operators/docker.py index 8dc614cfc..879a8164c 100644 --- a/cosmos/operators/docker.py +++ b/cosmos/operators/docker.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from typing import Any, Callable, Sequence from airflow.utils.context import Context @@ -7,7 +8,7 @@ from cosmos.config import ProfileConfig from cosmos.exceptions import CosmosValueError from cosmos.operators.base import ( - AbstractDbtBaseOperator, + AbstractDbtBase, DbtBuildMixin, DbtCloneMixin, DbtLSMixin, @@ -29,15 +30,13 @@ ) -class DbtDockerBaseOperator(AbstractDbtBaseOperator, DockerOperator): # type: ignore +class DbtDockerBaseOperator(AbstractDbtBase, DockerOperator): # type: ignore """ Executes a dbt core cli command in a Docker container. """ - template_fields: Sequence[str] = tuple( - list(AbstractDbtBaseOperator.template_fields) + list(DockerOperator.template_fields) - ) + template_fields: Sequence[str] = tuple(list(AbstractDbtBase.template_fields) + list(DockerOperator.template_fields)) intercept_flag = False @@ -56,8 +55,28 @@ def __init__( ) super().__init__(image=image, **kwargs) - - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any: + # In PR #1474, we refactored cosmos.operators.base.AbstractDbtBase to remove its inheritance from BaseOperator + # and eliminated the super().__init__() call. This change was made to resolve conflicts in parent class + # initializations while adding support for ExecutionMode.AIRFLOW_ASYNC. Operators under this mode inherit + # Airflow provider operators that enable deferrable SQL query execution. Since super().__init__() was removed + # from AbstractDbtBase and different parent classes require distinct initialization arguments, we explicitly + # initialize them (including the BaseOperator) here by segregating the required arguments for each parent class. + kwargs["image"] = image + base_operator_args = set(inspect.signature(DockerOperator.__init__).parameters.keys()) + base_kwargs = {} + for arg_key, arg_value in kwargs.items(): + if arg_key in base_operator_args: + base_kwargs[arg_key] = arg_value + base_kwargs["task_id"] = kwargs["task_id"] + DockerOperator.__init__(self, **base_kwargs) + + def build_and_run_cmd( + self, + context: Context, + cmd_flags: list[str] | None = None, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> Any: self.build_command(context, cmd_flags) self.log.info(f"Running command: {self.command}") result = DockerOperator.execute(self, context) diff --git a/cosmos/operators/gcp_cloud_run_job.py b/cosmos/operators/gcp_cloud_run_job.py index ef47db2cc..e24191d6a 100644 --- a/cosmos/operators/gcp_cloud_run_job.py +++ b/cosmos/operators/gcp_cloud_run_job.py @@ -8,7 +8,7 @@ from cosmos.config import ProfileConfig from cosmos.log import get_logger from cosmos.operators.base import ( - AbstractDbtBaseOperator, + AbstractDbtBase, DbtBuildMixin, DbtCloneMixin, DbtLSMixin, @@ -41,14 +41,14 @@ ) -class DbtGcpCloudRunJobBaseOperator(AbstractDbtBaseOperator, CloudRunExecuteJobOperator): # type: ignore +class DbtGcpCloudRunJobBaseOperator(AbstractDbtBase, CloudRunExecuteJobOperator): # type: ignore """ Executes a dbt core cli command in a Cloud Run Job instance with dbt installed in it. """ template_fields: Sequence[str] = tuple( - list(AbstractDbtBaseOperator.template_fields) + list(CloudRunExecuteJobOperator.template_fields) + list(AbstractDbtBase.template_fields) + list(CloudRunExecuteJobOperator.template_fields) ) intercept_flag = False @@ -69,8 +69,36 @@ def __init__( self.command = command self.environment_variables = environment_variables or DEFAULT_ENVIRONMENT_VARIABLES super().__init__(project_id=project_id, region=region, job_name=job_name, **kwargs) - - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any: + # In PR #1474, we refactored cosmos.operators.base.AbstractDbtBase to remove its inheritance from BaseOperator + # and eliminated the super().__init__() call. This change was made to resolve conflicts in parent class + # initializations while adding support for ExecutionMode.AIRFLOW_ASYNC. Operators under this mode inherit + # Airflow provider operators that enable deferrable SQL query execution. Since super().__init__() was removed + # from AbstractDbtBase and different parent classes require distinct initialization arguments, we explicitly + # initialize them (including the BaseOperator) here by segregating the required arguments for each parent class. + kwargs.update( + { + "project_id": project_id, + "region": region, + "job_name": job_name, + "command": command, + "environment_variables": environment_variables, + } + ) + base_operator_args = set(inspect.signature(CloudRunExecuteJobOperator.__init__).parameters.keys()) + base_kwargs = {} + for arg_key, arg_value in kwargs.items(): + if arg_key in base_operator_args: + base_kwargs[arg_key] = arg_value + base_kwargs["task_id"] = kwargs["task_id"] + CloudRunExecuteJobOperator.__init__(self, **base_kwargs) + + def build_and_run_cmd( + self, + context: Context, + cmd_flags: list[str] | None = None, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> Any: self.build_command(context, cmd_flags) self.log.info(f"Running command: {self.command}") result = CloudRunExecuteJobOperator.execute(self, context) diff --git a/cosmos/operators/kubernetes.py b/cosmos/operators/kubernetes.py index b00e6380c..8cbc20e1c 100644 --- a/cosmos/operators/kubernetes.py +++ b/cosmos/operators/kubernetes.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from abc import ABC from os import PathLike from typing import Any, Callable, Sequence @@ -10,7 +11,7 @@ from cosmos.config import ProfileConfig from cosmos.dbt.parser.output import extract_log_issues from cosmos.operators.base import ( - AbstractDbtBaseOperator, + AbstractDbtBase, DbtBuildMixin, DbtCloneMixin, DbtLSMixin, @@ -43,14 +44,14 @@ ) -class DbtKubernetesBaseOperator(AbstractDbtBaseOperator, KubernetesPodOperator): # type: ignore +class DbtKubernetesBaseOperator(AbstractDbtBase, KubernetesPodOperator): # type: ignore """ Executes a dbt core cli command in a Kubernetes Pod. """ template_fields: Sequence[str] = tuple( - list(AbstractDbtBaseOperator.template_fields) + list(KubernetesPodOperator.template_fields) + list(AbstractDbtBase.template_fields) + list(KubernetesPodOperator.template_fields) ) intercept_flag = False @@ -58,6 +59,19 @@ class DbtKubernetesBaseOperator(AbstractDbtBaseOperator, KubernetesPodOperator): def __init__(self, profile_config: ProfileConfig | None = None, **kwargs: Any) -> None: self.profile_config = profile_config super().__init__(**kwargs) + # In PR #1474, we refactored cosmos.operators.base.AbstractDbtBase to remove its inheritance from BaseOperator + # and eliminated the super().__init__() call. This change was made to resolve conflicts in parent class + # initializations while adding support for ExecutionMode.AIRFLOW_ASYNC. Operators under this mode inherit + # Airflow provider operators that enable deferrable SQL query execution. Since super().__init__() was removed + # from AbstractDbtBase and different parent classes require distinct initialization arguments, we explicitly + # initialize them (including the BaseOperator) here by segregating the required arguments for each parent class. + base_operator_args = set(inspect.signature(KubernetesPodOperator.__init__).parameters.keys()) + base_kwargs = {} + for arg_key, arg_value in kwargs.items(): + if arg_key in base_operator_args: + base_kwargs[arg_key] = arg_value + base_kwargs["task_id"] = kwargs["task_id"] + KubernetesPodOperator.__init__(self, **base_kwargs) def build_env_args(self, env: dict[str, str | bytes | PathLike[Any]]) -> None: env_vars_dict: dict[str, str] = dict() @@ -69,7 +83,13 @@ def build_env_args(self, env: dict[str, str | bytes | PathLike[Any]]) -> None: self.env_vars: list[Any] = convert_env_vars(env_vars_dict) - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any: + def build_and_run_cmd( + self, + context: Context, + cmd_flags: list[str] | None = None, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> Any: self.build_kube_args(context, cmd_flags) self.log.info(f"Running command: {self.arguments}") result = KubernetesPodOperator.execute(self, context) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index e5dbcfd31..91b3dd314 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import json import os import tempfile @@ -15,6 +16,7 @@ import jinja2 from airflow import DAG from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.models import BaseOperator from airflow.models.taskinstance import TaskInstance from airflow.utils.context import Context from airflow.utils.session import NEW_SESSION, create_session, provide_session @@ -66,13 +68,14 @@ parse_number_of_warnings_subprocess, ) from cosmos.dbt.project import create_symlinks +from cosmos.dbt_adapters import PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP, associate_async_operator_args from cosmos.hooks.subprocess import ( FullOutputSubprocessHook, FullOutputSubprocessResult, ) from cosmos.log import get_logger from cosmos.operators.base import ( - AbstractDbtBaseOperator, + AbstractDbtBase, DbtBuildMixin, DbtCloneMixin, DbtCompileMixin, @@ -112,7 +115,7 @@ class OperatorLineage: # type: ignore job_facets: dict[str, str] = dict() -class DbtLocalBaseOperator(AbstractDbtBaseOperator): +class AbstractDbtLocalBase(AbstractDbtBase): """ Executes a dbt core cli command locally. @@ -131,7 +134,7 @@ class DbtLocalBaseOperator(AbstractDbtBaseOperator): and does not inherit the current process environment. """ - template_fields: Sequence[str] = AbstractDbtBaseOperator.template_fields + ("compiled_sql", "freshness") # type: ignore[operator] + template_fields: Sequence[str] = AbstractDbtBase.template_fields + ("compiled_sql", "freshness") # type: ignore[operator] template_fields_renderers = { "compiled_sql": "sql", "freshness": "json", @@ -162,17 +165,6 @@ def __init__( self.invocation_mode = invocation_mode self._dbt_runner: dbtRunner | None = None - if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"): - from airflow.datasets import DatasetAlias - - # ignoring the type because older versions of Airflow raise the follow error in mypy - # error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") - dag_id = kwargs.get("dag") - task_group_id = kwargs.get("task_group") - kwargs["outlets"] = [ - DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, task_id)) - ] # type: ignore - super().__init__(task_id=task_id, **kwargs) # For local execution mode, we're consistent with the LoadMode.DBT_LS command in forwarding the environment @@ -271,7 +263,7 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se # delete the old records session.query(RenderedTaskInstanceFields).filter( - RenderedTaskInstanceFields.dag_id == self.dag_id, + RenderedTaskInstanceFields.dag_id == self.dag_id, # type: ignore[attr-defined] RenderedTaskInstanceFields.task_id == self.task_id, RenderedTaskInstanceFields.run_id == ti.run_id, ).delete() @@ -401,12 +393,97 @@ def _cache_package_lockfile(self, tmp_project_dir: Path) -> None: if latest_package_lockfile: _copy_cached_package_lockfile_to_project(latest_package_lockfile, tmp_project_dir) + def _read_run_sql_from_target_dir(self, tmp_project_dir: str, sql_context: dict[str, Any]) -> str: + sql_relative_path = sql_context["dbt_node_config"]["file_path"].split(str(self.project_dir))[-1].lstrip("/") + run_sql_path = Path(tmp_project_dir) / "target/run" / Path(self.project_dir).name / sql_relative_path + with run_sql_path.open("r") as sql_file: + sql_content: str = sql_file.read() + return sql_content + + def _clone_project(self, tmp_dir_path: Path) -> None: + self.log.info( + "Cloning project to writable temp directory %s from %s", + tmp_dir_path, + self.project_dir, + ) + create_symlinks(Path(self.project_dir), tmp_dir_path, self.install_deps) + + def _handle_partial_parse(self, tmp_dir_path: Path) -> None: + if self.cache_dir is None: + return + latest_partial_parse = cache._get_latest_partial_parse(Path(self.project_dir), self.cache_dir) + self.log.info("Partial parse is enabled and the latest partial parse file is %s", latest_partial_parse) + if latest_partial_parse is not None: + cache._copy_partial_parse_to_project(latest_partial_parse, tmp_dir_path) + + def _generate_dbt_flags(self, tmp_project_dir: str, profile_path: Path) -> list[str]: + return [ + "--project-dir", + str(tmp_project_dir), + "--profiles-dir", + str(profile_path.parent), + "--profile", + self.profile_config.profile_name, + "--target", + self.profile_config.target_name, + ] + + def _install_dependencies( + self, tmp_dir_path: Path, flags: list[str], env: dict[str, str | bytes | os.PathLike[Any]] + ) -> None: + self._cache_package_lockfile(tmp_dir_path) + deps_command = [self.dbt_executable_path, "deps"] + flags + self.invoke_dbt(command=deps_command, env=env, cwd=tmp_dir_path) + + @staticmethod + def _mock_dbt_adapter(async_context: dict[str, Any] | None) -> None: + if not async_context: + raise CosmosValueError("`async_context` is necessary for running the model asynchronously") + if "async_operator" not in async_context: + raise CosmosValueError("`async_operator` needs to be specified in `async_context` when running as async") + if "profile_type" not in async_context: + raise CosmosValueError("`profile_type` needs to be specified in `async_context` when running as async") + profile_type = async_context["profile_type"] + if profile_type not in PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP: + raise CosmosValueError(f"Mock adapter callable function not available for profile_type {profile_type}") + mock_adapter_callable = PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP[profile_type] + mock_adapter_callable() + + def _handle_datasets(self, context: Context) -> None: + inlets = self.get_datasets("inputs") + outlets = self.get_datasets("outputs") + self.log.info("Inlets: %s", inlets) + self.log.info("Outlets: %s", outlets) + self.register_dataset(inlets, outlets, context) + + def _update_partial_parse_cache(self, tmp_dir_path: Path) -> None: + if self.cache_dir is None: + return + partial_parse_file = get_partial_parse_path(tmp_dir_path) + if partial_parse_file.exists(): + cache._update_partial_parse_cache(partial_parse_file, self.cache_dir) + + def _handle_post_execution(self, tmp_project_dir: str, context: Context) -> None: + self.store_freshness_json(tmp_project_dir, context) + self.store_compiled_sql(tmp_project_dir, context) + self.upload_compiled_sql(tmp_project_dir, context) + if self.callback: + self.callback_args.update({"context": context}) + self.callback(tmp_project_dir, **self.callback_args) + + def _handle_async_execution(self, tmp_project_dir: str, context: Context, async_context: dict[str, Any]) -> None: + sql = self._read_run_sql_from_target_dir(tmp_project_dir, async_context) + associate_async_operator_args(self, async_context["profile_type"], sql=sql) + async_context["async_operator"].execute(self, context) + def run_command( self, cmd: list[str], env: dict[str, str | bytes | os.PathLike[Any]], context: Context, - ) -> FullOutputSubprocessResult | dbtRunnerResult: + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> FullOutputSubprocessResult | dbtRunnerResult | str: """ Copies the dbt project to a temporary directory and runs the command. """ @@ -415,50 +492,27 @@ def run_command( with tempfile.TemporaryDirectory() as tmp_project_dir: - self.log.info( - "Cloning project to writable temp directory %s from %s", - tmp_project_dir, - self.project_dir, - ) tmp_dir_path = Path(tmp_project_dir) env = {k: str(v) for k, v in env.items()} - create_symlinks(Path(self.project_dir), tmp_dir_path, self.install_deps) + self._clone_project(tmp_dir_path) - if self.partial_parse and self.cache_dir is not None: - latest_partial_parse = cache._get_latest_partial_parse(Path(self.project_dir), self.cache_dir) - self.log.info("Partial parse is enabled and the latest partial parse file is %s", latest_partial_parse) - if latest_partial_parse is not None: - cache._copy_partial_parse_to_project(latest_partial_parse, tmp_dir_path) + if self.partial_parse: + self._handle_partial_parse(tmp_dir_path) with self.profile_config.ensure_profile() as profile_values: (profile_path, env_vars) = profile_values env.update(env_vars) + self.log.debug("Using environment variables keys: %s", env.keys()) - flags = [ - "--project-dir", - str(tmp_project_dir), - "--profiles-dir", - str(profile_path.parent), - "--profile", - self.profile_config.profile_name, - "--target", - self.profile_config.target_name, - ] + flags = self._generate_dbt_flags(tmp_project_dir, profile_path) if self.install_deps: - self._cache_package_lockfile(tmp_dir_path) - deps_command = [self.dbt_executable_path, "deps"] - deps_command.extend(flags) - self.invoke_dbt( - command=deps_command, - env=env, - cwd=tmp_project_dir, - ) + self._install_dependencies(tmp_dir_path, flags, env) - full_cmd = cmd + flags - - self.log.debug("Using environment variables keys: %s", env.keys()) + if run_as_async: + self._mock_dbt_adapter(async_context) + full_cmd = cmd + flags result = self.invoke_dbt( command=full_cmd, env=env, @@ -471,25 +525,17 @@ def run_command( ].openlineage_events_completes = self.openlineage_events_completes # type: ignore if self.emit_datasets: - inlets = self.get_datasets("inputs") - outlets = self.get_datasets("outputs") - self.log.info("Inlets: %s", inlets) - self.log.info("Outlets: %s", outlets) - self.register_dataset(inlets, outlets, context) - - if self.partial_parse and self.cache_dir: - partial_parse_file = get_partial_parse_path(tmp_dir_path) - if partial_parse_file.exists(): - cache._update_partial_parse_cache(partial_parse_file, self.cache_dir) - - self.store_freshness_json(tmp_project_dir, context) - self.store_compiled_sql(tmp_project_dir, context) - self.upload_compiled_sql(tmp_project_dir, context) - if self.callback: - self.callback_args.update({"context": context}) - self.callback(tmp_project_dir, **self.callback_args) + self._handle_datasets(context) + + if self.partial_parse: + self._update_partial_parse_cache(tmp_dir_path) + + self._handle_post_execution(tmp_project_dir, context) self.handle_exception(result) + if run_as_async and async_context: + self._handle_async_execution(tmp_project_dir, context, async_context) + return result def calculate_openlineage_events_completes( @@ -576,17 +622,17 @@ def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset] if AIRFLOW_VERSION < Version("2.10") or not settings.enable_dataset_alias: logger.info("Assigning inlets/outlets without DatasetAlias") with create_session() as session: - self.outlets.extend(new_outlets) - self.inlets.extend(new_inlets) - for task in self.dag.tasks: + self.outlets.extend(new_outlets) # type: ignore[attr-defined] + self.inlets.extend(new_inlets) # type: ignore[attr-defined] + for task in self.dag.tasks: # type: ignore[attr-defined] if task.task_id == self.task_id: task.outlets.extend(new_outlets) task.inlets.extend(new_inlets) - DAG.bulk_write_to_db([self.dag], session=session) + DAG.bulk_write_to_db([self.dag], session=session) # type: ignore[attr-defined] session.commit() else: logger.info("Assigning inlets/outlets with DatasetAlias") - dataset_alias_name = get_dataset_alias_name(self.dag, self.task_group, self.task_id) + dataset_alias_name = get_dataset_alias_name(self.dag, self.task_group, self.task_id) # type: ignore[attr-defined] for outlet in new_outlets: context["outlet_events"][dataset_alias_name].add(outlet) @@ -629,11 +675,17 @@ def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> Ope ) def build_and_run_cmd( - self, context: Context, cmd_flags: list[str] | None = None + self, + context: Context, + cmd_flags: list[str] | None = None, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, ) -> FullOutputSubprocessResult | dbtRunnerResult: dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags) dbt_cmd = dbt_cmd or [] - result = self.run_command(cmd=dbt_cmd, env=env, context=context) + result = self.run_command( + cmd=dbt_cmd, env=env, context=context, run_as_async=run_as_async, async_context=async_context + ) return result def on_kill(self) -> None: @@ -644,6 +696,43 @@ def on_kill(self) -> None: self.subprocess_hook.send_sigterm() +class DbtLocalBaseOperator(AbstractDbtLocalBase, BaseOperator): + + template_fields: Sequence[str] = AbstractDbtLocalBase.template_fields # type: ignore[operator] + + def __init__(self, *args: Any, **kwargs: Any) -> None: + # In PR #1474, we refactored cosmos.operators.base.AbstractDbtBase to remove its inheritance from BaseOperator + # and eliminated the super().__init__() call. This change was made to resolve conflicts in parent class + # initializations while adding support for ExecutionMode.AIRFLOW_ASYNC. Operators under this mode inherit + # Airflow provider operators that enable deferrable SQL query execution. Since super().__init__() was removed + # from AbstractDbtBase and different parent classes require distinct initialization arguments, we explicitly + # initialize them (including the BaseOperator) here by segregating the required arguments for each parent class. + abstract_dbt_local_base_kwargs = {} + base_operator_kwargs = {} + abstract_dbt_local_base_args_keys = ( + inspect.getfullargspec(AbstractDbtBase.__init__).args + + inspect.getfullargspec(AbstractDbtLocalBase.__init__).args + ) + base_operator_args = set(inspect.signature(BaseOperator.__init__).parameters.keys()) + for arg_key, arg_value in kwargs.items(): + if arg_key in abstract_dbt_local_base_args_keys: + abstract_dbt_local_base_kwargs[arg_key] = arg_value + if arg_key in base_operator_args: + base_operator_kwargs[arg_key] = arg_value + AbstractDbtLocalBase.__init__(self, **abstract_dbt_local_base_kwargs) + if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"): + from airflow.datasets import DatasetAlias + + # ignoring the type because older versions of Airflow raise the follow error in mypy + # error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") + dag_id = kwargs.get("dag") + task_group_id = kwargs.get("task_group") + base_operator_kwargs["outlets"] = [ + DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, self.task_id)) + ] # type: ignore + BaseOperator.__init__(self, **base_operator_kwargs) + + class DbtBuildLocalOperator(DbtBuildMixin, DbtLocalBaseOperator): """ Executes a dbt core build command. @@ -660,6 +749,8 @@ class DbtLSLocalOperator(DbtLSMixin, DbtLocalBaseOperator): Executes a dbt core ls command. """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -680,6 +771,8 @@ class DbtSnapshotLocalOperator(DbtSnapshotMixin, DbtLocalBaseOperator): Executes a dbt core snapshot command. """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -689,6 +782,8 @@ class DbtSourceLocalOperator(DbtSourceMixin, DbtLocalBaseOperator): Executes a dbt source freshness command. """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.on_warning_callback = on_warning_callback @@ -715,7 +810,7 @@ def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, self.on_warning_callback and self.on_warning_callback(warning_context) - def execute(self, context: Context) -> None: + def execute(self, context: Context, **kwargs: Any) -> None: result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) if self.on_warning_callback: self._handle_warnings(result, context) @@ -739,6 +834,8 @@ class DbtTestLocalOperator(DbtTestMixin, DbtLocalBaseOperator): and "test_results" of type `List`. Each index in "test_names" corresponds to the same index in "test_results". """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + def __init__( self, on_warning_callback: Callable[..., Any] | None = None, @@ -774,7 +871,7 @@ def _set_test_result_parsing_methods(self) -> None: self.extract_issues = dbt_runner.extract_message_by_status self.parse_number_of_warnings = dbt_runner.parse_number_of_warnings - def execute(self, context: Context) -> None: + def execute(self, context: Context, **kwargs: Any) -> None: result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) self._set_test_result_parsing_methods() number_of_warnings = self.parse_number_of_warnings(result) # type: ignore @@ -803,6 +900,8 @@ class DbtDocsLocalOperator(DbtLocalBaseOperator): Use the `callback` parameter to specify a callback function to run after the command completes. """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + ui_color = "#8194E0" required_files = ["index.html", "manifest.json", "catalog.json"] base_cmd = ["docs", "generate"] @@ -826,6 +925,8 @@ class DbtDocsCloudLocalOperator(DbtDocsLocalOperator, ABC): Abstract class for operators that upload the generated documentation to cloud storage. """ + template_fields: Sequence[str] = DbtDocsLocalOperator.template_fields # type: ignore[operator] + def __init__( self, connection_id: str, @@ -1021,6 +1122,8 @@ def __init__(self, **kwargs: str) -> None: class DbtCompileLocalOperator(DbtCompileMixin, DbtLocalBaseOperator): + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["should_upload_compiled_sql"] = True super().__init__(*args, **kwargs) @@ -1031,5 +1134,7 @@ class DbtCloneLocalOperator(DbtCloneMixin, DbtLocalBaseOperator): Executes a dbt core clone command. """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index 3bd54da99..4026d3eb4 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -5,7 +5,7 @@ import time from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Sequence import psutil from airflow.utils.python_virtualenv import prepare_virtualenv @@ -96,6 +96,8 @@ def run_command( cmd: list[str], env: dict[str, str | bytes | os.PathLike[Any]], context: Context, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, ) -> FullOutputSubprocessResult | dbtRunnerResult: # No virtualenv_dir set, so create a temporary virtualenv if self.virtualenv_dir is None or self.is_virtualenv_dir_temporary: @@ -128,7 +130,7 @@ def clean_dir_if_temporary(self) -> None: self.log.info(f"Deleting the Python virtualenv {self.virtualenv_dir}") shutil.rmtree(str(self.virtualenv_dir), ignore_errors=True) - def execute(self, context: Context) -> None: + def execute(self, context: Context, **kwargs: Any) -> None: try: output = super().execute(context) self.log.info(output) @@ -215,6 +217,8 @@ class DbtLSVirtualenvOperator(DbtVirtualenvBaseOperator, DbtLSLocalOperator): and deleted just after. """ + template_fields: Sequence[str] = DbtVirtualenvBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -235,6 +239,8 @@ class DbtSnapshotVirtualenvOperator(DbtVirtualenvBaseOperator, DbtSnapshotLocalO command and deleted just after. """ + template_fields: Sequence[str] = DbtVirtualenvBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -245,6 +251,8 @@ class DbtSourceVirtualenvOperator(DbtVirtualenvBaseOperator, DbtSourceLocalOpera command and deleted just after. """ + template_fields: Sequence[str] = DbtVirtualenvBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -265,6 +273,8 @@ class DbtTestVirtualenvOperator(DbtVirtualenvBaseOperator, DbtTestLocalOperator) and deleted just after. """ + template_fields: Sequence[str] = DbtVirtualenvBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -285,6 +295,8 @@ class DbtDocsVirtualenvOperator(DbtVirtualenvBaseOperator, DbtDocsLocalOperator) command and deleted just after. """ + template_fields: Sequence[str] = DbtVirtualenvBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -294,5 +306,7 @@ class DbtCloneVirtualenvOperator(DbtVirtualenvBaseOperator, DbtCloneLocalOperato Executes a dbt core clone command. """ + template_fields: Sequence[str] = DbtVirtualenvBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) diff --git a/dev/dags/simple_dag_async.py b/dev/dags/simple_dag_async.py index 1b2b67651..8fb8cb844 100644 --- a/dev/dags/simple_dag_async.py +++ b/dev/dags/simple_dag_async.py @@ -37,6 +37,6 @@ catchup=False, dag_id="simple_dag_async", tags=["simple"], - operator_args={"full_refresh": True, "location": "northamerica-northeast1"}, + operator_args={"location": "northamerica-northeast1"}, ) # [END airflow_async_execution_mode_example] diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index ccbd911be..d86abab74 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -1,7 +1,7 @@ import os from datetime import datetime from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from airflow import __version__ as airflow_version @@ -22,7 +22,6 @@ ) from cosmos.config import ProfileConfig, RenderConfig from cosmos.constants import ( - DBT_COMPILE_TASK_ID, DbtResourceType, ExecutionMode, SourceRenderingBehavior, @@ -31,7 +30,7 @@ ) from cosmos.converter import airflow_kwargs from cosmos.dbt.graph import DbtNode -from cosmos.profiles import GoogleCloudServiceAccountFileProfileMapping, PostgresUserPasswordProfileMapping +from cosmos.profiles import PostgresUserPasswordProfileMapping SAMPLE_PROJ_PATH = Path("/home/user/path/dbt-proj/") SOURCE_RENDERING_BEHAVIOR = SourceRenderingBehavior(os.getenv("SOURCE_RENDERING_BEHAVIOR", "none")) @@ -347,42 +346,6 @@ def test_build_airflow_graph_with_override_profile_config(): assert generated_parent_profile_config.profile_mapping.profile_args["schema"] == "public" -@pytest.mark.integration -@patch("airflow.hooks.base.BaseHook.get_connection", new=MagicMock()) -def test_build_airflow_graph_with_dbt_compile_task(): - bigquery_profile_config = ProfileConfig( - profile_name="my-bigquery-db", - target_name="dev", - profile_mapping=GoogleCloudServiceAccountFileProfileMapping( - conn_id="fake_conn", profile_args={"dataset": "release_17"} - ), - ) - with DAG("test-id-dbt-compile", start_date=datetime(2022, 1, 1)) as dag: - task_args = { - "project_dir": SAMPLE_PROJ_PATH, - "conn_id": "fake_conn", - "profile_config": bigquery_profile_config, - } - render_config = RenderConfig( - select=["tag:some"], - test_behavior=TestBehavior.AFTER_ALL, - source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR, - ) - build_airflow_graph( - nodes=sample_nodes, - dag=dag, - execution_mode=ExecutionMode.AIRFLOW_ASYNC, - test_indirect_selection=TestIndirectSelection.EAGER, - task_args=task_args, - dbt_project_name="astro_shop", - render_config=render_config, - ) - - task_ids = [task.task_id for task in dag.tasks] - assert DBT_COMPILE_TASK_ID in task_ids - assert DBT_COMPILE_TASK_ID in dag.tasks[0].upstream_task_ids - - def test_calculate_operator_class(): class_module_import_path = calculate_operator_class(execution_mode=ExecutionMode.KUBERNETES, dbt_class="DbtSeed") assert class_module_import_path == "cosmos.operators.kubernetes.DbtSeedKubernetesOperator" diff --git a/tests/dbt_adapters/test_bigquery.py b/tests/dbt_adapters/test_bigquery.py new file mode 100644 index 000000000..d8921d059 --- /dev/null +++ b/tests/dbt_adapters/test_bigquery.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from unittest.mock import Mock + +import pytest + +from cosmos.dbt_adapters.bigquery import _associate_bigquery_async_op_args, _mock_bigquery_adapter +from cosmos.exceptions import CosmosValueError + + +@pytest.fixture +def async_operator_mock(): + """Fixture to create a mock async operator object.""" + return Mock() + + +@pytest.mark.integration +def test_mock_bigquery_adapter(): + """Test _mock_bigquery_adapter to verify it modifies BigQueryConnectionManager.execute.""" + from dbt.adapters.bigquery.connections import BigQueryConnectionManager + + _mock_bigquery_adapter() + + assert hasattr(BigQueryConnectionManager, "execute") + + response, table = BigQueryConnectionManager.execute(None, sql="SELECT 1") + assert response._message == "mock_bigquery_adapter_response" + assert table is not None + + +def test_associate_bigquery_async_op_args_valid(async_operator_mock): + """Test _associate_bigquery_async_op_args correctly configures the async operator.""" + sql_query = "SELECT * FROM test_table" + + result = _associate_bigquery_async_op_args(async_operator_mock, sql=sql_query) + + assert result == async_operator_mock + assert result.configuration["query"]["query"] == sql_query + assert result.configuration["query"]["useLegacySql"] is False + + +def test_associate_bigquery_async_op_args_missing_sql(async_operator_mock): + """Test _associate_bigquery_async_op_args raises CosmosValueError when 'sql' is missing.""" + with pytest.raises(CosmosValueError, match="Keyword argument 'sql' is required for BigQuery Async operator"): + _associate_bigquery_async_op_args(async_operator_mock) diff --git a/tests/dbt_adapters/test_init.py b/tests/dbt_adapters/test_init.py new file mode 100644 index 000000000..ce272e333 --- /dev/null +++ b/tests/dbt_adapters/test_init.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from unittest.mock import Mock + +import pytest + +from cosmos.dbt_adapters import associate_async_operator_args + + +def test_associate_async_operator_args_invalid_profile(): + """Test associate_async_operator_args raises KeyError for an invalid profile type.""" + async_operator_mock = Mock() + + with pytest.raises(KeyError): + associate_async_operator_args(async_operator_mock, "invalid_profile") diff --git a/tests/operators/_asynchronous/test_base.py b/tests/operators/_asynchronous/test_base.py index c01bbd866..f3e49a621 100644 --- a/tests/operators/_asynchronous/test_base.py +++ b/tests/operators/_asynchronous/test_base.py @@ -1,12 +1,13 @@ -from unittest.mock import patch +from __future__ import annotations + +from unittest.mock import MagicMock, patch import pytest -from cosmos import ProfileConfig +from cosmos.config import ProfileConfig from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator, _create_async_operator_class from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator from cosmos.operators.local import DbtRunLocalOperator -from cosmos.profiles import get_automatic_profile_mapping @pytest.mark.parametrize( @@ -25,30 +26,45 @@ def test_create_async_operator_class_success(profile_type, dbt_class, expected_o assert operator_class == expected_operator_class -@patch("cosmos.operators._asynchronous.bigquery.DbtRunAirflowAsyncBigqueryOperator.drop_table_sql") -@patch("cosmos.operators._asynchronous.bigquery.DbtRunAirflowAsyncBigqueryOperator.get_remote_sql") -@patch("cosmos.operators._asynchronous.bigquery.BigQueryInsertJobOperator.execute") -def test_factory_async_class(mock_execute, get_remote_sql, drop_table_sql, mock_bigquery_conn): - profile_mapping = get_automatic_profile_mapping( - mock_bigquery_conn.conn_id, - profile_args={ - "dataset": "my_dataset", - }, - ) - bigquery_profile_config = ProfileConfig( - profile_name="my_profile", target_name="dev", profile_mapping=profile_mapping - ) - factory_class = DbtRunAirflowAsyncFactoryOperator( - task_id="run", - project_dir="/tmp", - profile_config=bigquery_profile_config, - full_refresh=True, - extra_context={"dbt_node_config": {"resource_name": "customer"}}, - ) +@pytest.fixture +def profile_config_mock(): + """Fixture to create a mock ProfileConfig.""" + mock_config = MagicMock(spec=ProfileConfig) + mock_config.get_profile_type.return_value = "bigquery" + return mock_config + + +def test_create_async_operator_class_valid(): + """Test _create_async_operator_class returns the correct async operator class if available.""" + with patch("cosmos.operators._asynchronous.base.importlib.import_module") as mock_import: + mock_class = MagicMock() + mock_import.return_value = MagicMock() + setattr(mock_import.return_value, "DbtRunAirflowAsyncBigqueryOperator", mock_class) + + result = _create_async_operator_class("bigquery", "DbtRun") + assert result == mock_class - async_operator = factory_class.create_async_operator() - assert async_operator == DbtRunAirflowAsyncBigqueryOperator - factory_class.execute(context={}) +def test_create_async_operator_class_fallback(): + """Test _create_async_operator_class falls back to DbtRunLocalOperator when import fails.""" + with patch("cosmos.operators._asynchronous.base.importlib.import_module", side_effect=ModuleNotFoundError): + result = _create_async_operator_class("bigquery", "DbtRun") + assert result == DbtRunLocalOperator + + +class MockAsyncOperator(DbtRunLocalOperator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +@patch("cosmos.operators._asynchronous.base._create_async_operator_class", return_value=MockAsyncOperator) +def test_dbt_run_airflow_async_factory_operator_init(mock_create_class, profile_config_mock): + + operator = DbtRunAirflowAsyncFactoryOperator( + task_id="test_task", + project_dir="some/path", + profile_config=profile_config_mock, + ) - mock_execute.assert_called_once_with({}) + assert operator is not None + assert isinstance(operator, MockAsyncOperator) diff --git a/tests/operators/_asynchronous/test_bigquery.py b/tests/operators/_asynchronous/test_bigquery.py index 6eb532107..34182784b 100644 --- a/tests/operators/_asynchronous/test_bigquery.py +++ b/tests/operators/_asynchronous/test_bigquery.py @@ -1,96 +1,71 @@ +from __future__ import annotations + from unittest.mock import MagicMock, patch import pytest -from airflow import __version__ as airflow_version -from packaging import version +from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator -from cosmos import ProfileConfig -from cosmos.exceptions import CosmosValueError +from cosmos.config import ProfileConfig from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator -from cosmos.profiles import get_automatic_profile_mapping -from cosmos.settings import AIRFLOW_IO_AVAILABLE - -def test_bigquery_without_refresh(mock_bigquery_conn): - profile_mapping = get_automatic_profile_mapping( - mock_bigquery_conn.conn_id, - profile_args={ - "dataset": "my_dataset", - }, - ) - bigquery_profile_config = ProfileConfig( - profile_name="my_profile", target_name="dev", profile_mapping=profile_mapping - ) - operator = DbtRunAirflowAsyncBigqueryOperator( - task_id="test_task", project_dir="/tmp", profile_config=bigquery_profile_config - ) - operator.extra_context = { - "dbt_node_config": {"file_path": "/some/path/to/file.sql"}, - "dbt_dag_task_group_identifier": "task_group_1", - } - with pytest.raises(CosmosValueError, match="The async execution only supported for full_refresh"): - operator.execute({}) +@pytest.fixture +def profile_config_mock(): + """Fixture to create a mock ProfileConfig.""" + mock_config = MagicMock(spec=ProfileConfig) + mock_config.get_profile_type.return_value = "bigquery" + mock_config.profile_mapping.conn_id = "google_cloud_default" + mock_config.profile_mapping.profile = {"project": "test_project", "dataset": "test_dataset"} + return mock_config -def test_get_remote_sql_airflow_io_unavailable(mock_bigquery_conn): - profile_mapping = get_automatic_profile_mapping( - mock_bigquery_conn.conn_id, - profile_args={ - "dataset": "my_dataset", - }, - ) - bigquery_profile_config = ProfileConfig( - profile_name="my_profile", target_name="dev", profile_mapping=profile_mapping - ) +def test_dbt_run_airflow_async_bigquery_operator_init(profile_config_mock): + """Test DbtRunAirflowAsyncBigqueryOperator initializes with correct attributes.""" operator = DbtRunAirflowAsyncBigqueryOperator( - task_id="test_task", project_dir="/tmp", profile_config=bigquery_profile_config + task_id="test_task", + project_dir="/path/to/project", + profile_config=profile_config_mock, + dbt_kwargs={"task_id": "test_task"}, ) - operator.extra_context = { - "dbt_node_config": {"file_path": "/some/path/to/file.sql"}, - "dbt_dag_task_group_identifier": "task_group_1", - } + assert isinstance(operator, DbtRunAirflowAsyncBigqueryOperator) + assert isinstance(operator, BigQueryInsertJobOperator) + assert operator.project_dir == "/path/to/project" + assert operator.profile_config == profile_config_mock + assert operator.gcp_conn_id == "google_cloud_default" + assert operator.gcp_project == "test_project" + assert operator.dataset == "test_dataset" - if not AIRFLOW_IO_AVAILABLE: - with pytest.raises( - CosmosValueError, match="Cosmos async support is only available starting in Airflow 2.8 or later." - ): - operator.get_remote_sql() - -@pytest.mark.skipif( - version.parse(airflow_version) < version.parse("2.8"), - reason="Airflow object storage supported 2.8 release", -) -def test_get_remote_sql_success(mock_bigquery_conn): - profile_mapping = get_automatic_profile_mapping( - mock_bigquery_conn.conn_id, - profile_args={ - "dataset": "my_dataset", - }, - ) - bigquery_profile_config = ProfileConfig( - profile_name="my_profile", target_name="dev", profile_mapping=profile_mapping - ) +def test_dbt_run_airflow_async_bigquery_operator_base_cmd(profile_config_mock): + """Test base_cmd property returns the correct dbt command.""" operator = DbtRunAirflowAsyncBigqueryOperator( - task_id="test_task", project_dir="/tmp", profile_config=bigquery_profile_config + task_id="test_task", + project_dir="/path/to/project", + profile_config=profile_config_mock, + dbt_kwargs={"task_id": "test_task"}, ) + assert operator.base_cmd == ["run"] - operator.extra_context = { - "dbt_node_config": {"file_path": "/some/path/to/file.sql"}, - "dbt_dag_task_group_identifier": "task_group_1", - } - operator.project_dir = "/tmp" - - mock_object_storage_path = MagicMock() - mock_file = MagicMock() - mock_file.read.return_value = "SELECT * FROM table" - mock_object_storage_path.open.return_value.__enter__.return_value = mock_file +@patch.object(DbtRunAirflowAsyncBigqueryOperator, "build_and_run_cmd") +def test_dbt_run_airflow_async_bigquery_operator_execute(mock_build_and_run_cmd, profile_config_mock): + """Test execute calls build_and_run_cmd with correct parameters.""" + operator = DbtRunAirflowAsyncBigqueryOperator( + task_id="test_task", + project_dir="/path/to/project", + profile_config=profile_config_mock, + dbt_kwargs={"task_id": "test_task"}, + ) - with patch("airflow.io.path.ObjectStoragePath", return_value=mock_object_storage_path): - remote_sql = operator.get_remote_sql() + mock_context = MagicMock() + operator.execute(mock_context) - assert remote_sql == "SELECT * FROM table" - mock_object_storage_path.open.assert_called_once() + mock_build_and_run_cmd.assert_called_once_with( + context=mock_context, + run_as_async=True, + async_context={ + "profile_type": "bigquery", + "async_operator": BigQueryInsertJobOperator, + }, + ) diff --git a/tests/operators/test_aws_eks.py b/tests/operators/test_aws_eks.py index bca007c4d..86f9409b2 100644 --- a/tests/operators/test_aws_eks.py +++ b/tests/operators/test_aws_eks.py @@ -38,7 +38,6 @@ def test_dbt_kubernetes_build_command(): Since we know that the KubernetesOperator is tested, we can just test that the command is built correctly and added to the "arguments" parameter. """ - result_map = { "ls": DbtLSAwsEksOperator(**base_kwargs), "run": DbtRunAwsEksOperator(**base_kwargs), diff --git a/tests/operators/test_base.py b/tests/operators/test_base.py index e97c2d396..7394a7df9 100644 --- a/tests/operators/test_base.py +++ b/tests/operators/test_base.py @@ -1,12 +1,14 @@ +import inspect import sys from datetime import datetime from unittest.mock import patch import pytest +from airflow.models import BaseOperator from airflow.utils.context import Context from cosmos.operators.base import ( - AbstractDbtBaseOperator, + AbstractDbtBase, DbtBuildMixin, DbtCompileMixin, DbtLSMixin, @@ -22,13 +24,13 @@ (sys.version_info.major, sys.version_info.minor) == (3, 12), reason="The error message for the abstract class instantiation seems to have changed between Python 3.11 and 3.12", ) -def test_dbt_base_operator_is_abstract(): +def test_dbt_base_is_abstract(): """Tests that the abstract base operator cannot be instantiated since the base_cmd is not defined.""" expected_error = ( - "Can't instantiate abstract class AbstractDbtBaseOperator with abstract methods base_cmd, build_and_run_cmd" + "Can't instantiate abstract class AbstractDbtBase with abstract methods base_cmd, build_and_run_cmd" ) with pytest.raises(TypeError, match=expected_error): - AbstractDbtBaseOperator() + AbstractDbtBase(project_dir="project_dir") @pytest.mark.skipif( @@ -38,21 +40,21 @@ def test_dbt_base_operator_is_abstract(): def test_dbt_base_operator_is_abstract_py12(): """Tests that the abstract base operator cannot be instantiated since the base_cmd is not defined.""" expected_error = ( - "Can't instantiate abstract class AbstractDbtBaseOperator without an implementation for abstract methods " + "Can't instantiate abstract class AbstractDbtBase without an implementation for abstract methods " "'base_cmd', 'build_and_run_cmd'" ) with pytest.raises(TypeError, match=expected_error): - AbstractDbtBaseOperator() + AbstractDbtBase(project_dir="project_dir") @pytest.mark.parametrize("cmd_flags", [["--some-flag"], []]) -@patch("cosmos.operators.base.AbstractDbtBaseOperator.build_and_run_cmd") +@patch("cosmos.operators.base.AbstractDbtBase.build_and_run_cmd") def test_dbt_base_operator_execute(mock_build_and_run_cmd, cmd_flags, monkeypatch): """Tests that the base operator execute method calls the build_and_run_cmd method with the expected arguments.""" - monkeypatch.setattr(AbstractDbtBaseOperator, "add_cmd_flags", lambda _: cmd_flags) - AbstractDbtBaseOperator.__abstractmethods__ = set() + monkeypatch.setattr(AbstractDbtBase, "add_cmd_flags", lambda _: cmd_flags) + AbstractDbtBase.__abstractmethods__ = set() - base_operator = AbstractDbtBaseOperator(task_id="fake_task", project_dir="fake_dir") + base_operator = AbstractDbtBase(task_id="fake_task", project_dir="fake_dir") base_operator.execute(context={}) mock_build_and_run_cmd.assert_called_once_with(context={}, cmd_flags=cmd_flags) @@ -61,7 +63,7 @@ def test_dbt_base_operator_execute(mock_build_and_run_cmd, cmd_flags, monkeypatc @patch("cosmos.operators.base.context_merge") def test_dbt_base_operator_context_merge_called(mock_context_merge): """Tests that the base operator execute method calls the context_merge method with the expected arguments.""" - base_operator = AbstractDbtBaseOperator( + base_operator = AbstractDbtBase( task_id="fake_task", project_dir="fake_dir", extra_context={"extra": "extra"}, @@ -125,7 +127,7 @@ def test_dbt_base_operator_context_merge( expected_context, ): """Tests that the base operator execute method calls and update context""" - base_operator = AbstractDbtBaseOperator( + base_operator = AbstractDbtBase( task_id="fake_task", project_dir="fake_dir", extra_context=extra_context, @@ -173,5 +175,21 @@ def test_dbt_mixin_add_cmd_flags_run_operator(args, expected_flags): def test_abstract_dbt_base_operator_append_env_is_false_by_default(): """Tests that the append_env attribute is set to False by default.""" - base_operator = AbstractDbtBaseOperator(task_id="fake_task", project_dir="fake_dir") + AbstractDbtBase.__abstractmethods__ = set() + base_operator = AbstractDbtBase(task_id="fake_task", project_dir="fake_dir") assert base_operator.append_env is False + + +def test_abstract_dbt_base_is_not_airflow_base_operator(): + AbstractDbtBase.__abstractmethods__ = set() + base_operator = AbstractDbtBase(task_id="fake_task", project_dir="fake_dir") + assert not isinstance(base_operator, BaseOperator) + + +def test_abstract_dbt_base_init_no_super(): + """Test that super().__init__ is not called in AbstractDbtBase""" + init_method = getattr(AbstractDbtBase, "__init__", None) + assert init_method is not None + + source = inspect.getsource(init_method) + assert "super().__init__" not in source diff --git a/tests/operators/test_kubernetes.py b/tests/operators/test_kubernetes.py index aee415f26..0562e28ce 100644 --- a/tests/operators/test_kubernetes.py +++ b/tests/operators/test_kubernetes.py @@ -191,20 +191,27 @@ def test_dbt_kubernetes_build_command(): not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available" ) def test_dbt_test_kubernetes_operator_constructor(additional_kwargs, expected_results): + # TODO: Refactor this test so that the asserts test according to the input parameters. test_operator = DbtTestKubernetesOperator( on_warning_callback=(lambda **kwargs: None), **additional_kwargs, **base_kwargs ) - print(additional_kwargs, test_operator.__dict__) - assert isinstance(test_operator.on_success_callback, list) - assert isinstance(test_operator.on_failure_callback, list) - assert test_operator._handle_warnings in test_operator.on_success_callback - assert test_operator._cleanup_pod in test_operator.on_failure_callback - assert len(test_operator.on_success_callback) == expected_results[0] - assert len(test_operator.on_failure_callback) == expected_results[1] + assert isinstance(test_operator.on_success_callback, list) or test_operator.on_success_callback is None + assert isinstance(test_operator.on_failure_callback, list) or test_operator.on_failure_callback is None + + if test_operator.on_success_callback is not None: + assert test_operator._handle_warnings in test_operator.on_success_callback + assert len(test_operator.on_success_callback) == expected_results[0] + + if test_operator.on_failure_callback is not None: + assert test_operator._cleanup_pod in test_operator.on_failure_callback + assert len(test_operator.on_failure_callback) == expected_results[1] + assert test_operator.is_delete_operator_pod_original == expected_results[2] - assert test_operator.on_finish_action_original == OnFinishAction(expected_results[3]) + + expected_action = OnFinishAction(expected_results[3]) + assert test_operator.on_finish_action_original == expected_action @pytest.mark.parametrize( @@ -247,20 +254,28 @@ def test_dbt_test_kubernetes_operator_constructor(additional_kwargs, expected_re not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available" ) def test_dbt_source_kubernetes_operator_constructor(additional_kwargs, expected_results): + # TODO: Refactor this test so that the asserts test according to the input parameters. source_operator = DbtSourceKubernetesOperator( on_warning_callback=(lambda **kwargs: None), **additional_kwargs, **base_kwargs ) print(additional_kwargs, source_operator.__dict__) - assert isinstance(source_operator.on_success_callback, list) - assert isinstance(source_operator.on_failure_callback, list) - assert source_operator._handle_warnings in source_operator.on_success_callback - assert source_operator._cleanup_pod in source_operator.on_failure_callback - assert len(source_operator.on_success_callback) == expected_results[0] - assert len(source_operator.on_failure_callback) == expected_results[1] + assert isinstance(source_operator.on_success_callback, list) or source_operator.on_success_callback is None + assert isinstance(source_operator.on_failure_callback, list) or source_operator.on_failure_callback is None + + if source_operator.on_success_callback is not None: + assert source_operator._handle_warnings in source_operator.on_success_callback + assert len(source_operator.on_success_callback) == expected_results[0] + + if source_operator.on_failure_callback is not None: + assert source_operator._cleanup_pod in source_operator.on_failure_callback + assert len(source_operator.on_failure_callback) == expected_results[1] + assert source_operator.is_delete_operator_pod_original == expected_results[2] - assert source_operator.on_finish_action_original == OnFinishAction(expected_results[3]) + + expected_action = OnFinishAction(expected_results[3]) + assert source_operator.on_finish_action_original == expected_action class FakePodManager: diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 69164a194..34c34d895 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -27,6 +27,7 @@ from cosmos.exceptions import CosmosDbtRunError, CosmosValueError from cosmos.hooks.subprocess import FullOutputSubprocessResult from cosmos.operators.local import ( + AbstractDbtLocalBase, DbtBuildLocalOperator, DbtCloneLocalOperator, DbtCompileLocalOperator, @@ -776,47 +777,89 @@ def test_store_compiled_sql() -> None: ( DbtSeedLocalOperator, {"full_refresh": True}, - {"context": {}, "env": {}, "cmd_flags": ["seed", "--full-refresh"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["seed", "--full-refresh"], + "run_as_async": False, + "async_context": None, + }, ), ( DbtBuildLocalOperator, {"full_refresh": True}, - {"context": {}, "env": {}, "cmd_flags": ["build", "--full-refresh"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["build", "--full-refresh"], + "run_as_async": False, + "async_context": None, + }, ), ( DbtRunLocalOperator, {"full_refresh": True}, - {"context": {}, "env": {}, "cmd_flags": ["run", "--full-refresh"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["run", "--full-refresh"], + "run_as_async": False, + "async_context": None, + }, ), ( DbtCloneLocalOperator, {"full_refresh": True}, - {"context": {}, "env": {}, "cmd_flags": ["clone", "--full-refresh"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["clone", "--full-refresh"], + "run_as_async": False, + "async_context": None, + }, ), ( DbtTestLocalOperator, {}, - {"context": {}, "env": {}, "cmd_flags": ["test"]}, + {"context": {}, "env": {}, "cmd_flags": ["test"], "run_as_async": False, "async_context": None}, ), ( DbtTestLocalOperator, {"select": []}, - {"context": {}, "env": {}, "cmd_flags": ["test"]}, + {"context": {}, "env": {}, "cmd_flags": ["test"], "run_as_async": False, "async_context": None}, ), ( DbtTestLocalOperator, {"full_refresh": True, "select": ["tag:daily"], "exclude": ["tag:disabled"]}, - {"context": {}, "env": {}, "cmd_flags": ["test", "--select", "tag:daily", "--exclude", "tag:disabled"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["test", "--select", "tag:daily", "--exclude", "tag:disabled"], + "run_as_async": False, + "async_context": None, + }, ), ( DbtTestLocalOperator, {"full_refresh": True, "selector": "nightly_snowplow"}, - {"context": {}, "env": {}, "cmd_flags": ["test", "--selector", "nightly_snowplow"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["test", "--selector", "nightly_snowplow"], + "run_as_async": False, + "async_context": None, + }, ), ( DbtRunOperationLocalOperator, {"args": {"days": 7, "dry_run": True}, "macro_name": "bla"}, - {"context": {}, "env": {}, "cmd_flags": ["run-operation", "bla", "--args", "days: 7\ndry_run: true\n"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["run-operation", "bla", "--args", "days: 7\ndry_run: true\n"], + "run_as_async": False, + "async_context": None, + }, ), ], ) @@ -1317,3 +1360,92 @@ def test_upload_compiled_sql_should_upload(mock_configure_remote, mock_object_st expected_dest_path = f"mock_remote_path/test_dag/compiled/{rel_path.lstrip('/')}" mock_object_storage_path.assert_any_call(expected_dest_path, conn_id="mock_conn_id") mock_object_storage_path.return_value.copy.assert_any_call(mock_object_storage_path.return_value) + + +MOCK_ADAPTER_CALLABLE_MAP = { + "snowflake": MagicMock(), + "bigquery": MagicMock(), +} + + +@pytest.fixture +def mock_adapter_map(monkeypatch): + monkeypatch.setattr( + "cosmos.operators.local.PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP", + MOCK_ADAPTER_CALLABLE_MAP, + ) + + +def test_mock_dbt_adapter_valid_context(mock_adapter_map): + """ + Test that the _mock_dbt_adapter method calls the correct mock adapter function + when provided with a valid async_context. + """ + async_context = { + "async_operator": MagicMock(), + "profile_type": "bigquery", + } + AbstractDbtLocalBase.__abstractmethods__ = set() + operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock()) + operator._mock_dbt_adapter(async_context) + + MOCK_ADAPTER_CALLABLE_MAP["bigquery"].assert_called_once() + + +def test_mock_dbt_adapter_missing_async_context(): + """ + Test that the _mock_dbt_adapter method raises a CosmosValueError + when async_context is None. + """ + AbstractDbtLocalBase.__abstractmethods__ = set() + operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock()) + with pytest.raises(CosmosValueError, match="`async_context` is necessary for running the model asynchronously"): + operator._mock_dbt_adapter(None) + + +def test_mock_dbt_adapter_missing_async_operator(): + """ + Test that the _mock_dbt_adapter method raises a CosmosValueError + when async_operator is missing in async_context. + """ + async_context = { + "profile_type": "snowflake", + } + AbstractDbtLocalBase.__abstractmethods__ = set() + operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock()) + with pytest.raises( + CosmosValueError, match="`async_operator` needs to be specified in `async_context` when running as async" + ): + operator._mock_dbt_adapter(async_context) + + +def test_mock_dbt_adapter_missing_profile_type(): + """ + Test that the _mock_dbt_adapter method raises a CosmosValueError + when profile_type is missing in async_context. + """ + async_context = { + "async_operator": MagicMock(), + } + AbstractDbtLocalBase.__abstractmethods__ = set() + operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock()) + with pytest.raises(CosmosValueError, match="`profile_type` needs to be specified in `async_context`"): + operator._mock_dbt_adapter(async_context) + + +def test_mock_dbt_adapter_unsupported_profile_type(mock_adapter_map): + """ + Test that the _mock_dbt_adapter method raises a CosmosValueError + when the profile_type is not supported. + """ + async_context = { + "async_operator": MagicMock(), + "profile_type": "unsupported_profile", + } + AbstractDbtLocalBase.__abstractmethods__ = set() + operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock()) + with pytest.raises( + CosmosValueError, + match="Mock adapter callable function not available for profile_type unsupported_profile", + ): + operator._mock_dbt_adapter(async_context)