From 96fb5d78cfc67c19f299dcc0b4219df85dac513e Mon Sep 17 00:00:00 2001 From: tuancamtbtx Date: Tue, 2 Jul 2024 12:08:04 +0700 Subject: [PATCH] [add] sensors --- airlake/sensors/bq_sensor.py | 42 ++++++++++++- airlake/sensors/external_task_sensor.py | 84 +++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) create mode 100644 airlake/sensors/external_task_sensor.py diff --git a/airlake/sensors/bq_sensor.py b/airlake/sensors/bq_sensor.py index f6af16a..8da78fc 100644 --- a/airlake/sensors/bq_sensor.py +++ b/airlake/sensors/bq_sensor.py @@ -1,2 +1,40 @@ -class BigquerySensor: - pass \ No newline at end of file +import logging + +from airflow.sensors.base import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + +from airlake.hooks.gcs_hook import BigQueryNativeHook + + +class BigQuerySQLSensor(BaseSensorOperator): + template_fields = ("sql",) + ui_color = "#f0eee4" + + @apply_defaults + def __init__( + self, + sql: str, + bigquery_conn_id="bigquery_default_conn", + delegate_to=None, + *args, + **kwargs + ): + if not sql: + raise Exception("Must have sql") + super(BigQuerySQLSensor, self).__init__(*args, **kwargs) + self.sql = sql + self.bigquery_conn_id = bigquery_conn_id + self.delegate_to = delegate_to + + def poke(self, context): + self.log.info("Sensor checks sql:\n%s", self.sql) + hook = BigQueryNativeHook( + gcp_conn_id=self.bigquery_conn_id, + delegate_to=self.delegate_to, + ) + + try: + return hook.total_rows(self.sql) > 0 + except Exception as e: + logging.error("Execute error %s", e) + return False diff --git a/airlake/sensors/external_task_sensor.py b/airlake/sensors/external_task_sensor.py new file mode 100644 index 0000000..8955143 --- /dev/null +++ b/airlake/sensors/external_task_sensor.py @@ -0,0 +1,84 @@ +from datetime import timedelta + + +from airflow.sensors.external_task import ExternalTaskSensor as AExternalTaskSensor +from airflow.models import BaseOperatorLink, DagBag, DagModel, DagRun, TaskInstance +from sqlalchemy import func + + +class ExecutionDateFn: + def __init__(self, execution_delta_hour: int = None, + execution_delta_minutes: int = None): + self.execution_delta_hour = execution_delta_hour + self.execution_delta_minutes = execution_delta_minutes + + def __call__(self, execution_date, context=None): + execution_date = context['data_interval_start'] + if self.execution_delta_hour is not None: + return execution_date - timedelta(hours=self.execution_delta_hour) + elif self.execution_delta_minutes is not None: + return execution_date - timedelta(minutes=self.execution_delta_minutes) + return execution_date + +class ExternalTaskSensor(AExternalTaskSensor): + def __init__( + self, + execution_delta_hour: int = None, + execution_delta_minutes: int = None, + *args, + **kwargs): + kwargs.pop('execution_delta', None) + kwargs['execution_date_fn'] = ExecutionDateFn(execution_delta_hour, execution_delta_minutes) + kwargs['check_existence'] = True + super(ExternalTaskSensor, self).__init__(*args, **kwargs) + + def get_count(self, dttm_filter, session, states) -> int: + """ + Get the count of records against dttm filter and states + + :param dttm_filter: date time filter for execution date + :type dttm_filter: list + :param session: airflow session object + :type session: SASession + :param states: task or dag states + :type states: list + :return: count of record against the filters + """ + TI = TaskInstance + DR = DagRun + if not dttm_filter: + return 0 + if self.external_task_id: + dag_run = ( + session.query(DR) + .filter( + DR.dag_id == self.external_dag_id, + DR.data_interval_start == dttm_filter[0] + ) + .first() + ) + if not dag_run: + return 0 + count = ( + session.query(func.count()) # .count() is inefficient + .filter( + TI.dag_id == self.external_dag_id, + TI.task_id == self.external_task_id, + TI.state.in_(states), + TI.run_id == dag_run.run_id, + ) + .scalar() + ) + else: + count = ( + session.query(func.count()) + .filter( + DR.dag_id == self.external_dag_id, + DR.state.in_(states), + DR.data_interval_start.in_(dttm_filter), + ) + .scalar() + ) + if count > 1: + return 1 + return count