diff --git a/cosmos/operators/_asynchronous/base.py b/cosmos/operators/_asynchronous/base.py index ddb518d53..89c3a5f84 100644 --- a/cosmos/operators/_asynchronous/base.py +++ b/cosmos/operators/_asynchronous/base.py @@ -14,6 +14,16 @@ def _create_async_operator_class(profile_type: str, dbt_class: str) -> Any: + """ + Dynamically constructs and returns an asynchronous operator class for the given profile type and DBT class name. + + The function constructs a class path string for an asynchronous operator, based on the provided `profile_type` and + `dbt_class`. It attempts to import the corresponding class dynamically and return it. If the class cannot be found, + it falls back to returning the `DbtRunLocalOperator` class. + + :param profile_type: The dbt profile type + :param dbt_class: The dbt class name. Example DbtRun, DbtTest. + """ execution_mode = ExecutionMode.AIRFLOW_ASYNC.value class_path = f"cosmos.operators._asynchronous.{profile_type}.{dbt_class}{_snake_case_to_camelcase(execution_mode)}{profile_type.capitalize()}Operator" try: diff --git a/cosmos/operators/_asynchronous/databricks.py b/cosmos/operators/_asynchronous/databricks.py index 956881227..d49fd0be0 100644 --- a/cosmos/operators/_asynchronous/databricks.py +++ b/cosmos/operators/_asynchronous/databricks.py @@ -1,4 +1,3 @@ -# pragma: no cover # TODO: Implement it from typing import Any diff --git a/tests/conftest.py b/tests/conftest.py index e69de29bb..d553fb7b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -0,0 +1,24 @@ +import json +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + + +@pytest.fixture() +def mock_bigquery_conn(): # type: ignore + """ + Mocks and returns an Airflow BigQuery connection. + """ + extra = { + "project": "my_project", + "key_path": "my_key_path.json", + } + conn = Connection( + conn_id="my_bigquery_connection", + conn_type="google_cloud_platform", + extra=json.dumps(extra), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn diff --git a/tests/operators/_asynchronous/test_base.py b/tests/operators/_asynchronous/test_base.py index ce3f07cca..c01bbd866 100644 --- a/tests/operators/_asynchronous/test_base.py +++ b/tests/operators/_asynchronous/test_base.py @@ -1,8 +1,6 @@ -import json from unittest.mock import patch import pytest -from airflow.models.connection import Connection from cosmos import ProfileConfig from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator, _create_async_operator_class @@ -11,25 +9,6 @@ from cosmos.profiles import get_automatic_profile_mapping -@pytest.fixture() -def mock_bigquery_conn(): # type: ignore - """ - Mocks and returns an Airflow BigQuery connection. - """ - extra = { - "project": "my_project", - "key_path": "my_key_path.json", - } - conn = Connection( - conn_id="my_bigquery_connection", - conn_type="google_cloud_platform", - extra=json.dumps(extra), - ) - - with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - yield conn - - @pytest.mark.parametrize( "profile_type, dbt_class, expected_operator_class", [ diff --git a/tests/operators/_asynchronous/test_bigquery.py b/tests/operators/_asynchronous/test_bigquery.py new file mode 100644 index 000000000..45bc8ecdc --- /dev/null +++ b/tests/operators/_asynchronous/test_bigquery.py @@ -0,0 +1,68 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from cosmos import ProfileConfig +from cosmos.exceptions import CosmosValueError +from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.settings import AIRFLOW_IO_AVAILABLE + + +def test_get_remote_sql_airflow_io_unavailable(mock_bigquery_conn): + profile_mapping = get_automatic_profile_mapping( + mock_bigquery_conn.conn_id, + profile_args={ + "dataset": "my_dataset", + }, + ) + bigquery_profile_config = ProfileConfig( + profile_name="my_profile", target_name="dev", profile_mapping=profile_mapping + ) + operator = DbtRunAirflowAsyncBigqueryOperator( + task_id="test_task", project_dir="/tmp", profile_config=bigquery_profile_config + ) + + operator.extra_context = { + "dbt_node_config": {"file_path": "/some/path/to/file.sql"}, + "dbt_dag_task_group_identifier": "task_group_1", + } + + if not AIRFLOW_IO_AVAILABLE: + with pytest.raises( + CosmosValueError, match="Cosmos async support is only available starting in Airflow 2.8 or later." + ): + operator.get_remote_sql() + + +def test_get_remote_sql_success(mock_bigquery_conn): + profile_mapping = get_automatic_profile_mapping( + mock_bigquery_conn.conn_id, + profile_args={ + "dataset": "my_dataset", + }, + ) + bigquery_profile_config = ProfileConfig( + profile_name="my_profile", target_name="dev", profile_mapping=profile_mapping + ) + operator = DbtRunAirflowAsyncBigqueryOperator( + task_id="test_task", project_dir="/tmp", profile_config=bigquery_profile_config + ) + + operator.extra_context = { + "dbt_node_config": {"file_path": "/some/path/to/file.sql"}, + "dbt_dag_task_group_identifier": "task_group_1", + } + operator.project_dir = "/tmp" + + mock_object_storage_path = MagicMock() + mock_file = MagicMock() + mock_file.read.return_value = "SELECT * FROM table" + + mock_object_storage_path.open.return_value.__enter__.return_value = mock_file + + with patch("airflow.io.path.ObjectStoragePath", return_value=mock_object_storage_path): + remote_sql = operator.get_remote_sql() + + assert remote_sql == "SELECT * FROM table" + mock_object_storage_path.open.assert_called_once() diff --git a/tests/operators/_asynchronous/test_databricks.py b/tests/operators/_asynchronous/test_databricks.py new file mode 100644 index 000000000..01e673d9c --- /dev/null +++ b/tests/operators/_asynchronous/test_databricks.py @@ -0,0 +1,9 @@ +import pytest + +from cosmos.operators._asynchronous.databricks import DbtRunAirflowAsyncDatabricksOperator + + +def test_execute_should_raise_not_implemented_error(): + operator = DbtRunAirflowAsyncDatabricksOperator(task_id="test_task") + with pytest.raises(NotImplementedError): + operator.execute(context={})