From c2c0fd65909bf8cb3a8f6d50c1c9c8aa0773cc3b Mon Sep 17 00:00:00 2001 From: nate nowack Date: Thu, 12 Dec 2024 09:15:10 -0600 Subject: [PATCH 1/8] Consolidate use of `DateTime` to `prefect.types` (#16356) --- docs/v3/tutorials/pipelines.mdx | 2 +- flows/worker.py | 3 +-- src/prefect/_internal/schemas/bases.py | 3 ++- src/prefect/_internal/schemas/validators.py | 2 +- src/prefect/blocks/system.py | 2 +- src/prefect/client/schemas/actions.py | 2 +- src/prefect/client/schemas/filters.py | 2 +- src/prefect/client/schemas/objects.py | 2 +- src/prefect/client/schemas/responses.py | 3 +-- src/prefect/client/schemas/schedules.py | 2 +- src/prefect/context.py | 2 +- src/prefect/events/filters.py | 2 +- src/prefect/events/related.py | 3 ++- src/prefect/events/schemas/events.py | 2 +- src/prefect/results.py | 2 +- src/prefect/server/api/deployments.py | 2 +- src/prefect/server/api/flow_runs.py | 2 +- src/prefect/server/api/run_history.py | 2 +- src/prefect/server/api/task_runs.py | 2 +- src/prefect/server/api/ui/flow_runs.py | 2 +- src/prefect/server/api/ui/flows.py | 2 +- src/prefect/server/api/ui/task_runs.py | 2 +- src/prefect/server/api/work_queues.py | 2 +- src/prefect/server/api/workers.py | 2 +- src/prefect/server/events/counting.py | 2 +- src/prefect/server/events/filters.py | 2 +- src/prefect/server/events/schemas/automations.py | 2 +- src/prefect/server/events/schemas/events.py | 2 +- src/prefect/server/events/triggers.py | 2 +- src/prefect/server/models/task_workers.py | 3 ++- src/prefect/server/schemas/actions.py | 2 +- src/prefect/server/schemas/core.py | 2 +- src/prefect/server/schemas/filters.py | 2 +- src/prefect/server/schemas/responses.py | 3 +-- src/prefect/server/schemas/schedules.py | 3 +-- src/prefect/server/schemas/states.py | 2 +- src/prefect/server/utilities/schemas/bases.py | 3 ++- src/prefect/types/__init__.py | 5 ++++- tests/blocks/test_system.py | 2 +- tests/cli/deployment/test_deployment_run.py | 2 +- tests/cli/test_flow_run.py | 2 +- tests/client/schemas/test_schedules.py | 2 +- tests/client/test_prefect_client.py | 2 +- tests/events/client/test_events_emit_event.py | 2 +- tests/events/client/test_events_schema.py | 2 +- tests/events/server/actions/test_actions_service.py | 2 +- tests/events/server/actions/test_calling_webhook.py | 2 +- tests/events/server/actions/test_jinja_templated_action.py | 2 +- .../server/actions/test_pausing_resuming_automation.py | 2 +- .../server/actions/test_pausing_resuming_deployment.py | 2 +- .../events/server/actions/test_pausing_resuming_work_pool.py | 2 +- .../server/actions/test_pausing_resuming_work_queue.py | 2 +- tests/events/server/conftest.py | 2 +- .../server/models/test_composite_trigger_child_firing.py | 2 +- tests/events/server/storage/test_event_persister.py | 2 +- tests/events/server/test_automations_api.py | 2 +- tests/events/server/test_clients.py | 2 +- tests/events/server/test_events_api.py | 2 +- tests/events/server/test_events_counts.py | 2 +- tests/events/server/test_events_schema.py | 2 +- tests/events/server/triggers/test_basics.py | 2 +- tests/events/server/triggers/test_composite_triggers.py | 2 +- tests/events/server/triggers/test_flow_run_slas.py | 2 +- tests/events/server/triggers/test_service.py | 2 +- tests/server/orchestration/api/ui/test_flows.py | 2 +- tests/server/orchestration/api/ui/test_task_runs.py | 2 +- tests/test_context.py | 2 +- tests/workers/test_process_worker.py | 2 +- 68 files changed, 75 insertions(+), 72 deletions(-) diff --git a/docs/v3/tutorials/pipelines.mdx b/docs/v3/tutorials/pipelines.mdx index 65ef4a58eb24..15be9115e354 100644 --- a/docs/v3/tutorials/pipelines.mdx +++ b/docs/v3/tutorials/pipelines.mdx @@ -418,4 +418,4 @@ You'll use error handling, pagination, and nested flows to scrape data from GitH Need help? [Book a meeting](https://calendly.com/prefect-experts/prefect-product-advocates?utm_campaign=prefect_docs_cloud&utm_content=prefect_docs&utm_medium=docs&utm_source=docs) with a Prefect Product Advocate to get your questions answered. - + \ No newline at end of file diff --git a/flows/worker.py b/flows/worker.py index 8a73f44029dd..9c8e109a7a0d 100644 --- a/flows/worker.py +++ b/flows/worker.py @@ -4,11 +4,10 @@ from threading import Thread from typing import List -from pydantic_extra_types.pendulum_dt import DateTime - from prefect.events import Event from prefect.events.clients import get_events_subscriber from prefect.events.filters import EventFilter, EventNameFilter, EventOccurredFilter +from prefect.types import DateTime async def watch_worker_events(events: List[Event]): diff --git a/src/prefect/_internal/schemas/bases.py b/src/prefect/_internal/schemas/bases.py index f85907f51af7..62804f8b478a 100644 --- a/src/prefect/_internal/schemas/bases.py +++ b/src/prefect/_internal/schemas/bases.py @@ -13,9 +13,10 @@ ConfigDict, Field, ) -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Self +from prefect.types import DateTime + T = TypeVar("T") diff --git a/src/prefect/_internal/schemas/validators.py b/src/prefect/_internal/schemas/validators.py index 0b3269237214..52380cce951c 100644 --- a/src/prefect/_internal/schemas/validators.py +++ b/src/prefect/_internal/schemas/validators.py @@ -18,9 +18,9 @@ import jsonschema import pendulum import yaml -from pydantic_extra_types.pendulum_dt import DateTime from prefect.exceptions import InvalidRepositoryURLError +from prefect.types import DateTime from prefect.utilities.collections import isiterable from prefect.utilities.dockerutils import get_prefect_image_name from prefect.utilities.filesystem import relative_path_to_current_platform diff --git a/src/prefect/blocks/system.py b/src/prefect/blocks/system.py index 135e7f21e0f1..1bdbc4c80d86 100644 --- a/src/prefect/blocks/system.py +++ b/src/prefect/blocks/system.py @@ -9,10 +9,10 @@ field_validator, ) from pydantic import Secret as PydanticSecret -from pydantic_extra_types.pendulum_dt import DateTime as PydanticDateTime from prefect._internal.compatibility.deprecated import deprecated_class from prefect.blocks.core import Block +from prefect.types import DateTime as PydanticDateTime _SecretValueType = Union[ Annotated[StrictStr, Field(title="string")], diff --git a/src/prefect/client/schemas/actions.py b/src/prefect/client/schemas/actions.py index 659c5153d46a..28be60e39cc5 100644 --- a/src/prefect/client/schemas/actions.py +++ b/src/prefect/client/schemas/actions.py @@ -4,7 +4,6 @@ import jsonschema from pydantic import Field, field_validator, model_validator -from pydantic_extra_types.pendulum_dt import DateTime import prefect.client.schemas.objects as objects from prefect._internal.schemas.bases import ActionBaseModel @@ -27,6 +26,7 @@ from prefect.settings import PREFECT_DEPLOYMENT_SCHEDULE_MAX_SCHEDULED_RUNS from prefect.types import ( MAX_VARIABLE_NAME_LENGTH, + DateTime, KeyValueLabelsField, Name, NonEmptyishName, diff --git a/src/prefect/client/schemas/filters.py b/src/prefect/client/schemas/filters.py index 5a4726e75367..52bb7258e700 100644 --- a/src/prefect/client/schemas/filters.py +++ b/src/prefect/client/schemas/filters.py @@ -6,10 +6,10 @@ from uuid import UUID from pydantic import Field -from pydantic_extra_types.pendulum_dt import DateTime from prefect._internal.schemas.bases import PrefectBaseModel from prefect.client.schemas.objects import StateType +from prefect.types import DateTime from prefect.utilities.collections import AutoEnum diff --git a/src/prefect/client/schemas/objects.py b/src/prefect/client/schemas/objects.py index ef5eab667ad1..ece3042c02ce 100644 --- a/src/prefect/client/schemas/objects.py +++ b/src/prefect/client/schemas/objects.py @@ -69,7 +69,7 @@ DateTime = pendulum.DateTime else: - from pydantic_extra_types.pendulum_dt import DateTime + from prefect.types import DateTime R = TypeVar("R", default=Any) diff --git a/src/prefect/client/schemas/responses.py b/src/prefect/client/schemas/responses.py index cb27a6f55392..ce76537001f0 100644 --- a/src/prefect/client/schemas/responses.py +++ b/src/prefect/client/schemas/responses.py @@ -3,13 +3,12 @@ from uuid import UUID from pydantic import ConfigDict, Field -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Literal import prefect.client.schemas.objects as objects from prefect._internal.schemas.bases import ObjectBaseModel, PrefectBaseModel from prefect._internal.schemas.fields import CreatedBy, UpdatedBy -from prefect.types import KeyValueLabelsField +from prefect.types import DateTime, KeyValueLabelsField from prefect.utilities.collections import AutoEnum from prefect.utilities.names import generate_slug diff --git a/src/prefect/client/schemas/schedules.py b/src/prefect/client/schemas/schedules.py index 4b9cf1b3cf5b..2437e194ad83 100644 --- a/src/prefect/client/schemas/schedules.py +++ b/src/prefect/client/schemas/schedules.py @@ -26,7 +26,7 @@ # together. DateTime = pendulum.DateTime else: - from pydantic_extra_types.pendulum_dt import DateTime + from prefect.types import DateTime MAX_ITERATIONS = 1000 # approx. 1 years worth of RDATEs + buffer diff --git a/src/prefect/context.py b/src/prefect/context.py index b82d214d4aff..287b9b58e138 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -26,7 +26,6 @@ ) from pydantic import BaseModel, ConfigDict, Field, PrivateAttr -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Self import prefect.logging @@ -48,6 +47,7 @@ ) from prefect.states import State from prefect.task_runners import TaskRunner +from prefect.types import DateTime from prefect.utilities.services import start_client_metrics_server T = TypeVar("T") diff --git a/src/prefect/events/filters.py b/src/prefect/events/filters.py index 9143c43a8689..f969e9ccb651 100644 --- a/src/prefect/events/filters.py +++ b/src/prefect/events/filters.py @@ -3,9 +3,9 @@ import pendulum from pydantic import Field, PrivateAttr -from pydantic_extra_types.pendulum_dt import DateTime from prefect._internal.schemas.bases import PrefectBaseModel +from prefect.types import DateTime from prefect.utilities.collections import AutoEnum from .schemas.events import Event, Resource, ResourceSpecification diff --git a/src/prefect/events/related.py b/src/prefect/events/related.py index c2218e92903b..ee36db860352 100644 --- a/src/prefect/events/related.py +++ b/src/prefect/events/related.py @@ -15,7 +15,8 @@ from uuid import UUID import pendulum -from pendulum.datetime import DateTime + +from prefect.types import DateTime from .schemas.events import RelatedResource diff --git a/src/prefect/events/schemas/events.py b/src/prefect/events/schemas/events.py index f143c959be5b..7e7ddc6b9c5a 100644 --- a/src/prefect/events/schemas/events.py +++ b/src/prefect/events/schemas/events.py @@ -20,7 +20,6 @@ RootModel, model_validator, ) -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Annotated, Self from prefect._internal.schemas.bases import PrefectBaseModel @@ -28,6 +27,7 @@ from prefect.settings import ( PREFECT_EVENTS_MAXIMUM_LABELS_PER_RESOURCE, ) +from prefect.types import DateTime from .labelling import Labelled diff --git a/src/prefect/results.py b/src/prefect/results.py index 96569ea1f5b3..19206665c8bf 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -35,7 +35,6 @@ model_validator, ) from pydantic_core import PydanticUndefinedType -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import ParamSpec, Self import prefect @@ -57,6 +56,7 @@ from prefect.logging import get_logger from prefect.serializers import PickleSerializer, Serializer from prefect.settings.context import get_current_settings +from prefect.types import DateTime from prefect.utilities.annotations import NotSet from prefect.utilities.asyncutils import sync_compatible from prefect.utilities.pydantic import get_dispatch_key, lookup_type, register_base_type diff --git a/src/prefect/server/api/deployments.py b/src/prefect/server/api/deployments.py index a690b602322b..347b8f33a60a 100644 --- a/src/prefect/server/api/deployments.py +++ b/src/prefect/server/api/deployments.py @@ -9,7 +9,6 @@ import jsonschema.exceptions import pendulum from fastapi import Body, Depends, HTTPException, Path, Response, status -from pydantic_extra_types.pendulum_dt import DateTime from starlette.background import BackgroundTasks import prefect.server.api.dependencies as dependencies @@ -27,6 +26,7 @@ from prefect.server.models.workers import DEFAULT_AGENT_WORK_POOL_NAME from prefect.server.schemas.responses import DeploymentPaginationResponse from prefect.server.utilities.server import PrefectRouter +from prefect.types import DateTime from prefect.utilities.schema_tools.hydration import ( HydrationContext, HydrationError, diff --git a/src/prefect/server/api/flow_runs.py b/src/prefect/server/api/flow_runs.py index 864469640a14..bace90d0f8e4 100644 --- a/src/prefect/server/api/flow_runs.py +++ b/src/prefect/server/api/flow_runs.py @@ -21,7 +21,6 @@ status, ) from fastapi.responses import ORJSONResponse, PlainTextResponse, StreamingResponse -from pydantic_extra_types.pendulum_dt import DateTime from sqlalchemy.exc import IntegrityError import prefect.server.api.dependencies as dependencies @@ -45,6 +44,7 @@ OrchestrationResult, ) from prefect.server.utilities.server import PrefectRouter +from prefect.types import DateTime from prefect.utilities import schema_tools logger = get_logger("server.api") diff --git a/src/prefect/server/api/run_history.py b/src/prefect/server/api/run_history.py index fe7a1b8cfd97..f70976f4a1e9 100644 --- a/src/prefect/server/api/run_history.py +++ b/src/prefect/server/api/run_history.py @@ -8,7 +8,6 @@ import pydantic import sqlalchemy as sa -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Literal import prefect.server.models as models @@ -16,6 +15,7 @@ from prefect.logging import get_logger from prefect.server.database.dependencies import db_injector from prefect.server.database.interface import PrefectDBInterface +from prefect.types import DateTime logger = get_logger("server.api") diff --git a/src/prefect/server/api/task_runs.py b/src/prefect/server/api/task_runs.py index b75b2abc7ad4..8912b3f5fdf0 100644 --- a/src/prefect/server/api/task_runs.py +++ b/src/prefect/server/api/task_runs.py @@ -17,7 +17,6 @@ WebSocket, status, ) -from pydantic_extra_types.pendulum_dt import DateTime from starlette.websockets import WebSocketDisconnect import prefect.server.api.dependencies as dependencies @@ -34,6 +33,7 @@ from prefect.server.task_queue import MultiQueue, TaskQueue from prefect.server.utilities import subscriptions from prefect.server.utilities.server import PrefectRouter +from prefect.types import DateTime logger = get_logger("server.api") diff --git a/src/prefect/server/api/ui/flow_runs.py b/src/prefect/server/api/ui/flow_runs.py index ee254c4b00b9..c76db8719dca 100644 --- a/src/prefect/server/api/ui/flow_runs.py +++ b/src/prefect/server/api/ui/flow_runs.py @@ -5,7 +5,6 @@ import sqlalchemy as sa from fastapi import Body, Depends from pydantic import Field -from pydantic_extra_types.pendulum_dt import DateTime import prefect.server.schemas as schemas from prefect._internal.schemas.bases import PrefectBaseModel @@ -15,6 +14,7 @@ from prefect.server.database.dependencies import provide_database_interface from prefect.server.database.interface import PrefectDBInterface from prefect.server.utilities.server import PrefectRouter +from prefect.types import DateTime logger = get_logger("server.api.ui.flow_runs") diff --git a/src/prefect/server/api/ui/flows.py b/src/prefect/server/api/ui/flows.py index 57261e459614..8128abd45ccf 100644 --- a/src/prefect/server/api/ui/flows.py +++ b/src/prefect/server/api/ui/flows.py @@ -6,7 +6,6 @@ import sqlalchemy as sa from fastapi import Body, Depends from pydantic import Field, field_validator -from pydantic_extra_types.pendulum_dt import DateTime from prefect.logging import get_logger from prefect.server.database import orm_models @@ -16,6 +15,7 @@ from prefect.server.utilities.database import UUID as UUIDTypeDecorator from prefect.server.utilities.schemas import PrefectBaseModel from prefect.server.utilities.server import PrefectRouter +from prefect.types import DateTime logger = get_logger() diff --git a/src/prefect/server/api/ui/task_runs.py b/src/prefect/server/api/ui/task_runs.py index 11d97a72e07d..b8f4bb778240 100644 --- a/src/prefect/server/api/ui/task_runs.py +++ b/src/prefect/server/api/ui/task_runs.py @@ -6,7 +6,6 @@ import sqlalchemy as sa from fastapi import Depends, HTTPException, status from pydantic import Field, model_serializer -from pydantic_extra_types.pendulum_dt import DateTime import prefect.server.schemas as schemas from prefect._internal.schemas.bases import PrefectBaseModel @@ -15,6 +14,7 @@ from prefect.server.database.dependencies import provide_database_interface from prefect.server.database.interface import PrefectDBInterface from prefect.server.utilities.server import PrefectRouter +from prefect.types import DateTime logger = get_logger("orion.api.ui.task_runs") diff --git a/src/prefect/server/api/work_queues.py b/src/prefect/server/api/work_queues.py index 7733691d521f..101b717c6081 100644 --- a/src/prefect/server/api/work_queues.py +++ b/src/prefect/server/api/work_queues.py @@ -15,7 +15,6 @@ Path, status, ) -from pydantic_extra_types.pendulum_dt import DateTime import prefect.server.api.dependencies as dependencies import prefect.server.models as models @@ -29,6 +28,7 @@ ) from prefect.server.schemas.statuses import WorkQueueStatus from prefect.server.utilities.server import PrefectRouter +from prefect.types import DateTime router = PrefectRouter(prefix="/work_queues", tags=["Work Queues"]) diff --git a/src/prefect/server/api/workers.py b/src/prefect/server/api/workers.py index cc85aa6f4df3..cc40dea8a9d8 100644 --- a/src/prefect/server/api/workers.py +++ b/src/prefect/server/api/workers.py @@ -15,7 +15,6 @@ Path, status, ) -from pydantic_extra_types.pendulum_dt import DateTime from sqlalchemy.ext.asyncio import AsyncSession import prefect.server.api.dependencies as dependencies @@ -32,6 +31,7 @@ from prefect.server.models.workers import emit_work_pool_status_event from prefect.server.schemas.statuses import WorkQueueStatus from prefect.server.utilities.server import PrefectRouter +from prefect.types import DateTime if TYPE_CHECKING: from prefect.server.database.orm_models import ORMWorkQueue diff --git a/src/prefect/server/events/counting.py b/src/prefect/server/events/counting.py index aa8d833c4bac..72d6051c0d19 100644 --- a/src/prefect/server/events/counting.py +++ b/src/prefect/server/events/counting.py @@ -4,12 +4,12 @@ import pendulum import sqlalchemy as sa -from pendulum.datetime import DateTime from sqlalchemy.sql.selectable import Select from prefect.server.database.dependencies import provide_database_interface from prefect.server.database.interface import PrefectDBInterface from prefect.server.utilities.database import json_extract +from prefect.types import DateTime from prefect.utilities.collections import AutoEnum if TYPE_CHECKING: diff --git a/src/prefect/server/events/filters.py b/src/prefect/server/events/filters.py index 2736b6290719..d4f2453b09e4 100644 --- a/src/prefect/server/events/filters.py +++ b/src/prefect/server/events/filters.py @@ -7,7 +7,6 @@ import pendulum import sqlalchemy as sa from pydantic import Field, PrivateAttr -from pydantic_extra_types.pendulum_dt import DateTime from sqlalchemy.sql import Select from prefect._internal.schemas.bases import PrefectBaseModel @@ -17,6 +16,7 @@ PrefectOperatorFilterBaseModel, ) from prefect.server.utilities.database import json_extract +from prefect.types import DateTime from prefect.utilities.collections import AutoEnum from .schemas.events import Event, Resource, ResourceSpecification diff --git a/src/prefect/server/events/schemas/automations.py b/src/prefect/server/events/schemas/automations.py index 548db796142a..8f54426990ab 100644 --- a/src/prefect/server/events/schemas/automations.py +++ b/src/prefect/server/events/schemas/automations.py @@ -25,7 +25,6 @@ field_validator, model_validator, ) -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Self, TypeAlias from prefect.logging import get_logger @@ -39,6 +38,7 @@ ) from prefect.server.schemas.actions import ActionBaseModel from prefect.server.utilities.schemas import ORMBaseModel, PrefectBaseModel +from prefect.types import DateTime from prefect.utilities.collections import AutoEnum logger = get_logger(__name__) diff --git a/src/prefect/server/events/schemas/events.py b/src/prefect/server/events/schemas/events.py index e01e79ec68dd..d72073b6e42f 100644 --- a/src/prefect/server/events/schemas/events.py +++ b/src/prefect/server/events/schemas/events.py @@ -23,7 +23,6 @@ field_validator, model_validator, ) -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Annotated, Self from prefect.logging import get_logger @@ -33,6 +32,7 @@ PREFECT_EVENTS_MAXIMUM_LABELS_PER_RESOURCE, PREFECT_EVENTS_MAXIMUM_RELATED_RESOURCES, ) +from prefect.types import DateTime logger = get_logger(__name__) diff --git a/src/prefect/server/events/triggers.py b/src/prefect/server/events/triggers.py index 2c722901bab8..9a839bc5e66a 100644 --- a/src/prefect/server/events/triggers.py +++ b/src/prefect/server/events/triggers.py @@ -19,7 +19,6 @@ import pendulum import sqlalchemy as sa -from pendulum.datetime import DateTime from sqlalchemy.ext.asyncio import AsyncSession from typing_extensions import Literal, TypeAlias @@ -57,6 +56,7 @@ from prefect.server.events.schemas.events import ReceivedEvent from prefect.server.utilities.messaging import Message, MessageHandler from prefect.settings import PREFECT_EVENTS_EXPIRED_BUCKET_BUFFER +from prefect.types import DateTime if TYPE_CHECKING: from prefect.server.database.orm_models import ORMAutomationBucket diff --git a/src/prefect/server/models/task_workers.py b/src/prefect/server/models/task_workers.py index c1213ae41381..899eba98b4d9 100644 --- a/src/prefect/server/models/task_workers.py +++ b/src/prefect/server/models/task_workers.py @@ -3,9 +3,10 @@ from typing import Dict, List, Set from pydantic import BaseModel -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import TypeAlias +from prefect.types import DateTime + TaskKey: TypeAlias = str WorkerId: TypeAlias = str diff --git a/src/prefect/server/schemas/actions.py b/src/prefect/server/schemas/actions.py index e95efe7c741d..3b61e4f821a5 100644 --- a/src/prefect/server/schemas/actions.py +++ b/src/prefect/server/schemas/actions.py @@ -9,7 +9,6 @@ import pendulum from pydantic import ConfigDict, Field, field_validator, model_validator -from pydantic_extra_types.pendulum_dt import DateTime import prefect.server.schemas as schemas from prefect._internal.schemas.validators import ( @@ -31,6 +30,7 @@ from prefect.settings import PREFECT_DEPLOYMENT_SCHEDULE_MAX_SCHEDULED_RUNS from prefect.types import ( MAX_VARIABLE_NAME_LENGTH, + DateTime, Name, NonEmptyishName, NonNegativeFloat, diff --git a/src/prefect/server/schemas/core.py b/src/prefect/server/schemas/core.py index 496941340a3d..50cd8d0c40c6 100644 --- a/src/prefect/server/schemas/core.py +++ b/src/prefect/server/schemas/core.py @@ -17,7 +17,6 @@ field_validator, model_validator, ) -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Literal, Self from prefect._internal.schemas.validators import ( @@ -43,6 +42,7 @@ from prefect.settings import PREFECT_DEPLOYMENT_SCHEDULE_MAX_SCHEDULED_RUNS from prefect.types import ( MAX_VARIABLE_NAME_LENGTH, + DateTime, LaxUrl, Name, NameOrEmpty, diff --git a/src/prefect/server/schemas/filters.py b/src/prefect/server/schemas/filters.py index 6c6d2892465d..e5e1112c40a9 100644 --- a/src/prefect/server/schemas/filters.py +++ b/src/prefect/server/schemas/filters.py @@ -8,11 +8,11 @@ from uuid import UUID from pydantic import ConfigDict, Field -from pydantic_extra_types.pendulum_dt import DateTime import prefect.server.schemas as schemas from prefect.server.database import orm_models from prefect.server.utilities.schemas.bases import PrefectBaseModel +from prefect.types import DateTime from prefect.utilities.collections import AutoEnum from prefect.utilities.importtools import lazy_import diff --git a/src/prefect/server/schemas/responses.py b/src/prefect/server/schemas/responses.py index 207631d9473e..1249d7c37573 100644 --- a/src/prefect/server/schemas/responses.py +++ b/src/prefect/server/schemas/responses.py @@ -8,7 +8,6 @@ import pendulum from pydantic import BaseModel, ConfigDict, Field, model_validator -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Literal, Self import prefect.server.schemas as schemas @@ -19,7 +18,7 @@ WorkQueueStatusDetail, ) from prefect.server.utilities.schemas.bases import ORMBaseModel, PrefectBaseModel -from prefect.types import KeyValueLabelsField +from prefect.types import DateTime, KeyValueLabelsField from prefect.utilities.collections import AutoEnum from prefect.utilities.names import generate_slug diff --git a/src/prefect/server/schemas/schedules.py b/src/prefect/server/schemas/schedules.py index cfce1a320a33..ff35b0f219fe 100644 --- a/src/prefect/server/schemas/schedules.py +++ b/src/prefect/server/schemas/schedules.py @@ -11,7 +11,6 @@ import pytz from croniter import croniter from pydantic import AfterValidator, ConfigDict, Field, field_validator, model_validator -from pydantic_extra_types.pendulum_dt import DateTime from prefect._internal.schemas.validators import ( default_anchor_date, @@ -20,7 +19,7 @@ validate_rrule_string, ) from prefect.server.utilities.schemas.bases import PrefectBaseModel -from prefect.types import TimeZone +from prefect.types import DateTime, TimeZone MAX_ITERATIONS = 1000 diff --git a/src/prefect/server/schemas/states.py b/src/prefect/server/schemas/states.py index 69f6b058e3a1..da41805b1956 100644 --- a/src/prefect/server/schemas/states.py +++ b/src/prefect/server/schemas/states.py @@ -9,13 +9,13 @@ import pendulum from pydantic import ConfigDict, Field, field_validator, model_validator -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Self from prefect.server.utilities.schemas.bases import ( IDBaseModel, PrefectBaseModel, ) +from prefect.types import DateTime from prefect.utilities.collections import AutoEnum if TYPE_CHECKING: diff --git a/src/prefect/server/utilities/schemas/bases.py b/src/prefect/server/utilities/schemas/bases.py index da871d38c862..50568a4e100f 100644 --- a/src/prefect/server/utilities/schemas/bases.py +++ b/src/prefect/server/utilities/schemas/bases.py @@ -18,9 +18,10 @@ ConfigDict, Field, ) -from pydantic_extra_types.pendulum_dt import DateTime from typing_extensions import Self +from prefect.types import DateTime + if TYPE_CHECKING: from pydantic.main import IncEx diff --git a/src/prefect/types/__init__.py b/src/prefect/types/__init__.py index 934af32441f8..f36622f5a3df 100644 --- a/src/prefect/types/__init__.py +++ b/src/prefect/types/__init__.py @@ -3,7 +3,8 @@ from typing_extensions import Literal, TypeAlias import orjson import pydantic - +from pydantic_extra_types.pendulum_dt import DateTime as PydanticDateTime +from pydantic_extra_types.pendulum_dt import Date as PydanticDate from pydantic import ( BeforeValidator, Field, @@ -34,6 +35,8 @@ ), ] +DateTime: TypeAlias = PydanticDateTime +Date: TypeAlias = PydanticDate BANNED_CHARACTERS = ["/", "%", "&", ">", "<"] diff --git a/tests/blocks/test_system.py b/tests/blocks/test_system.py index fbfa60a3c29b..57e332c803c0 100644 --- a/tests/blocks/test_system.py +++ b/tests/blocks/test_system.py @@ -2,9 +2,9 @@ import pytest from pydantic import Secret as PydanticSecret from pydantic import SecretStr -from pydantic_extra_types.pendulum_dt import DateTime as PydanticDateTime from prefect.blocks.system import DateTime, Secret +from prefect.types import DateTime as PydanticDateTime def test_datetime(ignore_prefect_deprecation_warnings): diff --git a/tests/cli/deployment/test_deployment_run.py b/tests/cli/deployment/test_deployment_run.py index 1152755fa708..f388b9b027c3 100644 --- a/tests/cli/deployment/test_deployment_run.py +++ b/tests/cli/deployment/test_deployment_run.py @@ -4,7 +4,6 @@ import pendulum import pytest -from pendulum.datetime import DateTime from pendulum.duration import Duration import prefect @@ -12,6 +11,7 @@ from prefect.exceptions import FlowRunWaitTimeout from prefect.states import Completed, Failed from prefect.testing.cli import invoke_and_assert +from prefect.types import DateTime from prefect.utilities.asyncutils import run_sync_in_worker_thread diff --git a/tests/cli/test_flow_run.py b/tests/cli/test_flow_run.py index 6da46bafba03..50eeb1125876 100644 --- a/tests/cli/test_flow_run.py +++ b/tests/cli/test_flow_run.py @@ -1,7 +1,6 @@ from uuid import UUID, uuid4 import pytest -from pydantic_extra_types.pendulum_dt import DateTime import prefect.exceptions from prefect import flow @@ -23,6 +22,7 @@ StateType, ) from prefect.testing.cli import invoke_and_assert +from prefect.types import DateTime from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible diff --git a/tests/client/schemas/test_schedules.py b/tests/client/schemas/test_schedules.py index cbded66d27f7..48680aeba94a 100644 --- a/tests/client/schemas/test_schedules.py +++ b/tests/client/schemas/test_schedules.py @@ -2,7 +2,6 @@ from itertools import combinations import pytest -from pydantic_extra_types.pendulum_dt import DateTime from prefect.client.schemas.schedules import ( CronSchedule, @@ -10,6 +9,7 @@ RRuleSchedule, construct_schedule, ) +from prefect.types import DateTime class TestConstructSchedule: diff --git a/tests/client/test_prefect_client.py b/tests/client/test_prefect_client.py index 92f0f43f96cc..4cf9b9550bb7 100644 --- a/tests/client/test_prefect_client.py +++ b/tests/client/test_prefect_client.py @@ -18,7 +18,6 @@ import respx from fastapi import Depends, FastAPI, status from fastapi.security import HTTPBearer -from pydantic_extra_types.pendulum_dt import DateTime import prefect.client.schemas as client_schemas import prefect.context @@ -90,6 +89,7 @@ from prefect.states import Completed, Pending, Running, Scheduled, State from prefect.tasks import task from prefect.testing.utilities import AsyncMock, exceptions_equal +from prefect.types import DateTime from prefect.utilities.pydantic import parse_obj_as diff --git a/tests/events/client/test_events_emit_event.py b/tests/events/client/test_events_emit_event.py index e31e21592fdc..50a0a4a23bae 100644 --- a/tests/events/client/test_events_emit_event.py +++ b/tests/events/client/test_events_emit_event.py @@ -3,7 +3,6 @@ from uuid import UUID import pendulum -from pydantic_extra_types.pendulum_dt import DateTime from prefect.events import emit_event from prefect.events.clients import AssertingEventsClient @@ -12,6 +11,7 @@ PREFECT_API_URL, temporary_settings, ) +from prefect.types import DateTime def test_emits_simple_event(asserting_events_worker: EventsWorker, reset_worker_events): diff --git a/tests/events/client/test_events_schema.py b/tests/events/client/test_events_schema.py index c5124ce27c6d..a29d58dae83c 100644 --- a/tests/events/client/test_events_schema.py +++ b/tests/events/client/test_events_schema.py @@ -3,9 +3,9 @@ from uuid import UUID, uuid4 import pytest -from pydantic_extra_types.pendulum_dt import DateTime from prefect.events import Event, RelatedResource, Resource +from prefect.types import DateTime def test_client_events_generate_an_id_by_default(): diff --git a/tests/events/server/actions/test_actions_service.py b/tests/events/server/actions/test_actions_service.py index 4f98488e44f6..ef250327d415 100644 --- a/tests/events/server/actions/test_actions_service.py +++ b/tests/events/server/actions/test_actions_service.py @@ -3,13 +3,13 @@ import pendulum import pytest -from pendulum.datetime import DateTime from prefect.server.events import actions from prefect.server.events.clients import AssertingEventsClient from prefect.server.events.schemas.automations import TriggeredAction from prefect.server.utilities.messaging import MessageHandler from prefect.server.utilities.messaging.memory import MemoryMessage +from prefect.types import DateTime @pytest.fixture diff --git a/tests/events/server/actions/test_calling_webhook.py b/tests/events/server/actions/test_calling_webhook.py index 2c4ea5f1cae6..bf75ddaec2af 100644 --- a/tests/events/server/actions/test_calling_webhook.py +++ b/tests/events/server/actions/test_calling_webhook.py @@ -8,7 +8,6 @@ import pendulum import pytest from httpx import Response -from pendulum.datetime import DateTime from pydantic import TypeAdapter from sqlalchemy.ext.asyncio import AsyncSession @@ -35,6 +34,7 @@ from prefect.server.models import deployments, flow_runs, flows, work_queues from prefect.server.schemas.actions import WorkQueueCreate from prefect.server.schemas.core import Deployment, Flow, FlowRun, WorkQueue +from prefect.types import DateTime @pytest.fixture diff --git a/tests/events/server/actions/test_jinja_templated_action.py b/tests/events/server/actions/test_jinja_templated_action.py index fca5279c9be1..793c91a8ebb8 100644 --- a/tests/events/server/actions/test_jinja_templated_action.py +++ b/tests/events/server/actions/test_jinja_templated_action.py @@ -6,7 +6,6 @@ import pendulum import pytest -from pendulum.datetime import DateTime from pydantic import Field, ValidationInfo, field_validator from sqlalchemy.ext.asyncio import AsyncSession @@ -48,6 +47,7 @@ from prefect.server.schemas.responses import FlowRunResponse from prefect.server.schemas.states import State, StateType from prefect.settings import PREFECT_UI_URL, temporary_settings +from prefect.types import DateTime @pytest.fixture(autouse=True) diff --git a/tests/events/server/actions/test_pausing_resuming_automation.py b/tests/events/server/actions/test_pausing_resuming_automation.py index 6a092be21e14..3c4b0af243a9 100644 --- a/tests/events/server/actions/test_pausing_resuming_automation.py +++ b/tests/events/server/actions/test_pausing_resuming_automation.py @@ -4,7 +4,6 @@ import pendulum import pytest -from pendulum.datetime import DateTime from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession @@ -25,6 +24,7 @@ PREFECT_API_SERVICES_TRIGGERS_ENABLED, temporary_settings, ) +from prefect.types import DateTime from prefect.utilities.pydantic import parse_obj_as diff --git a/tests/events/server/actions/test_pausing_resuming_deployment.py b/tests/events/server/actions/test_pausing_resuming_deployment.py index fe5c669c6f71..7e0bb4d66fe2 100644 --- a/tests/events/server/actions/test_pausing_resuming_deployment.py +++ b/tests/events/server/actions/test_pausing_resuming_deployment.py @@ -4,7 +4,6 @@ import pendulum import pytest -from pendulum.datetime import DateTime from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession @@ -23,6 +22,7 @@ from prefect.server.schemas.actions import DeploymentScheduleCreate from prefect.server.schemas.core import Deployment, Flow from prefect.server.schemas.schedules import IntervalSchedule +from prefect.types import DateTime from prefect.utilities.pydantic import parse_obj_as diff --git a/tests/events/server/actions/test_pausing_resuming_work_pool.py b/tests/events/server/actions/test_pausing_resuming_work_pool.py index 7f8ad8b162e3..d9cf20cc2471 100644 --- a/tests/events/server/actions/test_pausing_resuming_work_pool.py +++ b/tests/events/server/actions/test_pausing_resuming_work_pool.py @@ -3,7 +3,6 @@ from uuid import UUID, uuid4 import pytest -from pendulum.datetime import DateTime from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession @@ -21,6 +20,7 @@ from prefect.server.events.schemas.events import ReceivedEvent, RelatedResource from prefect.server.models import workers from prefect.server.schemas.actions import WorkPoolCreate +from prefect.types import DateTime from prefect.utilities.pydantic import parse_obj_as if TYPE_CHECKING: diff --git a/tests/events/server/actions/test_pausing_resuming_work_queue.py b/tests/events/server/actions/test_pausing_resuming_work_queue.py index 395e0a146470..283ed8fe35be 100644 --- a/tests/events/server/actions/test_pausing_resuming_work_queue.py +++ b/tests/events/server/actions/test_pausing_resuming_work_queue.py @@ -4,7 +4,6 @@ import pendulum import pytest -from pendulum.datetime import DateTime from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession @@ -22,6 +21,7 @@ from prefect.server.models import work_queues from prefect.server.schemas.actions import WorkQueueCreate, WorkQueueUpdate from prefect.server.schemas.core import WorkQueue +from prefect.types import DateTime from prefect.utilities.pydantic import parse_obj_as diff --git a/tests/events/server/conftest.py b/tests/events/server/conftest.py index 41759cce7072..e66541d695dd 100644 --- a/tests/events/server/conftest.py +++ b/tests/events/server/conftest.py @@ -6,7 +6,6 @@ import pendulum import pytest import sqlalchemy as sa -from pendulum.datetime import DateTime from sqlalchemy.ext.asyncio import AsyncSession from prefect.server.database.interface import PrefectDBInterface @@ -21,6 +20,7 @@ ) from prefect.server.events.schemas.events import ReceivedEvent from prefect.server.utilities.messaging import Message +from prefect.types import DateTime from prefect.utilities.pydantic import parse_obj_as diff --git a/tests/events/server/models/test_composite_trigger_child_firing.py b/tests/events/server/models/test_composite_trigger_child_firing.py index 5a592e035852..4a1d776d6fd5 100644 --- a/tests/events/server/models/test_composite_trigger_child_firing.py +++ b/tests/events/server/models/test_composite_trigger_child_firing.py @@ -2,7 +2,6 @@ from uuid import uuid4 import pytest -from pendulum.datetime import DateTime from sqlalchemy.ext.asyncio import AsyncSession from prefect.server.events import actions @@ -23,6 +22,7 @@ ) from prefect.server.events.schemas.events import ReceivedEvent from prefect.server.events.triggers import load_automation +from prefect.types import DateTime @pytest.fixture diff --git a/tests/events/server/storage/test_event_persister.py b/tests/events/server/storage/test_event_persister.py index 8d764300194e..a7ea13d018ba 100644 --- a/tests/events/server/storage/test_event_persister.py +++ b/tests/events/server/storage/test_event_persister.py @@ -6,7 +6,6 @@ import pytest import sqlalchemy as sa from pydantic import ValidationError -from pydantic_extra_types.pendulum_dt import DateTime from sqlalchemy.ext.asyncio import AsyncSession from prefect.server.database.dependencies import db_injector @@ -17,6 +16,7 @@ from prefect.server.events.storage.database import query_events, write_events from prefect.server.utilities.messaging import CapturedMessage, Message, MessageHandler from prefect.settings import PREFECT_EVENTS_RETENTION_PERIOD, temporary_settings +from prefect.types import DateTime if TYPE_CHECKING: from prefect.server.database.orm_models import ORMEventResource diff --git a/tests/events/server/test_automations_api.py b/tests/events/server/test_automations_api.py index befc3140ffc6..d02cf6df92a4 100644 --- a/tests/events/server/test_automations_api.py +++ b/tests/events/server/test_automations_api.py @@ -11,7 +11,6 @@ import sqlalchemy as sa from fastapi.applications import FastAPI from httpx import ASGITransport, AsyncClient -from pendulum.datetime import DateTime from sqlalchemy.ext.asyncio import AsyncSession from prefect.server import models as server_models @@ -40,6 +39,7 @@ PREFECT_API_SERVICES_TRIGGERS_ENABLED, temporary_settings, ) +from prefect.types import DateTime from prefect.utilities.pydantic import parse_obj_as diff --git a/tests/events/server/test_clients.py b/tests/events/server/test_clients.py index 841ce6cda6e0..1f71bddb3abc 100644 --- a/tests/events/server/test_clients.py +++ b/tests/events/server/test_clients.py @@ -5,7 +5,6 @@ import pendulum import pytest -from pendulum.datetime import DateTime from prefect.server.events.clients import ( AssertingEventsClient, @@ -14,6 +13,7 @@ ) from prefect.server.events.schemas.events import Event, ReceivedEvent, RelatedResource from prefect.server.utilities.messaging import CapturingPublisher +from prefect.types import DateTime @pytest.fixture diff --git a/tests/events/server/test_events_api.py b/tests/events/server/test_events_api.py index d068251441b4..f7dc36961abb 100644 --- a/tests/events/server/test_events_api.py +++ b/tests/events/server/test_events_api.py @@ -7,7 +7,6 @@ import pendulum import pytest from httpx import AsyncClient -from pendulum.datetime import DateTime from pydantic.networks import AnyHttpUrl from prefect.server.events.counting import Countable, TimeUnit @@ -23,6 +22,7 @@ Resource, ) from prefect.server.events.storage import INTERACTIVE_PAGE_SIZE, InvalidTokenError +from prefect.types import DateTime from prefect.utilities.pydantic import parse_obj_as diff --git a/tests/events/server/test_events_counts.py b/tests/events/server/test_events_counts.py index 7e10960b9d50..7eb2c313f496 100644 --- a/tests/events/server/test_events_counts.py +++ b/tests/events/server/test_events_counts.py @@ -5,7 +5,6 @@ import pendulum import pytest -from pydantic_extra_types.pendulum_dt import Date, DateTime from sqlalchemy.ext.asyncio import AsyncSession from prefect.server.events.counting import PIVOT_DATETIME, Countable, TimeUnit @@ -18,6 +17,7 @@ count_events, write_events, ) +from prefect.types import Date, DateTime # Note: the counts in this module are sensitive to the number and shape of events # we produce in conftest.py and may need to be adjusted if we make changes. diff --git a/tests/events/server/test_events_schema.py b/tests/events/server/test_events_schema.py index 2d12c23fcfc4..a7f5ccf43ca7 100644 --- a/tests/events/server/test_events_schema.py +++ b/tests/events/server/test_events_schema.py @@ -4,7 +4,6 @@ import pendulum import pytest -from pendulum.datetime import DateTime from pydantic import ValidationError from prefect.server.events.schemas.events import ( @@ -13,6 +12,7 @@ RelatedResource, Resource, ) +from prefect.types import DateTime def test_client_events_do_not_have_defaults_for_the_fields_it_seems_they_should(): diff --git a/tests/events/server/triggers/test_basics.py b/tests/events/server/triggers/test_basics.py index 6c3fe9e7fefa..e5a9d9f83c45 100644 --- a/tests/events/server/triggers/test_basics.py +++ b/tests/events/server/triggers/test_basics.py @@ -4,7 +4,6 @@ from uuid import uuid4 import pytest -from pendulum.datetime import DateTime from sqlalchemy.ext.asyncio import AsyncSession from prefect.server.events import actions, triggers @@ -18,6 +17,7 @@ ) from prefect.server.events.schemas.events import ReceivedEvent, matches from prefect.settings import PREFECT_EVENTS_EXPIRED_BUCKET_BUFFER +from prefect.types import DateTime def test_triggers_have_identifiers(arachnophobia: Automation): diff --git a/tests/events/server/triggers/test_composite_triggers.py b/tests/events/server/triggers/test_composite_triggers.py index e0d79d99bb65..728bc1fa88f3 100644 --- a/tests/events/server/triggers/test_composite_triggers.py +++ b/tests/events/server/triggers/test_composite_triggers.py @@ -5,7 +5,6 @@ import pendulum import pytest -from pendulum.datetime import DateTime from sqlalchemy.ext.asyncio import AsyncSession from prefect.server.database.interface import PrefectDBInterface @@ -21,6 +20,7 @@ TriggerState, ) from prefect.server.events.schemas.events import ReceivedEvent +from prefect.types import DateTime @pytest.fixture diff --git a/tests/events/server/triggers/test_flow_run_slas.py b/tests/events/server/triggers/test_flow_run_slas.py index 50e9aaec6a01..e01ec83e31cb 100644 --- a/tests/events/server/triggers/test_flow_run_slas.py +++ b/tests/events/server/triggers/test_flow_run_slas.py @@ -5,7 +5,6 @@ import pendulum import pytest -from pendulum.datetime import DateTime from sqlalchemy.ext.asyncio import AsyncSession from prefect.server.events import actions, triggers @@ -18,6 +17,7 @@ TriggerState, ) from prefect.server.events.schemas.events import Event, ReceivedEvent +from prefect.types import DateTime @pytest.fixture diff --git a/tests/events/server/triggers/test_service.py b/tests/events/server/triggers/test_service.py index c1c4d9f49afb..171d683b382e 100644 --- a/tests/events/server/triggers/test_service.py +++ b/tests/events/server/triggers/test_service.py @@ -7,7 +7,6 @@ import pendulum import pytest -from pendulum.datetime import DateTime from sqlalchemy.ext.asyncio import AsyncSession from prefect.server.events import actions, triggers @@ -22,6 +21,7 @@ ) from prefect.server.utilities.messaging import MessageHandler from prefect.server.utilities.messaging.memory import MemoryMessage +from prefect.types import DateTime async def test_acting_publishes_an_action_message_from_a_reactive_event( diff --git a/tests/server/orchestration/api/ui/test_flows.py b/tests/server/orchestration/api/ui/test_flows.py index 68dd1adbfb73..1e57a694e594 100644 --- a/tests/server/orchestration/api/ui/test_flows.py +++ b/tests/server/orchestration/api/ui/test_flows.py @@ -1,10 +1,10 @@ import pendulum import pytest -from pydantic_extra_types.pendulum_dt import DateTime from prefect.server import models, schemas from prefect.server.api.ui.flows import SimpleNextFlowRun from prefect.server.database import orm_models +from prefect.types import DateTime @pytest.fixture diff --git a/tests/server/orchestration/api/ui/test_task_runs.py b/tests/server/orchestration/api/ui/test_task_runs.py index 0c736d5c6bf3..17c0a759449d 100644 --- a/tests/server/orchestration/api/ui/test_task_runs.py +++ b/tests/server/orchestration/api/ui/test_task_runs.py @@ -3,12 +3,12 @@ import pendulum import pytest from httpx import AsyncClient -from pydantic_extra_types.pendulum_dt import DateTime from sqlalchemy.ext.asyncio import AsyncSession from prefect.server import models from prefect.server.api.ui.task_runs import TaskRunCount from prefect.server.schemas import core, filters, states +from prefect.types import DateTime class TestReadDashboardTaskRunCounts: diff --git a/tests/test_context.py b/tests/test_context.py index 84309f1aba34..8d51fddbb104 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock import pytest -from pendulum.datetime import DateTime import prefect.settings from prefect import flow, task @@ -40,6 +39,7 @@ ) from prefect.states import Running from prefect.task_runners import ThreadPoolTaskRunner +from prefect.types import DateTime class ExampleContext(ContextModel): diff --git a/tests/workers/test_process_worker.py b/tests/workers/test_process_worker.py index 975e6b64b9bd..ba0cfd270705 100644 --- a/tests/workers/test_process_worker.py +++ b/tests/workers/test_process_worker.py @@ -13,7 +13,6 @@ import pytest from exceptiongroup import ExceptionGroup, catch from pydantic import BaseModel -from pydantic_extra_types.pendulum_dt import DateTime from sqlalchemy.ext.asyncio import AsyncSession import prefect @@ -30,6 +29,7 @@ ) from prefect.states import Cancelled, Cancelling, Completed, Pending, Running, Scheduled from prefect.testing.utilities import AsyncMock, MagicMock +from prefect.types import DateTime from prefect.workers.process import ( ProcessWorker, ProcessWorkerResult, From b60fc0c3288befd1a052a3743a315374da766ecb Mon Sep 17 00:00:00 2001 From: Alexander Streed Date: Thu, 12 Dec 2024 09:55:19 -0600 Subject: [PATCH 2/8] Fix type analysis check for PRs from forks (#16359) --- .github/workflows/static-analysis.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/static-analysis.yaml b/.github/workflows/static-analysis.yaml index 45eb836fadc7..c9f9ff18409d 100644 --- a/.github/workflows/static-analysis.yaml +++ b/.github/workflows/static-analysis.yaml @@ -102,10 +102,6 @@ jobs: BASE_SCORE=$(jq -r '.typeCompleteness.completenessScore' prefect-analysis-base.json) echo "base_score=$BASE_SCORE" >> $GITHUB_OUTPUT - - name: Checkout current branch - run: | - git checkout ${{ github.head_ref || github.ref_name }} - - name: Compare scores run: | CURRENT_SCORE=$(echo ${{ steps.calculate_current_score.outputs.current_score }}) From da5439d6a1146308862677cf9fcc62b10ce6ad86 Mon Sep 17 00:00:00 2001 From: Devin Villarosa <102188207+devinvillarosa@users.noreply.github.com> Date: Thu, 12 Dec 2024 08:06:22 -0800 Subject: [PATCH 3/8] [UI v2] test: Adds tests for global concurrency view (#16333) --- .../data-table/active-cell.tsx | 1 + .../data-table/data-table.test.tsx | 121 ++++++++++++++++++ .../data-table/data-table.tsx | 46 +++++-- .../create-or-edit-limit-dialog.test.tsx | 88 +++++++++++++ ...ex.tsx => create-or-edit-limit-dialog.tsx} | 0 .../create-or-edit-limit-dialog/index.ts | 1 + .../dialog/delete-limit-dialog.test.tsx | 43 +++++++ ...bal-concurrency-limit-empty-state.test.tsx | 20 ++- .../global-concurrency-limits-header.test.tsx | 26 ++++ .../global-concurrency-limits-header.tsx | 0 .../global-concurrency-view/header/index.ts | 1 + .../global-concurrency-view/index.tsx | 2 +- .../data-table/active-task-runs-cell.tsx | 13 ++ .../data-table/data-table.test.tsx | 68 ++++++++++ .../data-table/data-table.tsx | 48 +++++-- .../dialogs/create-dialog.test.tsx | 48 +++++++ .../dialogs/delete-dialog.test.tsx | 41 ++++++ .../dialogs/reset-dialog.test.tsx | 45 +++++++ .../task-run-concurrenct-view/header/index.ts | 1 + .../task-run-concurrency-limit-header.tsx} | 0 .../task-run-conrrency-limits-header.test.tsx | 26 ++++ 21 files changed, 607 insertions(+), 32 deletions(-) create mode 100644 ui-v2/src/components/concurrency/global-concurrency-view/data-table/data-table.test.tsx create mode 100644 ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/create-or-edit-limit-dialog.test.tsx rename ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/{index.tsx => create-or-edit-limit-dialog.tsx} (100%) create mode 100644 ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/index.ts create mode 100644 ui-v2/src/components/concurrency/global-concurrency-view/dialog/delete-limit-dialog.test.tsx create mode 100644 ui-v2/src/components/concurrency/global-concurrency-view/header/global-concurrency-limits-header.test.tsx rename ui-v2/src/components/concurrency/global-concurrency-view/{ => header}/global-concurrency-limits-header.tsx (100%) create mode 100644 ui-v2/src/components/concurrency/global-concurrency-view/header/index.ts create mode 100644 ui-v2/src/components/concurrency/task-run-concurrenct-view/data-table/active-task-runs-cell.tsx create mode 100644 ui-v2/src/components/concurrency/task-run-concurrenct-view/data-table/data-table.test.tsx create mode 100644 ui-v2/src/components/concurrency/task-run-concurrenct-view/dialogs/create-dialog.test.tsx create mode 100644 ui-v2/src/components/concurrency/task-run-concurrenct-view/dialogs/delete-dialog.test.tsx create mode 100644 ui-v2/src/components/concurrency/task-run-concurrenct-view/dialogs/reset-dialog.test.tsx create mode 100644 ui-v2/src/components/concurrency/task-run-concurrenct-view/header/index.ts rename ui-v2/src/components/concurrency/task-run-concurrenct-view/{header.tsx => header/task-run-concurrency-limit-header.tsx} (100%) create mode 100644 ui-v2/src/components/concurrency/task-run-concurrenct-view/header/task-run-conrrency-limits-header.test.tsx diff --git a/ui-v2/src/components/concurrency/global-concurrency-view/data-table/active-cell.tsx b/ui-v2/src/components/concurrency/global-concurrency-view/data-table/active-cell.tsx index 17d3ad943165..ce0a452b9a4a 100644 --- a/ui-v2/src/components/concurrency/global-concurrency-view/data-table/active-cell.tsx +++ b/ui-v2/src/components/concurrency/global-concurrency-view/data-table/active-cell.tsx @@ -41,6 +41,7 @@ export const ActiveCell = ( return ( handleCheckedChange(checked, rowId)} /> diff --git a/ui-v2/src/components/concurrency/global-concurrency-view/data-table/data-table.test.tsx b/ui-v2/src/components/concurrency/global-concurrency-view/data-table/data-table.test.tsx new file mode 100644 index 000000000000..16e35f843441 --- /dev/null +++ b/ui-v2/src/components/concurrency/global-concurrency-view/data-table/data-table.test.tsx @@ -0,0 +1,121 @@ +import { Toaster } from "@/components/ui/toaster"; +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { createWrapper } from "@tests/utils"; +import { describe, expect, it, vi } from "vitest"; +import { Table } from "./data-table"; + +const MOCK_ROW = { + id: "0", + created: "2021-01-01T00:00:00Z", + updated: "2021-01-01T00:00:00Z", + active: true, + name: "global concurrency limit 0", + limit: 0, + active_slots: 0, + slot_decay_per_second: 0, +}; + +describe("GlobalConcurrencyLimitTable -- table", () => { + it("renders row data", () => { + render( + , + { wrapper: createWrapper() }, + ); + expect( + screen.getByRole("cell", { name: /global concurrency limit 0/i }), + ).toBeVisible(); + expect( + screen.getByRole("switch", { name: /toggle active/i }), + ).toBeChecked(); + }); + + it("calls onDelete upon clicking delete action menu item", async () => { + const user = userEvent.setup(); + + const mockFn = vi.fn(); + + render( +
, + { wrapper: createWrapper() }, + ); + await user.click( + screen.getByRole("button", { name: /open menu/i, hidden: true }), + ); + await user.click(screen.getByRole("menuitem", { name: /delete/i })); + expect(mockFn).toBeCalledWith(MOCK_ROW); + }); + it("calls onEdit upon clicking rest action menu item", async () => { + const user = userEvent.setup(); + const mockFn = vi.fn(); + + render( +
, + { wrapper: createWrapper() }, + ); + await user.click( + screen.getByRole("button", { name: /open menu/i, hidden: true }), + ); + await user.click(screen.getByRole("menuitem", { name: /edit/i })); + expect(mockFn).toHaveBeenCalledWith(MOCK_ROW); + }); + + it("toggles active switch", async () => { + const user = userEvent.setup(); + + const { rerender } = render( + <> + +
+ , + { wrapper: createWrapper() }, + ); + expect( + screen.getByRole("switch", { name: /toggle active/i }), + ).toBeChecked(); + + await user.click( + screen.getByRole("switch", { + name: /toggle active/i, + }), + ); + expect(screen.getByText("Concurrency limit updated")).toBeVisible(); + rerender( +
, + ); + + expect( + screen.getByRole("switch", { name: /toggle active/i }), + ).not.toBeChecked(); + }); +}); diff --git a/ui-v2/src/components/concurrency/global-concurrency-view/data-table/data-table.tsx b/ui-v2/src/components/concurrency/global-concurrency-view/data-table/data-table.tsx index 77ae3de0184d..d246a82a0885 100644 --- a/ui-v2/src/components/concurrency/global-concurrency-view/data-table/data-table.tsx +++ b/ui-v2/src/components/concurrency/global-concurrency-view/data-table/data-table.tsx @@ -69,9 +69,40 @@ export const GlobalConcurrencyDataTable = ({ ); }, [data, deferredSearch]); + return ( +
+ void navigate({ + to: ".", + search: (prev) => ({ ...prev, search: value }), + }) + } + /> + ); +}; + +type TableProps = { + data: Array; + onDeleteRow: (row: GlobalConcurrencyLimit) => void; + onEditRow: (row: GlobalConcurrencyLimit) => void; + onSearchChange: (value: string) => void; + searchValue: string | undefined; +}; + +export function Table({ + data, + onDeleteRow, + onEditRow, + onSearchChange, + searchValue, +}: TableProps) { const table = useReactTable({ - data: filteredData, - columns: createColumns({ onEditRow, onDeleteRow }), + data, + columns: createColumns({ onDeleteRow, onEditRow }), getCoreRowModel: getCoreRowModel(), getPaginationRowModel: getPaginationRowModel(), //load client-side pagination code }); @@ -80,15 +111,10 @@ export const GlobalConcurrencyDataTable = ({
- void navigate({ - to: ".", - search: (prev) => ({ ...prev, search: e.target.value }), - }) - } + value={searchValue} + onChange={(e) => onSearchChange(e.target.value)} />
); -}; +} diff --git a/ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/create-or-edit-limit-dialog.test.tsx b/ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/create-or-edit-limit-dialog.test.tsx new file mode 100644 index 000000000000..e8473af769a8 --- /dev/null +++ b/ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/create-or-edit-limit-dialog.test.tsx @@ -0,0 +1,88 @@ +import { CreateOrEditLimitDialog } from "./create-or-edit-limit-dialog"; + +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { createWrapper } from "@tests/utils"; +import { beforeAll, describe, expect, it, vi } from "vitest"; + +const MOCK_DATA = { + id: "0", + created: "2021-01-01T00:00:00Z", + updated: "2021-01-01T00:00:00Z", + active: false, + name: "global concurrency limit 0", + limit: 0, + active_slots: 0, + slot_decay_per_second: 0, +}; + +describe("CreateOrEditLimitDialog", () => { + beforeAll(() => { + class ResizeObserverMock { + observe() {} + unobserve() {} + disconnect() {} + } + + global.ResizeObserver = ResizeObserverMock; + }); + + it("able to create a new limit", async () => { + const user = userEvent.setup(); + + // ------------ Setup + const mockOnSubmitFn = vi.fn(); + render( + , + { wrapper: createWrapper() }, + ); + // ------------ Act + + await user.type(screen.getByLabelText(/name/i), MOCK_DATA.name); + await user.type( + screen.getByLabelText("Concurrency Limit"), + MOCK_DATA.limit.toString(), + ); + await user.type( + screen.getByLabelText("Slot Decay Per Second"), + MOCK_DATA.slot_decay_per_second.toString(), + ); + await user.click(screen.getByRole("button", { name: /save/i })); + + // ------------ Assert + expect(mockOnSubmitFn).toHaveBeenCalledOnce(); + }); + + it("able to edit a limit", async () => { + const user = userEvent.setup(); + + // ------------ Setup + const mockOnSubmitFn = vi.fn(); + render( + , + { wrapper: createWrapper() }, + ); + // ------------ Act + + await user.type(screen.getByLabelText(/name/i), MOCK_DATA.name); + await user.type( + screen.getByLabelText("Concurrency Limit"), + MOCK_DATA.limit.toString(), + ); + await user.type( + screen.getByLabelText("Slot Decay Per Second"), + MOCK_DATA.slot_decay_per_second.toString(), + ); + await user.click(screen.getByRole("button", { name: /update/i })); + + // ------------ Assert + expect(mockOnSubmitFn).toHaveBeenCalledOnce(); + }); +}); diff --git a/ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/index.tsx b/ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/create-or-edit-limit-dialog.tsx similarity index 100% rename from ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/index.tsx rename to ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/create-or-edit-limit-dialog.tsx diff --git a/ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/index.ts b/ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/index.ts new file mode 100644 index 000000000000..3d271c515e53 --- /dev/null +++ b/ui-v2/src/components/concurrency/global-concurrency-view/dialog/create-or-edit-limit-dialog/index.ts @@ -0,0 +1 @@ +export { CreateOrEditLimitDialog } from "./create-or-edit-limit-dialog"; diff --git a/ui-v2/src/components/concurrency/global-concurrency-view/dialog/delete-limit-dialog.test.tsx b/ui-v2/src/components/concurrency/global-concurrency-view/dialog/delete-limit-dialog.test.tsx new file mode 100644 index 000000000000..3f9f227f8618 --- /dev/null +++ b/ui-v2/src/components/concurrency/global-concurrency-view/dialog/delete-limit-dialog.test.tsx @@ -0,0 +1,43 @@ +import { DeleteLimitDialog } from "./delete-limit-dialog"; + +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { createWrapper } from "@tests/utils"; +import { expect, test, vi } from "vitest"; + +const MOCK_DATA = { + id: "0", + created: "2021-01-01T00:00:00Z", + updated: "2021-01-01T00:00:00Z", + active: false, + name: "global concurrency limit 0", + limit: 0, + active_slots: 0, + slot_decay_per_second: 0, +}; + +test("DeleteLimitDialog can successfully call delete", async () => { + const user = userEvent.setup(); + + // ------------ Setup + const mockOnDeleteFn = vi.fn(); + render( + , + { wrapper: createWrapper() }, + ); + + // ------------ Act + expect(screen.getByRole("heading", { name: /delete concurrency limit/i })); + await user.click( + screen.getByRole("button", { + name: /delete/i, + }), + ); + + // ------------ Assert + expect(mockOnDeleteFn).toHaveBeenCalledOnce(); +}); diff --git a/ui-v2/src/components/concurrency/global-concurrency-view/empty-state/global-concurrency-limit-empty-state.test.tsx b/ui-v2/src/components/concurrency/global-concurrency-view/empty-state/global-concurrency-limit-empty-state.test.tsx index 48cf3bb962a1..b2590e75f2dc 100644 --- a/ui-v2/src/components/concurrency/global-concurrency-view/empty-state/global-concurrency-limit-empty-state.test.tsx +++ b/ui-v2/src/components/concurrency/global-concurrency-view/empty-state/global-concurrency-limit-empty-state.test.tsx @@ -1,18 +1,16 @@ import { render, screen } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; -import { describe, expect, it, vi } from "vitest"; +import { expect, test, vi } from "vitest"; import { GlobalConcurrencyLimitEmptyState } from "./global-concurrency-limit-empty-state"; -describe("GlobalConcurrencyLimitEmptyState", () => { - it("when adding limit, callback gets fired", async () => { - const user = userEvent.setup(); +test("GlobalConcurrencyLimitEmptyState", async () => { + const user = userEvent.setup(); - const mockFn = vi.fn(); + const mockFn = vi.fn(); - render(); - await user.click( - screen.getByRole("button", { name: /Add Concurrency Limit/i }), - ); - expect(mockFn).toHaveBeenCalledOnce(); - }); + render(); + await user.click( + screen.getByRole("button", { name: /Add Concurrency Limit/i }), + ); + expect(mockFn).toHaveBeenCalledOnce(); }); diff --git a/ui-v2/src/components/concurrency/global-concurrency-view/header/global-concurrency-limits-header.test.tsx b/ui-v2/src/components/concurrency/global-concurrency-view/header/global-concurrency-limits-header.test.tsx new file mode 100644 index 000000000000..c067a7d29a42 --- /dev/null +++ b/ui-v2/src/components/concurrency/global-concurrency-view/header/global-concurrency-limits-header.test.tsx @@ -0,0 +1,26 @@ +import { GlobalConcurrencyLimitsHeader } from "./global-concurrency-limits-header"; + +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { expect, test, vi } from "vitest"; + +test("GlobalConcurrencyLimitsHeader can successfully call onAdd", async () => { + const user = userEvent.setup(); + + // ------------ Setup + const mockOnAddFn = vi.fn(); + render(); + + // ------------ Act + expect( + screen.getByRole("heading", { name: /global concurrency limits/i }), + ).toBeVisible(); + await user.click( + screen.getByRole("button", { + name: /add global concurrency limit/i, + }), + ); + + // ------------ Assert + expect(mockOnAddFn).toHaveBeenCalledOnce(); +}); diff --git a/ui-v2/src/components/concurrency/global-concurrency-view/global-concurrency-limits-header.tsx b/ui-v2/src/components/concurrency/global-concurrency-view/header/global-concurrency-limits-header.tsx similarity index 100% rename from ui-v2/src/components/concurrency/global-concurrency-view/global-concurrency-limits-header.tsx rename to ui-v2/src/components/concurrency/global-concurrency-view/header/global-concurrency-limits-header.tsx diff --git a/ui-v2/src/components/concurrency/global-concurrency-view/header/index.ts b/ui-v2/src/components/concurrency/global-concurrency-view/header/index.ts new file mode 100644 index 000000000000..421bd0174d83 --- /dev/null +++ b/ui-v2/src/components/concurrency/global-concurrency-view/header/index.ts @@ -0,0 +1 @@ +export { GlobalConcurrencyLimitsHeader } from "./global-concurrency-limits-header"; diff --git a/ui-v2/src/components/concurrency/global-concurrency-view/index.tsx b/ui-v2/src/components/concurrency/global-concurrency-view/index.tsx index 11d9d1a31a99..b66109159506 100644 --- a/ui-v2/src/components/concurrency/global-concurrency-view/index.tsx +++ b/ui-v2/src/components/concurrency/global-concurrency-view/index.tsx @@ -7,7 +7,7 @@ import { useState } from "react"; import { GlobalConcurrencyDataTable } from "./data-table"; import { type DialogState, DialogView } from "./dialog"; import { GlobalConcurrencyLimitEmptyState } from "./empty-state"; -import { GlobalConcurrencyLimitsHeader } from "./global-concurrency-limits-header"; +import { GlobalConcurrencyLimitsHeader } from "./header"; export const GlobalConcurrencyView = () => { const [openDialog, setOpenDialog] = useState({ diff --git a/ui-v2/src/components/concurrency/task-run-concurrenct-view/data-table/active-task-runs-cell.tsx b/ui-v2/src/components/concurrency/task-run-concurrenct-view/data-table/active-task-runs-cell.tsx new file mode 100644 index 000000000000..ac84b50c854a --- /dev/null +++ b/ui-v2/src/components/concurrency/task-run-concurrenct-view/data-table/active-task-runs-cell.tsx @@ -0,0 +1,13 @@ +import { type TaskRunConcurrencyLimit } from "@/hooks/task-run-concurrency-limits"; +import { CellContext } from "@tanstack/react-table"; + +type Props = CellContext>; + +export const ActiveTaskRunCells = (props: Props) => { + const activeTaskRuns = props.getValue(); + const numActiveTaskRuns = activeTaskRuns.length; + if (numActiveTaskRuns === 0) { + return "None"; + } + return numActiveTaskRuns; +}; diff --git a/ui-v2/src/components/concurrency/task-run-concurrenct-view/data-table/data-table.test.tsx b/ui-v2/src/components/concurrency/task-run-concurrenct-view/data-table/data-table.test.tsx new file mode 100644 index 000000000000..85961acb10e5 --- /dev/null +++ b/ui-v2/src/components/concurrency/task-run-concurrenct-view/data-table/data-table.test.tsx @@ -0,0 +1,68 @@ +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { describe, expect, it, vi } from "vitest"; +import { Table } from "./data-table"; + +const MOCK_ROW = { + id: "0", + created: "2021-01-01T00:00:00Z", + updated: "2021-01-01T00:00:00Z", + tag: "my tag 0", + concurrency_limit: 1, + active_slots: [] as Array, +}; + +describe("TaskRunDataTable -- table", () => { + it("renders row data", () => { + render( +
, + ); + expect(screen.getByRole("cell", { name: /my tag 0/i })).toBeVisible(); + expect(screen.getByRole("cell", { name: /1/i })).toBeVisible(); + }); + it("calls onDelete upon clicking delete action menu item", async () => { + const user = userEvent.setup(); + + const mockFn = vi.fn(); + + render( +
, + ); + await user.click( + screen.getByRole("button", { name: /open menu/i, hidden: true }), + ); + await user.click(screen.getByRole("menuitem", { name: /delete/i })); + expect(mockFn).toHaveBeenCalledWith(MOCK_ROW); + }); + it("calls onReset upon clicking rest action menu item", async () => { + const user = userEvent.setup(); + const mockFn = vi.fn(); + + render( +
, + ); + await user.click( + screen.getByRole("button", { name: /open menu/i, hidden: true }), + ); + await user.click(screen.getByRole("menuitem", { name: /reset/i })); + expect(mockFn).toHaveBeenCalledWith(MOCK_ROW); + }); +}); diff --git a/ui-v2/src/components/concurrency/task-run-concurrenct-view/data-table/data-table.tsx b/ui-v2/src/components/concurrency/task-run-concurrenct-view/data-table/data-table.tsx index 1fda29f44fd7..2069baa90ad5 100644 --- a/ui-v2/src/components/concurrency/task-run-concurrenct-view/data-table/data-table.tsx +++ b/ui-v2/src/components/concurrency/task-run-concurrenct-view/data-table/data-table.tsx @@ -11,6 +11,7 @@ import { import { SearchInput } from "@/components/ui/input"; import { useDeferredValue, useMemo } from "react"; import { ActionsCell } from "./actions-cell"; +import { ActiveTaskRunCells } from "./active-task-runs-cell"; const routeApi = getRouteApi("/concurrency-limits"); @@ -30,7 +31,8 @@ const createColumns = ({ header: "Slots", }), columnHelper.accessor("active_slots", { - header: "Active Task Runs", // TODO: Give this styling once knowing what it looks like + header: "Active Task Runs", + cell: ActiveTaskRunCells, }), columnHelper.display({ id: "actions", @@ -65,8 +67,39 @@ export const TaskRunConcurrencyDataTable = ({ ); }, [data, deferredSearch]); + return ( +
+ void navigate({ + to: ".", + search: (prev) => ({ ...prev, search: value }), + }) + } + /> + ); +}; + +type TableProps = { + data: Array; + onDeleteRow: (row: TaskRunConcurrencyLimit) => void; + onResetRow: (row: TaskRunConcurrencyLimit) => void; + onSearchChange: (value: string) => void; + searchValue: string | undefined; +}; + +export function Table({ + data, + onDeleteRow, + onResetRow, + onSearchChange, + searchValue, +}: TableProps) { const table = useReactTable({ - data: filteredData, + data, columns: createColumns({ onDeleteRow, onResetRow }), getCoreRowModel: getCoreRowModel(), getPaginationRowModel: getPaginationRowModel(), //load client-side pagination code @@ -76,15 +109,10 @@ export const TaskRunConcurrencyDataTable = ({
- void navigate({ - to: ".", - search: (prev) => ({ ...prev, search: e.target.value }), - }) - } + value={searchValue} + onChange={(e) => onSearchChange(e.target.value)} />
); -}; +} diff --git a/ui-v2/src/components/concurrency/task-run-concurrenct-view/dialogs/create-dialog.test.tsx b/ui-v2/src/components/concurrency/task-run-concurrenct-view/dialogs/create-dialog.test.tsx new file mode 100644 index 000000000000..65c17dc3ac09 --- /dev/null +++ b/ui-v2/src/components/concurrency/task-run-concurrenct-view/dialogs/create-dialog.test.tsx @@ -0,0 +1,48 @@ +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { createWrapper } from "@tests/utils"; +import { beforeAll, describe, expect, it, vi } from "vitest"; + +import { CreateLimitDialog } from "./create-dialog"; + +const MOCK_DATA = { + id: "0", + created: "2021-01-01T00:00:00Z", + updated: "2021-01-01T00:00:00Z", + tag: "my tag 0", + concurrency_limit: 1, + active_slots: [] as Array, +}; + +describe("CreateLimitDialog", () => { + beforeAll(() => { + class ResizeObserverMock { + observe() {} + unobserve() {} + disconnect() {} + } + global.ResizeObserver = ResizeObserverMock; + }); + it.skip("calls onSubmit upon entering form data", async () => { + const user = userEvent.setup(); + + // ------------ Setup + const mockOnSubmitFn = vi.fn(); + render( + , + { wrapper: createWrapper() }, + ); + + // ------------ Act + await user.type(screen.getByLabelText(/tag/i), MOCK_DATA.tag); + await user.type( + screen.getByLabelText("Concurrency Limit"), + String(MOCK_DATA.concurrency_limit), + ); + + await user.click(screen.getByRole("button", { name: /add/i })); + + // ------------ Assert + expect(mockOnSubmitFn).toHaveBeenCalledOnce(); + }); +}); diff --git a/ui-v2/src/components/concurrency/task-run-concurrenct-view/dialogs/delete-dialog.test.tsx b/ui-v2/src/components/concurrency/task-run-concurrenct-view/dialogs/delete-dialog.test.tsx new file mode 100644 index 000000000000..2eae080d669e --- /dev/null +++ b/ui-v2/src/components/concurrency/task-run-concurrenct-view/dialogs/delete-dialog.test.tsx @@ -0,0 +1,41 @@ +import { DeleteLimitDialog } from "./delete-dialog"; + +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { createWrapper } from "@tests/utils"; +import { expect, test, vi } from "vitest"; + +const MOCK_DATA = { + id: "0", + created: "2021-01-01T00:00:00Z", + updated: "2021-01-01T00:00:00Z", + tag: "my tag 0", + concurrency_limit: 1, + active_slots: [] as Array, +}; + +test("DeleteLimitDialog can successfully call delete", async () => { + const user = userEvent.setup(); + + // ------------ Setup + const mockOnDeleteFn = vi.fn(); + render( + , + { wrapper: createWrapper() }, + ); + + // ------------ Act + expect(screen.getByRole("heading", { name: /delete concurrency limit/i })); + await user.click( + screen.getByRole("button", { + name: /delete/i, + }), + ); + + // ------------ Assert + expect(mockOnDeleteFn).toHaveBeenCalledOnce(); +}); diff --git a/ui-v2/src/components/concurrency/task-run-concurrenct-view/dialogs/reset-dialog.test.tsx b/ui-v2/src/components/concurrency/task-run-concurrenct-view/dialogs/reset-dialog.test.tsx new file mode 100644 index 000000000000..fc789ac15a75 --- /dev/null +++ b/ui-v2/src/components/concurrency/task-run-concurrenct-view/dialogs/reset-dialog.test.tsx @@ -0,0 +1,45 @@ +import { ResetLimitDialog } from "./reset-dialog"; + +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { createWrapper } from "@tests/utils"; +import { expect, test, vi } from "vitest"; + +const MOCK_DATA = { + id: "0", + created: "2021-01-01T00:00:00Z", + updated: "2021-01-01T00:00:00Z", + tag: "my tag 0", + concurrency_limit: 1, + active_slots: [] as Array, +}; + +test("ResetLimitDialog can successfully call delete", async () => { + const user = userEvent.setup(); + + // ------------ Setup + const mockOnResetFn = vi.fn(); + render( + , + { wrapper: createWrapper() }, + ); + + // ------------ Act + expect( + screen.getByRole("heading", { + name: /reset concurrency limit for tag my tag 0/i, + }), + ); + await user.click( + screen.getByRole("button", { + name: /reset/i, + }), + ); + + // ------------ Assert + expect(mockOnResetFn).toHaveBeenCalledOnce(); +}); diff --git a/ui-v2/src/components/concurrency/task-run-concurrenct-view/header/index.ts b/ui-v2/src/components/concurrency/task-run-concurrenct-view/header/index.ts new file mode 100644 index 000000000000..5c98b17a6b83 --- /dev/null +++ b/ui-v2/src/components/concurrency/task-run-concurrenct-view/header/index.ts @@ -0,0 +1 @@ +export { TaskRunConcurrencyLimitsHeader } from "./task-run-concurrency-limit-header"; diff --git a/ui-v2/src/components/concurrency/task-run-concurrenct-view/header.tsx b/ui-v2/src/components/concurrency/task-run-concurrenct-view/header/task-run-concurrency-limit-header.tsx similarity index 100% rename from ui-v2/src/components/concurrency/task-run-concurrenct-view/header.tsx rename to ui-v2/src/components/concurrency/task-run-concurrenct-view/header/task-run-concurrency-limit-header.tsx diff --git a/ui-v2/src/components/concurrency/task-run-concurrenct-view/header/task-run-conrrency-limits-header.test.tsx b/ui-v2/src/components/concurrency/task-run-concurrenct-view/header/task-run-conrrency-limits-header.test.tsx new file mode 100644 index 000000000000..6a8198620b2f --- /dev/null +++ b/ui-v2/src/components/concurrency/task-run-concurrenct-view/header/task-run-conrrency-limits-header.test.tsx @@ -0,0 +1,26 @@ +import { TaskRunConcurrencyLimitsHeader } from "./task-run-concurrency-limit-header"; + +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { expect, test, vi } from "vitest"; + +test("TaskRunConcurrencyLimitsHeader can successfully call onAdd", async () => { + const user = userEvent.setup(); + + // ------------ Setup + const mockOnAddFn = vi.fn(); + render(); + + // ------------ Act + expect( + screen.getByRole("heading", { name: /task run concurrency limits/i }), + ).toBeVisible(); + await user.click( + screen.getByRole("button", { + name: /add task run concurrency limit/i, + }), + ); + + // ------------ Assert + expect(mockOnAddFn).toHaveBeenCalledOnce(); +}); From a8e098098cff4c67714c432fd2c17f4a96f7a84d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ladislav=20G=C3=A1l?= <129292521+GalLadislav@users.noreply.github.com> Date: Thu, 12 Dec 2024 17:14:54 +0100 Subject: [PATCH 4/8] Fix missing terminal state timings for TimedOut tasks (#16328) --- src/prefect/task_engine.py | 2 ++ tests/test_task_engine.py | 52 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index a07f2108c95a..d6b834d6bca6 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -600,6 +600,7 @@ def handle_timeout(self, exc: TimeoutError) -> None: message=message, name="TimedOut", ) + self.record_terminal_state_timing(state) self.set_state(state) self._raised = exc @@ -1134,6 +1135,7 @@ async def handle_timeout(self, exc: TimeoutError) -> None: message=message, name="TimedOut", ) + self.record_terminal_state_timing(state) await self.set_state(state) self._raised = exc self._telemetry.end_span_on_failure(state.message) diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index 185045fb9455..c4f1847b160c 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -1511,6 +1511,58 @@ async def foo(): assert run.end_time == failed.timestamp assert run.total_run_time == failed.timestamp - running.timestamp + async def test_sync_task_sets_end_time_on_failed_timedout( + self, prefect_client, events_pipeline + ): + ID = None + + @task + def foo(): + nonlocal ID + ID = TaskRunContext.get().task_run.id + raise TimeoutError + + with pytest.raises(TimeoutError): + run_task_sync(foo) + + await events_pipeline.process_events() + + run = await prefect_client.read_task_run(ID) + + states = await prefect_client.read_task_run_states(ID) + running = [state for state in states if state.type == StateType.RUNNING][0] + failed = [state for state in states if state.type == StateType.FAILED][0] + + assert failed.name == "TimedOut" + assert run.end_time + assert run.end_time == failed.timestamp + assert run.total_run_time == failed.timestamp - running.timestamp + + async def test_async_task_sets_end_time_on_failed_timedout( + self, prefect_client, events_pipeline + ): + ID = None + + @task + async def foo(): + nonlocal ID + ID = TaskRunContext.get().task_run.id + raise TimeoutError + + with pytest.raises(TimeoutError): + await run_task_async(foo) + + await events_pipeline.process_events() + run = await prefect_client.read_task_run(ID) + states = await prefect_client.read_task_run_states(ID) + running = [state for state in states if state.type == StateType.RUNNING][0] + failed = [state for state in states if state.type == StateType.FAILED][0] + + assert failed.name == "TimedOut" + assert run.end_time + assert run.end_time == failed.timestamp + assert run.total_run_time == failed.timestamp - running.timestamp + async def test_sync_task_sets_end_time_on_crashed( self, prefect_client, events_pipeline ): From dce8eca583a2322d96aef3acba5988a6e6642bf7 Mon Sep 17 00:00:00 2001 From: Jean Luciano Date: Thu, 12 Dec 2024 10:52:32 -0600 Subject: [PATCH 5/8] Update flow run instrumentation to use `RunTelemetry` class (#16233) Co-authored-by: Chris Pickett --- src/prefect/flow_engine.py | 146 +++++------ src/prefect/server/models/flow_runs.py | 3 + src/prefect/task_engine.py | 54 ++-- src/prefect/telemetry/run_telemetry.py | 95 +++++-- tests/telemetry/test_instrumentation.py | 253 +++++++++++++++++- tests/test_flow_engine.py | 328 +----------------------- 6 files changed, 413 insertions(+), 466 deletions(-) diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index c37154a09cdf..45a18c35246f 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -24,10 +24,8 @@ from anyio import CancelScope from opentelemetry import propagate, trace -from opentelemetry.trace import Tracer, get_tracer from typing_extensions import ParamSpec -import prefect from prefect import Task from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client from prefect.client.schemas import FlowRun, TaskRun @@ -72,7 +70,12 @@ exception_to_failed_state, return_value_to_state, ) -from prefect.telemetry.run_telemetry import OTELSetter +from prefect.telemetry.run_telemetry import ( + LABELS_TRACEPARENT_KEY, + TRACEPARENT_KEY, + OTELSetter, + RunTelemetry, +) from prefect.types import KeyValueLabels from prefect.utilities._engine import get_hook_name, resolve_custom_flow_run_name from prefect.utilities.annotations import NotSet @@ -95,8 +98,6 @@ P = ParamSpec("P") R = TypeVar("R") -LABELS_TRACEPARENT_KEY = "__OTEL_TRACEPARENT" -TRACEPARENT_KEY = "traceparent" class FlowRunTimeoutError(TimeoutError): @@ -136,10 +137,7 @@ class BaseFlowRunEngine(Generic[P, R]): _is_started: bool = False short_circuit: bool = False _flow_run_name_set: bool = False - _tracer: Tracer = field( - default_factory=lambda: get_tracer("prefect", prefect.__version__) - ) - _span: Optional[trace.Span] = None + _telemetry: RunTelemetry = field(default_factory=RunTelemetry) def __post_init__(self): if self.flow is None and self.flow_run_id is None: @@ -152,21 +150,6 @@ def __post_init__(self): def state(self) -> State: return self.flow_run.state # type: ignore - def _end_span_on_success(self): - if not self._span: - return - self._span.set_status(trace.Status(trace.StatusCode.OK)) - self._span.end(time.time_ns()) - self._span = None - - def _end_span_on_error(self, exc: BaseException, description: Optional[str]): - if not self._span: - return - self._span.record_exception(exc) - self._span.set_status(trace.Status(trace.StatusCode.ERROR, description)) - self._span.end(time.time_ns()) - self._span = None - def is_running(self) -> bool: if getattr(self, "flow_run", None) is None: return False @@ -185,6 +168,7 @@ def _update_otel_labels( self, span: trace.Span, client: Union[SyncPrefectClient, PrefectClient] ): parent_flow_run_ctx = FlowRunContext.get() + if parent_flow_run_ctx and parent_flow_run_ctx.flow_run: if traceparent := parent_flow_run_ctx.flow_run.labels.get( LABELS_TRACEPARENT_KEY @@ -194,6 +178,7 @@ def _update_otel_labels( carrier={TRACEPARENT_KEY: traceparent}, setter=OTELSetter(), ) + else: carrier: KeyValueLabels = {} propagate.get_global_textmap().inject( @@ -315,16 +300,7 @@ def set_state(self, state: State, force: bool = False) -> State: self.flow_run.state_name = state.name # type: ignore self.flow_run.state_type = state.type # type: ignore - if self._span: - self._span.add_event( - state.name or state.type, - { - "prefect.state.message": state.message or "", - "prefect.state.type": state.type, - "prefect.state.name": state.name or state.type, - "prefect.state.id": str(state.id), - }, - ) + self._telemetry.update_state(state) return state def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": @@ -374,7 +350,7 @@ def handle_success(self, result: R) -> R: self.set_state(terminal_state) self._return_value = resolved_result - self._end_span_on_success() + self._telemetry.end_span_on_success() return result @@ -406,8 +382,8 @@ def handle_exception( ) state = self.set_state(Running()) self._raised = exc - - self._end_span_on_error(exc, state.message) + self._telemetry.record_exception(exc) + self._telemetry.end_span_on_failure(state.message) return state @@ -426,8 +402,8 @@ def handle_timeout(self, exc: TimeoutError) -> None: ) self.set_state(state) self._raised = exc - - self._end_span_on_error(exc, message) + self._telemetry.record_exception(exc) + self._telemetry.end_span_on_failure(message) def handle_crash(self, exc: BaseException) -> None: state = run_coro_as_sync(exception_to_crashed_state(exc)) @@ -435,8 +411,8 @@ def handle_crash(self, exc: BaseException) -> None: self.logger.debug("Crash details:", exc_info=exc) self.set_state(state, force=True) self._raised = exc - - self._end_span_on_error(exc, state.message if state else "") + self._telemetry.record_exception(exc) + self._telemetry.end_span_on_failure(state.message if state else None) def load_subflow_run( self, @@ -680,20 +656,23 @@ def initialize_run(self): flow_version=self.flow.version, empirical_policy=self.flow_run.empirical_policy, ) - - span = self._tracer.start_span( - name=self.flow_run.name, - attributes={ - **self.flow_run.labels, - "prefect.run.type": "flow", - "prefect.run.id": str(self.flow_run.id), - "prefect.tags": self.flow_run.tags, - "prefect.flow.name": self.flow.name, - }, + parent_flow_run = FlowRunContext.get() + parent_labels = {} + if parent_flow_run and parent_flow_run.flow_run: + parent_labels = parent_flow_run.flow_run.labels + + self._telemetry.start_span( + name=self.flow.name, + run=self.flow_run, + parameters=self.parameters, + parent_labels=parent_labels, ) - self._update_otel_labels(span, self.client) - - self._span = span + carrier = self._telemetry.propagate_traceparent() + if carrier: + self.client.update_flow_run_labels( + flow_run_id=self.flow_run.id, + labels={LABELS_TRACEPARENT_KEY: carrier[TRACEPARENT_KEY]}, + ) try: yield self @@ -736,7 +715,9 @@ def initialize_run(self): @contextmanager def start(self) -> Generator[None, None, None]: with self.initialize_run(): - with trace.use_span(self._span) if self._span else nullcontext(): + with trace.use_span( + self._telemetry.span + ) if self._telemetry.span else nullcontext(): self.begin_run() if self.state.is_running(): @@ -892,16 +873,7 @@ async def set_state(self, state: State, force: bool = False) -> State: self.flow_run.state_name = state.name # type: ignore self.flow_run.state_type = state.type # type: ignore - if self._span: - self._span.add_event( - state.name or state.type, - { - "prefect.state.message": state.message or "", - "prefect.state.type": state.type, - "prefect.state.name": state.name or state.type, - "prefect.state.id": str(state.id), - }, - ) + self._telemetry.update_state(state) return state async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": @@ -949,7 +921,7 @@ async def handle_success(self, result: R) -> R: await self.set_state(terminal_state) self._return_value = resolved_result - self._end_span_on_success() + self._telemetry.end_span_on_success() return result @@ -979,8 +951,8 @@ async def handle_exception( ) state = await self.set_state(Running()) self._raised = exc - - self._end_span_on_error(exc, state.message) + self._telemetry.record_exception(exc) + self._telemetry.end_span_on_failure(state.message) return state @@ -1000,7 +972,8 @@ async def handle_timeout(self, exc: TimeoutError) -> None: await self.set_state(state) self._raised = exc - self._end_span_on_error(exc, message) + self._telemetry.record_exception(exc) + self._telemetry.end_span_on_failure(message) async def handle_crash(self, exc: BaseException) -> None: # need to shield from asyncio cancellation to ensure we update the state @@ -1012,7 +985,8 @@ async def handle_crash(self, exc: BaseException) -> None: await self.set_state(state, force=True) self._raised = exc - self._end_span_on_error(exc, state.message) + self._telemetry.record_exception(exc) + self._telemetry.end_span_on_failure(state.message) async def load_subflow_run( self, @@ -1254,19 +1228,23 @@ async def initialize_run(self): flow_version=self.flow.version, empirical_policy=self.flow_run.empirical_policy, ) - - span = self._tracer.start_span( - name=self.flow_run.name, - attributes={ - **self.flow_run.labels, - "prefect.run.type": "flow", - "prefect.run.id": str(self.flow_run.id), - "prefect.tags": self.flow_run.tags, - "prefect.flow.name": self.flow.name, - }, + parent_flow_run = FlowRunContext.get() + parent_labels = {} + if parent_flow_run and parent_flow_run.flow_run: + parent_labels = parent_flow_run.flow_run.labels + + self._telemetry.start_span( + name=self.flow.name, + run=self.flow_run, + parameters=self.parameters, + parent_labels=parent_labels, ) - self._update_otel_labels(span, self.client) - self._span = span + carrier = self._telemetry.propagate_traceparent() + if carrier: + await self.client.update_flow_run_labels( + flow_run_id=self.flow_run.id, + labels={LABELS_TRACEPARENT_KEY: carrier[TRACEPARENT_KEY]}, + ) try: yield self @@ -1309,7 +1287,9 @@ async def initialize_run(self): @asynccontextmanager async def start(self) -> AsyncGenerator[None, None]: async with self.initialize_run(): - with trace.use_span(self._span) if self._span else nullcontext(): + with trace.use_span( + self._telemetry.span + ) if self._telemetry.span else nullcontext(): await self.begin_run() if self.state.is_running(): diff --git a/src/prefect/server/models/flow_runs.py b/src/prefect/server/models/flow_runs.py index 5db2ff750956..a454fff8f1d4 100644 --- a/src/prefect/server/models/flow_runs.py +++ b/src/prefect/server/models/flow_runs.py @@ -52,6 +52,9 @@ logger = get_logger("flow_runs") +logger = get_logger("flow_runs") + + T = TypeVar("T", bound=tuple) diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index d6b834d6bca6..15053d5016ae 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -4,7 +4,7 @@ import threading import time from asyncio import CancelledError -from contextlib import ExitStack, asynccontextmanager, contextmanager +from contextlib import ExitStack, asynccontextmanager, contextmanager, nullcontext from dataclasses import dataclass, field from functools import partial from textwrap import dedent @@ -523,7 +523,7 @@ def handle_success(self, result: R, transaction: Transaction) -> R: self.set_state(terminal_state) self._return_value = result - self._telemetry.end_span_on_success(terminal_state.message) + self._telemetry.end_span_on_success() return result def handle_retry(self, exc: Exception) -> bool: @@ -586,7 +586,7 @@ def handle_exception(self, exc: Exception) -> None: self.record_terminal_state_timing(state) self.set_state(state) self._raised = exc - self._telemetry.end_span_on_failure(state.message) + self._telemetry.end_span_on_failure(state.message if state else None) def handle_timeout(self, exc: TimeoutError) -> None: if not self.handle_retry(exc): @@ -612,7 +612,7 @@ def handle_crash(self, exc: BaseException) -> None: self.set_state(state, force=True) self._raised = exc self._telemetry.record_exception(exc) - self._telemetry.end_span_on_failure(state.message) + self._telemetry.end_span_on_failure(state.message if state else None) @contextmanager def setup_run_context(self, client: Optional[SyncPrefectClient] = None): @@ -670,7 +670,7 @@ def initialize_run( with SyncClientContext.get_or_create() as client_ctx: self._client = client_ctx.client self._is_started = True - flow_run_context = FlowRunContext.get() + parent_flow_run_context = FlowRunContext.get() parent_task_run_context = TaskRunContext.get() try: @@ -679,7 +679,7 @@ def initialize_run( self.task.create_local_run( id=task_run_id, parameters=self.parameters, - flow_run_context=flow_run_context, + flow_run_context=parent_flow_run_context, parent_task_run_context=parent_task_run_context, wait_for=self.wait_for, extra_task_inputs=dependencies, @@ -697,11 +697,16 @@ def initialize_run( self.logger.debug( f"Created task run {self.task_run.name!r} for task {self.task.name!r}" ) - labels = ( - flow_run_context.flow_run.labels if flow_run_context else {} - ) + + parent_labels = {} + if parent_flow_run_context and parent_flow_run_context.flow_run: + parent_labels = parent_flow_run_context.flow_run.labels + self._telemetry.start_span( - self.task_run, self.parameters, labels + run=self.task_run, + name=self.task.name, + parameters=self.parameters, + parent_labels=parent_labels, ) yield self @@ -755,7 +760,9 @@ def start( dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, ) -> Generator[None, None, None]: with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies): - with trace.use_span(self._telemetry._span): + with trace.use_span( + self._telemetry.span + ) if self._telemetry.span else nullcontext(): self.begin_run() try: yield @@ -1058,7 +1065,7 @@ async def handle_success(self, result: R, transaction: Transaction) -> R: await self.set_state(terminal_state) self._return_value = result - self._telemetry.end_span_on_success(terminal_state.message) + self._telemetry.end_span_on_success() return result @@ -1206,15 +1213,16 @@ async def initialize_run( async with AsyncClientContext.get_or_create(): self._client = get_client() self._is_started = True - flow_run_context = FlowRunContext.get() + parent_flow_run_context = FlowRunContext.get() + parent_task_run_context = TaskRunContext.get() try: if not self.task_run: self.task_run = await self.task.create_local_run( id=task_run_id, parameters=self.parameters, - flow_run_context=flow_run_context, - parent_task_run_context=TaskRunContext.get(), + flow_run_context=parent_flow_run_context, + parent_task_run_context=parent_task_run_context, wait_for=self.wait_for, extra_task_inputs=dependencies, ) @@ -1231,11 +1239,15 @@ async def initialize_run( f"Created task run {self.task_run.name!r} for task {self.task.name!r}" ) - labels = ( - flow_run_context.flow_run.labels if flow_run_context else {} - ) + parent_labels = {} + if parent_flow_run_context and parent_flow_run_context.flow_run: + parent_labels = parent_flow_run_context.flow_run.labels + self._telemetry.start_span( - self.task_run, self.parameters, labels + run=self.task_run, + name=self.task.name, + parameters=self.parameters, + parent_labels=parent_labels, ) yield self @@ -1291,7 +1303,9 @@ async def start( async with self.initialize_run( task_run_id=task_run_id, dependencies=dependencies ): - with trace.use_span(self._telemetry._span): + with trace.use_span( + self._telemetry.span + ) if self._telemetry.span else nullcontext(): await self.begin_run() try: yield diff --git a/src/prefect/telemetry/run_telemetry.py b/src/prefect/telemetry/run_telemetry.py index 08de1a2ebd0b..bb7cc81de5f9 100644 --- a/src/prefect/telemetry/run_telemetry.py +++ b/src/prefect/telemetry/run_telemetry.py @@ -1,22 +1,28 @@ import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from opentelemetry import propagate, trace from opentelemetry.propagators.textmap import Setter from opentelemetry.trace import ( + Span, Status, StatusCode, get_tracer, ) import prefect -from prefect.client.schemas import TaskRun +from prefect.client.schemas import FlowRun, TaskRun from prefect.client.schemas.objects import State +from prefect.context import FlowRunContext from prefect.types import KeyValueLabels if TYPE_CHECKING: from opentelemetry.trace import Tracer +LABELS_TRACEPARENT_KEY = "__OTEL_TRACEPARENT" +TRACEPARENT_KEY = "traceparent" + class OTELSetter(Setter[KeyValueLabels]): """ @@ -36,67 +42,74 @@ class RunTelemetry: _tracer: "Tracer" = field( default_factory=lambda: get_tracer("prefect", prefect.__version__) ) - _span = None + span: Optional[Span] = None def start_span( self, - task_run: TaskRun, + run: Union[TaskRun, FlowRun], + name: Optional[str] = None, parameters: Optional[Dict[str, Any]] = None, - labels: Optional[Dict[str, Any]] = None, + parent_labels: Optional[Dict[str, Any]] = None, ): """ Start a span for a task run. """ if parameters is None: parameters = {} - if labels is None: - labels = {} + if parent_labels is None: + parent_labels = {} parameter_attributes = { f"prefect.run.parameter.{k}": type(v).__name__ for k, v in parameters.items() } - self._span = self._tracer.start_span( - name=task_run.name, + run_type = "task" if isinstance(run, TaskRun) else "flow" + + self.span = self._tracer.start_span( + name=name or run.name, attributes={ - "prefect.run.type": "task", - "prefect.run.id": str(task_run.id), - "prefect.tags": task_run.tags, + f"prefect.{run_type}.name": name or run.name, + "prefect.run.type": run_type, + "prefect.run.id": str(run.id), + "prefect.tags": run.tags, **parameter_attributes, - **labels, + **parent_labels, }, ) + return self.span - def end_span_on_success(self, terminal_message: str) -> None: + def end_span_on_success(self) -> None: """ End a span for a task run on success. """ - if self._span: - self._span.set_status(Status(StatusCode.OK), terminal_message) - self._span.end(time.time_ns()) - self._span = None + if self.span: + self.span.set_status(Status(StatusCode.OK)) + self.span.end(time.time_ns()) + self.span = None - def end_span_on_failure(self, terminal_message: str) -> None: + def end_span_on_failure(self, terminal_message: Optional[str] = None) -> None: """ End a span for a task run on failure. """ - if self._span: - self._span.set_status(Status(StatusCode.ERROR, terminal_message)) - self._span.end(time.time_ns()) - self._span = None + if self.span: + self.span.set_status( + Status(StatusCode.ERROR, terminal_message or "Run failed") + ) + self.span.end(time.time_ns()) + self.span = None - def record_exception(self, exc: Exception) -> None: + def record_exception(self, exc: BaseException) -> None: """ Record an exception on a span. """ - if self._span: - self._span.record_exception(exc) + if self.span: + self.span.record_exception(exc) def update_state(self, new_state: State) -> None: """ Update a span with the state of a task run. """ - if self._span: - self._span.add_event( + if self.span: + self.span.add_event( new_state.name or new_state.type, { "prefect.state.message": new_state.message or "", @@ -105,3 +118,29 @@ def update_state(self, new_state: State) -> None: "prefect.state.id": str(new_state.id), }, ) + + def propagate_traceparent(self) -> Optional[KeyValueLabels]: + """ + Propagate a traceparent to a span. + """ + parent_flow_run_ctx = FlowRunContext.get() + + if parent_flow_run_ctx and parent_flow_run_ctx.flow_run: + if traceparent := parent_flow_run_ctx.flow_run.labels.get( + LABELS_TRACEPARENT_KEY + ): + carrier: KeyValueLabels = {TRACEPARENT_KEY: traceparent} + propagate.get_global_textmap().inject( + carrier={TRACEPARENT_KEY: traceparent}, + setter=OTELSetter(), + ) + return carrier + else: + if self.span: + carrier: KeyValueLabels = {} + propagate.get_global_textmap().inject( + carrier, + context=trace.set_span_in_context(self.span), + setter=OTELSetter(), + ) + return carrier diff --git a/tests/telemetry/test_instrumentation.py b/tests/telemetry/test_instrumentation.py index 6fae35a96895..ecb7377be899 100644 --- a/tests/telemetry/test_instrumentation.py +++ b/tests/telemetry/test_instrumentation.py @@ -1,4 +1,5 @@ import os +from typing import Literal from uuid import UUID, uuid4 import pytest @@ -13,7 +14,10 @@ from opentelemetry.sdk.trace import TracerProvider from tests.telemetry.instrumentation_tester import InstrumentationTester +import prefect from prefect import flow, task +from prefect.client.orchestration import SyncPrefectClient +from prefect.context import FlowRunContext from prefect.task_engine import ( run_task_async, run_task_sync, @@ -170,9 +174,215 @@ def test_logger_provider( assert log_handler._logger_provider == logger_provider +class TestFlowRunInstrumentation: + @pytest.fixture(params=["async", "sync"]) + async def engine_type( + self, request: pytest.FixtureRequest + ) -> Literal["async", "sync"]: + return request.param + + async def test_flow_run_creates_and_stores_otel_traceparent( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + sync_prefect_client: SyncPrefectClient, + ): + """Test that when no parent traceparent exists, the flow run stores its own span's traceparent""" + + @flow(name="child-flow") + async def async_child_flow() -> str: + return "hello from child" + + @flow(name="child-flow") + def sync_child_flow() -> str: + return "hello from child" + + @flow(name="parent-flow") + async def async_parent_flow() -> str: + return await async_child_flow() + + @flow(name="parent-flow") + def sync_parent_flow() -> str: + return sync_child_flow() + + if engine_type == "async": + await async_parent_flow() + else: + sync_parent_flow() + + spans = instrumentation.get_finished_spans() + + next( + span + for span in spans + if span.attributes.get("prefect.flow.name") == "parent-flow" + ) + child_span = next( + span + for span in spans + if span.attributes.get("prefect.flow.name") == "child-flow" + ) + + # Get the child flow run + child_flow_run_id = child_span.attributes.get("prefect.run.id") + child_flow_run = sync_prefect_client.read_flow_run(UUID(child_flow_run_id)) + + # Verify the child flow run has its span's traceparent in its labels + assert "__OTEL_TRACEPARENT" in child_flow_run.labels + assert child_flow_run.labels["__OTEL_TRACEPARENT"].startswith("00-") + trace_id_hex = child_flow_run.labels["__OTEL_TRACEPARENT"].split("-")[1] + assert int(trace_id_hex, 16) == child_span.context.trace_id + + async def test_flow_run_propagates_otel_traceparent_to_subflow( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): + """Test that OTEL traceparent gets propagated from parent flow to child flow""" + + @flow(name="child-flow") + async def async_child_flow() -> str: + return "hello from child" + + @flow(name="child-flow") + def sync_child_flow() -> str: + return "hello from child" + + @flow(name="parent-flow") + async def async_parent_flow() -> str: + # Set OTEL context in the parent flow's labels + flow_run = FlowRunContext.get().flow_run + mock_traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + flow_run.labels["__OTEL_TRACEPARENT"] = mock_traceparent + return await async_child_flow() + + @flow(name="parent-flow") + def sync_parent_flow() -> str: + # Set OTEL context in the parent flow's labels + flow_run = FlowRunContext.get().flow_run + mock_traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + flow_run.labels["__OTEL_TRACEPARENT"] = mock_traceparent + return sync_child_flow() + + parent_flow = async_parent_flow if engine_type == "async" else sync_parent_flow + await parent_flow() if engine_type == "async" else parent_flow() + + spans = instrumentation.get_finished_spans() + + parent_span = next( + span + for span in spans + if span.attributes.get("prefect.flow.name") == "parent-flow" + ) + child_span = next( + span + for span in spans + if span.attributes.get("prefect.flow.name") == "child-flow" + ) + + assert parent_span is not None + assert child_span is not None + assert child_span.context.trace_id == parent_span.context.trace_id + + async def test_flow_run_instrumentation( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): + @flow(name="instrumented-flow") + async def async_flow() -> str: + return 42 + + @flow(name="instrumented-flow") + def sync_flow() -> str: + return 42 + + test_flow = async_flow if engine_type == "async" else sync_flow + await test_flow() if engine_type == "async" else test_flow() + + spans = instrumentation.get_finished_spans() + assert len(spans) == 1 + + span = spans[0] + assert span is not None + instrumentation.assert_span_instrumented_for(span, prefect) + + instrumentation.assert_has_attributes( + span, + { + "prefect.flow.name": "instrumented-flow", + "prefect.run.type": "flow", + }, + ) + + async def test_flow_run_inherits_parent_labels( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + sync_prefect_client: SyncPrefectClient, + ): + """Test that parent flow labels get propagated to child flow spans""" + + @flow(name="child-flow") + async def async_child_flow() -> str: + return "hello from child" + + @flow(name="child-flow") + def sync_child_flow() -> str: + return "hello from child" + + @flow(name="parent-flow") + async def async_parent_flow() -> str: + # Set custom labels in parent flow + flow_run = FlowRunContext.get().flow_run + flow_run.labels.update( + {"test.label": "test-value", "environment": "testing"} + ) + return await async_child_flow() + + @flow(name="parent-flow") + def sync_parent_flow() -> str: + # Set custom labels in parent flow + flow_run = FlowRunContext.get().flow_run + flow_run.labels.update( + {"test.label": "test-value", "environment": "testing"} + ) + return sync_child_flow() + + if engine_type == "async": + state = await async_parent_flow(return_state=True) + else: + state = sync_parent_flow(return_state=True) + + spans = instrumentation.get_finished_spans() + child_spans = [ + span + for span in spans + if span.attributes.get("prefect.flow.name") == "child-flow" + ] + assert len(child_spans) == 1 + + # Get the parent flow run + parent_flow_run = sync_prefect_client.read_flow_run( + state.state_details.flow_run_id + ) + + # Verify the child span has the parent flow's labels + instrumentation.assert_has_attributes( + child_spans[0], + { + **parent_flow_run.labels, + "prefect.run.type": "flow", + "prefect.flow.name": "child-flow", + }, + ) + + class TestTaskRunInstrumentation: @pytest.fixture(params=["async", "sync"]) - async def engine_type(self, request): + async def engine_type( + self, request: pytest.FixtureRequest + ) -> Literal["async", "sync"]: return request.param async def run_task(self, task, task_run_id, parameters, engine_type): @@ -184,7 +394,9 @@ async def run_task(self, task, task_run_id, parameters, engine_type): return run_task_sync(task, task_run_id=task_run_id, parameters=parameters) async def test_span_creation( - self, engine_type, instrumentation: InstrumentationTester + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, ): @task async def async_task(x: int, y: int): @@ -213,7 +425,11 @@ def sync_task(x: int, y: int): ) assert spans[0].name == task_fn.name - async def test_span_attributes(self, engine_type, instrumentation): + async def test_span_attributes( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): @task async def async_task(x: int, y: int): return x + y @@ -245,7 +461,11 @@ def sync_task(x: int, y: int): ) assert spans[0].name == task_fn.__name__ - async def test_span_events(self, engine_type, instrumentation): + async def test_span_events( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): @task async def async_task(x: int, y: int): return x + y @@ -270,7 +490,11 @@ def sync_task(x: int, y: int): assert events[0].name == "Running" assert events[1].name == "Completed" - async def test_span_status_on_success(self, engine_type, instrumentation): + async def test_span_status_on_success( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): @task async def async_task(x: int, y: int): return x + y @@ -293,7 +517,11 @@ def sync_task(x: int, y: int): assert len(spans) == 1 assert spans[0].status.status_code == trace.StatusCode.OK - async def test_span_status_on_failure(self, engine_type, instrumentation): + async def test_span_status_on_failure( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): @task async def async_task(x: int, y: int): raise ValueError("Test error") @@ -318,7 +546,11 @@ def sync_task(x: int, y: int): assert spans[0].status.status_code == trace.StatusCode.ERROR assert "Test error" in spans[0].status.description - async def test_span_exception_recording(self, engine_type, instrumentation): + async def test_span_exception_recording( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): @task async def async_task(x: int, y: int): raise Exception("Test error") @@ -347,7 +579,12 @@ def sync_task(x: int, y: int): assert exception_event.attributes["exception.type"] == "Exception" assert exception_event.attributes["exception.message"] == "Test error" - async def test_flow_labels(self, engine_type, instrumentation, sync_prefect_client): + async def test_flow_labels( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + sync_prefect_client: SyncPrefectClient, + ): """Test that parent flow ID gets propagated to task spans""" @task diff --git a/tests/test_flow_engine.py b/tests/test_flow_engine.py index 9807f8c219cd..9ceeae8241bd 100644 --- a/tests/test_flow_engine.py +++ b/tests/test_flow_engine.py @@ -10,9 +10,7 @@ import anyio import pydantic import pytest -from opentelemetry import trace -import prefect from prefect import Flow, __development_base_path__, flow, task from prefect.client.orchestration import PrefectClient, SyncPrefectClient from prefect.client.schemas.filters import FlowFilter, FlowRunFilter @@ -49,8 +47,6 @@ from prefect.utilities.callables import get_call_parameters from prefect.utilities.filesystem import tmpchdir -from .telemetry.instrumentation_tester import InstrumentationTester - @flow async def foo(): @@ -615,7 +611,7 @@ def my_flow(): # after a flow run retry, the stale value will be pulled from the cache. async def test_flow_retry_with_no_error_in_flow_and_one_failed_child_flow( - self, sync_prefect_client: SyncPrefectClient + self, sync_prefect_client ): child_run_count = 0 flow_run_count = 0 @@ -1853,325 +1849,3 @@ async def expensive_flow(): concurrency_limit_v2.name ) assert response.active_slots == 0 - - -class TestFlowRunInstrumentation: - def test_flow_run_instrumentation(self, instrumentation: InstrumentationTester): - @flow - def instrumented_flow(): - from prefect.states import Completed - - return Completed(message="The flow is with you") - - instrumented_flow() - - spans = instrumentation.get_finished_spans() - assert len(spans) == 1 - span = spans[0] - assert span is not None - instrumentation.assert_span_instrumented_for(span, prefect) - - instrumentation.assert_has_attributes( - span, - { - "prefect.run.type": "flow", - "prefect.tags": (), - "prefect.flow.name": "instrumented-flow", - "prefect.run.id": mock.ANY, - }, - ) - assert span.status.status_code == trace.StatusCode.OK - - assert len(span.events) == 2 - assert span.events[0].name == "Running" - instrumentation.assert_has_attributes( - span.events[0], - { - "prefect.state.message": "", - "prefect.state.type": StateType.RUNNING, - "prefect.state.name": "Running", - "prefect.state.id": mock.ANY, - }, - ) - - assert span.events[1].name == "Completed" - instrumentation.assert_has_attributes( - span.events[1], - { - "prefect.state.message": "The flow is with you", - "prefect.state.type": StateType.COMPLETED, - "prefect.state.name": "Completed", - "prefect.state.id": mock.ANY, - }, - ) - - def test_flow_run_instrumentation_captures_tags( - self, - instrumentation: InstrumentationTester, - ): - from prefect import tags - - @flow - def instrumented_flow(): - pass - - with tags("foo", "bar"): - instrumented_flow() - - spans = instrumentation.get_finished_spans() - assert len(spans) == 1 - span = spans[0] - assert span is not None - instrumentation.assert_span_instrumented_for(span, prefect) - - instrumentation.assert_has_attributes( - span, - { - "prefect.run.type": "flow", - "prefect.flow.name": "instrumented-flow", - "prefect.run.id": mock.ANY, - }, - ) - # listy span attributes are serialized to tuples -- order seems nondeterministic so ignore rather than flake - assert set(span.attributes.get("prefect.tags")) == {"foo", "bar"} # type: ignore - assert span.status.status_code == trace.StatusCode.OK - - def test_flow_run_instrumentation_captures_labels( - self, - instrumentation: InstrumentationTester, - sync_prefect_client: SyncPrefectClient, - ): - @flow - def instrumented_flow(): - pass - - state = instrumented_flow(return_state=True) - - assert state.state_details.flow_run_id is not None - flow_run = sync_prefect_client.read_flow_run(state.state_details.flow_run_id) - - spans = instrumentation.get_finished_spans() - assert len(spans) == 1 - span = spans[0] - assert span is not None - instrumentation.assert_span_instrumented_for(span, prefect) - - instrumentation.assert_has_attributes( - span, - { - **flow_run.labels, - "prefect.run.type": "flow", - "prefect.flow.name": "instrumented-flow", - "prefect.run.id": mock.ANY, - }, - ) - - def test_flow_run_instrumentation_on_exception( - self, instrumentation: InstrumentationTester - ): - @flow - def a_broken_flow(): - raise Exception("This flow broke!") - - with pytest.raises(Exception): - a_broken_flow() - - spans = instrumentation.get_finished_spans() - assert len(spans) == 1 - span = spans[0] - assert span is not None - instrumentation.assert_span_instrumented_for(span, prefect) - - instrumentation.assert_has_attributes( - span, - { - "prefect.run.type": "flow", - "prefect.tags": (), - "prefect.flow.name": "a-broken-flow", - "prefect.run.id": mock.ANY, - }, - ) - - assert span.status.status_code == trace.StatusCode.ERROR - assert ( - span.status.description - == "Flow run encountered an exception: Exception: This flow broke!" - ) - - assert len(span.events) == 3 - assert span.events[0].name == "Running" - instrumentation.assert_has_attributes( - span.events[0], - { - "prefect.state.message": "", - "prefect.state.type": StateType.RUNNING, - "prefect.state.name": "Running", - "prefect.state.id": mock.ANY, - }, - ) - - assert span.events[1].name == "Failed" - instrumentation.assert_has_attributes( - span.events[1], - { - "prefect.state.message": "Flow run encountered an exception: Exception: This flow broke!", - "prefect.state.type": StateType.FAILED, - "prefect.state.name": "Failed", - "prefect.state.id": mock.ANY, - }, - ) - - assert span.events[2].name == "exception" - instrumentation.assert_has_attributes( - span.events[2], - { - "exception.type": "Exception", - "exception.message": "This flow broke!", - "exception.stacktrace": mock.ANY, - "exception.escaped": "False", - }, - ) - - def test_flow_run_instrumentation_on_timeout( - self, instrumentation: InstrumentationTester - ): - @flow(timeout_seconds=0.1) - def a_slow_flow(): - time.sleep(1) - - with pytest.raises(TimeoutError): - a_slow_flow() - - spans = instrumentation.get_finished_spans() - assert len(spans) == 1 - span = spans[0] - assert span is not None - instrumentation.assert_span_instrumented_for(span, prefect) - - instrumentation.assert_has_attributes( - span, - { - "prefect.run.type": "flow", - "prefect.tags": (), - "prefect.flow.name": "a-slow-flow", - "prefect.run.id": mock.ANY, - }, - ) - - assert span.status.status_code == trace.StatusCode.ERROR - assert span.status.description == "Flow run exceeded timeout of 0.1 second(s)" - - assert len(span.events) == 3 - assert span.events[0].name == "Running" - instrumentation.assert_has_attributes( - span.events[0], - { - "prefect.state.message": "", - "prefect.state.type": StateType.RUNNING, - "prefect.state.name": "Running", - "prefect.state.id": mock.ANY, - }, - ) - - assert span.events[1].name == "TimedOut" - instrumentation.assert_has_attributes( - span.events[1], - { - "prefect.state.message": "Flow run exceeded timeout of 0.1 second(s)", - "prefect.state.type": StateType.FAILED, - "prefect.state.name": "TimedOut", - "prefect.state.id": mock.ANY, - }, - ) - - assert span.events[2].name == "exception" - instrumentation.assert_has_attributes( - span.events[2], - { - "exception.type": "prefect.flow_engine.FlowRunTimeoutError", - "exception.message": "Scope timed out after 0.1 second(s).", - "exception.stacktrace": mock.ANY, - "exception.escaped": "False", - }, - ) - - async def test_flow_run_propagates_otel_traceparent_to_subflow( - self, instrumentation: InstrumentationTester - ): - """Test that OTEL traceparent gets propagated from parent flow to child flow""" - - @flow - def child_flow(): - return "hello from child" - - @flow - def parent_flow(): - flow_run_ctx = FlowRunContext.get() - assert flow_run_ctx - assert flow_run_ctx.flow_run - flow_run = flow_run_ctx.flow_run - mock_traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" - flow_run.labels["__OTEL_TRACEPARENT"] = mock_traceparent - - return child_flow() - - parent_flow() - - spans = instrumentation.get_finished_spans() - - parent_span = next( - span - for span in spans - if span.attributes - and span.attributes.get("prefect.flow.name") == "parent-flow" - ) - child_span = next( - span - for span in spans - if span.attributes - and span.attributes.get("prefect.flow.name") == "child-flow" - ) - - assert parent_span is not None - assert child_span is not None - assert child_span.context and parent_span.context - assert child_span.context.trace_id == parent_span.context.trace_id - - async def test_flow_run_creates_and_stores_otel_traceparent( - self, instrumentation: InstrumentationTester, sync_prefect_client - ): - """Test that when no parent traceparent exists, the flow run stores its own span's traceparent""" - - @flow - def child_flow(): - return "hello from child" - - @flow - def parent_flow(): - return child_flow() - - parent_flow() - - spans = instrumentation.get_finished_spans() - - next( - span - for span in spans - if span.attributes - and span.attributes.get("prefect.flow.name") == "parent-flow" - ) - child_span = next( - span - for span in spans - if span.attributes - and span.attributes.get("prefect.flow.name") == "child-flow" - ) - - child_flow_run_id = child_span.attributes.get("prefect.run.id") - assert child_flow_run_id - child_flow_run = sync_prefect_client.read_flow_run(UUID(child_flow_run_id)) - - assert "__OTEL_TRACEPARENT" in child_flow_run.labels - assert child_flow_run.labels["__OTEL_TRACEPARENT"].startswith("00-") - trace_id_hex = child_flow_run.labels["__OTEL_TRACEPARENT"].split("-")[1] - assert int(trace_id_hex, 16) == child_span.context.trace_id From f4f596365233c412201daca4b89b242802f8ff63 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Thu, 12 Dec 2024 11:46:25 -0600 Subject: [PATCH 6/8] improve `Dockerfile` build time and add CI to catch future slow downs (#16348) --- .dockerignore | 9 +- .github/workflows/time-docker-build.yaml | 110 +++++++++++++++++++++++ Dockerfile | 20 +++-- 3 files changed, 130 insertions(+), 9 deletions(-) create mode 100644 .github/workflows/time-docker-build.yaml diff --git a/.dockerignore b/.dockerignore index 30b96ab92540..a5fb96ae22d3 100644 --- a/.dockerignore +++ b/.dockerignore @@ -30,12 +30,13 @@ env/ venv/ # Documentation artifacts -docs/api-ref/schema.json +docs/ site/ # UI artifacts src/prefect/server/ui/* ui/node_modules +ui-v2/ # Databases *.db @@ -49,3 +50,9 @@ dask-worker-space/ # Editors .idea/ .vscode/ + +# Other +tests/ +compat-tests/ +benches/ +build/ diff --git a/.github/workflows/time-docker-build.yaml b/.github/workflows/time-docker-build.yaml new file mode 100644 index 000000000000..7156ac170aa2 --- /dev/null +++ b/.github/workflows/time-docker-build.yaml @@ -0,0 +1,110 @@ +name: Docker Build Time Benchmark + +on: + push: + paths: + - "Dockerfile" + - ".dockerignore" + pull_request: + paths: + - "Dockerfile" + - ".dockerignore" + +jobs: + benchmark: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + # For PRs, checkout the base branch to compare against + - name: Checkout base branch + if: github.base_ref + uses: actions/checkout@v4 + with: + ref: ${{ github.base_ref }} + clean: true + + - name: Clean Docker system + run: | + docker system prune -af + docker builder prune -af + + - name: Build base branch image + if: github.base_ref + id: base_build_time + run: | + start_time=$(date +%s) + DOCKER_BUILDKIT=1 docker build . --no-cache --progress=plain + end_time=$(date +%s) + base_time=$((end_time - start_time)) + echo "base_time=$base_time" >> $GITHUB_OUTPUT + + # For PRs, checkout back to the PR's HEAD + - name: Checkout PR branch + if: github.base_ref + uses: actions/checkout@v4 + with: + ref: ${{ github.head_ref }} + clean: true + + - name: Clean Docker system again + run: | + docker system prune -af + docker builder prune -af + + - name: Build and time Docker image + id: build_time + run: | + start_time=$(date +%s) + DOCKER_BUILDKIT=1 docker build . --no-cache --progress=plain + end_time=$(date +%s) + build_time=$((end_time - start_time)) + echo "build_time=$build_time" >> $GITHUB_OUTPUT + + - name: Compare build times + run: | + CURRENT_TIME=${{ steps.build_time.outputs.build_time }} + + if [ "${{ github.base_ref }}" != "" ]; then + BASE_TIME=${{ steps.base_build_time.outputs.base_time }} + + echo "### 🏗️ Docker Build Time Comparison" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "| Branch | Build Time | Difference |" >> $GITHUB_STEP_SUMMARY + echo "|--------|------------|------------|" >> $GITHUB_STEP_SUMMARY + echo "| base (${{ github.base_ref }}) | ${BASE_TIME}s | - |" >> $GITHUB_STEP_SUMMARY + + DIFF=$((CURRENT_TIME - BASE_TIME)) + PERCENT=$(echo "scale=2; ($CURRENT_TIME - $BASE_TIME) * 100 / $BASE_TIME" | bc) + + if [ $DIFF -gt 0 ]; then + DIFF_TEXT="⬆️ +${DIFF}s (+${PERCENT}%)" + elif [ $DIFF -lt 0 ]; then + DIFF_TEXT="⬇️ ${DIFF}s (${PERCENT}%)" + else + DIFF_TEXT="✨ No change" + fi + + echo "| current (${{ github.head_ref }}) | ${CURRENT_TIME}s | $DIFF_TEXT |" >> $GITHUB_STEP_SUMMARY + + # Fail if build time increased by more than 5% + if (( $(echo "$PERCENT > 5" | bc -l) )); then + echo "" >> $GITHUB_STEP_SUMMARY + echo "❌ **Build time increased by more than 5%!**" >> $GITHUB_STEP_SUMMARY + echo "This change significantly increases the build time. Please review the Dockerfile changes." >> $GITHUB_STEP_SUMMARY + exit 1 + elif (( $(echo "$PERCENT < 0" | bc -l) )); then + echo "" >> $GITHUB_STEP_SUMMARY + echo "✅ **Build time decreased!**" >> $GITHUB_STEP_SUMMARY + echo "Great job optimizing the build!" >> $GITHUB_STEP_SUMMARY + fi + else + echo "### 🏗️ Docker Build Time" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Build completed in ${CURRENT_TIME} seconds" >> $GITHUB_STEP_SUMMARY + fi diff --git a/Dockerfile b/Dockerfile index 9534db5507a4..542296331793 100644 --- a/Dockerfile +++ b/Dockerfile @@ -81,6 +81,9 @@ FROM ${BASE_IMAGE} AS final ENV LC_ALL=C.UTF-8 ENV LANG=C.UTF-8 +ENV UV_LINK_MODE=copy +ENV UV_SYSTEM_PYTHON=1 + LABEL maintainer="help@prefect.io" \ io.prefect.python-version=${PYTHON_VERSION} \ org.label-schema.schema-version="1.0" \ @@ -95,32 +98,33 @@ RUN apt-get update && \ tini=0.19.* \ build-essential \ git=1:2.* \ - curl \ - ca-certificates \ && apt-get clean && rm -rf /var/lib/apt/lists/* -# Install UV from official image -COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ +# Install UV from official image - pin to specific version for build caching +COPY --from=ghcr.io/astral-sh/uv:0.5.8 /uv /uvx /bin/ # Install dependencies using a temporary mount for requirements files RUN --mount=type=bind,source=requirements-client.txt,target=/tmp/requirements-client.txt \ --mount=type=bind,source=requirements.txt,target=/tmp/requirements.txt \ - uv pip install --system -r /tmp/requirements.txt + --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r /tmp/requirements.txt -r /tmp/requirements-client.txt # Install prefect from the sdist COPY --from=python-builder /opt/prefect/dist ./dist # Extras to include during installation ARG PREFECT_EXTRAS -RUN uv pip install --system "./dist/prefect.tar.gz${PREFECT_EXTRAS:-""}" && \ +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install "./dist/prefect.tar.gz${PREFECT_EXTRAS:-""}" && \ rm -rf dist/ # Remove setuptools -RUN uv pip uninstall --system setuptools +RUN uv pip uninstall setuptools # Install any extra packages ARG EXTRA_PIP_PACKAGES -RUN [ -z "${EXTRA_PIP_PACKAGES:-""}" ] || uv pip install --system "${EXTRA_PIP_PACKAGES}" +RUN --mount=type=cache,target=/root/.cache/uv \ + [ -z "${EXTRA_PIP_PACKAGES:-""}" ] || uv pip install "${EXTRA_PIP_PACKAGES}" # Smoke test RUN prefect version From f57b48c47dd2ecb53d9394a740147b06b54dbc94 Mon Sep 17 00:00:00 2001 From: Andrew Brookins Date: Thu, 12 Dec 2024 10:35:59 -0800 Subject: [PATCH 7/8] Add experimental support for emitting lineage events (#16242) --- .gitignore | 6 +- docs/v3/develop/settings-ref.mdx | 12 + schemas/settings.schema.json | 9 + src/prefect/_experimental/__init__.py | 0 src/prefect/_experimental/lineage.py | 181 +++++++++++ src/prefect/events/utilities.py | 2 + src/prefect/events/worker.py | 8 + src/prefect/results.py | 51 ++-- src/prefect/settings/models/experiments.py | 5 + src/prefect/testing/fixtures.py | 8 + tests/events/client/test_events_worker.py | 28 ++ tests/experimental/test_lineage.py | 339 +++++++++++++++++++++ tests/results/test_result_store.py | 71 +++++ tests/test_settings.py | 1 + 14 files changed, 702 insertions(+), 19 deletions(-) create mode 100644 src/prefect/_experimental/__init__.py create mode 100644 src/prefect/_experimental/lineage.py create mode 100644 tests/experimental/test_lineage.py diff --git a/.gitignore b/.gitignore index 127e8e8c02a0..c247225328a5 100644 --- a/.gitignore +++ b/.gitignore @@ -52,7 +52,6 @@ src/prefect/server/ui_build/* # API artifacts - # MacOS .DS_Store @@ -76,4 +75,7 @@ libcairo.2.dylib # setuptools-scm generated files src/integrations/*/**/_version.py -*.log \ No newline at end of file +*.log + +# Pyright type analysis report +prefect-analysis.json diff --git a/docs/v3/develop/settings-ref.mdx b/docs/v3/develop/settings-ref.mdx index 62fb0d353300..1d85829fd752 100644 --- a/docs/v3/develop/settings-ref.mdx +++ b/docs/v3/develop/settings-ref.mdx @@ -467,6 +467,18 @@ Enables sending telemetry to Prefect Cloud. **Supported environment variables**: `PREFECT_EXPERIMENTS_TELEMETRY_ENABLED` +### `lineage_events_enabled` +If `True`, enables emitting lineage events. Set to `False` to disable lineage event emission. + +**Type**: `boolean` + +**Default**: `False` + +**TOML dotted key path**: `experiments.lineage_events_enabled` + +**Supported environment variables**: +`PREFECT_EXPERIMENTS_LINEAGE_EVENTS_ENABLED` + --- ## FlowsSettings Settings for controlling flow behavior diff --git a/schemas/settings.schema.json b/schemas/settings.schema.json index 686c847b0d2e..dafc013944e1 100644 --- a/schemas/settings.schema.json +++ b/schemas/settings.schema.json @@ -327,6 +327,15 @@ ], "title": "Telemetry Enabled", "type": "boolean" + }, + "lineage_events_enabled": { + "default": false, + "description": "If `True`, enables emitting lineage events. Set to `False` to disable lineage event emission.", + "supported_environment_variables": [ + "PREFECT_EXPERIMENTS_LINEAGE_EVENTS_ENABLED" + ], + "title": "Lineage Events Enabled", + "type": "boolean" } }, "title": "ExperimentsSettings", diff --git a/src/prefect/_experimental/__init__.py b/src/prefect/_experimental/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/prefect/_experimental/lineage.py b/src/prefect/_experimental/lineage.py new file mode 100644 index 000000000000..b26474efffdc --- /dev/null +++ b/src/prefect/_experimental/lineage.py @@ -0,0 +1,181 @@ +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Union + +from prefect.events.related import related_resources_from_run_context +from prefect.events.schemas.events import RelatedResource, Resource +from prefect.events.utilities import emit_event +from prefect.settings import get_current_settings + +if TYPE_CHECKING: + from prefect.results import ResultStore + +UpstreamResources = Sequence[Union[RelatedResource, dict[str, str]]] +DownstreamResources = Sequence[Union[Resource, dict[str, str]]] + +# Map block types to their URI schemes +STORAGE_URI_SCHEMES = { + "local-file-system": "file://{path}", + "s3-bucket": "s3://{storage.bucket_name}/{path}", + "gcs-bucket": "gs://{storage.bucket}/{path}", + "azure-blob-storage": "azure-blob://{storage.container_name}/{path}", +} + + +def get_result_resource_uri( + store: "ResultStore", + key: str, +) -> Optional[str]: + """ + Generate a URI for a result based on its storage backend. + + Args: + store: A `ResultStore` instance. + key: The key of the result to generate a URI for. + """ + storage = store.result_storage + if storage is None: + return + + path = store._resolved_key_path(key) + + block_type = storage.get_block_type_slug() + if block_type and block_type in STORAGE_URI_SCHEMES: + return STORAGE_URI_SCHEMES[block_type].format(storage=storage, path=path) + + # Generic fallback + return f"prefect://{block_type}/{path}" + + +async def emit_lineage_event( + event_name: str, + upstream_resources: Optional[UpstreamResources] = None, + downstream_resources: Optional[DownstreamResources] = None, + direction_of_run_from_event: Literal["upstream", "downstream"] = "downstream", +) -> None: + """Emit lineage events showing relationships between resources. + + Args: + event_name: The name of the event to emit + upstream_resources: Optional list of RelatedResources that were upstream of + the event + downstream_resources: Optional list of Resources that were downstream + of the event + direction_of_run_from_event: The direction of the current run from + the event. E.g., if we're in a flow run and + `direction_of_run_from_event` is "downstream", then the flow run is + considered downstream of the resource's event. + """ + from prefect.client.orchestration import get_client # Avoid a circular import + + if not get_current_settings().experiments.lineage_events_enabled: + return + + upstream_resources = list(upstream_resources) if upstream_resources else [] + downstream_resources = list(downstream_resources) if downstream_resources else [] + + async with get_client() as client: + related_resources = await related_resources_from_run_context(client) + + # NOTE: We handle adding run-related resources to the event here instead of in + # the EventsWorker because not all run-related resources are upstream from + # every lineage event (they might be downstream). The EventsWorker only adds + # related resources to the "related" field in the event, which, for + # lineage-related events, tracks upstream resources only. For downstream + # resources, we need to emit an event for each downstream resource. + if direction_of_run_from_event == "downstream": + downstream_resources.extend(related_resources) + else: + upstream_resources.extend(related_resources) + + # Emit an event for each downstream resource. This is necessary because + # our event schema allows one primary resource and many related resources, + # and for the purposes of lineage, related resources can only represent + # upstream resources. + for resource in downstream_resources: + # Downstream lineage resources need to have the + # prefect.resource.lineage-group label. All upstram resources from a + # downstream resource with this label will be considered lineage-related + # resources. + if "prefect.resource.lineage-group" not in resource: + resource["prefect.resource.lineage-group"] = "global" + + emit_kwargs: Dict[str, Any] = { + "event": event_name, + "resource": resource, + "related": upstream_resources, + } + + emit_event(**emit_kwargs) + + +async def emit_result_read_event( + store: "ResultStore", + result_key: str, + downstream_resources: Optional[DownstreamResources] = None, + cached: bool = False, +) -> None: + """ + Emit a lineage event showing a task or flow result was read. + + Args: + store: A `ResultStore` instance. + result_key: The key of the result to generate a URI for. + downstream_resources: List of resources that were + downstream of the event's resource. + """ + if not get_current_settings().experiments.lineage_events_enabled: + return + + result_resource_uri = get_result_resource_uri(store, result_key) + if result_resource_uri: + upstream_resources = [ + RelatedResource( + root={ + "prefect.resource.id": result_resource_uri, + "prefect.resource.role": "result", + } + ) + ] + event_name = "prefect.result.read" + if cached: + event_name += ".cached" + + await emit_lineage_event( + event_name=event_name, + upstream_resources=upstream_resources, + downstream_resources=downstream_resources, + direction_of_run_from_event="downstream", + ) + + +async def emit_result_write_event( + store: "ResultStore", + result_key: str, + upstream_resources: Optional[UpstreamResources] = None, +) -> None: + """ + Emit a lineage event showing a task or flow result was written. + + Args: + store: A `ResultStore` instance. + result_key: The key of the result to generate a URI for. + upstream_resources: Optional list of resources that were + upstream of the event's resource. + """ + if not get_current_settings().experiments.lineage_events_enabled: + return + + result_resource_uri = get_result_resource_uri(store, result_key) + if result_resource_uri: + downstream_resources = [ + { + "prefect.resource.id": result_resource_uri, + "prefect.resource.role": "result", + "prefect.resource.lineage-group": "global", + } + ] + await emit_lineage_event( + event_name="prefect.result.write", + upstream_resources=upstream_resources, + downstream_resources=downstream_resources, + direction_of_run_from_event="upstream", + ) diff --git a/src/prefect/events/utilities.py b/src/prefect/events/utilities.py index 6995e96dced8..b1a04a96a725 100644 --- a/src/prefect/events/utilities.py +++ b/src/prefect/events/utilities.py @@ -24,6 +24,7 @@ def emit_event( payload: Optional[Dict[str, Any]] = None, id: Optional[UUID] = None, follows: Optional[Event] = None, + **kwargs: Optional[Dict[str, Any]], ) -> Optional[Event]: """ Send an event to Prefect Cloud. @@ -62,6 +63,7 @@ def emit_event( event_kwargs: Dict[str, Any] = { "event": event, "resource": resource, + **kwargs, } if occurred is None: diff --git a/src/prefect/events/worker.py b/src/prefect/events/worker.py index b1fa30baebf1..0adb06a549db 100644 --- a/src/prefect/events/worker.py +++ b/src/prefect/events/worker.py @@ -83,6 +83,14 @@ async def _handle(self, event: Event): await self._client.emit(event) async def attach_related_resources_from_context(self, event: Event): + if "prefect.resource.lineage-group" in event.resource: + # We attach related resources to lineage events in `emit_lineage_event`, + # instead of the worker, because not all run-related resources are + # upstream from every lineage event (they might be downstream). + # The "related" field in the event schema tracks upstream resources + # only. + return + exclude = {resource.id for resource in event.involved_resources} event.related += await related_resources_from_run_context( client=self._orchestration_client, exclude=exclude diff --git a/src/prefect/results.py b/src/prefect/results.py index 19206665c8bf..82dd06573100 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -38,6 +38,10 @@ from typing_extensions import ParamSpec, Self import prefect +from prefect._experimental.lineage import ( + emit_result_read_event, + emit_result_write_event, +) from prefect._internal.compatibility import deprecated from prefect._internal.compatibility.deprecated import deprecated_field from prefect.blocks.core import Block @@ -232,6 +236,10 @@ def _format_user_supplied_storage_key(key: str) -> str: T = TypeVar("T") +def default_cache() -> LRUCache[str, "ResultRecord[Any]"]: + return LRUCache(maxsize=1000) + + def result_storage_discriminator(x: Any) -> str: if isinstance(x, dict): if "block_type_slug" in x: @@ -284,7 +292,7 @@ class ResultStore(BaseModel): cache_result_in_memory: bool = Field(default=True) serializer: Serializer = Field(default_factory=get_default_result_serializer) storage_key_fn: Callable[[], str] = Field(default=DEFAULT_STORAGE_KEY_FN) - cache: LRUCache = Field(default_factory=lambda: LRUCache(maxsize=1000)) + cache: LRUCache[str, "ResultRecord[Any]"] = Field(default_factory=default_cache) # Deprecated fields persist_result: Optional[bool] = Field(default=None) @@ -446,8 +454,15 @@ async def aexists(self, key: str) -> bool: """ return await self._exists(key=key, _sync=False) + def _resolved_key_path(self, key: str) -> str: + if self.result_storage_block_id is None and hasattr( + self.result_storage, "_resolve_path" + ): + return str(self.result_storage._resolve_path(key)) + return key + @sync_compatible - async def _read(self, key: str, holder: str) -> "ResultRecord": + async def _read(self, key: str, holder: str) -> "ResultRecord[Any]": """ Read a result record from storage. @@ -465,8 +480,12 @@ async def _read(self, key: str, holder: str) -> "ResultRecord": if self.lock_manager is not None and not self.is_lock_holder(key, holder): await self.await_for_lock(key) - if key in self.cache: - return self.cache[key] + resolved_key_path = self._resolved_key_path(key) + + if resolved_key_path in self.cache: + cached_result = self.cache[resolved_key_path] + await emit_result_read_event(self, resolved_key_path, cached=True) + return cached_result if self.result_storage is None: self.result_storage = await get_default_result_storage() @@ -478,31 +497,28 @@ async def _read(self, key: str, holder: str) -> "ResultRecord": metadata.storage_key is not None ), "Did not find storage key in metadata" result_content = await self.result_storage.read_path(metadata.storage_key) - result_record = ResultRecord.deserialize_from_result_and_metadata( + result_record: ResultRecord[ + Any + ] = ResultRecord.deserialize_from_result_and_metadata( result=result_content, metadata=metadata_content ) + await emit_result_read_event(self, resolved_key_path) else: content = await self.result_storage.read_path(key) - result_record = ResultRecord.deserialize( + result_record: ResultRecord[Any] = ResultRecord.deserialize( content, backup_serializer=self.serializer ) + await emit_result_read_event(self, resolved_key_path) if self.cache_result_in_memory: - if self.result_storage_block_id is None and hasattr( - self.result_storage, "_resolve_path" - ): - cache_key = str(self.result_storage._resolve_path(key)) - else: - cache_key = key - - self.cache[cache_key] = result_record + self.cache[resolved_key_path] = result_record return result_record def read( self, key: str, holder: Optional[str] = None, - ) -> "ResultRecord": + ) -> "ResultRecord[Any]": """ Read a result record from storage. @@ -520,7 +536,7 @@ async def aread( self, key: str, holder: Optional[str] = None, - ) -> "ResultRecord": + ) -> "ResultRecord[Any]": """ Read a result record from storage. @@ -663,12 +679,13 @@ async def _persist_result_record(self, result_record: "ResultRecord", holder: st base_key, content=result_record.serialize_metadata(), ) + await emit_result_write_event(self, result_record.metadata.storage_key) # Otherwise, write the result metadata and result together else: await self.result_storage.write_path( result_record.metadata.storage_key, content=result_record.serialize() ) - + await emit_result_write_event(self, result_record.metadata.storage_key) if self.cache_result_in_memory: self.cache[key] = result_record diff --git a/src/prefect/settings/models/experiments.py b/src/prefect/settings/models/experiments.py index 218128c3dcf1..1ff11c7a13e2 100644 --- a/src/prefect/settings/models/experiments.py +++ b/src/prefect/settings/models/experiments.py @@ -22,3 +22,8 @@ class ExperimentsSettings(PrefectBaseSettings): default=False, description="Enables sending telemetry to Prefect Cloud.", ) + + lineage_events_enabled: bool = Field( + default=False, + description="If `True`, enables emitting lineage events. Set to `False` to disable lineage event emission.", + ) diff --git a/src/prefect/testing/fixtures.py b/src/prefect/testing/fixtures.py index 545778427ac1..07352f872afc 100644 --- a/src/prefect/testing/fixtures.py +++ b/src/prefect/testing/fixtures.py @@ -27,6 +27,7 @@ from prefect.server.events.pipeline import EventsPipeline from prefect.settings import ( PREFECT_API_URL, + PREFECT_EXPERIMENTS_LINEAGE_EVENTS_ENABLED, PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, PREFECT_SERVER_CSRF_PROTECTION_ENABLED, get_current_settings, @@ -452,3 +453,10 @@ def reset_worker_events(asserting_events_worker: EventsWorker): yield assert isinstance(asserting_events_worker._client, AssertingEventsClient) asserting_events_worker._client.events = [] + + +@pytest.fixture +def enable_lineage_events(): + """A fixture that ensures lineage events are enabled.""" + with temporary_settings(updates={PREFECT_EXPERIMENTS_LINEAGE_EVENTS_ENABLED: True}): + yield diff --git a/tests/events/client/test_events_worker.py b/tests/events/client/test_events_worker.py index b338c49dd1dd..5170f45a9b8c 100644 --- a/tests/events/client/test_events_worker.py +++ b/tests/events/client/test_events_worker.py @@ -9,6 +9,7 @@ AssertingEventsClient, PrefectEventsClient, ) +from prefect.events.utilities import emit_event from prefect.events.worker import EventsWorker from prefect.settings import ( PREFECT_API_URL, @@ -88,3 +89,30 @@ def emitting_flow(): assert event.related[1].id == f"prefect.flow.{db_flow.id}" assert event.related[1].role == "flow" assert event.related[1]["prefect.resource.name"] == db_flow.name + + +async def test_does_not_include_related_resources_from_run_context_for_lineage_events( + asserting_events_worker: EventsWorker, + reset_worker_events, + prefect_client, +): + @flow + def emitting_flow(): + emit_event( + event="s3.read", + resource={ + "prefect.resource.id": "s3://bucket-name/key-name", + "prefect.resource.role": "data-source", + "prefect.resource.lineage-group": "global", + }, + ) + + emitting_flow(return_state=True) + + await asserting_events_worker.drain() + + assert len(asserting_events_worker._client.events) == 1 + event = asserting_events_worker._client.events[0] + assert event.event == "s3.read" + assert event.resource.id == "s3://bucket-name/key-name" + assert len(event.related) == 0 diff --git a/tests/experimental/test_lineage.py b/tests/experimental/test_lineage.py new file mode 100644 index 000000000000..77a33e6da7d4 --- /dev/null +++ b/tests/experimental/test_lineage.py @@ -0,0 +1,339 @@ +from unittest.mock import patch + +import pytest + +from prefect._experimental.lineage import ( + emit_lineage_event, + emit_result_read_event, + emit_result_write_event, + get_result_resource_uri, +) +from prefect.events.schemas.events import RelatedResource +from prefect.filesystems import ( + LocalFileSystem, + WritableDeploymentStorage, + WritableFileSystem, +) +from prefect.results import ResultStore + + +@pytest.fixture +async def local_storage(tmp_path): + return LocalFileSystem(basepath=str(tmp_path)) + + +@pytest.fixture +def result_store(local_storage): + return ResultStore(result_storage=local_storage) + + +@pytest.fixture +def mock_emit_event(): + """Mock the emit_event function used by all lineage event emission.""" + with patch("prefect._experimental.lineage.emit_event") as mock: + yield mock + + +class CustomStorage(WritableFileSystem, WritableDeploymentStorage): + _block_type_slug = "custom-storage" + + def _resolve_path(self, path): + return path + + @classmethod + def get_block_type_slug(cls): + return "custom-storage" + + def get_directory(self, path: str) -> str: + raise NotImplementedError + + def put_directory(self, path: str, directory_path: str) -> None: + raise NotImplementedError + + def read_path(self, path: str) -> bytes: + raise NotImplementedError + + def write_path(self, path: str, contents: bytes) -> None: + raise NotImplementedError + + +async def test_get_result_resource_uri_with_local_storage(local_storage): + uri = get_result_resource_uri(ResultStore(result_storage=local_storage), "test-key") + assert uri is not None + assert uri.startswith("file://") + assert uri.endswith("/test-key") + + +async def test_get_resource_uri_with_none_storage(): + store = ResultStore(result_storage=None) + uri = get_result_resource_uri(store, "test-key") + assert uri is None + + +async def test_get_resource_uri_with_unknown_storage(): + store = ResultStore(result_storage=CustomStorage()) + uri = get_result_resource_uri(store, "test-key") + assert uri == "prefect://custom-storage/test-key" + + +@pytest.mark.parametrize( + "block_type,expected_prefix", + [ + ("local-file-system", "file://"), + ("s3-bucket", "s3://"), + ("gcs-bucket", "gs://"), + ("azure-blob-storage", "azure-blob://"), + ], +) +async def test_get_resource_uri_block_type_mapping(block_type, expected_prefix): + if block_type == "local-file-system": + cls = LocalFileSystem + else: + + class MockStorage(CustomStorage): + _block_type_slug = block_type + + def _resolve_path(self, path): + return path + + @classmethod + def get_block_type_slug(cls): + return block_type + + # Add required attributes based on block type + bucket_name: str = "test-bucket" + bucket: str = "test-bucket" + container_name: str = "test-container" + + cls = MockStorage + + store = ResultStore(result_storage=cls()) + uri = get_result_resource_uri(store, "test-key") + assert uri is not None + assert uri.startswith(expected_prefix), f"Failed for {block_type}" + + +class TestEmitLineageEvent: + async def test_emit_lineage_event_with_upstream_and_downstream( + self, enable_lineage_events, mock_emit_event + ): + await emit_lineage_event( + event_name="test.event", + upstream_resources=[ + { + "prefect.resource.id": "upstream1", + "prefect.resource.role": "some-purpose", + }, + { + "prefect.resource.id": "upstream2", + "prefect.resource.role": "some-purpose", + }, + ], + downstream_resources=[ + { + "prefect.resource.id": "downstream1", + "prefect.resource.lineage-group": "global", + }, + { + "prefect.resource.id": "downstream2", + "prefect.resource.lineage-group": "global", + }, + ], + ) + + assert mock_emit_event.call_count == 2 # One call per downstream resource + + # Check first downstream resource event + first_call = mock_emit_event.call_args_list[0] + assert first_call.kwargs["event"] == "test.event" + assert first_call.kwargs["resource"] == { + "prefect.resource.id": "downstream1", + "prefect.resource.lineage-group": "global", + } + assert first_call.kwargs["related"] == [ + { + "prefect.resource.id": "upstream1", + "prefect.resource.role": "some-purpose", + }, + { + "prefect.resource.id": "upstream2", + "prefect.resource.role": "some-purpose", + }, + ] + + # Check second downstream resource event + second_call = mock_emit_event.call_args_list[1] + assert second_call.kwargs["event"] == "test.event" + assert second_call.kwargs["resource"] == { + "prefect.resource.id": "downstream2", + "prefect.resource.lineage-group": "global", + } + assert second_call.kwargs["related"] == [ + { + "prefect.resource.id": "upstream1", + "prefect.resource.role": "some-purpose", + }, + { + "prefect.resource.id": "upstream2", + "prefect.resource.role": "some-purpose", + }, + ] + + async def test_emit_lineage_event_with_no_resources( + self, enable_lineage_events, mock_emit_event + ): + await emit_lineage_event(event_name="test.event") + mock_emit_event.assert_not_called() + + async def test_emit_lineage_event_disabled(self, mock_emit_event): + await emit_lineage_event( + event_name="test.event", + upstream_resources=[ + { + "prefect.resource.id": "upstream", + "prefect.resource.role": "some-purpose", + } + ], + downstream_resources=[ + { + "prefect.resource.id": "downstream", + "prefect.resource.lineage-group": "global", + "prefect.resource.role": "result", + } + ], + ) + mock_emit_event.assert_not_called() + + +class TestEmitResultEvents: + async def test_emit_result_read_event( + self, result_store, enable_lineage_events, mock_emit_event + ): + await emit_result_read_event( + result_store, + "test-key", + [ + { + "prefect.resource.id": "downstream", + "prefect.resource.role": "flow-run", + } + ], + ) + + mock_emit_event.assert_called_once() + call_args = mock_emit_event.call_args.kwargs + assert call_args["event"] == "prefect.result.read" + resource_uri = get_result_resource_uri(result_store, "test-key") + assert resource_uri is not None + assert call_args["resource"] == { + "prefect.resource.id": "downstream", + "prefect.resource.lineage-group": "global", + "prefect.resource.role": "flow-run", + } + assert call_args["related"] == [ + RelatedResource( + root={ + "prefect.resource.id": resource_uri, + "prefect.resource.role": "result", + } + ) + ] + + async def test_emit_result_write_event( + self, result_store, enable_lineage_events, mock_emit_event + ): + await emit_result_write_event(result_store, "test-key") + + mock_emit_event.assert_called_once() + call_args = mock_emit_event.call_args.kwargs + assert call_args["event"] == "prefect.result.write" + assert call_args["resource"] == { + "prefect.resource.id": get_result_resource_uri(result_store, "test-key"), + "prefect.resource.lineage-group": "global", + "prefect.resource.role": "result", + } + assert call_args["related"] == [] + + async def test_emit_result_read_event_with_none_uri( + self, enable_lineage_events, mock_emit_event + ): + store = ResultStore(result_storage=None) + await emit_result_read_event( + store, + "test-key", + [ + { + "prefect.resource.id": "downstream", + "prefect.resource.role": "flow-run", + } + ], + ) + mock_emit_event.assert_not_called() + + async def test_emit_result_write_event_with_none_uri( + self, enable_lineage_events, mock_emit_event + ): + store = ResultStore(result_storage=None) + await emit_result_write_event(store, "test-key") + mock_emit_event.assert_not_called() + + async def test_emit_result_read_event_with_downstream_resources( + self, result_store, enable_lineage_events, mock_emit_event + ): + await emit_result_read_event( + result_store, + "test-key", + downstream_resources=[ + {"prefect.resource.id": "downstream1"}, + {"prefect.resource.id": "downstream2"}, + ], + ) + + calls = mock_emit_event.call_args_list + assert len(calls) == 2 + + for i, call in enumerate(calls): + resource_uri = get_result_resource_uri(result_store, "test-key") + assert resource_uri is not None + assert call.kwargs["event"] == "prefect.result.read" + assert call.kwargs["resource"] == { + "prefect.resource.id": f"downstream{i+1}", + "prefect.resource.lineage-group": "global", + } + assert call.kwargs["related"] == [ + RelatedResource( + root={ + "prefect.resource.id": resource_uri, + "prefect.resource.role": "result", + } + ) + ] + + async def test_emit_result_write_event_with_upstream_resources( + self, result_store, enable_lineage_events, mock_emit_event + ): + await emit_result_write_event( + result_store, + "test-key", + upstream_resources=[ + { + "prefect.resource.id": "upstream", + "prefect.resource.role": "my-role", + } + ], + ) + + resolved_key_path = result_store._resolved_key_path("test-key") + resource_uri = get_result_resource_uri(result_store, resolved_key_path) + + mock_emit_event.assert_called_once_with( + event="prefect.result.write", + resource={ + "prefect.resource.id": resource_uri, + "prefect.resource.lineage-group": "global", + "prefect.resource.role": "result", + }, + related=[ + {"prefect.resource.id": "upstream", "prefect.resource.role": "my-role"}, + ], + ) diff --git a/tests/results/test_result_store.py b/tests/results/test_result_store.py index f516af249c9d..5ec34a2720e1 100644 --- a/tests/results/test_result_store.py +++ b/tests/results/test_result_store.py @@ -1,3 +1,5 @@ +from unittest import mock + import pytest import prefect.exceptions @@ -889,3 +891,72 @@ async def test_deprecation_warning_on_persist_result(): with pytest.warns(DeprecationWarning): ResultStore(persist_result=False) + + +class TestResultStoreEmitsEvents: + async def test_result_store_emits_write_event( + self, tmp_path, enable_lineage_events + ): + filesystem = LocalFileSystem(basepath=tmp_path) + result_store = ResultStore(result_storage=filesystem) + + with mock.patch("prefect.results.emit_result_write_event") as mock_emit: + await result_store.awrite(key="test", obj="test") + resolved_key_path = result_store._resolved_key_path("test") + mock_emit.assert_called_once_with(result_store, resolved_key_path) + + async def test_result_store_emits_read_event(self, tmp_path, enable_lineage_events): + filesystem = LocalFileSystem(basepath=tmp_path) + result_store = ResultStore(result_storage=filesystem) + await result_store.awrite(key="test", obj="test") + + # Reading from a different result store allows us to test the read + # without the store's in-memory cache. + other_result_store = ResultStore(result_storage=filesystem) + + with mock.patch("prefect.results.emit_result_read_event") as mock_emit: + await other_result_store.aread(key="test") + resolved_key_path = other_result_store._resolved_key_path("test") + mock_emit.assert_called_once_with(other_result_store, resolved_key_path) + + async def test_result_store_emits_cached_read_event( + self, tmp_path, enable_lineage_events + ): + result_store = ResultStore( + cache_result_in_memory=True, + ) + await result_store.awrite(key="test", obj="test") + + with mock.patch("prefect.results.emit_result_read_event") as mock_emit: + await result_store.aread(key="test") # cached read + resolved_key_path = result_store._resolved_key_path("test") + mock_emit.assert_called_once_with( + result_store, + resolved_key_path, + cached=True, + ) + + async def test_result_store_does_not_emit_lineage_write_events_when_disabled( + self, tmp_path + ): + filesystem = LocalFileSystem(basepath=tmp_path) + result_store = ResultStore(result_storage=filesystem) + + with mock.patch( + "prefect._experimental.lineage.emit_lineage_event" + ) as mock_emit: + await result_store.awrite(key="test", obj="test") + mock_emit.assert_not_called() + + async def test_result_store_does_not_emit_lineage_read_events_when_disabled( + self, tmp_path + ): + filesystem = LocalFileSystem(basepath=tmp_path) + result_store = ResultStore(result_storage=filesystem) + await result_store.awrite(key="test", obj="test") + + with mock.patch( + "prefect._experimental.lineage.emit_lineage_event" + ) as mock_emit: + await result_store.aread(key="test") + mock_emit.assert_not_called() diff --git a/tests/test_settings.py b/tests/test_settings.py index a42c32b0482d..c054e114e58d 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -231,6 +231,7 @@ "PREFECT_EXPERIMENTAL_WARN": {"test_value": True, "legacy": True}, "PREFECT_EXPERIMENTS_TELEMETRY_ENABLED": {"test_value": False}, "PREFECT_EXPERIMENTS_WARN": {"test_value": True}, + "PREFECT_EXPERIMENTS_LINEAGE_EVENTS_ENABLED": {"test_value": True}, "PREFECT_FLOW_DEFAULT_RETRIES": {"test_value": 10, "legacy": True}, "PREFECT_FLOWS_DEFAULT_RETRIES": {"test_value": 10}, "PREFECT_FLOW_DEFAULT_RETRY_DELAY_SECONDS": {"test_value": 10, "legacy": True}, From b2faff31873a571a7e598802ba6d27d1d8d4458d Mon Sep 17 00:00:00 2001 From: nate nowack Date: Thu, 12 Dec 2024 12:45:32 -0600 Subject: [PATCH 8/8] update tutorial for clarity (#16354) --- docs/v3/tutorials/pipelines.mdx | 346 +++++++++++++++----------------- 1 file changed, 167 insertions(+), 179 deletions(-) diff --git a/docs/v3/tutorials/pipelines.mdx b/docs/v3/tutorials/pipelines.mdx index 15be9115e354..5df61a230e53 100644 --- a/docs/v3/tutorials/pipelines.mdx +++ b/docs/v3/tutorials/pipelines.mdx @@ -22,14 +22,18 @@ The first improvement you can make is to add retries to your flow. Whenever an HTTP request fails, you can retry it a few times before giving up. ```python +from typing import Any + +import httpx from prefect import task + @task(retries=3) -def fetch_stats(github_repo: str): +def fetch_stats(github_repo: str) -> dict[str, Any]: """Task 1: Fetch the statistics for a GitHub repo""" api_response = httpx.get(f"https://api.github.com/repos/{github_repo}") - api_response.raise_for_status() # Force a retry if you don't get a 2xx status code + api_response.raise_for_status() # Force a retry if not a 2xx status code return api_response.json() ``` @@ -37,10 +41,26 @@ def fetch_stats(github_repo: str): Run the following code to see retries in action: ```python -import httpx +from typing import Any +import httpx from prefect import flow, task # Prefect flow and task decorators +@task(retries=3) +def fetch_stats(github_repo: str) -> dict[str, Any]: + """Task 1: Fetch the statistics for a GitHub repo""" + + api_response = httpx.get(f"https://api.github.com/repos/{github_repo}") + api_response.raise_for_status() # Force a retry if not a 2xx status code + return api_response.json() + + +@task +def get_stars(repo_stats: dict[str, Any]) -> int: + """Task 2: Get the number of stars from GitHub repo statistics""" + + return repo_stats['stargazers_count'] + @flow(log_prints=True) def show_stars(github_repos: list[str]): @@ -57,21 +77,6 @@ def show_stars(github_repos: list[str]): print(f"{repo}: {stars} stars") -@task(retries=3) -def fetch_stats(github_repo: str): - """Task 1: Fetch the statistics for a GitHub repo""" - - api_response = httpx.get(f"https://api.github.com/repos/{github_repo}") - api_response.raise_for_status() # Force a retry if you don't get a 2xx status code - return api_response.json() - - -@task -def get_stars(repo_stats: dict): - """Task 2: Get the number of stars from GitHub repo statistics""" - - return repo_stats['stargazers_count'] - # Run the flow if __name__ == "__main__": @@ -86,82 +91,83 @@ if __name__ == "__main__": ## Concurrent execution of slow tasks If individual API requests are slow, you can speed them up in aggregate by making multiple requests concurrently. -When you call the `submit` method on a task, the task is submitted to a task runner for execution. +When you call the `map` method on a task, you submit a list of arguments to the task runner to run concurrently (alternatively, you could [`.submit()` each argument individually](/v3/develop/task-runners#access-results-from-submitted-tasks)). ```python from prefect import flow @flow(log_prints=True) -def show_stars(github_repos: list[str]): - """Flow: Show the number of stars that GitHub repos have""" +def show_stars(github_repos: list[str]) -> None: + """Flow: Show number of GitHub repo stars""" # Task 1: Make HTTP requests concurrently - repo_stats = [] - for repo in github_repos: - repo_stats.append({ - 'repo': repo, - 'task': fetch_stats.submit(repo) # Submit each task to a task runner - }) - - # Task 2: Once each concurrent task completes, show the results - for repo in repo_stats: - repo_name = repo['repo'] - stars = get_stars(repo['task'].result()) # Block until the task has completed - print(f"{repo_name}: {stars} stars") + stats_futures = fetch_stats.map(github_repos) + + # Task 2: Once each concurrent task completes, get the star counts + stars = get_stars.map(stats_futures).result() + + # Show the results + for repo, star_count in zip(github_repos, stars): + print(f"{repo}: {star_count} stars") ``` Run the following code to see concurrent tasks in action: ```python -import httpx - -from prefect import flow, task # Prefect flow and task decorators - +from typing import Any -@flow(log_prints=True) -def show_stars(github_repos: list[str]): - """Flow: Show the number of stars that GitHub repos have""" - - # Task 1: Make HTTP requests concurrently - repo_stats = [] - for repo in github_repos: - repo_stats.append({ - 'repo': repo, - 'task': fetch_stats.submit(repo) # Submit each task to a task runner - }) - - # Task 2: Once each concurrent task completes, show the results - for repo in repo_stats: - repo_name = repo['repo'] - stars = get_stars(repo['task'].result()) # Block until the task has completed - print(f"{repo_name}: {stars} stars") +import httpx +from prefect import flow, task -@task -def fetch_stats(github_repo: str): +@task(retries=3) +def fetch_stats(github_repo: str) -> dict[str, Any]: """Task 1: Fetch the statistics for a GitHub repo""" - return httpx.get(f"https://api.github.com/repos/{github_repo}").json() @task -def get_stars(repo_stats: dict): +def get_stars(repo_stats: dict[str, Any]) -> int: """Task 2: Get the number of stars from GitHub repo statistics""" + return repo_stats["stargazers_count"] - return repo_stats['stargazers_count'] + +@flow(log_prints=True) +def show_stars(github_repos: list[str]) -> None: + """Flow: Show number of GitHub repo stars""" + + # Task 1: Make HTTP requests concurrently + stats_futures = fetch_stats.map(github_repos) + + # Task 2: Once each concurrent task completes, get the star counts + stars = get_stars.map(stats_futures).result() + + # Show the results + for repo, star_count in zip(github_repos, stars): + print(f"{repo}: {star_count} stars") -# Run the flow if __name__ == "__main__": - show_stars([ - "PrefectHQ/prefect", - "pydantic/pydantic", - "huggingface/transformers" - ]) + # Run the flow + show_stars( + [ + "PrefectHQ/prefect", + "pydantic/pydantic", + "huggingface/transformers" + ] + ) + ``` + +Calling `.result()` on the list of futures returned by `.map()` will block until all tasks are complete. + +Read more in the [`.map()` documentation](/v3/develop/task-runners#mapping-over-iterables). + + + ## Avoid getting rate limited One consequence of running tasks concurrently is that you're more likely to hit the rate limits of whatever API you're using. @@ -172,74 +178,56 @@ To avoid this, use Prefect to set a global concurrency limit. prefect gcl create github-api --limit 60 --slot-decay-per-second 0.016 ``` -Now, you can use this global concurrency limit in your code: +Now, you can use this global concurrency limit in your code to rate limit your API requests. ```python -from prefect import flow +from typing import Any +from prefect import task from prefect.concurrency.sync import rate_limit -@flow(log_prints=True) -def show_stars(github_repos: list[str]): - """Flow: Show the number of stars that GitHub repos have""" - - repo_stats = [] - for repo in github_repos: - # Apply the concurrency limit to this loop - rate_limit("github-api") - - # Call Task 1 - repo_stats.append({ - 'repo': repo, - 'task': fetch_stats.submit(repo) - }) - - # ... +@task +def fetch_stats(github_repo: str) -> dict[str, Any]: + """Task 1: Fetch the statistics for a GitHub repo""" + rate_limit("github-api") + return httpx.get(f"https://api.github.com/repos/{github_repo}").json() ``` Run the following code to see concurrency limits in action: ```python -import httpx +from typing import Any -from prefect import flow, task # Prefect flow and task decorators +import httpx +from prefect import flow, task from prefect.concurrency.sync import rate_limit - -@flow(log_prints=True) -def show_stars(github_repos: list[str]): - """Flow: Show the number of stars that GitHub repos have""" - - repo_stats = [] - for repo in github_repos: - # Apply the concurrency limit to this loop - rate_limit("github-api") - - # Call Task 1 - repo_stats.append({ - 'repo': repo, - 'task': fetch_stats.submit(repo) - }) - - # Call Task 2 - stars = get_stars(repo_stats) - - # Print the result - print(f"{repo}: {stars} stars") - - -@task -def fetch_stats(github_repo: str): +@task(retries=3) +def fetch_stats(github_repo: str) -> dict[str, Any]: """Task 1: Fetch the statistics for a GitHub repo""" - + rate_limit("github-api") return httpx.get(f"https://api.github.com/repos/{github_repo}").json() @task -def get_stars(repo_stats: dict): +def get_stars(repo_stats: dict[str, Any]) -> int: """Task 2: Get the number of stars from GitHub repo statistics""" + return repo_stats["stargazers_count"] - return repo_stats['stargazers_count'] + +@flow(log_prints=True) +def show_stars(github_repos: list[str]) -> None: + """Flow: Show number of GitHub repo stars""" + + # Task 1: Make HTTP requests concurrently + stats_futures = fetch_stats.map(github_repos) + + # Task 2: Once each concurrent task completes, get the star counts + stars = get_stars.map(stats_futures).result() + + # Show the results + for repo, star_count in zip(github_repos, stars): + print(f"{repo}: {star_count} stars") # Run the flow @@ -258,13 +246,14 @@ For efficiency, you can skip tasks that have already run. For example, if you don't want to fetch the number of stars for a given repository more than once per day, you can cache those results for a day. ```python +from typing import Any from datetime import timedelta from prefect import task from prefect.cache_policies import INPUTS @task(cache_policy=INPUTS, cache_expiration=timedelta(days=1)) -def fetch_stats(github_repo: str): +def fetch_stats(github_repo: str) -> dict[str, Any]: """Task 1: Fetch the statistics for a GitHub repo""" # ... ``` @@ -273,40 +262,44 @@ def fetch_stats(github_repo: str): Run the following code to see caching in action: ```python +from typing import Any from datetime import timedelta -import httpx -from prefect import flow, task # Prefect flow and task decorators +import httpx +from prefect import flow, task from prefect.cache_policies import INPUTS +from prefect.concurrency.sync import rate_limit - -@flow(log_prints=True) -def show_stars(github_repos: list[str]): - """Flow: Show the number of stars that GitHub repos have""" - - for repo in github_repos: - # Call Task 1 - repo_stats = fetch_stats(repo) - - # Call Task 2 - stars = get_stars(repo_stats) - - # Print the result - print(f"{repo}: {stars} stars") - - -@task(cache_policy=INPUTS, cache_expiration=timedelta(days=1)) -def fetch_stats(github_repo: str): +@task( + retries=3, + cache_policy=INPUTS, + cache_expiration=timedelta(days=1) +) +def fetch_stats(github_repo: str) -> dict[str, Any]: """Task 1: Fetch the statistics for a GitHub repo""" - + rate_limit("github-api") return httpx.get(f"https://api.github.com/repos/{github_repo}").json() @task -def get_stars(repo_stats: dict): +def get_stars(repo_stats: dict[str, Any]) -> int: """Task 2: Get the number of stars from GitHub repo statistics""" + return repo_stats["stargazers_count"] - return repo_stats['stargazers_count'] + +@flow(log_prints=True) +def show_stars(github_repos: list[str]) -> None: + """Flow: Show number of GitHub repo stars""" + + # Task 1: Make HTTP requests concurrently + stats_futures = fetch_stats.map(github_repos) + + # Task 2: Once each concurrent task completes, get the star counts + stars = get_stars.map(stats_futures).result() + + # Show the results + for repo, star_count in zip(github_repos, stars): + print(f"{repo}: {star_count} stars") # Run the flow @@ -324,48 +317,44 @@ if __name__ == "__main__": This is what your flow looks like after applying all of these improvements: ```python my_data_pipeline.py +from typing import Any from datetime import timedelta -import httpx +import httpx from prefect import flow, task from prefect.cache_policies import INPUTS from prefect.concurrency.sync import rate_limit +@task( + retries=3, + cache_policy=INPUTS, + cache_expiration=timedelta(days=1) +) +def fetch_stats(github_repo: str) -> dict[str, Any]: + """Task 1: Fetch the statistics for a GitHub repo""" + rate_limit("github-api") + return httpx.get(f"https://api.github.com/repos/{github_repo}").json() -@flow(log_prints=True) -def show_stars(github_repos: list[str]): - """Flow: Show the number of stars that GitHub repos have""" - - # Task 1: Make HTTP requests concurrently while respecting concurrency limits - repo_stats = [] - for repo in github_repos: - rate_limit("github-api") - repo_stats.append({ - 'repo': repo, - 'task': fetch_stats.submit(repo) # Submit each task to a task runner - }) - - # Task 2: Once each concurrent task completes, show the results - for repo in repo_stats: - repo_name = repo['repo'] - stars = get_stars(repo['task'].result()) # Block until the task has completed - print(f"{repo_name}: {stars} stars") +@task +def get_stars(repo_stats: dict[str, Any]) -> int: + """Task 2: Get the number of stars from GitHub repo statistics""" + return repo_stats["stargazers_count"] -@task(retries=3, cache_policy=INPUTS, cache_expiration=timedelta(days=1)) -def fetch_stats(github_repo: str): - """Task 1: Fetch the statistics for a GitHub repo""" - api_response = httpx.get(f"https://api.github.com/repos/{github_repo}") - api_response.raise_for_status() # Force a retry if you don't get a 2xx status code - return api_response.json() +@flow(log_prints=True) +def show_stars(github_repos: list[str]) -> None: + """Flow: Show number of GitHub repo stars""" + # Task 1: Make HTTP requests concurrently + stats_futures = fetch_stats.map(github_repos) -@task -def get_stars(repo_stats: dict): - """Task 2: Get the number of stars from GitHub repo statistics""" + # Task 2: Once each concurrent task completes, get the star counts + stars = get_stars.map(stats_futures).result() - return repo_stats['stargazers_count'] + # Show the results + for repo, star_count in zip(github_repos, stars): + print(f"{repo}: {star_count} stars") # Run the flow @@ -383,25 +372,24 @@ Run your flow twice: once to run the tasks and cache the result, again to retrie # Run the tasks and cache the results python my_data_pipeline.py -# Retrieve the cached results +# Run again (notice the cached results) python my_data_pipeline.py ``` The terminal output from the second flow run should look like this: ```bash -09:08:12.265 | INFO | prefect.engine - Created flow run 'laughing-nightingale' for flow 'show-stars' -09:08:12.266 | INFO | prefect.engine - View at http://127.0.0.1:4200/runs/flow-run/541864e8-12f7-4890-9397-b2ed361f6b20 -09:08:12.322 | INFO | Task run 'fetch_stats-0c9' - Finished in state Cached(type=COMPLETED) -09:08:12.359 | INFO | Task run 'fetch_stats-e89' - Finished in state Cached(type=COMPLETED) -09:08:12.360 | INFO | Task run 'get_stars-b51' - Finished in state Completed() -09:08:12.361 | INFO | Flow run 'laughing-nightingale' - PrefectHQ/prefect: 17320 stars -09:08:12.372 | INFO | Task run 'fetch_stats-8ef' - Finished in state Cached(type=COMPLETED) -09:08:12.374 | INFO | Task run 'get_stars-08d' - Finished in state Completed() -09:08:12.374 | INFO | Flow run 'laughing-nightingale' - pydantic/pydantic: 186319 stars -09:08:12.387 | INFO | Task run 'get_stars-2af' - Finished in state Completed() -09:08:12.387 | INFO | Flow run 'laughing-nightingale' - huggingface/transformers: 134849 stars -09:08:12.404 | INFO | Flow run 'laughing-nightingale' - Finished in state Completed() +20:03:04.398 | INFO | prefect.engine - Created flow run 'laughing-nightingale' for flow 'show-stars' +20:03:05.146 | INFO | Task run 'fetch_stats-90f' - Finished in state Cached(type=COMPLETED) +20:03:05.149 | INFO | Task run 'fetch_stats-258' - Finished in state Cached(type=COMPLETED) +20:03:05.153 | INFO | Task run 'fetch_stats-924' - Finished in state Cached(type=COMPLETED) +20:03:05.159 | INFO | Task run 'get_stars-3a9' - Finished in state Completed() +20:03:05.159 | INFO | Task run 'get_stars-ed3' - Finished in state Completed() +20:03:05.161 | INFO | Task run 'get_stars-39c' - Finished in state Completed() +20:03:05.162 | INFO | Flow run 'laughing-nightingale' - PrefectHQ/prefect: 17756 stars +20:03:05.163 | INFO | Flow run 'laughing-nightingale' - pydantic/pydantic: 21613 stars +20:03:05.163 | INFO | Flow run 'laughing-nightingale' - huggingface/transformers: 136166 stars +20:03:05.339 | INFO | Flow run 'laughing-nightingale' - Finished in state Completed() ``` ## Next steps