From 074375a51e2c395e4a4fdb1264bdb16649271628 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 24 Apr 2023 12:41:18 +0200 Subject: [PATCH 1/5] Test infer module for unsupported data files --- tests/fixtures/files.py | 9 +++++++++ tests/test_load.py | 17 ++++++++++++++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index f38db20a68d..cc2b6f0d5e9 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -480,6 +480,15 @@ def zip_text_with_dir_path(text_path, text2_path, tmp_path_factory): return path +@pytest.fixture(scope="session") +def zip_unsupported_ext_path(text_path, text2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset.ext.zip" + with zipfile.ZipFile(path, "w") as f: + f.write(text_path, arcname=os.path.basename("unsupported.ext")) + f.write(text2_path, arcname=os.path.basename("unsupported_2.ext")) + return path + + @pytest.fixture(scope="session") def text_path_with_unicode_new_lines(tmp_path_factory): text = "\n".join(["First", "Second\u2029with Unicode new line", "Third"]) diff --git a/tests/test_load.py b/tests/test_load.py index 7475e97853c..e553d3fa0da 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -209,6 +209,8 @@ def metric_loading_script_dir(tmp_path): (["train.jsonl"], "json", {}), (["train.parquet"], "parquet", {}), (["train.txt"], "text", {}), + (["unsupported.ext"], None, {}), + ([""], None, {}), ], ) def test_infer_module_for_data_files(data_files, expected_module, expected_builder_kwargs): @@ -217,9 +219,18 @@ def test_infer_module_for_data_files(data_files, expected_module, expected_build assert builder_kwargs == expected_builder_kwargs -@pytest.mark.parametrize("data_file, expected_module", [("zip_csv_path", "csv"), ("zip_csv_with_dir_path", "csv")]) -def test_infer_module_for_data_files_in_archives(data_file, expected_module, zip_csv_path, zip_csv_with_dir_path): - data_file_paths = {"zip_csv_path": zip_csv_path, "zip_csv_with_dir_path": zip_csv_with_dir_path} +@pytest.mark.parametrize( + "data_file, expected_module", + [("zip_csv_path", "csv"), ("zip_csv_with_dir_path", "csv"), ("zip_unsupported_ext_path", None)], +) +def test_infer_module_for_data_files_in_archives( + data_file, expected_module, zip_csv_path, zip_csv_with_dir_path, zip_unsupported_ext_path +): + data_file_paths = { + "zip_csv_path": zip_csv_path, + "zip_csv_with_dir_path": zip_csv_with_dir_path, + "zip_unsupported_ext_path": zip_unsupported_ext_path, + } data_files = [str(data_file_paths[data_file])] inferred_module, _ = infer_module_for_data_files_in_archives(data_files, False) assert inferred_module == expected_module From 244c4e59b4e37f8ddd94fd0d0be2cf849faf736a Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 24 Apr 2023 12:42:25 +0200 Subject: [PATCH 2/5] Fix infer module functions for unsupported files --- src/datasets/load.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/datasets/load.py b/src/datasets/load.py index e648f8c34ff..b2e01e093de 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -368,6 +368,7 @@ def infer_module_for_data_files( return _EXTENSION_TO_MODULE[ext] elif ext == "zip": return infer_module_for_data_files_in_archives(data_files_list, use_auth_token=use_auth_token) + return None, {} def infer_module_for_data_files_in_archives( @@ -404,6 +405,7 @@ def infer_module_for_data_files_in_archives( most_common = extensions_counter.most_common(1)[0][0] if most_common in _EXTENSION_TO_MODULE: return _EXTENSION_TO_MODULE[most_common] + return None, {} @dataclass From 207268e6a33f120dc5658658c48ee6d70ee845f0 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 24 Apr 2023 12:43:40 +0200 Subject: [PATCH 3/5] Fix dataset module factories without script --- src/datasets/load.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index b2e01e093de..25919237d64 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -634,14 +634,14 @@ def get_module(self) -> DatasetModule: base_path=base_path, allowed_extensions=ALL_ALLOWED_EXTENSIONS, ) - module_names = { - key: infer_module_for_data_files(data_files_list) for key, data_files_list in data_files.items() + split_modules = { + split: infer_module_for_data_files(data_files_list) for split, data_files_list in data_files.items() } - if len(set(list(zip(*module_names.values()))[0])) > 1: - raise ValueError(f"Couldn't infer the same data file format for all splits. Got {module_names}") - module_name, builder_kwargs = next(iter(module_names.values())) + module_name, builder_kwargs = next(iter(split_modules.values())) + if any((module_name, builder_kwargs) != split_module for split_module in split_modules.values()): + raise ValueError(f"Couldn't infer the same data file format for all splits. Got {split_modules}") if not module_name: - raise FileNotFoundError(f"No data files or dataset script found in {self.path}") + raise FileNotFoundError(f"No (supported) data files or dataset script found in {self.path}") # Collect metadata files if the module supports them if self.data_files is None and module_name in _MODULE_SUPPORTS_METADATA and patterns != DEFAULT_PATTERNS_ALL: try: @@ -774,15 +774,15 @@ def get_module(self) -> DatasetModule: base_path=self.data_dir, allowed_extensions=ALL_ALLOWED_EXTENSIONS, ) - module_names = { - key: infer_module_for_data_files(data_files_list, use_auth_token=self.download_config.use_auth_token) - for key, data_files_list in data_files.items() + split_modules = { + split: infer_module_for_data_files(data_files_list, use_auth_token=self.download_config.use_auth_token) + for split, data_files_list in data_files.items() } - if len(set(list(zip(*module_names.values()))[0])) > 1: - raise ValueError(f"Couldn't infer the same data file format for all splits. Got {module_names}") - module_name, builder_kwargs = next(iter(module_names.values())) + module_name, builder_kwargs = next(iter(split_modules.values())) + if any((module_name, builder_kwargs) != split_module for split_module in split_modules.values()): + raise ValueError(f"Couldn't infer the same data file format for all splits. Got {split_modules}") if not module_name: - raise FileNotFoundError(f"No data files or dataset script found in {self.name}") + raise FileNotFoundError(f"No (supported) data files or dataset script found in {self.name}") # Collect metadata files if the module supports them if self.data_files is None and module_name in _MODULE_SUPPORTS_METADATA and patterns != DEFAULT_PATTERNS_ALL: try: From 650141d21eb2947c73790d776f22a931c7c8f70e Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 27 Apr 2023 11:39:25 +0200 Subject: [PATCH 4/5] Make sure split_modules is not empty due to empty data_files --- src/datasets/load.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index 25919237d64..f2ca44d254d 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -626,9 +626,7 @@ def __init__( def get_module(self) -> DatasetModule: base_path = os.path.join(self.path, self.data_dir) if self.data_dir else self.path - patterns = ( - sanitize_patterns(self.data_files) if self.data_files is not None else get_data_patterns_locally(base_path) - ) + patterns = sanitize_patterns(self.data_files) if self.data_files else get_data_patterns_locally(base_path) data_files = DataFilesDict.from_local_or_remote( patterns, base_path=base_path, @@ -765,7 +763,7 @@ def get_module(self) -> DatasetModule: ) patterns = ( sanitize_patterns(self.data_files) - if self.data_files is not None + if self.data_files else get_data_patterns_in_dataset_repository(hfh_dataset_info, self.data_dir) ) data_files = DataFilesDict.from_hf_repo( From 6e4997a5ca56f4f8361bcb82ac27a0e2b0969098 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 27 Apr 2023 14:18:56 +0200 Subject: [PATCH 5/5] Revert "Make sure split_modules is not empty due to empty data_files" This reverts commit 650141d21eb2947c73790d776f22a931c7c8f70e. As requested by reviewer. --- src/datasets/load.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index f2ca44d254d..25919237d64 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -626,7 +626,9 @@ def __init__( def get_module(self) -> DatasetModule: base_path = os.path.join(self.path, self.data_dir) if self.data_dir else self.path - patterns = sanitize_patterns(self.data_files) if self.data_files else get_data_patterns_locally(base_path) + patterns = ( + sanitize_patterns(self.data_files) if self.data_files is not None else get_data_patterns_locally(base_path) + ) data_files = DataFilesDict.from_local_or_remote( patterns, base_path=base_path, @@ -763,7 +765,7 @@ def get_module(self) -> DatasetModule: ) patterns = ( sanitize_patterns(self.data_files) - if self.data_files + if self.data_files is not None else get_data_patterns_in_dataset_repository(hfh_dataset_info, self.data_dir) ) data_files = DataFilesDict.from_hf_repo(