Skip to content

Commit

Permalink
Add on_skipped_callback in to BaseOperator (apache#36374)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Jens Scheffler <95105677+jscheffl@users.noreply.github.com>
  • Loading branch information
romsharon98 and jscheffl authored Jan 14, 2024
1 parent b241577 commit 3eed501
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 0 deletions.
1 change: 1 addition & 0 deletions airflow/example_dags/tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
# 'on_success_callback': some_other_function, # or list of functions
# 'on_retry_callback': another_function, # or list of functions
# 'sla_miss_callback': yet_another_function, # or list of functions
# 'on_skipped_callback': another_function, #or list of functions
# 'trigger_rule': 'all_success'
},
# [END default_args]
Expand Down
10 changes: 10 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def partial(
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
run_as_user: str | None | ArgNotSet = NOTSET,
executor_config: dict | None | ArgNotSet = NOTSET,
inlets: Any | None | ArgNotSet = NOTSET,
Expand Down Expand Up @@ -310,6 +311,7 @@ def partial(
"on_failure_callback": on_failure_callback,
"on_retry_callback": on_retry_callback,
"on_success_callback": on_success_callback,
"on_skipped_callback": on_skipped_callback,
"run_as_user": run_as_user,
"executor_config": executor_config,
"inlets": inlets,
Expand Down Expand Up @@ -597,6 +599,11 @@ class derived from this one results in the creation of a task object,
that it is executed when retries occur.
:param on_success_callback: much like the ``on_failure_callback`` except
that it is executed when the task succeeds.
:param on_skipped_callback: much like the ``on_failure_callback`` except
that it is executed when skipped occur; this callback will be called only if AirflowSkipException get raised.
Explicitly it is NOT called if a task is not started to be executed because of a preceding branching
decision in the DAG or a trigger rule which causes execution to skip so that the task execution
is never scheduled.
:param pre_execute: a function to be called immediately before task
execution, receiving a context dictionary; raising an exception will
prevent the task from being executed.
Expand Down Expand Up @@ -700,6 +707,7 @@ class derived from this one results in the creation of a task object,
"on_failure_callback",
"on_success_callback",
"on_retry_callback",
"on_skipped_callback",
"do_xcom_push",
}

Expand Down Expand Up @@ -759,6 +767,7 @@ def __init__(
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
pre_execute: TaskPreExecuteHook | None = None,
post_execute: TaskPostExecuteHook | None = None,
trigger_rule: str = DEFAULT_TRIGGER_RULE,
Expand Down Expand Up @@ -825,6 +834,7 @@ def __init__(
self.on_failure_callback = on_failure_callback
self.on_success_callback = on_success_callback
self.on_retry_callback = on_retry_callback
self.on_skipped_callback = on_skipped_callback
self._pre_execute_hook = pre_execute
self._post_execute_hook = post_execute

Expand Down
8 changes: 8 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,14 @@ def on_success_callback(self) -> None | TaskStateChangeCallback | list[TaskState
def on_success_callback(self, value: TaskStateChangeCallback | None) -> None:
self.partial_kwargs["on_success_callback"] = value

@property
def on_skipped_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
return self.partial_kwargs.get("on_skipped_callback")

@on_skipped_callback.setter
def on_skipped_callback(self, value: TaskStateChangeCallback | None) -> None:
self.partial_kwargs["on_skipped_callback"] = value

@property
def run_as_user(self) -> str | None:
return self.partial_kwargs.get("run_as_user")
Expand Down
2 changes: 2 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,6 +2361,8 @@ def _run_raw_task(
self.log.info(e)
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
_run_finished_callback(callbacks=self.task.on_skipped_callback, context=context)
session.commit()
self.state = TaskInstanceState.SKIPPED
except AirflowRescheduleException as reschedule_exception:
self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ Name Description
``sla_miss_callback`` Invoked when a task misses its defined :ref:`SLA <concepts:slas>`
``on_retry_callback`` Invoked when the task is :ref:`up for retry <concepts:task-instances>`
``on_execute_callback`` Invoked right before the task begins executing.
``on_skipped_callback`` Invoked when the task is :ref:`running <concepts:task-instances>` and AirflowSkipException raised.
Explicitly it is NOT called if a task is not started to be executed because of a preceding branching
decision in the DAG or a trigger rule which causes execution to skip so that the task execution
is never scheduled.
=========================================== ================================================================


Expand Down
20 changes: 20 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,26 @@ def test_clear_db_references(self, session, create_task_instance):

assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None

def test_skipped_task_call_on_skipped_callback(self, dag_maker):
def raise_skip_exception():
raise AirflowSkipException

callback_function = mock.MagicMock()

with dag_maker(dag_id="test_skipped_task"):
task = PythonOperator(
task_id="test_skipped_task",
python_callable=raise_skip_exception,
on_skipped_callback=callback_function,
)

dr = dag_maker.create_dagrun(execution_date=timezone.utcnow())
ti = dr.task_instances[0]
ti.task = task
ti.run()
assert State.SKIPPED == ti.state
assert callback_function.called


@pytest.mark.parametrize("pool_override", [None, "test_pool2"])
def test_refresh_from_task(pool_override):
Expand Down
1 change: 1 addition & 0 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,7 @@ def test_no_new_fields_added_to_base_operator(self):
"on_execute_callback": None,
"on_failure_callback": None,
"on_retry_callback": None,
"on_skipped_callback": None,
"on_success_callback": None,
"outlets": [],
"owner": "airflow",
Expand Down

0 comments on commit 3eed501

Please sign in to comment.