Skip to content

Commit

Permalink
Add CUDA generic API triangular solver
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Aug 26, 2021
1 parent de2df07 commit 90a3bbb
Show file tree
Hide file tree
Showing 22 changed files with 438 additions and 407 deletions.
6 changes: 0 additions & 6 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,6 @@ namespace lower_trs {
GKO_DECLARE_LOWER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL()
GKO_NOT_COMPILED(GKO_HOOK_MODULE);

GKO_DECLARE_LOWER_TRS_INIT_STRUCT_KERNEL()
GKO_NOT_COMPILED(GKO_HOOK_MODULE);

template <typename ValueType, typename IndexType>
GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(ValueType, IndexType)
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
Expand All @@ -447,9 +444,6 @@ namespace upper_trs {
GKO_DECLARE_UPPER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL()
GKO_NOT_COMPILED(GKO_HOOK_MODULE);

GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL()
GKO_NOT_COMPILED(GKO_HOOK_MODULE);

template <typename ValueType, typename IndexType>
GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(ValueType, IndexType)
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
Expand Down
11 changes: 1 addition & 10 deletions core/solver/lower_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ namespace {


GKO_REGISTER_OPERATION(generate, lower_trs::generate);
GKO_REGISTER_OPERATION(init_struct, lower_trs::init_struct);
GKO_REGISTER_OPERATION(should_perform_transpose,
lower_trs::should_perform_transpose);
GKO_REGISTER_OPERATION(solve, lower_trs::solve);
Expand Down Expand Up @@ -85,19 +84,11 @@ std::unique_ptr<LinOp> LowerTrs<ValueType, IndexType>::conj_transpose() const
}


template <typename ValueType, typename IndexType>
void LowerTrs<ValueType, IndexType>::init_trs_solve_struct()
{
this->get_executor()->run(lower_trs::make_init_struct(this->solve_struct_));
}


template <typename ValueType, typename IndexType>
void LowerTrs<ValueType, IndexType>::generate()
{
this->get_executor()->run(lower_trs::make_generate(
gko::lend(system_matrix_), gko::lend(this->solve_struct_),
parameters_.num_rhs));
gko::lend(system_matrix_), this->solve_struct_, parameters_.num_rhs));
}


Expand Down
14 changes: 4 additions & 10 deletions core/solver/lower_trs_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,10 @@ namespace lower_trs {
bool &do_transpose)


#define GKO_DECLARE_LOWER_TRS_INIT_STRUCT_KERNEL() \
void init_struct(std::shared_ptr<const DefaultExecutor> exec, \
std::shared_ptr<solver::SolveStruct> &solve_struct)


#define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \
void generate(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype> *matrix, \
solver::SolveStruct *solve_struct, \
#define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \
void generate(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype> *matrix, \
std::shared_ptr<solver::SolveStruct> &solve_struct, \
const gko::size_type num_rhs)


Expand All @@ -77,7 +72,6 @@ namespace lower_trs {

#define GKO_DECLARE_ALL_AS_TEMPLATES \
GKO_DECLARE_LOWER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL(); \
GKO_DECLARE_LOWER_TRS_INIT_STRUCT_KERNEL(); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_LOWER_TRS_SOLVE_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
Expand Down
11 changes: 1 addition & 10 deletions core/solver/upper_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ namespace {


GKO_REGISTER_OPERATION(generate, upper_trs::generate);
GKO_REGISTER_OPERATION(init_struct, upper_trs::init_struct);
GKO_REGISTER_OPERATION(should_perform_transpose,
upper_trs::should_perform_transpose);
GKO_REGISTER_OPERATION(solve, upper_trs::solve);
Expand Down Expand Up @@ -85,19 +84,11 @@ std::unique_ptr<LinOp> UpperTrs<ValueType, IndexType>::conj_transpose() const
}


template <typename ValueType, typename IndexType>
void UpperTrs<ValueType, IndexType>::init_trs_solve_struct()
{
this->get_executor()->run(upper_trs::make_init_struct(this->solve_struct_));
}


template <typename ValueType, typename IndexType>
void UpperTrs<ValueType, IndexType>::generate()
{
this->get_executor()->run(upper_trs::make_generate(
gko::lend(system_matrix_), gko::lend(this->solve_struct_),
parameters_.num_rhs));
gko::lend(system_matrix_), this->solve_struct_, parameters_.num_rhs));
}


Expand Down
14 changes: 4 additions & 10 deletions core/solver/upper_trs_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,10 @@ namespace upper_trs {
bool &do_transpose)


#define GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL() \
void init_struct(std::shared_ptr<const DefaultExecutor> exec, \
std::shared_ptr<gko::solver::SolveStruct> &solve_struct)


#define GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(_vtype, _itype) \
void generate(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype> *matrix, \
solver::SolveStruct *solve_struct, \
#define GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(_vtype, _itype) \
void generate(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype> *matrix, \
std::shared_ptr<gko::solver::SolveStruct> &solve_struct, \
const gko::size_type num_rhs)


Expand All @@ -77,7 +72,6 @@ namespace upper_trs {

#define GKO_DECLARE_ALL_AS_TEMPLATES \
GKO_DECLARE_UPPER_TRS_SHOULD_PERFORM_TRANSPOSE_KERNEL(); \
GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL(); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_UPPER_TRS_SOLVE_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
Expand Down
Loading

0 comments on commit 90a3bbb

Please sign in to comment.