Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openlineage, redshift: do not call DB for schemas below Airflow 2.10 #40197

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion airflow/providers/amazon/aws/hooks/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@
from typing import TYPE_CHECKING

import redshift_connector
from packaging.version import Version
from redshift_connector import Connection as RedshiftConnection
from sqlalchemy import create_engine
from sqlalchemy.engine.url import URL

from airflow import __version__ as AIRFLOW_VERSION
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook

_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")
mobuchowski marked this conversation as resolved.
Show resolved Hide resolved


if TYPE_CHECKING:
from airflow.models.connection import Connection
from airflow.providers.openlineage.sqlparser import DatabaseInfo
Expand Down Expand Up @@ -257,4 +262,6 @@ def get_openlineage_database_dialect(self, connection: Connection) -> str:

def get_openlineage_default_schema(self) -> str | None:
"""Return current schema. This is usually changed with ``SEARCH_PATH`` parameter."""
return self.get_first("SELECT CURRENT_SCHEMA();")[0]
if _IS_AIRFLOW_2_10_OR_HIGHER:
return self.get_first("SELECT CURRENT_SCHEMA();")[0]
return super().get_openlineage_default_schema()
9 changes: 7 additions & 2 deletions airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import attrs
from deprecated import deprecated
from openlineage.client.utils import RedactMixin
from packaging.version import Version

from airflow import __version__ as AIRFLOW_VERSION
from airflow.exceptions import AirflowProviderDeprecationWarning # TODO: move this maybe to Airflow's logic?
from airflow.models import DAG, BaseOperator, MappedOperator
from airflow.providers.openlineage import conf
Expand Down Expand Up @@ -57,6 +59,7 @@

log = logging.getLogger(__name__)
_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")
mobuchowski marked this conversation as resolved.
Show resolved Hide resolved


def try_import_from_string(string: str) -> Any:
Expand Down Expand Up @@ -558,5 +561,7 @@ def normalize_sql(sql: str | Iterable[str]):


def should_use_external_connection(hook) -> bool:
# TODO: Add checking overrides
return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook"]
# If we're at Airflow 2.10, the execution is process-isolated, so we can safely run those again.
if not _IS_AIRFLOW_2_10_OR_HIGHER:
return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook", "RedshiftSQLHook"]
return True
220 changes: 125 additions & 95 deletions tests/providers/amazon/aws/operators/test_redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from unittest.mock import MagicMock, call, patch
from unittest.mock import MagicMock, PropertyMock, call, patch

