Skip to content

Commit

Permalink
Cast member pointer arguments of class_<T>::def() to member-of-T
Browse files Browse the repository at this point in the history
  • Loading branch information
oremanj committed Mar 12, 2024
1 parent e80edb1 commit 5f632d6
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 10 deletions.
45 changes: 37 additions & 8 deletions include/nanobind/nb_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,37 @@ struct is_copy_constructible : std::is_copy_constructible<T> { };
template <typename T>
constexpr bool is_copy_constructible_v = is_copy_constructible<T>::value;

// Like is_base_of, but also requires that the base type is accessible (i.e. that a Derived
// pointer can be converted to a Base pointer). For unions, `is_base_of<T, T>::value` is False, so
// we need to check `is_same` as well.
template <typename Base, typename Derived>
constexpr bool is_accessible_base_of_v =
(std::is_same_v<Base, Derived> || std::is_base_of_v<Base, Derived>)
&& std::is_convertible_v<Derived *, Base *>;

// Given a pointer to a member function, cast it to its `Derived` version.
// Forward everything else unchanged. Analogous to method_adaptor in pybind11.
template <typename /*Derived*/, typename F>
auto with_self(F &&f) -> decltype(std::forward<F>(f)) {
return std::forward<F>(f);
}

template <typename Derived, typename Return, typename Class, typename... Args>
auto with_self(Return (Class::*pmf)(Args...)) -> Return (Derived::*)(Args...) {
static_assert(is_accessible_base_of_v<Class, Derived>,
"Cannot bind an inaccessible base class method; "
"use a lambda definition instead");
return pmf;
}

template <typename Derived, typename Return, typename Class, typename... Args>
auto with_self(Return (Class::*pmf)(Args...) const) -> Return (Derived::*)(Args...) const {
static_assert(is_accessible_base_of_v<Class, Derived>,
"Cannot bind an inaccessible base class method; "
"use a lambda definition instead");
return pmf;
}

NAMESPACE_END(detail)

// Low level access to nanobind type objects
Expand Down Expand Up @@ -491,8 +522,8 @@ class class_ : public object {

template <typename Func, typename... Extra>
NB_INLINE class_ &def(const char *name_, Func &&f, const Extra &... extra) {
cpp_function_def((detail::forward_t<Func>) f, scope(*this), name(name_),
is_method(), extra...);
cpp_function_def(detail::with_self<T>((detail::forward_t<Func>) f),
scope(*this), name(name_), is_method(), extra...);
return *this;
}

Expand Down Expand Up @@ -525,12 +556,12 @@ class class_ : public object {
object get_p, set_p;

if constexpr (!std::is_same_v<Getter, std::nullptr_t>)
get_p = cpp_function((detail::forward_t<Getter>) getter,
get_p = cpp_function(detail::with_self<T>((detail::forward_t<Getter>) getter),
is_method(), is_getter(),
rv_policy::reference_internal, detail::filter_getter(extra)...);

if constexpr (!std::is_same_v<Setter, std::nullptr_t>)
set_p = cpp_function((detail::forward_t<Setter>) setter,
set_p = cpp_function(detail::with_self<T>((detail::forward_t<Setter>) setter),
is_method(), detail::filter_setter(extra)...);

detail::property_install(m_ptr, name_, get_p.ptr(), set_p.ptr());
Expand Down Expand Up @@ -572,8 +603,7 @@ class class_ : public object {
template <typename C, typename D, typename... Extra>
NB_INLINE class_ &def_rw(const char *name, D C::*p,
const Extra &...extra) {
// Unions never satisfy is_base_of, thus the is_same alternative
static_assert(std::is_base_of_v<C, T> || std::is_same_v<C, T>,
static_assert(detail::is_accessible_base_of_v<C, T>,
"def_rw() requires a (base) class member!");

using Q =
Expand Down Expand Up @@ -605,8 +635,7 @@ class class_ : public object {
template <typename C, typename D, typename... Extra>
NB_INLINE class_ &def_ro(const char *name, D C::*p,
const Extra &...extra) {
// Unions never satisfy is_base_of, thus the is_same alternative
static_assert(std::is_base_of_v<C, T> || std::is_same_v<C, T>,
static_assert(detail::is_accessible_base_of_v<C, T>,
"def_ro() requires a (base) class member!");

def_prop_ro(name,
Expand Down
16 changes: 16 additions & 0 deletions tests/test_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,4 +575,20 @@ NB_MODULE(test_classes_ext, m) {
.def(nb::init<>())
.def_rw("i", &Union::i)
.def_rw("f", &Union::f);

struct HiddenBase {
int value = 10;
int vget() const { return value; }
void vset(int v) { value = v; }
int get_answer() const { return value * 10; }
};
struct BoundDerived : HiddenBase {
virtual int polymorphic() { return value; }
};
nb::class_<BoundDerived>(m, "BoundDerived")
.def(nb::init<>())
.def_rw("value", &BoundDerived::value)
.def_prop_rw("prop", &BoundDerived::vget, &BoundDerived::vset)
.def("get_answer", &BoundDerived::get_answer)
.def("polymorphic", &BoundDerived::polymorphic);
}
14 changes: 12 additions & 2 deletions tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,17 +801,27 @@ def test42_weak_references():


def test43_union():
import struct

u = t.Union()
u.i = 42
assert u.i == 42

u.f = 2.125
assert u.f == 2.125


def test44_dynamic_attr_has_dict():
s = t.StructWithAttr(5)
assert s.__dict__ == {}
s.a_dynamic_attr = 101
assert s.__dict__ == {"a_dynamic_attr": 101}


def test45_hidden_base():
s = t.BoundDerived()
assert s.value == 10
s.value = 5
assert s.prop == 5
s.prop = 20
assert s.value == 20
assert s.get_answer() == 200
assert s.polymorphic() == 20
19 changes: 19 additions & 0 deletions tests/test_classes_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ class Big:
class BigAligned:
def __init__(self) -> None: ...

class BoundDerived:
def __init__(self) -> None: ...

@property
def value(self) -> int: ...

@value.setter
def value(self, arg: int, /) -> None: ...

@property
def prop(self) -> int: ...

@prop.setter
def prop(self, arg: int, /) -> None: ...

def get_answer(self) -> int: ...

def polymorphic(self) -> int: ...

class C:
def __init__(self, arg: int, /) -> None: ...

Expand Down

0 comments on commit 5f632d6

Please sign in to comment.