diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 3a41e9e33..6d6cf7bc9 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,5 +1,6 @@ # Upcoming Release ## Major features and improvements +* Added async functionality for loading and saving data in `PartitionedDataset` via `use_async` argument. ## Bug fixes and other changes * Removed arbitrary upper bound for `s3fs`. @@ -8,6 +9,7 @@ ## Community contributions Many thanks to the following Kedroids for contributing PRs to this release: * [Charles Guan](https://github.com/charlesbmi) +* [Puneet Saini](https://github.com/puneeter) # Release 3.0.0 diff --git a/kedro-datasets/docs/source/conf.py b/kedro-datasets/docs/source/conf.py index aaade1aca..39446a335 100644 --- a/kedro-datasets/docs/source/conf.py +++ b/kedro-datasets/docs/source/conf.py @@ -220,10 +220,7 @@ todo_include_todos = False # -- Kedro specific configuration ----------------------------------------- -KEDRO_MODULES = [ - "kedro_datasets", - "kedro_datasets_experimental" -] +KEDRO_MODULES = ["kedro_datasets", "kedro_datasets_experimental"] def get_classes(module): diff --git a/kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py b/kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py index 818ee97b8..fb939b5ce 100644 --- a/kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py +++ b/kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py @@ -4,6 +4,7 @@ from __future__ import annotations +import asyncio import operator from copy import deepcopy from pathlib import PurePosixPath @@ -152,6 +153,7 @@ def __init__( # noqa: PLR0913 fs_args: dict[str, Any] | None = None, overwrite: bool = False, metadata: dict[str, Any] | None = None, + use_async: bool = False, ) -> None: """Creates a new instance of ``PartitionedDataset``. @@ -192,6 +194,8 @@ def __init__( # noqa: PLR0913 overwrite: If True, any existing partitions will be removed. metadata: Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins. + use_async: If True, the dataset will be loaded and saved asynchronously. + Defaults to False. Raises: DatasetError: If versioning is enabled for the underlying dataset. @@ -206,6 +210,7 @@ def __init__( # noqa: PLR0913 self._protocol = infer_storage_options(self._path)["protocol"] self._partition_cache: Cache = Cache(maxsize=1) self.metadata = metadata + self._use_async = use_async dataset = dataset if isinstance(dataset, dict) else {"type": dataset} self._dataset_type, self._dataset_config = parse_dataset_definition(dataset) @@ -285,6 +290,12 @@ def _path_to_partition(self, path: str) -> str: return path def _load(self) -> dict[str, Callable[[], Any]]: + if self._use_async: + return asyncio.run(self._async_load()) + else: + return self._sync_load() + + def _sync_load(self) -> dict[str, Callable[[], Any]]: partitions = {} for partition in self._list_partitions(): @@ -300,7 +311,32 @@ def _load(self) -> dict[str, Callable[[], Any]]: return partitions + async def _async_load(self) -> dict[str, Callable[[], Any]]: + partitions = {} + + async def load_partition(partition: str) -> None: + kwargs = deepcopy(self._dataset_config) + kwargs[self._filepath_arg] = self._join_protocol(partition) + dataset = self._dataset_type(**kwargs) # type: ignore + partition_id = self._path_to_partition(partition) + partitions[partition_id] = dataset.load + + await asyncio.gather( + *[load_partition(partition) for partition in self._list_partitions()] + ) + + if not partitions: + raise DatasetError(f"No partitions found in '{self._path}'") + + return partitions + def _save(self, data: dict[str, Any]) -> None: + if self._use_async: + asyncio.run(self._async_save(data)) + else: + self._sync_save(data) + + def _sync_save(self, data: dict[str, Any]) -> None: if self._overwrite and self._filesystem.exists(self._normalized_path): self._filesystem.rm(self._normalized_path, recursive=True) @@ -315,6 +351,36 @@ def _save(self, data: dict[str, Any]) -> None: dataset.save(partition_data) self._invalidate_caches() + async def _async_save(self, data: dict[str, Any]) -> None: + if self._overwrite and await self._filesystem_exists(self._normalized_path): + await self._filesystem_rm(self._normalized_path, recursive=True) + + async def save_partition(partition_id: str, partition_data: Any) -> None: + kwargs = deepcopy(self._dataset_config) + partition = self._partition_to_path(partition_id) + kwargs[self._filepath_arg] = self._join_protocol(partition) + dataset = self._dataset_type(**kwargs) # type: ignore + if callable(partition_data): + partition_data = partition_data() # noqa: PLW2901 + await self._dataset_save(dataset, partition_data) + + await asyncio.gather( + *[ + save_partition(partition_id, partition_data) + for partition_id, partition_data in sorted(data.items()) + ] + ) + self._invalidate_caches() + + async def _filesystem_exists(self, path: str) -> bool: + return self._filesystem.exists(path) + + async def _filesystem_rm(self, path: str, recursive: bool) -> None: + self._filesystem.rm(path, recursive=recursive) + + async def _dataset_save(self, dataset: AbstractDataset, data: Any) -> None: + dataset.save(data) + def _describe(self) -> dict[str, Any]: clean_dataset_config = ( {k: v for k, v in self._dataset_config.items() if k != CREDENTIALS_KEY} diff --git a/kedro-datasets/tests/partitions/test_partitioned_dataset.py b/kedro-datasets/tests/partitions/test_partitioned_dataset.py index 4dc70881a..4af124787 100644 --- a/kedro-datasets/tests/partitions/test_partitioned_dataset.py +++ b/kedro-datasets/tests/partitions/test_partitioned_dataset.py @@ -61,11 +61,21 @@ class TestPartitionedDatasetLocal: @pytest.mark.parametrize( "suffix,expected_num_parts", [("", 5), (".csv", 3), ("p4", 1)] ) + @pytest.mark.parametrize("use_async", [True, False]) def test_load( - self, dataset, local_csvs, partitioned_data_pandas, suffix, expected_num_parts + self, + dataset, + local_csvs, + partitioned_data_pandas, + suffix, + expected_num_parts, + use_async, ): pds = PartitionedDataset( - path=str(local_csvs), dataset=dataset, filename_suffix=suffix + path=str(local_csvs), + dataset=dataset, + filename_suffix=suffix, + use_async=use_async, ) loaded_partitions = pds.load() @@ -78,9 +88,13 @@ def test_load( @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) @pytest.mark.parametrize("suffix", ["", ".csv"]) - def test_save(self, dataset, local_csvs, suffix): + @pytest.mark.parametrize("use_async", [True, False]) + def test_save(self, dataset, local_csvs, suffix, use_async): pds = PartitionedDataset( - path=str(local_csvs), dataset=dataset, filename_suffix=suffix + path=str(local_csvs), + dataset=dataset, + filename_suffix=suffix, + use_async=use_async, ) original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) part_id = "new/data" @@ -94,9 +108,13 @@ def test_save(self, dataset, local_csvs, suffix): @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) @pytest.mark.parametrize("suffix", ["", ".csv"]) - def test_lazy_save(self, dataset, local_csvs, suffix): + @pytest.mark.parametrize("use_async", [True, False]) + def test_lazy_save(self, dataset, local_csvs, suffix, use_async): pds = PartitionedDataset( - path=str(local_csvs), dataset=dataset, filename_suffix=suffix + path=str(local_csvs), + dataset=dataset, + filename_suffix=suffix, + use_async=use_async, ) def original_data(): @@ -111,9 +129,12 @@ def original_data(): reloaded_data = loaded_partitions[part_id]() assert_frame_equal(reloaded_data, original_data()) - def test_save_invalidates_cache(self, local_csvs, mocker): + @pytest.mark.parametrize("use_async", [True, False]) + def test_save_invalidates_cache(self, local_csvs, mocker, use_async): """Test that save calls invalidate partition cache""" - pds = PartitionedDataset(path=str(local_csvs), dataset="pandas.CSVDataset") + pds = PartitionedDataset( + path=str(local_csvs), dataset="pandas.CSVDataset", use_async=use_async + ) mocked_fs_invalidate = mocker.patch.object(pds._filesystem, "invalidate_cache") first_load = pds.load() assert pds._partition_cache.currsize == 1 @@ -135,9 +156,13 @@ def test_save_invalidates_cache(self, local_csvs, mocker): assert new_partition in second_load @pytest.mark.parametrize("overwrite,expected_num_parts", [(False, 6), (True, 1)]) - def test_overwrite(self, local_csvs, overwrite, expected_num_parts): + @pytest.mark.parametrize("use_async", [True, False]) + def test_overwrite(self, local_csvs, overwrite, expected_num_parts, use_async): pds = PartitionedDataset( - path=str(local_csvs), dataset="pandas.CSVDataset", overwrite=overwrite + path=str(local_csvs), + dataset="pandas.CSVDataset", + overwrite=overwrite, + use_async=use_async, ) original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) part_id = "new/data" @@ -147,11 +172,16 @@ def test_overwrite(self, local_csvs, overwrite, expected_num_parts): assert part_id in loaded_partitions assert len(loaded_partitions.keys()) == expected_num_parts - def test_release_instance_cache(self, local_csvs): + @pytest.mark.parametrize("use_async", [True, False]) + def test_release_instance_cache(self, local_csvs, use_async): """Test that cache invalidation does not affect other instances""" - ds_a = PartitionedDataset(path=str(local_csvs), dataset="pandas.CSVDataset") + ds_a = PartitionedDataset( + path=str(local_csvs), dataset="pandas.CSVDataset", use_async=use_async + ) ds_a.load() - ds_b = PartitionedDataset(path=str(local_csvs), dataset="pandas.CSVDataset") + ds_b = PartitionedDataset( + path=str(local_csvs), dataset="pandas.CSVDataset", use_async=use_async + ) ds_b.load() assert ds_a._partition_cache.currsize == 1 @@ -164,18 +194,28 @@ def test_release_instance_cache(self, local_csvs): assert ds_b._partition_cache.currsize == 1 @pytest.mark.parametrize("dataset", ["pandas.CSVDataset", "pandas.ParquetDataset"]) - def test_exists(self, local_csvs, dataset): - assert PartitionedDataset(path=str(local_csvs), dataset=dataset).exists() + @pytest.mark.parametrize("use_async", [True, False]) + def test_exists(self, local_csvs, dataset, use_async): + assert PartitionedDataset( + path=str(local_csvs), dataset=dataset, use_async=use_async + ).exists() empty_folder = local_csvs / "empty" / "folder" - assert not PartitionedDataset(path=str(empty_folder), dataset=dataset).exists() + assert not PartitionedDataset( + path=str(empty_folder), dataset=dataset, use_async=use_async + ).exists() empty_folder.mkdir(parents=True) - assert not PartitionedDataset(path=str(empty_folder), dataset=dataset).exists() + assert not PartitionedDataset( + path=str(empty_folder), dataset=dataset, use_async=use_async + ).exists() @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) - def test_release(self, dataset, local_csvs): + @pytest.mark.parametrize("use_async", [True, False]) + def test_release(self, dataset, local_csvs, use_async): partition_to_remove = "p2.csv" - pds = PartitionedDataset(path=str(local_csvs), dataset=dataset) + pds = PartitionedDataset( + path=str(local_csvs), dataset=dataset, use_async=use_async + ) initial_load = pds.load() assert partition_to_remove in initial_load @@ -188,15 +228,17 @@ def test_release(self, dataset, local_csvs): assert initial_load.keys() ^ load_after_release.keys() == {partition_to_remove} @pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION) - def test_describe(self, dataset): + @pytest.mark.parametrize("use_async", [True, False]) + def test_describe(self, dataset, use_async): path = str(Path.cwd()) - pds = PartitionedDataset(path=path, dataset=dataset) + pds = PartitionedDataset(path=path, dataset=dataset, use_async=use_async) assert f"path={path}" in str(pds) assert "dataset_type=CSVDataset" in str(pds) assert "dataset_config" in str(pds) - def test_load_args(self, mocker): + @pytest.mark.parametrize("use_async", [True, False]) + def test_load_args(self, mocker, use_async): fake_partition_name = "fake_partition" mocked_filesystem = mocker.patch("fsspec.filesystem") mocked_find = mocked_filesystem.return_value.find @@ -205,7 +247,10 @@ def test_load_args(self, mocker): path = str(Path.cwd()) load_args = {"maxdepth": 42, "withdirs": True} pds = PartitionedDataset( - path=path, dataset="pandas.CSVDataset", load_args=load_args + path=path, + dataset="pandas.CSVDataset", + load_args=load_args, + use_async=use_async, ) mocker.patch.object(pds, "_path_to_partition", return_value=fake_partition_name) @@ -216,13 +261,17 @@ def test_load_args(self, mocker): "credentials,expected_pds_creds,expected_dataset_creds", [({"cred": "common"}, {"cred": "common"}, {"cred": "common"}), (None, {}, {})], ) + @pytest.mark.parametrize("use_async", [True, False]) def test_credentials( - self, mocker, credentials, expected_pds_creds, expected_dataset_creds + self, mocker, credentials, expected_pds_creds, expected_dataset_creds, use_async ): mocked_filesystem = mocker.patch("fsspec.filesystem") path = str(Path.cwd()) pds = PartitionedDataset( - path=path, dataset="pandas.CSVDataset", credentials=credentials + path=path, + dataset="pandas.CSVDataset", + credentials=credentials, + use_async=use_async, ) assert mocked_filesystem.call_count == 2 @@ -244,13 +293,14 @@ def _assert_not_in_repr(value): _assert_not_in_repr(credentials) - def test_fs_args(self, mocker): + @pytest.mark.parametrize("use_async", [True, False]) + def test_fs_args(self, mocker, use_async): fs_args = {"foo": "bar"} mocked_filesystem = mocker.patch("fsspec.filesystem") path = str(Path.cwd()) pds = PartitionedDataset( - path=path, dataset="pandas.CSVDataset", fs_args=fs_args + path=path, dataset="pandas.CSVDataset", fs_args=fs_args, use_async=use_async ) assert mocked_filesystem.call_count == 2 @@ -258,8 +308,11 @@ def test_fs_args(self, mocker): assert pds._dataset_config["fs_args"] == fs_args @pytest.mark.parametrize("dataset", ["pandas.ParquetDataset", ParquetDataset]) - def test_invalid_dataset(self, dataset, local_csvs): - pds = PartitionedDataset(path=str(local_csvs), dataset=dataset) + @pytest.mark.parametrize("use_async", [True, False]) + def test_invalid_dataset(self, dataset, local_csvs, use_async): + pds = PartitionedDataset( + path=str(local_csvs), dataset=dataset, use_async=use_async + ) loaded_partitions = pds.load() for partition, df_loader in loaded_partitions.items(): @@ -289,9 +342,12 @@ def test_invalid_dataset(self, dataset, local_csvs): ({}, "'type' is missing from dataset catalog configuration"), ], ) - def test_invalid_dataset_config(self, dataset_config, error_pattern): + @pytest.mark.parametrize("use_async", [True, False]) + def test_invalid_dataset_config(self, dataset_config, error_pattern, use_async): with pytest.raises(DatasetError, match=error_pattern): - PartitionedDataset(path=str(Path.cwd()), dataset=dataset_config) + PartitionedDataset( + path=str(Path.cwd()), dataset=dataset_config, use_async=use_async + ) @pytest.mark.parametrize( "dataset_config", @@ -304,6 +360,7 @@ def test_invalid_dataset_config(self, dataset_config, error_pattern): @pytest.mark.parametrize( "suffix,expected_num_parts", [("", 5), (".csv", 3), ("p4", 1)] ) + @pytest.mark.parametrize("use_async", [True, False]) def test_versioned_dataset_save_and_load( self, mocker, @@ -312,6 +369,7 @@ def test_versioned_dataset_save_and_load( suffix, expected_num_parts, partitioned_data_pandas, + use_async, ): """Test that saved and reloaded data matches the original one for the versioned dataset.""" @@ -319,13 +377,16 @@ def test_versioned_dataset_save_and_load( mock_ts = mocker.patch( "kedro.io.core.generate_timestamp", return_value=save_version ) - PartitionedDataset(path=filepath_csvs, dataset=dataset_config).save( - partitioned_data_pandas - ) + PartitionedDataset( + path=filepath_csvs, dataset=dataset_config, use_async=use_async + ).save(partitioned_data_pandas) mock_ts.assert_called_once() pds = PartitionedDataset( - path=filepath_csvs, dataset=dataset_config, filename_suffix=suffix + path=filepath_csvs, + dataset=dataset_config, + filename_suffix=suffix, + use_async=use_async, ) loaded_partitions = pds.load() @@ -343,7 +404,8 @@ def test_versioned_dataset_save_and_load( # all partitions were saved using the same version string assert actual_save_versions == {save_version} - def test_malformed_versioned_path(self, tmp_path): + @pytest.mark.parametrize("use_async", [True, False]) + def test_malformed_versioned_path(self, tmp_path, use_async): local_dir = tmp_path / "files" local_dir.mkdir() @@ -354,6 +416,7 @@ def test_malformed_versioned_path(self, tmp_path): pds = PartitionedDataset( path=str(local_dir / "path/to/folder"), dataset={"type": "pandas.CSVDataset", "versioned": True}, + use_async=use_async, ) pattern = re.escape( @@ -363,8 +426,11 @@ def test_malformed_versioned_path(self, tmp_path): with pytest.raises(DatasetError, match=pattern): pds.load() - def test_no_partitions(self, tmpdir): - pds = PartitionedDataset(path=str(tmpdir), dataset="pandas.CSVDataset") + @pytest.mark.parametrize("use_async", [True, False]) + def test_no_partitions(self, tmpdir, use_async): + pds = PartitionedDataset( + path=str(tmpdir), dataset="pandas.CSVDataset", use_async=use_async + ) pattern = re.escape(f"No partitions found in '{tmpdir}'") with pytest.raises(DatasetError, match=pattern): @@ -390,21 +456,24 @@ def test_no_partitions(self, tmpdir): ), ], ) - def test_filepath_arg_warning(self, pds_config, filepath_arg): + @pytest.mark.parametrize("use_async", [True, False]) + def test_filepath_arg_warning(self, pds_config, filepath_arg, use_async): pattern = ( f"'{filepath_arg}' key must not be specified in the dataset definition as it " f"will be overwritten by partition path" ) with pytest.warns(UserWarning, match=re.escape(pattern)): - PartitionedDataset(**pds_config) + PartitionedDataset(**pds_config, use_async=use_async) - def test_credentials_log_warning(self, caplog): + @pytest.mark.parametrize("use_async", [True, False]) + def test_credentials_log_warning(self, caplog, use_async): """Check that the warning is logged if the dataset credentials will overwrite the top-level ones""" pds = PartitionedDataset( path=str(Path.cwd()), dataset={"type": CSVDataset, "credentials": {"secret": "dataset"}}, credentials={"secret": "global"}, + use_async=use_async, ) log_message = KEY_PROPAGATION_WARNING % { "keys": "credentials", @@ -413,13 +482,15 @@ def test_credentials_log_warning(self, caplog): assert caplog.record_tuples == [("kedro.io.core", logging.WARNING, log_message)] assert pds._dataset_config["credentials"] == {"secret": "dataset"} - def test_fs_args_log_warning(self, caplog): + @pytest.mark.parametrize("use_async", [True, False]) + def test_fs_args_log_warning(self, caplog, use_async): """Check that the warning is logged if the dataset filesystem arguments will overwrite the top-level ones""" pds = PartitionedDataset( path=str(Path.cwd()), dataset={"type": CSVDataset, "fs_args": {"args": "dataset"}}, fs_args={"args": "dataset"}, + use_async=use_async, ) log_message = KEY_PROPAGATION_WARNING % { "keys": "filesystem arguments", @@ -467,9 +538,14 @@ def test_fs_args_log_warning(self, caplog): ), ], ) - def test_dataset_creds(self, pds_config, expected_ds_creds, global_creds): + @pytest.mark.parametrize("use_async", [True, False]) + def test_dataset_creds( + self, pds_config, expected_ds_creds, global_creds, use_async + ): """Check that global credentials do not interfere dataset credentials.""" - pds = PartitionedDataset(path=str(Path.cwd()), **pds_config) + pds = PartitionedDataset( + path=str(Path.cwd()), **pds_config, use_async=use_async + ) assert pds._dataset_config["credentials"] == expected_ds_creds assert pds._credentials == global_creds @@ -514,8 +590,11 @@ class TestPartitionedDatasetS3: os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY" @pytest.mark.parametrize("dataset", S3_DATASET_DEFINITION) - def test_load(self, dataset, mocked_csvs_in_s3, partitioned_data_pandas): - pds = PartitionedDataset(path=mocked_csvs_in_s3, dataset=dataset) + @pytest.mark.parametrize("use_async", [True, False]) + def test_load(self, dataset, mocked_csvs_in_s3, partitioned_data_pandas, use_async): + pds = PartitionedDataset( + path=mocked_csvs_in_s3, dataset=dataset, use_async=use_async + ) loaded_partitions = pds.load() assert loaded_partitions.keys() == partitioned_data_pandas.keys() @@ -523,12 +602,17 @@ def test_load(self, dataset, mocked_csvs_in_s3, partitioned_data_pandas): df = load_func() assert_frame_equal(df, partitioned_data_pandas[partition_id]) - def test_load_s3a(self, mocked_csvs_in_s3, partitioned_data_pandas, mocker): + @pytest.mark.parametrize("use_async", [True, False]) + def test_load_s3a( + self, mocked_csvs_in_s3, partitioned_data_pandas, mocker, use_async + ): path = mocked_csvs_in_s3.split("://", 1)[1] s3a_path = f"s3a://{path}" # any type is fine as long as it passes isinstance check # since _dataset_type is mocked later anyways - pds = PartitionedDataset(path=s3a_path, dataset="pandas.CSVDataset") + pds = PartitionedDataset( + path=s3a_path, dataset="pandas.CSVDataset", use_async=use_async + ) assert pds._protocol == "s3a" mocked_ds = mocker.patch.object(pds, "_dataset_type") @@ -544,8 +628,11 @@ def test_load_s3a(self, mocked_csvs_in_s3, partitioned_data_pandas, mocker): mocked_ds.assert_has_calls(expected, any_order=True) @pytest.mark.parametrize("dataset", S3_DATASET_DEFINITION) - def test_save(self, dataset, mocked_csvs_in_s3): - pds = PartitionedDataset(path=mocked_csvs_in_s3, dataset=dataset) + @pytest.mark.parametrize("use_async", [True, False]) + def test_save(self, dataset, mocked_csvs_in_s3, use_async): + pds = PartitionedDataset( + path=mocked_csvs_in_s3, dataset=dataset, use_async=use_async + ) original_data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]}) part_id = "new/data.csv" pds.save({part_id: original_data}) @@ -558,14 +645,18 @@ def test_save(self, dataset, mocked_csvs_in_s3): reloaded_data = loaded_partitions[part_id]() assert_frame_equal(reloaded_data, original_data) - def test_save_s3a(self, mocked_csvs_in_s3, mocker): + @pytest.mark.parametrize("use_async", [True, False]) + def test_save_s3a(self, mocked_csvs_in_s3, mocker, use_async): """Test that save works in case of s3a protocol""" path = mocked_csvs_in_s3.split("://", 1)[1] s3a_path = f"s3a://{path}" # any type is fine as long as it passes isinstance check # since _dataset_type is mocked later anyways pds = PartitionedDataset( - path=s3a_path, dataset="pandas.CSVDataset", filename_suffix=".csv" + path=s3a_path, + dataset="pandas.CSVDataset", + filename_suffix=".csv", + use_async=use_async, ) assert pds._protocol == "s3a" @@ -579,19 +670,29 @@ def test_save_s3a(self, mocked_csvs_in_s3, mocker): mocked_ds.return_value.save.assert_called_once_with(data) @pytest.mark.parametrize("dataset", ["pandas.CSVDataset", "pandas.HDFDataset"]) - def test_exists(self, dataset, mocked_csvs_in_s3): - assert PartitionedDataset(path=mocked_csvs_in_s3, dataset=dataset).exists() + @pytest.mark.parametrize("use_async", [True, False]) + def test_exists(self, dataset, mocked_csvs_in_s3, use_async): + assert PartitionedDataset( + path=mocked_csvs_in_s3, dataset=dataset, use_async=use_async + ).exists() empty_folder = "/".join([mocked_csvs_in_s3, "empty", "folder"]) - assert not PartitionedDataset(path=empty_folder, dataset=dataset).exists() + assert not PartitionedDataset( + path=empty_folder, dataset=dataset, use_async=use_async + ).exists() s3fs.S3FileSystem().mkdir(empty_folder) - assert not PartitionedDataset(path=empty_folder, dataset=dataset).exists() + assert not PartitionedDataset( + path=empty_folder, dataset=dataset, use_async=use_async + ).exists() @pytest.mark.parametrize("dataset", S3_DATASET_DEFINITION) - def test_release(self, dataset, mocked_csvs_in_s3): + @pytest.mark.parametrize("use_async", [True, False]) + def test_release(self, dataset, mocked_csvs_in_s3, use_async): partition_to_remove = "p2.csv" - pds = PartitionedDataset(path=mocked_csvs_in_s3, dataset=dataset) + pds = PartitionedDataset( + path=mocked_csvs_in_s3, dataset=dataset, use_async=use_async + ) initial_load = pds.load() assert partition_to_remove in initial_load @@ -605,9 +706,10 @@ def test_release(self, dataset, mocked_csvs_in_s3): assert initial_load.keys() ^ load_after_release.keys() == {partition_to_remove} @pytest.mark.parametrize("dataset", S3_DATASET_DEFINITION) - def test_describe(self, dataset): + @pytest.mark.parametrize("use_async", [True, False]) + def test_describe(self, dataset, use_async): path = f"s3://{BUCKET_NAME}/foo/bar" - pds = PartitionedDataset(path=path, dataset=dataset) + pds = PartitionedDataset(path=path, dataset=dataset, use_async=use_async) assert f"path={path}" in str(pds) assert "dataset_type=CSVDataset" in str(pds)