Skip to content

Commit

Permalink
Introduce experimental multithreaded interleave transformation.
Browse files Browse the repository at this point in the history
We will support new file formats that do not allow random access. This transformation will allow to implement parallel reads and hierarchical shuffle for them.

PiperOrigin-RevId: 697777276
  • Loading branch information
iindyk authored and copybara-github committed Nov 18, 2024
1 parent d881f15 commit 2a64618
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 0 deletions.
1 change: 1 addition & 0 deletions grain/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ py_library(
":python_lazy_dataset", # build_cleaner: keep
"//grain/_src/core:transforms", # build_cleaner: keep
"//grain/_src/python/dataset:visualize", # build_cleaner: keep
"//grain/_src/python/dataset/transformations:interleave", # build_cleaner: keep
"//grain/_src/python/experimental/example_packing:packing", # build_cleaner: keep
],
)
Expand Down
18 changes: 18 additions & 0 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,21 @@ py_test(
"//grain/_src/python/dataset",
],
)

py_library(
name = "interleave",
srcs = ["interleave.py"],
srcs_version = "PY3",
deps = ["//grain/_src/python/dataset"],
)

py_test(
name = "interleave_test",
srcs = ["interleave_test.py"],
srcs_version = "PY3",
deps = [
":interleave",
"//grain/_src/python:options",
"//grain/_src/python/dataset",
],
)
168 changes: 168 additions & 0 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements dataset interleaving."""

from collections.abc import Sequence
from typing import TypeVar

from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import prefetch

T = TypeVar("T")


def _add_prefetch_and_make_iterator(
ds: dataset.IterDataset[T] | dataset.MapDataset[T],
) -> dataset.DatasetIterator[T]:
if isinstance(ds, dataset.MapDataset):
# Prefetch is automatically added in `MapDataset.__iter__`.
return ds.__iter__()
return prefetch.ThreadPrefetchIterDataset(
ds, prefetch_buffer_size=1
).__iter__()


class _InterleavedDatasetIterator(dataset.DatasetIterator[T]):
"""Iterates over the interleaved datasets."""

def __init__(
self,
datasets: Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]],
cycle_length: int,
):
# `datasets` is allowed to be a lazily evaluated `MapDataset`. We avoid
# passing it as `parents` to not trigger evaluation early.
super().__init__()
self._datasets = datasets
self._cycle_length: int = min(cycle_length, len(datasets))
self._next_index_in_cycle: int = 0
self._next_index_in_datasets: int = 0
self._iterators_in_use_indices: list[int] = list(range(self._cycle_length))
self._iterators_in_use: list[dataset.DatasetIterator[T] | None] = [
None
] * self._cycle_length

def __next__(self) -> T:
while True:
if iterator_to_use := self._iterators_in_use[self._next_index_in_cycle]:
try:
result = iterator_to_use.__next__()
self._next_index_in_cycle = (
self._next_index_in_cycle + 1
) % self._cycle_length
return result
except StopIteration:
self._iterators_in_use[self._next_index_in_cycle] = None
continue
if self._next_index_in_datasets < len(self._datasets):
self._iterators_in_use[self._next_index_in_cycle] = (
_add_prefetch_and_make_iterator(
self._datasets[self._next_index_in_datasets]
)
)
self._iterators_in_use_indices[self._next_index_in_cycle] = (
self._next_index_in_datasets
)
self._next_index_in_datasets += 1
elif not any(self._iterators_in_use):
raise StopIteration
else:
self._next_index_in_cycle = (
self._next_index_in_cycle + 1
) % self._cycle_length

def get_state(self):
return {
"next_index_in_cycle": self._next_index_in_cycle,
"next_index_in_datasets": self._next_index_in_datasets,
"iterators_in_use_indices": self._iterators_in_use_indices.copy(),
"iterators_in_use_states": [
(None if it is None else it.get_state())
for it in self._iterators_in_use
],
}

def set_state(self, state):
self._next_index_in_cycle = state["next_index_in_cycle"]
self._next_index_in_datasets = state["next_index_in_datasets"]
if not self._next_index_in_datasets and not self._next_index_in_cycle:
return
self._iterators_in_use_indices = state["iterators_in_use_indices"]
for index_in_cycle, (index_in_datasets, it_state) in enumerate(
zip(self._iterators_in_use_indices, state["iterators_in_use_states"])
):
if it_state is None:
self._iterators_in_use[index_in_cycle] = None
else:
iterator = self._datasets[index_in_datasets].__iter__()
iterator.set_state(it_state)
self._iterators_in_use[index_in_cycle] = iterator

def __str__(self) -> str:
return (
f"InterleavedDatasetIterator([{len(self._datasets)} datasets],"
f" cycle_length={self._cycle_length})"
)


class InterleavedIterDataset(dataset.IterDataset[T]):
"""Interleaves the given sequence of datasets.
The sequence can be a `MapDataset`.
Creates at most `cycle_length` iterators at a time that are processed
concurrently and interleives their elements. If `cycle_length` is larger than
the number of datasets, then the behavior is similar to mixing the datasets
with equal proportions. If `cycle_length` is 1, the datasets are chained.
Can be used with `mp_prefetch` to parallelize reading from sources that do not
support random access and are implemented as `IterDataset`:
```
def make_source(filename: str) -> grain.IterDataset:
...
ds = grain.MapDataset.source(filenames).shuffle(seed=42).map(make_source)
ds = grain.experimental.InterleavedIterDataset(ds, cycle_length=4)
ds = ...
ds = ds.mp_prefetch(ds, 2)
for element in ds:
...
```
"""

def __init__(
self,
datasets: Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]],
*,
cycle_length: int,
):
super().__init__()
self._datasets = datasets
self._cycle_length = cycle_length

def __iter__(self) -> _InterleavedDatasetIterator[T]:
return _InterleavedDatasetIterator(
self._datasets,
cycle_length=self._cycle_length,
)

def set_slice(self, sl: slice):
self._datasets = self._datasets[sl]

def __str__(self) -> str:
return (
f"InterleavedIterDataset([{len(self._datasets)} datasets],"
f" cycle_length={self._cycle_length})"
)
121 changes: 121 additions & 0 deletions grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
import multiprocessing as mp
from grain._src.python import options
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import interleave


_INTERLEAVE_TEST_CASES = (
dict(
testcase_name="cycle_length_1",
to_mix=[[1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5]],
cycle_length=1,
expected=[1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5],
),
dict(
testcase_name="cycle_length_2",
to_mix=[[1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5]],
cycle_length=2,
expected=[1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 4, 5, 5, 5],
),
dict(
testcase_name="cycle_length_3",
to_mix=[[1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5]],
cycle_length=3,
expected=[1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 4, 5, 5, 5],
),
dict(
testcase_name="same_lengths",
to_mix=[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]],
cycle_length=3,
expected=[1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 4, 4],
),
dict(
testcase_name="unsorted_lengths",
to_mix=[[1, 1, 1], [2], [3, 3, 3, 3], [4, 4]],
cycle_length=3,
expected=[1, 2, 3, 1, 4, 3, 1, 4, 3, 3],
),
dict(
testcase_name="large_cycle_length",
to_mix=[[1, 1, 1], [2], [3, 3, 3, 3], [4, 4]],
cycle_length=10,
expected=[1, 2, 3, 4, 1, 3, 4, 1, 3, 3],
),
)


class MixedIterDatasetTest(parameterized.TestCase):

@parameterized.named_parameters(*_INTERLEAVE_TEST_CASES)
def test_interleaved_mix(self, to_mix, cycle_length, expected):
datasets = [
dataset.MapDataset.source(elements).to_iter_dataset()
for elements in to_mix
]
ds = interleave.InterleavedIterDataset(datasets, cycle_length=cycle_length)
self.assertEqual(list(ds), expected)
# Sanity check.
flat_inputs = []
for ds in datasets:
flat_inputs.extend(list(ds))
self.assertCountEqual(flat_inputs, expected)

@parameterized.named_parameters(*_INTERLEAVE_TEST_CASES)
def test_checkpoint(self, to_mix, cycle_length, expected):
datasets = [
dataset.MapDataset.source(elements).to_iter_dataset()
for elements in to_mix
]
ds = interleave.InterleavedIterDataset(datasets, cycle_length=cycle_length)
ds_iter = ds.__iter__()
checkpoints = {}
for i in range(len(expected)):
checkpoints[i] = ds_iter.get_state()
_ = next(ds_iter)
for i, state in checkpoints.items():
ds_iter.set_state(state)
self.assertEqual(
list(ds_iter), expected[i:], msg=f"Failed at checkpoint {i}."
)

def test_with_map_dataset_of_datasets(self):

def make_dummy_source(filename):
chars = [c for c in filename]
return dataset.MapDataset.source(chars)

filenames = dataset.MapDataset.source(["11", "2345", "678", "9999"])
sources = filenames.shuffle(seed=42).map(make_dummy_source)
ds = interleave.InterleavedIterDataset(sources, cycle_length=2)
self.assertEqual(
list(ds),
["1", "2", "1", "3", "6", "4", "7", "5", "8", "9", "9", "9", "9"],
)

def test_with_mp_prefetch(self):
ds = dataset.MapDataset.range(1, 6).map(
lambda i: dataset.MapDataset.source([i]).repeat(i).to_iter_dataset()
)
ds = interleave.InterleavedIterDataset(ds, cycle_length=5)
ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=3))
self.assertEqual(list(ds), [1, 2, 3, 4, 5, 3, 4, 2, 3, 4, 5, 4, 5, 5, 5])


if __name__ == "__main__":
absltest.main()
3 changes: 3 additions & 0 deletions grain/python_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
FlatMapMapDataset,
FlatMapIterDataset,
)
from ._src.python.dataset.transformations.interleave import (
InterleavedIterDataset,
)
from ._src.python.dataset.transformations.map import RngPool
from ._src.python.dataset.transformations.mix import ConcatenateMapDataset
from ._src.python.dataset.transformations.packing import FirstFitPackIterDataset
Expand Down

0 comments on commit 2a64618

Please sign in to comment.