From db49b6859e9e698e357384b8714909178ea6bd4b Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Fri, 6 Dec 2024 15:04:50 -0600 Subject: [PATCH 1/4] fruitful --- src/prefect/automations.py | 29 ++++++++++++------ src/prefect/cache_policies.py | 22 +++++++------- src/prefect/context.py | 49 +++++++++++++++++-------------- src/prefect/settings/constants.py | 4 +-- src/prefect/settings/legacy.py | 2 +- src/prefect/utilities/hashing.py | 13 ++++---- 6 files changed, 69 insertions(+), 50 deletions(-) diff --git a/src/prefect/automations.py b/src/prefect/automations.py index 86799ec59cd6..4f97f7dcac70 100644 --- a/src/prefect/automations.py +++ b/src/prefect/automations.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional, Type, cast from uuid import UUID from pydantic import Field @@ -41,6 +41,9 @@ from prefect.exceptions import PrefectHTTPStatusError from prefect.utilities.asyncutils import sync_compatible +if TYPE_CHECKING: + from prefect.client.orchestration import PrefectClient + __all__ = [ "AutomationCore", "EventTrigger", @@ -75,6 +78,11 @@ ] +def _get_or_create_and_cast_client() -> tuple["PrefectClient", bool]: + client, inferred = get_or_create_client() + return cast("PrefectClient", client), inferred + + class Automation(AutomationCore): id: Optional[UUID] = Field(default=None, description="The ID of this automation") @@ -99,7 +107,7 @@ async def create(self: Self) -> Self: ) created_automation = auto_to_create.create() """ - client, _ = get_or_create_client() + client, _ = _get_or_create_and_cast_client() automation = AutomationCore(**self.model_dump(exclude={"id"})) self.id = await client.create_automation(automation=automation) return self @@ -113,14 +121,14 @@ async def update(self: Self): auto.update() """ - client, _ = get_or_create_client() + client, _ = _get_or_create_and_cast_client() automation = AutomationCore(**self.model_dump(exclude={"id", "owner_resource"})) await client.update_automation(automation_id=self.id, automation=automation) @classmethod @sync_compatible async def read( - cls: Self, id: Optional[UUID] = None, name: Optional[str] = None + cls: Type[Self], id: Optional[UUID] = None, name: Optional[str] = None ) -> Self: """ Read an automation by ID or name. @@ -134,13 +142,16 @@ async def read( raise ValueError("Only one of id or name can be provided") if not id and not name: raise ValueError("One of id or name must be provided") - client, _ = get_or_create_client() + client, _ = _get_or_create_and_cast_client() if id: try: automation = await client.read_automation(automation_id=id) except PrefectHTTPStatusError as exc: if exc.response.status_code == 404: raise ValueError(f"Automation with ID {id!r} not found") + raise + if automation is None: + raise ValueError(f"Automation with ID {id!r} not found") return Automation(**automation.model_dump()) else: automation = await client.read_automations_by_name(name=name) @@ -156,7 +167,7 @@ async def delete(self: Self) -> bool: auto.delete() """ try: - client, _ = get_or_create_client() + client, _ = _get_or_create_and_cast_client() await client.delete_automation(self.id) return True except PrefectHTTPStatusError as exc: @@ -172,7 +183,7 @@ async def disable(self: Self) -> bool: auto.disable() """ try: - client, _ = get_or_create_client() + client, _ = _get_or_create_and_cast_client() await client.pause_automation(self.id) return True except PrefectHTTPStatusError as exc: @@ -188,8 +199,8 @@ async def enable(self: Self) -> bool: auto.enable() """ try: - client, _ = get_or_create_client() - await client.resume_automation("asd") + client, _ = _get_or_create_and_cast_client() + await client.resume_automation(self.id) return True except PrefectHTTPStatusError as exc: if exc.response.status_code == 404: diff --git a/src/prefect/cache_policies.py b/src/prefect/cache_policies.py index 50717e5ceaea..746f8561cdfa 100644 --- a/src/prefect/cache_policies.py +++ b/src/prefect/cache_policies.py @@ -75,12 +75,12 @@ def compute_key( task_ctx: TaskRunContext, inputs: Dict[str, Any], flow_parameters: Dict[str, Any], - **kwargs, + **kwargs: Any, ) -> Optional[str]: raise NotImplementedError def __sub__(self, other: str) -> "CachePolicy": - if not isinstance(other, str): + if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance] raise TypeError("Can only subtract strings from key policies.") new = Inputs(exclude=[other]) return CompoundCachePolicy(policies=[self, new]) @@ -140,7 +140,7 @@ def compute_key( task_ctx: TaskRunContext, inputs: Dict[str, Any], flow_parameters: Dict[str, Any], - **kwargs, + **kwargs: Any, ) -> Optional[str]: if self.cache_key_fn: return self.cache_key_fn(task_ctx, inputs) @@ -162,9 +162,9 @@ def compute_key( task_ctx: TaskRunContext, inputs: Dict[str, Any], flow_parameters: Dict[str, Any], - **kwargs, + **kwargs: Any, ) -> Optional[str]: - keys = [] + keys: list[str] = [] for policy in self.policies: policy_key = policy.compute_key( task_ctx=task_ctx, @@ -191,7 +191,7 @@ def compute_key( task_ctx: TaskRunContext, inputs: Dict[str, Any], flow_parameters: Dict[str, Any], - **kwargs, + **kwargs: Any, ) -> Optional[str]: return None @@ -211,7 +211,7 @@ def compute_key( task_ctx: TaskRunContext, inputs: Optional[Dict[str, Any]], flow_parameters: Optional[Dict[str, Any]], - **kwargs, + **kwargs: Any, ) -> Optional[str]: if not task_ctx: return None @@ -238,7 +238,7 @@ def compute_key( task_ctx: TaskRunContext, inputs: Dict[str, Any], flow_parameters: Dict[str, Any], - **kwargs, + **kwargs: Any, ) -> Optional[str]: if not flow_parameters: return None @@ -257,7 +257,7 @@ def compute_key( task_ctx: TaskRunContext, inputs: Dict[str, Any], flow_parameters: Dict[str, Any], - **kwargs, + **kwargs: Any, ) -> Optional[str]: if not task_ctx: return None @@ -280,7 +280,7 @@ def compute_key( task_ctx: TaskRunContext, inputs: Dict[str, Any], flow_parameters: Dict[str, Any], - **kwargs, + **kwargs: Any, ) -> Optional[str]: hashed_inputs = {} inputs = inputs or {} @@ -307,7 +307,7 @@ def compute_key( raise ValueError(msg) from exc def __sub__(self, other: str) -> "CachePolicy": - if not isinstance(other, str): + if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance] raise TypeError("Can only subtract strings from key policies.") return Inputs(exclude=self.exclude + [other]) diff --git a/src/prefect/context.py b/src/prefect/context.py index 675812de85c2..69c14ce4fdb0 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -43,7 +43,9 @@ get_default_persist_setting_for_tasks, ) from prefect.settings import Profile, Settings -from prefect.settings.legacy import _get_settings_fields +from prefect.settings.legacy import ( + _get_settings_fields, # type: ignore[reportPrivateUsage] +) from prefect.states import State from prefect.task_runners import TaskRunner from prefect.utilities.services import start_client_metrics_server @@ -200,14 +202,14 @@ class SyncClientContext(ContextModel): assert c1 is ctx.client """ - __var__ = ContextVar("sync-client-context") + __var__: ContextVar[Self] = ContextVar("sync-client-context") client: SyncPrefectClient _httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None) _context_stack: int = PrivateAttr(0) def __init__(self, httpx_settings: Optional[dict[str, Any]] = None): super().__init__( - client=get_client(sync_client=True, httpx_settings=httpx_settings), + client=get_client(sync_client=True, httpx_settings=httpx_settings), # type: ignore[reportCallIssue] ) self._httpx_settings = httpx_settings self._context_stack = 0 @@ -221,11 +223,11 @@ def __enter__(self): else: return self - def __exit__(self, *exc_info): + def __exit__(self, *exc_info: Any): self._context_stack -= 1 if self._context_stack == 0: - self.client.__exit__(*exc_info) - return super().__exit__(*exc_info) + self.client.__exit__(*exc_info) # type: ignore[reportUnknownMemberType] + return super().__exit__(*exc_info) # type: ignore[reportUnknownMemberType] @classmethod @contextmanager @@ -265,12 +267,12 @@ class AsyncClientContext(ContextModel): def __init__(self, httpx_settings: Optional[dict[str, Any]] = None): super().__init__( - client=get_client(sync_client=False, httpx_settings=httpx_settings), + client=get_client(sync_client=False, httpx_settings=httpx_settings), # type: ignore[reportCallIssue] ) self._httpx_settings = httpx_settings self._context_stack = 0 - async def __aenter__(self): + async def __aenter__(self: Self) -> Self: self._context_stack += 1 if self._context_stack == 1: await self.client.__aenter__() @@ -279,11 +281,11 @@ async def __aenter__(self): else: return self - async def __aexit__(self, *exc_info): + async def __aexit__(self: Self, *exc_info: Any) -> None: self._context_stack -= 1 if self._context_stack == 0: - await self.client.__aexit__(*exc_info) - return super().__exit__(*exc_info) + await self.client.__aexit__(*exc_info) # type: ignore[reportUnknownMemberType] + return super().__exit__(*exc_info) # type: ignore[reportUnknownMemberType] @classmethod @asynccontextmanager @@ -306,7 +308,7 @@ class RunContext(ContextModel): client: The Prefect client instance being used for API communication """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) start_client_metrics_server() @@ -315,10 +317,11 @@ def __init__(self, *args, **kwargs): input_keyset: Optional[Dict[str, Dict[str, str]]] = None client: Union[PrefectClient, SyncPrefectClient] - def serialize(self: Self) -> Dict[str, Any]: + def serialize(self: Self, include_secrets: bool = True) -> Dict[str, Any]: return self.model_dump( include={"start_time", "input_keyset"}, exclude_unset=True, + context={"include_secrets": include_secrets}, ) @@ -364,9 +367,9 @@ class EngineContext(RunContext): # Events worker to emit events events: Optional[EventsWorker] = None - __var__: ContextVar = ContextVar("flow_run") + __var__: ContextVar[Self] = ContextVar("flow_run") - def serialize(self): + def serialize(self: Self, include_secrets: bool = True) -> Dict[str, Any]: return self.model_dump( include={ "flow_run", @@ -380,6 +383,7 @@ def serialize(self): }, exclude_unset=True, serialize_as_any=True, + context={"include_secrets": include_secrets}, ) @@ -396,7 +400,7 @@ class TaskRunContext(RunContext): task_run: The API metadata for this task run """ - task: "Task" + task: "Task[Any, Any]" task_run: TaskRun log_prints: bool = False parameters: Dict[str, Any] @@ -407,7 +411,7 @@ class TaskRunContext(RunContext): __var__ = ContextVar("task_run") - def serialize(self): + def serialize(self: Self, include_secrets: bool = True) -> Dict[str, Any]: return self.model_dump( include={ "task_run", @@ -421,6 +425,7 @@ def serialize(self): }, exclude_unset=True, serialize_as_any=True, + context={"include_secrets": include_secrets}, ) @@ -439,7 +444,7 @@ def get(cls) -> "TagsContext": # Return an empty `TagsContext` instead of `None` if no context exists return cls.__var__.get(TagsContext()) - __var__: ContextVar = ContextVar("tags") + __var__: ContextVar[Self] = ContextVar("tags") class SettingsContext(ContextModel): @@ -456,9 +461,9 @@ class SettingsContext(ContextModel): profile: Profile settings: Settings - __var__: ContextVar = ContextVar("settings") + __var__: ContextVar[Self] = ContextVar("settings") - def __hash__(self) -> int: + def __hash__(self: Self) -> int: return hash(self.settings) @classmethod @@ -565,7 +570,7 @@ def tags(*new_tags: str) -> Generator[Set[str], None, None]: @contextmanager def use_profile( - profile: Union[Profile, str], + profile: Union[Profile, str, Any], override_environment_variables: bool = False, include_current_context: bool = True, ): @@ -665,7 +670,7 @@ def root_settings_context(): # an override in the `SettingsContext.get` method. -GLOBAL_SETTINGS_CONTEXT: SettingsContext = root_settings_context() +GLOBAL_SETTINGS_CONTEXT: SettingsContext = root_settings_context() # type: ignore[reportConstantRedefinition] # 2024-07-02: This surfaces an actionable error message for removed objects diff --git a/src/prefect/settings/constants.py b/src/prefect/settings/constants.py index ac7520492b61..70d00ccd9394 100644 --- a/src/prefect/settings/constants.py +++ b/src/prefect/settings/constants.py @@ -1,8 +1,8 @@ from pathlib import Path -from typing import Tuple, Type +from typing import Any, Tuple, Type from pydantic import Secret, SecretStr DEFAULT_PREFECT_HOME = Path.home() / ".prefect" DEFAULT_PROFILES_PATH = Path(__file__).parent.joinpath("profiles.toml") -_SECRET_TYPES: Tuple[Type, ...] = (Secret, SecretStr) +_SECRET_TYPES: Tuple[Type[Any], ...] = (Secret, SecretStr) diff --git a/src/prefect/settings/legacy.py b/src/prefect/settings/legacy.py index 17f76e3f1404..6bc496fe1aee 100644 --- a/src/prefect/settings/legacy.py +++ b/src/prefect/settings/legacy.py @@ -8,7 +8,7 @@ from typing_extensions import Self from prefect.settings.base import PrefectBaseSettings -from prefect.settings.constants import _SECRET_TYPES +from prefect.settings.constants import _SECRET_TYPES # type: ignore[reportPrivateUsage] from prefect.settings.context import get_current_settings from prefect.settings.models.root import Settings diff --git a/src/prefect/utilities/hashing.py b/src/prefect/utilities/hashing.py index 2724cb38c3f4..b31a60609164 100644 --- a/src/prefect/utilities/hashing.py +++ b/src/prefect/utilities/hashing.py @@ -2,7 +2,7 @@ import sys from functools import partial from pathlib import Path -from typing import Optional, Union +from typing import Any, Callable, Optional, Union import cloudpickle @@ -15,7 +15,7 @@ _md5 = hashlib.md5 -def stable_hash(*args: Union[str, bytes], hash_algo=_md5) -> str: +def stable_hash(*args: Union[str, bytes], hash_algo: Callable[..., Any] = _md5) -> str: """Given some arguments, produces a stable 64-bit hash of their contents. Supports bytes and strings. Strings will be UTF-8 encoded. @@ -35,7 +35,7 @@ def stable_hash(*args: Union[str, bytes], hash_algo=_md5) -> str: return h.hexdigest() -def file_hash(path: str, hash_algo=_md5) -> str: +def file_hash(path: str, hash_algo: Callable[..., Any] = _md5) -> str: """Given a path to a file, produces a stable hash of the file contents. Args: @@ -50,7 +50,10 @@ def file_hash(path: str, hash_algo=_md5) -> str: def hash_objects( - *args, hash_algo=_md5, raise_on_failure: bool = False, **kwargs + *args: Any, + hash_algo: Callable[..., Any] = _md5, + raise_on_failure: bool = False, + **kwargs: Any, ) -> Optional[str]: """ Attempt to hash objects by dumping to JSON or serializing with cloudpickle. @@ -77,7 +80,7 @@ def hash_objects( json_error = str(e) try: - return stable_hash(cloudpickle.dumps((args, kwargs)), hash_algo=hash_algo) + return stable_hash(cloudpickle.dumps((args, kwargs)), hash_algo=hash_algo) # type: ignore[reportUnknownMemberType] except Exception as e: pickle_error = str(e) From cdf3892bd5708cd7bda8aeaf6f97b43b949a6665 Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Fri, 6 Dec 2024 15:17:58 -0600 Subject: [PATCH 2/4] add test --- tests/test_automations.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_automations.py b/tests/test_automations.py index c9efa5e45eb2..96781cf99e18 100644 --- a/tests/test_automations.py +++ b/tests/test_automations.py @@ -183,3 +183,13 @@ async def test_nonexistent_id_raises_value_error(): async def test_nonexistent_name_raises_value_error(): with pytest.raises(ValueError): await Automation.read(name="nonexistent_name") + + +async def test_disabled_automation_can_be_enabled( + prefect_client, automation: Automation +): + await automation.disable() + await automation.enable() + + updated_automation = await Automation.read(id=automation.id) + assert updated_automation.enabled is True From 7d066ed9dbc170d8bb43f01a324ab12ea1f0c037 Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Fri, 6 Dec 2024 15:29:33 -0600 Subject: [PATCH 3/4] use client normally --- src/prefect/automations.py | 108 ++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 55 deletions(-) diff --git a/src/prefect/automations.py b/src/prefect/automations.py index 4f97f7dcac70..0d629fdc8f7e 100644 --- a/src/prefect/automations.py +++ b/src/prefect/automations.py @@ -1,10 +1,10 @@ -from typing import TYPE_CHECKING, Optional, Type, cast +from typing import TYPE_CHECKING, Optional, Type from uuid import UUID from pydantic import Field from typing_extensions import Self -from prefect.client.utilities import get_or_create_client +from prefect.client.orchestration import get_client from prefect.events.actions import ( CallWebhook, CancelFlowRun, @@ -42,7 +42,7 @@ from prefect.utilities.asyncutils import sync_compatible if TYPE_CHECKING: - from prefect.client.orchestration import PrefectClient + pass __all__ = [ "AutomationCore", @@ -78,11 +78,6 @@ ] -def _get_or_create_and_cast_client() -> tuple["PrefectClient", bool]: - client, inferred = get_or_create_client() - return cast("PrefectClient", client), inferred - - class Automation(AutomationCore): id: Optional[UUID] = Field(default=None, description="The ID of this automation") @@ -107,10 +102,10 @@ async def create(self: Self) -> Self: ) created_automation = auto_to_create.create() """ - client, _ = _get_or_create_and_cast_client() - automation = AutomationCore(**self.model_dump(exclude={"id"})) - self.id = await client.create_automation(automation=automation) - return self + async with get_client() as client: + automation = AutomationCore(**self.model_dump(exclude={"id"})) + self.id = await client.create_automation(automation=automation) + return self @sync_compatible async def update(self: Self): @@ -120,10 +115,11 @@ async def update(self: Self): auto.name = "new name" auto.update() """ - - client, _ = _get_or_create_and_cast_client() - automation = AutomationCore(**self.model_dump(exclude={"id", "owner_resource"})) - await client.update_automation(automation_id=self.id, automation=automation) + async with get_client() as client: + automation = AutomationCore( + **self.model_dump(exclude={"id", "owner_resource"}) + ) + await client.update_automation(automation_id=self.id, automation=automation) @classmethod @sync_compatible @@ -142,23 +138,25 @@ async def read( raise ValueError("Only one of id or name can be provided") if not id and not name: raise ValueError("One of id or name must be provided") - client, _ = _get_or_create_and_cast_client() - if id: - try: - automation = await client.read_automation(automation_id=id) - except PrefectHTTPStatusError as exc: - if exc.response.status_code == 404: + async with get_client() as client: + if id: + try: + automation = await client.read_automation(automation_id=id) + except PrefectHTTPStatusError as exc: + if exc.response.status_code == 404: + raise ValueError(f"Automation with ID {id!r} not found") + raise + if automation is None: raise ValueError(f"Automation with ID {id!r} not found") - raise - if automation is None: - raise ValueError(f"Automation with ID {id!r} not found") - return Automation(**automation.model_dump()) - else: - automation = await client.read_automations_by_name(name=name) - if len(automation) > 0: - return Automation(**automation[0].model_dump()) if automation else None + return Automation(**automation.model_dump()) else: - raise ValueError(f"Automation with name {name!r} not found") + automation = await client.read_automations_by_name(name=name) + if len(automation) > 0: + return ( + Automation(**automation[0].model_dump()) if automation else None + ) + else: + raise ValueError(f"Automation with name {name!r} not found") @sync_compatible async def delete(self: Self) -> bool: @@ -166,14 +164,14 @@ async def delete(self: Self) -> bool: auto = Automation.read(id = 123) auto.delete() """ - try: - client, _ = _get_or_create_and_cast_client() - await client.delete_automation(self.id) - return True - except PrefectHTTPStatusError as exc: - if exc.response.status_code == 404: - return False - raise + async with get_client() as client: + try: + await client.delete_automation(self.id) + return True + except PrefectHTTPStatusError as exc: + if exc.response.status_code == 404: + return False + raise @sync_compatible async def disable(self: Self) -> bool: @@ -182,14 +180,14 @@ async def disable(self: Self) -> bool: auto = Automation.read(id = 123) auto.disable() """ - try: - client, _ = _get_or_create_and_cast_client() - await client.pause_automation(self.id) - return True - except PrefectHTTPStatusError as exc: - if exc.response.status_code == 404: - return False - raise + async with get_client() as client: + try: + await client.pause_automation(self.id) + return True + except PrefectHTTPStatusError as exc: + if exc.response.status_code == 404: + return False + raise @sync_compatible async def enable(self: Self) -> bool: @@ -198,11 +196,11 @@ async def enable(self: Self) -> bool: auto = Automation.read(id = 123) auto.enable() """ - try: - client, _ = _get_or_create_and_cast_client() - await client.resume_automation(self.id) - return True - except PrefectHTTPStatusError as exc: - if exc.response.status_code == 404: - return False - raise + async with get_client() as client: + try: + await client.resume_automation(self.id) + return True + except PrefectHTTPStatusError as exc: + if exc.response.status_code == 404: + return False + raise From 85bed05093219b584015b646aedaacfebb9043fd Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Fri, 6 Dec 2024 15:44:37 -0600 Subject: [PATCH 4/4] rm random TYPE_CHECKING block --- src/prefect/automations.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/prefect/automations.py b/src/prefect/automations.py index 0d629fdc8f7e..a37c5a3a45dd 100644 --- a/src/prefect/automations.py +++ b/src/prefect/automations.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Type +from typing import Optional, Type from uuid import UUID from pydantic import Field @@ -41,9 +41,6 @@ from prefect.exceptions import PrefectHTTPStatusError from prefect.utilities.asyncutils import sync_compatible -if TYPE_CHECKING: - pass - __all__ = [ "AutomationCore", "EventTrigger",