Skip to content

Commit

Permalink
[Draft] Array4: __array_interface__
Browse files Browse the repository at this point in the history
Implement the `__array_interface__` in Array for.
  • Loading branch information
ax3l committed Feb 15, 2021
1 parent a629bb1 commit afa91ed
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 0 deletions.
160 changes: 160 additions & 0 deletions src/Base/Array4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/* Copyright 2021 The AMReX Community
*
* Authors: Axel Huebl
* License: BSD-3-Clause-LBNL
*/
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>

#include <AMReX_Config.H>
#include <AMReX_Array4.H>
#include <AMReX_IntVect.H>

#include <sstream>
#include <type_traits>

namespace py = pybind11;
using namespace amrex;


template< typename T >
void make_Array4(py::module &m, std::string typestr)
{
// dispatch simpler via: py::format_descriptor<T>::format() naming
auto const array_name = std::string("Array4_").append(typestr);
py::class_< Array4<T> >(m, array_name.c_str(), py::buffer_protocol())
.def("__repr__",
[](Array4<T> const & a4) {
std::stringstream s;
s << a4.size();
return "<amrex.Array4 of size '" + s.str() + "'>";
}
)
#if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK)
.def("index_assert", &Array4<T>::index_assert)
#endif

.def_property_readonly("size", &Array4<T>::size)
.def_property_readonly("nComp", &Array4<T>::nComp)

.def(py::init< >())
.def(py::init< Array4<T> const & >())
.def(py::init< Array4<T> const &, int >())
.def(py::init< Array4<T> const &, int, int >())
//.def(py::init< T*, Dim3 const &, Dim3 const &, int >())

.def(py::init([](py::array_t<T> & arr) {
py::buffer_info buf = arr.request();

auto a4 = std::make_unique< Array4<T> >();
a4.get()->p = (T*)buf.ptr;
a4.get()->begin = Dim3{0, 0, 0};
// TODO: likely C->F index conversion here
// p[(i-begin.x)+(j-begin.y)*jstride+(k-begin.z)*kstride+n*nstride];
a4.get()->end.x = (int)buf.shape.at(0);
a4.get()->end.y = (int)buf.shape.at(1);
a4.get()->end.z = (int)buf.shape.at(2);
a4.get()->ncomp = 1;
// buffer protocol strides are in bytes, AMReX strides are elements
a4.get()->jstride = (int)buf.strides.at(0) / sizeof(T);
a4.get()->kstride = (int)buf.strides.at(1) / sizeof(T);
a4.get()->nstride = (int)buf.strides.at(2) * (int)buf.shape.at(2) / sizeof(T);
return a4;
}))


.def_property_readonly("__array_interface__", [](Array4<T> const & a4) {
auto d = py::dict();
auto const len = length(a4);
// TODO: likely F->C index conversion here
// p[(i-begin.x)+(j-begin.y)*jstride+(k-begin.z)*kstride+n*nstride];
auto shape = py::make_tuple( // Buffer dimensions
len.x < 0 ? 0 : len.x,
len.y < 0 ? 0 : len.y,
len.z < 0 ? 0 : len.z//, // zero-size shall not have negative dimension
//a4.ncomp
);
// buffer protocol strides are in bytes, AMReX strides are elements
auto const strides = py::make_tuple( // Strides (in bytes) for each index
sizeof(T) * a4.jstride,
sizeof(T) * a4.kstride,
sizeof(T)//,
//sizeof(T) * a4.nstride
);
d["data"] = py::make_tuple(long(a4.dataPtr()), false);
d["typestr"] = py::format_descriptor<T>::format();
d["shape"] = shape;
d["strides"] = strides;
// d["strides"] = py::none();
d["version"] = 3;
return d;
})

// not sure if useful to have this implemented on top
/*
.def_buffer([](Array4<T> & a4) -> py::buffer_info {
auto const len = length(a4);
// TODO: likely F->C index conversion here
// p[(i-begin.x)+(j-begin.y)*jstride+(k-begin.z)*kstride+n*nstride];
auto shape = { // Buffer dimensions
len.x < 0 ? 0 : len.x,
len.y < 0 ? 0 : len.y,
len.z < 0 ? 0 : len.z//, // zero-size shall not have negative dimension
//a4.ncomp
};
// buffer protocol strides are in bytes, AMReX strides are elements
auto const strides = { // Strides (in bytes) for each index
sizeof(T) * a4.jstride,
sizeof(T) * a4.kstride,
sizeof(T)//,
//sizeof(T) * a4.nstride
};
return py::buffer_info(
a4.dataPtr(),
shape,
strides
);
})
*/
;
}

void init_Array4(py::module &m) {
make_Array4< float >(m, "float");
make_Array4< double >(m, "double");
make_Array4< long double >(m, "longdouble");

make_Array4< short >(m, "short");
make_Array4< int >(m, "int");
make_Array4< long >(m, "long");
make_Array4< long long >(m, "longlong");

make_Array4< unsigned short >(m, "ushort");
make_Array4< unsigned int >(m, "uint");
make_Array4< unsigned long >(m, "ulong");
make_Array4< unsigned long long >(m, "ulonglong");

// std::complex< float|double|long double> ?

/*
py::class_< PolymorphicArray4, Array4 >(m, "PolymorphicArray4")
.def("__repr__",
[](PolymorphicArray4 const & pa4) {
std::stringstream s;
s << pa4.size();
return "<amrex.PolymorphicArray4 of size '" + s.str() + "'>";
}
)
;
*/

// free standing C++ functions:
/*
contains
lbound
ubound
length
makePolymorphic
*/
}
1 change: 1 addition & 0 deletions src/Base/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
target_sources(pyAMReX
PRIVATE
AMReX.cpp
Array4.cpp
Box.cpp
Dim3.cpp
IntVect.cpp
Expand Down
2 changes: 2 additions & 0 deletions src/pyAMReX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace py = pybind11;

// forward declarations of exposed classes
void init_AMReX(py::module&);
void init_Array4(py::module&);
void init_Box(py::module &);
void init_Dim3(py::module&);
void init_IntVect(py::module &);
Expand All @@ -37,6 +38,7 @@ PYBIND11_MODULE(amrex_pybind, m) {
init_AMReX(m);
init_Dim3(m);
init_IntVect(m);
init_Array4(m);
init_Box(m);

// API runtime version
Expand Down

0 comments on commit afa91ed

Please sign in to comment.