Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate const of ndarray from const of its data. #491

Merged
merged 7 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions docs/api_extra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,12 @@ section <ndarrays>`.

.. cpp:class:: template <typename... Args> ndarray

.. cpp:var:: ReadOnly

A constant static boolean that is true if the array's data is read-only.
This is determined by the class template arguments, not by any dynamic
properties of the referenced array.

.. cpp:function:: ndarray() = default

Create an invalid array.
Expand Down Expand Up @@ -793,14 +799,19 @@ section <ndarrays>`.
In a multi-device/GPU setup, this function returns the ID of the device
storing the array.

.. cpp:function:: const Scalar * data() const
.. cpp:function:: Scalar * data() const

Return a pointer to the array data.
If :cpp:var:`ReadOnly` is true, a pointer-to-const is returned.

Return a const pointer to the array data.
.. cpp:function:: template <typename... Ts> auto& operator()(Ts... indices)

.. cpp:function:: Scalar * data()
Return a reference to the element stored at the provided index/indices.
If :cpp:var:`ReadOnly` is true, a reference-to-const is returned.
Note that ``sizeof(Ts)`` must match :cpp:func:`ndim()`.

Return a mutable pointer to the array data. Only enabled when `Scalar` is
not itself ``const``.
This accessor is only available when the scalar type and array dimension
were specified as template parameters.

.. cpp:function:: template <typename... Extra> auto view()

