Skip to content

Commit

Permalink
[typing] prefect.utilities
Browse files Browse the repository at this point in the history
This is a complete refactor of this module.

- Move functions into the `sqlalchemy.func` namespace so they don't need to be imported everywhere
- Re-use SQLAlchemy's Postgresql JSONB operators by providing SQLite equivalents
- Provide a new function that calculates the difference between timestamps as seconds.

This removes the need for many separate PostgreSQL vs SQLite queries.
  • Loading branch information
mjpieters committed Dec 13, 2024
1 parent 3abe9d0 commit 93b85fa
Show file tree
Hide file tree
Showing 13 changed files with 752 additions and 728 deletions.
6 changes: 4 additions & 2 deletions src/prefect/server/api/run_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,14 @@ async def run_history(
# estimated run times only includes positive run times (to avoid any unexpected corner cases)
"sum_estimated_run_time",
sa.func.sum(
db.greatest(0, sa.extract("epoch", runs.c.estimated_run_time))
sa.func.greatest(
0, sa.extract("epoch", runs.c.estimated_run_time)
)
),
# estimated lateness is the sum of any positive start time deltas
"sum_estimated_lateness",
sa.func.sum(
db.greatest(
sa.func.greatest(
0, sa.extract("epoch", runs.c.estimated_start_time_delta)
)
),
Expand Down
43 changes: 5 additions & 38 deletions src/prefect/server/api/ui/task_runs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
from datetime import datetime, timezone
from datetime import datetime
from typing import List, Optional, cast

import pendulum
Expand Down Expand Up @@ -37,37 +36,6 @@ def ser_model(self) -> dict:
}


def _postgres_bucket_expression(
db: PrefectDBInterface, delta: pendulum.Duration, start_datetime: datetime
):
# asyncpg under Python 3.7 doesn't support timezone-aware datetimes for the EXTRACT
# function, so we will send it as a naive datetime in UTC
if sys.version_info < (3, 8):
start_datetime = start_datetime.astimezone(timezone.utc).replace(tzinfo=None)

return sa.func.floor(
(
sa.func.extract("epoch", db.TaskRun.start_time)
- sa.func.extract("epoch", start_datetime)
)
/ delta.total_seconds()
).label("bucket")


def _sqlite_bucket_expression(
db: PrefectDBInterface, delta: pendulum.Duration, start_datetime: datetime
):
return sa.func.floor(
(
(
sa.func.strftime("%s", db.TaskRun.start_time)
- sa.func.strftime("%s", start_datetime)
)
/ delta.total_seconds()
)
).label("bucket")


@router.post("/dashboard/counts")
async def read_dashboard_task_run_counts(
task_runs: schemas.filters.TaskRunFilter,
Expand Down Expand Up @@ -121,11 +89,10 @@ async def read_dashboard_task_run_counts(
start_time.microsecond,
start_time.timezone,
)
bucket_expression = (
_sqlite_bucket_expression(db, delta, start_datetime)
if db.dialect.name == "sqlite"
else _postgres_bucket_expression(db, delta, start_datetime)
)
bucket_expression = sa.func.floor(
sa.func.date_diff_seconds(db.TaskRun.start_time, start_datetime)
/ delta.total_seconds()
).label("bucket")

raw_counts = (
(
Expand Down
17 changes: 7 additions & 10 deletions src/prefect/server/database/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@
from typing import Dict, Hashable, Optional, Tuple

import sqlalchemy as sa

try:
from sqlalchemy import AdaptedConnection
from sqlalchemy.pool import ConnectionPoolEntry
except ImportError:
# SQLAlchemy 1.4 equivalents
from sqlalchemy.pool import _ConnectionFairy as AdaptedConnection
from sqlalchemy.pool.base import _ConnectionRecord as ConnectionPoolEntry

from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy import AdaptedConnection
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
create_async_engine,
)
from sqlalchemy.pool import ConnectionPoolEntry
from typing_extensions import Literal

from prefect.settings import (
Expand Down
3 changes: 0 additions & 3 deletions src/prefect/server/database/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,6 @@ def insert(self, model):
"""INSERTs a model into the database"""
return self.queries.insert(model)

def greatest(self, *values):
return self.queries.greatest(*values)

def make_timestamp_intervals(
self,
start_time: datetime.datetime,
Expand Down
57 changes: 29 additions & 28 deletions src/prefect/server/database/orm_models.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import datetime
import uuid
from abc import ABC, abstractmethod
from collections.abc import Hashable, Iterable
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Hashable,
Iterable,
Optional,
Union,
cast,
)

import pendulum
import sqlalchemy as sa
from sqlalchemy import FetchedValue
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import (
Expand Down Expand Up @@ -46,19 +45,21 @@
WorkQueueStatus,
)
from prefect.server.utilities.database import (
CAMEL_TO_SNAKE,
JSON,
UUID,
GenerateUUID,
Pydantic,
Timestamp,
camel_to_snake,
date_diff,
interval_add,
now,
)
from prefect.server.utilities.encryption import decrypt_fernet, encrypt_fernet
from prefect.utilities.names import generate_slug

# for 'plain JSON' columns, use the postgresql variant (which comes with
# extra an extra operator) and fall back to the generic JSON variant for
# SQLite
sa_JSON = postgresql.JSON().with_variant(sa.JSON(), "sqlite")


class Base(DeclarativeBase):
"""
Expand Down Expand Up @@ -117,7 +118,7 @@ def __tablename__(cls) -> str:
into a snake-case table name. Override by providing
an explicit `__tablename__` class property.
"""
return camel_to_snake.sub("_", cls.__name__).lower()
return CAMEL_TO_SNAKE.sub("_", cls.__name__).lower()

id: Mapped[uuid.UUID] = mapped_column(
primary_key=True,
Expand All @@ -126,17 +127,17 @@ def __tablename__(cls) -> str:
)

created: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)

# onupdate is only called when statements are actually issued
# against the database. until COMMIT is issued, this column
# will not be updated
updated: Mapped[pendulum.DateTime] = mapped_column(
index=True,
server_default=now(),
server_default=sa.func.now(),
default=lambda: pendulum.now("UTC"),
onupdate=now(),
onupdate=sa.func.now(),
server_onupdate=FetchedValue(),
)

Expand Down Expand Up @@ -175,7 +176,7 @@ class FlowRunState(Base):
sa.Enum(schemas.states.StateType, name="state_type"), index=True
)
timestamp: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)
name: Mapped[str] = mapped_column(index=True)
message: Mapped[Optional[str]]
Expand Down Expand Up @@ -240,7 +241,7 @@ class TaskRunState(Base):
sa.Enum(schemas.states.StateType, name="state_type"), index=True
)
timestamp: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)
name: Mapped[str] = mapped_column(index=True)
message: Mapped[Optional[str]]
Expand Down Expand Up @@ -303,11 +304,11 @@ class Artifact(Base):
flow_run_id: Mapped[Optional[uuid.UUID]] = mapped_column(index=True)

