From 4b65e7c93de7f86090a83be3154adf09040441d4 Mon Sep 17 00:00:00 2001 From: Marcin Wojdyr Date: Tue, 13 Aug 2024 16:52:30 +0200 Subject: [PATCH] support None for const char* arg, with minimal overhead Since pybind11 also accepts None for const char* as an exception among built-in types, remove a paragraph about differences in porting.rst. --- docs/api_core.rst | 8 +++++--- docs/porting.rst | 6 ------ include/nanobind/nb_cast.h | 3 ++- tests/test_functions.cpp | 1 + tests/test_functions.py | 2 ++ tests/test_functions_ext.pyi.ref | 2 ++ 6 files changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/api_core.rst b/docs/api_core.rst index b90191a4..ae6d9911 100644 --- a/docs/api_core.rst +++ b/docs/api_core.rst @@ -1622,9 +1622,11 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`, Set a flag noting that the function argument accepts ``None``. Can only be used for python wrapper types (e.g. :cpp:class:`handle`, - :cpp:class:`int_`) and types that have been bound using - :cpp:class:`class_`. You cannot use this to implement functions that - accept null pointers to builtin C++ types like ``int *i = nullptr``. + :cpp:class:`int_`), for types that have been bound using + :cpp:class:`class_`, for :cpp:class:`ndarray`, ``std::optional``, + ``std::variant``, and ``const char*``. + You cannot use this to implement functions that + accept null pointers to other builtin C++ types like ``int *i = nullptr``. .. cpp:function:: arg &noconvert(bool value = true) diff --git a/docs/porting.rst b/docs/porting.rst index 832b7ae1..45365705 100644 --- a/docs/porting.rst +++ b/docs/porting.rst @@ -86,12 +86,6 @@ It is also possible to set a ``None`` default value, in which case the m.def("func", &func, "arg"_a = nb::none()); -``None``-valued arguments are only supported by two of the three parameter -passing styles described in the section on :ref:`information exchange -`. In particular, they are supported by :ref:`bindings ` -and :ref:`wrappers `, *but not* by :ref:`type casters -`. - Shared pointers and holders --------------------------- diff --git a/include/nanobind/nb_cast.h b/include/nanobind/nb_cast.h index 8fcbad67..4a6a1d7d 100644 --- a/include/nanobind/nb_cast.h +++ b/include/nanobind/nb_cast.h @@ -286,7 +286,8 @@ template <> struct type_caster { value = PyUnicode_AsUTF8AndSize(src.ptr(), &size); if (!value) { PyErr_Clear(); - return false; + // optimize for src being a string, check for None afterwards + return src.is_none(); } return true; } diff --git a/tests/test_functions.cpp b/tests/test_functions.cpp index cee371cd..23c95d55 100644 --- a/tests/test_functions.cpp +++ b/tests/test_functions.cpp @@ -161,6 +161,7 @@ NB_MODULE(test_functions_ext, m) { // Test string caster m.def("test_12", [](const char *c) { return nb::str(c); }); + m.def("test_12_n", [](const char *c) { return nb::str(c ? c : "n/a"); }, nb::arg().none()); m.def("test_13", []() -> const char * { return "test"; }); m.def("test_14", [](nb::object o) -> const char * { return nb::cast(o); }); diff --git a/tests/test_functions.py b/tests/test_functions.py index 2217e7d5..4c127fcc 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -236,6 +236,8 @@ def test21_numpy_overloads(): def test22_string_return(): assert t.test_12("hello") == "hello" + assert t.test_12_n("hello") == "hello" + assert t.test_12_n(None) == "n/a" assert t.test_13() == "test" assert t.test_14("abc") == "abc" diff --git a/tests/test_functions_ext.pyi.ref b/tests/test_functions_ext.pyi.ref index dcaaca0c..a730a91c 100644 --- a/tests/test_functions_ext.pyi.ref +++ b/tests/test_functions_ext.pyi.ref @@ -84,6 +84,8 @@ def test_11_ull(arg: int, /) -> int: ... def test_12(arg: str, /) -> str: ... +def test_12_n(arg: str | None) -> str: ... + def test_13() -> str: ... def test_14(arg: object, /) -> str: ...