diff --git a/src/datasets/load.py b/src/datasets/load.py index e648f8c34ff..25919237d64 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -368,6 +368,7 @@ def infer_module_for_data_files( return _EXTENSION_TO_MODULE[ext] elif ext == "zip": return infer_module_for_data_files_in_archives(data_files_list, use_auth_token=use_auth_token) + return None, {} def infer_module_for_data_files_in_archives( @@ -404,6 +405,7 @@ def infer_module_for_data_files_in_archives( most_common = extensions_counter.most_common(1)[0][0] if most_common in _EXTENSION_TO_MODULE: return _EXTENSION_TO_MODULE[most_common] + return None, {} @dataclass @@ -632,14 +634,14 @@ def get_module(self) -> DatasetModule: base_path=base_path, allowed_extensions=ALL_ALLOWED_EXTENSIONS, ) - module_names = { - key: infer_module_for_data_files(data_files_list) for key, data_files_list in data_files.items() + split_modules = { + split: infer_module_for_data_files(data_files_list) for split, data_files_list in data_files.items() } - if len(set(list(zip(*module_names.values()))[0])) > 1: - raise ValueError(f"Couldn't infer the same data file format for all splits. Got {module_names}") - module_name, builder_kwargs = next(iter(module_names.values())) + module_name, builder_kwargs = next(iter(split_modules.values())) + if any((module_name, builder_kwargs) != split_module for split_module in split_modules.values()): + raise ValueError(f"Couldn't infer the same data file format for all splits. Got {split_modules}") if not module_name: - raise FileNotFoundError(f"No data files or dataset script found in {self.path}") + raise FileNotFoundError(f"No (supported) data files or dataset script found in {self.path}") # Collect metadata files if the module supports them if self.data_files is None and module_name in _MODULE_SUPPORTS_METADATA and patterns != DEFAULT_PATTERNS_ALL: try: @@ -772,15 +774,15 @@ def get_module(self) -> DatasetModule: base_path=self.data_dir, allowed_extensions=ALL_ALLOWED_EXTENSIONS, ) - module_names = { - key: infer_module_for_data_files(data_files_list, use_auth_token=self.download_config.use_auth_token) - for key, data_files_list in data_files.items() + split_modules = { + split: infer_module_for_data_files(data_files_list, use_auth_token=self.download_config.use_auth_token) + for split, data_files_list in data_files.items() } - if len(set(list(zip(*module_names.values()))[0])) > 1: - raise ValueError(f"Couldn't infer the same data file format for all splits. Got {module_names}") - module_name, builder_kwargs = next(iter(module_names.values())) + module_name, builder_kwargs = next(iter(split_modules.values())) + if any((module_name, builder_kwargs) != split_module for split_module in split_modules.values()): + raise ValueError(f"Couldn't infer the same data file format for all splits. Got {split_modules}") if not module_name: - raise FileNotFoundError(f"No data files or dataset script found in {self.name}") + raise FileNotFoundError(f"No (supported) data files or dataset script found in {self.name}") # Collect metadata files if the module supports them if self.data_files is None and module_name in _MODULE_SUPPORTS_METADATA and patterns != DEFAULT_PATTERNS_ALL: try: diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index f38db20a68d..cc2b6f0d5e9 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -480,6 +480,15 @@ def zip_text_with_dir_path(text_path, text2_path, tmp_path_factory): return path +@pytest.fixture(scope="session") +def zip_unsupported_ext_path(text_path, text2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset.ext.zip" + with zipfile.ZipFile(path, "w") as f: + f.write(text_path, arcname=os.path.basename("unsupported.ext")) + f.write(text2_path, arcname=os.path.basename("unsupported_2.ext")) + return path + + @pytest.fixture(scope="session") def text_path_with_unicode_new_lines(tmp_path_factory): text = "\n".join(["First", "Second\u2029with Unicode new line", "Third"]) diff --git a/tests/test_load.py b/tests/test_load.py index 7475e97853c..e553d3fa0da 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -209,6 +209,8 @@ def metric_loading_script_dir(tmp_path): (["train.jsonl"], "json", {}), (["train.parquet"], "parquet", {}), (["train.txt"], "text", {}), + (["unsupported.ext"], None, {}), + ([""], None, {}), ], ) def test_infer_module_for_data_files(data_files, expected_module, expected_builder_kwargs): @@ -217,9 +219,18 @@ def test_infer_module_for_data_files(data_files, expected_module, expected_build assert builder_kwargs == expected_builder_kwargs -@pytest.mark.parametrize("data_file, expected_module", [("zip_csv_path", "csv"), ("zip_csv_with_dir_path", "csv")]) -def test_infer_module_for_data_files_in_archives(data_file, expected_module, zip_csv_path, zip_csv_with_dir_path): - data_file_paths = {"zip_csv_path": zip_csv_path, "zip_csv_with_dir_path": zip_csv_with_dir_path} +@pytest.mark.parametrize( + "data_file, expected_module", + [("zip_csv_path", "csv"), ("zip_csv_with_dir_path", "csv"), ("zip_unsupported_ext_path", None)], +) +def test_infer_module_for_data_files_in_archives( + data_file, expected_module, zip_csv_path, zip_csv_with_dir_path, zip_unsupported_ext_path +): + data_file_paths = { + "zip_csv_path": zip_csv_path, + "zip_csv_with_dir_path": zip_csv_with_dir_path, + "zip_unsupported_ext_path": zip_unsupported_ext_path, + } data_files = [str(data_file_paths[data_file])] inferred_module, _ = infer_module_for_data_files_in_archives(data_files, False) assert inferred_module == expected_module