Skip to content

Commit

Permalink
Minor documentation updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
hpkfft committed Mar 29, 2024
1 parent 42fe541 commit b2d4873
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/ndarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ Binding functions that take arrays as input
-------------------------------------------

A function that accepts a :cpp:class:`nb::ndarray\<\> <ndarray>`-typed parameter
(i.e., *without* template parameters) can be called with *any* array
(i.e., *without* template parameters) can be called with *any* writable array
from any framework regardless of the device on which it is stored. The
following example binding declaration uses this functionality to inspect the
properties of an arbitrary input array:

.. code-block:: cpp
m.def("inspect", [](nb::ndarray<> a) {
m.def("inspect", [](const nb::ndarray<>& a) {
printf("Array data pointer : %p\n", a.data());
printf("Array dimension : %zu\n", a.ndim());
for (size_t i = 0; i < a.ndim(); ++i) {
Expand Down
4 changes: 2 additions & 2 deletions include/nanobind/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,15 +326,15 @@ template <typename... Ts> struct ndarray_info<pytorch, Ts...> : ndarray_info<Ts.
template <typename... Ts> struct ndarray_info<tensorflow, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::framework == ndarray_framework::none
|| ndarray_info<Ts...>::framework == ndarray_framework::tensorflow,
"The framework cannot be set more than once.");
"The framework can only be set once.");
constexpr static auto name = const_name("tensorflow.python.framework.ops.EagerTensor");
constexpr static ndarray_framework framework = ndarray_framework::tensorflow;
};

template <typename... Ts> struct ndarray_info<jax, Ts...> : ndarray_info<Ts...> {
static_assert(ndarray_info<Ts...>::framework == ndarray_framework::none
|| ndarray_info<Ts...>::framework == ndarray_framework::jax,
"The framework cannot be set more than once.");
"The framework can only be set once.");
constexpr static auto name = const_name("jaxlib.xla_extension.DeviceArray");
constexpr static ndarray_framework framework = ndarray_framework::jax;
};
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ NB_MODULE(test_ndarray_ext, m) {
[](nb::ndarray<float, nb::c_contig, nb::shape<2, 2>>) { return 0; },
"array"_a);

m.def("inspect_ndarray", [](nb::ndarray<> ndarray) {
m.def("inspect_ndarray", [](const nb::ndarray<>& ndarray) {
printf("Tensor data pointer : %p\n", ndarray.data());
printf("Tensor dimension : %zu\n", ndarray.ndim());
for (size_t i = 0; i < ndarray.ndim(); ++i) {
Expand Down

0 comments on commit b2d4873

Please sign in to comment.