Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix functional.h bug + introduce test to verify that it is fixed #4254

Merged
merged 32 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
56cd516
Illustrate bug in functional.h
EthanSteinberg Oct 19, 2022
ba640fd
style: pre-commit fixes
pre-commit-ci[bot] Oct 19, 2022
e85ed9b
Make functional casting more robust / add workaround
Skylion007 Oct 19, 2022
7f3ff9b
Make function_record* casting even more robust
Skylion007 Oct 19, 2022
fbdf4aa
See if this fixes PyPy issue
Skylion007 Oct 19, 2022
9deba24
It still fails on PyPy sadly
Skylion007 Oct 19, 2022
4c6c768
Do not make new CTOR just yet
Skylion007 Oct 20, 2022
c36f3a8
Fix test
Skylion007 Oct 20, 2022
1aaf98b
Add name to ensure correctness
EthanSteinberg Oct 20, 2022
871273c
style: pre-commit fixes
pre-commit-ci[bot] Oct 20, 2022
4962a5d
Clean up tests + remove ifdef guards
EthanSteinberg Oct 20, 2022
3a8f0f3
Add comments
EthanSteinberg Oct 20, 2022
8dce8b3
Improve comments, error handling, and safety
Skylion007 Oct 21, 2022
eca0cbf
Fix compile error
Skylion007 Oct 21, 2022
a0d5be3
Fix magic logic
EthanSteinberg Oct 21, 2022
184534e
Extract helper function
Skylion007 Oct 21, 2022
f92cb03
Fix func signature
Skylion007 Oct 21, 2022
17b1f91
move to local internals
EthanSteinberg Oct 21, 2022
3e04c07
style: pre-commit fixes
pre-commit-ci[bot] Oct 21, 2022
13b5fea
Switch to simpler design
EthanSteinberg Oct 21, 2022
71f97d1
style: pre-commit fixes
pre-commit-ci[bot] Oct 21, 2022
5c07f91
Move to function_record
EthanSteinberg Oct 21, 2022
aae5968
style: pre-commit fixes
pre-commit-ci[bot] Oct 21, 2022
8edaae7
Switch to internals, update tests and docs
EthanSteinberg Oct 22, 2022
8f8a873
Fix lint
EthanSteinberg Oct 22, 2022
2d054d3
Oops, forgot to resolve last comment
EthanSteinberg Oct 23, 2022
8e4cece
Fix typo
EthanSteinberg Oct 23, 2022
aa89200
Update in response to comments
EthanSteinberg Oct 29, 2022
f184c54
Implement suggestion to improve test
EthanSteinberg Nov 1, 2022
14b7e33
Update comment
EthanSteinberg Nov 1, 2022
d54fa89
Simple fixes
EthanSteinberg Nov 1, 2022
13cb394
Merge remote-tracking branch 'upstream/master' into functional_bug
EthanSteinberg Nov 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions include/pybind11/detail/internals.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ using ExceptionTranslator = void (*)(std::exception_ptr);

PYBIND11_NAMESPACE_BEGIN(detail)

constexpr const char *internals_function_record_capsule_name = "pybind11_function_record_capsule";

