Skip to content

Commit

Permalink
Fix crash from concurrent nb::make_iterator<> under free-threading.
Browse files Browse the repository at this point in the history
The `State` class used by make_iterator is constructed lazily, but
without locking it is possible for the caller to crash when the class
type is only partially constructed. This PR adds an ft_mutex around the
binding of the State type.

My initial instinct was to use an nb_object_guard on `scope`.
Unfortunately this doesn't work; I suspect another critical section is
acquired by code during the class binding and that releases the outer
critical section.
  • Loading branch information
hawkinsp committed Dec 23, 2024
1 parent b7c4f1a commit 2004500
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 45 deletions.
43 changes: 23 additions & 20 deletions include/nanobind/make_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,27 +71,30 @@ typed<iterator, ValueType> make_iterator_impl(handle scope, const char *name,
"make_iterator_impl(): the generated __next__ would copy elements, so the "
"element type must be copy-constructible");

if (!type<State>().is_valid()) {
class_<State>(scope, name)
.def("__iter__", [](handle h) { return h; })
.def("__next__",
[](State &s) -> ValueType {
if (!s.first_or_done)
++s.it;
else
s.first_or_done = false;

if (s.it == s.end) {
s.first_or_done = true;
throw stop_iteration();
}

return Access()(s.it);
},
std::forward<Extra>(extra)...,
Policy);
{
static ft_mutex mu;
ft_lock_guard lock(mu);
if (!type<State>().is_valid()) {
class_<State>(scope, name)
.def("__iter__", [](handle h) { return h; })
.def("__next__",
[](State &s) -> ValueType {
if (!s.first_or_done)
++s.it;
else
s.first_or_done = false;

if (s.it == s.end) {
s.first_or_done = true;
throw stop_iteration();
}

return Access()(s.it);
},
std::forward<Extra>(extra)...,
Policy);
}
}

return borrow<typed<iterator, ValueType>>(cast(State{
std::forward<Iterator>(first), std::forward<Sentinel>(last), true }));
}
Expand Down
23 changes: 23 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import platform
import gc
import pytest
import threading

is_pypy = platform.python_implementation() == 'PyPy'
is_darwin = platform.system() == 'Darwin'
Expand All @@ -17,3 +18,25 @@ def collect() -> None:

xfail_on_pypy_darwin = pytest.mark.xfail(
is_pypy and is_darwin, reason="This test for some reason fails on PyPy/Darwin")


# Helper function to parallelize execution of a function. We intentionally
# don't use the Python threads pools here to have threads shut down / start
# between test cases.
def parallelize(func, n_threads):
barrier = threading.Barrier(n_threads)
result = [None]*n_threads

def wrapper(i):
barrier.wait()
result[i] = func()

workers = []
for i in range(n_threads):
t = threading.Thread(target=wrapper, args=(i,))
t.start()
workers.append(t)

for worker in workers:
worker.join()
return result
6 changes: 5 additions & 1 deletion tests/test_make_iterator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import test_make_iterator_ext as t
import pytest
from common import parallelize

data = [
{},
Expand Down Expand Up @@ -30,6 +30,10 @@ def test03_items_iterator():
assert sorted(list(m.items_l())) == sorted(list(d.items()))


def test03_items_iterator_parallel(n_threads=8):
parallelize(test03_items_iterator, n_threads=n_threads)


def test04_passthrough_iterator():
for d in data:
m = t.StringMap(d)
Expand Down
25 changes: 1 addition & 24 deletions tests/test_thread.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,6 @@
import test_thread_ext as t
from test_thread_ext import Counter

import threading

# Helper function to parallelize execution of a function. We intentionally
# don't use the Python threads pools here to have threads shut down / start
# between test cases.
def parallelize(func, n_threads):
barrier = threading.Barrier(n_threads)
result = [None]*n_threads

def wrapper(i):
barrier.wait()
result[i] = func()

workers = []
for i in range(n_threads):
t = threading.Thread(target=wrapper, args=(i,))
t.start()
workers.append(t)

for worker in workers:
worker.join()
return result

from common import parallelize

def test01_object_creation(n_threads=8):
# This test hammers 'inst_c2p' from multiple threads, and
Expand Down

0 comments on commit 2004500

Please sign in to comment.