import pytest
from openlineage.client.facet import (
Expand All @@ -31,7 +31,7 @@
from openlineage.client.run import Dataset

from airflow.models.connection import Connection
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook as OriginalRedshiftSQLHook
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator

MOCK_REGION_NAME = "eu-north-1"
Expand All @@ -40,38 +40,64 @@
class TestRedshiftSQLOpenLineage:
@patch.dict("os.environ", AIRFLOW_CONN_AWS_DEFAULT=f"aws://?region_name={MOCK_REGION_NAME}")
@pytest.mark.parametrize(
"connection_host, connection_extra, expected_identity",
"connection_host, connection_extra, expected_identity, is_over_210, expected_schemaname",
[
# test without a connection host but with a cluster_identifier in connection extra
(
None,
{"iam": True, "cluster_identifier": "cluster_identifier_from_extra"},
f"cluster_identifier_from_extra.{MOCK_REGION_NAME}",
True,
"database.public",
),
# test with a connection host and without a cluster_identifier in connection extra
(
"cluster_identifier_from_host.id.my_region.redshift.amazonaws.com",
{"iam": True},
"cluster_identifier_from_host.my_region",
True,
"database.public",
),
# test with both connection host and cluster_identifier in connection extra
(
"cluster_identifier_from_host.x.y",
{"iam": True, "cluster_identifier": "cluster_identifier_from_extra"},
f"cluster_identifier_from_extra.{MOCK_REGION_NAME}",
True,
"database.public",
),
# test when hostname doesn't match pattern
("1.2.3.4", {}, "1.2.3.4", True, "database.public"),
# test with Airflow below 2.10 not using Hook connection
(
"1.2.3.4",
{},
"1.2.3.4",
"cluster_identifier_from_host.id.my_region.redshift.amazonaws.com",
{"iam": True},
"cluster_identifier_from_host.my_region",
False,
"public",
),
],
)
@patch(
"airflow.providers.amazon.aws.hooks.redshift_sql._IS_AIRFLOW_2_10_OR_HIGHER",
new_callable=PropertyMock,
)
@patch("airflow.providers.openlineage.utils.utils._IS_AIRFLOW_2_10_OR_HIGHER", new_callable=PropertyMock)
@patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
def test_execute_openlineage_events(
self, mock_aws_hook_conn, connection_host, connection_extra, expected_identity
self,
mock_aws_hook_conn,
mock_ol_utils,
mock_redshift_sql,
connection_host,
connection_extra,
expected_identity,
is_over_210,
expected_schemaname,
# self, mock_aws_hook_conn, connection_host, connection_extra, expected_identity, is_below_2_10, expected_schemaname
):
mock_ol_utils.__bool__ = lambda x: is_over_210
mock_redshift_sql.__bool__ = lambda x: is_over_210
DB_NAME = "database"
DB_SCHEMA_NAME = "public"

Expand All @@ -84,14 +110,15 @@ def test_execute_openlineage_events(
"DbUser": "IAM:user",
}

class RedshiftSQLHookForTests(RedshiftSQLHook):
class RedshiftSQLHook(OriginalRedshiftSQLHook):
get_conn = MagicMock(name="conn")
get_connection = MagicMock()

def get_first(self, *_):
self.log.error("CALLING FIRST")
return [f"{DB_NAME}.{DB_SCHEMA_NAME}"]

dbapi_hook = RedshiftSQLHookForTests()
dbapi_hook = RedshiftSQLHook()

class RedshiftOperatorForTest(SQLExecuteQueryOperator):
def get_db_hook(self):
Expand Down Expand Up @@ -149,94 +176,97 @@ def get_db_hook(self):
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = rows

lineage = op.get_openlineage_facets_on_start()
assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [
call(
"SELECT SVV_REDSHIFT_COLUMNS.schema_name, "
"SVV_REDSHIFT_COLUMNS.table_name, "
"SVV_REDSHIFT_COLUMNS.column_name, "
"SVV_REDSHIFT_COLUMNS.ordinal_position, "
"SVV_REDSHIFT_COLUMNS.data_type, "
"SVV_REDSHIFT_COLUMNS.database_name \n"
"FROM SVV_REDSHIFT_COLUMNS \n"
"WHERE SVV_REDSHIFT_COLUMNS.schema_name = 'database.public' "
"AND SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') "
"OR SVV_REDSHIFT_COLUMNS.database_name = 'another_db' "
"AND SVV_REDSHIFT_COLUMNS.schema_name = 'another_schema' AND "
"SVV_REDSHIFT_COLUMNS.table_name IN ('popular_orders_day_of_week')"
),
call(
"SELECT SVV_REDSHIFT_COLUMNS.schema_name, "
"SVV_REDSHIFT_COLUMNS.table_name, "
"SVV_REDSHIFT_COLUMNS.column_name, "
"SVV_REDSHIFT_COLUMNS.ordinal_position, "
"SVV_REDSHIFT_COLUMNS.data_type, "
"SVV_REDSHIFT_COLUMNS.database_name \n"
"FROM SVV_REDSHIFT_COLUMNS \n"
"WHERE SVV_REDSHIFT_COLUMNS.schema_name = 'database.public' "
"AND SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')"
),
]

if is_over_210:
assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [
call(
"SELECT SVV_REDSHIFT_COLUMNS.schema_name, "
"SVV_REDSHIFT_COLUMNS.table_name, "
"SVV_REDSHIFT_COLUMNS.column_name, "
"SVV_REDSHIFT_COLUMNS.ordinal_position, "
"SVV_REDSHIFT_COLUMNS.data_type, "
"SVV_REDSHIFT_COLUMNS.database_name \n"
"FROM SVV_REDSHIFT_COLUMNS \n"
f"WHERE SVV_REDSHIFT_COLUMNS.schema_name = '{expected_schemaname}' "
"AND SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') "
"OR SVV_REDSHIFT_COLUMNS.database_name = 'another_db' "
"AND SVV_REDSHIFT_COLUMNS.schema_name = 'another_schema' AND "
"SVV_REDSHIFT_COLUMNS.table_name IN ('popular_orders_day_of_week')"
),
call(
"SELECT SVV_REDSHIFT_COLUMNS.schema_name, "
"SVV_REDSHIFT_COLUMNS.table_name, "
"SVV_REDSHIFT_COLUMNS.column_name, "
"SVV_REDSHIFT_COLUMNS.ordinal_position, "
"SVV_REDSHIFT_COLUMNS.data_type, "
"SVV_REDSHIFT_COLUMNS.database_name \n"
"FROM SVV_REDSHIFT_COLUMNS \n"
f"WHERE SVV_REDSHIFT_COLUMNS.schema_name = '{expected_schemaname}' "
"AND SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')"
),
]
else:
assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == []
expected_namespace = f"redshift://{expected_identity}:5439"

assert lineage.inputs == [
Dataset(
namespace=expected_namespace,
name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="order_placed_on", type="timestamp"),
SchemaField(name="orders_placed", type="int4"),
]
)
},
),
Dataset(
namespace=expected_namespace,
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="additional_constant", type="varchar"),
]
)
},
),
]
assert lineage.outputs == [
Dataset(
namespace=expected_namespace,
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="order_placed_on", type="timestamp"),
SchemaField(name="orders_placed", type="int4"),
SchemaField(name="additional_constant", type="varchar"),
]
),
"columnLineage": ColumnLineageDatasetFacet(
fields={
"additional_constant": ColumnLineageDatasetFacetFieldsAdditional(
inputFields=[
ColumnLineageDatasetFacetFieldsAdditionalInputFields(
namespace=expected_namespace,
name="database.public.little_table",
field="additional_constant",
)
],
transformationDescription="",
transformationType="",
)
}
),
},
)
]
if is_over_210:
assert lineage.inputs == [
Dataset(
namespace=expected_namespace,
name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="order_placed_on", type="timestamp"),
SchemaField(name="orders_placed", type="int4"),
]
)
},
),
Dataset(
namespace=expected_namespace,
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="additional_constant", type="varchar"),
]
)
},
),
]
assert lineage.outputs == [
Dataset(
namespace=expected_namespace,
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaField(name="order_day_of_week", type="varchar"),
SchemaField(name="order_placed_on", type="timestamp"),
SchemaField(name="orders_placed", type="int4"),
SchemaField(name="additional_constant", type="varchar"),
]
),
"columnLineage": ColumnLineageDatasetFacet(
fields={
"additional_constant": ColumnLineageDatasetFacetFieldsAdditional(
inputFields=[
ColumnLineageDatasetFacetFieldsAdditionalInputFields(
namespace=expected_namespace,
name="database.public.little_table",
field="additional_constant",
)
],
transformationDescription="",
transformationType="",
)
}
),
},
)
]

assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)}

Expand Down