Skip to content

Commit

Permalink
added a few missing methods to container wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Apr 2, 2024
1 parent 369e79a commit 4148e83
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 2 deletions.
21 changes: 20 additions & 1 deletion docs/api_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,14 @@ Wrapper classes
negative). When `T` does not already represent a wrapped Python object,
the function performs a cast.

.. cpp:function:: void clear()

Clear the list entries.

.. cpp:function:: void extend(handle h)

Analogous to the ``.extend(h)`` method of ``list`` in Python.

.. cpp:function:: template <typename T, enable_if_t<std::is_arithmetic_v<T>> = 1> detail::accessor<num_item_list> operator[](T key) const

Analogous to ``self[key]`` in Python, where ``key`` is an arithmetic
Expand Down Expand Up @@ -794,6 +802,10 @@ Wrapper classes

Clear the contents of the dictionary.

.. cpp:function:: void update(handle h)

Analogous to the ``.update(h)`` method of ``dict`` in Python.

.. cpp:class:: set: public object

Wrapper class representing Python ``set`` instances.
Expand All @@ -818,7 +830,14 @@ Wrapper classes

.. cpp:function:: void clear()

Clear the contents of the set
Clear the contents of the set.

.. cpp:function:: template <typename T> bool discard(T&& key)

Analogous to the ``.discard(h)`` method of the ``set`` type in Python.
Returns ``true`` if the item was deleted successfully, and ``false`` if
the value was not present. When `T` does not already represent a wrapped
Python object, the function performs a cast.

.. cpp:class:: module_: public object

Expand Down
8 changes: 7 additions & 1 deletion include/nanobind/nb_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,14 +565,20 @@ template <typename T> bool set::contains(T&& key) const {
return rv == 1;
}


template <typename T> void set::add(T&& key) {
object o = nanobind::cast((detail::forward_t<T>) key);
int rv = PySet_Add(m_ptr, o.ptr());
if (rv == -1)
raise_python_error();
}

template <typename T> bool set::discard(T &&value) {
object o = nanobind::cast((detail::forward_t<T>) value);
int rv = PySet_Discard(m_ptr, o.ptr());
if (rv < 0)
raise_python_error();
return rv == 1;
}

template <typename T> bool mapping::contains(T&& key) const {
object o = nanobind::cast((detail::forward_t<T>) key);
Expand Down
15 changes: 15 additions & 0 deletions include/nanobind/nb_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,16 @@ class list : public object {
template <typename T, detail::enable_if_t<std::is_arithmetic_v<T>> = 1>
detail::accessor<detail::num_item_list> operator[](T key) const;

void clear() {
if (PyList_SetSlice(m_ptr, 0, PY_SSIZE_T_MAX, nullptr))
raise_python_error();
}

void extend(handle h) {
if (PyList_SetSlice(m_ptr, PY_SSIZE_T_MAX, PY_SSIZE_T_MAX, h.ptr()))
raise_python_error();
}

#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION)
detail::fast_iterator begin() const;
detail::fast_iterator end() const;
Expand All @@ -488,6 +498,10 @@ class dict : public object {
list items() const { return steal<list>(detail::obj_op_1(m_ptr, PyDict_Items)); }
template <typename T> bool contains(T&& key) const;
void clear() { PyDict_Clear(m_ptr); }
void update(handle h) {
if (PyDict_Update(m_ptr, h.ptr()))
raise_python_error();
}
};


Expand All @@ -501,6 +515,7 @@ class set : public object {
if (PySet_Clear(m_ptr))
raise_python_error();
}
template <typename T> bool discard(T &&value);
};

class sequence : public object {
Expand Down
44 changes: 44 additions & 0 deletions tests/test_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,4 +314,48 @@ NB_MODULE(test_functions_ext, m) {
"i"_a = 1, nb::kw_only(), "j"_a = 2);

m.def("test_any", [](nb::any a) { return a; } );

m.def("test_wrappers_list", []{
nb::list l1, l2;
l1.append(1);
l2.append(2);
l1.extend(l2);

bool b = nb::len(l1) == 2 && nb::len(l2) == 1 &&
l1[0].equal(nb::int_(1)) && l1[1].equal(nb::int_(2));

l1.clear();
return b && nb::len(l1) == 0;
});

m.def("test_wrappers_dict", []{
nb::dict d1, d2;
d1["a"] = 1;
d2["b"] = 2;
d1.update(d2);

bool b = nb::len(d1) == 2 && nb::len(d2) == 1 &&
d1["a"].equal(nb::int_(1)) &&
d1["b"].equal(nb::int_(2));

d1.clear();
return b && nb::len(d1) == 0;
});

m.def("test_wrappers_set", []{
nb::set s;
s.add("a");
s.add("b");

bool b = nb::len(s) == 2 && s.contains("a") && s.contains("b");

b &= s.discard("a");
b &= !s.discard("q");

b &= !s.contains("a") && s.contains("b");
s.clear();
b &= s.size() == 0;

return b;
});
}
9 changes: 9 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,12 @@ def test41_any():
s = "hello"
assert t.test_any(s) is s
assert t.test_any.__doc__ == "test_any(arg: typing.Any, /) -> typing.Any"

def test42_wrappers_list():
assert t.test_wrappers_list()

def test43_wrappers_dict():
assert t.test_wrappers_dict()

def test43_wrappers_set():
assert t.test_wrappers_set()
6 changes: 6 additions & 0 deletions tests/test_functions_ext.pyi.ref
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,9 @@ def test_tuple() -> tuple: ...

@overload
def test_tuple(arg: tuple, /) -> int: ...

def test_wrappers_dict() -> bool: ...

def test_wrappers_list() -> bool: ...

def test_wrappers_set() -> bool: ...

0 comments on commit 4148e83

Please sign in to comment.