Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create CloudComposerRunAirflowCLICommandOperator operator #38965

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 250 additions & 2 deletions airflow/providers/google/cloud/hooks/cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
import asyncio
import time
from typing import TYPE_CHECKING, MutableSequence, Sequence

from google.api_core.client_options import ClientOptions
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.orchestration.airflow.service_v1 import (
EnvironmentsAsyncClient,
EnvironmentsClient,
ImageVersionsClient,
PollAirflowCommandResponse,
)

from airflow.exceptions import AirflowException
Expand All @@ -42,7 +45,10 @@
from google.cloud.orchestration.airflow.service_v1.services.image_versions.pagers import (
ListImageVersionsPager,
)
from google.cloud.orchestration.airflow.service_v1.types import Environment
from google.cloud.orchestration.airflow.service_v1.types import (
Environment,
ExecuteAirflowCommandResponse,
)
from google.protobuf.field_mask_pb2 import FieldMask


Expand Down Expand Up @@ -294,6 +300,127 @@ def list_image_versions(
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def execute_airflow_command(
self,
project_id: str,
region: str,
environment_id: str,
command: str,
subcommand: str,
parameters: MutableSequence[str],
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> ExecuteAirflowCommandResponse:
"""
Execute Airflow command for provided Composer environment.

:param project_id: The ID of the Google Cloud project that the service belongs to.
:param region: The ID of the Google Cloud region that the service belongs to.
:param environment_id: The ID of the Google Cloud environment that the service belongs to.
:param command: Airflow command.
:param subcommand: Airflow subcommand.
:param parameters: Parameters for the Airflow command/subcommand as an array of arguments. It may
contain positional arguments like ``["my-dag-id"]``, key-value parameters like ``["--foo=bar"]``
or ``["--foo","bar"]``, or other flags like ``["-f"]``.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_environment_client()
result = client.execute_airflow_command(
request={
"environment": self.get_environment_name(project_id, region, environment_id),
"command": command,
"subcommand": subcommand,
"parameters": parameters,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def poll_airflow_command(
self,
project_id: str,
region: str,
environment_id: str,
execution_id: str,
pod: str,
pod_namespace: str,
next_line_number: int,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> PollAirflowCommandResponse:
"""
Poll Airflow command execution result for provided Composer environment.

:param project_id: The ID of the Google Cloud project that the service belongs to.
:param region: The ID of the Google Cloud region that the service belongs to.
:param environment_id: The ID of the Google Cloud environment that the service belongs to.
:param execution_id: The unique ID of the command execution.
:param pod: The name of the pod where the command is executed.
:param pod_namespace: The namespace of the pod where the command is executed.
:param next_line_number: Line number from which new logs should be fetched.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_environment_client()
result = client.poll_airflow_command(
request={
"environment": self.get_environment_name(project_id, region, environment_id),
"execution_id": execution_id,
"pod": pod,
"pod_namespace": pod_namespace,
"next_line_number": next_line_number,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

def wait_command_execution_result(
self,
project_id: str,
region: str,
environment_id: str,
execution_cmd_info: dict,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
poll_interval: int = 10,
) -> dict:
while True:
try:
result = self.poll_airflow_command(
project_id=project_id,
region=region,
environment_id=environment_id,
execution_id=execution_cmd_info["execution_id"],
pod=execution_cmd_info["pod"],
pod_namespace=execution_cmd_info["pod_namespace"],
next_line_number=1,
retry=retry,
timeout=timeout,
metadata=metadata,
)
except Exception as ex:
self.log.exception("Exception occurred while polling CMD result")
raise AirflowException(ex)

result_dict = PollAirflowCommandResponse.to_dict(result)
if result_dict["output_end"]:
return result_dict

self.log.info("Waiting for result...")
time.sleep(poll_interval)


class CloudComposerAsyncHook(GoogleBaseHook):
"""Hook for Google Cloud Composer async APIs."""
Expand Down Expand Up @@ -421,3 +548,124 @@ async def update_environment(
timeout=timeout,
metadata=metadata,
)

@GoogleBaseHook.fallback_to_default_project_id
async def execute_airflow_command(
self,
project_id: str,
region: str,
environment_id: str,
command: str,
subcommand: str,
parameters: MutableSequence[str],
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> AsyncOperation:
"""
Execute Airflow command for provided Composer environment.

:param project_id: The ID of the Google Cloud project that the service belongs to.
:param region: The ID of the Google Cloud region that the service belongs to.
:param environment_id: The ID of the Google Cloud environment that the service belongs to.
:param command: Airflow command.
:param subcommand: Airflow subcommand.
:param parameters: Parameters for the Airflow command/subcommand as an array of arguments. It may
contain positional arguments like ``["my-dag-id"]``, key-value parameters like ``["--foo=bar"]``
or ``["--foo","bar"]``, or other flags like ``["-f"]``.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_environment_client()

return await client.execute_airflow_command(
request={
"environment": self.get_environment_name(project_id, region, environment_id),
"command": command,
"subcommand": subcommand,
"parameters": parameters,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)

@GoogleBaseHook.fallback_to_default_project_id
async def poll_airflow_command(
self,
project_id: str,
region: str,
environment_id: str,
execution_id: str,
pod: str,
pod_namespace: str,
next_line_number: int,
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> AsyncOperation:
"""
Poll Airflow command execution result for provided Composer environment.

:param project_id: The ID of the Google Cloud project that the service belongs to.
:param region: The ID of the Google Cloud region that the service belongs to.
:param environment_id: The ID of the Google Cloud environment that the service belongs to.
:param execution_id: The unique ID of the command execution.
:param pod: The name of the pod where the command is executed.
:param pod_namespace: The namespace of the pod where the command is executed.
:param next_line_number: Line number from which new logs should be fetched.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_environment_client()

return await client.poll_airflow_command(
request={
"environment": self.get_environment_name(project_id, region, environment_id),
"execution_id": execution_id,
"pod": pod,
"pod_namespace": pod_namespace,
"next_line_number": next_line_number,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)

async def wait_command_execution_result(
self,
project_id: str,
region: str,
environment_id: str,
execution_cmd_info: dict,
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
poll_interval: int = 10,
) -> dict:
while True:
try:
result = await self.poll_airflow_command(
project_id=project_id,
region=region,
environment_id=environment_id,
execution_id=execution_cmd_info["execution_id"],
pod=execution_cmd_info["pod"],
pod_namespace=execution_cmd_info["pod_namespace"],
next_line_number=1,
retry=retry,
timeout=timeout,
metadata=metadata,
)
except Exception as ex:
self.log.exception("Exception occurred while polling CMD result")
raise AirflowException(ex)

result_dict = PollAirflowCommandResponse.to_dict(result)
if result_dict["output_end"]:
return result_dict

self.log.info("Sleeping for %s seconds.", poll_interval)
await asyncio.sleep(poll_interval)
Loading