Skip to content

Commit

Permalink
Add links for BigQuery Data Transfer (#22280)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 327eab3e26a3fb3e40a995facebb512cebb0fec2
  • Loading branch information
Wojciech Januszek authored and Cloud Composer Team committed Sep 12, 2024
1 parent 94f1d56 commit 29fb23f
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 4 deletions.
50 changes: 50 additions & 0 deletions airflow/providers/google/cloud/links/bigquery_dts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains Google BigQuery Data Transfer links."""
from typing import TYPE_CHECKING

from airflow.models import BaseOperator
from airflow.providers.google.cloud.links.base import BaseGoogleLink

if TYPE_CHECKING:
from airflow.utils.context import Context

BIGQUERY_BASE_LINK = "https://console.cloud.google.com/bigquery/transfers"
BIGQUERY_DTS_LINK = BIGQUERY_BASE_LINK + "/locations/{region}/configs/{config_id}/runs?project={project_id}"


class BigQueryDataTransferConfigLink(BaseGoogleLink):
"""Helper class for constructing BigQuery Data Transfer Config Link"""

name = "BigQuery Data Transfer Config"
key = "bigquery_dts_config"
format_str = BIGQUERY_DTS_LINK

@staticmethod
def persist(
context: "Context",
task_instance: BaseOperator,
region: str,
config_id: str,
project_id: str,
):
task_instance.xcom_push(
context,
key=BigQueryDataTransferConfigLink.key,
value={"project_id": project_id, "region": region, "config_id": config_id},
)
28 changes: 28 additions & 0 deletions airflow/providers/google/cloud/operators/bigquery_dts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@

from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook, get_object_id
from airflow.providers.google.cloud.links.bigquery_dts import BigQueryDataTransferConfigLink

if TYPE_CHECKING:
from airflow.utils.context import Context


def _get_transfer_config_details(config_transfer_name: str):
config_details = config_transfer_name.split("/")
return {"project_id": config_details[1], "region": config_details[3], "config_id": config_details[5]}


class BigQueryCreateDataTransferOperator(BaseOperator):
"""
Creates a new data transfer configuration.
Expand Down Expand Up @@ -67,6 +73,7 @@ class BigQueryCreateDataTransferOperator(BaseOperator):
"gcp_conn_id",
"impersonation_chain",
)
operator_extra_links = (BigQueryDataTransferConfigLink(),)

def __init__(
self,
Expand Down Expand Up @@ -106,6 +113,16 @@ def execute(self, context: 'Context'):
timeout=self.timeout,
metadata=self.metadata,
)

transfer_config = _get_transfer_config_details(response.name)
BigQueryDataTransferConfigLink.persist(
context=context,
task_instance=self,
region=transfer_config["region"],
config_id=transfer_config["config_id"],
project_id=transfer_config["project_id"],
)

result = TransferConfig.to_dict(response)
self.log.info("Created DTS transfer config %s", get_object_id(result))
self.xcom_push(context, key="transfer_config_id", value=get_object_id(result))
Expand Down Expand Up @@ -231,6 +248,7 @@ class BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator):
"gcp_conn_id",
"impersonation_chain",
)
operator_extra_links = (BigQueryDataTransferConfigLink(),)

def __init__(
self,
Expand Down Expand Up @@ -273,6 +291,16 @@ def execute(self, context: 'Context'):
timeout=self.timeout,
metadata=self.metadata,
)

transfer_config = _get_transfer_config_details(response.runs[0].name)
BigQueryDataTransferConfigLink.persist(
context=context,
task_instance=self,
region=transfer_config["region"],
config_id=transfer_config["config_id"],
project_id=transfer_config["project_id"],
)

result = StartManualTransferRunsResponse.to_dict(response)
run_id = get_object_id(result['runs'][0])
self.xcom_push(context, key="run_id", value=run_id)
Expand Down
4 changes: 4 additions & 0 deletions airflow/providers/google/cloud/sensors/bigquery_dts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.api_core.retry import Retry
from google.cloud.bigquery_datatransfer_v1 import TransferState

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -130,4 +131,7 @@ def poke(self, context: 'Context') -> bool:
metadata=self.metadata,
)
self.log.info("Status of %s run: %s", self.run_id, str(run.state))

if run.state in (TransferState.FAILED, TransferState.CANCELLED):
raise AirflowException(f"Transfer {self.run_id} did not succeed")
return run.state in self.expected_statuses
1 change: 1 addition & 0 deletions airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,7 @@ extra-links:
- airflow.providers.google.cloud.operators.datafusion.DataFusionPipelinesLink
- airflow.providers.google.cloud.links.dataplex.DataplexTaskLink
- airflow.providers.google.cloud.links.dataplex.DataplexTasksLink
- airflow.providers.google.cloud.links.bigquery_dts.BigQueryDataTransferConfigLink
- airflow.providers.google.cloud.links.dataproc.DataprocLink
- airflow.providers.google.cloud.links.dataproc.DataprocListLink
- airflow.providers.google.cloud.operators.dataproc_metastore.DataprocMetastoreDetailedLink
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/google/cloud/operators/test_bigquery_dts.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_execute(self, mock_hook):
retry=None,
timeout=None,
)
ti.xcom_push.assert_called_once_with(execution_date=None, key='transfer_config_id', value='1a2b3c')
ti.xcom_push.assert_called_with(execution_date=None, key='transfer_config_id', value='1a2b3c')


class BigQueryDeleteDataTransferConfigOperatorTestCase(unittest.TestCase):
Expand Down Expand Up @@ -111,4 +111,4 @@ def test_execute(self, mock_hook):
retry=None,
timeout=None,
)
ti.xcom_push.assert_called_once_with(execution_date=None, key='run_id', value='123')
ti.xcom_push.assert_called_with(execution_date=None, key='run_id', value='123')
6 changes: 4 additions & 2 deletions tests/providers/google/cloud/sensors/test_bigquery_dts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from unittest import mock
from unittest.mock import MagicMock as MM

import pytest
from google.cloud.bigquery_datatransfer_v1 import TransferState

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.sensors.bigquery_dts import BigQueryDataTransferServiceTransferRunSensor

TRANSFER_CONFIG_ID = "config_id"
Expand All @@ -42,9 +44,9 @@ def test_poke_returns_false(self, mock_hook):
project_id=PROJECT_ID,
expected_statuses={"SUCCEEDED"},
)
result = op.poke({})

assert result is False
with pytest.raises(AirflowException, match="Transfer"):
op.poke({})
mock_hook.return_value.get_transfer_run.assert_called_once_with(
transfer_config_id=TRANSFER_CONFIG_ID,
run_id=RUN_ID,
Expand Down

0 comments on commit 29fb23f

Please sign in to comment.