From be553780e56cf8c34a65aecf2c52a33b82e0e039 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Tue, 22 Oct 2024 19:13:42 +0200 Subject: [PATCH] feat: add OpenLineage support for RedshiftToS3Operator (#41632) Signed-off-by: Kacper Muda --- .../amazon/aws/transfers/redshift_to_s3.py | 113 +++++- .../amazon/aws/transfers/s3_to_redshift.py | 2 +- .../aws/transfers/test_redshift_to_s3.py | 330 ++++++++++++++++++ 3 files changed, 437 insertions(+), 8 deletions(-) diff --git a/providers/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/providers/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py index 8538b1dfc313c..0ed59e3db7e8e 100644 --- a/providers/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py +++ b/providers/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py @@ -152,6 +152,10 @@ def default_select_query(self) -> str | None: table = self.table return f"SELECT * FROM {table}" + @property + def use_redshift_data(self): + return bool(self.redshift_data_api_kwargs) + def execute(self, context: Context) -> None: if self.table and self.table_as_file_name: self.s3_key = f"{self.s3_key}/{self.table}_" @@ -164,14 +168,13 @@ def execute(self, context: Context) -> None: if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]: self.unload_options = [*self.unload_options, "HEADER"] - redshift_hook: RedshiftDataHook | RedshiftSQLHook - if self.redshift_data_api_kwargs: - redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id) + if self.use_redshift_data: + redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id) for arg in ["sql", "parameters"]: if arg in self.redshift_data_api_kwargs: raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs") else: - redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) + redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None if conn and conn.extra_dejson.get("role_arn", False): credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}" @@ -187,10 +190,106 @@ def execute(self, context: Context) -> None: ) self.log.info("Executing UNLOAD command...") - if isinstance(redshift_hook, RedshiftDataHook): - redshift_hook.execute_query( + if self.use_redshift_data: + redshift_data_hook.execute_query( sql=unload_query, parameters=self.parameters, **self.redshift_data_api_kwargs ) else: - redshift_hook.run(unload_query, self.autocommit, parameters=self.parameters) + redshift_sql_hook.run(unload_query, self.autocommit, parameters=self.parameters) self.log.info("UNLOAD command complete...") + + def get_openlineage_facets_on_complete(self, task_instance): + """Implement on_complete as we may query for table details.""" + from airflow.providers.amazon.aws.utils.openlineage import ( + get_facets_from_redshift_table, + get_identity_column_lineage_facet, + ) + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + Error, + ExtractionErrorRunFacet, + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + output_dataset = Dataset( + namespace=f"s3://{self.s3_bucket}", + name=self.s3_key, + ) + + if self.use_redshift_data: + redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id) + database = self.redshift_data_api_kwargs.get("database") + identifier = self.redshift_data_api_kwargs.get( + "cluster_identifier", self.redshift_data_api_kwargs.get("workgroup_name") + ) + port = self.redshift_data_api_kwargs.get("port", "5439") + authority = f"{identifier}.{redshift_data_hook.region_name}:{port}" + else: + redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id) + database = redshift_sql_hook.conn.schema + authority = redshift_sql_hook.get_openlineage_database_info(redshift_sql_hook.conn).authority + + if self.select_query == self.default_select_query: + if self.use_redshift_data: + input_dataset_facets = get_facets_from_redshift_table( + redshift_data_hook, self.table, self.redshift_data_api_kwargs, self.schema + ) + else: + input_dataset_facets = get_facets_from_redshift_table( + redshift_sql_hook, self.table, {}, self.schema + ) + + input_dataset = Dataset( + namespace=f"redshift://{authority}", + name=f"{database}.{self.schema}.{self.table}" if database else f"{self.schema}.{self.table}", + facets=input_dataset_facets, + ) + + # If default select query is used (SELECT *) output file matches the input table. + output_dataset.facets = { + "schema": input_dataset_facets["schema"], + "columnLineage": get_identity_column_lineage_facet( + field_names=[field.name for field in input_dataset_facets["schema"].fields], + input_datasets=[input_dataset], + ), + } + + return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset]) + + try: + from airflow.providers.openlineage.sqlparser import SQLParser, from_table_meta + except ImportError: + return OperatorLineage(outputs=[output_dataset]) + + run_facets = {} + parse_result = SQLParser(dialect="redshift", default_schema=self.schema).parse(self.select_query) + if parse_result.errors: + run_facets["extractionError"] = ExtractionErrorRunFacet( + totalTasks=1, + failedTasks=1, + errors=[ + Error( + errorMessage=error.message, + stackTrace=None, + task=error.origin_statement, + taskNumber=error.index, + ) + for error in parse_result.errors + ], + ) + + input_datasets = [] + for in_tb in parse_result.in_tables: + ds = from_table_meta(in_tb, database, f"redshift://{authority}", False) + schema, table = ds.name.split(".")[-2:] + if self.use_redshift_data: + input_dataset_facets = get_facets_from_redshift_table( + redshift_data_hook, table, self.redshift_data_api_kwargs, schema + ) + else: + input_dataset_facets = get_facets_from_redshift_table(redshift_sql_hook, table, {}, schema) + + ds.facets = input_dataset_facets + input_datasets.append(ds) + + return OperatorLineage(inputs=input_datasets, outputs=[output_dataset], run_facets=run_facets) diff --git a/providers/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/providers/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py index 792119bfebb55..df040b19eafac 100644 --- a/providers/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py +++ b/providers/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py @@ -241,7 +241,7 @@ def get_openlineage_facets_on_complete(self, task_instance): output_dataset = Dataset( namespace=f"redshift://{authority}", - name=f"{database}.{self.schema}.{self.table}", + name=f"{database}.{self.schema}.{self.table}" if database else f"{self.schema}.{self.table}", facets=output_dataset_facets, ) diff --git a/providers/tests/amazon/aws/transfers/test_redshift_to_s3.py b/providers/tests/amazon/aws/transfers/test_redshift_to_s3.py index 0bc4b83c3552c..30194e1b4819b 100644 --- a/providers/tests/amazon/aws/transfers/test_redshift_to_s3.py +++ b/providers/tests/amazon/aws/transfers/test_redshift_to_s3.py @@ -27,6 +27,14 @@ from airflow.models.connection import Connection from airflow.providers.amazon.aws.transfers.redshift_to_s3 import RedshiftToS3Operator from airflow.providers.amazon.aws.utils.redshift import build_credentials_block +from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + DocumentationDatasetFacet, + Fields, + InputField, + SchemaDatasetFacet, + SchemaDatasetFacetFields, +) from tests_common.test_utils.asserts import assert_equal_ignore_multiple_spaces @@ -591,3 +599,325 @@ def test_table_unloading_using_redshift_data_api( ) # test sql arg assert_equal_ignore_multiple_spaces(mock_rs.execute_statement.call_args.kwargs["Sql"], unload_query) + + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") + @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") + @mock.patch("boto3.session.Session") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") + @mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table") + def test_get_openlineage_facets_on_complete_default( + self, mock_get_facets, mock_run, mock_session, mock_connection, mock_hook + ): + access_key = "aws_access_key_id" + secret_key = "aws_secret_access_key" + mock_session.return_value = Session(access_key, secret_key) + mock_session.return_value.access_key = access_key + mock_session.return_value.secret_key = secret_key + mock_session.return_value.token = None + + mock_connection.return_value = mock.MagicMock( + schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={} + ) + mock_facets = { + "schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]), + "documentation": DocumentationDatasetFacet(description="mock_description"), + } + mock_get_facets.return_value = mock_facets + + schema = "schema" + table = "table" + s3_bucket = "bucket" + s3_key = "key" + + op = RedshiftToS3Operator( + schema=schema, + table=table, + s3_bucket=s3_bucket, + s3_key=s3_key, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + ) + op.execute(None) + + lineage = op.get_openlineage_facets_on_complete(None) + # Hook called only one time - on operator execution - we mocked querying to fetch schema + assert mock_run.call_count == 1 + + assert len(lineage.inputs) == 1 + assert len(lineage.outputs) == 1 + assert lineage.outputs[0].name == f"{s3_key}/{table}_" + assert lineage.outputs[0].namespace == f"s3://{s3_bucket}" + assert lineage.inputs[0].name == f"database.{schema}.{table}" + assert lineage.inputs[0].namespace == "redshift://cluster.region:5439" + + assert lineage.inputs[0].facets == mock_facets + assert lineage.outputs[0].facets == { + "columnLineage": ColumnLineageDatasetFacet( + fields={ + "col": Fields( + inputFields=[ + InputField( + namespace="redshift://cluster.region:5439", + name=f"database.{schema}.{table}", + field="col", + ) + ], + transformationType="IDENTITY", + transformationDescription="identical", + ) + } + ), + "schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]), + } + + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") + @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") + @mock.patch("boto3.session.Session") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") + @mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table") + def test_get_openlineage_facets_on_complete_with_select_query( + self, mock_get_facets, mock_run, mock_session, mock_connection, mock_hook + ): + access_key = "aws_access_key_id" + secret_key = "aws_secret_access_key" + mock_session.return_value = Session(access_key, secret_key) + mock_session.return_value.access_key = access_key + mock_session.return_value.secret_key = secret_key + mock_session.return_value.token = None + + mock_connection.return_value = mock.MagicMock( + schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={} + ) + mock_facets = { + "schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]), + "documentation": DocumentationDatasetFacet(description="mock_description"), + } + mock_get_facets.return_value = mock_facets + + schema = "schema" + table = "table" + s3_bucket = "bucket" + s3_key = "key" + query = """ + SELECT + c.customer_id, + c.first_name, + c.last_name, + o.order_id, + o.order_date, + o.total_amount + FROM + schema1.customers c + INNER JOIN + schema2.orders o + ON + c.customer_id = o.customer_id + ORDER BY + o.order_date DESC; + """ + op = RedshiftToS3Operator( + schema=schema, # should be ignored + table=table, # should be ignored + s3_bucket=s3_bucket, + s3_key=s3_key, + table_as_file_name=False, + select_query=query, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + ) + op.execute(None) + + lineage = op.get_openlineage_facets_on_complete(None) + + assert len(lineage.inputs) == 2 + assert len(lineage.outputs) == 1 + assert lineage.outputs[0].name == s3_key + assert lineage.outputs[0].namespace == f"s3://{s3_bucket}" + assert lineage.inputs[0].name == "database.schema1.customers" + assert lineage.inputs[0].namespace == "redshift://cluster.region:5439" + assert lineage.inputs[1].name == "database.schema2.orders" + assert lineage.inputs[1].namespace == "redshift://cluster.region:5439" + + assert lineage.inputs[0].facets == mock_facets + assert lineage.outputs[0].facets == {} + + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") + @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") + @mock.patch("boto3.session.Session") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.region_name", + new_callable=mock.PropertyMock, + ) + @mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table") + def test_get_openlineage_facets_on_complete_using_redshift_data_api( + self, mock_get_facets, mock_rs_region, mock_rs, mock_session, mock_connection, mock_hook + ): + """ + Using the Redshift Data API instead of the SQL-based connection + """ + access_key = "aws_access_key_id" + secret_key = "aws_secret_access_key" + mock_session.return_value = Session(access_key, secret_key) + mock_session.return_value.access_key = access_key + mock_session.return_value.secret_key = secret_key + mock_session.return_value.token = None + + mock_hook.return_value = Connection() + mock_rs.execute_statement.return_value = {"Id": "STATEMENT_ID"} + mock_rs.describe_statement.return_value = {"Status": "FINISHED"} + + mock_rs_region.return_value = "region" + mock_facets = { + "schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]), + "documentation": DocumentationDatasetFacet(description="mock_description"), + } + mock_get_facets.return_value = mock_facets + + schema = "schema" + table = "table" + s3_bucket = "bucket" + s3_key = "key" + + # RS Data API params + database = "database" + cluster_identifier = "cluster" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + + op = RedshiftToS3Operator( + schema=schema, + table=table, + s3_bucket=s3_bucket, + s3_key=s3_key, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + redshift_data_api_kwargs=dict( + database=database, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + ), + ) + op.execute(None) + + lineage = op.get_openlineage_facets_on_complete(None) + + assert len(lineage.inputs) == 1 + assert len(lineage.outputs) == 1 + assert lineage.outputs[0].name == f"{s3_key}/{table}_" + assert lineage.outputs[0].namespace == f"s3://{s3_bucket}" + assert lineage.inputs[0].name == f"database.{schema}.{table}" + assert lineage.inputs[0].namespace == "redshift://cluster.region:5439" + + assert lineage.inputs[0].facets == mock_facets + assert lineage.outputs[0].facets == { + "columnLineage": ColumnLineageDatasetFacet( + fields={ + "col": Fields( + inputFields=[ + InputField( + namespace="redshift://cluster.region:5439", + name=f"database.{schema}.{table}", + field="col", + ) + ], + transformationType="IDENTITY", + transformationDescription="identical", + ) + } + ), + "schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]), + } + + @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection") + @mock.patch("airflow.models.connection.Connection.get_connection_from_secrets") + @mock.patch("boto3.session.Session") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.region_name", + new_callable=mock.PropertyMock, + ) + @mock.patch("airflow.providers.amazon.aws.utils.openlineage.get_facets_from_redshift_table") + def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned( + self, mock_get_facets, mock_rs_region, mock_rs, mock_run, mock_session, mock_connection, mock_hook + ): + """ + Ensuring both supported hooks - RedshiftDataHook and RedshiftSQLHook return same lineage. + """ + access_key = "aws_access_key_id" + secret_key = "aws_secret_access_key" + mock_session.return_value = Session(access_key, secret_key) + mock_session.return_value.access_key = access_key + mock_session.return_value.secret_key = secret_key + mock_session.return_value.token = None + + mock_connection.return_value = mock.MagicMock( + schema="database", port=5439, host="cluster.id.region.redshift.amazonaws.com", extra_dejson={} + ) + mock_hook.return_value = Connection() + mock_rs.execute_statement.return_value = {"Id": "STATEMENT_ID"} + mock_rs.describe_statement.return_value = {"Status": "FINISHED"} + + mock_rs_region.return_value = "region" + mock_facets = { + "schema": SchemaDatasetFacet(fields=[SchemaDatasetFacetFields(name="col", type="STRING")]), + "documentation": DocumentationDatasetFacet(description="mock_description"), + } + mock_get_facets.return_value = mock_facets + + schema = "schema" + table = "table" + s3_bucket = "bucket" + s3_key = "key" + + # RS Data API params + database = "database" + cluster_identifier = "cluster" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + + op_rs_data = RedshiftToS3Operator( + schema=schema, + table=table, + s3_bucket=s3_bucket, + s3_key=s3_key, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + redshift_data_api_kwargs=dict( + database=database, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + ), + ) + op_rs_data.execute(None) + rs_data_lineage = op_rs_data.get_openlineage_facets_on_complete(None) + + op_rs_sql = RedshiftToS3Operator( + schema=schema, + table=table, + s3_bucket=s3_bucket, + s3_key=s3_key, + redshift_conn_id="redshift_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + ) + op_rs_sql.execute(None) + rs_sql_lineage = op_rs_sql.get_openlineage_facets_on_complete(None) + + assert len(rs_sql_lineage.inputs) == 1 + assert len(rs_sql_lineage.outputs) == 1 + assert rs_sql_lineage.inputs == rs_data_lineage.inputs + assert rs_sql_lineage.outputs == rs_data_lineage.outputs + assert rs_sql_lineage.job_facets == rs_data_lineage.job_facets + assert rs_sql_lineage.run_facets == rs_data_lineage.run_facets