Skip to content

Commit

Permalink
Separate const of ndarray from const of its data. (wjakob#491)
Browse files Browse the repository at this point in the history
A static assertion ensures that the ``order`` and ``framework``
cannot be changed once set.  For example, a template argument to
``view()`` can only change one of these if it has not been set.
The scalar element type and the ``shape`` can be set multiple times,
and the rightmost one takes effect.  Thus, these can be changed by
specifying template arguments to ``view()``.
The read-only boolean, ``is_ro``, can be set either by specifying
``nb::ro`` or by specifying a const-qualified scalar element type.
It is sticky.  Once set, it cannot be cleared.  So, a read-only
view of writable data can be obtained, but not the reverse.
  • Loading branch information
hpkfft authored and TaylorP committed Aug 9, 2024
1 parent 594c7fc commit 97b1e37
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 49 deletions.
29 changes: 16 additions & 13 deletions docs/api_extra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,12 @@ section <ndarrays>`.

.. cpp:class:: template <typename... Args> ndarray

.. cpp:var:: ReadOnly

A constant static boolean that is true if the array's data is read-only.
This is determined by the class template arguments, not by any dynamic
properties of the referenced array.

.. cpp:function:: ndarray() = default

Create an invalid array.
Expand Down Expand Up @@ -793,14 +799,19 @@ section <ndarrays>`.
In a multi-device/GPU setup, this function returns the ID of the device
storing the array.

.. cpp:function:: const Scalar * data() const
.. cpp:function:: Scalar * data() const

Return a pointer to the array data.
If :cpp:var:`ReadOnly` is true, a pointer-to-const is returned.

Return a const pointer to the array data.
.. cpp:function:: template <typename... Ts> auto& operator()(Ts... indices)

.. cpp:function:: Scalar * data()
Return a reference to the element stored at the provided index/indices.
If :cpp:var:`ReadOnly` is true, a reference-to-const is returned.
Note that ``sizeof(Ts)`` must match :cpp:func:`ndim()`.

Return a mutable pointer to the array data. Only enabled when `Scalar` is
not itself ``const``.
This accessor is only available when the scalar type and array dimension
were specified as template parameters.

.. cpp:function:: template <typename... Extra> auto view()

Expand All @@ -813,14 +824,6 @@ section <ndarrays>`.
``shape()``, ``stride()``, and ``operator()`` following the conventions
of the `ndarray` type.

.. cpp:function:: template <typename... Ts> auto& operator()(Ts... indices)

Return a mutable reference to the element at stored at the provided
index/indices. ``sizeof(Ts)`` must match :cpp:func:`ndim()`.

This accessor is only available when the scalar type and array dimension
were specified as template parameters.

Data types
^^^^^^^^^^

Expand Down
4 changes: 2 additions & 2 deletions docs/ndarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ Binding functions that take arrays as input
-------------------------------------------

A function that accepts a :cpp:class:`nb::ndarray\<\> <ndarray>`-typed parameter
(i.e., *without* template parameters) can be called with *any* array
(i.e., *without* template parameters) can be called with *any* writable array
from any framework regardless of the device on which it is stored. The
following example binding declaration uses this functionality to inspect the
properties of an arbitrary input array:

.. code-block:: cpp
m.def("inspect", [](nb::ndarray<> a) {
m.def("inspect", [](const nb::ndarray<>& a) {
printf("Array data pointer : %p\n", a.data());
printf("Array dimension : %zu\n", a.ndim());
for (size_t i = 0; i < a.ndim(); ++i) {
Expand Down
86 changes: 55 additions & 31 deletions include/nanobind/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ template <size_t N> using ndim = typename detail::ndim_shape<std::make_index_seq
template <typename T> constexpr dlpack::dtype dtype() {
static_assert(
detail::is_ndarray_scalar_v<T>,
"nanobind::dtype<T>: T must be a floating point or integer variable!"
"nanobind::dtype<T>: T must be a floating point or integer type!"
);

dlpack::dtype result;
Expand Down Expand Up @@ -266,46 +266,76 @@ template <typename T> struct ndarray_arg<T, enable_if_t<T::is_device>> {
template <typename... Ts> struct ndarray_info {
using scalar_type = void;
using shape_type = void;
constexpr static bool ReadOnly = false;
constexpr static auto name = const_name("ndarray");
constexpr static ndarray_framework framework = ndarray_framework::none;
constexpr static char order = '\0';
};

template <typename T, typename... Ts> struct ndarray_info<T, Ts...> : ndarray_info<Ts...> {
using scalar_type =
std::conditional_t<ndarray_traits<T>::is_float || ndarray_traits<T>::is_int ||
ndarray_traits<T>::is_bool || ndarray_traits<T>::is_complex,
T, typename ndarray_info<Ts...>::scalar_type>;
std::conditional_t<
detail::is_ndarray_scalar_v<T> &&
std::is_void_v<typename ndarray_info<Ts...>::scalar_type>,
T, typename ndarray_info<Ts...>::scalar_type>;

constexpr static bool ReadOnly = ndarray_info<Ts...>::ReadOnly ||
(detail::is_ndarray_scalar_v<T> && std::is_const_v<T>);
};

template <typename... Ts> struct ndarray_info<ro, Ts...> : ndarray_info<Ts...> {
constexpr static bool ReadOnly = true;
};

template <ssize_t... Is, typename... Ts> struct ndarray_info<shape<Is...>, Ts...> : ndarray_info<Ts...> {
using shape_type = shape<Is...>;
using shape_type =
std::conditional_t<
std::is_void_v<typename ndarray_info<Ts...>::shape_type>,
shape<Is...>, typename ndarray_info<Ts...>::shape_type>;
};

template <typename... Ts> struct ndarray_info<c_contig, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::order == '\0'
|| ndarray_info<Ts...>::order == 'C',
"The order can only be set once.");
constexpr static char order = 'C';
};

template <typename... Ts> struct ndarray_info<f_contig, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::order == '\0'
|| ndarray_info<Ts...>::order == 'F',
"The order can only be set once.");
constexpr static char order = 'F';
};

template <typename... Ts> struct ndarray_info<numpy, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::framework == ndarray_framework::none
|| ndarray_info<Ts...>::framework == ndarray_framework::numpy,
"The framework can only be set once.");
constexpr static auto name = const_name("numpy.ndarray");
constexpr static ndarray_framework framework = ndarray_framework::numpy;
};

template <typename... Ts> struct ndarray_info<pytorch, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::framework == ndarray_framework::none
|| ndarray_info<Ts...>::framework == ndarray_framework::pytorch,
"The framework can only be set once.");
constexpr static auto name = const_name("torch.Tensor");
constexpr static ndarray_framework framework = ndarray_framework::pytorch;
};

template <typename... Ts> struct ndarray_info<tensorflow, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::framework == ndarray_framework::none
|| ndarray_info<Ts...>::framework == ndarray_framework::tensorflow,
"The framework can only be set once.");
constexpr static auto name = const_name("tensorflow.python.framework.ops.EagerTensor");
constexpr static ndarray_framework framework = ndarray_framework::tensorflow;
};

template <typename... Ts> struct ndarray_info<jax, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::framework == ndarray_framework::none
|| ndarray_info<Ts...>::framework == ndarray_framework::jax,
"The framework can only be set once.");
constexpr static auto name = const_name("jaxlib.xla_extension.DeviceArray");
constexpr static ndarray_framework framework = ndarray_framework::jax;
};
Expand All @@ -314,9 +344,9 @@ template <typename... Ts> struct ndarray_info<cupy, Ts...> : ndarray_info<Ts...>
constexpr static ndarray_framework framework = ndarray_framework::cupy;
};


NAMESPACE_END(detail)


template <typename Scalar, typename Shape, char Order> struct ndarray_view {
static constexpr size_t Dim = Shape::size;

Expand Down Expand Up @@ -380,7 +410,10 @@ template <typename... Args> class ndarray {
template <typename...> friend class ndarray;

using Info = detail::ndarray_info<Args...>;
using Scalar = typename Info::scalar_type;
static constexpr bool ReadOnly = Info::ReadOnly;
using Scalar = std::conditional_t<ReadOnly,
std::add_const_t<typename Info::scalar_type>,
typename Info::scalar_type>;

ndarray() = default;

Expand All @@ -392,7 +425,7 @@ template <typename... Args> class ndarray {
template <typename... Args2>
explicit ndarray(const ndarray<Args2...> &other) : ndarray(other.m_handle) { }

ndarray(std::conditional_t<std::is_const_v<Scalar>, const void *, void *> data,
ndarray(std::conditional_t<ReadOnly, const void *, void *> data,
size_t ndim,
const size_t *shape,
handle owner,
Expand All @@ -402,11 +435,11 @@ template <typename... Args> class ndarray {
int32_t device_id = 0) {
m_handle = detail::ndarray_create(
(void *) data, ndim, shape, owner.ptr(), strides, &dtype,
std::is_const_v<Scalar>, device_type, device_id);
ReadOnly, device_type, device_id);
m_dltensor = *detail::ndarray_inc_ref(m_handle);
}

ndarray(std::conditional_t<std::is_const_v<Scalar>, const void *, void *> data,
ndarray(std::conditional_t<ReadOnly, const void *, void *> data,
std::initializer_list<size_t> shape,
handle owner,
std::initializer_list<int64_t> strides = { },
Expand All @@ -420,7 +453,7 @@ template <typename... Args> class ndarray {
m_handle = detail::ndarray_create(
(void *) data, shape.size(), shape.begin(), owner.ptr(),
(strides.size() == 0) ? nullptr : strides.begin(), &dtype,
std::is_const_v<Scalar>, device_type, device_id);
ReadOnly, device_type, device_id);

m_dltensor = *detail::ndarray_inc_ref(m_handle);
}
Expand Down Expand Up @@ -476,35 +509,26 @@ template <typename... Args> class ndarray {
size_t itemsize() const { return ((size_t) dtype().bits + 7) / 8; }
size_t nbytes() const { return ((size_t) dtype().bits * size() + 7) / 8; }

const Scalar *data() const {
return (const Scalar *)((const uint8_t *) m_dltensor.data + m_dltensor.byte_offset);
}

template <typename T = Scalar, std::enable_if_t<!std::is_const_v<T>, int> = 1>
Scalar *data() {
Scalar *data() const {
return (Scalar *) ((uint8_t *) m_dltensor.data +
m_dltensor.byte_offset);
}

template <typename T = Scalar,
std::enable_if_t<!std::is_const_v<T>, int> = 1, typename... Ts>
NB_INLINE auto &operator()(Ts... indices) {
template <typename... Ts>
NB_INLINE auto& operator()(Ts... indices) const {
return *(Scalar *) ((uint8_t *) m_dltensor.data +
byte_offset(indices...));
}

template <typename... Ts> NB_INLINE const auto & operator()(Ts... indices) const {
return *(const Scalar *) ((const uint8_t *) m_dltensor.data +
byte_offset(indices...));
}

template <typename... Extra> NB_INLINE auto view() const {
using Info2 = typename ndarray<Args..., Extra...>::Info;
using Scalar2 = typename Info2::scalar_type;
using Scalar2 = std::conditional_t<Info2::ReadOnly,
std::add_const_t<typename Info2::scalar_type>,
typename Info2::scalar_type>;
using Shape2 = typename Info2::shape_type;

constexpr bool has_scalar = !std::is_same_v<Scalar2, void>,
has_shape = !std::is_same_v<Shape2, void>;
constexpr bool has_scalar = !std::is_void_v<Scalar2>,
has_shape = !std::is_void_v<Shape2>;

static_assert(has_scalar,
"To use the ndarray::view<..>() method, you must add a scalar type "
Expand All @@ -528,8 +552,8 @@ template <typename... Args> class ndarray {
private:
template <typename... Ts>
NB_INLINE int64_t byte_offset(Ts... indices) const {
constexpr bool has_scalar = !std::is_same_v<Scalar, void>,
has_shape = !std::is_same_v<typename Info::shape_type, void>;
constexpr bool has_scalar = !std::is_void_v<Scalar>,
has_shape = !std::is_void_v<typename Info::shape_type>;

static_assert(has_scalar,
"To use ndarray::operator(), you must add a scalar type "
Expand All @@ -547,7 +571,7 @@ template <typename... Args> class ndarray {
int64_t index = 0;
((index += int64_t(indices) * m_dltensor.strides[counter++]), ...);

return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type);
return (int64_t) m_dltensor.byte_offset + index * sizeof(Scalar);
} else {
return 0;
}
Expand Down
Loading

0 comments on commit 97b1e37

Please sign in to comment.