From 90a3bbb06d5cd79122361ccb849920d9d878228d Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Thu, 26 Aug 2021 15:02:16 +0200 Subject: [PATCH] Add CUDA generic API triangular solver --- core/device_hooks/common_kernels.inc.cpp | 6 - core/solver/lower_trs.cpp | 11 +- core/solver/lower_trs_kernels.hpp | 14 +- core/solver/upper_trs.cpp | 11 +- core/solver/upper_trs_kernels.hpp | 14 +- cuda/base/cusparse_bindings.hpp | 167 +++++++-- cuda/base/pointer_mode_guard.hpp | 12 +- cuda/solver/common_trs_kernels.cuh | 447 +++++++++++++---------- cuda/solver/lower_trs_kernels.cu | 10 +- cuda/solver/upper_trs_kernels.cu | 10 +- dpcpp/solver/lower_trs_kernels.dp.cpp | 20 +- dpcpp/solver/upper_trs_kernels.dp.cpp | 16 +- hip/solver/lower_trs_kernels.hip.cpp | 12 +- hip/solver/upper_trs_kernels.hip.cpp | 15 +- include/ginkgo/core/solver/lower_trs.hpp | 3 - include/ginkgo/core/solver/upper_trs.hpp | 3 - omp/solver/lower_trs_kernels.cpp | 11 +- omp/solver/upper_trs_kernels.cpp | 11 +- omp/test/solver/lower_trs_kernels.cpp | 15 +- omp/test/solver/upper_trs_kernels.cpp | 15 +- reference/solver/lower_trs_kernels.cpp | 11 +- reference/solver/upper_trs_kernels.cpp | 11 +- 22 files changed, 438 insertions(+), 407 deletions(-) diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index d04d4831089..51b9436fdc6 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -422,9 +422,6 @@ namespace lower_trs { GKO_DECLARE_LOWER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL() GKO_NOT_COMPILED(GKO_HOOK_MODULE); -GKO_DECLARE_LOWER_TRS_INIT_STRUCT_KERNEL() -GKO_NOT_COMPILED(GKO_HOOK_MODULE); - template GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(ValueType, IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); @@ -447,9 +444,6 @@ namespace upper_trs { GKO_DECLARE_UPPER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL() GKO_NOT_COMPILED(GKO_HOOK_MODULE); -GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL() -GKO_NOT_COMPILED(GKO_HOOK_MODULE); - template GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(ValueType, IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); diff --git a/core/solver/lower_trs.cpp b/core/solver/lower_trs.cpp index cbe5f332f9a..bc5ebe4095c 100644 --- a/core/solver/lower_trs.cpp +++ b/core/solver/lower_trs.cpp @@ -55,7 +55,6 @@ namespace { GKO_REGISTER_OPERATION(generate, lower_trs::generate); -GKO_REGISTER_OPERATION(init_struct, lower_trs::init_struct); GKO_REGISTER_OPERATION(should_perform_transpose, lower_trs::should_perform_transpose); GKO_REGISTER_OPERATION(solve, lower_trs::solve); @@ -85,19 +84,11 @@ std::unique_ptr LowerTrs::conj_transpose() const } -template -void LowerTrs::init_trs_solve_struct() -{ - this->get_executor()->run(lower_trs::make_init_struct(this->solve_struct_)); -} - - template void LowerTrs::generate() { this->get_executor()->run(lower_trs::make_generate( - gko::lend(system_matrix_), gko::lend(this->solve_struct_), - parameters_.num_rhs)); + gko::lend(system_matrix_), this->solve_struct_, parameters_.num_rhs)); } diff --git a/core/solver/lower_trs_kernels.hpp b/core/solver/lower_trs_kernels.hpp index 5f9f272417a..778b4a13152 100644 --- a/core/solver/lower_trs_kernels.hpp +++ b/core/solver/lower_trs_kernels.hpp @@ -55,15 +55,10 @@ namespace lower_trs { bool &do_transpose) -#define GKO_DECLARE_LOWER_TRS_INIT_STRUCT_KERNEL() \ - void init_struct(std::shared_ptr exec, \ - std::shared_ptr &solve_struct) - - -#define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \ - void generate(std::shared_ptr exec, \ - const matrix::Csr<_vtype, _itype> *matrix, \ - solver::SolveStruct *solve_struct, \ +#define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \ + void generate(std::shared_ptr exec, \ + const matrix::Csr<_vtype, _itype> *matrix, \ + std::shared_ptr &solve_struct, \ const gko::size_type num_rhs) @@ -77,7 +72,6 @@ namespace lower_trs { #define GKO_DECLARE_ALL_AS_TEMPLATES \ GKO_DECLARE_LOWER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL(); \ - GKO_DECLARE_LOWER_TRS_INIT_STRUCT_KERNEL(); \ template \ GKO_DECLARE_LOWER_TRS_SOLVE_KERNEL(ValueType, IndexType); \ template \ diff --git a/core/solver/upper_trs.cpp b/core/solver/upper_trs.cpp index bae00182fc8..585c57f659b 100644 --- a/core/solver/upper_trs.cpp +++ b/core/solver/upper_trs.cpp @@ -55,7 +55,6 @@ namespace { GKO_REGISTER_OPERATION(generate, upper_trs::generate); -GKO_REGISTER_OPERATION(init_struct, upper_trs::init_struct); GKO_REGISTER_OPERATION(should_perform_transpose, upper_trs::should_perform_transpose); GKO_REGISTER_OPERATION(solve, upper_trs::solve); @@ -85,19 +84,11 @@ std::unique_ptr UpperTrs::conj_transpose() const } -template -void UpperTrs::init_trs_solve_struct() -{ - this->get_executor()->run(upper_trs::make_init_struct(this->solve_struct_)); -} - - template void UpperTrs::generate() { this->get_executor()->run(upper_trs::make_generate( - gko::lend(system_matrix_), gko::lend(this->solve_struct_), - parameters_.num_rhs)); + gko::lend(system_matrix_), this->solve_struct_, parameters_.num_rhs)); } diff --git a/core/solver/upper_trs_kernels.hpp b/core/solver/upper_trs_kernels.hpp index bdbc4a9b1d7..d61db2b1b62 100644 --- a/core/solver/upper_trs_kernels.hpp +++ b/core/solver/upper_trs_kernels.hpp @@ -55,15 +55,10 @@ namespace upper_trs { bool &do_transpose) -#define GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL() \ - void init_struct(std::shared_ptr exec, \ - std::shared_ptr &solve_struct) - - -#define GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(_vtype, _itype) \ - void generate(std::shared_ptr exec, \ - const matrix::Csr<_vtype, _itype> *matrix, \ - solver::SolveStruct *solve_struct, \ +#define GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(_vtype, _itype) \ + void generate(std::shared_ptr exec, \ + const matrix::Csr<_vtype, _itype> *matrix, \ + std::shared_ptr &solve_struct, \ const gko::size_type num_rhs) @@ -77,7 +72,6 @@ namespace upper_trs { #define GKO_DECLARE_ALL_AS_TEMPLATES \ GKO_DECLARE_UPPER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL(); \ - GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL(); \ template \ GKO_DECLARE_UPPER_TRS_SOLVE_KERNEL(ValueType, IndexType); \ template \ diff --git a/cuda/base/cusparse_bindings.hpp b/cuda/base/cusparse_bindings.hpp index 8a3de85293b..37612c7006c 100644 --- a/cuda/base/cusparse_bindings.hpp +++ b/cuda/base/cusparse_bindings.hpp @@ -745,10 +745,23 @@ inline cusparseMatDescr_t create_mat_descr() { cusparseMatDescr_t descr{}; GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateMatDescr(&descr)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetMatDiagType(descr, CUSPARSE_DIAG_TYPE_NON_UNIT)); return descr; } +inline void set_mat_fill_mode(cusparseMatDescr_t descr, + cusparseFillMode_t fill_mode) +{ + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSetMatFillMode(descr, fill_mode)); +} + + inline void destroy(cusparseMatDescr_t descr) { GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseDestroyMatDescr(descr)); @@ -806,6 +819,24 @@ inline void destroy(cusparseDnVecDescr_t descr) } +template +inline cusparseDnMatDescr_t create_dnmat(int64_t rows, int64_t cols, + int64_t stride, ValueType *values) +{ + cusparseDnMatDescr_t descr{}; + constexpr auto value_type = cuda_data_type(); + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateDnMat( + &descr, rows, cols, stride, values, value_type, CUSPARSE_ORDER_ROW)); + return descr; +} + + +inline void destroy(cusparseDnMatDescr_t descr) +{ + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseDestroyDnMat(descr)); +} + + template inline cusparseSpVecDescr_t create_spvec(int64_t size, int64_t nnz, IndexType *indices, ValueType *values) @@ -847,7 +878,37 @@ inline void destroy(cusparseSpMatDescr_t descr) GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseDestroySpMat(descr)); } -#endif + +#if (CUDA_VERSION >= 11030) + + +template +inline void set_attribute(cusparseSpMatDescr_t desc, + cusparseSpMatAttribute_t attr, AttribType val) +{ + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSpMatSetAttribute(desc, attr, &val, sizeof(val))); +} + + +inline cusparseSpSMDescr_t create_spsm_descr() +{ + cusparseSpSMDescr_t desc{}; + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSpSM_createDescr(&desc)); + return desc; +} + + +inline void destroy(cusparseSpSMDescr_t info) +{ + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSpSM_destroyDescr(info)); +} + + +#endif // CUDA_VERSION >= 11030 + + +#endif // defined(CUDA_VERSION) && (CUDA_VERSION >= 11000) // CUDA versions 9.2 and above have csrsm2. @@ -925,7 +986,7 @@ inline void destroy(csric02Info_t info) inline void buffer_size_ext( \ cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ - const ValueType *one, const cusparseMatDescr_t descr, \ + ValueType one, const cusparseMatDescr_t descr, \ const ValueType *csrVal, const int32 *csrRowPtr, \ const int32 *csrColInd, const ValueType *rhs, int32 sol_size, \ csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ @@ -933,7 +994,7 @@ inline void destroy(csric02Info_t info) { \ GKO_ASSERT_NO_CUSPARSE_ERRORS( \ CusparseName(handle, algo, trans1, trans2, m, n, nnz, \ - as_culibs_type(one), descr, as_culibs_type(csrVal), \ + as_culibs_type(&one), descr, as_culibs_type(csrVal), \ csrRowPtr, csrColInd, as_culibs_type(rhs), sol_size, \ factor_info, policy, factor_work_size)); \ } \ @@ -945,7 +1006,7 @@ inline void destroy(csric02Info_t info) inline void buffer_size_ext( \ cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ - const ValueType *one, const cusparseMatDescr_t descr, \ + ValueType one, const cusparseMatDescr_t descr, \ const ValueType *csrVal, const int64 *csrRowPtr, \ const int64 *csrColInd, const ValueType *rhs, int64 sol_size, \ csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ @@ -978,7 +1039,7 @@ GKO_BIND_CUSPARSE64_BUFFERSIZEEXT(ValueType, detail::not_implemented); inline void csrsm2_analysis( \ cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ - const ValueType *one, const cusparseMatDescr_t descr, \ + ValueType one, const cusparseMatDescr_t descr, \ const ValueType *csrVal, const int32 *csrRowPtr, \ const int32 *csrColInd, const ValueType *rhs, int32 sol_size, \ csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ @@ -986,7 +1047,7 @@ GKO_BIND_CUSPARSE64_BUFFERSIZEEXT(ValueType, detail::not_implemented); { \ GKO_ASSERT_NO_CUSPARSE_ERRORS( \ CusparseName(handle, algo, trans1, trans2, m, n, nnz, \ - as_culibs_type(one), descr, as_culibs_type(csrVal), \ + as_culibs_type(&one), descr, as_culibs_type(csrVal), \ csrRowPtr, csrColInd, as_culibs_type(rhs), sol_size, \ factor_info, policy, factor_work_vec)); \ } \ @@ -998,7 +1059,7 @@ GKO_BIND_CUSPARSE64_BUFFERSIZEEXT(ValueType, detail::not_implemented); inline void csrsm2_analysis( \ cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ - const ValueType *one, const cusparseMatDescr_t descr, \ + ValueType one, const cusparseMatDescr_t descr, \ const ValueType *csrVal, const int64 *csrRowPtr, \ const int64 *csrColInd, const ValueType *rhs, int64 sol_size, \ csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ @@ -1027,31 +1088,31 @@ GKO_BIND_CUSPARSE64_CSRSM2_ANALYSIS(ValueType, detail::not_implemented); #undef GKO_BIND_CUSPARSE64_CSRSM2_ANALYSIS -#define GKO_BIND_CUSPARSE32_CSRSM2_SOLVE(ValueType, CusparseName) \ - inline void csrsm2_solve( \ - cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ - cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ - const ValueType *one, const cusparseMatDescr_t descr, \ - const ValueType *csrVal, const int32 *csrRowPtr, \ - const int32 *csrColInd, ValueType *rhs, int32 sol_stride, \ - csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ - void *factor_work_vec) \ - { \ - GKO_ASSERT_NO_CUSPARSE_ERRORS( \ - CusparseName(handle, algo, trans1, trans2, m, n, nnz, \ - as_culibs_type(one), descr, as_culibs_type(csrVal), \ - csrRowPtr, csrColInd, as_culibs_type(rhs), \ - sol_stride, factor_info, policy, factor_work_vec)); \ - } \ - static_assert(true, \ - "This assert is used to counter the false positive extra " \ +#define GKO_BIND_CUSPARSE32_CSRSM2_SOLVE(ValueType, CusparseName) \ + inline void csrsm2_solve( \ + cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ + cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ + ValueType one, const cusparseMatDescr_t descr, \ + const ValueType *csrVal, const int32 *csrRowPtr, \ + const int32 *csrColInd, ValueType *rhs, int32 sol_stride, \ + csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ + void *factor_work_vec) \ + { \ + GKO_ASSERT_NO_CUSPARSE_ERRORS( \ + CusparseName(handle, algo, trans1, trans2, m, n, nnz, \ + as_culibs_type(&one), descr, as_culibs_type(csrVal), \ + csrRowPtr, csrColInd, as_culibs_type(rhs), \ + sol_stride, factor_info, policy, factor_work_vec)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ "semi-colon warnings") #define GKO_BIND_CUSPARSE64_CSRSM2_SOLVE(ValueType, CusparseName) \ inline void csrsm2_solve( \ cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ - const ValueType *one, const cusparseMatDescr_t descr, \ + ValueType one, const cusparseMatDescr_t descr, \ const ValueType *csrVal, const int64 *csrRowPtr, \ const int64 *csrColInd, ValueType *rhs, int64 sol_stride, \ csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ @@ -1129,14 +1190,14 @@ GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS(ValueType, detail::not_implemented); #define GKO_BIND_CUSPARSE32_CSRSM_SOLVE(ValueType, CusparseName) \ inline void csrsm_solve( \ cusparseHandle_t handle, cusparseOperation_t trans, size_type m, \ - size_type n, const ValueType *one, const cusparseMatDescr_t descr, \ + size_type n, ValueType one, const cusparseMatDescr_t descr, \ const ValueType *csrVal, const int32 *csrRowPtr, \ const int32 *csrColInd, cusparseSolveAnalysisInfo_t factor_info, \ const ValueType *rhs, int32 rhs_stride, ValueType *sol, \ int32 sol_stride) \ { \ GKO_ASSERT_NO_CUSPARSE_ERRORS( \ - CusparseName(handle, trans, m, n, as_culibs_type(one), descr, \ + CusparseName(handle, trans, m, n, as_culibs_type(&one), descr, \ as_culibs_type(csrVal), csrRowPtr, csrColInd, \ factor_info, as_culibs_type(rhs), rhs_stride, \ as_culibs_type(sol), sol_stride)); \ @@ -1148,7 +1209,7 @@ GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS(ValueType, detail::not_implemented); #define GKO_BIND_CUSPARSE64_CSRSM_SOLVE(ValueType, CusparseName) \ inline void csrsm_solve( \ cusparseHandle_t handle, cusparseOperation_t trans1, size_type m, \ - size_type n, const ValueType *one, const cusparseMatDescr_t descr, \ + size_type n, ValueType one, const cusparseMatDescr_t descr, \ const ValueType *csrVal, const int64 *csrRowPtr, \ const int64 *csrColInd, cusparseSolveAnalysisInfo_t factor_info, \ const ValueType *rhs, int64 rhs_stride, ValueType *sol, \ @@ -1176,6 +1237,54 @@ GKO_BIND_CUSPARSE64_CSRSM_SOLVE(ValueType, detail::not_implemented); #endif +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11030)) + + +template +size_type spsm_buffer_size(cusparseHandle_t handle, cusparseOperation_t op_a, + cusparseOperation_t op_b, ValueType alpha, + cusparseSpMatDescr_t descr_a, + cusparseDnMatDescr_t descr_b, + cusparseDnMatDescr_t descr_c, cusparseSpSMAlg_t algo, + cusparseSpSMDescr_t spsm_descr) +{ + size_type work_size; + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSpSM_bufferSize( + handle, op_a, op_b, &alpha, descr_a, descr_b, descr_c, + cuda_data_type(), algo, spsm_descr, &work_size)); + return work_size; +} + + +template +void spsm_analysis(cusparseHandle_t handle, cusparseOperation_t op_a, + cusparseOperation_t op_b, ValueType alpha, + cusparseSpMatDescr_t descr_a, cusparseDnMatDescr_t descr_b, + cusparseDnMatDescr_t descr_c, cusparseSpSMAlg_t algo, + cusparseSpSMDescr_t spsm_descr, void *work) +{ + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSpSM_analysis( + handle, op_a, op_b, &alpha, descr_a, descr_b, descr_c, + cuda_data_type(), algo, spsm_descr, work)); +} + + +template +void spsm_solve(cusparseHandle_t handle, cusparseOperation_t op_a, + cusparseOperation_t op_b, ValueType alpha, + cusparseSpMatDescr_t descr_a, cusparseDnMatDescr_t descr_b, + cusparseDnMatDescr_t descr_c, cusparseSpSMAlg_t algo, + cusparseSpSMDescr_t spsm_descr) +{ + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSpSM_solve( + handle, op_a, op_b, &alpha, descr_a, descr_b, descr_c, + cuda_data_type(), algo, spsm_descr)); +} + + +#endif + + template void create_identity_permutation(cusparseHandle_t handle, IndexType size, IndexType *permutation) GKO_NOT_IMPLEMENTED; diff --git a/cuda/base/pointer_mode_guard.hpp b/cuda/base/pointer_mode_guard.hpp index 72ac4f372d7..3cd245ef2ca 100644 --- a/cuda/base/pointer_mode_guard.hpp +++ b/cuda/base/pointer_mode_guard.hpp @@ -108,9 +108,9 @@ namespace cusparse { */ class pointer_mode_guard { public: - pointer_mode_guard(cusparseHandle_t &handle) + pointer_mode_guard(cusparseHandle_t handle) { - l_handle = &handle; + l_handle = handle; GKO_ASSERT_NO_CUSPARSE_ERRORS( cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST)); } @@ -127,15 +127,15 @@ class pointer_mode_guard { { /* Ignore the error during stack unwinding for this call */ if (std::uncaught_exception()) { - cusparseSetPointerMode(*l_handle, CUSPARSE_POINTER_MODE_DEVICE); + cusparseSetPointerMode(l_handle, CUSPARSE_POINTER_MODE_DEVICE); } else { - GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSetPointerMode( - *l_handle, CUSPARSE_POINTER_MODE_DEVICE)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetPointerMode(l_handle, CUSPARSE_POINTER_MODE_DEVICE)); } } private: - cusparseHandle_t *l_handle; + cusparseHandle_t l_handle; }; diff --git a/cuda/solver/common_trs_kernels.cuh b/cuda/solver/common_trs_kernels.cuh index 53d6661c45f..fdb5885e633 100644 --- a/cuda/solver/common_trs_kernels.cuh +++ b/cuda/solver/common_trs_kernels.cuh @@ -35,6 +35,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include #include @@ -59,207 +60,318 @@ namespace solver { struct SolveStruct { - virtual void dummy() {} + virtual ~SolveStruct() = default; }; +} // namespace solver + + +namespace kernels { namespace cuda { +namespace { + + +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11030)) + + +template +struct CudaSolveStruct : gko::solver::SolveStruct { + cusparseHandle_t handle; + cusparseSpSMDescr_t spsm_descr; + cusparseSpMatDescr_t descr_a; + Array work; + + CudaSolveStruct(std::shared_ptr exec, + const matrix::Csr *matrix, + size_type num_rhs, bool is_upper) + : handle{exec->get_cusparse_handle()}, + spsm_descr{}, + descr_a{}, + work{exec} + { + cusparse::pointer_mode_guard pm_guard(handle); + spsm_descr = cusparse::create_spsm_descr(); + descr_a = cusparse::create_csr( + matrix->get_size()[0], matrix->get_size()[1], + matrix->get_num_stored_elements(), + const_cast(matrix->get_const_row_ptrs()), + const_cast(matrix->get_const_col_idxs()), + const_cast(matrix->get_const_values())); + cusparse::set_attribute( + descr_a, CUSPARSE_SPMAT_FILL_MODE, + is_upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER); + cusparse::set_attribute( + descr_a, CUSPARSE_SPMAT_DIAG_TYPE, CUSPARSE_DIAG_TYPE_NON_UNIT); + + const auto rows = matrix->get_size()[0]; + // workaround suggested by NVIDIA engineers: for some reason + // cusparse needs non-nullptr input vectors even for analysis + auto descr_b = cusparse::create_dnmat( + matrix->get_size()[0], num_rhs, matrix->get_size()[1], + reinterpret_cast(0xFF)); + auto descr_c = cusparse::create_dnmat( + matrix->get_size()[0], num_rhs, matrix->get_size()[1], + reinterpret_cast(0xFF)); + + auto work_size = cusparse::spsm_buffer_size( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, one(), descr_a, + descr_b, descr_c, CUSPARSE_SPSM_ALG_DEFAULT, spsm_descr); + + work.resize_and_reset(work_size); + + cusparse::spsm_analysis(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, + one(), descr_a, descr_b, descr_c, + CUSPARSE_SPSM_ALG_DEFAULT, spsm_descr, + work.get_data()); + + cusparse::destroy(descr_b); + cusparse::destroy(descr_c); + } + void solve(const matrix::Csr *, + const matrix::Dense *input, + matrix::Dense *output, matrix::Dense *, + matrix::Dense *) const + { + cusparse::pointer_mode_guard pm_guard(handle); + auto descr_b = cusparse::create_dnmat( + input->get_size()[0], input->get_size()[1], input->get_stride(), + const_cast(input->get_const_values())); + auto descr_c = + cusparse::create_dnmat(output->get_size()[0], output->get_size()[1], + output->get_stride(), output->get_values()); + + cusparse::spsm_solve(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, one(), + descr_a, descr_b, descr_c, + CUSPARSE_SPSM_ALG_DEFAULT, spsm_descr); + + cusparse::destroy(descr_b); + cusparse::destroy(descr_c); + } + + ~CudaSolveStruct() + { + if (descr_a) { + cusparse::destroy(descr_a); + descr_a = nullptr; + } + if (spsm_descr) { + cusparse::destroy(spsm_descr); + spsm_descr = nullptr; + } + } + + CudaSolveStruct(const SolveStruct &) = delete; + + CudaSolveStruct(SolveStruct &&) = delete; + + CudaSolveStruct &operator=(const SolveStruct &) = delete; + + CudaSolveStruct &operator=(SolveStruct &&) = delete; +}; -#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) +#elif (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) -struct SolveStruct : gko::solver::SolveStruct { +template +struct CudaSolveStruct : gko::solver::SolveStruct { + std::shared_ptr exec; + cusparseHandle_t handle; int algorithm; csrsm2Info_t solve_info; cusparseSolvePolicy_t policy; cusparseMatDescr_t factor_descr; - size_t factor_work_size; - void *factor_work_vec; - SolveStruct() + mutable Array work; + + CudaSolveStruct(std::shared_ptr exec, + const matrix::Csr *matrix, + size_type num_rhs, bool is_upper) + : exec{exec}, + handle{exec->get_cusparse_handle()}, + algorithm{}, + solve_info{}, + policy{}, + factor_descr{}, + work{exec} { - factor_work_vec = nullptr; - GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateMatDescr(&factor_descr)); - GKO_ASSERT_NO_CUSPARSE_ERRORS( - cusparseSetMatIndexBase(factor_descr, CUSPARSE_INDEX_BASE_ZERO)); - GKO_ASSERT_NO_CUSPARSE_ERRORS( - cusparseSetMatType(factor_descr, CUSPARSE_MATRIX_TYPE_GENERAL)); - GKO_ASSERT_NO_CUSPARSE_ERRORS( - cusparseSetMatDiagType(factor_descr, CUSPARSE_DIAG_TYPE_NON_UNIT)); - GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateCsrsm2Info(&solve_info)); + cusparse::pointer_mode_guard pm_guard(handle); + factor_descr = cusparse::create_mat_descr(); + solve_info = cusparse::create_solve_info(); + cusparse::set_mat_fill_mode( + factor_descr, + is_upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER); algorithm = 0; policy = CUSPARSE_SOLVE_POLICY_USE_LEVEL; - } - - SolveStruct(const SolveStruct &) = delete; - SolveStruct(SolveStruct &&) = delete; - - SolveStruct &operator=(const SolveStruct &) = delete; + size_type work_size{}; + + cusparse::buffer_size_ext( + handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs, + matrix->get_num_stored_elements(), one(), factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), nullptr, num_rhs, solve_info, policy, + &work_size); + + // allocate workspace + work.resize_and_reset(work_size); + + cusparse::csrsm2_analysis( + handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs, + matrix->get_num_stored_elements(), one(), factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), nullptr, num_rhs, solve_info, policy, + work.get_data()); + } - SolveStruct &operator=(SolveStruct &&) = delete; + void solve(const matrix::Csr *matrix, + const matrix::Dense *input, + matrix::Dense *output, matrix::Dense *, + matrix::Dense *) const + { + cusparse::pointer_mode_guard pm_guard(handle); + dense::copy(exec, input, output); + cusparse::csrsm2_solve( + handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], + output->get_stride(), matrix->get_num_stored_elements(), + one(), factor_descr, matrix->get_const_values(), + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + output->get_values(), output->get_stride(), solve_info, policy, + work.get_data()); + } - ~SolveStruct() + ~CudaSolveStruct() { - cusparseDestroyMatDescr(factor_descr); - if (solve_info) { - cusparseDestroyCsrsm2Info(solve_info); + if (factor_descr) { + cusparse::destroy(factor_descr); + factor_descr = nullptr; } - if (factor_work_vec != nullptr) { - cudaFree(factor_work_vec); - factor_work_vec = nullptr; + if (solve_info) { + cusparse::destroy(solve_info); + solve_info = nullptr; } } + + CudaSolveStruct(const CudaSolveStruct &) = delete; + + CudaSolveStruct(CudaSolveStruct &&) = delete; + + CudaSolveStruct &operator=(const CudaSolveStruct &) = delete; + + CudaSolveStruct &operator=(CudaSolveStruct &&) = delete; }; #elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020)) -struct SolveStruct : gko::solver::SolveStruct { +struct CudaSolveStruct : gko::solver::SolveStruct { cusparseSolveAnalysisInfo_t solve_info; cusparseMatDescr_t factor_descr; - SolveStruct() + CudaSolveStruct(std::shared_ptr exec, + const matrix::Csr *matrix, + size_type num_rhs, bool is_upper) { - GKO_ASSERT_NO_CUSPARSE_ERRORS( - cusparseCreateSolveAnalysisInfo(&solve_info)); - GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateMatDescr(&factor_descr)); - GKO_ASSERT_NO_CUSPARSE_ERRORS( - cusparseSetMatIndexBase(factor_descr, CUSPARSE_INDEX_BASE_ZERO)); - GKO_ASSERT_NO_CUSPARSE_ERRORS( - cusparseSetMatType(factor_descr, CUSPARSE_MATRIX_TYPE_GENERAL)); - GKO_ASSERT_NO_CUSPARSE_ERRORS( - cusparseSetMatDiagType(factor_descr, CUSPARSE_DIAG_TYPE_NON_UNIT)); + cusparse::pointer_mode_guard pm_guard(handle); + solve_info = cusparse::create_solve_info(); + factor_descr = cusparse::create_mat_descr(); + cusparse::set_mat_fill_mode( + factor_descr, + is_upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER); + cusparse::csrsm_analysis( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, matrix->get_size()[0], + matrix->get_num_stored_elements(), factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), solve_info); } - SolveStruct(const SolveStruct &) = delete; - - SolveStruct(SolveStruct &&) = delete; - - SolveStruct &operator=(const SolveStruct &) = delete; - - SolveStruct &operator=(SolveStruct &&) = delete; + void solve(const matrix::Csr *matrix, + const matrix::Dense *input, + matrix::Dense *output, + matrix::Dense *trans_in, + matrix::Dense *trans_out) const + { + cusparse::pointer_mode_guard pm_guard(handle); + if (input->get_stride() == 1 && output->get_stride() == 1) { + cusparse::csrsm_solve( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, matrix->get_size()[0], + input->get_stride(), one(), factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), solve_info, + input->get_const_values(), input->get_size()[0], + output->get_values(), output->get_size()[0]); + } else { + dense::transpose(exec, input, trans_in); + cusparse::csrsm_solve( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, matrix->get_size()[0], + matrix->get_size()[1], one(), factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), solve_info, + trans_in->get_values(), trans_in->get_stride(), + trans_out->get_values(), trans_out->get_stride()); + dense::transpose(exec, trans_out, out); + } + } - ~SolveStruct() + ~CudaSolveStruct() { - cusparseDestroyMatDescr(factor_descr); - cusparseDestroySolveAnalysisInfo(solve_info); + if (factor_descr) { + cusparse::destroy(factor_descr); + factor_descr = nullptr; + } + if (solve_info) { + cusparse::destroy(solve_info); + solve_info = nullptr; + } } -}; + CudaSolveStruct(const CudaSolveStruct &) = delete; -#endif + CudaSolveStruct(CudaSolveStruct &&) = delete; + CudaSolveStruct &operator=(const CudaSolveStruct &) = delete; -} // namespace cuda -} // namespace solver + CudaSolveStruct &operator=(CudaSolveStruct &&) = delete; +}; -namespace kernels { -namespace cuda { -namespace { +#endif void should_perform_transpose_kernel(std::shared_ptr exec, bool &do_transpose) { -#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) - +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11030) do_transpose = false; +#elif (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) -#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020)) + do_transpose = false; +#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020)) do_transpose = true; - #endif } -void init_struct_kernel(std::shared_ptr exec, - std::shared_ptr &solve_struct) -{ - solve_struct = std::make_shared(); -} - - template void generate_kernel(std::shared_ptr exec, const matrix::Csr *matrix, - solver::SolveStruct *solve_struct, + std::shared_ptr &solve_struct, const gko::size_type num_rhs, bool is_upper) { if (cusparse::is_supported::value) { - if (auto cuda_solve_struct = - dynamic_cast(solve_struct)) { - auto handle = exec->get_cusparse_handle(); - if (is_upper) { - GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSetMatFillMode( - cuda_solve_struct->factor_descr, CUSPARSE_FILL_MODE_UPPER)); - } - - -#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) - - - ValueType one = 1.0; - - { - cusparse::pointer_mode_guard pm_guard(handle); - cusparse::buffer_size_ext( - handle, cuda_solve_struct->algorithm, - CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], - num_rhs, matrix->get_num_stored_elements(), &one, - cuda_solve_struct->factor_descr, matrix->get_const_values(), - matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), - nullptr, num_rhs, cuda_solve_struct->solve_info, - cuda_solve_struct->policy, - &cuda_solve_struct->factor_work_size); - - // allocate workspace - if (cuda_solve_struct->factor_work_vec != nullptr) { - exec->free(cuda_solve_struct->factor_work_vec); - } - cuda_solve_struct->factor_work_vec = - exec->alloc(cuda_solve_struct->factor_work_size); - - cusparse::csrsm2_analysis( - handle, cuda_solve_struct->algorithm, - CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], - num_rhs, matrix->get_num_stored_elements(), &one, - cuda_solve_struct->factor_descr, matrix->get_const_values(), - matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), - nullptr, num_rhs, cuda_solve_struct->solve_info, - cuda_solve_struct->policy, - cuda_solve_struct->factor_work_vec); - } - - -#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020)) - - - { - cusparse::pointer_mode_guard pm_guard(handle); - cusparse::csrsm_analysis( - handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - matrix->get_size()[0], matrix->get_num_stored_elements(), - cuda_solve_struct->factor_descr, matrix->get_const_values(), - matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), - cuda_solve_struct->solve_info); - } - - -#endif - - - } else { - GKO_NOT_SUPPORTED(solve_struct); - } + solve_struct = std::make_shared>( + exec, matrix, num_rhs, is_upper); } else { GKO_NOT_IMPLEMENTED; } @@ -279,66 +391,9 @@ void solve_kernel(std::shared_ptr exec, if (cusparse::is_supported::value) { if (auto cuda_solve_struct = - dynamic_cast(solve_struct)) { - ValueType one = 1.0; - auto handle = exec->get_cusparse_handle(); - - -#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) - - - x->copy_from(gko::lend(b)); - { - cusparse::pointer_mode_guard pm_guard(handle); - cusparse::csrsm2_solve( - handle, cuda_solve_struct->algorithm, - CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], - b->get_stride(), matrix->get_num_stored_elements(), &one, - cuda_solve_struct->factor_descr, matrix->get_const_values(), - matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), - x->get_values(), b->get_stride(), - cuda_solve_struct->solve_info, cuda_solve_struct->policy, - cuda_solve_struct->factor_work_vec); - } - - -#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020)) - - - { - cusparse::pointer_mode_guard pm_guard(handle); - if (b->get_stride() == 1) { - cusparse::csrsm_solve( - handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - matrix->get_size()[0], b->get_stride(), &one, - cuda_solve_struct->factor_descr, - matrix->get_const_values(), - matrix->get_const_row_ptrs(), - matrix->get_const_col_idxs(), - cuda_solve_struct->solve_info, b->get_const_values(), - b->get_size()[0], x->get_values(), x->get_size()[0]); - } else { - dense::transpose(exec, b, trans_b); - dense::transpose(exec, x, trans_x); - cusparse::csrsm_solve( - handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - matrix->get_size()[0], trans_b->get_size()[0], &one, - cuda_solve_struct->factor_descr, - matrix->get_const_values(), - matrix->get_const_row_ptrs(), - matrix->get_const_col_idxs(), - cuda_solve_struct->solve_info, trans_b->get_values(), - trans_b->get_size()[1], trans_x->get_values(), - trans_x->get_size()[1]); - dense::transpose(exec, trans_x, x); - } - } - - -#endif - - + dynamic_cast *>( + solve_struct)) { + cuda_solve_struct->solve(matrix, b, x, trans_b, trans_x); } else { GKO_NOT_SUPPORTED(solve_struct); } diff --git a/cuda/solver/lower_trs_kernels.cu b/cuda/solver/lower_trs_kernels.cu index d83e7218d72..bf44ff7bf6a 100644 --- a/cuda/solver/lower_trs_kernels.cu +++ b/cuda/solver/lower_trs_kernels.cu @@ -255,17 +255,11 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr &solve_struct) -{ - init_struct_kernel(exec, solve_struct); -} - - template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - solver::SolveStruct *solve_struct, const gko::size_type num_rhs) + std::shared_ptr &solve_struct, + const gko::size_type num_rhs) { if (matrix->get_strategy()->get_name() == "sparselib") { generate_kernel(exec, matrix, solve_struct, diff --git a/cuda/solver/upper_trs_kernels.cu b/cuda/solver/upper_trs_kernels.cu index 361a0738d0d..b67a621b7f4 100644 --- a/cuda/solver/upper_trs_kernels.cu +++ b/cuda/solver/upper_trs_kernels.cu @@ -69,17 +69,11 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr &solve_struct) -{ - init_struct_kernel(exec, solve_struct); -} - - template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - solver::SolveStruct *solve_struct, const gko::size_type num_rhs) + std::shared_ptr &solve_struct, + const gko::size_type num_rhs) { generate_kernel(exec, matrix, solve_struct, num_rhs, true); diff --git a/dpcpp/solver/lower_trs_kernels.dp.cpp b/dpcpp/solver/lower_trs_kernels.dp.cpp index 7144108593f..eb8b4338f84 100644 --- a/dpcpp/solver/lower_trs_kernels.dp.cpp +++ b/dpcpp/solver/lower_trs_kernels.dp.cpp @@ -59,27 +59,11 @@ namespace dpcpp { namespace lower_trs { -void should_perform_transpose(std::shared_ptr exec, - bool &do_transpose) GKO_NOT_IMPLEMENTED; - - -void init_struct(std::shared_ptr exec, - std::shared_ptr &solve_struct) -{ - // This init kernel is here to allow initialization of the solve struct for - // a more sophisticated implementation as for other executors. -} - - template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - solver::SolveStruct *solve_struct, const gko::size_type num_rhs) -{ - // This generate kernel is here to allow for a more sophisticated - // implementation as for other executors. This kernel would perform the - // "analysis" phase for the triangular matrix. -} + std::shared_ptr &solve_struct, + const gko::size_type num_rhs) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL); diff --git a/dpcpp/solver/upper_trs_kernels.dp.cpp b/dpcpp/solver/upper_trs_kernels.dp.cpp index cc1d40f711d..675776ce562 100644 --- a/dpcpp/solver/upper_trs_kernels.dp.cpp +++ b/dpcpp/solver/upper_trs_kernels.dp.cpp @@ -66,23 +66,11 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr &solve_struct) -{ - // This init kernel is here to allow initialization of the solve struct for - // a more sophisticated implementation as for other executors. -} - - template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - solver::SolveStruct *solve_struct, const gko::size_type num_rhs) -{ - // This generate kernel is here to allow for a more sophisticated - // implementation as for other executors. This kernel would perform the - // "analysis" phase for the triangular matrix. -} + std::shared_ptr &solve_struct, + const gko::size_type num_rhs) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL); diff --git a/hip/solver/lower_trs_kernels.hip.cpp b/hip/solver/lower_trs_kernels.hip.cpp index 3eeb50185ac..6f638fbb9c6 100644 --- a/hip/solver/lower_trs_kernels.hip.cpp +++ b/hip/solver/lower_trs_kernels.hip.cpp @@ -71,18 +71,18 @@ void should_perform_transpose(std::shared_ptr exec, void init_struct(std::shared_ptr exec, std::shared_ptr &solve_struct) -{ - init_struct_kernel(exec, solve_struct); -} +{} template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - solver::SolveStruct *solve_struct, const gko::size_type num_rhs) + std::shared_ptr &solve_struct, + const gko::size_type num_rhs) { - generate_kernel(exec, matrix, solve_struct, num_rhs, - false); + init_struct_kernel(exec, solve_struct); + generate_kernel(exec, matrix, solve_struct.get(), + num_rhs, false); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( diff --git a/hip/solver/upper_trs_kernels.hip.cpp b/hip/solver/upper_trs_kernels.hip.cpp index 835e2f3803c..74f643b69cf 100644 --- a/hip/solver/upper_trs_kernels.hip.cpp +++ b/hip/solver/upper_trs_kernels.hip.cpp @@ -69,20 +69,15 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr &solve_struct) -{ - init_struct_kernel(exec, solve_struct); -} - - template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - solver::SolveStruct *solve_struct, const gko::size_type num_rhs) + std::shared_ptr &solve_struct, + const gko::size_type num_rhs) { - generate_kernel(exec, matrix, solve_struct, num_rhs, - true); + init_struct_kernel(exec, solve_struct); + generate_kernel(exec, matrix, solve_struct.get(), + num_rhs, true); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( diff --git a/include/ginkgo/core/solver/lower_trs.hpp b/include/ginkgo/core/solver/lower_trs.hpp index c3b95702fd0..00c160b6619 100644 --- a/include/ginkgo/core/solver/lower_trs.hpp +++ b/include/ginkgo/core/solver/lower_trs.hpp @@ -123,8 +123,6 @@ class LowerTrs : public EnableLinOp>, GKO_ENABLE_BUILD_METHOD(Factory); protected: - void init_trs_solve_struct(); - void apply_impl(const LinOp *b, LinOp *x) const override; void apply_impl(const LinOp *alpha, const LinOp *b, const LinOp *beta, @@ -159,7 +157,6 @@ class LowerTrs : public EnableLinOp>, system_matrix_ = copy_and_convert_to(exec, system_matrix); } - this->init_trs_solve_struct(); this->generate(); } diff --git a/include/ginkgo/core/solver/upper_trs.hpp b/include/ginkgo/core/solver/upper_trs.hpp index dd82b5df6e4..a275a60a8b1 100644 --- a/include/ginkgo/core/solver/upper_trs.hpp +++ b/include/ginkgo/core/solver/upper_trs.hpp @@ -123,8 +123,6 @@ class UpperTrs : public EnableLinOp>, GKO_ENABLE_BUILD_METHOD(Factory); protected: - void init_trs_solve_struct(); - void apply_impl(const LinOp *b, LinOp *x) const override; void apply_impl(const LinOp *alpha, const LinOp *b, const LinOp *beta, @@ -159,7 +157,6 @@ class UpperTrs : public EnableLinOp>, system_matrix_ = copy_and_convert_to(exec, system_matrix); } - this->init_trs_solve_struct(); this->generate(); } diff --git a/omp/solver/lower_trs_kernels.cpp b/omp/solver/lower_trs_kernels.cpp index b313f762d2f..2e2bbe72df2 100644 --- a/omp/solver/lower_trs_kernels.cpp +++ b/omp/solver/lower_trs_kernels.cpp @@ -66,18 +66,11 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr &solve_struct) -{ - // This init kernel is here to allow initialization of the solve struct for - // a more sophisticated implementation as for other executors. -} - - template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - solver::SolveStruct *solve_struct, const gko::size_type num_rhs) + std::shared_ptr &solve_struct, + const gko::size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated // implementation as for other executors. This kernel would perform the diff --git a/omp/solver/upper_trs_kernels.cpp b/omp/solver/upper_trs_kernels.cpp index 29ecfda6c8c..361cad0f348 100644 --- a/omp/solver/upper_trs_kernels.cpp +++ b/omp/solver/upper_trs_kernels.cpp @@ -66,18 +66,11 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr &solve_struct) -{ - // This init kernel is here to allow initialization of the solve struct for - // a more sophisticated implementation as for other executors. -} - - template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - solver::SolveStruct *solve_struct, const gko::size_type num_rhs) + std::shared_ptr &solve_struct, + const gko::size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated // implementation as for other executors. This kernel would perform the diff --git a/omp/test/solver/lower_trs_kernels.cpp b/omp/test/solver/lower_trs_kernels.cpp index 1a15fbe5647..eca1e1ea0bf 100644 --- a/omp/test/solver/lower_trs_kernels.cpp +++ b/omp/test/solver/lower_trs_kernels.cpp @@ -139,20 +139,13 @@ TEST_F(LowerTrs, OmpLowerTrsFlagCheckIsCorrect) } -TEST_F(LowerTrs, OmpLowerTrsSolveStructInitIsEquivalentToRef) -{ - gko::kernels::reference::lower_trs::init_struct(ref, solve_struct_ref); - gko::kernels::omp::lower_trs::init_struct(omp, solve_struct_omp); -} - - TEST_F(LowerTrs, OmpLowerTrsGenerateIsEquivalentToRef) { gko::size_type num_rhs = 1; - gko::kernels::reference::lower_trs::generate( - ref, csr_mat.get(), solve_struct_ref.get(), num_rhs); + gko::kernels::reference::lower_trs::generate(ref, csr_mat.get(), + solve_struct_ref, num_rhs); gko::kernels::omp::lower_trs::generate(omp, d_csr_mat.get(), - solve_struct_omp.get(), num_rhs); + solve_struct_omp, num_rhs); } @@ -160,8 +153,6 @@ TEST_F(LowerTrs, OmpLowerTrsSolveIsEquivalentToRef) { initialize_data(59, 43); - gko::kernels::reference::lower_trs::init_struct(ref, solve_struct_ref); - gko::kernels::omp::lower_trs::init_struct(omp, solve_struct_omp); gko::kernels::reference::lower_trs::solve(ref, csr_mat.get(), solve_struct_ref.get(), t_b.get(), t_x.get(), b.get(), x.get()); diff --git a/omp/test/solver/upper_trs_kernels.cpp b/omp/test/solver/upper_trs_kernels.cpp index d3381167f43..53d5822f617 100644 --- a/omp/test/solver/upper_trs_kernels.cpp +++ b/omp/test/solver/upper_trs_kernels.cpp @@ -138,20 +138,13 @@ TEST_F(UpperTrs, OmpUpperTrsFlagCheckIsCorrect) } -TEST_F(UpperTrs, OmpUpperTrsSolveStructInitIsEquivalentToRef) -{ - gko::kernels::reference::upper_trs::init_struct(ref, solve_struct_ref); - gko::kernels::omp::upper_trs::init_struct(omp, solve_struct_omp); -} - - TEST_F(UpperTrs, OmpUpperTrsGenerateIsEquivalentToRef) { gko::size_type num_rhs = 1; - gko::kernels::reference::upper_trs::generate( - ref, csr_mat.get(), solve_struct_ref.get(), num_rhs); + gko::kernels::reference::upper_trs::generate(ref, csr_mat.get(), + solve_struct_ref, num_rhs); gko::kernels::omp::upper_trs::generate(omp, d_csr_mat.get(), - solve_struct_omp.get(), num_rhs); + solve_struct_omp, num_rhs); } @@ -159,8 +152,6 @@ TEST_F(UpperTrs, OmpUpperTrsSolveIsEquivalentToRef) { initialize_data(59, 43); - gko::kernels::reference::upper_trs::init_struct(ref, solve_struct_ref); - gko::kernels::omp::upper_trs::init_struct(omp, solve_struct_omp); gko::kernels::reference::upper_trs::solve(ref, csr_mat.get(), solve_struct_ref.get(), t_b.get(), t_x.get(), b.get(), x.get()); diff --git a/reference/solver/lower_trs_kernels.cpp b/reference/solver/lower_trs_kernels.cpp index 8247397c4ba..c30f5391d4d 100644 --- a/reference/solver/lower_trs_kernels.cpp +++ b/reference/solver/lower_trs_kernels.cpp @@ -62,18 +62,11 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr &solve_struct) -{ - // This init kernel is here to allow initialization of the solve struct for - // a more sophisticated implementation as for other executors. -} - - template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - solver::SolveStruct *solve_struct, const gko::size_type num_rhs) + std::shared_ptr &solve_struct, + const gko::size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated // implementation as for other executors. This kernel would perform the diff --git a/reference/solver/upper_trs_kernels.cpp b/reference/solver/upper_trs_kernels.cpp index 7c938d505ab..a36e3d8ae78 100644 --- a/reference/solver/upper_trs_kernels.cpp +++ b/reference/solver/upper_trs_kernels.cpp @@ -62,18 +62,11 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr &solve_struct) -{ - // This init kernel is here to allow initialization of the solve struct for - // a more sophisticated implementation as for other executors. -} - - template void generate(std::shared_ptr exec, const matrix::Csr *matrix, - solver::SolveStruct *solve_struct, const gko::size_type num_rhs) + std::shared_ptr &solve_struct, + const gko::size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated // implementation as for other executors. This kernel would perform the