diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index b182bacf7..392b4f78b 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -19,6 +19,7 @@ from .snowflake.user_pass import SnowflakeUserPasswordProfileMapping from .snowflake.user_privatekey import SnowflakePrivateKeyPemProfileMapping from .spark.thrift import SparkThriftProfileMapping +from .teradata.user_pass import TeradataUserPasswordProfileMapping from .trino.certificate import TrinoCertificateProfileMapping from .trino.jwt import TrinoJWTProfileMapping from .trino.ldap import TrinoLDAPProfileMapping @@ -39,6 +40,7 @@ SnowflakePrivateKeyPemProfileMapping, SparkThriftProfileMapping, ExasolUserPasswordProfileMapping, + TeradataUserPasswordProfileMapping, TrinoLDAPProfileMapping, TrinoCertificateProfileMapping, TrinoJWTProfileMapping, @@ -79,6 +81,7 @@ def get_automatic_profile_mapping( "SnowflakeEncryptedPrivateKeyFilePemProfileMapping", "SparkThriftProfileMapping", "ExasolUserPasswordProfileMapping", + "TeradataUserPasswordProfileMapping", "TrinoLDAPProfileMapping", "TrinoCertificateProfileMapping", "TrinoJWTProfileMapping", diff --git a/cosmos/profiles/teradata/__init__.py b/cosmos/profiles/teradata/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cosmos/profiles/teradata/user_pass.py b/cosmos/profiles/teradata/user_pass.py new file mode 100644 index 000000000..059e4a9f0 --- /dev/null +++ b/cosmos/profiles/teradata/user_pass.py @@ -0,0 +1,51 @@ +"""Maps Airflow Snowflake connections to dbt profiles if they use a user/password.""" + +from __future__ import annotations + +from typing import Any + +from ..base import BaseProfileMapping + + +class TeradataUserPasswordProfileMapping(BaseProfileMapping): + """ + Maps Airflow Teradata connections using user + password authentication to dbt profiles. + https://docs.getdbt.com/docs/core/connect-data-platform/teradata-setup + https://airflow.apache.org/docs/apache-airflow-providers-teradata/stable/connections/teradata.html + """ + + airflow_connection_type: str = "teradata" + dbt_profile_type: str = "teradata" + is_community = True + + required_fields = [ + "host", + "user", + "password", + ] + secret_fields = [ + "password", + ] + airflow_param_mapping = { + "host": "host", + "user": "login", + "password": "password", + "schema": "schema", + "tmode": "extra.tmode", + } + + @property + def profile(self) -> dict[str, Any]: + """Gets profile. The password is stored in an environment variable.""" + profile = { + **self.mapped_params, + **self.profile_args, + # password should always get set as env var + "password": self.get_env_var_format("password"), + } + # schema is not mandatory in teradata. In teradata user itself a database so if schema is not mentioned + # in both airflow connection and profile_args then treating user as schema. + if "schema" not in self.profile_args and self.mapped_params.get("schema") is None: + profile["schema"] = profile["user"] + + return self.filter_null(profile) diff --git a/pyproject.toml b/pyproject.toml index 6c518613b..5cbc93a98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dbt-all = [ "dbt-redshift", "dbt-snowflake", "dbt-spark", + "dbt-teradata", "dbt-vertica", ] dbt-athena = ["dbt-athena-community", "apache-airflow-providers-amazon>=8.0.0"] @@ -62,6 +63,7 @@ dbt-postgres = ["dbt-postgres"] dbt-redshift = ["dbt-redshift"] dbt-snowflake = ["dbt-snowflake"] dbt-spark = ["dbt-spark"] +dbt-teradata = ["dbt-teradata"] dbt-vertica = ["dbt-vertica<=1.5.4"] openlineage = ["openlineage-integration-common!=1.15.0", "openlineage-airflow"] all = ["astronomer-cosmos[dbt-all]", "astronomer-cosmos[openlineage]"] diff --git a/tests/profiles/teradata/__init__.py b/tests/profiles/teradata/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/profiles/teradata/test_teradata_user_pass.py b/tests/profiles/teradata/test_teradata_user_pass.py new file mode 100644 index 000000000..ff28977fe --- /dev/null +++ b/tests/profiles/teradata/test_teradata_user_pass.py @@ -0,0 +1,176 @@ +"""Tests for the postgres profile.""" + +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.profiles.teradata.user_pass import TeradataUserPasswordProfileMapping + + +@pytest.fixture() +def mock_teradata_conn(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_teradata_connection", + conn_type="teradata", + host="my_host", + login="my_user", + password="my_password", + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +@pytest.fixture() +def mock_teradata_conn_custom_tmode(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_teradata_connection", + conn_type="teradata", + host="my_host", + login="my_user", + password="my_password", + schema="my_database", + extra='{"tmode": "TERA"}', + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the teradata profile mapping claims the correct connection type. + """ + # should only claim when: + # - conn_type == teradata + # and the following exist: + # - host + # - user + # - password + potential_values: dict[str, str] = { + "conn_type": "teradata", + "host": "my_host", + "login": "my_user", + "password": "my_password", + } + + # if we're missing any of the values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = TeradataUserPasswordProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # Even there is no schema, making user as schema as user itself schema in teradata + conn = Connection(**{k: v for k, v in potential_values.items() if k != "schema"}) + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = TeradataUserPasswordProfileMapping(conn, {"schema": None}) + assert profile_mapping.can_claim_connection() + # if we have them all, it should claim + conn = Connection(**potential_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = TeradataUserPasswordProfileMapping(conn, {"schema": "my_schema"}) + assert profile_mapping.can_claim_connection() + + +def test_profile_mapping_selected( + mock_teradata_conn: Connection, +) -> None: + """ + Tests that the correct profile mapping is selected. + """ + profile_mapping = get_automatic_profile_mapping( + mock_teradata_conn.conn_id, + ) + assert isinstance(profile_mapping, TeradataUserPasswordProfileMapping) + + +def test_profile_mapping_keeps_port(mock_teradata_conn: Connection) -> None: + # port is not handled in airflow connection so adding it as profile_args + profile = TeradataUserPasswordProfileMapping(mock_teradata_conn.conn_id, profile_args={"port": 1025}) + assert profile.profile["port"] == 1025 + + +def test_profile_mapping_keeps_custom_tmode(mock_teradata_conn_custom_tmode: Connection) -> None: + profile = TeradataUserPasswordProfileMapping(mock_teradata_conn_custom_tmode.conn_id) + assert profile.profile["tmode"] == "TERA" + + +def test_profile_args( + mock_teradata_conn: Connection, +) -> None: + """ + Tests that the profile values get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_teradata_conn.conn_id, + profile_args={"schema": "my_database"}, + ) + assert profile_mapping.profile_args == { + "schema": "my_database", + } + + assert profile_mapping.profile == { + "type": mock_teradata_conn.conn_type, + "host": mock_teradata_conn.host, + "user": mock_teradata_conn.login, + "password": "{{ env_var('COSMOS_CONN_TERADATA_PASSWORD') }}", + "schema": "my_database", + } + + +def test_profile_args_overrides( + mock_teradata_conn: Connection, +) -> None: + """ + Tests that you can override the profile values. + """ + profile_mapping = get_automatic_profile_mapping( + mock_teradata_conn.conn_id, + profile_args={"schema": "my_schema"}, + ) + assert profile_mapping.profile_args == { + "schema": "my_schema", + } + + assert profile_mapping.profile == { + "type": mock_teradata_conn.conn_type, + "host": mock_teradata_conn.host, + "user": mock_teradata_conn.login, + "password": "{{ env_var('COSMOS_CONN_TERADATA_PASSWORD') }}", + "schema": "my_schema", + } + + +def test_profile_env_vars( + mock_teradata_conn: Connection, +) -> None: + """ + Tests that the environment variables get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_teradata_conn.conn_id, + profile_args={"schema": "my_schema"}, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_TERADATA_PASSWORD": mock_teradata_conn.password, + } + + +def test_mock_profile() -> None: + """ + Tests that the mock profile port value get set correctly. + """ + profile = TeradataUserPasswordProfileMapping("mock_conn_id") + assert profile.mock_profile.get("host") == "mock_value"