Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inferring module for unsupported data files #5787

Merged
merged 5 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ def infer_module_for_data_files(
return _EXTENSION_TO_MODULE[ext]
elif ext == "zip":
return infer_module_for_data_files_in_archives(data_files_list, use_auth_token=use_auth_token)
return None, {}


def infer_module_for_data_files_in_archives(
Expand Down Expand Up @@ -404,6 +405,7 @@ def infer_module_for_data_files_in_archives(
most_common = extensions_counter.most_common(1)[0][0]
if most_common in _EXTENSION_TO_MODULE:
return _EXTENSION_TO_MODULE[most_common]
return None, {}


@dataclass
Expand Down Expand Up @@ -632,14 +634,14 @@ def get_module(self) -> DatasetModule:
base_path=base_path,
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
)
module_names = {
key: infer_module_for_data_files(data_files_list) for key, data_files_list in data_files.items()
split_modules = {
split: infer_module_for_data_files(data_files_list) for split, data_files_list in data_files.items()
}
if len(set(list(zip(*module_names.values()))[0])) > 1:
raise ValueError(f"Couldn't infer the same data file format for all splits. Got {module_names}")
module_name, builder_kwargs = next(iter(module_names.values()))
module_name, builder_kwargs = next(iter(split_modules.values()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe check that i split_modules is not empty ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review, @lhoestq.

I think it can only be empty if the user passes data_files={}, otherwise there are 2 options: either it is not empty or an exception is raised.

  • split_modules is derived from data_files, which is instance of DataFilesDict.from_local_or_remote with patterns
  • patterns is derived either from sanitize_patterns or get_data_patterns_locally
    • sanitize_patterns can only return an empty dict if the user passes data_files={}
    • get_data_patterns_locally can only return a non-empty dict or raise a EmptyDatasetError

I think the validation of data_files={} should be elsewhere though. What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe changing?

sanitize_patterns(self.data_files) if self.data_files is not None else get_data_patterns_locally(base_path)

to

sanitize_patterns(self.data_files) if self.data_files else get_data_patterns_locally(base_path)

This way, we are sure split_modules is never empty.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the validation of data_files={} should be elsewhere though. What do you think?

Yea indeed, probably in load_dataset_builder ?

Maybe changing?

I think it's better if it raises an error rather than trying to make it run with data files that were not requested

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to merge then :)

if any((module_name, builder_kwargs) != split_module for split_module in split_modules.values()):
raise ValueError(f"Couldn't infer the same data file format for all splits. Got {split_modules}")
if not module_name:
raise FileNotFoundError(f"No data files or dataset script found in {self.path}")
raise FileNotFoundError(f"No (supported) data files or dataset script found in {self.path}")
# Collect metadata files if the module supports them
if self.data_files is None and module_name in _MODULE_SUPPORTS_METADATA and patterns != DEFAULT_PATTERNS_ALL:
try:
Expand Down Expand Up @@ -772,15 +774,15 @@ def get_module(self) -> DatasetModule:
base_path=self.data_dir,
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
)
module_names = {
key: infer_module_for_data_files(data_files_list, use_auth_token=self.download_config.use_auth_token)
for key, data_files_list in data_files.items()
split_modules = {
split: infer_module_for_data_files(data_files_list, use_auth_token=self.download_config.use_auth_token)
for split, data_files_list in data_files.items()
}
if len(set(list(zip(*module_names.values()))[0])) > 1:
raise ValueError(f"Couldn't infer the same data file format for all splits. Got {module_names}")
module_name, builder_kwargs = next(iter(module_names.values()))
module_name, builder_kwargs = next(iter(split_modules.values()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

if any((module_name, builder_kwargs) != split_module for split_module in split_modules.values()):
raise ValueError(f"Couldn't infer the same data file format for all splits. Got {split_modules}")
if not module_name:
raise FileNotFoundError(f"No data files or dataset script found in {self.name}")
raise FileNotFoundError(f"No (supported) data files or dataset script found in {self.name}")
# Collect metadata files if the module supports them
if self.data_files is None and module_name in _MODULE_SUPPORTS_METADATA and patterns != DEFAULT_PATTERNS_ALL:
try:
Expand Down
9 changes: 9 additions & 0 deletions tests/fixtures/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,15 @@ def zip_text_with_dir_path(text_path, text2_path, tmp_path_factory):
return path


@pytest.fixture(scope="session")
def zip_unsupported_ext_path(text_path, text2_path, tmp_path_factory):
path = tmp_path_factory.mktemp("data") / "dataset.ext.zip"
with zipfile.ZipFile(path, "w") as f:
f.write(text_path, arcname=os.path.basename("unsupported.ext"))
f.write(text2_path, arcname=os.path.basename("unsupported_2.ext"))
return path


@pytest.fixture(scope="session")
def text_path_with_unicode_new_lines(tmp_path_factory):
text = "\n".join(["First", "Second\u2029with Unicode new line", "Third"])
Expand Down
17 changes: 14 additions & 3 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def metric_loading_script_dir(tmp_path):
(["train.jsonl"], "json", {}),
(["train.parquet"], "parquet", {}),
(["train.txt"], "text", {}),
(["unsupported.ext"], None, {}),
([""], None, {}),
],
)
def test_infer_module_for_data_files(data_files, expected_module, expected_builder_kwargs):
Expand All @@ -217,9 +219,18 @@ def test_infer_module_for_data_files(data_files, expected_module, expected_build
assert builder_kwargs == expected_builder_kwargs


@pytest.mark.parametrize("data_file, expected_module", [("zip_csv_path", "csv"), ("zip_csv_with_dir_path", "csv")])
def test_infer_module_for_data_files_in_archives(data_file, expected_module, zip_csv_path, zip_csv_with_dir_path):
data_file_paths = {"zip_csv_path": zip_csv_path, "zip_csv_with_dir_path": zip_csv_with_dir_path}
@pytest.mark.parametrize(
"data_file, expected_module",
[("zip_csv_path", "csv"), ("zip_csv_with_dir_path", "csv"), ("zip_unsupported_ext_path", None)],
)
def test_infer_module_for_data_files_in_archives(
data_file, expected_module, zip_csv_path, zip_csv_with_dir_path, zip_unsupported_ext_path
):
data_file_paths = {
"zip_csv_path": zip_csv_path,
"zip_csv_with_dir_path": zip_csv_with_dir_path,
"zip_unsupported_ext_path": zip_unsupported_ext_path,
}
data_files = [str(data_file_paths[data_file])]
inferred_module, _ = infer_module_for_data_files_in_archives(data_files, False)
assert inferred_module == expected_module
Expand Down