Expand All @@ -813,14 +824,6 @@ section <ndarrays>`.
``shape()``, ``stride()``, and ``operator()`` following the conventions
of the `ndarray` type.

.. cpp:function:: template <typename... Ts> auto& operator()(Ts... indices)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this removal? AFAIK this operator is still there, and it should be documented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not removed; it's on lines 692-696. The diff algorithm got confused (from a human perspective) and spliced things together weirdly.
There had been two data() functions documented, with and without const-qualified this. Now there is only one as the return value is const-qualified based on the writability of the actual data and not on the const qualification of the ndarray object itself.

I did move the documentation of operator() above that of view(). I thought it would be helpful for readers to see operator() next to data() so they could see both and decide what better suits their purpose. Also, view() documents that it provides operator() following the same conventions as ndarray, so it seemed nice to document all of ndarray's functions before discussing view().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, lines 807-814 if you're looking at diff after the merge commit....


Return a mutable reference to the element at stored at the provided
index/indices. ``sizeof(Ts)`` must match :cpp:func:`ndim()`.

This accessor is only available when the scalar type and array dimension
were specified as template parameters.

Data types
^^^^^^^^^^

Expand Down
4 changes: 2 additions & 2 deletions docs/ndarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ Binding functions that take arrays as input
-------------------------------------------

A function that accepts a :cpp:class:`nb::ndarray\<\> <ndarray>`-typed parameter
(i.e., *without* template parameters) can be called with *any* array
(i.e., *without* template parameters) can be called with *any* writable array
from any framework regardless of the device on which it is stored. The
following example binding declaration uses this functionality to inspect the
properties of an arbitrary input array:

.. code-block:: cpp

m.def("inspect", [](nb::ndarray<> a) {
m.def("inspect", [](const nb::ndarray<>& a) {
printf("Array data pointer : %p\n", a.data());
printf("Array dimension : %zu\n", a.ndim());
for (size_t i = 0; i < a.ndim(); ++i) {
Expand Down
86 changes: 55 additions & 31 deletions include/nanobind/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ template <size_t N> using ndim = typename detail::ndim_shape<std::make_index_seq
template <typename T> constexpr dlpack::dtype dtype() {
static_assert(
detail::is_ndarray_scalar_v<T>,
"nanobind::dtype<T>: T must be a floating point or integer variable!"
"nanobind::dtype<T>: T must be a floating point or integer type!"
);

dlpack::dtype result;
Expand Down Expand Up @@ -266,46 +266,76 @@ template <typename T> struct ndarray_arg<T, enable_if_t<T::is_device>> {
template <typename... Ts> struct ndarray_info {
using scalar_type = void;
using shape_type = void;
constexpr static bool ReadOnly = false;
constexpr static auto name = const_name("ndarray");
constexpr static ndarray_framework framework = ndarray_framework::none;
constexpr static char order = '\0';
};

template <typename T, typename... Ts> struct ndarray_info<T, Ts...> : ndarray_info<Ts...> {
using scalar_type =
std::conditional_t<ndarray_traits<T>::is_float || ndarray_traits<T>::is_int ||
ndarray_traits<T>::is_bool || ndarray_traits<T>::is_complex,
T, typename ndarray_info<Ts...>::scalar_type>;
std::conditional_t<
detail::is_ndarray_scalar_v<T> &&
std::is_void_v<typename ndarray_info<Ts...>::scalar_type>,
T, typename ndarray_info<Ts...>::scalar_type>;

constexpr static bool ReadOnly = ndarray_info<Ts...>::ReadOnly ||
(detail::is_ndarray_scalar_v<T> && std::is_const_v<T>);
};

template <typename... Ts> struct ndarray_info<ro, Ts...> : ndarray_info<Ts...> {
constexpr static bool ReadOnly = true;
};

template <ssize_t... Is, typename... Ts> struct ndarray_info<shape<Is...>, Ts...> : ndarray_info<Ts...> {
using shape_type = shape<Is...>;
using shape_type =
std::conditional_t<
std::is_void_v<typename ndarray_info<Ts...>::shape_type>,
shape<Is...>, typename ndarray_info<Ts...>::shape_type>;
};

template <typename... Ts> struct ndarray_info<c_contig, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::order == '\0'
|| ndarray_info<Ts...>::order == 'C',
"The order can only be set once.");
constexpr static char order = 'C';
};

template <typename... Ts> struct ndarray_info<f_contig, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::order == '\0'
|| ndarray_info<Ts...>::order == 'F',
"The order can only be set once.");
constexpr static char order = 'F';
};

template <typename... Ts> struct ndarray_info<numpy, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::framework == ndarray_framework::none
|| ndarray_info<Ts...>::framework == ndarray_framework::numpy,
"The framework can only be set once.");
constexpr static auto name = const_name("numpy.ndarray");
constexpr static ndarray_framework framework = ndarray_framework::numpy;
};

template <typename... Ts> struct ndarray_info<pytorch, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::framework == ndarray_framework::none
|| ndarray_info<Ts...>::framework == ndarray_framework::pytorch,
"The framework can only be set once.");
constexpr static auto name = const_name("torch.Tensor");
constexpr static ndarray_framework framework = ndarray_framework::pytorch;
};

template <typename... Ts> struct ndarray_info<tensorflow, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::framework == ndarray_framework::none
|| ndarray_info<Ts...>::framework == ndarray_framework::tensorflow,
"The framework can only be set once.");
constexpr static auto name = const_name("tensorflow.python.framework.ops.EagerTensor");
constexpr static ndarray_framework framework = ndarray_framework::tensorflow;
};

template <typename... Ts> struct ndarray_info<jax, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::framework == ndarray_framework::none
|| ndarray_info<Ts...>::framework == ndarray_framework::jax,
"The framework can only be set once.");
constexpr static auto name = const_name("jaxlib.xla_extension.DeviceArray");
constexpr static ndarray_framework framework = ndarray_framework::jax;
};
Expand All @@ -314,9 +344,9 @@ template <typename... Ts> struct ndarray_info<cupy, Ts...> : ndarray_info<Ts...>
constexpr static ndarray_framework framework = ndarray_framework::cupy;
};


NAMESPACE_END(detail)


template <typename Scalar, typename Shape, char Order> struct ndarray_view {
static constexpr size_t Dim = Shape::size;

Expand Down Expand Up @@ -380,7 +410,10 @@ template <typename... Args> class ndarray {
template <typename...> friend class ndarray;

using Info = detail::ndarray_info<Args...>;
using Scalar = typename Info::scalar_type;
static constexpr bool ReadOnly = Info::ReadOnly;
using Scalar = std::conditional_t<ReadOnly,
std::add_const_t<typename Info::scalar_type>,
typename Info::scalar_type>;

ndarray() = default;

Expand All @@ -392,7 +425,7 @@ template <typename... Args> class ndarray {
template <typename... Args2>
explicit ndarray(const ndarray<Args2...> &other) : ndarray(other.m_handle) { }

ndarray(std::conditional_t<std::is_const_v<Scalar>, const void *, void *> data,
ndarray(std::conditional_t<ReadOnly, const void *, void *> data,
size_t ndim,
const size_t *shape,
handle owner,
Expand All @@ -402,11 +435,11 @@ template <typename... Args> class ndarray {
int32_t device_id = 0) {
m_handle = detail::ndarray_create(
(void *) data, ndim, shape, owner.ptr(), strides, &dtype,
std::is_const_v<Scalar>, device_type, device_id);
ReadOnly, device_type, device_id);
m_dltensor = *detail::ndarray_inc_ref(m_handle);
}

ndarray(std::conditional_t<std::is_const_v<Scalar>, const void *, void *> data,
ndarray(std::conditional_t<ReadOnly, const void *, void *> data,
std::initializer_list<size_t> shape,
handle owner,
std::initializer_list<int64_t> strides = { },
Expand All @@ -420,7 +453,7 @@ template <typename... Args> class ndarray {
m_handle = detail::ndarray_create(
(void *) data, shape.size(), shape.begin(), owner.ptr(),
(strides.size() == 0) ? nullptr : strides.begin(), &dtype,
std::is_const_v<Scalar>, device_type, device_id);
ReadOnly, device_type, device_id);

m_dltensor = *detail::ndarray_inc_ref(m_handle);
}
Expand Down Expand Up @@ -476,35 +509,26 @@ template <typename... Args> class ndarray {
size_t itemsize() const { return ((size_t) dtype().bits + 7) / 8; }
size_t nbytes() const { return ((size_t) dtype().bits * size() + 7) / 8; }

const Scalar *data() const {
return (const Scalar *)((const uint8_t *) m_dltensor.data + m_dltensor.byte_offset);
}

template <typename T = Scalar, std::enable_if_t<!std::is_const_v<T>, int> = 1>
Scalar *data() {
Scalar *data() const {
return (Scalar *) ((uint8_t *) m_dltensor.data +
m_dltensor.byte_offset);
}

template <typename T = Scalar,
std::enable_if_t<!std::is_const_v<T>, int> = 1, typename... Ts>
NB_INLINE auto &operator()(Ts... indices) {
template <typename... Ts>
NB_INLINE auto& operator()(Ts... indices) const {
return *(Scalar *) ((uint8_t *) m_dltensor.data +
byte_offset(indices...));
}

template <typename... Ts> NB_INLINE const auto & operator()(Ts... indices) const {
return *(const Scalar *) ((const uint8_t *) m_dltensor.data +
byte_offset(indices...));
}

template <typename... Extra> NB_INLINE auto view() const {
using Info2 = typename ndarray<Args..., Extra...>::Info;
using Scalar2 = typename Info2::scalar_type;
using Scalar2 = std::conditional_t<Info2::ReadOnly,
std::add_const_t<typename Info2::scalar_type>,
typename Info2::scalar_type>;
using Shape2 = typename Info2::shape_type;

constexpr bool has_scalar = !std::is_same_v<Scalar2, void>,
has_shape = !std::is_same_v<Shape2, void>;
constexpr bool has_scalar = !std::is_void_v<Scalar2>,
has_shape = !std::is_void_v<Shape2>;

static_assert(has_scalar,
"To use the ndarray::view<..>() method, you must add a scalar type "
Expand All @@ -528,8 +552,8 @@ template <typename... Args> class ndarray {
private:
template <typename... Ts>
NB_INLINE int64_t byte_offset(Ts... indices) const {
constexpr bool has_scalar = !std::is_same_v<Scalar, void>,
has_shape = !std::is_same_v<typename Info::shape_type, void>;
constexpr bool has_scalar = !std::is_void_v<Scalar>,
has_shape = !std::is_void_v<typename Info::shape_type>;

static_assert(has_scalar,
"To use ndarray::operator(), you must add a scalar type "
Expand All @@ -547,7 +571,7 @@ template <typename... Args> class ndarray {
int64_t index = 0;
((index += int64_t(indices) * m_dltensor.strides[counter++]), ...);

return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type);
return (int64_t) m_dltensor.byte_offset + index * sizeof(Scalar);
} else {
return 0;
}
Expand Down
Loading
Loading