Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-84 List Mapped Task Instances #43642

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def get_mapped_task_instance(
return task_instance_schema.dump(task_instance)


@mark_fastapi_migration_done
@format_parameters(
{
"execution_date_gte": format_datetime,
Expand Down
151 changes: 142 additions & 9 deletions airflow/api_fastapi/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,24 @@

from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Generic, List, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Generic, List, Optional, TypeVar

from fastapi import Depends, HTTPException, Query
from pendulum.parsing.exceptions import ParserError
from pydantic import AfterValidator
from pydantic import AfterValidator, BaseModel
from sqlalchemy import Column, case, or_
from sqlalchemy.inspection import inspect
from typing_extensions import Annotated, Self

from airflow.api_connexion.endpoints.task_instance_endpoint import _convert_ti_states
from airflow.models import Base, Connection
from airflow.models.dag import DagModel, DagTag
from airflow.models.dagrun import DagRun
from airflow.models.dagwarning import DagWarning, DagWarningType
from airflow.models.errors import ParseImportError
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.state import DagRunState
from airflow.utils.state import DagRunState, TaskInstanceState

if TYPE_CHECKING:
from sqlalchemy.sql import ColumnElement, Select
Expand All @@ -45,8 +47,8 @@
class BaseParam(Generic[T], ABC):
"""Base class for filters."""

def __init__(self, skip_none: bool = True) -> None:
self.value: T | None = None
def __init__(self, value: T | None = None, skip_none: bool = True) -> None:
self.value = value
self.attribute: ColumnElement | None = None
self.skip_none = skip_none

Expand Down Expand Up @@ -128,7 +130,7 @@ class _SearchParam(BaseParam[str]):
"""Search on attribute."""

def __init__(self, attribute: ColumnElement, skip_none: bool = True) -> None:
super().__init__(skip_none)
super().__init__(skip_none=skip_none)
self.attribute: ColumnElement = attribute

def to_orm(self, select: Select) -> Select:
Expand Down Expand Up @@ -227,8 +229,8 @@ def get_primary_key_string(self) -> str:
def depends(self, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError("Use dynamic_depends, depends not implemented.")

def dynamic_depends(self) -> Callable:
def inner(order_by: str = self.get_primary_key_string()) -> SortParam:
def dynamic_depends(self, default: str | None = None) -> Callable:
def inner(order_by: str = default or self.get_primary_key_string()) -> SortParam:
return self.set_value(self.get_primary_key_string() if order_by == "" else order_by)

return inner
Expand Down Expand Up @@ -268,6 +270,75 @@ def depends(self, owners: list[str] = Query(default_factory=list)) -> _OwnersFil
return self.set_value(owners)


class _TIStateFilter(BaseParam[List[Optional[TaskInstanceState]]]):
"""Filter on task instance state."""

def to_orm(self, select: Select) -> Select:
if self.skip_none is False:
raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}")

if not self.value:
return select

conditions = [TaskInstance.state == state for state in self.value]
return select.where(or_(*conditions))

def depends(self, state: list[str] = Query(default_factory=list)) -> _TIStateFilter:
states = _convert_ti_states(state)
return self.set_value(states)


class _TIPoolFilter(BaseParam[List[str]]):
"""Filter on task instance pool."""

def to_orm(self, select: Select) -> Select:
if self.skip_none is False:
raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}")

if not self.value:
return select

conditions = [TaskInstance.pool == pool for pool in self.value]
return select.where(or_(*conditions))

def depends(self, pool: list[str] = Query(default_factory=list)) -> _TIPoolFilter:
return self.set_value(pool)


class _TIQueueFilter(BaseParam[List[str]]):
"""Filter on task instance queue."""

def to_orm(self, select: Select) -> Select:
if self.skip_none is False:
raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}")

if not self.value:
return select

conditions = [TaskInstance.queue == queue for queue in self.value]
return select.where(or_(*conditions))

def depends(self, queue: list[str] = Query(default_factory=list)) -> _TIQueueFilter:
return self.set_value(queue)


class _TIExecutorFilter(BaseParam[List[str]]):
"""Filter on task instance executor."""

def to_orm(self, select: Select) -> Select:
if self.skip_none is False:
raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}")

