Skip to content

Commit

Permalink
use correct async method on storage blocks (#16445)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Dec 20, 2024
1 parent 6ee9561 commit c29969e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 24 deletions.
95 changes: 76 additions & 19 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,29 @@ def _format_user_supplied_storage_key(key: str) -> str:
return key.format(**runtime_vars, parameters=prefect.runtime.task_run.parameters)


async def _call_explicitly_async_block_method(
block: Union[WritableFileSystem, NullFileSystem],
method: str,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Any:
"""
TODO: remove this once we have explicit async methods on all storage blocks
see https://github.com/PrefectHQ/prefect/issues/15008
"""
if hasattr(block, f"a{method}"): # explicit async method
return await getattr(block.__class__.__name__, f"a{method}")(*args, **kwargs)
elif hasattr(getattr(block, method, None), "aio"): # sync_compatible
return await getattr(block, method).aio(block, *args, **kwargs)
else: # should not happen in prefect, but users can override impls
maybe_coro = getattr(block, method)(*args, **kwargs)
if inspect.isawaitable(maybe_coro):
return await maybe_coro
else:
return maybe_coro


T = TypeVar("T")


Expand Down Expand Up @@ -405,7 +428,9 @@ async def _exists(self, key: str) -> bool:
# TODO: Add an `exists` method to commonly used storage blocks
# so the entire payload doesn't need to be read
try:
metadata_content = await self.metadata_storage.read_path(key)
metadata_content = await _call_explicitly_async_block_method(
self.metadata_storage, "read_path", (key,), {}
)
if metadata_content is None:
return False
metadata = ResultRecordMetadata.load_bytes(metadata_content)
Expand All @@ -414,7 +439,9 @@ async def _exists(self, key: str) -> bool:
return False
else:
try:
content = await self.result_storage.read_path(key)
content = await _call_explicitly_async_block_method(
self.result_storage, "read_path", (key,), {}
)
if content is None:
return False
record = ResultRecord.deserialize(content)
Expand Down Expand Up @@ -491,20 +518,35 @@ async def _read(self, key: str, holder: str) -> "ResultRecord[Any]":
self.result_storage = await get_default_result_storage()

if self.metadata_storage is not None:
metadata_content = await self.metadata_storage.read_path(key)
metadata_content = await _call_explicitly_async_block_method(
self.metadata_storage,
"read_path",
(key,),
{},
)
metadata = ResultRecordMetadata.load_bytes(metadata_content)
assert (
metadata.storage_key is not None
), "Did not find storage key in metadata"
result_content = await self.result_storage.read_path(metadata.storage_key)
result_content = await _call_explicitly_async_block_method(
self.result_storage,
"read_path",
(metadata.storage_key,),
{},
)
result_record: ResultRecord[
Any
] = ResultRecord.deserialize_from_result_and_metadata(
result=result_content, metadata=metadata_content
)
await emit_result_read_event(self, resolved_key_path)
else:
content = await self.result_storage.read_path(key)
content = await _call_explicitly_async_block_method(
self.result_storage,
"read_path",
(key,),
{},
)
result_record: ResultRecord[Any] = ResultRecord.deserialize(
content, backup_serializer=self.serializer
)
Expand Down Expand Up @@ -555,7 +597,7 @@ def create_result_record(
obj: Any,
key: Optional[str] = None,
expiration: Optional[DateTime] = None,
) -> "ResultRecord":
) -> "ResultRecord[Any]":
"""
Create a result record.
Expand Down Expand Up @@ -671,19 +713,26 @@ async def _persist_result_record(self, result_record: "ResultRecord", holder: st

# If metadata storage is configured, write result and metadata separately
if self.metadata_storage is not None:
await self.result_storage.write_path(
result_record.metadata.storage_key,
content=result_record.serialize_result(),
await _call_explicitly_async_block_method(
self.result_storage,
"write_path",
(result_record.metadata.storage_key,),
{"content": result_record.serialize_result()},
)
await self.metadata_storage.write_path(
base_key,
content=result_record.serialize_metadata(),
await _call_explicitly_async_block_method(
self.metadata_storage,
"write_path",
(base_key,),
{"content": result_record.serialize_metadata()},
)
await emit_result_write_event(self, result_record.metadata.storage_key)
# Otherwise, write the result metadata and result together
else:
await self.result_storage.write_path(
result_record.metadata.storage_key, content=result_record.serialize()
await _call_explicitly_async_block_method(
self.result_storage,
"write_path",
(result_record.metadata.storage_key,),
{"content": result_record.serialize()},
)
await emit_result_write_event(self, result_record.metadata.storage_key)
if self.cache_result_in_memory:
Expand Down Expand Up @@ -910,8 +959,11 @@ async def store_parameters(self, identifier: UUID, parameters: Dict[str, Any]):
serializer=self.serializer, storage_key=str(identifier)
),
)
await self.result_storage.write_path(
f"parameters/{identifier}", content=record.serialize()
await _call_explicitly_async_block_method(
self.result_storage,
"write_path",
(f"parameters/{identifier}",),
{"content": record.serialize()},
)

@sync_compatible
Expand All @@ -921,7 +973,12 @@ async def read_parameters(self, identifier: UUID) -> dict[str, Any]:
"Result store is not configured - must have a result storage block to read parameters"
)
record = ResultRecord.deserialize(
await self.result_storage.read_path(f"parameters/{identifier}")
await _call_explicitly_async_block_method(
self.result_storage,
"read_path",
(f"parameters/{identifier}",),
{},
)
)
return record.result

Expand Down Expand Up @@ -976,7 +1033,7 @@ def load_bytes(cls, data: bytes) -> "ResultRecordMetadata":
"""
return cls.model_validate_json(data)

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, ResultRecordMetadata):
return False
return (
Expand Down Expand Up @@ -1050,7 +1107,7 @@ def serialize_result(self) -> bytes:

@model_validator(mode="before")
@classmethod
def coerce_old_format(cls, value: Any):
def coerce_old_format(cls, value: Any) -> Any:
if isinstance(value, dict):
if "data" in value:
value["result"] = value.pop("data")
Expand Down
10 changes: 5 additions & 5 deletions tests/results/test_state_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def store(storage_block: WritableFileSystem) -> ResultStore:


@pytest.fixture
async def a_real_result(store) -> ResultRecord:
async def a_real_result(store: ResultStore) -> ResultRecord[str]:
return store.create_result_record(
"test-graceful-retry",
)


@pytest.fixture
def completed_state(a_real_result: ResultRecord) -> State[str]:
def completed_state(a_real_result: ResultRecord[str]) -> State[str]:
return State(type=StateType.COMPLETED, data=a_real_result.metadata)


Expand Down Expand Up @@ -107,7 +107,7 @@ async def test_graceful_retries_reraise_last_error_while_retrieving_missing_resu
now = time.monotonic()
with pytest.raises(FileNotFoundError):
with mock.patch(
"prefect.filesystems.LocalFileSystem.read_path",
"prefect.filesystems.LocalFileSystem.read_path.aio",
new=mock.AsyncMock(
side_effect=[
OSError,
Expand All @@ -129,7 +129,7 @@ async def test_graceful_retries_reraise_last_error_while_retrieving_missing_resu

async def test_graceful_retries_eventually_succeed_while(
shorter_result_retries: None,
a_real_result: ResultRecord,
a_real_result: ResultRecord[str],
completed_state: State[str],
store: ResultStore,
):
Expand All @@ -147,7 +147,7 @@ async def test_graceful_retries_eventually_succeed_while(
# even if it misses a couple times, it will eventually return the data
now = time.monotonic()
with mock.patch(
"prefect.filesystems.LocalFileSystem.read_path",
"prefect.filesystems.LocalFileSystem.read_path.aio",
new=mock.AsyncMock(
side_effect=[
FileNotFoundError,
Expand Down

0 comments on commit c29969e

Please sign in to comment.