Skip to content

Commit

Permalink
Rename InterleavedIterDataset->InterleaveIterDataset for consistency …
Browse files Browse the repository at this point in the history
…with other names.

PiperOrigin-RevId: 698065051
  • Loading branch information
iindyk authored and copybara-github committed Nov 19, 2024
1 parent 2a64618 commit 16d0983
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
14 changes: 7 additions & 7 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _add_prefetch_and_make_iterator(
).__iter__()


class _InterleavedDatasetIterator(dataset.DatasetIterator[T]):
class _InterleaveDatasetIterator(dataset.DatasetIterator[T]):
"""Iterates over the interleaved datasets."""

def __init__(
Expand Down Expand Up @@ -111,12 +111,12 @@ def set_state(self, state):

def __str__(self) -> str:
return (
f"InterleavedDatasetIterator([{len(self._datasets)} datasets],"
f"InterleaveDatasetIterator([{len(self._datasets)} datasets],"
f" cycle_length={self._cycle_length})"
)


class InterleavedIterDataset(dataset.IterDataset[T]):
class InterleaveIterDataset(dataset.IterDataset[T]):
"""Interleaves the given sequence of datasets.
The sequence can be a `MapDataset`.
Expand All @@ -134,7 +134,7 @@ def make_source(filename: str) -> grain.IterDataset:
...
ds = grain.MapDataset.source(filenames).shuffle(seed=42).map(make_source)
ds = grain.experimental.InterleavedIterDataset(ds, cycle_length=4)
ds = grain.experimental.InterleaveIterDataset(ds, cycle_length=4)
ds = ...
ds = ds.mp_prefetch(ds, 2)
for element in ds:
Expand All @@ -152,8 +152,8 @@ def __init__(
self._datasets = datasets
self._cycle_length = cycle_length

def __iter__(self) -> _InterleavedDatasetIterator[T]:
return _InterleavedDatasetIterator(
def __iter__(self) -> _InterleaveDatasetIterator[T]:
return _InterleaveDatasetIterator(
self._datasets,
cycle_length=self._cycle_length,
)
Expand All @@ -163,6 +163,6 @@ def set_slice(self, sl: slice):

def __str__(self) -> str:
return (
f"InterleavedIterDataset([{len(self._datasets)} datasets],"
f"InterleaveIterDataset([{len(self._datasets)} datasets],"
f" cycle_length={self._cycle_length})"
)
8 changes: 4 additions & 4 deletions grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_interleaved_mix(self, to_mix, cycle_length, expected):
dataset.MapDataset.source(elements).to_iter_dataset()
for elements in to_mix
]
ds = interleave.InterleavedIterDataset(datasets, cycle_length=cycle_length)
ds = interleave.InterleaveIterDataset(datasets, cycle_length=cycle_length)
self.assertEqual(list(ds), expected)
# Sanity check.
flat_inputs = []
Expand All @@ -82,7 +82,7 @@ def test_checkpoint(self, to_mix, cycle_length, expected):
dataset.MapDataset.source(elements).to_iter_dataset()
for elements in to_mix
]
ds = interleave.InterleavedIterDataset(datasets, cycle_length=cycle_length)
ds = interleave.InterleaveIterDataset(datasets, cycle_length=cycle_length)
ds_iter = ds.__iter__()
checkpoints = {}
for i in range(len(expected)):
Expand All @@ -102,7 +102,7 @@ def make_dummy_source(filename):

filenames = dataset.MapDataset.source(["11", "2345", "678", "9999"])
sources = filenames.shuffle(seed=42).map(make_dummy_source)
ds = interleave.InterleavedIterDataset(sources, cycle_length=2)
ds = interleave.InterleaveIterDataset(sources, cycle_length=2)
self.assertEqual(
list(ds),
["1", "2", "1", "3", "6", "4", "7", "5", "8", "9", "9", "9", "9"],
Expand All @@ -112,7 +112,7 @@ def test_with_mp_prefetch(self):
ds = dataset.MapDataset.range(1, 6).map(
lambda i: dataset.MapDataset.source([i]).repeat(i).to_iter_dataset()
)
ds = interleave.InterleavedIterDataset(ds, cycle_length=5)
ds = interleave.InterleaveIterDataset(ds, cycle_length=5)
ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=3))
self.assertEqual(list(ds), [1, 2, 3, 4, 5, 3, 4, 2, 3, 4, 5, 4, 5, 5, 5])

Expand Down
2 changes: 1 addition & 1 deletion grain/python_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
FlatMapIterDataset,
)
from ._src.python.dataset.transformations.interleave import (
InterleavedIterDataset,
InterleaveIterDataset,
)
from ._src.python.dataset.transformations.map import RngPool
from ._src.python.dataset.transformations.mix import ConcatenateMapDataset
Expand Down

0 comments on commit 16d0983

Please sign in to comment.