if not self.value:
return select

conditions = [TaskInstance.executor == executor for executor in self.value]
return select.where(or_(*conditions))

def depends(self, executor: list[str] = Query(default_factory=list)) -> _TIExecutorFilter:
return self.set_value(executor)


class _LastDagRunStateFilter(BaseParam[DagRunState]):
"""Filter on the state of the latest DagRun."""

Expand Down Expand Up @@ -323,7 +394,7 @@ class _DagIdFilter(BaseParam[str]):
"""Filter on dag_id."""

def __init__(self, attribute: ColumnElement, skip_none: bool = True) -> None:
super().__init__(skip_none)
super().__init__(skip_none=skip_none)
self.attribute = attribute

def to_orm(self, select: Select) -> Select:
Expand All @@ -335,6 +406,63 @@ def depends(self, dag_id: str | None = None) -> _DagIdFilter:
return self.set_value(dag_id)


class Range(BaseModel, Generic[T]):
"""Range with a lower and upper bound."""

lower_bound: T | None
upper_bound: T | None


class RangeFilter(BaseParam[Range]):
"""Filter on range in between the lower and upper bound."""

def __init__(self, value: Range | None, attribute: ColumnElement) -> None:
super().__init__(value)
self.attribute: ColumnElement = attribute

def to_orm(self, select: Select) -> Select:
if self.skip_none is False:
raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}")

if self.value and self.value.lower_bound:
select = select.where(self.attribute >= self.value.lower_bound)
if self.value and self.value.upper_bound:
select = select.where(self.attribute <= self.value.upper_bound)
return select

def depends(self, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError("Use the `range_filter_factory` function to create the dependency")


def datetime_range_filter_factory(
filter_name: str, model: Base, attribute_name: str | None = None
) -> Callable[[datetime | None, datetime | None], RangeFilter]:
def depends_datetime(
lower_bound: datetime | None = Query(alias=f"{filter_name}_gte", default=None),
upper_bound: datetime | None = Query(alias=f"{filter_name}_lte", default=None),
) -> RangeFilter:
return RangeFilter(
Range(lower_bound=lower_bound, upper_bound=upper_bound),
getattr(model, attribute_name or filter_name),
)

return depends_datetime


def float_range_filter_factory(
filter_name: str, model: Base
) -> Callable[[float | None, float | None], RangeFilter]:
def depends_float(
lower_bound: float | None = Query(alias=f"{filter_name}_gte", default=None),
upper_bound: float | None = Query(alias=f"{filter_name}_lte", default=None),
) -> RangeFilter:
return RangeFilter(
Range(lower_bound=lower_bound, upper_bound=upper_bound), getattr(model, filter_name)
)

return depends_float


# Common Safe DateTime
DateTimeQuery = Annotated[str, AfterValidator(_safe_parse_datetime)]

Expand Down Expand Up @@ -363,3 +491,8 @@ def depends(self, dag_id: str | None = None) -> _DagIdFilter:

# DAGTags
QueryDagTagPatternSearch = Annotated[_DagTagNamePatternSearch, Depends(_DagTagNamePatternSearch().depends)]
# TI
QueryTIStateFilter = Annotated[_TIStateFilter, Depends(_TIStateFilter().depends)]
QueryTIPoolFilter = Annotated[_TIPoolFilter, Depends(_TIPoolFilter().depends)]
QueryTIQueueFilter = Annotated[_TIQueueFilter, Depends(_TIQueueFilter().depends)]
QueryTIExecutorFilter = Annotated[_TIExecutorFilter, Depends(_TIExecutorFilter().depends)]
Loading