Skip to content

Commit

Permalink
Cast member pointer arguments of class_<T>::def() to member-of-T (#471)
Browse files Browse the repository at this point in the history
  • Loading branch information
oremanj authored Mar 12, 2024
1 parent 026aea0 commit 754e1b2
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 25 deletions.
7 changes: 7 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ noteworthy:
indicated using the new :cpp:struct:`nb::kw_only() <kw_only>` function
annotation. (PR `#448 <https://github.com/wjakob/nanobind/pull/448>`__).

* When binding methods on a class ``T``, nanobind will now produce a Python
function that expects a self argument of type ``T``. Previously, it would
use the type of the member pointer to determine the Python function
signature, which could be a base of ``T``, which would create problems
if nanobind did not know about that base.
(PR `#471 <https://github.com/wjakob/nanobind/pull/471>`__).

* ABI version 14.

.. rubric:: Footnote
Expand Down
15 changes: 8 additions & 7 deletions include/nanobind/nb_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,8 @@ class class_ : public object {

template <typename Func, typename... Extra>
NB_INLINE class_ &def(const char *name_, Func &&f, const Extra &... extra) {
cpp_function_def((detail::forward_t<Func>) f, scope(*this), name(name_),
is_method(), extra...);
cpp_function_def<T>((detail::forward_t<Func>) f, scope(*this),
name(name_), is_method(), extra...);
return *this;
}

Expand Down Expand Up @@ -525,13 +525,14 @@ class class_ : public object {
object get_p, set_p;

if constexpr (!std::is_same_v<Getter, std::nullptr_t>)
get_p = cpp_function((detail::forward_t<Getter>) getter,
is_method(), is_getter(),
rv_policy::reference_internal, detail::filter_getter(extra)...);
get_p = cpp_function<T>((detail::forward_t<Getter>) getter,
is_method(), is_getter(),
rv_policy::reference_internal,
detail::filter_getter(extra)...);

if constexpr (!std::is_same_v<Setter, std::nullptr_t>)
set_p = cpp_function((detail::forward_t<Setter>) setter,
is_method(), detail::filter_setter(extra)...);
set_p = cpp_function<T>((detail::forward_t<Setter>) setter,
is_method(), detail::filter_setter(extra)...);

detail::property_install(m_ptr, name_, get_p.ptr(), set_p.ptr());
return *this;
Expand Down
45 changes: 29 additions & 16 deletions include/nanobind/nb_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,21 +244,26 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...),

NAMESPACE_END(detail)

template <typename Return, typename... Args, typename... Extra>
// The initial template parameter to cpp_function/cpp_function_def is
// used by class_ to ensure that member pointers are treated as members
// of the class being defined; other users can safely leave it at its
// default of void.

template <typename = void, typename Return, typename... Args, typename... Extra>
NB_INLINE object cpp_function(Return (*f)(Args...), const Extra&... extra) {
return steal(detail::func_create<true, true>(
f, f, std::make_index_sequence<sizeof...(Args)>(), extra...));
}

template <typename Return, typename... Args, typename... Extra>
template <typename = void, typename Return, typename... Args, typename... Extra>
NB_INLINE void cpp_function_def(Return (*f)(Args...), const Extra&... extra) {
detail::func_create<false, true>(
f, f, std::make_index_sequence<sizeof...(Args)>(), extra...);
}

/// Construct a cpp_function from a lambda function (pot. with internal state)
template <
typename Func, typename... Extra,
typename = void, typename Func, typename... Extra,
detail::enable_if_t<detail::is_lambda_v<std::remove_reference_t<Func>>> = 0>
NB_INLINE object cpp_function(Func &&f, const Extra &...extra) {
using am = detail::analyze_method<decltype(&std::remove_reference_t<Func>::operator())>;
Expand All @@ -268,7 +273,7 @@ NB_INLINE object cpp_function(Func &&f, const Extra &...extra) {
}

template <
typename Func, typename... Extra,
typename = void, typename Func, typename... Extra,
detail::enable_if_t<detail::is_lambda_v<std::remove_reference_t<Func>>> = 0>
NB_INLINE void cpp_function_def(Func &&f, const Extra &...extra) {
using am = detail::analyze_method<decltype(&std::remove_reference_t<Func>::operator())>;
Expand All @@ -278,44 +283,52 @@ NB_INLINE void cpp_function_def(Func &&f, const Extra &...extra) {
}

/// Construct a cpp_function from a class method (non-const)
template <typename Return, typename Class, typename... Args, typename... Extra>
template <typename Target = void,
typename Return, typename Class, typename... Args, typename... Extra>
NB_INLINE object cpp_function(Return (Class::*f)(Args...), const Extra &...extra) {
using T = std::conditional_t<std::is_void_v<Target>, Class, Target>;
return steal(detail::func_create<true, true>(
[f](Class *c, Args... args) NB_INLINE_LAMBDA -> Return {
[f](T *c, Args... args) NB_INLINE_LAMBDA -> Return {
return (c->*f)((detail::forward_t<Args>) args...);
},
(Return(*)(Class *, Args...)) nullptr,
(Return(*)(T *, Args...)) nullptr,
std::make_index_sequence<sizeof...(Args) + 1>(), extra...));
}

template <typename Return, typename Class, typename... Args, typename... Extra>
template <typename Target = void,
typename Return, typename Class, typename... Args, typename... Extra>
NB_INLINE void cpp_function_def(Return (Class::*f)(Args...), const Extra &...extra) {
using T = std::conditional_t<std::is_void_v<Target>, Class, Target>;
detail::func_create<false, true>(
[f](Class *c, Args... args) NB_INLINE_LAMBDA -> Return {
[f](T *c, Args... args) NB_INLINE_LAMBDA -> Return {
return (c->*f)((detail::forward_t<Args>) args...);
},
(Return(*)(Class *, Args...)) nullptr,
(Return(*)(T *, Args...)) nullptr,
std::make_index_sequence<sizeof...(Args) + 1>(), extra...);
}

/// Construct a cpp_function from a class method (const)
template <typename Return, typename Class, typename... Args, typename... Extra>
template <typename Target = void,
typename Return, typename Class, typename... Args, typename... Extra>
NB_INLINE object cpp_function(Return (Class::*f)(Args...) const, const Extra &...extra) {
using T = std::conditional_t<std::is_void_v<Target>, Class, Target>;
return steal(detail::func_create<true, true>(
[f](const Class *c, Args... args) NB_INLINE_LAMBDA -> Return {
[f](const T *c, Args... args) NB_INLINE_LAMBDA -> Return {
return (c->*f)((detail::forward_t<Args>) args...);
},
(Return(*)(const Class *, Args...)) nullptr,
(Return(*)(const T *, Args...)) nullptr,
std::make_index_sequence<sizeof...(Args) + 1>(), extra...));
}

template <typename Return, typename Class, typename... Args, typename... Extra>
template <typename Target = void,
typename Return, typename Class, typename... Args, typename... Extra>
NB_INLINE void cpp_function_def(Return (Class::*f)(Args...) const, const Extra &...extra) {
using T = std::conditional_t<std::is_void_v<Target>, Class, Target>;
detail::func_create<false, true>(
[f](const Class *c, Args... args) NB_INLINE_LAMBDA -> Return {
[f](const T *c, Args... args) NB_INLINE_LAMBDA -> Return {
return (c->*f)((detail::forward_t<Args>) args...);
},
(Return(*)(const Class *, Args...)) nullptr,
(Return(*)(const T *, Args...)) nullptr,
std::make_index_sequence<sizeof...(Args) + 1>(), extra...);
}

Expand Down
16 changes: 16 additions & 0 deletions tests/test_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,4 +575,20 @@ NB_MODULE(test_classes_ext, m) {
.def(nb::init<>())
.def_rw("i", &Union::i)
.def_rw("f", &Union::f);

struct HiddenBase {
int value = 10;
int vget() const { return value; }
void vset(int v) { value = v; }
int get_answer() const { return value * 10; }
};
struct BoundDerived : HiddenBase {
virtual int polymorphic() { return value; }
};
nb::class_<BoundDerived>(m, "BoundDerived")
.def(nb::init<>())
.def_rw("value", &BoundDerived::value)
.def_prop_rw("prop", &BoundDerived::vget, &BoundDerived::vset)
.def("get_answer", &BoundDerived::get_answer)
.def("polymorphic", &BoundDerived::polymorphic);
}
14 changes: 12 additions & 2 deletions tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,17 +801,27 @@ def test42_weak_references():


def test43_union():
import struct

u = t.Union()
u.i = 42
assert u.i == 42

u.f = 2.125
assert u.f == 2.125


def test44_dynamic_attr_has_dict():
s = t.StructWithAttr(5)
assert s.__dict__ == {}
s.a_dynamic_attr = 101
assert s.__dict__ == {"a_dynamic_attr": 101}


def test45_hidden_base():
s = t.BoundDerived()
assert s.value == 10
s.value = 5
assert s.prop == 5
s.prop = 20
assert s.value == 20
assert s.get_answer() == 200
assert s.polymorphic() == 20
19 changes: 19 additions & 0 deletions tests/test_classes_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ class Big:
class BigAligned:
def __init__(self) -> None: ...

class BoundDerived:
def __init__(self) -> None: ...

@property
def value(self) -> int: ...

@value.setter
def value(self, arg: int, /) -> None: ...

@property
def prop(self) -> int: ...

@prop.setter
def prop(self, arg: int, /) -> None: ...

def get_answer(self) -> int: ...

def polymorphic(self) -> int: ...

class C:
def __init__(self, arg: int, /) -> None: ...

Expand Down

0 comments on commit 754e1b2

Please sign in to comment.