From 2004500adf5b59d9aca85c7a863a207fdec1c8ce Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 23 Dec 2024 14:24:06 +0000 Subject: [PATCH] Fix crash from concurrent nb::make_iterator<> under free-threading. The `State` class used by make_iterator is constructed lazily, but without locking it is possible for the caller to crash when the class type is only partially constructed. This PR adds an ft_mutex around the binding of the State type. My initial instinct was to use an nb_object_guard on `scope`. Unfortunately this doesn't work; I suspect another critical section is acquired by code during the class binding and that releases the outer critical section. --- include/nanobind/make_iterator.h | 43 +++++++++++++++++--------------- tests/common.py | 23 +++++++++++++++++ tests/test_make_iterator.py | 6 ++++- tests/test_thread.py | 25 +------------------ 4 files changed, 52 insertions(+), 45 deletions(-) diff --git a/include/nanobind/make_iterator.h b/include/nanobind/make_iterator.h index 22e125ff..db63a912 100644 --- a/include/nanobind/make_iterator.h +++ b/include/nanobind/make_iterator.h @@ -71,27 +71,30 @@ typed make_iterator_impl(handle scope, const char *name, "make_iterator_impl(): the generated __next__ would copy elements, so the " "element type must be copy-constructible"); - if (!type().is_valid()) { - class_(scope, name) - .def("__iter__", [](handle h) { return h; }) - .def("__next__", - [](State &s) -> ValueType { - if (!s.first_or_done) - ++s.it; - else - s.first_or_done = false; - - if (s.it == s.end) { - s.first_or_done = true; - throw stop_iteration(); - } - - return Access()(s.it); - }, - std::forward(extra)..., - Policy); + { + static ft_mutex mu; + ft_lock_guard lock(mu); + if (!type().is_valid()) { + class_(scope, name) + .def("__iter__", [](handle h) { return h; }) + .def("__next__", + [](State &s) -> ValueType { + if (!s.first_or_done) + ++s.it; + else + s.first_or_done = false; + + if (s.it == s.end) { + s.first_or_done = true; + throw stop_iteration(); + } + + return Access()(s.it); + }, + std::forward(extra)..., + Policy); + } } - return borrow>(cast(State{ std::forward(first), std::forward(last), true })); } diff --git a/tests/common.py b/tests/common.py index 2f8ccafa..a38b9f8f 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,6 +1,7 @@ import platform import gc import pytest +import threading is_pypy = platform.python_implementation() == 'PyPy' is_darwin = platform.system() == 'Darwin' @@ -17,3 +18,25 @@ def collect() -> None: xfail_on_pypy_darwin = pytest.mark.xfail( is_pypy and is_darwin, reason="This test for some reason fails on PyPy/Darwin") + + +# Helper function to parallelize execution of a function. We intentionally +# don't use the Python threads pools here to have threads shut down / start +# between test cases. +def parallelize(func, n_threads): + barrier = threading.Barrier(n_threads) + result = [None]*n_threads + + def wrapper(i): + barrier.wait() + result[i] = func() + + workers = [] + for i in range(n_threads): + t = threading.Thread(target=wrapper, args=(i,)) + t.start() + workers.append(t) + + for worker in workers: + worker.join() + return result \ No newline at end of file diff --git a/tests/test_make_iterator.py b/tests/test_make_iterator.py index ebda4dfc..2e736698 100644 --- a/tests/test_make_iterator.py +++ b/tests/test_make_iterator.py @@ -1,5 +1,5 @@ import test_make_iterator_ext as t -import pytest +from common import parallelize data = [ {}, @@ -30,6 +30,10 @@ def test03_items_iterator(): assert sorted(list(m.items_l())) == sorted(list(d.items())) +def test03_items_iterator_parallel(n_threads=8): + parallelize(test03_items_iterator, n_threads=n_threads) + + def test04_passthrough_iterator(): for d in data: m = t.StringMap(d) diff --git a/tests/test_thread.py b/tests/test_thread.py index f9fd85f1..832b2fb6 100644 --- a/tests/test_thread.py +++ b/tests/test_thread.py @@ -1,29 +1,6 @@ import test_thread_ext as t from test_thread_ext import Counter - -import threading - -# Helper function to parallelize execution of a function. We intentionally -# don't use the Python threads pools here to have threads shut down / start -# between test cases. -def parallelize(func, n_threads): - barrier = threading.Barrier(n_threads) - result = [None]*n_threads - - def wrapper(i): - barrier.wait() - result[i] = func() - - workers = [] - for i in range(n_threads): - t = threading.Thread(target=wrapper, args=(i,)) - t.start() - workers.append(t) - - for worker in workers: - worker.join() - return result - +from common import parallelize def test01_object_creation(n_threads=8): # This test hammers 'inst_c2p' from multiple threads, and