Skip to content

Commit

Permalink
openlineage, redshift: do not call DB for schemas below Airflow 2.10
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski committed Jun 12, 2024
1 parent d509abf commit 7e9622b
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 98 deletions.
9 changes: 8 additions & 1 deletion airflow/providers/amazon/aws/hooks/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@
from typing import TYPE_CHECKING

import redshift_connector
from packaging.version import Version
from redshift_connector import Connection as RedshiftConnection
from sqlalchemy import create_engine
from sqlalchemy.engine.url import URL

from airflow import __version__ as AIRFLOW_VERSION
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook

_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")


if TYPE_CHECKING:
from airflow.models.connection import Connection
from airflow.providers.openlineage.sqlparser import DatabaseInfo
Expand Down Expand Up @@ -257,4 +262,6 @@ def get_openlineage_database_dialect(self, connection: Connection) -> str:

def get_openlineage_default_schema(self) -> str | None:
"""Return current schema. This is usually changed with ``SEARCH_PATH`` parameter."""
return self.get_first("SELECT CURRENT_SCHEMA();")[0]
if _IS_AIRFLOW_2_10_OR_HIGHER:
return self.get_first("SELECT CURRENT_SCHEMA();")[0]
return super().get_openlineage_default_schema()
9 changes: 7 additions & 2 deletions airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import attrs
from deprecated import deprecated
from openlineage.client.utils import RedactMixin
from packaging.version import Version

from airflow import __version__ as AIRFLOW_VERSION
from airflow.exceptions import AirflowProviderDeprecationWarning # TODO: move this maybe to Airflow's logic?
from airflow.models import DAG, BaseOperator, MappedOperator
from airflow.providers.openlineage import conf
Expand Down Expand Up @@ -57,6 +59,7 @@

log = logging.getLogger(__name__)
_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")


def try_import_from_string(string: str) -> Any:
Expand Down Expand Up @@ -558,5 +561,7 @@ def normalize_sql(sql: str | Iterable[str]):


def should_use_external_connection(hook) -> bool:
# TODO: Add checking overrides
return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook"]
# If we're at Airflow 2.10, the execution is process-isolated, so we can safely run those again.
if not _IS_AIRFLOW_2_10_OR_HIGHER:
return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook", "RedshiftSQLHook"]
return True
220 changes: 125 additions & 95 deletions tests/providers/amazon/aws/operators/test_redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from unittest.mock import MagicMock, call, patch
from unittest.mock import MagicMock, PropertyMock, call, patch

