Skip to content

Commit

Permalink
Fix deferrable mode of BeamRunPythonPipelineOperator (#44386)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak authored Nov 27, 2024
1 parent 90442e8 commit 43adccf
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 31 deletions.
26 changes: 5 additions & 21 deletions providers/src/airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

from __future__ import annotations

import asyncio
import contextlib
import copy
import os
import stat
Expand All @@ -30,7 +28,7 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack
from functools import partial
from typing import IO, TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable

from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -377,7 +375,7 @@ def execute(self, context: Context):
# Check deferrable parameter passed to the operator
# to determine type of run - asynchronous or synchronous
if self.deferrable:
asyncio.run(self.execute_async(context))
self.execute_async(context)
else:
return self.execute_sync(context)

Expand Down Expand Up @@ -425,23 +423,7 @@ def execute_sync(self, context: Context):
process_line_callback=self.process_line_callback,
)

async def execute_async(self, context: Context):
# Creating a new event loop to manage I/O operations asynchronously
loop = asyncio.get_event_loop()
if self.py_file.lower().startswith("gs://"):
gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
# Running synchronous `enter_context()` method in a separate
# thread using the default executor `None`. The `run_in_executor()` function returns the
# file object, which is created using gcs function `provide_file()`, asynchronously.
# This means we can perform asynchronous operations with this file.
create_tmp_file_call = gcs_hook.provide_file(object_url=self.py_file)
tmp_gcs_file: IO[str] = await loop.run_in_executor(
None,
contextlib.ExitStack().enter_context, # type: ignore[arg-type]
create_tmp_file_call,
)
self.py_file = tmp_gcs_file.name

def execute_async(self, context: Context):
if self.is_dataflow and self.dataflow_hook:
DataflowJobLink.persist(
self,
Expand All @@ -460,6 +442,7 @@ async def execute_async(self, context: Context):
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
runner=self.runner,
gcp_conn_id=self.gcp_conn_id,
),
method_name="execute_complete",
)
Expand All @@ -473,6 +456,7 @@ async def execute_async(self, context: Context):
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
runner=self.runner,
gcp_conn_id=self.gcp_conn_id,
),
method_name="execute_complete",
)
Expand Down
20 changes: 20 additions & 0 deletions providers/src/airflow/providers/apache/beam/triggers/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class BeamPythonPipelineTrigger(BeamPipelineBaseTrigger):
Other possible options: DataflowRunner, SparkRunner, FlinkRunner, PortableRunner.
See: :class:`~providers.apache.beam.hooks.beam.BeamRunnerType`
See: https://beam.apache.org/documentation/runners/capability-matrix/
:param gcp_conn_id: Optional. The connection ID to use connecting to Google Cloud.
"""

def __init__(
Expand All @@ -76,6 +77,7 @@ def __init__(
py_requirements: list[str] | None = None,
py_system_site_packages: bool = False,
runner: str = "DirectRunner",
gcp_conn_id: str = "google_cloud_default",
):
super().__init__()
self.variables = variables
Expand All @@ -85,6 +87,7 @@ def __init__(
self.py_requirements = py_requirements
self.py_system_site_packages = py_system_site_packages
self.runner = runner
self.gcp_conn_id = gcp_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize BeamPythonPipelineTrigger arguments and classpath."""
Expand All @@ -98,13 +101,30 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"py_requirements": self.py_requirements,
"py_system_site_packages": self.py_system_site_packages,
"runner": self.runner,
"gcp_conn_id": self.gcp_conn_id,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current pipeline status and yields a TriggerEvent."""
hook = self._get_async_hook(runner=self.runner)
try:
# Get the current running event loop to manage I/O operations asynchronously
loop = asyncio.get_running_loop()
if self.py_file.lower().startswith("gs://"):
gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
# Running synchronous `enter_context()` method in a separate
# thread using the default executor `None`. The `run_in_executor()` function returns the
# file object, which is created using gcs function `provide_file()`, asynchronously.
# This means we can perform asynchronous operations with this file.
create_tmp_file_call = gcs_hook.provide_file(object_url=self.py_file)
tmp_gcs_file: IO[str] = await loop.run_in_executor(
None,
contextlib.ExitStack().enter_context, # type: ignore[arg-type]
create_tmp_file_call,
)
self.py_file = tmp_gcs_file.name

return_code = await hook.start_python_pipeline_async(
variables=self.variables,
py_file=self.py_file,
Expand Down
10 changes: 2 additions & 8 deletions providers/tests/apache/beam/operators/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,24 +942,20 @@ def test_async_execute_should_execute_successfully(self, gcs_hook, beam_hook_moc
), "Trigger is not a BeamPythonPipelineTrigger"

@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_async_execute_direct_runner(self, gcs_hook, beam_hook_mock):
def test_async_execute_direct_runner(self, beam_hook_mock):
"""
Test BeamHook is created and the right args are passed to
start_python_workflow when executing direct runner.
"""
gcs_provide_file = gcs_hook.return_value.provide_file
op = BeamRunPythonPipelineOperator(**self.default_op_kwargs)
with pytest.raises(TaskDeferred):
op.execute(context=mock.MagicMock())
beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
gcs_provide_file.assert_called_once_with(object_url=PY_FILE)

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock):
def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock, persist_link_mock):
"""
Test DataflowHook is created and the right args are passed to
start_python_dataflow when executing Dataflow runner.
Expand All @@ -971,7 +967,6 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
dataflow_config=dataflow_config,
**self.default_op_kwargs,
)
gcs_provide_file = gcs_hook.return_value.provide_file
magic_mock = mock.MagicMock()
with pytest.raises(TaskDeferred):
op.execute(context=magic_mock)
Expand All @@ -994,7 +989,6 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
"region": "us-central1",
"impersonate_service_account": TEST_IMPERSONATION_ACCOUNT,
}
gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
persist_link_mock.assert_called_once_with(
op,
magic_mock,
Expand Down
15 changes: 15 additions & 0 deletions providers/tests/apache/beam/triggers/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
TEST_RUNNER = "DirectRunner"
TEST_JAR_FILE = "example.jar"
TEST_GCS_JAR_FILE = "gs://my-bucket/example/test.jar"
TEST_GCS_PY_FILE = "gs://my-bucket/my-object.py"
TEST_JOB_CLASS = "TestClass"
TEST_CHECK_IF_RUNNING = False
TEST_JOB_NAME = "test_job_name"
Expand All @@ -61,6 +62,7 @@ def python_trigger():
py_requirements=TEST_PY_REQUIREMENTS,
py_system_site_packages=TEST_PY_PACKAGES,
runner=TEST_RUNNER,
gcp_conn_id=TEST_GCP_CONN_ID,
)


