Skip to content

Commit

Permalink
refactor: remove FilterParam out of this PR
Browse files Browse the repository at this point in the history
  • Loading branch information
jason810496 committed Nov 1, 2024
1 parent 46a4c25 commit aeed230
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 64 deletions.
51 changes: 1 addition & 50 deletions airflow/api_fastapi/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Generic, List, Literal, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Generic, List, TypeVar

from fastapi import Depends, HTTPException, Query
from pendulum.parsing.exceptions import ParserError
Expand Down Expand Up @@ -219,55 +219,6 @@ def inner(order_by: str = self.get_primary_key_string()) -> SortParam:
return inner


_filter_options = Literal["in", "not_in", "eq", "ne", "lt", "le", "gt", "ge"]


class FilterParam(BaseParam[T]):
"""Filter on attribute."""

def __init__(
self,
attribute: ColumnElement,
value: T | None = None,
filter_option: _filter_options = "eq",
skip_none: bool = True,
) -> None:
super().__init__(skip_none)
self.attribute: ColumnElement = attribute
self.value: T | None = value
self.filter_option: _filter_options = filter_option

def to_orm(self, select: Select) -> Select:
if self.value is None and self.skip_none:
return select

if isinstance(self.value, list):
if self.filter_option == "in":
return select.where(self.attribute.in_(self.value))
if self.filter_option == "not_in":
return select.where(self.attribute.notin_(self.value))
raise ValueError(f"Invalid filter option {self.filter_option} for list value {self.value}")

if self.filter_option == "eq":
return select.where(self.attribute == self.value)
if self.filter_option == "ne":
return select.where(self.attribute != self.value)
if self.filter_option == "lt":
return select.where(self.attribute < self.value)
if self.filter_option == "le":
return select.where(self.attribute <= self.value)
if self.filter_option == "gt":
return select.where(self.attribute > self.value)
if self.filter_option == "ge":
return select.where(self.attribute >= self.value)
raise ValueError(f"Invalid filter option {self.filter_option} for value {self.value}")

def depends(self, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError(
"Construct FilterParam directly within the router handler, depends is not implemented."
)


class _TagsFilter(BaseParam[List[str]]):
"""Filter on tags."""

Expand Down
38 changes: 24 additions & 14 deletions airflow/api_fastapi/core_api/routes/public/event_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
paginated_select,
)
from airflow.api_fastapi.common.parameters import (
FilterParam,
QueryLimit,
QueryOffset,
SortParam,
Expand Down Expand Up @@ -99,21 +98,32 @@ async def get_event_logs(
) -> EventLogCollectionResponse:
"""Get all Event Logs."""
base_select = select(Log).group_by(Log.id)
# TODO: Refactor using the `FilterParam` class in commit `574b72e41cc5ed175a2bbf4356522589b836bb11`
if dag_id is not None:
base_select = base_select.where(Log.dag_id == dag_id)
if task_id is not None:
base_select = base_select.where(Log.task_id == task_id)
if run_id is not None:
base_select = base_select.where(Log.run_id == run_id)
if map_index is not None:
base_select = base_select.where(Log.map_index == map_index)
if try_number is not None:
base_select = base_select.where(Log.try_number == try_number)
if owner is not None:
base_select = base_select.where(Log.owner == owner)
if event is not None:
base_select = base_select.where(Log.event == event)
if excluded_events is not None:
base_select = base_select.where(Log.event.notin_(excluded_events))
if included_events is not None:
base_select = base_select.where(Log.event.in_(included_events))
if before is not None:
base_select = base_select.where(Log.dttm < before)
if after is not None:
base_select = base_select.where(Log.dttm > after)
event_logs_select, total_entries = paginated_select(
base_select,
[
FilterParam(Log.dag_id, dag_id),
FilterParam(Log.task_id, task_id),
FilterParam(Log.run_id, run_id),
FilterParam(Log.map_index, map_index),
FilterParam(Log.event, event),
FilterParam(Log.try_number, try_number),
FilterParam(Log.owner, owner),
FilterParam(Log.event, excluded_events, "not_in"),
FilterParam(Log.event, included_events, "in"),
FilterParam(Log.dttm, before, "lt"),
FilterParam(Log.dttm, after, "gt"),
],
[],
order_by,
offset,
limit,
Expand Down

0 comments on commit aeed230

Please sign in to comment.