import pytest
from openlineage.client.facet import (
Expand All @@ -31,7 +31,7 @@
from openlineage.client.run import Dataset

from airflow.models.connection import Connection
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook as OriginalRedshiftSQLHook
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator

MOCK_REGION_NAME = "eu-north-1"
Expand All @@ -40,38 +40,64 @@
class TestRedshiftSQLOpenLineage:
@patch.dict("os.environ", AIRFLOW_CONN_AWS_DEFAULT=f"aws://?region_name={MOCK_REGION_NAME}")
@pytest.mark.parametrize(
"connection_host, connection_extra, expected_identity",
"connection_host, connection_extra, expected_identity, is_over_210, expected_schemaname",
[
# test without a connection host but with a cluster_identifier in connection extra
(
None,
{"iam": True, "cluster_identifier": "cluster_identifier_from_extra"},
f"cluster_identifier_from_extra.{MOCK_REGION_NAME}",
True,
"database.public",
),
# test with a connection host and without a cluster_identifier in connection extra
(
"cluster_identifier_from_host.id.my_region.redshift.amazonaws.com",
{"iam": True},
"cluster_identifier_from_host.my_region",
True,
"database.public",
),
# test with both connection host and cluster_identifier in connection extra
(
"cluster_identifier_from_host.x.y",
{"iam": True, "cluster_identifier": "cluster_identifier_from_extra"},
f"cluster_identifier_from_extra.{MOCK_REGION_NAME}",
True,
"database.public",
),
# test when hostname doesn't match pattern
("1.2.3.4", {}, "1.2.3.4", True, "database.public"),
# test with Airflow below 2.10 not using Hook connection
(
"1.2.3.4",
{},
"1.2.3.4",
"cluster_identifier_from_host.id.my_region.redshift.amazonaws.com",
{"iam": True},
"cluster_identifier_from_host.my_region",
False,
"public",
),
],
)
@patch(
"airflow.providers.amazon.aws.hooks.redshift_sql._IS_AIRFLOW_2_10_OR_HIGHER",
new_callable=PropertyMock,
)
@patch("airflow.providers.openlineage.utils.utils._IS_AIRFLOW_2_10_OR_HIGHER", new_callable=PropertyMock)
@patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
def test_execute_openlineage_events(
self, mock_aws_hook_conn, connection_host, connection_extra, expected_identity
self,
mock_aws_hook_conn,
mock_ol_utils,
mock_redshift_sql,
connection_host,
connection_extra,
expected_identity,
is_over_210,
expected_schemaname,
# self, mock_aws_hook_conn, connection_host, connection_extra, expected_identity, is_below_2_10, expected_schemaname
):
mock_ol_utils.__bool__ = lambda x: is_over_210
mock_redshift_sql.__bool__ = lambda x: is_over_210
DB_NAME = "database"
DB_SCHEMA_NAME = "public"

Expand All @@ -84,14 +110,15 @@ def test_execute_openlineage_events(
"DbUser": "IAM:user",
}

class RedshiftSQLHookForTests(RedshiftSQLHook):
class RedshiftSQLHook(OriginalRedshiftSQLHook):
get_conn = MagicMock(name="conn")
get_connection = MagicMock()

def get_first(self, *_):
self.log.error("CALLING FIRST")
return [f"{DB_NAME}.{DB_SCHEMA_NAME}"]

dbapi_hook = RedshiftSQLHookForTests()
dbapi_hook = RedshiftSQLHook()

class RedshiftOperatorForTest(SQLExecuteQueryOperator):
def get_db_hook(self):
Expand Down Expand Up @@ -149,94 +176,97 @@ def get_db_hook(self):
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = rows

lineage = op.get_openlineage_facets_on_start()
assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [
call(
"SELECT SVV_REDSHIFT_COLUMNS.schema_name, "
"SVV_REDSHIFT_COLUMNS.table_name, "
"SVV_REDSHIFT_COLUMNS.column_name, "
"SVV_REDSHIFT_COLUMNS.ordinal_position, "
"SVV_REDSHIFT_COLUMNS.data_type, "
"SVV_REDSHIFT_COLUMNS.database_name \n"
"FROM SVV_REDSHIFT_COLUMNS \n"
"WHERE SVV_REDSHIFT_COLUMNS.schema_name = 'database.public' "
"AND SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') "
"OR SVV_REDSHIFT_COLUMNS.database_name = 'another_db' "
"AND SVV_REDSHIFT_COLUMNS.schema_name = 'another_schema' AND "
"SVV_REDSHIFT_COLUMNS.table_name IN ('popular_orders_day_of_week')"
),
call(
"SELECT SVV_REDSHIFT_COLUMNS.schema_name, "
"SVV_REDSHIFT_COLUMNS.table_name, "
"SVV_REDSHIFT_COLUMNS.column_name, "
"SVV_REDSHIFT_COLUMNS.ordinal_position, "
"SVV_REDSHIFT_COLUMNS.data_type, "
"SVV_REDSHIFT_COLUMNS.database_name \n"
"FROM SVV_REDSHIFT_COLUMNS \n"
"WHERE SVV_REDSHIFT_COLUMNS.schema_name = 'database.public' "
"AND SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')"
),
]

if is_over_210:
assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [
call(
"SELECT SVV_REDSHIFT_COLUMNS.schema_name, "
"SVV_REDSHIFT_COLUMNS.table_name, "
"SVV_REDSHIFT_COLUMNS.column_name, "
"SVV_REDSHIFT_COLUMNS.ordinal_position, "
"SVV_REDSHIFT_COLUMNS.data_type, "
"SVV_REDSHIFT_COLUMNS.database_name \n"
"FROM SVV_REDSHIFT_COLUMNS \n"
f"WHERE SVV_REDSHIFT_COLUMNS.schema_name = '{expected_schemaname}' "
"AND SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') "
"OR SVV_REDSHIFT_COLUMNS.database_name = 'another_db' "
"AND SVV_REDSHIFT_COLUMNS.schema_name = 'another_schema' AND "
"SVV_REDSHIFT_COLUMNS.table_name IN ('popular_orders_day_of_week')"
),
call(
"SELECT SVV_REDSHIFT_COLUMNS.schema_name, "
"SVV_REDSHIFT_COLUMNS.table_name, "
"SVV_REDSHIFT_COLUMNS.column_name, "
"SVV_REDSHIFT_COLUMNS.ordinal_position, "
"SVV_REDSHIFT_COLUMNS.data_type, "
"SVV_REDSHIFT_COLUMNS.database_name \n"
"FROM SVV_REDSHIFT_COLUMNS \n"
f"WHERE SVV_REDSHIFT_COLUMNS.schema_name = '{expected_schemaname}' "
"AND SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')"
),
]
else:
assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == []
expected_namespace = f"redshift://{expected_identity}:5439"

assert lineage.inputs == [
Dataset(
namespace=expected_namespace,
name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="order_placed_on", type="timestamp"),
SchemaField(name="orders_placed", type="int4"),
]
)
},
),
Dataset(
namespace=expected_namespace,
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="additional_constant", type="varchar"),
]
)
},
),
]
assert lineage.outputs == [
Dataset(
namespace=expected_namespace,
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="order_placed_on", type="timestamp"),
SchemaField(name="orders_placed", type="int4"),
SchemaField(name="additional_constant", type="varchar"),
]
),
"columnLineage": ColumnLineageDatasetFacet(
fields={
"additional_constant": ColumnLineageDatasetFacetFieldsAdditional(
inputFields=[
ColumnLineageDatasetFacetFieldsAdditionalInputFields(
namespace=expected_namespace,
name="database.public.little_table",
field="additional_constant",
)
],
transformationDescription="",
transformationType="",
)
}
),
},
)
]
if is_over_210:
assert lineage.inputs == [
Dataset(
namespace=expected_namespace,
name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="order_placed_on", type="timestamp"),
SchemaField(name="orders_placed", type="int4"),
]
)
},
),
Dataset(
namespace=expected_namespace,
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="additional_constant", type="varchar"),
]
)
},
),
]
assert lineage.outputs == [
Dataset(
namespace=expected_namespace,
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="order_placed_on", type="timestamp"),
SchemaField(name="orders_placed", type="int4"),
SchemaField(name="additional_constant", type="varchar"),
]
),
"columnLineage": ColumnLineageDatasetFacet(
fields={
"additional_constant": ColumnLineageDatasetFacetFieldsAdditional(
inputFields=[
ColumnLineageDatasetFacetFieldsAdditionalInputFields(
namespace=expected_namespace,
name="database.public.little_table",
field="additional_constant",
)
],
transformationDescription="",
transformationType="",
)
}
),
},
)
]

assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)}

Expand Down

0 comments on commit 7e9622b

Please sign in to comment.