Expand Down Expand Up @@ -98,6 +100,7 @@ def test_beam_trigger_serialization_should_execute_successfully(self, python_tri
"py_requirements": TEST_PY_REQUIREMENTS,
"py_system_site_packages": TEST_PY_PACKAGES,
"runner": TEST_RUNNER,
"gcp_conn_id": TEST_GCP_CONN_ID,
}

@pytest.mark.asyncio
Expand Down Expand Up @@ -139,6 +142,18 @@ async def test_beam_trigger_exception_should_execute_successfully(
actual = await generator.asend(None)
assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.beam.triggers.beam.GCSHook")
async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, gcs_hook, python_trigger):
"""
Test that BeamPythonPipelineTrigger downloads GCS provide file correct.
"""
gcs_provide_file = gcs_hook.return_value.provide_file
python_trigger.py_file = TEST_GCS_PY_FILE
generator = python_trigger.run()
await generator.asend(None)
gcs_provide_file.assert_called_once_with(object_url=TEST_GCS_PY_FILE)


class TestBeamJavaPipelineTrigger:
def test_beam_trigger_serialization_should_execute_successfully(self, java_trigger):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,21 @@
py_system_site_packages=False,
)

start_python_deferrable = BeamRunPythonPipelineOperator(
runner=BeamRunnerType.DataflowRunner,
task_id="start_python_job_deferrable",
py_file=GCS_PYTHON_SCRIPT,
py_options=[],
pipeline_options={
"output": GCS_OUTPUT,
},
py_requirements=["apache-beam[gcp]==2.59.0"],
py_interpreter="python3",
py_system_site_packages=False,
dataflow_config={"location": LOCATION, "job_name": "start_python_deferrable"},
deferrable=True,
)

# [START howto_operator_stop_dataflow_job]
stop_dataflow_job = DataflowStopJobOperator(
task_id="stop_dataflow_job",
Expand All @@ -103,8 +118,7 @@
# TEST SETUP
create_bucket
# TEST BODY
>> start_python_job
>> start_python_job_local
>> [start_python_job, start_python_job_local, start_python_deferrable]
>> stop_dataflow_job
# TEST TEARDOWN
>> delete_bucket
Expand Down

0 comments on commit 43adccf

Please sign in to comment.