diff --git a/include/gridtools/storage/adapter/nanobind_adapter.hpp b/include/gridtools/storage/adapter/nanobind_adapter.hpp index 5a190b4c1..4ce1a9fe3 100644 --- a/include/gridtools/storage/adapter/nanobind_adapter.hpp +++ b/include/gridtools/storage/adapter/nanobind_adapter.hpp @@ -78,7 +78,7 @@ namespace gridtools { class Strides = fully_dynamic_strides, class StridesKind = sid::unknown_kind> auto as_sid(nanobind::ndarray, Args...> ndarray, - Strides stride_spec_ = {}, + Strides stride_spec = {}, StridesKind = {}) { using sid::property; const auto ptr = ndarray.data(); @@ -86,13 +86,13 @@ namespace gridtools { assert(ndim == ndarray.ndim()); gridtools::array shape; std::copy_n(ndarray.shape_ptr(), ndim, shape.begin()); - gridtools::array strides_; - std::copy_n(ndarray.stride_ptr(), ndim, strides_.begin()); - const auto strides = select_static_strides(stride_spec_, strides_.data()); + gridtools::array strides; + std::copy_n(ndarray.stride_ptr(), ndim, strides.begin()); + const auto static_strides = select_static_strides(stride_spec, strides.data()); return sid::synthetic() - .template set(sid::host_device::simple_ptr_holder{ptr}) - .template set(strides) + .template set(sid::host_device::simple_ptr_holder{ptr}) + .template set(static_strides) .template set() .template set(gridtools::array, ndim>()) .template set(shape); diff --git a/tests/unit_tests/storage/adapter/CMakeLists.txt b/tests/unit_tests/storage/adapter/CMakeLists.txt index ac7c8e3a2..5f00fd329 100644 --- a/tests/unit_tests/storage/adapter/CMakeLists.txt +++ b/tests/unit_tests/storage/adapter/CMakeLists.txt @@ -21,7 +21,7 @@ if (${GT_TESTS_ENABLE_PYTHON_TESTS}) FetchContent_Declare( nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git - GIT_TAG v2.0.0 + GIT_TAG v2.1.0 ) FetchContent_MakeAvailable(nanobind) nanobind_build_library(nanobind-static) diff --git a/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp b/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp index 552b0b0e5..786d178b4 100644 --- a/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp +++ b/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp @@ -42,6 +42,26 @@ TEST_F(python_init_fixture, NanobindAdapterDataDynStrides) { EXPECT_EQ(strides[1], gridtools::get<1>(s_strides)); } +TEST_F(python_init_fixture, NanobindAdapterReadOnly) { + const auto data = reinterpret_cast(0xDEADBEEF); + constexpr int ndim = 2; + constexpr std::array shape = {3, 4}; + constexpr std::array strides = {1, 3}; + nb::ndarray, nb::ro> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; + + const auto sid = gridtools::nanobind::as_sid(ndarray); + using element_t = gridtools::sid::element_type; + static_assert(std::is_same_v); + + const auto s_origin = sid_get_origin(sid); + const auto s_strides = sid_get_strides(sid); + const auto s_ptr = s_origin(); + + EXPECT_EQ(s_ptr, data); + EXPECT_EQ(strides[0], gridtools::get<0>(s_strides)); + EXPECT_EQ(strides[1], gridtools::get<1>(s_strides)); +} + TEST_F(python_init_fixture, NanobindAdapterStaticStridesMatch) { const auto data = reinterpret_cast(0xDEADBEEF); constexpr int ndim = 2;