type: Mapped[Optional[str]]
data: Mapped[Optional[Any]] = mapped_column(sa.JSON)
data: Mapped[Optional[Any]] = mapped_column(sa_JSON)
description: Mapped[Optional[str]]

# Suffixed with underscore as attribute name 'metadata' is reserved for the MetaData instance when using a declarative base class.
metadata_: Mapped[Optional[dict[str, str]]] = mapped_column(sa.JSON)
metadata_: Mapped[Optional[dict[str, str]]] = mapped_column(sa_JSON)

@declared_attr.directive
@classmethod
Expand Down Expand Up @@ -342,9 +343,9 @@ class ArtifactCollection(Base):
flow_run_id: Mapped[Optional[uuid.UUID]]

type: Mapped[Optional[str]]
data: Mapped[Optional[Any]] = mapped_column(sa.JSON)
data: Mapped[Optional[Any]] = mapped_column(sa_JSON)
description: Mapped[Optional[str]]
metadata_: Mapped[Optional[dict[str, str]]] = mapped_column(sa.JSON)
metadata_: Mapped[Optional[dict[str, str]]] = mapped_column(sa_JSON)

__table_args__: Any = (
sa.UniqueConstraint("key"),
Expand Down Expand Up @@ -419,9 +420,9 @@ def _estimated_run_time_expression(cls) -> sa.Label[datetime.timedelta]:
sa.case(
(
cls.state_type == schemas.states.StateType.RUNNING,
interval_add(
sa.func.interval_add(
cls.total_run_time,
date_diff(now(), cls.state_timestamp),
sa.func.date_diff(sa.func.now(), cls.state_timestamp),
),
),
else_=cls.total_run_time,
Expand Down Expand Up @@ -464,15 +465,15 @@ def _estimated_start_time_delta_expression(
return sa.case(
(
cls.start_time > cls.expected_start_time,
date_diff(cls.start_time, cls.expected_start_time),
sa.func.date_diff(cls.start_time, cls.expected_start_time),
),
(
sa.and_(
cls.start_time.is_(None),
cls.state_type.not_in(schemas.states.TERMINAL_STATES),
cls.expected_start_time < now(),
cls.expected_start_time < sa.func.now(),
),
date_diff(now(), cls.expected_start_time),
sa.func.date_diff(sa.func.now(), cls.expected_start_time),
),
else_=datetime.timedelta(0),
)
Expand Down Expand Up @@ -1165,7 +1166,7 @@ class Worker(Base):

name: Mapped[str]
last_heartbeat_time: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)
heartbeat_interval_seconds: Mapped[Optional[int]]

Expand Down Expand Up @@ -1195,7 +1196,7 @@ class Agent(Base):
)

last_activity_time: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)

__table_args__: Any = (sa.UniqueConstraint("name"),)
Expand Down Expand Up @@ -1277,11 +1278,11 @@ class Automation(Base):
@classmethod
def sort_expression(cls, value: AutomationSort) -> sa.ColumnExpressionArgument[Any]:
"""Return an expression used to sort Automations"""
sort_mapping = {
sort_mapping: dict[AutomationSort, sa.ColumnExpressionArgument[Any]] = {
AutomationSort.CREATED_DESC: cls.created.desc(),
AutomationSort.UPDATED_DESC: cls.updated.desc(),
AutomationSort.NAME_ASC: cast(sa.Column, cls.name).asc(),
AutomationSort.NAME_DESC: cast(sa.Column, cls.name).desc(),
AutomationSort.NAME_ASC: cls.name.asc(),
AutomationSort.NAME_DESC: cls.name.desc(),
}
return sort_mapping[value]

Expand Down Expand Up @@ -1439,7 +1440,7 @@ def __tablename__(cls) -> str:
occurred: Mapped[pendulum.DateTime]
resource_id: Mapped[str] = mapped_column(sa.Text())
resource_role: Mapped[str] = mapped_column(sa.Text())
resource: Mapped[dict[str, Any]] = mapped_column(sa.JSON())
resource: Mapped[dict[str, Any]] = mapped_column(sa_JSON)
event_id: Mapped[uuid.UUID]


Expand Down
Loading

0 comments on commit 93b85fa

Please sign in to comment.