diff --git a/python-spec/src/somacore/query/_eager_iter.py b/python-spec/src/somacore/query/_eager_iter.py index 84e601b9..c42e52c1 100644 --- a/python-spec/src/somacore/query/_eager_iter.py +++ b/python-spec/src/somacore/query/_eager_iter.py @@ -14,16 +14,29 @@ def __init__( self.iterator = iterator self._pool = pool or futures.ThreadPoolExecutor() self._own_pool = pool is None - self._future = self._pool.submit(self.iterator.__next__) + self._preload_future = self._pool.submit(self.iterator.__next__) def __next__(self) -> _T: + stopped = False try: - res = self._future.result() - self._future = self._pool.submit(self.iterator.__next__) - return res + if self._preload_future.cancel(): + # If `.cancel` returns True, cancellation was successful. + # The self.iterator.__next__ call has not yet been started, + # and will never be started, so we can compute next ourselves. + # This prevents deadlocks if the thread pool is too small + # and we can never create a preload thread. + return next(self.iterator) + # `.cancel` returned false, so the preload is already running. + # Just wait for it. + return self._preload_future.result() except StopIteration: self._cleanup() + stopped = True raise + finally: + if not stopped: + # If we have more to do, go for the next thing. + self._preload_future = self._pool.submit(self.iterator.__next__) def _cleanup(self) -> None: if self._own_pool: diff --git a/python-spec/testing/test_eager_iter.py b/python-spec/testing/test_eager_iter.py new file mode 100644 index 00000000..87f74b9c --- /dev/null +++ b/python-spec/testing/test_eager_iter.py @@ -0,0 +1,64 @@ +import threading +import unittest +from concurrent import futures +from unittest import mock + +from somacore.query import _eager_iter + + +class EagerIterTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.kiddie_pool = futures.ThreadPoolExecutor(1) + """Tiny thread pool for testing.""" + self.verify_pool = futures.ThreadPoolExecutor(1) + """Separate thread pool so verification is not blocked.""" + + def tearDown(self): + self.verify_pool.shutdown(wait=False) + self.kiddie_pool.shutdown(wait=False) + super().tearDown() + + def test_thread_starvation(self): + sem = threading.Semaphore() + try: + # Monopolize the threadpool. + sem.acquire() + self.kiddie_pool.submit(sem.acquire) + eager = _eager_iter.EagerIterator(iter("abc"), pool=self.kiddie_pool) + got_a = self.verify_pool.submit(lambda: next(eager)) + self.assertEqual("a", got_a.result(0.1)) + got_b = self.verify_pool.submit(lambda: next(eager)) + self.assertEqual("b", got_b.result(0.1)) + got_c = self.verify_pool.submit(lambda: next(eager)) + self.assertEqual("c", got_c.result(0.1)) + with self.assertRaises(StopIteration): + self.verify_pool.submit(lambda: next(eager)).result(0.1) + finally: + sem.release() + + def test_nesting(self): + inner = _eager_iter.EagerIterator(iter("abc"), pool=self.kiddie_pool) + outer = _eager_iter.EagerIterator(inner, pool=self.kiddie_pool) + self.assertEqual( + "a, b, c", self.verify_pool.submit(", ".join, outer).result(0.1) + ) + + def test_exceptions(self): + flaky = mock.MagicMock() + flaky.__next__.side_effect = [1, 2, ValueError(), 3, 4] + + eager_flaky = _eager_iter.EagerIterator(flaky, pool=self.kiddie_pool) + got_1 = self.verify_pool.submit(lambda: next(eager_flaky)) + self.assertEqual(1, got_1.result(0.1)) + got_2 = self.verify_pool.submit(lambda: next(eager_flaky)) + self.assertEqual(2, got_2.result(0.1)) + with self.assertRaises(ValueError): + self.verify_pool.submit(lambda: next(eager_flaky)).result(0.1) + got_3 = self.verify_pool.submit(lambda: next(eager_flaky)) + self.assertEqual(3, got_3.result(0.1)) + got_4 = self.verify_pool.submit(lambda: next(eager_flaky)) + self.assertEqual(4, got_4.result(0.1)) + for _ in range(5): + with self.assertRaises(StopIteration): + self.verify_pool.submit(lambda: next(eager_flaky)).result(0.1)