Skip to content

Commit

Permalink
Merge pull request #275 from bd4/pr/cuda-solver-spsv
Browse files Browse the repository at this point in the history
Pr/cuda solver spsv
  • Loading branch information
bd4 authored Sep 1, 2023
2 parents f5f86af + 96acc4f commit 5a1e34f
Show file tree
Hide file tree
Showing 6 changed files with 460 additions and 75 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ option(GTENSOR_ENABLE_BLAS "Enable gtblas" OFF)
option(GTENSOR_ENABLE_FFT "Enable gtfft" OFF)
option(GTENSOR_ENABLE_FORTRAN "Enable Fortran interoperability" OFF)
option(GTENSOR_ENABLE_SOLVER "Enable high level solver library" OFF)
option(GTENSOR_SOLVER_HIP_SPARSE_GENERIC "Use rocSPARSE generic API for sparse solver" OFF)

option(GTENSOR_PER_DIM_KERNELS
"Enable per dim kernels (may break for large arrays)" OFF)
Expand Down Expand Up @@ -459,6 +460,10 @@ if (GTENSOR_ENABLE_BLAS)
target_link_libraries(gtsolver CUDA::cusparse)
elseif (${GTENSOR_DEVICE} STREQUAL "hip")
target_link_libraries(gtsolver rocsparse)
if (GTENSOR_SOLVER_HIP_SPARSE_GENERIC)
target_compile_definitions(gtsolver INTERFACE
GTENSOR_SOLVER_HIP_SPARSE_GENERIC)
endif()
endif()
list(APPEND GTENSOR_TARGETS gtsolver)
add_library(gtensor::gtsolver ALIAS gtsolver)
Expand Down
15 changes: 10 additions & 5 deletions include/gt-solver/backend/cuda-bsrsm2.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class csr_matrix_lu_cuda_bsrsm2
public:
using value_type = T;
using space_type = gt::space::device;
static constexpr bool inplace = false;
static constexpr bool inplace = true;

