Skip to content

Commit

Permalink
[add] sensors
Browse files Browse the repository at this point in the history
  • Loading branch information
tuancamtbtx committed Jul 2, 2024
1 parent 1c0f8dd commit 96fb5d7
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 2 deletions.
42 changes: 40 additions & 2 deletions airlake/sensors/bq_sensor.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,40 @@
class BigquerySensor:
pass
import logging

from airflow.sensors.base import BaseSensorOperator
from airflow.utils.decorators import apply_defaults

from airlake.hooks.gcs_hook import BigQueryNativeHook


class BigQuerySQLSensor(BaseSensorOperator):
template_fields = ("sql",)
ui_color = "#f0eee4"

@apply_defaults
def __init__(
self,
sql: str,
bigquery_conn_id="bigquery_default_conn",
delegate_to=None,
*args,
**kwargs
):
if not sql:
raise Exception("Must have sql")
super(BigQuerySQLSensor, self).__init__(*args, **kwargs)
self.sql = sql
self.bigquery_conn_id = bigquery_conn_id
self.delegate_to = delegate_to

def poke(self, context):
self.log.info("Sensor checks sql:\n%s", self.sql)
hook = BigQueryNativeHook(
gcp_conn_id=self.bigquery_conn_id,
delegate_to=self.delegate_to,
)

try:
return hook.total_rows(self.sql) > 0
except Exception as e:
logging.error("Execute error %s", e)
return False
84 changes: 84 additions & 0 deletions airlake/sensors/external_task_sensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from datetime import timedelta


from airflow.sensors.external_task import ExternalTaskSensor as AExternalTaskSensor
from airflow.models import BaseOperatorLink, DagBag, DagModel, DagRun, TaskInstance
from sqlalchemy import func


class ExecutionDateFn:
def __init__(self, execution_delta_hour: int = None,
execution_delta_minutes: int = None):
self.execution_delta_hour = execution_delta_hour
self.execution_delta_minutes = execution_delta_minutes

def __call__(self, execution_date, context=None):
execution_date = context['data_interval_start']
if self.execution_delta_hour is not None:
return execution_date - timedelta(hours=self.execution_delta_hour)
elif self.execution_delta_minutes is not None:
return execution_date - timedelta(minutes=self.execution_delta_minutes)
return execution_date

class ExternalTaskSensor(AExternalTaskSensor):
def __init__(
self,
execution_delta_hour: int = None,
execution_delta_minutes: int = None,
*args,
**kwargs):
kwargs.pop('execution_delta', None)
kwargs['execution_date_fn'] = ExecutionDateFn(execution_delta_hour, execution_delta_minutes)
kwargs['check_existence'] = True
super(ExternalTaskSensor, self).__init__(*args, **kwargs)

def get_count(self, dttm_filter, session, states) -> int:
"""
Get the count of records against dttm filter and states
:param dttm_filter: date time filter for execution date
:type dttm_filter: list
:param session: airflow session object
:type session: SASession
:param states: task or dag states
:type states: list
:return: count of record against the filters
"""
TI = TaskInstance
DR = DagRun
if not dttm_filter:
return 0
if self.external_task_id:
dag_run = (
session.query(DR)
.filter(
DR.dag_id == self.external_dag_id,
DR.data_interval_start == dttm_filter[0]
)
.first()
)
if not dag_run:
return 0
count = (
session.query(func.count()) # .count() is inefficient
.filter(
TI.dag_id == self.external_dag_id,
TI.task_id == self.external_task_id,
TI.state.in_(states),
TI.run_id == dag_run.run_id,
)
.scalar()
)
else:
count = (
session.query(func.count())
.filter(
DR.dag_id == self.external_dag_id,
DR.state.in_(states),
DR.data_interval_start.in_(dttm_filter),
)
.scalar()
)
if count > 1:
return 1
return count

0 comments on commit 96fb5d7

Please sign in to comment.