Skip to content

Commit

Permalink
Prohibit negative buffer size in ThreadPrefetch, disable prefetch wit…
Browse files Browse the repository at this point in the history
…h buffer size 0.

PiperOrigin-RevId: 698813206
  • Loading branch information
iindyk authored and copybara-github committed Nov 21, 2024
1 parent ba9f8b7 commit 74d043d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
31 changes: 22 additions & 9 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,8 @@ class ThreadPrefetchIterDataset(dataset.IterDataset[T]):
Attributes:
parent: The parent dataset to prefetch from.
prefetch_buffer_size: The size of the prefetch buffer.
prefetch_buffer_size: The size of the prefetch buffer. Must be greater than
or equal to 0. If 0, prefetching is disabled and this is a noop.
"""

def __init__(
Expand All @@ -523,6 +524,11 @@ def __init__(
prefetch_buffer_size: int,
):
super().__init__(parent)
if prefetch_buffer_size < 0:
raise ValueError(
"`prefetch_buffer_size` must be greater than or equal to 0, got "
f"{prefetch_buffer_size}."
)
self._prefetch_buffer_size = prefetch_buffer_size

def __str__(self) -> str:
Expand All @@ -531,9 +537,14 @@ def __str__(self) -> str:
f"prefetch_buffer_size={self._prefetch_buffer_size})"
)

def __iter__(self) -> ThreadPrefetchDatasetIterator[T]:
return ThreadPrefetchDatasetIterator(
self._parent, self._prefetch_buffer_size
def __iter__(self) -> dataset.DatasetIterator[T]:
parent_iter = self._parent.__iter__()
if self._prefetch_buffer_size == 0:
# Avoid raising a NotImplemented error and make a noop instead.
parent_iter.start_prefetch = lambda: None
return parent_iter
return _ThreadPrefetchDatasetIterator(
parent_iter, self._prefetch_buffer_size, str(self)
)


Expand All @@ -545,19 +556,21 @@ def __iter__(self) -> ThreadPrefetchDatasetIterator[T]:
_INITIAL_STATE_SENTINEL = object()


class ThreadPrefetchDatasetIterator(dataset.DatasetIterator[T]):
class _ThreadPrefetchDatasetIterator(dataset.DatasetIterator[T]):
"""Iterator that performs prefetching using a synchronized queue."""

_MUTATES_ELEMENT_SPEC = False

def __init__(
self,
parent: dataset.IterDataset[T],
parent: dataset.DatasetIterator[T],
prefetch_buffer_size: int,
parent_transform_name: str,
):
super().__init__(parent.__iter__())
self._iter_parent: dataset.IterDataset[T] = parent
super().__init__(parent)
assert prefetch_buffer_size > 0, prefetch_buffer_size
self._prefetch_buffer_size = prefetch_buffer_size
self._parent_transform_name = parent_transform_name
self._state: StateT | None = None

self._work_queue = queue.Queue[Callable[[], Any]]()
Expand Down Expand Up @@ -586,7 +599,7 @@ def _start_producer(self, initial_state: None):
self._work_thread = threading.Thread(
target=self._work_loop,
daemon=True,
name=f"Prefetch-{self._iter_parent}",
name=f"Prefetch-{self._parent_transform_name}",
)
self._work_thread.start()

Expand Down
16 changes: 16 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,16 @@ def setUp(self):
)

@parameterized.named_parameters(
dict(
testcase_name='no_prefetch',
prefetch_buffer_size=0,
warm_start=False,
),
dict(
testcase_name='no_prefetch_with_warm_start',
prefetch_buffer_size=0,
warm_start=True,
),
dict(
testcase_name='thread',
prefetch_buffer_size=1,
Expand Down Expand Up @@ -632,6 +642,12 @@ def test_checkpoint(self, warm_start: bool):
value = next(ds_iter)
self.assertEqual(value, values_without_interruption[i])

def test_fails_with_negative_prefetch_buffer_size(self):
with self.assertRaisesRegex(
ValueError, '`prefetch_buffer_size` must be greater than or equal to 0'
):
prefetch.ThreadPrefetchIterDataset(self.ds, prefetch_buffer_size=-1)


if __name__ == '__main__':
absltest.main()

0 comments on commit 74d043d

Please sign in to comment.