csr_matrix_lu_cuda_bsrsm2(gt::sparse::csr_matrix<T, space_type>& csr_mat,
const T alpha, int nrhs,
Expand Down Expand Up @@ -204,23 +204,28 @@ class csr_matrix_lu_cuda_bsrsm2

void solve(T* rhs, T* result)
{
// first solve in place into rhs
// in place solve in result vector
if (result == nullptr) {
result = rhs;
} else if (rhs != result) {
gt::copy_n(gt::device_pointer_cast(rhs), csr_mat_.shape(0) * nrhs_,
gt::device_pointer_cast(result));
}
gtSparseCheck(FN::solve(
h_.get_backend_handle(), CUSPARSE_DIRECTION_COLUMN,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE,
csr_mat_.shape(0), nrhs_, csr_mat_.nnz(), FN::cast_pointer(&alpha_),
l_desc_, FN::cast_pointer(csr_mat_.values_data()),
csr_mat_.row_ptr_data(), csr_mat_.col_ind_data(), 1, l_info_,
FN::cast_pointer(rhs), csr_mat_.shape(0), FN::cast_pointer(rhs),
FN::cast_pointer(result), csr_mat_.shape(0), FN::cast_pointer(result),
csr_mat_.shape(0), policy_, FN::cast_pointer(l_buf_.data())));
// second solve uses solution of first as input in rhs, result as output
gtSparseCheck(FN::solve(
h_.get_backend_handle(), CUSPARSE_DIRECTION_COLUMN,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE,
csr_mat_.shape(0), nrhs_, csr_mat_.nnz(), FN::cast_pointer(&alpha_),
u_desc_, FN::cast_pointer(csr_mat_.values_data()),
csr_mat_.row_ptr_data(), csr_mat_.col_ind_data(), 1, u_info_,
FN::cast_pointer(rhs), csr_mat_.shape(0), FN::cast_pointer(result),
FN::cast_pointer(result), csr_mat_.shape(0), FN::cast_pointer(result),
csr_mat_.shape(0), policy_, FN::cast_pointer(u_buf_.data())));
}

Expand Down
184 changes: 116 additions & 68 deletions include/gt-solver/backend/cuda-generic.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,21 +127,12 @@ class csr_matrix_lu_cuda_generic
public:
using value_type = T;
using space_type = gt::space::device;
// has it's own internal buffer and does not need wrapping, because of
// dense vector descriptors.
// TODO: is it expensive to change the descriptors on each call to solve? If
// not, we can remove the buffers here to be more like the other backends,
// which don't buffer at all by default and rely on wrapper class when needed.
static constexpr bool inplace = true;

csr_matrix_lu_cuda_generic(gt::sparse::csr_matrix<T, space_type>& csr_mat,
const T alpha, int nrhs,
gt::stream_view sview = gt::stream_view{})
: csr_mat_(csr_mat),
alpha_(alpha),
nrhs_(nrhs),
rhs_tmp_(gt::shape(csr_mat.shape(0), nrhs)),
result_tmp_(gt::shape(csr_mat.shape(0), nrhs))
: csr_mat_(csr_mat), alpha_(alpha), nrhs_(nrhs)
{
gtSparseCheck(
cusparseSetStream(h_.get_backend_handle(), sview.get_backend_stream()));
Expand Down Expand Up @@ -170,73 +161,129 @@ class csr_matrix_lu_cuda_generic
gtSparseCheck(cusparseSpMatSetAttribute(u_desc_, CUSPARSE_SPMAT_DIAG_TYPE,
&u_diag_type, sizeof(u_diag_type)));

gtSparseCheck(cusparseSpSM_createDescr(&l_spsm_desc_));
gtSparseCheck(cusparseSpSM_createDescr(&u_spsm_desc_));

gtSparseCheck(cusparseCreateDnMat(&rhs_desc_, csr_mat_.shape(0), nrhs_,
csr_mat_.shape(0), rhs_tmp_.data().get(),
FN::dtype, CUSPARSE_ORDER_COL));
gtSparseCheck(cusparseCreateDnMat(
&result_desc_, csr_mat_.shape(0), nrhs_, csr_mat_.shape(0),
result_tmp_.data().get(), FN::dtype, CUSPARSE_ORDER_COL));

// analyze
std::size_t l_buf_size, u_buf_size;
gtSparseCheck(cusparseSpSM_bufferSize(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, FN::cast_pointer(&alpha_), l_desc_,
rhs_desc_, result_desc_, FN::dtype, algo_, l_spsm_desc_, &l_buf_size));

gtSparseCheck(cusparseSpSM_bufferSize(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, FN::cast_pointer(&alpha_), u_desc_,
rhs_desc_, result_desc_, FN::dtype, algo_, u_spsm_desc_, &u_buf_size));

l_buf_.resize(gt::shape(l_buf_size));
u_buf_.resize(gt::shape(u_buf_size));

gtSparseCheck(cusparseSpSM_analysis(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, FN::cast_pointer(&alpha_), l_desc_,
result_desc_, rhs_desc_, FN::dtype, algo_, l_spsm_desc_,
FN::cast_pointer(l_buf_.data())));

gtSparseCheck(cusparseSpSM_analysis(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, FN::cast_pointer(&alpha_), u_desc_,
rhs_desc_, result_desc_, FN::dtype, algo_, u_spsm_desc_,
FN::cast_pointer(u_buf_.data())));
if (nrhs_ > 1) {
gtSparseCheck(cusparseSpSM_createDescr(&l_spsm_desc_));
gtSparseCheck(cusparseSpSM_createDescr(&u_spsm_desc_));

gt::gtensor<T, 2, space_type> rhs_tmp(gt::shape(csr_mat.shape(0), nrhs_));
gtSparseCheck(cusparseCreateDnMat(
&rhs_mat_desc_, csr_mat_.shape(0), nrhs_, csr_mat_.shape(0),
rhs_tmp.data().get(), FN::dtype, CUSPARSE_ORDER_COL));

// analyze
std::size_t l_buf_size, u_buf_size;
gtSparseCheck(cusparseSpSM_bufferSize(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, FN::cast_pointer(&alpha_), l_desc_,
rhs_mat_desc_, rhs_mat_desc_, FN::dtype, algo_sm_, l_spsm_desc_,
&l_buf_size));

gtSparseCheck(cusparseSpSM_bufferSize(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, FN::cast_pointer(&alpha_), u_desc_,
rhs_mat_desc_, rhs_mat_desc_, FN::dtype, algo_sm_, u_spsm_desc_,
&u_buf_size));

l_buf_.resize(gt::shape(l_buf_size));
u_buf_.resize(gt::shape(u_buf_size));

gtSparseCheck(cusparseSpSM_analysis(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, FN::cast_pointer(&alpha_), l_desc_,
rhs_mat_desc_, rhs_mat_desc_, FN::dtype, algo_sm_, l_spsm_desc_,
FN::cast_pointer(l_buf_.data())));

gtSparseCheck(cusparseSpSM_analysis(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, FN::cast_pointer(&alpha_), u_desc_,
rhs_mat_desc_, rhs_mat_desc_, FN::dtype, algo_sm_, u_spsm_desc_,
FN::cast_pointer(u_buf_.data())));
} else {
// Note: SpSV APIs have better performance when nrhs=1, at least in 12.2
gtSparseCheck(cusparseSpSV_createDescr(&l_spsv_desc_));
gtSparseCheck(cusparseSpSV_createDescr(&u_spsv_desc_));

gt::gtensor<T, 1, space_type> rhs_tmp(gt::shape(csr_mat.shape(0)));
gtSparseCheck(cusparseCreateDnVec(&rhs_vec_desc_, csr_mat_.shape(0),
rhs_tmp.data().get(), FN::dtype));

// analyze
std::size_t l_buf_size, u_buf_size;
gtSparseCheck(cusparseSpSV_bufferSize(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
FN::cast_pointer(&alpha_), l_desc_, rhs_vec_desc_, rhs_vec_desc_,
FN::dtype, algo_sv_, l_spsv_desc_, &l_buf_size));

gtSparseCheck(cusparseSpSV_bufferSize(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
FN::cast_pointer(&alpha_), u_desc_, rhs_vec_desc_, rhs_vec_desc_,
FN::dtype, algo_sv_, u_spsv_desc_, &u_buf_size));

l_buf_.resize(gt::shape(l_buf_size));
u_buf_.resize(gt::shape(u_buf_size));

gtSparseCheck(cusparseSpSV_analysis(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
FN::cast_pointer(&alpha_), l_desc_, rhs_vec_desc_, rhs_vec_desc_,
FN::dtype, algo_sv_, l_spsv_desc_, FN::cast_pointer(l_buf_.data())));

gtSparseCheck(cusparseSpSV_analysis(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
FN::cast_pointer(&alpha_), u_desc_, rhs_vec_desc_, rhs_vec_desc_,
FN::dtype, algo_sv_, u_spsv_desc_, FN::cast_pointer(u_buf_.data())));
}
}

~csr_matrix_lu_cuda_generic()
{
gtSparseCheck(cusparseSpSM_destroyDescr(l_spsm_desc_));
gtSparseCheck(cusparseSpSM_destroyDescr(u_spsm_desc_));
if (nrhs_ > 1) {
gtSparseCheck(cusparseSpSM_destroyDescr(l_spsm_desc_));
gtSparseCheck(cusparseSpSM_destroyDescr(u_spsm_desc_));
gtSparseCheck(cusparseDestroyDnMat(rhs_mat_desc_));
} else {
gtSparseCheck(cusparseSpSV_destroyDescr(l_spsv_desc_));
gtSparseCheck(cusparseSpSV_destroyDescr(u_spsv_desc_));
gtSparseCheck(cusparseDestroyDnVec(rhs_vec_desc_));
}
gtSparseCheck(cusparseDestroySpMat(l_desc_));
gtSparseCheck(cusparseDestroySpMat(u_desc_));
gtSparseCheck(cusparseDestroyDnMat(rhs_desc_));
gtSparseCheck(cusparseDestroyDnMat(result_desc_));
}

void solve(T* rhs, T* result)
{
gt::copy_n(gt::device_pointer_cast(rhs), result_tmp_.size(),
result_tmp_.data());
gtSparseCheck(cusparseSpSM_solve(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, FN::cast_pointer(&alpha_), l_desc_,
result_desc_, rhs_desc_, FN::dtype, algo_, l_spsm_desc_));
gtSparseCheck(cusparseSpSM_solve(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, FN::cast_pointer(&alpha_), u_desc_,
rhs_desc_, result_desc_, FN::dtype, algo_, u_spsm_desc_));
gt::copy_n(result_tmp_.data(), result_tmp_.size(),
gt::device_pointer_cast(result));
// in place solve in result vector
if (result == nullptr) {
result = rhs;
} else if (rhs != result) {
gt::copy_n(gt::device_pointer_cast(rhs), csr_mat_.shape(0) * nrhs_,
gt::device_pointer_cast(result));
}
if (nrhs_ > 1) {
gtSparseCheck(cusparseDnMatSetValues(rhs_mat_desc_, result));
gtSparseCheck(cusparseSpSM_solve(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, FN::cast_pointer(&alpha_), l_desc_,
rhs_mat_desc_, rhs_mat_desc_, FN::dtype, algo_sm_, l_spsm_desc_));
gtSparseCheck(cusparseSpSM_solve(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, FN::cast_pointer(&alpha_), u_desc_,
rhs_mat_desc_, rhs_mat_desc_, FN::dtype, algo_sm_, u_spsm_desc_));
} else {
gtSparseCheck(cusparseDnVecSetValues(rhs_vec_desc_, result));
gtSparseCheck(cusparseSpSV_solve(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
FN::cast_pointer(&alpha_), l_desc_, rhs_vec_desc_, rhs_vec_desc_,
FN::dtype, algo_sv_, l_spsv_desc_));
gtSparseCheck(cusparseSpSV_solve(
h_.get_backend_handle(), CUSPARSE_OPERATION_NON_TRANSPOSE,
FN::cast_pointer(&alpha_), u_desc_, rhs_vec_desc_, rhs_vec_desc_,
FN::dtype, algo_sv_, u_spsv_desc_));
}
}

std::size_t get_device_memory_usage()
{
size_t nelements = csr_mat_.nnz() + rhs_tmp_.size() + result_tmp_.size();
size_t nelements = csr_mat_.nnz();
size_t nbuf = l_buf_.size() + u_buf_.size();
size_t nint = csr_mat_.nnz() + csr_mat_.shape(0) + 1;
return nelements * sizeof(T) + nint * sizeof(int) + nbuf;
Expand All @@ -246,20 +293,21 @@ class csr_matrix_lu_cuda_generic
gt::sparse::csr_matrix<T, space_type>& csr_mat_;
const T alpha_;
int nrhs_;
gt::gtensor<T, 2, space_type> rhs_tmp_;
gt::gtensor<T, 2, space_type> result_tmp_;

sparse_handle_t h_;
cusparseSpMatDescr_t l_desc_;
cusparseSpMatDescr_t u_desc_;
cusparseDnMatDescr_t rhs_desc_;
cusparseDnMatDescr_t result_desc_;
cusparseSpSMDescr_t l_spsm_desc_;
cusparseSpSMDescr_t u_spsm_desc_;
cusparseDnMatDescr_t rhs_mat_desc_;
cusparseSpSVDescr_t l_spsv_desc_;
cusparseSpSVDescr_t u_spsv_desc_;
cusparseDnVecDescr_t rhs_vec_desc_;
gt::gtensor_device<uint8_t, 1> l_buf_;
gt::gtensor_device<uint8_t, 1> u_buf_;

const cusparseSpSMAlg_t algo_ = CUSPARSE_SPSM_ALG_DEFAULT;
const cusparseSpSMAlg_t algo_sm_ = CUSPARSE_SPSM_ALG_DEFAULT;
const cusparseSpSVAlg_t algo_sv_ = CUSPARSE_SPSV_ALG_DEFAULT;

using FN = detail::csrsm_functions<T>;
};
Expand Down
File renamed without changes.
Loading

0 comments on commit 5a1e34f

Please sign in to comment.