Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datasets): Add option to async load and save in PartitionedDatasets #696

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Upcoming Release
## Major features and improvements
* Added async functionality for loading and saving data in `PartitionedDataset` via `use_async` argument.

## Bug fixes and other changes
* Removed arbitrary upper bound for `s3fs`.
Expand All @@ -8,6 +9,7 @@
## Community contributions
Many thanks to the following Kedroids for contributing PRs to this release:
* [Charles Guan](https://github.com/charlesbmi)
* [Puneet Saini](https://github.com/puneeter)


# Release 3.0.0
Expand Down
5 changes: 1 addition & 4 deletions kedro-datasets/docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,7 @@
todo_include_todos = False

# -- Kedro specific configuration -----------------------------------------
KEDRO_MODULES = [
"kedro_datasets",
"kedro_datasets_experimental"
]
KEDRO_MODULES = ["kedro_datasets", "kedro_datasets_experimental"]


def get_classes(module):
Expand Down
66 changes: 66 additions & 0 deletions kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import asyncio
import operator
from copy import deepcopy
from pathlib import PurePosixPath
Expand Down Expand Up @@ -152,6 +153,7 @@ def __init__( # noqa: PLR0913
fs_args: dict[str, Any] | None = None,
overwrite: bool = False,
metadata: dict[str, Any] | None = None,
use_async: bool = False,
) -> None:
"""Creates a new instance of ``PartitionedDataset``.

Expand Down Expand Up @@ -192,6 +194,8 @@ def __init__( # noqa: PLR0913
overwrite: If True, any existing partitions will be removed.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
use_async: If True, the dataset will be loaded and saved asynchronously.
Defaults to False.

Raises:
DatasetError: If versioning is enabled for the underlying dataset.
Expand All @@ -206,6 +210,7 @@ def __init__( # noqa: PLR0913
self._protocol = infer_storage_options(self._path)["protocol"]
self._partition_cache: Cache = Cache(maxsize=1)
self.metadata = metadata
self._use_async = use_async

dataset = dataset if isinstance(dataset, dict) else {"type": dataset}
self._dataset_type, self._dataset_config = parse_dataset_definition(dataset)
Expand Down Expand Up @@ -285,6 +290,12 @@ def _path_to_partition(self, path: str) -> str:
return path

def _load(self) -> dict[str, Callable[[], Any]]:
if self._use_async:
return asyncio.run(self._async_load())
else:
return self._sync_load()

def _sync_load(self) -> dict[str, Callable[[], Any]]:
partitions = {}

for partition in self._list_partitions():
Expand All @@ -300,7 +311,32 @@ def _load(self) -> dict[str, Callable[[], Any]]:

return partitions

async def _async_load(self) -> dict[str, Callable[[], Any]]:
partitions = {}

async def load_partition(partition: str) -> None:
kwargs = deepcopy(self._dataset_config)
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs) # type: ignore
partition_id = self._path_to_partition(partition)
partitions[partition_id] = dataset.load

await asyncio.gather(
*[load_partition(partition) for partition in self._list_partitions()]
)
Comment on lines +317 to +326
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, there's no actual I/O being performed here, right? Only the partitions dictionary is being populated.

I don't see the need of using async helpers and asyncio.gather here.

If anything, as a user I'd expect to have the async loaders available in my node function so that I can await them (provided that my node is asynchronous), use asyncio.gather myself, or use an asyncio.TaskGroup.

my_partitioned_dataset:
  type: partitions.PartitionedDataset
  path: s3://my-bucket-name/path/to/folder
  ...
  use_async: True
def concat_partitions(partitioned_input: dict[str, Awaitable]) -> pd.DataFrame:
    tasks = []
    async with asyncio.TaskGroup() as tg:
        for partition_key, partition_load_func in sorted(partitioned_input.items()):
            tasks.append(tg.create_task(partition_load_func()))

    result = pd.DataFrame()
    result = pd.concat([result] + [tasks.result() for task in tasks], ignore_index=True, sort=True)

(not that I find this a particularly friendly DX, but it's more or less a continuation of our current approach https://docs.kedro.org/en/stable/data/partitioned_and_incremental_datasets.html#partitioned-dataset-load)

What am I missing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I am open to both the options. Let me know if you want to revert the load method to the original definition. Happy to also update the documentation once we are aligned with the changes made

Copy link
Contributor

@noklam noklam Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC the original question was using the async option of Runner, and we found that the partitioned dataset only do async on the whole dataset level and it is not efficient.

I think we need to think about this separately for save and load.

For load, the logic is actually implemented in node, can we already do this today with the async node @astrojuanlu shown? If so it seems that we don't need to change anything for load in this PR.

Save is where we actually need changes for partitioned dataset, especially lazy saving. I think it is reasonable to use async by default for save. This is not possible today because how we list partitions and save it in a sync loop. We can only do async on the whole partitioned dataset level but not the underlying dataset (using runner is_async).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason why we prefer making it at the dataset level rather than runner? It seems like having the common approach at the above layer is needed anyway to make it efficient.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I think this is to achieve consistency with synchronous PartitionedDatasets, not sure what you have in mind for runners but maybe we should discuss that separately? Unless you still see issues with the proposed approach


if not partitions:
raise DatasetError(f"No partitions found in '{self._path}'")

return partitions

def _save(self, data: dict[str, Any]) -> None:
if self._use_async:
asyncio.run(self._async_save(data))
else:
self._sync_save(data)

def _sync_save(self, data: dict[str, Any]) -> None:
if self._overwrite and self._filesystem.exists(self._normalized_path):
self._filesystem.rm(self._normalized_path, recursive=True)

Expand All @@ -315,6 +351,36 @@ def _save(self, data: dict[str, Any]) -> None:
dataset.save(partition_data)
self._invalidate_caches()

async def _async_save(self, data: dict[str, Any]) -> None:
if self._overwrite and await self._filesystem_exists(self._normalized_path):
await self._filesystem_rm(self._normalized_path, recursive=True)

async def save_partition(partition_id: str, partition_data: Any) -> None:
kwargs = deepcopy(self._dataset_config)
partition = self._partition_to_path(partition_id)
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs) # type: ignore
if callable(partition_data):
partition_data = partition_data() # noqa: PLW2901
await self._dataset_save(dataset, partition_data)

await asyncio.gather(
*[
save_partition(partition_id, partition_data)
for partition_id, partition_data in sorted(data.items())
]
)
self._invalidate_caches()

async def _filesystem_exists(self, path: str) -> bool:
return self._filesystem.exists(path)

async def _filesystem_rm(self, path: str, recursive: bool) -> None:
self._filesystem.rm(path, recursive=recursive)

async def _dataset_save(self, dataset: AbstractDataset, data: Any) -> None:
dataset.save(data)

def _describe(self) -> dict[str, Any]:
clean_dataset_config = (
{k: v for k, v in self._dataset_config.items() if k != CREDENTIALS_KEY}
Expand Down
Loading