Skip to content

Commit

Permalink
Prevent deadlocks in EagerIterators by making prefetch optional. (#185)
Browse files Browse the repository at this point in the history
Previously, if you provided a thread pool that was too small and an
EagerIterator could not create a new preloading thread, the iterator
would deadlock, since it would wait for the new thread to be created
forever and not try to just do the work itself. This change instead
uses preloading as an optional optimization, and if the preload
has not yet been completed, computes the next value itself.
  • Loading branch information
thetorpedodog authored Jan 10, 2024
1 parent a79f984 commit 61f6cc9
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 4 deletions.
21 changes: 17 additions & 4 deletions python-spec/src/somacore/query/_eager_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,29 @@ def __init__(
self.iterator = iterator
self._pool = pool or futures.ThreadPoolExecutor()
self._own_pool = pool is None
self._future = self._pool.submit(self.iterator.__next__)
self._preload_future = self._pool.submit(self.iterator.__next__)

def __next__(self) -> _T:
stopped = False
try:
res = self._future.result()
self._future = self._pool.submit(self.iterator.__next__)
return res
if self._preload_future.cancel():
# If `.cancel` returns True, cancellation was successful.
# The self.iterator.__next__ call has not yet been started,
# and will never be started, so we can compute next ourselves.
# This prevents deadlocks if the thread pool is too small
# and we can never create a preload thread.
return next(self.iterator)
# `.cancel` returned false, so the preload is already running.
# Just wait for it.
return self._preload_future.result()
except StopIteration:
self._cleanup()
stopped = True
raise
finally:
if not stopped:
# If we have more to do, go for the next thing.
self._preload_future = self._pool.submit(self.iterator.__next__)

def _cleanup(self) -> None:
if self._own_pool:
Expand Down
64 changes: 64 additions & 0 deletions python-spec/testing/test_eager_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import threading
import unittest
from concurrent import futures
from unittest import mock

from somacore.query import _eager_iter


class EagerIterTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.kiddie_pool = futures.ThreadPoolExecutor(1)
"""Tiny thread pool for testing."""
self.verify_pool = futures.ThreadPoolExecutor(1)
"""Separate thread pool so verification is not blocked."""

def tearDown(self):
self.verify_pool.shutdown(wait=False)
self.kiddie_pool.shutdown(wait=False)
super().tearDown()

def test_thread_starvation(self):
sem = threading.Semaphore()
try:
# Monopolize the threadpool.
sem.acquire()
self.kiddie_pool.submit(sem.acquire)
eager = _eager_iter.EagerIterator(iter("abc"), pool=self.kiddie_pool)
got_a = self.verify_pool.submit(lambda: next(eager))
self.assertEqual("a", got_a.result(0.1))
got_b = self.verify_pool.submit(lambda: next(eager))
self.assertEqual("b", got_b.result(0.1))
got_c = self.verify_pool.submit(lambda: next(eager))
self.assertEqual("c", got_c.result(0.1))
with self.assertRaises(StopIteration):
self.verify_pool.submit(lambda: next(eager)).result(0.1)
finally:
sem.release()

def test_nesting(self):
inner = _eager_iter.EagerIterator(iter("abc"), pool=self.kiddie_pool)
outer = _eager_iter.EagerIterator(inner, pool=self.kiddie_pool)
self.assertEqual(
"a, b, c", self.verify_pool.submit(", ".join, outer).result(0.1)
)

def test_exceptions(self):
flaky = mock.MagicMock()
flaky.__next__.side_effect = [1, 2, ValueError(), 3, 4]

eager_flaky = _eager_iter.EagerIterator(flaky, pool=self.kiddie_pool)
got_1 = self.verify_pool.submit(lambda: next(eager_flaky))
self.assertEqual(1, got_1.result(0.1))
got_2 = self.verify_pool.submit(lambda: next(eager_flaky))
self.assertEqual(2, got_2.result(0.1))
with self.assertRaises(ValueError):
self.verify_pool.submit(lambda: next(eager_flaky)).result(0.1)
got_3 = self.verify_pool.submit(lambda: next(eager_flaky))
self.assertEqual(3, got_3.result(0.1))
got_4 = self.verify_pool.submit(lambda: next(eager_flaky))
self.assertEqual(4, got_4.result(0.1))
for _ in range(5):
with self.assertRaises(StopIteration):
self.verify_pool.submit(lambda: next(eager_flaky)).result(0.1)

0 comments on commit 61f6cc9

Please sign in to comment.