-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce experimental multithreaded interleave transformation.
We will support new file formats that do not allow random access. This transformation will allow to implement parallel reads and hierarchical shuffle for them. PiperOrigin-RevId: 697777276
- Loading branch information
1 parent
d881f15
commit 2a64618
Showing
5 changed files
with
311 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
168 changes: 168 additions & 0 deletions
168
grain/_src/python/dataset/transformations/interleave.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Implements dataset interleaving.""" | ||
|
||
from collections.abc import Sequence | ||
from typing import TypeVar | ||
|
||
from grain._src.python.dataset import dataset | ||
from grain._src.python.dataset.transformations import prefetch | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
def _add_prefetch_and_make_iterator( | ||
ds: dataset.IterDataset[T] | dataset.MapDataset[T], | ||
) -> dataset.DatasetIterator[T]: | ||
if isinstance(ds, dataset.MapDataset): | ||
# Prefetch is automatically added in `MapDataset.__iter__`. | ||
return ds.__iter__() | ||
return prefetch.ThreadPrefetchIterDataset( | ||
ds, prefetch_buffer_size=1 | ||
).__iter__() | ||
|
||
|
||
class _InterleavedDatasetIterator(dataset.DatasetIterator[T]): | ||
"""Iterates over the interleaved datasets.""" | ||
|
||
def __init__( | ||
self, | ||
datasets: Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]], | ||
cycle_length: int, | ||
): | ||
# `datasets` is allowed to be a lazily evaluated `MapDataset`. We avoid | ||
# passing it as `parents` to not trigger evaluation early. | ||
super().__init__() | ||
self._datasets = datasets | ||
self._cycle_length: int = min(cycle_length, len(datasets)) | ||
self._next_index_in_cycle: int = 0 | ||
self._next_index_in_datasets: int = 0 | ||
self._iterators_in_use_indices: list[int] = list(range(self._cycle_length)) | ||
self._iterators_in_use: list[dataset.DatasetIterator[T] | None] = [ | ||
None | ||
] * self._cycle_length | ||
|
||
def __next__(self) -> T: | ||
while True: | ||
if iterator_to_use := self._iterators_in_use[self._next_index_in_cycle]: | ||
try: | ||
result = iterator_to_use.__next__() | ||
self._next_index_in_cycle = ( | ||
self._next_index_in_cycle + 1 | ||
) % self._cycle_length | ||
return result | ||
except StopIteration: | ||
self._iterators_in_use[self._next_index_in_cycle] = None | ||
continue | ||
if self._next_index_in_datasets < len(self._datasets): | ||
self._iterators_in_use[self._next_index_in_cycle] = ( | ||
_add_prefetch_and_make_iterator( | ||
self._datasets[self._next_index_in_datasets] | ||
) | ||
) | ||
self._iterators_in_use_indices[self._next_index_in_cycle] = ( | ||
self._next_index_in_datasets | ||
) | ||
self._next_index_in_datasets += 1 | ||
elif not any(self._iterators_in_use): | ||
raise StopIteration | ||
else: | ||
self._next_index_in_cycle = ( | ||
self._next_index_in_cycle + 1 | ||
) % self._cycle_length | ||
|
||
def get_state(self): | ||
return { | ||
"next_index_in_cycle": self._next_index_in_cycle, | ||
"next_index_in_datasets": self._next_index_in_datasets, | ||
"iterators_in_use_indices": self._iterators_in_use_indices.copy(), | ||
"iterators_in_use_states": [ | ||
(None if it is None else it.get_state()) | ||
for it in self._iterators_in_use | ||
], | ||
} | ||
|
||
def set_state(self, state): | ||
self._next_index_in_cycle = state["next_index_in_cycle"] | ||
self._next_index_in_datasets = state["next_index_in_datasets"] | ||
if not self._next_index_in_datasets and not self._next_index_in_cycle: | ||
return | ||
self._iterators_in_use_indices = state["iterators_in_use_indices"] | ||
for index_in_cycle, (index_in_datasets, it_state) in enumerate( | ||
zip(self._iterators_in_use_indices, state["iterators_in_use_states"]) | ||
): | ||
if it_state is None: | ||
self._iterators_in_use[index_in_cycle] = None | ||
else: | ||
iterator = self._datasets[index_in_datasets].__iter__() | ||
iterator.set_state(it_state) | ||
self._iterators_in_use[index_in_cycle] = iterator | ||
|
||
def __str__(self) -> str: | ||
return ( | ||
f"InterleavedDatasetIterator([{len(self._datasets)} datasets]," | ||
f" cycle_length={self._cycle_length})" | ||
) | ||
|
||
|
||
class InterleavedIterDataset(dataset.IterDataset[T]): | ||
"""Interleaves the given sequence of datasets. | ||
The sequence can be a `MapDataset`. | ||
Creates at most `cycle_length` iterators at a time that are processed | ||
concurrently and interleives their elements. If `cycle_length` is larger than | ||
the number of datasets, then the behavior is similar to mixing the datasets | ||
with equal proportions. If `cycle_length` is 1, the datasets are chained. | ||
Can be used with `mp_prefetch` to parallelize reading from sources that do not | ||
support random access and are implemented as `IterDataset`: | ||
``` | ||
def make_source(filename: str) -> grain.IterDataset: | ||
... | ||
ds = grain.MapDataset.source(filenames).shuffle(seed=42).map(make_source) | ||
ds = grain.experimental.InterleavedIterDataset(ds, cycle_length=4) | ||
ds = ... | ||
ds = ds.mp_prefetch(ds, 2) | ||
for element in ds: | ||
... | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
datasets: Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]], | ||
*, | ||
cycle_length: int, | ||
): | ||
super().__init__() | ||
self._datasets = datasets | ||
self._cycle_length = cycle_length | ||
|
||
def __iter__(self) -> _InterleavedDatasetIterator[T]: | ||
return _InterleavedDatasetIterator( | ||
self._datasets, | ||
cycle_length=self._cycle_length, | ||
) | ||
|
||
def set_slice(self, sl: slice): | ||
self._datasets = self._datasets[sl] | ||
|
||
def __str__(self) -> str: | ||
return ( | ||
f"InterleavedIterDataset([{len(self._datasets)} datasets]," | ||
f" cycle_length={self._cycle_length})" | ||
) |
121 changes: 121 additions & 0 deletions
121
grain/_src/python/dataset/transformations/interleave_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
import multiprocessing as mp | ||
from grain._src.python import options | ||
from grain._src.python.dataset import dataset | ||
from grain._src.python.dataset.transformations import interleave | ||
|
||
|
||
_INTERLEAVE_TEST_CASES = ( | ||
dict( | ||
testcase_name="cycle_length_1", | ||
to_mix=[[1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5]], | ||
cycle_length=1, | ||
expected=[1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5], | ||
), | ||
dict( | ||
testcase_name="cycle_length_2", | ||
to_mix=[[1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5]], | ||
cycle_length=2, | ||
expected=[1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 4, 5, 5, 5], | ||
), | ||
dict( | ||
testcase_name="cycle_length_3", | ||
to_mix=[[1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5]], | ||
cycle_length=3, | ||
expected=[1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 4, 5, 5, 5], | ||
), | ||
dict( | ||
testcase_name="same_lengths", | ||
to_mix=[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]], | ||
cycle_length=3, | ||
expected=[1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 4, 4], | ||
), | ||
dict( | ||
testcase_name="unsorted_lengths", | ||
to_mix=[[1, 1, 1], [2], [3, 3, 3, 3], [4, 4]], | ||
cycle_length=3, | ||
expected=[1, 2, 3, 1, 4, 3, 1, 4, 3, 3], | ||
), | ||
dict( | ||
testcase_name="large_cycle_length", | ||
to_mix=[[1, 1, 1], [2], [3, 3, 3, 3], [4, 4]], | ||
cycle_length=10, | ||
expected=[1, 2, 3, 4, 1, 3, 4, 1, 3, 3], | ||
), | ||
) | ||
|
||
|
||
class MixedIterDatasetTest(parameterized.TestCase): | ||
|
||
@parameterized.named_parameters(*_INTERLEAVE_TEST_CASES) | ||
def test_interleaved_mix(self, to_mix, cycle_length, expected): | ||
datasets = [ | ||
dataset.MapDataset.source(elements).to_iter_dataset() | ||
for elements in to_mix | ||
] | ||
ds = interleave.InterleavedIterDataset(datasets, cycle_length=cycle_length) | ||
self.assertEqual(list(ds), expected) | ||
# Sanity check. | ||
flat_inputs = [] | ||
for ds in datasets: | ||
flat_inputs.extend(list(ds)) | ||
self.assertCountEqual(flat_inputs, expected) | ||
|
||
@parameterized.named_parameters(*_INTERLEAVE_TEST_CASES) | ||
def test_checkpoint(self, to_mix, cycle_length, expected): | ||
datasets = [ | ||
dataset.MapDataset.source(elements).to_iter_dataset() | ||
for elements in to_mix | ||
] | ||
ds = interleave.InterleavedIterDataset(datasets, cycle_length=cycle_length) | ||
ds_iter = ds.__iter__() | ||
checkpoints = {} | ||
for i in range(len(expected)): | ||
checkpoints[i] = ds_iter.get_state() | ||
_ = next(ds_iter) | ||
for i, state in checkpoints.items(): | ||
ds_iter.set_state(state) | ||
self.assertEqual( | ||
list(ds_iter), expected[i:], msg=f"Failed at checkpoint {i}." | ||
) | ||
|
||
def test_with_map_dataset_of_datasets(self): | ||
|
||
def make_dummy_source(filename): | ||
chars = [c for c in filename] | ||
return dataset.MapDataset.source(chars) | ||
|
||
filenames = dataset.MapDataset.source(["11", "2345", "678", "9999"]) | ||
sources = filenames.shuffle(seed=42).map(make_dummy_source) | ||
ds = interleave.InterleavedIterDataset(sources, cycle_length=2) | ||
self.assertEqual( | ||
list(ds), | ||
["1", "2", "1", "3", "6", "4", "7", "5", "8", "9", "9", "9", "9"], | ||
) | ||
|
||
def test_with_mp_prefetch(self): | ||
ds = dataset.MapDataset.range(1, 6).map( | ||
lambda i: dataset.MapDataset.source([i]).repeat(i).to_iter_dataset() | ||
) | ||
ds = interleave.InterleavedIterDataset(ds, cycle_length=5) | ||
ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=3)) | ||
self.assertEqual(list(ds), [1, 2, 3, 4, 5, 3, 4, 2, 3, 4, 5, 4, 5, 5, 5]) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters