Skip to content

Commit

Permalink
SmallMatrix: Matrix class with compile time size
Browse files Browse the repository at this point in the history
Add amrex::SmallMatrix class with compile time size.
  • Loading branch information
WeiqunZhang committed Oct 1, 2024
1 parent 6d9c25b commit 42d9ba9
Show file tree
Hide file tree
Showing 11 changed files with 344 additions and 29 deletions.
5 changes: 1 addition & 4 deletions Src/Base/AMReX_Array.H
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <AMReX_REAL.H>
#include <AMReX_Algorithm.H>
#include <AMReX_Dim3.H>
#include <AMReX_SmallMatrix.H>

#include <array>
#include <memory>
Expand Down Expand Up @@ -148,10 +149,6 @@ namespace amrex {
* order (last index moving fastest). If not specified, Fortran order is
* assumed.
*/
namespace Order {
struct C {};
struct F {};
}

/**
* A GPU-compatible one-dimensional array.
Expand Down
38 changes: 38 additions & 0 deletions Src/Base/AMReX_ConstexprFor.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifndef AMREX_CONSTEXPR_FOR_H_
#define AMREX_CONSTEXPR_FOR_H_
#include <AMReX_Config.H>

#include <AMReX_GpuQualifiers.H>
#include <AMReX_Extension.H>

#include <type_traits>

namespace amrex {

// Implementation of "constexpr for" based on
// https://artificial-mind.net/blog/2020/10/31/constexpr-for
//
// Approximates what one would get from a compile-time
// unrolling of the loop
// for (int i = 0; i < N; ++i) {
// f(i);
// }
//
// The mechanism is recursive: we evaluate f(i) at the current
// i and then call the for loop at i+1. f() is a lambda function
// that provides the body of the loop and takes only an integer
// i as its argument.

template<auto I, auto N, class F>
AMREX_GPU_HOST_DEVICE AMREX_INLINE
constexpr void constexpr_for (F const& f)
{
if constexpr (I < N) {
f(std::integral_constant<decltype(I), I>());
constexpr_for<I+1, N>(f);
}
}

}

#endif
25 changes: 1 addition & 24 deletions Src/Base/AMReX_Loop.H
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <AMReX_Config.H>

#include <AMReX_Box.H>
#include <AMReX_ConstexprFor.H>
#include <AMReX_Extension.H>

namespace amrex {
Expand Down Expand Up @@ -567,30 +568,6 @@ void LoopConcurrentOnCpu (BoxND<dim> const& bx, int ncomp, F const& f) noexcept
}
}

// Implementation of "constexpr for" based on
// https://artificial-mind.net/blog/2020/10/31/constexpr-for
//
// Approximates what one would get from a compile-time
// unrolling of the loop
// for (int i = 0; i < N; ++i) {
// f(i);
// }
//
// The mechanism is recursive: we evaluate f(i) at the current
// i and then call the for loop at i+1. f() is a lambda function
// that provides the body of the loop and takes only an integer
// i as its argument.

template<auto I, auto N, class F>
AMREX_GPU_HOST_DEVICE AMREX_INLINE
constexpr void constexpr_for (F const& f)
{
if constexpr (I < N) {
f(std::integral_constant<decltype(I), I>());
constexpr_for<I+1, N>(f);
}
}

#include <AMReX_Loop.nolint.H>

}
Expand Down
179 changes: 179 additions & 0 deletions Src/Base/AMReX_SmallMatrix.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#ifndef AMREX_SMALL_MATRIX_H_
#define AMREX_SMALL_MATRIX_H_
#include <AMReX_Config.H>

#include <AMReX_Algorithm.H>
#include <AMReX_BLassert.H>
#include <AMReX_Extension.H>
#include <AMReX_GpuQualifiers.H>
#include <AMReX_ConstexprFor.H>

#include <type_traits>

namespace amrex {

namespace Order {
struct C {};
struct F {};
}

template <class T, int M, int N, class ORDER=Order::F>
struct SmallMatrix
{
using value_type = T;
using reference_type = T&;
static constexpr int row_size = M;
static constexpr int column_size = N;

template <typename O=ORDER,
std::enable_if_t<std::is_same_v<O,Order::F>,int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
const T& operator() (int i, int j) const noexcept {
AMREX_ASSERT(i < M && j < N);
return arr[i+j*M];
}

template <typename O=ORDER,
std::enable_if_t<std::is_same_v<O,Order::F>,int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
T& operator() (int i, int j) noexcept {
AMREX_ASSERT(i < M && j < N);
return arr[i+j*M];
}

template <typename O=ORDER,
std::enable_if_t<std::is_same_v<O,Order::C>,int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
const T& operator() (int i, int j) const noexcept {
AMREX_ASSERT(i < M && j < N);
return arr[j+i*N];
}

template <typename O=ORDER,
std::enable_if_t<std::is_same_v<O,Order::C>,int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
T& operator() (int i, int j) noexcept {
AMREX_ASSERT(i < M && j < N);
return arr[j+i*N];
}

template <int MM=M, int NN=N, std::enable_if_t<(MM==1 || NN==1), int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
const T& operator() (int i) const noexcept {
AMREX_ASSERT((M==1 && i<N) || (N==1 && i<M));
return arr[i];
}

template <int MM=M, int NN=N, std::enable_if_t<(MM==1 || NN==1), int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
T& operator() (int i) noexcept {
AMREX_ASSERT((M==1 && i<N) || (N==1 && i<M));
return arr[i];
}

template <int MM=M, int NN=N, std::enable_if_t<(MM==1 || NN==1), int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
const T& operator[] (int i) const noexcept {
AMREX_ASSERT((M==1 && i<N) || (N==1 && i<M));
return arr[i];
}

template <int MM=M, int NN=N, std::enable_if_t<(MM==1 || NN==1), int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
T& operator[] (int i) noexcept {
AMREX_ASSERT((M==1 && i<N) || (N==1 && i<M));
return arr[i];
}

template <int MM=M, int NN=N, std::enable_if_t<MM==NN, int> = 0>
static constexpr SmallMatrix<T,M,N,ORDER>
Identity () noexcept {
SmallMatrix<T,M,N,ORDER> I{};
constexpr_for<0,M>([&] (int i) { I(i,i) = T(1); });
return I;
}

static constexpr SmallMatrix<T,M,N,ORDER>
Zero () noexcept {
SmallMatrix<T,M,N,ORDER> Z{};
return Z;
}

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
SmallMatrix<T,N,M,ORDER>
transpose () const
{
SmallMatrix<T,N,M,ORDER> r;
for (int j = 0; j < M; ++j) {
for (int i = 0; i < N; ++i) {
r(i,j) = (*this)(j,i);
}
}
return r;
}

template <int MM=M, int NN=N, std::enable_if_t<MM==NN,int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
SmallMatrix<T,M,N,ORDER>&
transposeInPlace ()
{
for (int j = 1; j < N; ++j) {
for (int i = 0; i < j; ++i) {
amrex::Swap((*this)(i,j), (*this)(j,i));
}
}
return *this;
}

template <class U, int N1, int N2, int N3, class Ord>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
friend SmallMatrix<U,N1,N3,Ord>
operator* (SmallMatrix<U,N1,N2,Ord> const& lhs,
SmallMatrix<U,N2,N3,Ord> const& rhs);

T arr[M*N];
};

template <class U, int N1, int N2, int N3, class Ord>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
SmallMatrix<U,N1,N3,Ord>
operator* (SmallMatrix<U,N1,N2,Ord> const& lhs,
SmallMatrix<U,N2,N3,Ord> const& rhs)
{
static_assert(std::is_same_v<Ord,Order::F> ||
std::is_same_v<Ord,Order::C>);
SmallMatrix<U,N1,N3,Ord> r;
if constexpr (std::is_same_v<Ord,Order::F>) {
for (int j = 0; j < N3; ++j) {
constexpr_for<0,N1>([&] (int i) { r(i,j) = U(0); });
for (int k = 0; k < N2; ++k) {
auto b = rhs(k,j);
constexpr_for<0,N1>([&] (int i)
{
r(i,j) += lhs(i,k) * b;
});
}
}
} else {
for (int i = 0; i < N1; ++i) {
constexpr_for<0,N3>([&] (int j) { r(i,j) = U(0); });
for (int k = 0; k < N2; ++k) {
auto a = lhs(i,k);
constexpr_for<0,N3>([&] (int j)
{
r(i,j) += a * rhs(k,j);
});
}
}
}
return r;
}

template <class T, int N>
using SmallVector = SmallMatrix<T,N,1>;

template <class T, int N>
using SmallRowVector = SmallMatrix<T,1,N>;
}

#endif
2 changes: 2 additions & 0 deletions Src/Base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ foreach(D IN LISTS AMReX_SPACEDIM)
AMReX_BlockMutex.cpp
AMReX_Enum.H
AMReX_GpuComplex.H
AMReX_SmallMatrix.H
AMReX_ConstexprFor.H
AMReX_Vector.H
AMReX_TableData.H
AMReX_Tuple.H
Expand Down
1 change: 1 addition & 0 deletions Src/Base/Make.package
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ AMREX_BASE=EXE
C$(AMREX_BASE)_headers += AMReX_ccse-mpi.H AMReX_Algorithm.H AMReX_Any.H AMReX_Array.H
C$(AMREX_BASE)_headers += AMReX_Enum.H
C$(AMREX_BASE)_headers += AMReX_Vector.H AMReX_TableData.H AMReX_Tuple.H AMReX_Math.H
C$(AMREX_BASE)_headers += AMReX_SmallMatrix.H AMReX_ConstexprFor.H

C$(AMREX_BASE)_headers += AMReX_TypeList.H

Expand Down
2 changes: 1 addition & 1 deletion Tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ else()
#
set( AMREX_TESTS_SUBDIRS Amr AsyncOut CLZ CTOParFor DeviceGlobal Enum
MultiBlock MultiPeriod ParmParse Parser Parser2 Reinit
RoundoffDomain)
RoundoffDomain SmallMatrix)

if (AMReX_PARTICLES)
list(APPEND AMREX_TESTS_SUBDIRS Particles)
Expand Down
9 changes: 9 additions & 0 deletions Tests/SmallMatrix/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
foreach(D IN LISTS AMReX_SPACEDIM)
set(_sources main.cpp)
set(_input_files)

setup_test(${D} _sources _input_files)

unset(_sources)
unset(_input_files)
endforeach()
24 changes: 24 additions & 0 deletions Tests/SmallMatrix/GNUmakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
AMREX_HOME := ../..

DEBUG = FALSE

DIM = 3

COMP = gcc

USE_MPI = FALSE
USE_OMP = FALSE
USE_CUDA = FALSE
USE_HIP = FALSE
USE_SYCL = FALSE

BL_NO_FORT = TRUE

TINY_PROFILE = FALSE

include $(AMREX_HOME)/Tools/GNUMake/Make.defs

include ./Make.package
include $(AMREX_HOME)/Src/Base/Make.package

include $(AMREX_HOME)/Tools/GNUMake/Make.rules
1 change: 1 addition & 0 deletions Tests/SmallMatrix/Make.package
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CEXE_sources += main.cpp
Loading

0 comments on commit 42d9ba9

Please sign in to comment.