diff --git a/docs/api_core.rst b/docs/api_core.rst index 6af938d41..c41df4392 100644 --- a/docs/api_core.rst +++ b/docs/api_core.rst @@ -670,7 +670,7 @@ Wrapper classes Return the number of list elements. - .. cpp:function:: template void append(T &&value) + .. cpp:function:: template void append(T&& value) Append an element to the list. When `T` does not already represent a wrapped Python object, the function performs a cast. @@ -707,6 +707,11 @@ Wrapper classes Return the number of dictionary elements. + .. cpp:function:: template bool contains(T&& key) const + + Check whether the dictionary contains a particular key. When `T` does not + already represent a wrapped Python object, the function performs a cast. + .. cpp:function:: detail::dict_iterator begin() const Return an item iterator that returns ``std::pair`` @@ -913,6 +918,11 @@ Wrapper classes Wrapper class representing arbitrary Python mapping types. + .. cpp:function:: template bool contains(T&& key) const + + Check whether the map contains a particular key. When `T` does not + already represent a wrapped Python object, the function performs a cast. + .. cpp:function:: list keys() const Return a list containing all of the map's keys. diff --git a/include/nanobind/nb_cast.h b/include/nanobind/nb_cast.h index 1e947d81c..536bb8de4 100644 --- a/include/nanobind/nb_cast.h +++ b/include/nanobind/nb_cast.h @@ -453,4 +453,20 @@ template void list::append(T &&value) { detail::raise_python_error(); } +template bool dict::contains(T&& key) const { + object o = nanobind::cast((detail::forward_t) key); + int rv = PyDict_Contains(m_ptr, o.ptr()); + if (rv == -1) + detail::raise_python_error(); + return rv == 1; +} + +template bool mapping::contains(T&& key) const { + object o = nanobind::cast((detail::forward_t) key); + int rv = PyMapping_HasKey(m_ptr, o.ptr()); + if (rv == -1) + detail::raise_python_error(); + return rv == 1; +} + NAMESPACE_END(NB_NAMESPACE) diff --git a/include/nanobind/nb_types.h b/include/nanobind/nb_types.h index 95af912b5..5a397e997 100644 --- a/include/nanobind/nb_types.h +++ b/include/nanobind/nb_types.h @@ -463,6 +463,7 @@ class dict : public object { list keys() const { return steal(detail::obj_op_1(m_ptr, PyDict_Keys)); } list values() const { return steal(detail::obj_op_1(m_ptr, PyDict_Values)); } list items() const { return steal(detail::obj_op_1(m_ptr, PyDict_Items)); } + template bool contains(T&& key) const; }; class sequence : public object { @@ -474,6 +475,7 @@ class mapping : public object { list keys() const { return steal(detail::obj_op_1(m_ptr, PyMapping_Keys)); } list values() const { return steal(detail::obj_op_1(m_ptr, PyMapping_Values)); } list items() const { return steal(detail::obj_op_1(m_ptr, PyMapping_Items)); } + template bool contains(T&& key) const; }; class args : public tuple { diff --git a/tests/test_functions.cpp b/tests/test_functions.cpp index 9d29c068a..bf9029e77 100644 --- a/tests/test_functions.cpp +++ b/tests/test_functions.cpp @@ -138,6 +138,10 @@ NB_MODULE(test_functions_ext, m) { return result; }); + m.def("test_10_contains", [](nb::dict d) { + return d.contains(nb::str("foo")); + }); + // Test implicit conversion of various types m.def("test_11_sl", [](signed long x) { return x; }); m.def("test_11_ul", [](unsigned long x) { return x; });