Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Jan 24, 2025
1 parent bed60ca commit 8b4518f
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 22 deletions.
10 changes: 10 additions & 0 deletions cosmos/operators/_asynchronous/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@


def _create_async_operator_class(profile_type: str, dbt_class: str) -> Any:
"""
Dynamically constructs and returns an asynchronous operator class for the given profile type and DBT class name.
The function constructs a class path string for an asynchronous operator, based on the provided `profile_type` and
`dbt_class`. It attempts to import the corresponding class dynamically and return it. If the class cannot be found,
it falls back to returning the `DbtRunLocalOperator` class.
:param profile_type: The dbt profile type
:param dbt_class: The dbt class name. Example DbtRun, DbtTest.
"""
execution_mode = ExecutionMode.AIRFLOW_ASYNC.value
class_path = f"cosmos.operators._asynchronous.{profile_type}.{dbt_class}{_snake_case_to_camelcase(execution_mode)}{profile_type.capitalize()}Operator"
try:
Expand Down
1 change: 0 additions & 1 deletion cosmos/operators/_asynchronous/databricks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# pragma: no cover
# TODO: Implement it

from typing import Any
Expand Down
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection


@pytest.fixture()
def mock_bigquery_conn(): # type: ignore
"""
Mocks and returns an Airflow BigQuery connection.
"""
extra = {
"project": "my_project",
"key_path": "my_key_path.json",
}
conn = Connection(
conn_id="my_bigquery_connection",
conn_type="google_cloud_platform",
extra=json.dumps(extra),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn
21 changes: 0 additions & 21 deletions tests/operators/_asynchronous/test_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos import ProfileConfig
from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator, _create_async_operator_class
Expand All @@ -11,25 +9,6 @@
from cosmos.profiles import get_automatic_profile_mapping


@pytest.fixture()
def mock_bigquery_conn(): # type: ignore
"""
Mocks and returns an Airflow BigQuery connection.
"""
extra = {
"project": "my_project",
"key_path": "my_key_path.json",
}
conn = Connection(
conn_id="my_bigquery_connection",
conn_type="google_cloud_platform",
extra=json.dumps(extra),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


@pytest.mark.parametrize(
"profile_type, dbt_class, expected_operator_class",
[
Expand Down
68 changes: 68 additions & 0 deletions tests/operators/_asynchronous/test_bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from unittest.mock import MagicMock, patch

import pytest

from cosmos import ProfileConfig
from cosmos.exceptions import CosmosValueError
from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator
from cosmos.profiles import get_automatic_profile_mapping
from cosmos.settings import AIRFLOW_IO_AVAILABLE


def test_get_remote_sql_airflow_io_unavailable(mock_bigquery_conn):
profile_mapping = get_automatic_profile_mapping(
mock_bigquery_conn.conn_id,
profile_args={
"dataset": "my_dataset",
},
)
bigquery_profile_config = ProfileConfig(
profile_name="my_profile", target_name="dev", profile_mapping=profile_mapping
)
operator = DbtRunAirflowAsyncBigqueryOperator(
task_id="test_task", project_dir="/tmp", profile_config=bigquery_profile_config
)

operator.extra_context = {
"dbt_node_config": {"file_path": "/some/path/to/file.sql"},
"dbt_dag_task_group_identifier": "task_group_1",
}

if not AIRFLOW_IO_AVAILABLE:
with pytest.raises(
CosmosValueError, match="Cosmos async support is only available starting in Airflow 2.8 or later."
):
operator.get_remote_sql()


def test_get_remote_sql_success(mock_bigquery_conn):
profile_mapping = get_automatic_profile_mapping(
mock_bigquery_conn.conn_id,
profile_args={
"dataset": "my_dataset",
},
)
bigquery_profile_config = ProfileConfig(
profile_name="my_profile", target_name="dev", profile_mapping=profile_mapping
)
operator = DbtRunAirflowAsyncBigqueryOperator(
task_id="test_task", project_dir="/tmp", profile_config=bigquery_profile_config
)

operator.extra_context = {
"dbt_node_config": {"file_path": "/some/path/to/file.sql"},
"dbt_dag_task_group_identifier": "task_group_1",
}
operator.project_dir = "/tmp"

mock_object_storage_path = MagicMock()
mock_file = MagicMock()
mock_file.read.return_value = "SELECT * FROM table"

mock_object_storage_path.open.return_value.__enter__.return_value = mock_file

with patch("airflow.io.path.ObjectStoragePath", return_value=mock_object_storage_path):
remote_sql = operator.get_remote_sql()

assert remote_sql == "SELECT * FROM table"
mock_object_storage_path.open.assert_called_once()
9 changes: 9 additions & 0 deletions tests/operators/_asynchronous/test_databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest

from cosmos.operators._asynchronous.databricks import DbtRunAirflowAsyncDatabricksOperator


def test_execute_should_raise_not_implemented_error():
operator = DbtRunAirflowAsyncDatabricksOperator(task_id="test_task")
with pytest.raises(NotImplementedError):
operator.execute(context={})

0 comments on commit 8b4518f

Please sign in to comment.