Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fruitful typing scouring #16255

Merged
merged 4 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 53 additions & 47 deletions src/prefect/automations.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Optional
from typing import Optional, Type
from uuid import UUID

from pydantic import Field
from typing_extensions import Self

from prefect.client.utilities import get_or_create_client
from prefect.client.orchestration import get_client
from prefect.events.actions import (
CallWebhook,
CancelFlowRun,
Expand Down Expand Up @@ -99,10 +99,10 @@ async def create(self: Self) -> Self:
)
created_automation = auto_to_create.create()
"""
client, _ = get_or_create_client()
automation = AutomationCore(**self.model_dump(exclude={"id"}))
self.id = await client.create_automation(automation=automation)
return self
async with get_client() as client:
automation = AutomationCore(**self.model_dump(exclude={"id"}))
self.id = await client.create_automation(automation=automation)
return self

@sync_compatible
async def update(self: Self):
Expand All @@ -112,15 +112,16 @@ async def update(self: Self):
auto.name = "new name"
auto.update()
"""

client, _ = get_or_create_client()
automation = AutomationCore(**self.model_dump(exclude={"id", "owner_resource"}))
await client.update_automation(automation_id=self.id, automation=automation)
async with get_client() as client:
automation = AutomationCore(
**self.model_dump(exclude={"id", "owner_resource"})
)
await client.update_automation(automation_id=self.id, automation=automation)

@classmethod
@sync_compatible
async def read(
cls: Self, id: Optional[UUID] = None, name: Optional[str] = None
cls: Type[Self], id: Optional[UUID] = None, name: Optional[str] = None
) -> Self:
"""
Read an automation by ID or name.
Expand All @@ -134,35 +135,40 @@ async def read(
raise ValueError("Only one of id or name can be provided")
if not id and not name:
raise ValueError("One of id or name must be provided")
client, _ = get_or_create_client()
if id:
try:
automation = await client.read_automation(automation_id=id)
except PrefectHTTPStatusError as exc:
if exc.response.status_code == 404:
async with get_client() as client:
if id:
try:
automation = await client.read_automation(automation_id=id)
except PrefectHTTPStatusError as exc:
if exc.response.status_code == 404:
raise ValueError(f"Automation with ID {id!r} not found")
raise
if automation is None:
raise ValueError(f"Automation with ID {id!r} not found")
return Automation(**automation.model_dump())
else:
automation = await client.read_automations_by_name(name=name)
if len(automation) > 0:
return Automation(**automation[0].model_dump()) if automation else None
return Automation(**automation.model_dump())
else:
raise ValueError(f"Automation with name {name!r} not found")
automation = await client.read_automations_by_name(name=name)
if len(automation) > 0:
return (
Automation(**automation[0].model_dump()) if automation else None
)
else:
raise ValueError(f"Automation with name {name!r} not found")

@sync_compatible
async def delete(self: Self) -> bool:
"""
auto = Automation.read(id = 123)
auto.delete()
"""
try:
client, _ = get_or_create_client()
await client.delete_automation(self.id)
return True
except PrefectHTTPStatusError as exc:
if exc.response.status_code == 404:
return False
raise
async with get_client() as client:
try:
await client.delete_automation(self.id)
return True
except PrefectHTTPStatusError as exc:
if exc.response.status_code == 404:
return False
raise

@sync_compatible
async def disable(self: Self) -> bool:
Expand All @@ -171,14 +177,14 @@ async def disable(self: Self) -> bool:
auto = Automation.read(id = 123)
auto.disable()
"""
try:
client, _ = get_or_create_client()
await client.pause_automation(self.id)
return True
except PrefectHTTPStatusError as exc:
if exc.response.status_code == 404:
return False
raise
async with get_client() as client:
try:
await client.pause_automation(self.id)
return True
except PrefectHTTPStatusError as exc:
if exc.response.status_code == 404:
return False
raise

@sync_compatible
async def enable(self: Self) -> bool:
Expand All @@ -187,11 +193,11 @@ async def enable(self: Self) -> bool:
auto = Automation.read(id = 123)
auto.enable()
"""
try:
client, _ = get_or_create_client()
await client.resume_automation("asd")
return True
except PrefectHTTPStatusError as exc:
if exc.response.status_code == 404:
return False
raise
async with get_client() as client:
try:
await client.resume_automation(self.id)
return True
except PrefectHTTPStatusError as exc:
if exc.response.status_code == 404:
return False
raise
22 changes: 11 additions & 11 deletions src/prefect/cache_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ def compute_key(
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
**kwargs,
**kwargs: Any,
) -> Optional[str]:
raise NotImplementedError

def __sub__(self, other: str) -> "CachePolicy":
if not isinstance(other, str):
if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance]
raise TypeError("Can only subtract strings from key policies.")
new = Inputs(exclude=[other])
return CompoundCachePolicy(policies=[self, new])
Expand Down Expand Up @@ -140,7 +140,7 @@ def compute_key(
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
**kwargs,
**kwargs: Any,
) -> Optional[str]:
if self.cache_key_fn:
return self.cache_key_fn(task_ctx, inputs)
Expand All @@ -162,9 +162,9 @@ def compute_key(
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
**kwargs,
**kwargs: Any,
) -> Optional[str]:
keys = []
keys: list[str] = []
for policy in self.policies:
policy_key = policy.compute_key(
task_ctx=task_ctx,
Expand All @@ -191,7 +191,7 @@ def compute_key(
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
**kwargs,
**kwargs: Any,
) -> Optional[str]:
return None

Expand All @@ -211,7 +211,7 @@ def compute_key(
task_ctx: TaskRunContext,
inputs: Optional[Dict[str, Any]],
flow_parameters: Optional[Dict[str, Any]],
**kwargs,
**kwargs: Any,
) -> Optional[str]:
if not task_ctx:
return None
Expand All @@ -238,7 +238,7 @@ def compute_key(
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
**kwargs,
**kwargs: Any,
) -> Optional[str]:
if not flow_parameters:
return None
Expand All @@ -257,7 +257,7 @@ def compute_key(
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
**kwargs,
**kwargs: Any,
) -> Optional[str]:
if not task_ctx:
return None
Expand All @@ -280,7 +280,7 @@ def compute_key(
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
**kwargs,
**kwargs: Any,
) -> Optional[str]:
hashed_inputs = {}
inputs = inputs or {}
Expand All @@ -307,7 +307,7 @@ def compute_key(
raise ValueError(msg) from exc

def __sub__(self, other: str) -> "CachePolicy":
if not isinstance(other, str):
if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance]
raise TypeError("Can only subtract strings from key policies.")
return Inputs(exclude=self.exclude + [other])

Expand Down
Loading
Loading