// Forward declarations
inline PyTypeObject *make_static_property_type();
inline PyTypeObject *make_default_metaclass();
Expand Down Expand Up @@ -182,6 +184,16 @@ struct internals {
# endif // PYBIND11_INTERNALS_VERSION > 4
// Unused if PYBIND11_SIMPLE_GIL_MANAGEMENT is defined:
PyInterpreterState *istate = nullptr;

# if PYBIND11_INTERNALS_VERSION > 4
// Note that we have to use a std::string to allocate memory to ensure a unique address
// We want unique addresses since we use pointer equality to compare function records
std::string function_record_capsule_name = internals_function_record_capsule_name;
EthanSteinberg marked this conversation as resolved.
Show resolved Hide resolved
# endif

internals() = default;
internals(const internals &other) = delete;
internals &operator=(const internals &other) = delete;
~internals() {
# if PYBIND11_INTERNALS_VERSION > 4
PYBIND11_TLS_FREE(loader_life_support_tls_key);
Expand Down Expand Up @@ -548,6 +560,25 @@ const char *c_str(Args &&...args) {
return strings.front().c_str();
}

inline const char *get_function_record_capsule_name() {
#if PYBIND11_INTERNALS_VERSION > 4
return get_internals().function_record_capsule_name.c_str();
#else
return nullptr;
#endif
}

// Determine whether or not the following capsule contains a pybind11 function record.
// Note that we use `internals` to make sure that only ABI compatible records are touched.
//
// This check is currently used in two places:
// - An important optimization in functional.h to avoid overhead in C++ -> Python -> C++
// - The sibling feature of cpp_function to allow overloads
inline bool is_function_record_capsule(const capsule &cap) {
// Pointer equality as we rely on internals() to ensure unique pointers
return cap.name() == get_function_record_capsule_name();
}

PYBIND11_NAMESPACE_END(detail)

/// Returns a named pointer that is shared among all extension modules (using the same
Expand Down
11 changes: 9 additions & 2 deletions include/pybind11/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,16 @@ struct type_caster<std::function<Return(Args...)>> {
*/
if (auto cfunc = func.cpp_function()) {
auto *cfunc_self = PyCFunction_GET_SELF(cfunc.ptr());
EthanSteinberg marked this conversation as resolved.
Show resolved Hide resolved
if (isinstance<capsule>(cfunc_self)) {
if (cfunc_self == nullptr) {
PyErr_Clear();
} else if (isinstance<capsule>(cfunc_self)) {
auto c = reinterpret_borrow<capsule>(cfunc_self);
auto *rec = (function_record *) c;

function_record *rec = nullptr;
// Check that we can safely reinterpret the capsule into a function_record
if (detail::is_function_record_capsule(c)) {
rec = c.get_pointer<function_record>();
}

while (rec != nullptr) {
if (rec->is_stateless
Expand Down
44 changes: 34 additions & 10 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,20 @@ class cpp_function : public function {
if (rec->sibling) {
if (PyCFunction_Check(rec->sibling.ptr())) {
auto *self = PyCFunction_GET_SELF(rec->sibling.ptr());
capsule rec_capsule = isinstance<capsule>(self) ? reinterpret_borrow<capsule>(self)
: capsule(self);
chain = (detail::function_record *) rec_capsule;
/* Never append a method to an overload chain of a parent class;
instead, hide the parent's overloads in this case */
if (!chain->scope.is(rec->scope)) {
if (!isinstance<capsule>(self)) {
chain = nullptr;
} else {
auto rec_capsule = reinterpret_borrow<capsule>(self);
if (detail::is_function_record_capsule(rec_capsule)) {
chain = rec_capsule.get_pointer<detail::function_record>();
/* Never append a method to an overload chain of a parent class;
instead, hide the parent's overloads in this case */
if (!chain->scope.is(rec->scope)) {
chain = nullptr;
}
} else {
chain = nullptr;
}
}
}
// Don't trigger for things like the default __init__, which are wrapper_descriptors
Expand All @@ -496,6 +503,7 @@ class cpp_function : public function {

capsule rec_capsule(unique_rec.release(),
[](void *ptr) { destruct((detail::function_record *) ptr); });
rec_capsule.set_name(detail::get_function_record_capsule_name());
guarded_strdup.release();

object scope_module;
Expand Down Expand Up @@ -661,10 +669,13 @@ class cpp_function : public function {
/// Main dispatch logic for calls to functions bound using pybind11
static PyObject *dispatcher(PyObject *self, PyObject *args_in, PyObject *kwargs_in) {
using namespace detail;
assert(isinstance<capsule>(self));

/* Iterator over the list of potentially admissible overloads */
const function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr),
const function_record *overloads = reinterpret_cast<function_record *>(
PyCapsule_GetPointer(self, get_function_record_capsule_name())),
*it = overloads;
assert(overloads != nullptr);

/* Need to know how many arguments + keyword arguments there are to pick the right
overload */
Expand Down Expand Up @@ -1871,9 +1882,22 @@ class class_ : public detail::generic_type {

static detail::function_record *get_function_record(handle h) {
h = detail::get_function(h);
return h ? (detail::function_record *) reinterpret_borrow<capsule>(
PyCFunction_GET_SELF(h.ptr()))
: nullptr;
if (!h) {
return nullptr;
}

handle func_self = PyCFunction_GET_SELF(h.ptr());
if (!func_self) {
throw error_already_set();
}
if (!isinstance<capsule>(func_self)) {
return nullptr;
}
auto cap = reinterpret_borrow<capsule>(func_self);
if (!detail::is_function_record_capsule(cap)) {
return nullptr;
}
return cap.get_pointer<detail::function_record>();
}
};

Expand Down
37 changes: 37 additions & 0 deletions tests/test_callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,4 +240,41 @@ TEST_SUBMODULE(callbacks, m) {
f();
}
});

auto *custom_def = []() {
static PyMethodDef def;
def.ml_name = "example_name";
def.ml_doc = "Example doc";
def.ml_meth = [](PyObject *, PyObject *args) -> PyObject * {
if (PyTuple_Size(args) != 1) {
throw std::runtime_error("Invalid number of arguments for example_name");
}
PyObject *first = PyTuple_GetItem(args, 0);
if (!PyLong_Check(first)) {
throw std::runtime_error("Invalid argument to example_name");
}
auto result = py::cast(PyLong_AsLong(first) * 9);
return result.release().ptr();
};
def.ml_flags = METH_VARARGS;
return &def;
}();

// rec_capsule with name that has the same value (but not pointer) as our internal one
// This capsule should be detected by our code as foreign and not inspected as the pointers
// shouldn't match
constexpr const char *rec_capsule_name
= pybind11::detail::internals_function_record_capsule_name;
py::capsule rec_capsule(std::malloc(1), [](void *data) { std::free(data); });
rec_capsule.set_name(rec_capsule_name);
m.add_object("custom_function", PyCFunction_New(custom_def, rec_capsule.ptr()));

// This test requires a new ABI version to pass
#if PYBIND11_INTERNALS_VERSION > 4
// rec_capsule with nullptr name
py::capsule rec_capsule2(std::malloc(1), [](void *data) { std::free(data); });
m.add_object("custom_function2", PyCFunction_New(custom_def, rec_capsule2.ptr()));
#else
m.add_object("custom_function2", py::none());
#endif
}
13 changes: 13 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,16 @@ def test_callback_num_times():
if len(rates) > 1:
print("Min Mean Max")
print(f"{min(rates):6.3f} {sum(rates) / len(rates):6.3f} {max(rates):6.3f}")


def test_custom_func():
assert m.custom_function(4) == 36
assert m.roundtrip(m.custom_function)(4) == 36

EthanSteinberg marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.skipif(
m.custom_function2 is None, reason="Current PYBIND11_INTERNALS_VERSION too low"
)
def test_custom_func2():
assert m.custom_function2(3) == 27
assert m.roundtrip(m.custom_function2)(3) == 27