diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index bbd9a43161a..a06eac4bab6 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -318,7 +318,6 @@ namespace lower_trs { GKO_STUB(GKO_DECLARE_LOWER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL); -GKO_STUB(GKO_DECLARE_LOWER_TRS_INIT_STRUCT_KERNEL); GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL); GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_LOWER_TRS_SOLVE_KERNEL); @@ -330,7 +329,6 @@ namespace upper_trs { GKO_STUB(GKO_DECLARE_UPPER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL); -GKO_STUB(GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL); GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL); GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_UPPER_TRS_SOLVE_KERNEL); diff --git a/core/solver/lower_trs.cpp b/core/solver/lower_trs.cpp index 7ae9e1b4858..b17840477e5 100644 --- a/core/solver/lower_trs.cpp +++ b/core/solver/lower_trs.cpp @@ -55,7 +55,6 @@ namespace { GKO_REGISTER_OPERATION(generate, lower_trs::generate); -GKO_REGISTER_OPERATION(init_struct, lower_trs::init_struct); GKO_REGISTER_OPERATION(should_perform_transpose, lower_trs::should_perform_transpose); GKO_REGISTER_OPERATION(solve, lower_trs::solve); @@ -85,19 +84,11 @@ std::unique_ptr LowerTrs::conj_transpose() const } -template -void LowerTrs::init_trs_solve_struct() -{ - this->get_executor()->run(lower_trs::make_init_struct(this->solve_struct_)); -} - - template void LowerTrs::generate() { this->get_executor()->run(lower_trs::make_generate( - gko::lend(system_matrix_), gko::lend(this->solve_struct_), - parameters_.num_rhs)); + gko::lend(system_matrix_), this->solve_struct_, parameters_.num_rhs)); } diff --git a/core/solver/lower_trs_kernels.hpp b/core/solver/lower_trs_kernels.hpp index 1aeb09918ab..78406c274dd 100644 --- a/core/solver/lower_trs_kernels.hpp +++ b/core/solver/lower_trs_kernels.hpp @@ -58,15 +58,10 @@ namespace lower_trs { bool& do_transpose) -#define GKO_DECLARE_LOWER_TRS_INIT_STRUCT_KERNEL \ - void init_struct(std::shared_ptr exec, \ - std::shared_ptr& solve_struct) - - -#define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \ - void generate(std::shared_ptr exec, \ - const matrix::Csr<_vtype, _itype>* matrix, \ - solver::SolveStruct* solve_struct, \ +#define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \ + void generate(std::shared_ptr exec, \ + const matrix::Csr<_vtype, _itype>* matrix, \ + std::shared_ptr& solve_struct, \ const gko::size_type num_rhs) @@ -80,7 +75,6 @@ namespace lower_trs { #define GKO_DECLARE_ALL_AS_TEMPLATES \ GKO_DECLARE_LOWER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL; \ - GKO_DECLARE_LOWER_TRS_INIT_STRUCT_KERNEL; \ template \ GKO_DECLARE_LOWER_TRS_SOLVE_KERNEL(ValueType, IndexType); \ template \ diff --git a/core/solver/upper_trs.cpp b/core/solver/upper_trs.cpp index d109d2cbd65..7a26596b1c9 100644 --- a/core/solver/upper_trs.cpp +++ b/core/solver/upper_trs.cpp @@ -55,7 +55,6 @@ namespace { GKO_REGISTER_OPERATION(generate, upper_trs::generate); -GKO_REGISTER_OPERATION(init_struct, upper_trs::init_struct); GKO_REGISTER_OPERATION(should_perform_transpose, upper_trs::should_perform_transpose); GKO_REGISTER_OPERATION(solve, upper_trs::solve); @@ -85,19 +84,11 @@ std::unique_ptr UpperTrs::conj_transpose() const } -template -void UpperTrs::init_trs_solve_struct() -{ - this->get_executor()->run(upper_trs::make_init_struct(this->solve_struct_)); -} - - template void UpperTrs::generate() { this->get_executor()->run(upper_trs::make_generate( - gko::lend(system_matrix_), gko::lend(this->solve_struct_), - parameters_.num_rhs)); + gko::lend(system_matrix_), this->solve_struct_, parameters_.num_rhs)); } diff --git a/core/solver/upper_trs_kernels.hpp b/core/solver/upper_trs_kernels.hpp index ba35eb6373b..8dfa65e698c 100644 --- a/core/solver/upper_trs_kernels.hpp +++ b/core/solver/upper_trs_kernels.hpp @@ -58,15 +58,10 @@ namespace upper_trs { bool& do_transpose) -#define GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL \ - void init_struct(std::shared_ptr exec, \ - std::shared_ptr& solve_struct) - - -#define GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(_vtype, _itype) \ - void generate(std::shared_ptr exec, \ - const matrix::Csr<_vtype, _itype>* matrix, \ - solver::SolveStruct* solve_struct, \ +#define GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(_vtype, _itype) \ + void generate(std::shared_ptr exec, \ + const matrix::Csr<_vtype, _itype>* matrix, \ + std::shared_ptr& solve_struct, \ const gko::size_type num_rhs) @@ -80,7 +75,6 @@ namespace upper_trs { #define GKO_DECLARE_ALL_AS_TEMPLATES \ GKO_DECLARE_UPPER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL; \ - GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL; \ template \ GKO_DECLARE_UPPER_TRS_SOLVE_KERNEL(ValueType, IndexType); \ template \ diff --git a/cuda/base/cusparse_bindings.hpp b/cuda/base/cusparse_bindings.hpp index 5bfb8629240..cd43bccb60f 100644 --- a/cuda/base/cusparse_bindings.hpp +++ b/cuda/base/cusparse_bindings.hpp @@ -559,9 +559,9 @@ void spgemm_copy(cusparseHandle_t handle, const ValueType* alpha, inline size_type sparse_matrix_nnz(cusparseSpMatDescr_t descr) { - int64_t dummy1{}; - int64_t dummy2{}; - int64_t nnz{}; + int64 dummy1{}; + int64 dummy2{}; + int64 nnz{}; cusparseSpMatGetSize(descr, &dummy1, &dummy2, &nnz); return static_cast(nnz); } @@ -745,10 +745,23 @@ inline cusparseMatDescr_t create_mat_descr() { cusparseMatDescr_t descr{}; GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateMatDescr(&descr)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetMatDiagType(descr, CUSPARSE_DIAG_TYPE_NON_UNIT)); return descr; } +inline void set_mat_fill_mode(cusparseMatDescr_t descr, + cusparseFillMode_t fill_mode) +{ + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSetMatFillMode(descr, fill_mode)); +} + + inline void destroy(cusparseMatDescr_t descr) { GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseDestroyMatDescr(descr)); @@ -790,7 +803,7 @@ inline void destroy(cusparseSpGEMMDescr_t info) template -inline cusparseDnVecDescr_t create_dnvec(int64_t size, ValueType* values) +inline cusparseDnVecDescr_t create_dnvec(int64 size, ValueType* values) { cusparseDnVecDescr_t descr{}; constexpr auto value_type = cuda_data_type(); @@ -806,8 +819,26 @@ inline void destroy(cusparseDnVecDescr_t descr) } +template +inline cusparseDnMatDescr_t create_dnmat(int64 rows, int64 cols, int64 stride, + ValueType* values) +{ + cusparseDnMatDescr_t descr{}; + constexpr auto value_type = cuda_data_type(); + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateDnMat( + &descr, rows, cols, stride, values, value_type, CUSPARSE_ORDER_ROW)); + return descr; +} + + +inline void destroy(cusparseDnMatDescr_t descr) +{ + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseDestroyDnMat(descr)); +} + + template -inline cusparseSpVecDescr_t create_spvec(int64_t size, int64_t nnz, +inline cusparseSpVecDescr_t create_spvec(int64 size, int64 nnz, IndexType* indices, ValueType* values) { cusparseSpVecDescr_t descr{}; @@ -827,7 +858,7 @@ inline void destroy(cusparseSpVecDescr_t descr) template -inline cusparseSpMatDescr_t create_csr(int64_t rows, int64_t cols, int64_t nnz, +inline cusparseSpMatDescr_t create_csr(int64 rows, int64 cols, int64 nnz, IndexType* csrRowOffsets, IndexType* csrColInd, ValueType* csrValues) @@ -847,7 +878,37 @@ inline void destroy(cusparseSpMatDescr_t descr) GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseDestroySpMat(descr)); } -#endif + +#if (CUDA_VERSION >= 11030) + + +template +inline void set_attribute(cusparseSpMatDescr_t desc, + cusparseSpMatAttribute_t attr, AttribType val) +{ + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSpMatSetAttribute(desc, attr, &val, sizeof(val))); +} + + +inline cusparseSpSMDescr_t create_spsm_descr() +{ + cusparseSpSMDescr_t desc{}; + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSpSM_createDescr(&desc)); + return desc; +} + + +inline void destroy(cusparseSpSMDescr_t info) +{ + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSpSM_destroyDescr(info)); +} + + +#endif // CUDA_VERSION >= 11030 + + +#endif // defined(CUDA_VERSION) && (CUDA_VERSION >= 11000) inline csrsm2Info_t create_solve_info() @@ -896,7 +957,7 @@ inline void destroy(csric02Info_t info) inline void buffer_size_ext( \ cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ - const ValueType* one, const cusparseMatDescr_t descr, \ + ValueType one, const cusparseMatDescr_t descr, \ const ValueType* csrVal, const int32* csrRowPtr, \ const int32* csrColInd, const ValueType* rhs, int32 sol_size, \ csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ @@ -904,7 +965,7 @@ inline void destroy(csric02Info_t info) { \ GKO_ASSERT_NO_CUSPARSE_ERRORS( \ CusparseName(handle, algo, trans1, trans2, m, n, nnz, \ - as_culibs_type(one), descr, as_culibs_type(csrVal), \ + as_culibs_type(&one), descr, as_culibs_type(csrVal), \ csrRowPtr, csrColInd, as_culibs_type(rhs), sol_size, \ factor_info, policy, factor_work_size)); \ } \ @@ -916,7 +977,7 @@ inline void destroy(csric02Info_t info) inline void buffer_size_ext( \ cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ - const ValueType* one, const cusparseMatDescr_t descr, \ + ValueType one, const cusparseMatDescr_t descr, \ const ValueType* csrVal, const int64* csrRowPtr, \ const int64* csrColInd, const ValueType* rhs, int64 sol_size, \ csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ @@ -949,7 +1010,7 @@ GKO_BIND_CUSPARSE64_BUFFERSIZEEXT(ValueType, detail::not_implemented); inline void csrsm2_analysis( \ cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ - const ValueType* one, const cusparseMatDescr_t descr, \ + ValueType one, const cusparseMatDescr_t descr, \ const ValueType* csrVal, const int32* csrRowPtr, \ const int32* csrColInd, const ValueType* rhs, int32 sol_size, \ csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ @@ -957,7 +1018,7 @@ GKO_BIND_CUSPARSE64_BUFFERSIZEEXT(ValueType, detail::not_implemented); { \ GKO_ASSERT_NO_CUSPARSE_ERRORS( \ CusparseName(handle, algo, trans1, trans2, m, n, nnz, \ - as_culibs_type(one), descr, as_culibs_type(csrVal), \ + as_culibs_type(&one), descr, as_culibs_type(csrVal), \ csrRowPtr, csrColInd, as_culibs_type(rhs), sol_size, \ factor_info, policy, factor_work_vec)); \ } \ @@ -969,7 +1030,7 @@ GKO_BIND_CUSPARSE64_BUFFERSIZEEXT(ValueType, detail::not_implemented); inline void csrsm2_analysis( \ cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ - const ValueType* one, const cusparseMatDescr_t descr, \ + ValueType one, const cusparseMatDescr_t descr, \ const ValueType* csrVal, const int64* csrRowPtr, \ const int64* csrColInd, const ValueType* rhs, int64 sol_size, \ csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ @@ -998,31 +1059,31 @@ GKO_BIND_CUSPARSE64_CSRSM2_ANALYSIS(ValueType, detail::not_implemented); #undef GKO_BIND_CUSPARSE64_CSRSM2_ANALYSIS -#define GKO_BIND_CUSPARSE32_CSRSM2_SOLVE(ValueType, CusparseName) \ - inline void csrsm2_solve( \ - cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ - cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ - const ValueType* one, const cusparseMatDescr_t descr, \ - const ValueType* csrVal, const int32* csrRowPtr, \ - const int32* csrColInd, ValueType* rhs, int32 sol_stride, \ - csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ - void* factor_work_vec) \ - { \ - GKO_ASSERT_NO_CUSPARSE_ERRORS( \ - CusparseName(handle, algo, trans1, trans2, m, n, nnz, \ - as_culibs_type(one), descr, as_culibs_type(csrVal), \ - csrRowPtr, csrColInd, as_culibs_type(rhs), \ - sol_stride, factor_info, policy, factor_work_vec)); \ - } \ - static_assert(true, \ - "This assert is used to counter the false positive extra " \ +#define GKO_BIND_CUSPARSE32_CSRSM2_SOLVE(ValueType, CusparseName) \ + inline void csrsm2_solve( \ + cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ + cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ + ValueType one, const cusparseMatDescr_t descr, \ + const ValueType* csrVal, const int32* csrRowPtr, \ + const int32* csrColInd, ValueType* rhs, int32 sol_stride, \ + csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ + void* factor_work_vec) \ + { \ + GKO_ASSERT_NO_CUSPARSE_ERRORS( \ + CusparseName(handle, algo, trans1, trans2, m, n, nnz, \ + as_culibs_type(&one), descr, as_culibs_type(csrVal), \ + csrRowPtr, csrColInd, as_culibs_type(rhs), \ + sol_stride, factor_info, policy, factor_work_vec)); \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ "semi-colon warnings") #define GKO_BIND_CUSPARSE64_CSRSM2_SOLVE(ValueType, CusparseName) \ inline void csrsm2_solve( \ cusparseHandle_t handle, int algo, cusparseOperation_t trans1, \ cusparseOperation_t trans2, size_type m, size_type n, size_type nnz, \ - const ValueType* one, const cusparseMatDescr_t descr, \ + ValueType one, const cusparseMatDescr_t descr, \ const ValueType* csrVal, const int64* csrRowPtr, \ const int64* csrColInd, ValueType* rhs, int64 sol_stride, \ csrsm2Info_t factor_info, cusparseSolvePolicy_t policy, \ @@ -1047,6 +1108,54 @@ GKO_BIND_CUSPARSE64_CSRSM2_SOLVE(ValueType, detail::not_implemented); #undef GKO_BIND_CUSPARSE64_CSRSM2_SOLVE +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11030)) + + +template +size_type spsm_buffer_size(cusparseHandle_t handle, cusparseOperation_t op_a, + cusparseOperation_t op_b, ValueType alpha, + cusparseSpMatDescr_t descr_a, + cusparseDnMatDescr_t descr_b, + cusparseDnMatDescr_t descr_c, cusparseSpSMAlg_t algo, + cusparseSpSMDescr_t spsm_descr) +{ + size_type work_size; + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSpSM_bufferSize( + handle, op_a, op_b, &alpha, descr_a, descr_b, descr_c, + cuda_data_type(), algo, spsm_descr, &work_size)); + return work_size; +} + + +template +void spsm_analysis(cusparseHandle_t handle, cusparseOperation_t op_a, + cusparseOperation_t op_b, ValueType alpha, + cusparseSpMatDescr_t descr_a, cusparseDnMatDescr_t descr_b, + cusparseDnMatDescr_t descr_c, cusparseSpSMAlg_t algo, + cusparseSpSMDescr_t spsm_descr, void* work) +{ + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSpSM_analysis( + handle, op_a, op_b, &alpha, descr_a, descr_b, descr_c, + cuda_data_type(), algo, spsm_descr, work)); +} + + +template +void spsm_solve(cusparseHandle_t handle, cusparseOperation_t op_a, + cusparseOperation_t op_b, ValueType alpha, + cusparseSpMatDescr_t descr_a, cusparseDnMatDescr_t descr_b, + cusparseDnMatDescr_t descr_c, cusparseSpSMAlg_t algo, + cusparseSpSMDescr_t spsm_descr) +{ + GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSpSM_solve( + handle, op_a, op_b, &alpha, descr_a, descr_b, descr_c, + cuda_data_type(), algo, spsm_descr)); +} + + +#endif // (defined(CUDA_VERSION) && (CUDA_VERSION >= 11030)) + + template void create_identity_permutation(cusparseHandle_t handle, IndexType size, IndexType* permutation) GKO_NOT_IMPLEMENTED; diff --git a/cuda/base/pointer_mode_guard.hpp b/cuda/base/pointer_mode_guard.hpp index 34923d21abf..f9ed73f4ba6 100644 --- a/cuda/base/pointer_mode_guard.hpp +++ b/cuda/base/pointer_mode_guard.hpp @@ -108,9 +108,9 @@ namespace cusparse { */ class pointer_mode_guard { public: - pointer_mode_guard(cusparseHandle_t& handle) + pointer_mode_guard(cusparseHandle_t handle) { - l_handle = &handle; + l_handle = handle; GKO_ASSERT_NO_CUSPARSE_ERRORS( cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST)); } @@ -127,15 +127,15 @@ class pointer_mode_guard { { /* Ignore the error during stack unwinding for this call */ if (std::uncaught_exception()) { - cusparseSetPointerMode(*l_handle, CUSPARSE_POINTER_MODE_DEVICE); + cusparseSetPointerMode(l_handle, CUSPARSE_POINTER_MODE_DEVICE); } else { - GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSetPointerMode( - *l_handle, CUSPARSE_POINTER_MODE_DEVICE)); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseSetPointerMode(l_handle, CUSPARSE_POINTER_MODE_DEVICE)); } } private: - cusparseHandle_t* l_handle; + cusparseHandle_t l_handle; }; diff --git a/cuda/solver/common_trs_kernels.cuh b/cuda/solver/common_trs_kernels.cuh index 2c01e6a76f5..8c2cb010d6c 100644 --- a/cuda/solver/common_trs_kernels.cuh +++ b/cuda/solver/common_trs_kernels.cuh @@ -35,6 +35,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include #include @@ -52,6 +53,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "cuda/base/math.hpp" #include "cuda/base/pointer_mode_guard.hpp" #include "cuda/base/types.hpp" +#include "cuda/components/atomic.cuh" +#include "cuda/components/thread_ids.cuh" +#include "cuda/components/uninitialized_array.hpp" namespace gko { @@ -59,64 +63,217 @@ namespace solver { struct SolveStruct { - virtual ~SolveStruct() {} + virtual ~SolveStruct() = default; }; +} // namespace solver + + +namespace kernels { namespace cuda { +namespace { + +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11030)) -struct SolveStruct : gko::solver::SolveStruct { + +template +struct CudaSolveStruct : gko::solver::SolveStruct { + cusparseHandle_t handle; + cusparseSpSMDescr_t spsm_descr; + cusparseSpMatDescr_t descr_a; + + // Implicit parameter in spsm_solve, therefore stored here. + Array work; + + CudaSolveStruct(std::shared_ptr exec, + const matrix::Csr* matrix, + size_type num_rhs, bool is_upper) + : handle{exec->get_cusparse_handle()}, + spsm_descr{}, + descr_a{}, + work{exec} + { + cusparse::pointer_mode_guard pm_guard(handle); + spsm_descr = cusparse::create_spsm_descr(); + descr_a = cusparse::create_csr( + matrix->get_size()[0], matrix->get_size()[1], + matrix->get_num_stored_elements(), + const_cast(matrix->get_const_row_ptrs()), + const_cast(matrix->get_const_col_idxs()), + const_cast(matrix->get_const_values())); + cusparse::set_attribute( + descr_a, CUSPARSE_SPMAT_FILL_MODE, + is_upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER); + cusparse::set_attribute( + descr_a, CUSPARSE_SPMAT_DIAG_TYPE, CUSPARSE_DIAG_TYPE_NON_UNIT); + + const auto rows = matrix->get_size()[0]; + // workaround suggested by NVIDIA engineers: for some reason + // cusparse needs non-nullptr input vectors even for analysis + auto descr_b = cusparse::create_dnmat( + matrix->get_size()[0], num_rhs, matrix->get_size()[1], + reinterpret_cast(0xDEAD)); + auto descr_c = cusparse::create_dnmat( + matrix->get_size()[0], num_rhs, matrix->get_size()[1], + reinterpret_cast(0xDEAF)); + + auto work_size = cusparse::spsm_buffer_size( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, one(), descr_a, + descr_b, descr_c, CUSPARSE_SPSM_ALG_DEFAULT, spsm_descr); + + work.resize_and_reset(work_size); + + cusparse::spsm_analysis(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, + one(), descr_a, descr_b, descr_c, + CUSPARSE_SPSM_ALG_DEFAULT, spsm_descr, + work.get_data()); + + cusparse::destroy(descr_b); + cusparse::destroy(descr_c); + } + + void solve(const matrix::Csr*, + const matrix::Dense* input, + matrix::Dense* output, matrix::Dense*, + matrix::Dense*) const + { + cusparse::pointer_mode_guard pm_guard(handle); + auto descr_b = cusparse::create_dnmat( + input->get_size()[0], input->get_size()[1], input->get_stride(), + const_cast(input->get_const_values())); + auto descr_c = + cusparse::create_dnmat(output->get_size()[0], output->get_size()[1], + output->get_stride(), output->get_values()); + + cusparse::spsm_solve(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, one(), + descr_a, descr_b, descr_c, + CUSPARSE_SPSM_ALG_DEFAULT, spsm_descr); + + cusparse::destroy(descr_b); + cusparse::destroy(descr_c); + } + + ~CudaSolveStruct() + { + if (descr_a) { + cusparse::destroy(descr_a); + descr_a = nullptr; + } + if (spsm_descr) { + cusparse::destroy(spsm_descr); + spsm_descr = nullptr; + } + } + + CudaSolveStruct(const SolveStruct&) = delete; + + CudaSolveStruct(SolveStruct&&) = delete; + + CudaSolveStruct& operator=(const SolveStruct&) = delete; + + CudaSolveStruct& operator=(SolveStruct&&) = delete; +}; + + +#elif (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) + +template +struct CudaSolveStruct : gko::solver::SolveStruct { + std::shared_ptr exec; + cusparseHandle_t handle; int algorithm; csrsm2Info_t solve_info; cusparseSolvePolicy_t policy; cusparseMatDescr_t factor_descr; - size_t factor_work_size; - void* factor_work_vec; - SolveStruct() + mutable Array work; + + CudaSolveStruct(std::shared_ptr exec, + const matrix::Csr* matrix, + size_type num_rhs, bool is_upper) + : exec{exec}, + handle{exec->get_cusparse_handle()}, + algorithm{}, + solve_info{}, + policy{}, + factor_descr{}, + work{exec} { - factor_work_vec = nullptr; - GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateMatDescr(&factor_descr)); - GKO_ASSERT_NO_CUSPARSE_ERRORS( - cusparseSetMatIndexBase(factor_descr, CUSPARSE_INDEX_BASE_ZERO)); - GKO_ASSERT_NO_CUSPARSE_ERRORS( - cusparseSetMatType(factor_descr, CUSPARSE_MATRIX_TYPE_GENERAL)); - GKO_ASSERT_NO_CUSPARSE_ERRORS( - cusparseSetMatDiagType(factor_descr, CUSPARSE_DIAG_TYPE_NON_UNIT)); - GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseCreateCsrsm2Info(&solve_info)); + cusparse::pointer_mode_guard pm_guard(handle); + factor_descr = cusparse::create_mat_descr(); + solve_info = cusparse::create_solve_info(); + cusparse::set_mat_fill_mode( + factor_descr, + is_upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER); algorithm = 0; policy = CUSPARSE_SOLVE_POLICY_USE_LEVEL; - } - - SolveStruct(const SolveStruct&) = delete; - SolveStruct(SolveStruct&&) = delete; - - SolveStruct& operator=(const SolveStruct&) = delete; + size_type work_size{}; + + cusparse::buffer_size_ext( + handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs, + matrix->get_num_stored_elements(), one(), factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), nullptr, num_rhs, solve_info, policy, + &work_size); + + // allocate workspace + work.resize_and_reset(work_size); + + cusparse::csrsm2_analysis( + handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs, + matrix->get_num_stored_elements(), one(), factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), nullptr, num_rhs, solve_info, policy, + work.get_data()); + } - SolveStruct& operator=(SolveStruct&&) = delete; + void solve(const matrix::Csr* matrix, + const matrix::Dense* input, + matrix::Dense* output, matrix::Dense*, + matrix::Dense*) const + { + cusparse::pointer_mode_guard pm_guard(handle); + dense::copy(exec, input, output); + cusparse::csrsm2_solve( + handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], + output->get_stride(), matrix->get_num_stored_elements(), + one(), factor_descr, matrix->get_const_values(), + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + output->get_values(), output->get_stride(), solve_info, policy, + work.get_data()); + } - ~SolveStruct() + ~CudaSolveStruct() { - cusparseDestroyMatDescr(factor_descr); - if (solve_info) { - cusparseDestroyCsrsm2Info(solve_info); + if (factor_descr) { + cusparse::destroy(factor_descr); + factor_descr = nullptr; } - if (factor_work_vec != nullptr) { - cudaFree(factor_work_vec); - factor_work_vec = nullptr; + if (solve_info) { + cusparse::destroy(solve_info); + solve_info = nullptr; } } -}; + CudaSolveStruct(const CudaSolveStruct&) = delete; -} // namespace cuda -} // namespace solver + CudaSolveStruct(CudaSolveStruct&&) = delete; + CudaSolveStruct& operator=(const CudaSolveStruct&) = delete; -namespace kernels { -namespace cuda { -namespace { + CudaSolveStruct& operator=(CudaSolveStruct&&) = delete; +}; + + +#endif void should_perform_transpose_kernel(std::shared_ptr exec, @@ -126,64 +283,15 @@ void should_perform_transpose_kernel(std::shared_ptr exec, } -void init_struct_kernel(std::shared_ptr exec, - std::shared_ptr& solve_struct) -{ - solve_struct = std::make_shared(); -} - - template void generate_kernel(std::shared_ptr exec, const matrix::Csr* matrix, - solver::SolveStruct* solve_struct, + std::shared_ptr& solve_struct, const gko::size_type num_rhs, bool is_upper) { if (cusparse::is_supported::value) { - if (auto cuda_solve_struct = - dynamic_cast(solve_struct)) { - auto handle = exec->get_cusparse_handle(); - if (is_upper) { - GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSetMatFillMode( - cuda_solve_struct->factor_descr, CUSPARSE_FILL_MODE_UPPER)); - } - - ValueType one = 1.0; - - { - cusparse::pointer_mode_guard pm_guard(handle); - cusparse::buffer_size_ext( - handle, cuda_solve_struct->algorithm, - CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], - num_rhs, matrix->get_num_stored_elements(), &one, - cuda_solve_struct->factor_descr, matrix->get_const_values(), - matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), - nullptr, num_rhs, cuda_solve_struct->solve_info, - cuda_solve_struct->policy, - &cuda_solve_struct->factor_work_size); - - // allocate workspace - if (cuda_solve_struct->factor_work_vec != nullptr) { - exec->free(cuda_solve_struct->factor_work_vec); - } - cuda_solve_struct->factor_work_vec = - exec->alloc(cuda_solve_struct->factor_work_size); - - cusparse::csrsm2_analysis( - handle, cuda_solve_struct->algorithm, - CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], - num_rhs, matrix->get_num_stored_elements(), &one, - cuda_solve_struct->factor_descr, matrix->get_const_values(), - matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), - nullptr, num_rhs, cuda_solve_struct->solve_info, - cuda_solve_struct->policy, - cuda_solve_struct->factor_work_vec); - } - } else { - GKO_NOT_SUPPORTED(solve_struct); - } + solve_struct = std::make_shared>( + exec, matrix, num_rhs, is_upper); } else { GKO_NOT_IMPLEMENTED; } @@ -203,23 +311,9 @@ void solve_kernel(std::shared_ptr exec, if (cusparse::is_supported::value) { if (auto cuda_solve_struct = - dynamic_cast(solve_struct)) { - ValueType one = 1.0; - auto handle = exec->get_cusparse_handle(); - x->copy_from(gko::lend(b)); - { - cusparse::pointer_mode_guard pm_guard(handle); - cusparse::csrsm2_solve( - handle, cuda_solve_struct->algorithm, - CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], - b->get_stride(), matrix->get_num_stored_elements(), &one, - cuda_solve_struct->factor_descr, matrix->get_const_values(), - matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), - x->get_values(), b->get_stride(), - cuda_solve_struct->solve_info, cuda_solve_struct->policy, - cuda_solve_struct->factor_work_vec); - } + dynamic_cast*>( + solve_struct)) { + cuda_solve_struct->solve(matrix, b, x, trans_b, trans_x); } else { GKO_NOT_SUPPORTED(solve_struct); } @@ -229,6 +323,240 @@ void solve_kernel(std::shared_ptr exec, } +constexpr int default_block_size = 512; +constexpr int fallback_block_size = 32; + + +template +__device__ __forceinline__ + std::enable_if_t::value, ValueType> + load(const ValueType* values, IndexType index) +{ + const volatile ValueType* val = values + index; + return *val; +} + +template +__device__ __forceinline__ std::enable_if_t< + std::is_floating_point::value, thrust::complex> +load(const thrust::complex* values, IndexType index) +{ + auto real = reinterpret_cast(values); + auto imag = real + 1; + return {load(real, 2 * index), load(imag, 2 * index)}; +} + +template +__device__ __forceinline__ void store( + ValueType* values, IndexType index, + std::enable_if_t::value, ValueType> value) +{ + volatile ValueType* val = values + index; + *val = value; +} + +template +__device__ __forceinline__ void store(thrust::complex* values, + IndexType index, + thrust::complex value) +{ + auto real = reinterpret_cast(values); + auto imag = real + 1; + store(real, 2 * index, value.real()); + store(imag, 2 * index, value.imag()); +} + + +template +__global__ void sptrsv_naive_caching_kernel( + const IndexType* const rowptrs, const IndexType* const colidxs, + const ValueType* const vals, const ValueType* const b, size_type b_stride, + ValueType* const x, size_type x_stride, const size_type n, + const size_type nrhs, bool* nan_produced, IndexType* atomic_counter) +{ + __shared__ UninitializedArray x_s_array; + __shared__ IndexType block_base_idx; + + if (threadIdx.x == 0) { + block_base_idx = + atomic_add(atomic_counter, IndexType{1}) * default_block_size; + } + __syncthreads(); + const auto full_gid = static_cast(threadIdx.x) + block_base_idx; + const auto rhs = full_gid % nrhs; + const auto gid = full_gid / nrhs; + const auto row = is_upper ? n - 1 - gid : gid; + + if (gid >= n) { + return; + } + + const auto self_shmem_id = full_gid / default_block_size; + const auto self_shid = full_gid % default_block_size; + + ValueType* x_s = x_s_array; + x_s[self_shid] = nan(); + + __syncthreads(); + + // lower tri matrix: start at beginning, run forward until last entry, + // (row_end - 1) which is the diagonal entry + // upper tri matrix: start at last entry (row_end - 1), run backward + // until first entry, which is the diagonal entry + const auto row_begin = is_upper ? rowptrs[row + 1] - 1 : rowptrs[row]; + const auto row_diag = is_upper ? rowptrs[row] : rowptrs[row + 1] - 1; + const int row_step = is_upper ? -1 : 1; + + ValueType sum = 0.0; + for (auto i = row_begin; i != row_diag; i += row_step) { + const auto dependency = colidxs[i]; + auto x_p = &x[dependency * x_stride + rhs]; + + const auto dependency_gid = is_upper ? (n - 1 - dependency) * nrhs + rhs + : dependency * nrhs + rhs; + const bool shmem_possible = + (dependency_gid / default_block_size) == self_shmem_id; + if (shmem_possible) { + const auto dependency_shid = dependency_gid % default_block_size; + x_p = &x_s[dependency_shid]; + } + + ValueType x = *x_p; + while (is_nan(x)) { + x = load(x_p, 0); + } + + sum += x * vals[i]; + } + + const auto r = (b[row * b_stride + rhs] - sum) / vals[row_diag]; + + store(x_s, self_shid, r); + x[row * x_stride + rhs] = r; + + // This check to ensure no infinte loops happen. + if (is_nan(r)) { + store(x_s, self_shid, zero()); + x[row * x_stride + rhs] = zero(); + *nan_produced = true; + } +} + + +template +__global__ void sptrsv_naive_legacy_kernel( + const IndexType* const rowptrs, const IndexType* const colidxs, + const ValueType* const vals, const ValueType* const b, size_type b_stride, + ValueType* const x, size_type x_stride, const size_type n, + const size_type nrhs, bool* nan_produced, IndexType* atomic_counter) +{ + __shared__ IndexType block_base_idx; + if (threadIdx.x == 0) { + block_base_idx = + atomic_add(atomic_counter, IndexType{1}) * fallback_block_size; + } + __syncthreads(); + const auto full_gid = static_cast(threadIdx.x) + block_base_idx; + const auto rhs = full_gid % nrhs; + const auto gid = full_gid / nrhs; + const auto row = is_upper ? n - 1 - gid : gid; + + if (gid >= n) { + return; + } + + // lower tri matrix: start at beginning, run forward until last entry, + // (row_end - 1) which is the diagonal entry + // upper tri matrix: start at last entry (row_end - 1), run backward + // until first entry, which is the diagonal entry + const auto row_begin = is_upper ? rowptrs[row + 1] - 1 : rowptrs[row]; + const auto row_diag = is_upper ? rowptrs[row] : rowptrs[row + 1] - 1; + const int row_step = is_upper ? -1 : 1; + + ValueType sum = 0.0; + auto j = row_begin; + while (j != row_diag + row_step) { + auto col = colidxs[j]; + auto x_val = load(x, col * x_stride + rhs); + while (!is_nan(x_val)) { + sum += vals[j] * x_val; + j += row_step; + col = colidxs[j]; + x_val = load(x, col * x_stride + rhs); + } + if (row == col) { + const auto r = (b[row * b_stride + rhs] - sum) / vals[row_diag]; + store(x, row * x_stride + rhs, r); + j += row_step; + if (is_nan(r)) { + store(x, row * x_stride + rhs, zero()); + *nan_produced = true; + } + } + } +} + + +template +__global__ void sptrsv_init_kernel(bool* const nan_produced, + IndexType* const atomic_counter) +{ + *nan_produced = false; + *atomic_counter = IndexType{}; +} + + +template +void sptrsv_naive_caching(std::shared_ptr exec, + const matrix::Csr* matrix, + const matrix::Dense* b, + matrix::Dense* x) +{ + // Pre-Volta GPUs may deadlock due to missing independent thread scheduling. + const auto is_fallback_required = exec->get_major_version() < 7; + + const auto n = matrix->get_size()[0]; + const auto nrhs = b->get_size()[1]; + + // Initialize x to all NaNs. + dense::fill(exec, x, nan()); + + Array nan_produced(exec, 1); + Array atomic_counter(exec, 1); + sptrsv_init_kernel<<<1, 1>>>(nan_produced.get_data(), + atomic_counter.get_data()); + + const dim3 block_size( + is_fallback_required ? fallback_block_size : default_block_size, 1, 1); + const dim3 grid_size(ceildiv(n * nrhs, block_size.x), 1, 1); + + if (is_fallback_required) { + sptrsv_naive_legacy_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), n, nrhs, + nan_produced.get_data(), atomic_counter.get_data()); + } else { + sptrsv_naive_caching_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), n, nrhs, + nan_produced.get_data(), atomic_counter.get_data()); + } + +#if GKO_VERBOSE_LEVEL >= 1 + if (exec->copy_val_to_host(nan_produced.get_const_data())) { + std::cerr + << "Error: triangular solve produced NaN, either not all diagonal " + "elements are nonzero, or the system is very ill-conditioned. " + "The NaN will be replaced with a zero.\n"; + } +#endif // GKO_VERBOSE_LEVEL >= 1 +} + + } // namespace } // namespace cuda } // namespace kernels diff --git a/cuda/solver/lower_trs_kernels.cu b/cuda/solver/lower_trs_kernels.cu index afbd545f1f1..5c16608a67d 100644 --- a/cuda/solver/lower_trs_kernels.cu +++ b/cuda/solver/lower_trs_kernels.cu @@ -69,20 +69,16 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr& solve_struct) -{ - init_struct_kernel(exec, solve_struct); -} - - template void generate(std::shared_ptr exec, const matrix::Csr* matrix, - solver::SolveStruct* solve_struct, const gko::size_type num_rhs) + std::shared_ptr& solve_struct, + const gko::size_type num_rhs) { - generate_kernel(exec, matrix, solve_struct, num_rhs, - false); + if (matrix->get_strategy()->get_name() == "sparselib") { + generate_kernel(exec, matrix, solve_struct, + num_rhs, false); + } } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -96,8 +92,12 @@ void solve(std::shared_ptr exec, matrix::Dense* trans_b, matrix::Dense* trans_x, const matrix::Dense* b, matrix::Dense* x) { - solve_kernel(exec, matrix, solve_struct, trans_b, - trans_x, b, x); + if (matrix->get_strategy()->get_name() == "sparselib") { + solve_kernel(exec, matrix, solve_struct, trans_b, + trans_x, b, x); + } else { + sptrsv_naive_caching(exec, matrix, b, x); + } } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( diff --git a/cuda/solver/upper_trs_kernels.cu b/cuda/solver/upper_trs_kernels.cu index 527dc3958a7..eeecdd24d02 100644 --- a/cuda/solver/upper_trs_kernels.cu +++ b/cuda/solver/upper_trs_kernels.cu @@ -69,20 +69,16 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr& solve_struct) -{ - init_struct_kernel(exec, solve_struct); -} - - template void generate(std::shared_ptr exec, const matrix::Csr* matrix, - solver::SolveStruct* solve_struct, const gko::size_type num_rhs) + std::shared_ptr& solve_struct, + const gko::size_type num_rhs) { - generate_kernel(exec, matrix, solve_struct, num_rhs, - true); + if (matrix->get_strategy()->get_name() == "sparselib") { + generate_kernel(exec, matrix, solve_struct, + num_rhs, true); + } } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -96,8 +92,12 @@ void solve(std::shared_ptr exec, matrix::Dense* trans_b, matrix::Dense* trans_x, const matrix::Dense* b, matrix::Dense* x) { - solve_kernel(exec, matrix, solve_struct, trans_b, - trans_x, b, x); + if (matrix->get_strategy()->get_name() == "sparselib") { + solve_kernel(exec, matrix, solve_struct, trans_b, + trans_x, b, x); + } else { + sptrsv_naive_caching(exec, matrix, b, x); + } } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( diff --git a/cuda/test/solver/lower_trs_kernels.cpp b/cuda/test/solver/lower_trs_kernels.cpp index 1cba5558312..ec8d02a9a57 100644 --- a/cuda/test/solver/lower_trs_kernels.cpp +++ b/cuda/test/solver/lower_trs_kernels.cpp @@ -134,6 +134,22 @@ TEST_F(LowerTrs, CudaLowerTrsFlagCheckIsCorrect) } +TEST_F(LowerTrs, CudaSingleRhsApplyClassicalIsEquivalentToRef) +{ + initialize_data(50, 1); + auto lower_trs_factory = gko::solver::LowerTrs<>::build().on(ref); + auto d_lower_trs_factory = gko::solver::LowerTrs<>::build().on(cuda); + d_csr_mtx->set_strategy(std::make_shared()); + auto solver = lower_trs_factory->generate(csr_mtx); + auto d_solver = d_lower_trs_factory->generate(d_csr_mtx); + + solver->apply(b2.get(), x.get()); + d_solver->apply(d_b2.get(), d_x.get()); + + GKO_ASSERT_MTX_NEAR(d_x, x, 1e-14); +} + + TEST_F(LowerTrs, CudaSingleRhsApplyIsEquivalentToRef) { initialize_data(50, 1); @@ -149,6 +165,27 @@ TEST_F(LowerTrs, CudaSingleRhsApplyIsEquivalentToRef) } +TEST_F(LowerTrs, CudaMultipleRhsApplyClassicalIsEquivalentToRef) +{ + initialize_data(50, 3); + auto lower_trs_factory = + gko::solver::LowerTrs<>::build().with_num_rhs(3u).on(ref); + auto d_lower_trs_factory = + gko::solver::LowerTrs<>::build().with_num_rhs(3u).on(cuda); + d_csr_mtx->set_strategy(std::make_shared()); + auto solver = lower_trs_factory->generate(csr_mtx); + auto d_solver = d_lower_trs_factory->generate(d_csr_mtx); + auto db2_strided = Mtx::create(cuda, b->get_size(), 4); + d_b2->convert_to(db2_strided.get()); + auto dx_strided = Mtx::create(cuda, x->get_size(), 5); + + solver->apply(b2.get(), x.get()); + d_solver->apply(db2_strided.get(), dx_strided.get()); + + GKO_ASSERT_MTX_NEAR(dx_strided, x, 1e-14); +} + + TEST_F(LowerTrs, CudaMultipleRhsApplyIsEquivalentToRef) { initialize_data(50, 3); @@ -158,11 +195,20 @@ TEST_F(LowerTrs, CudaMultipleRhsApplyIsEquivalentToRef) gko::solver::LowerTrs<>::build().with_num_rhs(3u).on(cuda); auto solver = lower_trs_factory->generate(csr_mtx); auto d_solver = d_lower_trs_factory->generate(d_csr_mtx); + auto db2_strided = Mtx::create(cuda, b->get_size(), 4); + d_b2->convert_to(db2_strided.get()); + // The cuSPARSE Generic SpSM implementation uses the wrong stride here + // so the input and output stride need to match +#if CUDA_VERSION >= 11030 + auto dx_strided = Mtx::create(cuda, x->get_size(), 4); +#else + auto dx_strided = Mtx::create(cuda, x->get_size(), 5); +#endif solver->apply(b2.get(), x.get()); - d_solver->apply(d_b2.get(), d_x.get()); + d_solver->apply(db2_strided.get(), dx_strided.get()); - GKO_ASSERT_MTX_NEAR(d_x, x, 1e-14); + GKO_ASSERT_MTX_NEAR(dx_strided, x, 1e-14); } diff --git a/cuda/test/solver/upper_trs_kernels.cpp b/cuda/test/solver/upper_trs_kernels.cpp index 78f1954ee3b..9e8918815f4 100644 --- a/cuda/test/solver/upper_trs_kernels.cpp +++ b/cuda/test/solver/upper_trs_kernels.cpp @@ -134,6 +134,22 @@ TEST_F(UpperTrs, CudaUpperTrsFlagCheckIsCorrect) } +TEST_F(UpperTrs, CudaSingleRhsApplyClassicalIsEquivalentToRef) +{ + initialize_data(50, 1); + auto upper_trs_factory = gko::solver::UpperTrs<>::build().on(ref); + auto d_upper_trs_factory = gko::solver::UpperTrs<>::build().on(cuda); + d_csr_mtx->set_strategy(std::make_shared()); + auto solver = upper_trs_factory->generate(csr_mtx); + auto d_solver = d_upper_trs_factory->generate(d_csr_mtx); + + solver->apply(b2.get(), x.get()); + d_solver->apply(d_b2.get(), d_x.get()); + + GKO_ASSERT_MTX_NEAR(d_x, x, 1e-14); +} + + TEST_F(UpperTrs, CudaSingleRhsApplyIsEquivalentToRef) { initialize_data(50, 1); @@ -149,6 +165,27 @@ TEST_F(UpperTrs, CudaSingleRhsApplyIsEquivalentToRef) } +TEST_F(UpperTrs, CudaMultipleRhsApplyClassicalIsEquivalentToRef) +{ + initialize_data(50, 3); + auto upper_trs_factory = + gko::solver::UpperTrs<>::build().with_num_rhs(3u).on(ref); + auto d_upper_trs_factory = + gko::solver::UpperTrs<>::build().with_num_rhs(3u).on(cuda); + d_csr_mtx->set_strategy(std::make_shared()); + auto solver = upper_trs_factory->generate(csr_mtx); + auto d_solver = d_upper_trs_factory->generate(d_csr_mtx); + auto db2_strided = Mtx::create(cuda, b->get_size(), 4); + d_b2->convert_to(db2_strided.get()); + auto dx_strided = Mtx::create(cuda, x->get_size(), 5); + + solver->apply(b2.get(), x.get()); + d_solver->apply(db2_strided.get(), dx_strided.get()); + + GKO_ASSERT_MTX_NEAR(dx_strided, x, 1e-14); +} + + TEST_F(UpperTrs, CudaMultipleRhsApplyIsEquivalentToRef) { initialize_data(50, 3); @@ -158,11 +195,20 @@ TEST_F(UpperTrs, CudaMultipleRhsApplyIsEquivalentToRef) gko::solver::UpperTrs<>::build().with_num_rhs(3u).on(cuda); auto solver = upper_trs_factory->generate(csr_mtx); auto d_solver = d_upper_trs_factory->generate(d_csr_mtx); + auto db2_strided = Mtx::create(cuda, b->get_size(), 4); + d_b2->convert_to(db2_strided.get()); + // The cuSPARSE Generic SpSM implementation uses the wrong stride here + // so the input and output stride need to match +#if CUDA_VERSION >= 11030 + auto dx_strided = Mtx::create(cuda, x->get_size(), 4); +#else + auto dx_strided = Mtx::create(cuda, x->get_size(), 5); +#endif solver->apply(b2.get(), x.get()); - d_solver->apply(d_b2.get(), d_x.get()); + d_solver->apply(db2_strided.get(), dx_strided.get()); - GKO_ASSERT_MTX_NEAR(d_x, x, 1e-14); + GKO_ASSERT_MTX_NEAR(dx_strided, x, 1e-14); } diff --git a/dpcpp/solver/lower_trs_kernels.dp.cpp b/dpcpp/solver/lower_trs_kernels.dp.cpp index 5f25d4c057b..cede13f1c20 100644 --- a/dpcpp/solver/lower_trs_kernels.dp.cpp +++ b/dpcpp/solver/lower_trs_kernels.dp.cpp @@ -60,26 +60,17 @@ namespace lower_trs { void should_perform_transpose(std::shared_ptr exec, - bool& do_transpose) GKO_NOT_IMPLEMENTED; - - -void init_struct(std::shared_ptr exec, - std::shared_ptr& solve_struct) + bool& do_transpose) { - // This init kernel is here to allow initialization of the solve struct for - // a more sophisticated implementation as for other executors. + do_transpose = false; } template void generate(std::shared_ptr exec, const matrix::Csr* matrix, - solver::SolveStruct* solve_struct, const gko::size_type num_rhs) -{ - // This generate kernel is here to allow for a more sophisticated - // implementation as for other executors. This kernel would perform the - // "analysis" phase for the triangular matrix. -} + std::shared_ptr& solve_struct, + const gko::size_type num_rhs) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL); diff --git a/dpcpp/solver/upper_trs_kernels.dp.cpp b/dpcpp/solver/upper_trs_kernels.dp.cpp index e46e7cb5195..e014095bb52 100644 --- a/dpcpp/solver/upper_trs_kernels.dp.cpp +++ b/dpcpp/solver/upper_trs_kernels.dp.cpp @@ -66,23 +66,11 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr& solve_struct) -{ - // This init kernel is here to allow initialization of the solve struct for - // a more sophisticated implementation as for other executors. -} - - template void generate(std::shared_ptr exec, const matrix::Csr* matrix, - solver::SolveStruct* solve_struct, const gko::size_type num_rhs) -{ - // This generate kernel is here to allow for a more sophisticated - // implementation as for other executors. This kernel would perform the - // "analysis" phase for the triangular matrix. -} + std::shared_ptr& solve_struct, + const gko::size_type num_rhs) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL); diff --git a/hip/solver/lower_trs_kernels.hip.cpp b/hip/solver/lower_trs_kernels.hip.cpp index 0f7751a039d..5a7cfde692b 100644 --- a/hip/solver/lower_trs_kernels.hip.cpp +++ b/hip/solver/lower_trs_kernels.hip.cpp @@ -71,18 +71,18 @@ void should_perform_transpose(std::shared_ptr exec, void init_struct(std::shared_ptr exec, std::shared_ptr& solve_struct) -{ - init_struct_kernel(exec, solve_struct); -} +{} template void generate(std::shared_ptr exec, const matrix::Csr* matrix, - solver::SolveStruct* solve_struct, const gko::size_type num_rhs) + std::shared_ptr& solve_struct, + const gko::size_type num_rhs) { - generate_kernel(exec, matrix, solve_struct, num_rhs, - false); + init_struct_kernel(exec, solve_struct); + generate_kernel(exec, matrix, solve_struct.get(), + num_rhs, false); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( diff --git a/hip/solver/upper_trs_kernels.hip.cpp b/hip/solver/upper_trs_kernels.hip.cpp index 0dd3d788d67..07bd0254531 100644 --- a/hip/solver/upper_trs_kernels.hip.cpp +++ b/hip/solver/upper_trs_kernels.hip.cpp @@ -69,20 +69,15 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr& solve_struct) -{ - init_struct_kernel(exec, solve_struct); -} - - template void generate(std::shared_ptr exec, const matrix::Csr* matrix, - solver::SolveStruct* solve_struct, const gko::size_type num_rhs) + std::shared_ptr& solve_struct, + const gko::size_type num_rhs) { - generate_kernel(exec, matrix, solve_struct, num_rhs, - true); + init_struct_kernel(exec, solve_struct); + generate_kernel(exec, matrix, solve_struct.get(), + num_rhs, true); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( diff --git a/include/ginkgo/core/base/math.hpp b/include/ginkgo/core/base/math.hpp index 17ffc9f8bad..915c4e5a262 100644 --- a/include/ginkgo/core/base/math.hpp +++ b/include/ginkgo/core/base/math.hpp @@ -1096,6 +1096,70 @@ GKO_INLINE GKO_ATTRIBUTES T safe_divide(T a, T b) } +/** + * Checks if a floating point number is NaN. + * + * @tparam T type of the value to check + * + * @param value value to check + * + * @return `true` if the value is NaN. + */ +template +GKO_INLINE GKO_ATTRIBUTES std::enable_if_t::value, bool> +is_nan(const T& value) +{ + return std::isnan(value); +} + + +/** + * Checks if any component of a complex value is NaN. + * + * @tparam T complex type of the value to check + * + * @param value complex value to check + * + * @return `true` if any component of the given value is NaN. + */ +template +GKO_INLINE GKO_ATTRIBUTES std::enable_if_t::value, bool> is_nan( + const T& value) +{ + return std::isnan(value.real()) || std::isnan(value.imag()); +} + + +/** + * Returns a quiet NaN of the given type. + * + * @tparam T the type of the object + * + * @return NaN. + */ +template +GKO_INLINE GKO_ATTRIBUTES constexpr std::enable_if_t::value, T> +nan() +{ + return std::numeric_limits::quiet_NaN(); +} + + +/** + * Returns a complex with both components quiet NaN. + * + * @tparam T the type of the object + * + * @return complex{NaN, NaN}. + */ +template +GKO_INLINE GKO_ATTRIBUTES constexpr std::enable_if_t::value, T> +nan() +{ + return T{nan>(), nan>()}; +} + + } // namespace gko diff --git a/include/ginkgo/core/solver/lower_trs.hpp b/include/ginkgo/core/solver/lower_trs.hpp index ab375c05406..94244ed1955 100644 --- a/include/ginkgo/core/solver/lower_trs.hpp +++ b/include/ginkgo/core/solver/lower_trs.hpp @@ -123,8 +123,6 @@ class LowerTrs : public EnableLinOp>, GKO_ENABLE_BUILD_METHOD(Factory); protected: - void init_trs_solve_struct(); - void apply_impl(const LinOp* b, LinOp* x) const override; void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta, @@ -159,7 +157,6 @@ class LowerTrs : public EnableLinOp>, system_matrix_ = copy_and_convert_to(exec, system_matrix); } - this->init_trs_solve_struct(); this->generate(); } diff --git a/include/ginkgo/core/solver/upper_trs.hpp b/include/ginkgo/core/solver/upper_trs.hpp index 66f388d4823..5a49c094dea 100644 --- a/include/ginkgo/core/solver/upper_trs.hpp +++ b/include/ginkgo/core/solver/upper_trs.hpp @@ -123,8 +123,6 @@ class UpperTrs : public EnableLinOp>, GKO_ENABLE_BUILD_METHOD(Factory); protected: - void init_trs_solve_struct(); - void apply_impl(const LinOp* b, LinOp* x) const override; void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta, @@ -159,7 +157,6 @@ class UpperTrs : public EnableLinOp>, system_matrix_ = copy_and_convert_to(exec, system_matrix); } - this->init_trs_solve_struct(); this->generate(); } diff --git a/omp/solver/lower_trs_kernels.cpp b/omp/solver/lower_trs_kernels.cpp index edc81943ee5..d8b296aeec6 100644 --- a/omp/solver/lower_trs_kernels.cpp +++ b/omp/solver/lower_trs_kernels.cpp @@ -66,18 +66,11 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr& solve_struct) -{ - // This init kernel is here to allow initialization of the solve struct for - // a more sophisticated implementation as for other executors. -} - - template void generate(std::shared_ptr exec, const matrix::Csr* matrix, - solver::SolveStruct* solve_struct, const gko::size_type num_rhs) + std::shared_ptr& solve_struct, + const gko::size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated // implementation as for other executors. This kernel would perform the diff --git a/omp/solver/upper_trs_kernels.cpp b/omp/solver/upper_trs_kernels.cpp index 0952c339990..75b50cc65cb 100644 --- a/omp/solver/upper_trs_kernels.cpp +++ b/omp/solver/upper_trs_kernels.cpp @@ -66,18 +66,11 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr& solve_struct) -{ - // This init kernel is here to allow initialization of the solve struct for - // a more sophisticated implementation as for other executors. -} - - template void generate(std::shared_ptr exec, const matrix::Csr* matrix, - solver::SolveStruct* solve_struct, const gko::size_type num_rhs) + std::shared_ptr& solve_struct, + const gko::size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated // implementation as for other executors. This kernel would perform the diff --git a/omp/test/solver/lower_trs_kernels.cpp b/omp/test/solver/lower_trs_kernels.cpp index 1a15fbe5647..eca1e1ea0bf 100644 --- a/omp/test/solver/lower_trs_kernels.cpp +++ b/omp/test/solver/lower_trs_kernels.cpp @@ -139,20 +139,13 @@ TEST_F(LowerTrs, OmpLowerTrsFlagCheckIsCorrect) } -TEST_F(LowerTrs, OmpLowerTrsSolveStructInitIsEquivalentToRef) -{ - gko::kernels::reference::lower_trs::init_struct(ref, solve_struct_ref); - gko::kernels::omp::lower_trs::init_struct(omp, solve_struct_omp); -} - - TEST_F(LowerTrs, OmpLowerTrsGenerateIsEquivalentToRef) { gko::size_type num_rhs = 1; - gko::kernels::reference::lower_trs::generate( - ref, csr_mat.get(), solve_struct_ref.get(), num_rhs); + gko::kernels::reference::lower_trs::generate(ref, csr_mat.get(), + solve_struct_ref, num_rhs); gko::kernels::omp::lower_trs::generate(omp, d_csr_mat.get(), - solve_struct_omp.get(), num_rhs); + solve_struct_omp, num_rhs); } @@ -160,8 +153,6 @@ TEST_F(LowerTrs, OmpLowerTrsSolveIsEquivalentToRef) { initialize_data(59, 43); - gko::kernels::reference::lower_trs::init_struct(ref, solve_struct_ref); - gko::kernels::omp::lower_trs::init_struct(omp, solve_struct_omp); gko::kernels::reference::lower_trs::solve(ref, csr_mat.get(), solve_struct_ref.get(), t_b.get(), t_x.get(), b.get(), x.get()); diff --git a/omp/test/solver/upper_trs_kernels.cpp b/omp/test/solver/upper_trs_kernels.cpp index d3381167f43..53d5822f617 100644 --- a/omp/test/solver/upper_trs_kernels.cpp +++ b/omp/test/solver/upper_trs_kernels.cpp @@ -138,20 +138,13 @@ TEST_F(UpperTrs, OmpUpperTrsFlagCheckIsCorrect) } -TEST_F(UpperTrs, OmpUpperTrsSolveStructInitIsEquivalentToRef) -{ - gko::kernels::reference::upper_trs::init_struct(ref, solve_struct_ref); - gko::kernels::omp::upper_trs::init_struct(omp, solve_struct_omp); -} - - TEST_F(UpperTrs, OmpUpperTrsGenerateIsEquivalentToRef) { gko::size_type num_rhs = 1; - gko::kernels::reference::upper_trs::generate( - ref, csr_mat.get(), solve_struct_ref.get(), num_rhs); + gko::kernels::reference::upper_trs::generate(ref, csr_mat.get(), + solve_struct_ref, num_rhs); gko::kernels::omp::upper_trs::generate(omp, d_csr_mat.get(), - solve_struct_omp.get(), num_rhs); + solve_struct_omp, num_rhs); } @@ -159,8 +152,6 @@ TEST_F(UpperTrs, OmpUpperTrsSolveIsEquivalentToRef) { initialize_data(59, 43); - gko::kernels::reference::upper_trs::init_struct(ref, solve_struct_ref); - gko::kernels::omp::upper_trs::init_struct(omp, solve_struct_omp); gko::kernels::reference::upper_trs::solve(ref, csr_mat.get(), solve_struct_ref.get(), t_b.get(), t_x.get(), b.get(), x.get()); diff --git a/reference/solver/lower_trs_kernels.cpp b/reference/solver/lower_trs_kernels.cpp index aca478c77da..a521d622d87 100644 --- a/reference/solver/lower_trs_kernels.cpp +++ b/reference/solver/lower_trs_kernels.cpp @@ -62,18 +62,11 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr& solve_struct) -{ - // This init kernel is here to allow initialization of the solve struct for - // a more sophisticated implementation as for other executors. -} - - template void generate(std::shared_ptr exec, const matrix::Csr* matrix, - solver::SolveStruct* solve_struct, const gko::size_type num_rhs) + std::shared_ptr& solve_struct, + const gko::size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated // implementation as for other executors. This kernel would perform the diff --git a/reference/solver/upper_trs_kernels.cpp b/reference/solver/upper_trs_kernels.cpp index 46dbbd423a7..d08924798b0 100644 --- a/reference/solver/upper_trs_kernels.cpp +++ b/reference/solver/upper_trs_kernels.cpp @@ -62,18 +62,11 @@ void should_perform_transpose(std::shared_ptr exec, } -void init_struct(std::shared_ptr exec, - std::shared_ptr& solve_struct) -{ - // This init kernel is here to allow initialization of the solve struct for - // a more sophisticated implementation as for other executors. -} - - template void generate(std::shared_ptr exec, const matrix::Csr* matrix, - solver::SolveStruct* solve_struct, const gko::size_type num_rhs) + std::shared_ptr& solve_struct, + const gko::size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated // implementation as for other executors. This kernel would perform the