Skip to content

Commit

Permalink
fix impl
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Dec 20, 2024
1 parent c534ec0 commit 887ce53
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions src/prefect/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,10 @@ def _format_user_supplied_storage_key(key: str) -> str:


async def _call_explicitly_async_block_method(
block: WritableFileSystem, method: str, *args: Any, **kwargs: Any
block: Union[WritableFileSystem, NullFileSystem],
method: str,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Any:
"""
TODO: remove this once we have explicit async methods on all storage blocks
Expand All @@ -243,8 +246,8 @@ async def _call_explicitly_async_block_method(
"""
if hasattr(block, f"a{method}"): # explicit async method
return await getattr(block, f"a{method}")(*args, **kwargs)
elif hasattr(hasattr(block, method), "aio"): # sync_compatible
return await getattr(block, method).aio(*args, **kwargs)
elif hasattr(getattr(block, method, None), "aio"): # sync_compatible
return await getattr(block, method).aio(block, *args, **kwargs)
else: # should not happen in prefect, but users can override impls
maybe_coro = getattr(block, method)(*args, **kwargs)
if inspect.isawaitable(maybe_coro):
Expand Down Expand Up @@ -426,9 +429,7 @@ async def _exists(self, key: str) -> bool:
# so the entire payload doesn't need to be read
try:
metadata_content = await _call_explicitly_async_block_method(
self.metadata_storage,
"read_path",
key,
self.metadata_storage, "read_path", (key,), {}
)
if metadata_content is None:
return False
Expand All @@ -439,9 +440,7 @@ async def _exists(self, key: str) -> bool:
else:
try:
content = await _call_explicitly_async_block_method(
self.result_storage,
"read_path",
key,
self.result_storage, "read_path", (key,), {}
)
if content is None:
return False
Expand Down Expand Up @@ -522,7 +521,8 @@ async def _read(self, key: str, holder: str) -> "ResultRecord[Any]":
metadata_content = await _call_explicitly_async_block_method(
self.metadata_storage,
"read_path",
key,
(key,),
{},
)
metadata = ResultRecordMetadata.load_bytes(metadata_content)
assert (
Expand All @@ -531,7 +531,8 @@ async def _read(self, key: str, holder: str) -> "ResultRecord[Any]":
result_content = await _call_explicitly_async_block_method(
self.result_storage,
"read_path",
metadata.storage_key,
(metadata.storage_key,),
{},
)
result_record: ResultRecord[
Any
Expand All @@ -543,7 +544,8 @@ async def _read(self, key: str, holder: str) -> "ResultRecord[Any]":
content = await _call_explicitly_async_block_method(
self.result_storage,
"read_path",
key,
(key,),
{},
)
result_record: ResultRecord[Any] = ResultRecord.deserialize(
content, backup_serializer=self.serializer
Expand Down Expand Up @@ -714,23 +716,23 @@ async def _persist_result_record(self, result_record: "ResultRecord", holder: st
await _call_explicitly_async_block_method(
self.result_storage,
"write_path",
result_record.metadata.storage_key,
content=result_record.serialize_result(),
(result_record.metadata.storage_key,),
{"content": result_record.serialize_result()},
)
await _call_explicitly_async_block_method(
self.metadata_storage,
"write_path",
base_key,
content=result_record.serialize_metadata(),
(base_key,),
{"content": result_record.serialize_metadata()},
)
await emit_result_write_event(self, result_record.metadata.storage_key)
# Otherwise, write the result metadata and result together
else:
await _call_explicitly_async_block_method(
self.result_storage,
"write_path",
result_record.metadata.storage_key,
content=result_record.serialize(),
(result_record.metadata.storage_key,),
{"content": result_record.serialize()},
)
await emit_result_write_event(self, result_record.metadata.storage_key)
if self.cache_result_in_memory:
Expand Down Expand Up @@ -960,8 +962,8 @@ async def store_parameters(self, identifier: UUID, parameters: Dict[str, Any]):
await _call_explicitly_async_block_method(
self.result_storage,
"write_path",
f"parameters/{identifier}",
content=record.serialize(),
(f"parameters/{identifier}",),
{"content": record.serialize()},
)

@sync_compatible
Expand All @@ -974,7 +976,8 @@ async def read_parameters(self, identifier: UUID) -> dict[str, Any]:
await _call_explicitly_async_block_method(
self.result_storage,
"read_path",
f"parameters/{identifier}",
(f"parameters/{identifier}",),
{},
)
)
return record.result
Expand Down Expand Up @@ -1030,7 +1033,7 @@ def load_bytes(cls, data: bytes) -> "ResultRecordMetadata":
"""
return cls.model_validate_json(data)

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, ResultRecordMetadata):
return False
return (
Expand Down Expand Up @@ -1104,7 +1107,7 @@ def serialize_result(self) -> bytes:

@model_validator(mode="before")
@classmethod
def coerce_old_format(cls, value: Any):
def coerce_old_format(cls, value: Any) -> Any:
if isinstance(value, dict):
if "data" in value:
value["result"] = value.pop("data")
Expand Down

0 comments on commit 887ce53

Please sign in to comment.