Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve split order in DataFilesDict #6198

Merged
merged 4 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,17 +682,6 @@ def from_patterns(
)
return out

def __reduce__(self):
"""
To make sure the order of the keys doesn't matter when pickling and hashing:

>>> from datasets.data_files import DataFilesDict
>>> from datasets.fingerprint import Hasher
>>> assert Hasher.hash(DataFilesDict(a=[], b=[])) == Hasher.hash(DataFilesDict(b=[], a=[]))

"""
return DataFilesDict, (dict(sorted(self.items())),)

def filter_extensions(self, extensions: List[str]) -> "DataFilesDict":
out = type(self)()
for key, data_files_list in self.items():
Expand Down
4 changes: 0 additions & 4 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,10 +759,6 @@ def test_cache_dir_for_data_files(self):
cache_dir=tmp_dir, data_files={"train": [dummy_data1], "test": dummy_data2}
)
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
other_builder = DummyGeneratorBasedBuilder(
cache_dir=tmp_dir, data_files={"test": dummy_data2, "train": dummy_data1}
)
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
other_builder = DummyGeneratorBasedBuilder(
cache_dir=tmp_dir, data_files={"train": dummy_data1, "validation": dummy_data2}
)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_data_files.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
from pathlib import Path, PurePath
from typing import List
Expand Down Expand Up @@ -385,6 +386,13 @@ def test_DataFilesList_from_patterns_raises_FileNotFoundError(complex_data_dir):
DataFilesList.from_patterns(["file_that_doesnt_exist.txt"], complex_data_dir)


class TestDataFilesDict:
def test_key_order_after_copy(self):
data_files = DataFilesDict({"train": "train.csv", "test": "test.csv"})
copied_data_files = copy.deepcopy(data_files)
assert list(copied_data_files.keys()) == list(data_files.keys()) # test split order with list()


@pytest.mark.parametrize("pattern", _TEST_PATTERNS)
def test_DataFilesDict_from_patterns_in_dataset_repository(
hub_dataset_repo_path, hub_dataset_repo_patterns_results, pattern
Expand Down
Loading