-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1c0f8dd
commit 96fb5d7
Showing
2 changed files
with
124 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,40 @@ | ||
class BigquerySensor: | ||
pass | ||
import logging | ||
|
||
from airflow.sensors.base import BaseSensorOperator | ||
from airflow.utils.decorators import apply_defaults | ||
|
||
from airlake.hooks.gcs_hook import BigQueryNativeHook | ||
|
||
|
||
class BigQuerySQLSensor(BaseSensorOperator): | ||
template_fields = ("sql",) | ||
ui_color = "#f0eee4" | ||
|
||
@apply_defaults | ||
def __init__( | ||
self, | ||
sql: str, | ||
bigquery_conn_id="bigquery_default_conn", | ||
delegate_to=None, | ||
*args, | ||
**kwargs | ||
): | ||
if not sql: | ||
raise Exception("Must have sql") | ||
super(BigQuerySQLSensor, self).__init__(*args, **kwargs) | ||
self.sql = sql | ||
self.bigquery_conn_id = bigquery_conn_id | ||
self.delegate_to = delegate_to | ||
|
||
def poke(self, context): | ||
self.log.info("Sensor checks sql:\n%s", self.sql) | ||
hook = BigQueryNativeHook( | ||
gcp_conn_id=self.bigquery_conn_id, | ||
delegate_to=self.delegate_to, | ||
) | ||
|
||
try: | ||
return hook.total_rows(self.sql) > 0 | ||
except Exception as e: | ||
logging.error("Execute error %s", e) | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from datetime import timedelta | ||
|
||
|
||
from airflow.sensors.external_task import ExternalTaskSensor as AExternalTaskSensor | ||
from airflow.models import BaseOperatorLink, DagBag, DagModel, DagRun, TaskInstance | ||
from sqlalchemy import func | ||
|
||
|
||
class ExecutionDateFn: | ||
def __init__(self, execution_delta_hour: int = None, | ||
execution_delta_minutes: int = None): | ||
self.execution_delta_hour = execution_delta_hour | ||
self.execution_delta_minutes = execution_delta_minutes | ||
|
||
def __call__(self, execution_date, context=None): | ||
execution_date = context['data_interval_start'] | ||
if self.execution_delta_hour is not None: | ||
return execution_date - timedelta(hours=self.execution_delta_hour) | ||
elif self.execution_delta_minutes is not None: | ||
return execution_date - timedelta(minutes=self.execution_delta_minutes) | ||
return execution_date | ||
|
||
class ExternalTaskSensor(AExternalTaskSensor): | ||
def __init__( | ||
self, | ||
execution_delta_hour: int = None, | ||
execution_delta_minutes: int = None, | ||
*args, | ||
**kwargs): | ||
kwargs.pop('execution_delta', None) | ||
kwargs['execution_date_fn'] = ExecutionDateFn(execution_delta_hour, execution_delta_minutes) | ||
kwargs['check_existence'] = True | ||
super(ExternalTaskSensor, self).__init__(*args, **kwargs) | ||
|
||
def get_count(self, dttm_filter, session, states) -> int: | ||
""" | ||
Get the count of records against dttm filter and states | ||
:param dttm_filter: date time filter for execution date | ||
:type dttm_filter: list | ||
:param session: airflow session object | ||
:type session: SASession | ||
:param states: task or dag states | ||
:type states: list | ||
:return: count of record against the filters | ||
""" | ||
TI = TaskInstance | ||
DR = DagRun | ||
if not dttm_filter: | ||
return 0 | ||
if self.external_task_id: | ||
dag_run = ( | ||
session.query(DR) | ||
.filter( | ||
DR.dag_id == self.external_dag_id, | ||
DR.data_interval_start == dttm_filter[0] | ||
) | ||
.first() | ||
) | ||
if not dag_run: | ||
return 0 | ||
count = ( | ||
session.query(func.count()) # .count() is inefficient | ||
.filter( | ||
TI.dag_id == self.external_dag_id, | ||
TI.task_id == self.external_task_id, | ||
TI.state.in_(states), | ||
TI.run_id == dag_run.run_id, | ||
) | ||
.scalar() | ||
) | ||
else: | ||
count = ( | ||
session.query(func.count()) | ||
.filter( | ||
DR.dag_id == self.external_dag_id, | ||
DR.state.in_(states), | ||
DR.data_interval_start.in_(dttm_filter), | ||
) | ||
.scalar() | ||
) | ||
if count > 1: | ||
return 1 | ||
return count |