Skip to content

Commit

Permalink
Merge branch 'main' into Resolve_internal_warnings_for_testLocalTaskJob
Browse files Browse the repository at this point in the history
  • Loading branch information
Owen-CH-Leung authored Apr 13, 2024
2 parents feaa49a + d03ba59 commit abb07d6
Show file tree
Hide file tree
Showing 70 changed files with 1,719 additions and 292 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ repos:
types_or: [python, pyi]
args: [--fix]
require_serial: true
additional_dependencies: ["ruff==0.3.5"]
additional_dependencies: ["ruff==0.3.6"]
exclude: ^.*/.*_vendor/|^tests/dags/test_imports.py
- id: ruff-format
name: Run 'ruff format' for extremely fast Python formatting
Expand All @@ -345,7 +345,7 @@ repos:
types_or: [python, pyi]
args: []
require_serial: true
additional_dependencies: ["ruff==0.3.5"]
additional_dependencies: ["ruff==0.3.6"]
exclude: ^.*/.*_vendor/|^tests/dags/test_imports.py|^airflow/contrib/
- id: replace-bad-characters
name: Replace bad characters
Expand Down
8 changes: 3 additions & 5 deletions airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,9 @@ def get_import_errors(
if not can_read_all_dags:
# if the user doesn't have access to all DAGs, only display errors from visible DAGs
readable_dag_ids = security.get_readable_dags()
dagfiles_subq = (
select(DagModel.fileloc).distinct().where(DagModel.dag_id.in_(readable_dag_ids)).subquery()
)
query = query.where(ImportErrorModel.filename.in_(dagfiles_subq))
count_query = count_query.where(ImportErrorModel.filename.in_(dagfiles_subq))
dagfiles_stmt = select(DagModel.fileloc).distinct().where(DagModel.dag_id.in_(readable_dag_ids))
query = query.where(ImportErrorModel.filename.in_(dagfiles_stmt))
count_query = count_query.where(ImportErrorModel.filename.in_(dagfiles_stmt))

total_entries = session.scalars(count_query).one()
import_errors = session.scalars(query.offset(offset).limit(limit)).all()
Expand Down
4 changes: 3 additions & 1 deletion airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ def _initialize_map() -> dict[str, Callable]:
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance
from airflow.secrets.metastore import MetastoreBackend
from airflow.utils.cli_action_loggers import _default_action_log_internal
from airflow.utils.log.file_task_handler import FileTaskHandler

functions: list[Callable] = [
_default_action_log_internal,
_get_template_context,
_update_rtif,
_get_ti_db_access,
_update_rtif,
DagFileProcessor.update_import_errors,
DagFileProcessor.manage_slas,
DagFileProcessorManager.deactivate_stale_dags,
Expand Down
2 changes: 1 addition & 1 deletion airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ database:
version_added: 2.3.0
type: string
sensitive: true
example: '{"arg1": True}'
example: '{"arg1": true}'
default: ~
sql_engine_encoding:
description: |
Expand Down
7 changes: 6 additions & 1 deletion airflow/example_dags/plugins/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def on_task_instance_success(previous_state: TaskInstanceState, task_instance: T

# [START howto_listen_ti_failure_task]
@hookimpl
def on_task_instance_failed(previous_state: TaskInstanceState, task_instance: TaskInstance, session):
def on_task_instance_failed(
previous_state: TaskInstanceState, task_instance: TaskInstance, error: None | str | BaseException, session
):
"""
This method is called when task state changes to FAILED.
Through callback, parameters like previous_task_state, task_instance object can be accessed.
Expand All @@ -113,6 +115,8 @@ def on_task_instance_failed(previous_state: TaskInstanceState, task_instance: Ta

print(f"Task start:{start_date} end:{end_date} duration:{duration}")
print(f"Task:{task} dag:{dag} dagrun:{dagrun}")
if error:
print(f"Failure caused by {error}")


# [END howto_listen_ti_failure_task]
Expand Down Expand Up @@ -146,6 +150,7 @@ def on_dag_run_failed(dag_run: DagRun, msg: str):
external_trigger = dag_run.external_trigger

print(f"Dag information:{dag_id} Run id: {run_id} external trigger: {external_trigger}")
print(f"Failed with message: {msg}")


# [END howto_listen_dagrun_failure_task]
Expand Down
5 changes: 4 additions & 1 deletion airflow/listeners/spec/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def on_task_instance_success(

@hookspec
def on_task_instance_failed(
previous_state: TaskInstanceState | None, task_instance: TaskInstance, session: Session | None
previous_state: TaskInstanceState | None,
task_instance: TaskInstance,
error: None | str | BaseException,
session: Session | None,
):
"""Execute when task state changes to FAIL. previous_state can be None."""
1 change: 0 additions & 1 deletion airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,6 @@ def get_task_instance(
)

@staticmethod
@internal_api_call
@provide_session
def fetch_task_instance(
dag_id: str,
Expand Down
34 changes: 29 additions & 5 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def _refresh_from_db(
task_instance.end_date = ti.end_date
task_instance.duration = ti.duration
task_instance.state = ti.state
task_instance.try_number = ti._try_number # private attr to get value unaltered by accessor
task_instance.try_number = _get_private_try_number(task_instance=ti)
task_instance.max_tries = ti.max_tries
task_instance.hostname = ti.hostname
task_instance.unixname = ti.unixname
Expand Down Expand Up @@ -925,7 +925,7 @@ def _handle_failure(
TaskInstance.save_to_db(failure_context["ti"], session)


def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic):
def _get_try_number(*, task_instance: TaskInstance):
"""
Return the try number that a task number will be when it is actually run.
Expand All @@ -943,6 +943,23 @@ def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic):
return task_instance._try_number + 1


def _get_private_try_number(*, task_instance: TaskInstance | TaskInstancePydantic):
"""
Opposite of _get_try_number.
Given the value returned by try_number, return the value of _try_number that
should produce the same result.
This is needed for setting _try_number on TaskInstance from the value on PydanticTaskInstance, which has no private attrs.
:param task_instance: the task instance
:meta private:
"""
if task_instance.state == TaskInstanceState.RUNNING:
return task_instance.try_number
return task_instance.try_number - 1


def _set_try_number(*, task_instance: TaskInstance | TaskInstancePydantic, value: int) -> None:
"""
Set a task try number.
Expand All @@ -952,7 +969,7 @@ def _set_try_number(*, task_instance: TaskInstance | TaskInstancePydantic, value
:meta private:
"""
task_instance._try_number = value
task_instance._try_number = value # type: ignore[union-attr]


def _refresh_from_task(
Expand Down Expand Up @@ -1413,6 +1430,7 @@ class TaskInstance(Base, LoggingMixin):
cascade="all, delete, delete-orphan",
)
note = association_proxy("task_instance_note", "content", creator=_creator_note)

task: Operator | None = None
test_mode: bool = False
is_trigger_log_context: bool = False
Expand Down Expand Up @@ -2934,7 +2952,7 @@ def fetch_handle_failure_context(
):
"""Handle Failure for the TaskInstance."""
get_listener_manager().hook.on_task_instance_failed(
previous_state=TaskInstanceState.RUNNING, task_instance=ti, session=session
previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error, session=session
)

if error:
Expand Down Expand Up @@ -2999,6 +3017,12 @@ def fetch_handle_failure_context(
_stop_remaining_tasks(task_instance=ti, session=session)
else:
if ti.state == TaskInstanceState.QUEUED:
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

if isinstance(ti, TaskInstancePydantic):
# todo: (AIP-44) we should probably "coalesce" `ti` to TaskInstance before here
# e.g. we could make refresh_from_db return a TI and replace ti with that
raise RuntimeError("Expected TaskInstance here. Further AIP-44 work required.")
# We increase the try_number to fail the task if it fails to start after sometime
ti._try_number += 1
ti.state = State.UP_FOR_RETRY
Expand Down Expand Up @@ -3539,7 +3563,7 @@ def _schedule_downstream_tasks(

except OperationalError as e:
# Any kind of DB error here is _non fatal_ as this block is just an optimisation.
cls.logger().info(
cls.logger().debug(
"Skipping mini scheduling run due to exception: %s",
e.statement,
exc_info=True,
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/airbyte/hooks/airbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,12 @@ def __init__(
async def get_headers_tenants_from_connection(self) -> tuple[dict[str, Any], str]:
"""Get Headers, tenants from the connection details."""
connection: Connection = await sync_to_async(self.get_connection)(self.http_conn_id)
base_url = connection.host
# schema defaults to HTTP
schema = connection.schema if connection.schema else "http"
base_url = f"{schema}://{connection.host}"

if connection.port:
base_url += f":{connection.port}"

if self.api_type == "config":
credentials = f"{connection.login}:{connection.password}"
Expand Down
103 changes: 102 additions & 1 deletion airflow/providers/amazon/aws/operators/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger
from airflow.providers.amazon.aws.triggers.bedrock import (
BedrockCustomizeModelCompletedTrigger,
BedrockProvisionModelThroughputCompletedTrigger,
)
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.utils.helpers import prune_dict
Expand Down Expand Up @@ -250,3 +253,101 @@ def execute(self, context: Context) -> dict:
)

return response["jobArn"]


class BedrockCreateProvisionedModelThroughputOperator(AwsBaseOperator[BedrockHook]):
"""
Create a fine-tuning job to customize a base model.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BedrockCreateProvisionedModelThroughputOperator`
:param model_units: Number of model units to allocate. (templated)
:param provisioned_model_name: Unique name for this provisioned throughput. (templated)
:param model_id: Name or ARN of the model to associate with this provisioned throughput. (templated)
:param create_throughput_kwargs: Any optional parameters to pass to the API.
:param wait_for_completion: Whether to wait for cluster to stop. (default: True)
:param waiter_delay: Time in seconds to wait between status checks. (default: 60)
:param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20)
:param deferrable: If True, the operator will wait asynchronously for the cluster to stop.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

aws_hook_class = BedrockHook
template_fields: Sequence[str] = aws_template_fields(
"model_units",
"provisioned_model_name",
"model_id",
)

def __init__(
self,
model_units: int,
provisioned_model_name: str,
model_id: str,
create_throughput_kwargs: dict[str, Any] | None = None,
wait_for_completion: bool = True,
waiter_delay: int = 60,
waiter_max_attempts: int = 20,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)
self.model_units = model_units
self.provisioned_model_name = provisioned_model_name
self.model_id = model_id
self.create_throughput_kwargs = create_throughput_kwargs or {}
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

def execute(self, context: Context) -> str:
provisioned_model_id = self.hook.conn.create_provisioned_model_throughput(
modelUnits=self.model_units,
provisionedModelName=self.provisioned_model_name,
modelId=self.model_id,
**self.create_throughput_kwargs,
)["provisionedModelArn"]

if self.deferrable:
self.log.info("Deferring for provisioned throughput.")
self.defer(
trigger=BedrockProvisionModelThroughputCompletedTrigger(
provisioned_model_id=provisioned_model_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
)
if self.wait_for_completion:
self.log.info("Waiting for provisioned throughput.")
self.hook.get_waiter("provisioned_model_throughput_complete").wait(
provisionedModelId=provisioned_model_id,
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
)

return provisioned_model_id

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
event = validate_execute_complete_event(event)

if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")

self.log.info("Bedrock provisioned throughput job `%s` complete.", event["provisioned_model_id"])
return event["provisioned_model_id"]
Loading

0 comments on commit abb07d6

Please sign in to comment.