Skip to content

Commit

Permalink
Prevent deadlocks in EagerIterators by making prefetch optional.
Browse files Browse the repository at this point in the history
Previously, if you provided a thread pool that was too small and an
EagerIterator could not create a new preloading thread, the iterator
would deadlock, since it would wait for the new thread to be created
forever and not try to just do the work itself. This change instead
uses preloading as an optional optimization, and if the preload
has not yet been completed, computes the next value itself.
  • Loading branch information
thetorpedodog committed Jan 10, 2024
1 parent a79f984 commit f5ce1d8
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 3 deletions.
21 changes: 18 additions & 3 deletions python-spec/src/somacore/query/_eager_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,31 @@ def __init__(
self.iterator = iterator
self._pool = pool or futures.ThreadPoolExecutor()
self._own_pool = pool is None
self._future = self._pool.submit(self.iterator.__next__)
self._preload_future = self._pool.submit(self.iterator.__next__)

def __next__(self) -> _T:
stopped = False
try:
res = self._future.result()
self._future = self._pool.submit(self.iterator.__next__)
if self._preload_future.cancel():
# If `.cancel` returns True, cancellation was successful.
# The self.iterator.__next__ call has not yet been started,
# and will never be started, so we can compute next ourselves.
# This prevents deadlocks if the thread pool is too small
# and we can never create a preload thread.
res = next(self.iterator)
else:
# `.cancel` returned false, so the preload is already running.
# Just wait for it.
res = self._preload_future.result()
return res
except StopIteration:
self._cleanup()
stopped = True
raise
finally:
if not stopped:
# If we have more to do, go for the next thing.
self._preload_future = self._pool.submit(self.iterator.__next__)

def _cleanup(self) -> None:
if self._own_pool:
Expand Down
64 changes: 64 additions & 0 deletions python-spec/testing/test_eager_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import threading
import unittest
from concurrent import futures
from unittest import mock

from somacore.query import _eager_iter


class EagerIterTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.kiddie_pool = futures.ThreadPoolExecutor(1)
"""Tiny thread pool for testing."""
self.verify_pool = futures.ThreadPoolExecutor(1)
"""Separate thread pool so verification is not blocked."""

def tearDown(self):
self.verify_pool.shutdown(wait=False)
self.kiddie_pool.shutdown(wait=False)
super().tearDown()

def test_thread_starvation(self):
sem = threading.Semaphore()
try:
# Monopolize the threadpool.
sem.acquire()
self.kiddie_pool.submit(sem.acquire)
eager = _eager_iter.EagerIterator(iter("abc"), pool=self.kiddie_pool)
got_a = self.verify_pool.submit(lambda: next(eager))
self.assertEqual("a", got_a.result(0.1))
got_b = self.verify_pool.submit(lambda: next(eager))
self.assertEqual("b", got_b.result(0.1))
got_c = self.verify_pool.submit(lambda: next(eager))
self.assertEqual("c", got_c.result(0.1))
with self.assertRaises(StopIteration):
self.verify_pool.submit(lambda: next(eager)).result(0.1)
finally:
sem.release()

def test_nesting(self):
inner = _eager_iter.EagerIterator(iter("abc"), pool=self.kiddie_pool)
outer = _eager_iter.EagerIterator(inner, pool=self.kiddie_pool)
self.assertEqual(
"a, b, c", self.verify_pool.submit(", ".join, outer).result(0.1)
)

def test_exceptions(self):
flaky = mock.MagicMock()
flaky.__next__.side_effect = [1, 2, ValueError(), 3, 4]

eager_flaky = _eager_iter.EagerIterator(flaky, pool=self.kiddie_pool)
got_1 = self.verify_pool.submit(lambda: next(eager_flaky))
self.assertEqual(1, got_1.result(0.1))
got_2 = self.verify_pool.submit(lambda: next(eager_flaky))
self.assertEqual(2, got_2.result(0.1))
with self.assertRaises(ValueError):
self.verify_pool.submit(lambda: next(eager_flaky)).result(0.1)
got_3 = self.verify_pool.submit(lambda: next(eager_flaky))
self.assertEqual(3, got_3.result(0.1))
got_4 = self.verify_pool.submit(lambda: next(eager_flaky))
self.assertEqual(4, got_4.result(0.1))
for _ in range(5):
with self.assertRaises(StopIteration):
self.verify_pool.submit(lambda: next(eager_flaky)).result(0.1)

0 comments on commit f5ce1d8

Please sign in to comment.