Skip to content

Commit

Permalink
Merge Fix the algorithm of transposed triangular solver
Browse files Browse the repository at this point in the history
This PR fixes the algorithm usage when generate the transpose of triangular solver.
It uses the same algorithm as the original one not the default.

Related PR: ginkgo-project#1641
  • Loading branch information
yhmtsai authored and MarcelKoch committed Dec 2, 2024
2 parents 5934ef4 + a71eec7 commit 8c6a7b7
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 4 deletions.
2 changes: 2 additions & 0 deletions core/solver/lower_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ std::unique_ptr<LinOp> LowerTrs<ValueType, IndexType>::transpose() const
{
return transposed_type::build()
.with_num_rhs(this->parameters_.num_rhs)
.with_algorithm(this->parameters_.algorithm)
.on(this->get_executor())
->generate(share(this->get_system_matrix()->transpose()));
}
Expand All @@ -109,6 +110,7 @@ std::unique_ptr<LinOp> LowerTrs<ValueType, IndexType>::conj_transpose() const
{
return transposed_type::build()
.with_num_rhs(this->parameters_.num_rhs)
.with_algorithm(this->parameters_.algorithm)
.on(this->get_executor())
->generate(share(this->get_system_matrix()->conj_transpose()));
}
Expand Down
2 changes: 2 additions & 0 deletions core/solver/upper_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ std::unique_ptr<LinOp> UpperTrs<ValueType, IndexType>::transpose() const
{
return transposed_type::build()
.with_num_rhs(this->parameters_.num_rhs)
.with_algorithm(this->parameters_.algorithm)
.on(this->get_executor())
->generate(share(this->get_system_matrix()->transpose()));
}
Expand All @@ -109,6 +110,7 @@ std::unique_ptr<LinOp> UpperTrs<ValueType, IndexType>::conj_transpose() const
{
return transposed_type::build()
.with_num_rhs(this->parameters_.num_rhs)
.with_algorithm(this->parameters_.algorithm)
.on(this->get_executor())
->generate(share(this->get_system_matrix()->conj_transpose()));
}
Expand Down
69 changes: 67 additions & 2 deletions reference/test/solver/lower_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class LowerTrs : public ::testing::Test {
{365.0, 97.0, -654.0, 8.0, 91.0}},
exec)),
lower_trs_factory(Solver::build().on(exec)),
lower_trs_syncfree_factory(
Solver::build()
.with_algorithm(gko::solver::trisolve_algorithm::syncfree)
.on(exec)),
lower_trs_factory_mrhs(Solver::build().with_num_rhs(2u).on(exec)),
lower_trs_factory_unit(
Solver::build().with_unit_diagonal(true).on(exec))
Expand All @@ -66,6 +70,7 @@ class LowerTrs : public ::testing::Test {
std::shared_ptr<Mtx> mtx_big_lower;
std::shared_ptr<Mtx> mtx_big_general;
std::unique_ptr<typename Solver::Factory> lower_trs_factory;
std::unique_ptr<typename Solver::Factory> lower_trs_syncfree_factory;
std::unique_ptr<typename Solver::Factory> lower_trs_factory_mrhs;
std::unique_ptr<typename Solver::Factory> lower_trs_factory_unit;
};
Expand Down Expand Up @@ -348,27 +353,87 @@ TYPED_TEST(LowerTrs, SolvesTransposedTriangularSystem)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using Solver = typename TestFixture::Solver;
std::shared_ptr<Mtx> b = gko::initialize<Mtx>({1.0, 2.0, 1.0}, this->exec);
auto x = gko::initialize<Mtx>({0.0, 0.0, 0.0}, this->exec);
auto solver = this->lower_trs_factory->generate(this->mtx);
auto transposed_solver =
gko::as<typename Solver::transposed_type>(solver->transpose());

solver->transpose()->apply(b, x);
transposed_solver->apply(b, x);

GKO_ASSERT_MTX_NEAR(x, l({0.0, 0.0, 1.0}), r<value_type>::value);
// Ensure that the other test with syncfree is not the default option
ASSERT_EQ(solver->get_parameters().algorithm,
gko::solver::trisolve_algorithm::sparselib);
ASSERT_EQ(transposed_solver->get_parameters().algorithm,
solver->get_parameters().algorithm);
}


TYPED_TEST(LowerTrs, SolvesConjTransposedTriangularSystem)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using Solver = typename TestFixture::Solver;
std::shared_ptr<Mtx> b = gko::initialize<Mtx>({1.0, 2.0, 1.0}, this->exec);
auto x = gko::initialize<Mtx>({0.0, 0.0, 0.0}, this->exec);
auto solver = this->lower_trs_factory->generate(this->mtx);
auto conj_transposed_solver =
gko::as<typename Solver::transposed_type>(solver->conj_transpose());

solver->conj_transpose()->apply(b, x);
conj_transposed_solver->apply(b, x);

GKO_ASSERT_MTX_NEAR(x, l({0.0, 0.0, 1.0}), r<value_type>::value);
// Ensure that the other test with syncfree is not the default option
ASSERT_EQ(solver->get_parameters().algorithm,
gko::solver::trisolve_algorithm::sparselib);
ASSERT_EQ(conj_transposed_solver->get_parameters().algorithm,
solver->get_parameters().algorithm);
}


TYPED_TEST(LowerTrs, SolvesTransposedTriangularSystemWithSyncFree)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using Solver = typename TestFixture::Solver;
std::shared_ptr<Mtx> b = gko::initialize<Mtx>({1.0, 2.0, 1.0}, this->exec);
auto x = gko::initialize<Mtx>({0.0, 0.0, 0.0}, this->exec);
auto solver = this->lower_trs_syncfree_factory->generate(this->mtx);
auto transposed_solver =
gko::as<typename Solver::transposed_type>(solver->transpose());

transposed_solver->apply(b, x);

GKO_ASSERT_MTX_NEAR(x, l({0.0, 0.0, 1.0}), r<value_type>::value);
// Ensure that this test uses syncfree
ASSERT_EQ(solver->get_parameters().algorithm,
gko::solver::trisolve_algorithm::syncfree);
ASSERT_EQ(transposed_solver->get_parameters().algorithm,
solver->get_parameters().algorithm);
}


TYPED_TEST(LowerTrs, SolvesConjTransposedTriangularSystemWithSyncFree)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using Solver = typename TestFixture::Solver;
std::shared_ptr<Mtx> b = gko::initialize<Mtx>({1.0, 2.0, 1.0}, this->exec);
auto x = gko::initialize<Mtx>({0.0, 0.0, 0.0}, this->exec);
auto solver = this->lower_trs_syncfree_factory->generate(this->mtx);
auto conj_transposed_solver =
gko::as<typename Solver::transposed_type>(solver->conj_transpose());

conj_transposed_solver->apply(b, x);

GKO_ASSERT_MTX_NEAR(x, l({0.0, 0.0, 1.0}), r<value_type>::value);
// Ensure that this test uses syncfree
ASSERT_EQ(solver->get_parameters().algorithm,
gko::solver::trisolve_algorithm::syncfree);
ASSERT_EQ(conj_transposed_solver->get_parameters().algorithm,
solver->get_parameters().algorithm);
}


Expand Down
69 changes: 67 additions & 2 deletions reference/test/solver/upper_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class UpperTrs : public ::testing::Test {
{0.0, 2.0, 0.0, 4.0, 124.0}},
exec)),
upper_trs_factory(Solver::build().on(exec)),
upper_trs_syncfree_factory(
Solver::build()
.with_algorithm(gko::solver::trisolve_algorithm::syncfree)
.on(exec)),
upper_trs_factory_mrhs(Solver::build().with_num_rhs(2u).on(exec)),
upper_trs_factory_unit(
Solver::build().with_unit_diagonal(true).on(exec))
Expand All @@ -66,6 +70,7 @@ class UpperTrs : public ::testing::Test {
std::shared_ptr<Mtx> mtx_big_upper;
std::shared_ptr<Mtx> mtx_big_general;
std::unique_ptr<typename Solver::Factory> upper_trs_factory;
std::unique_ptr<typename Solver::Factory> upper_trs_syncfree_factory;
std::unique_ptr<typename Solver::Factory> upper_trs_factory_mrhs;
std::unique_ptr<typename Solver::Factory> upper_trs_factory_unit;
};
Expand Down Expand Up @@ -349,27 +354,87 @@ TYPED_TEST(UpperTrs, SolvesTransposedTriangularSystem)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using Solver = typename TestFixture::Solver;
std::shared_ptr<Mtx> b = gko::initialize<Mtx>({4.0, 2.0, 3.0}, this->exec);
auto x = gko::initialize<Mtx>({0.0, 0.0, 0.0}, this->exec);
auto solver = this->upper_trs_factory->generate(this->mtx);
auto transposed_solver =
gko::as<typename Solver::transposed_type>(solver->transpose());

solver->transpose()->apply(b, x);
transposed_solver->apply(b, x);

GKO_ASSERT_MTX_NEAR(x, l({4.0, -10.0, 19.0}), r<value_type>::value);
// Ensure that the other test with syncfree is not the default option
ASSERT_EQ(solver->get_parameters().algorithm,
gko::solver::trisolve_algorithm::sparselib);
ASSERT_EQ(transposed_solver->get_parameters().algorithm,
solver->get_parameters().algorithm);
}


TYPED_TEST(UpperTrs, SolvesConjTransposedTriangularSystem)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using Solver = typename TestFixture::Solver;
std::shared_ptr<Mtx> b = gko::initialize<Mtx>({4.0, 2.0, 3.0}, this->exec);
auto x = gko::initialize<Mtx>({0.0, 0.0, 0.0}, this->exec);
auto solver = this->upper_trs_factory->generate(this->mtx);
auto conj_transposed_solver =
gko::as<typename Solver::transposed_type>(solver->conj_transpose());

solver->conj_transpose()->apply(b, x);
conj_transposed_solver->apply(b, x);

GKO_ASSERT_MTX_NEAR(x, l({4.0, -10.0, 19.0}), r<value_type>::value);
// Ensure that the other test with syncfree is not the default option
ASSERT_EQ(solver->get_parameters().algorithm,
gko::solver::trisolve_algorithm::sparselib);
ASSERT_EQ(conj_transposed_solver->get_parameters().algorithm,
solver->get_parameters().algorithm);
}


TYPED_TEST(UpperTrs, SolvesTransposedTriangularSystemWithSyncFree)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using Solver = typename TestFixture::Solver;
std::shared_ptr<Mtx> b = gko::initialize<Mtx>({4.0, 2.0, 3.0}, this->exec);
auto x = gko::initialize<Mtx>({0.0, 0.0, 0.0}, this->exec);
auto solver = this->upper_trs_syncfree_factory->generate(this->mtx);
auto transposed_solver =
gko::as<typename Solver::transposed_type>(solver->transpose());

transposed_solver->apply(b, x);

GKO_ASSERT_MTX_NEAR(x, l({4.0, -10.0, 19.0}), r<value_type>::value);
// Ensure that this test uses syncfree
ASSERT_EQ(solver->get_parameters().algorithm,
gko::solver::trisolve_algorithm::syncfree);
ASSERT_EQ(transposed_solver->get_parameters().algorithm,
solver->get_parameters().algorithm);
}


TYPED_TEST(UpperTrs, SolvesConjTransposedTriangularSystemWithSyncFree)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
using Solver = typename TestFixture::Solver;
std::shared_ptr<Mtx> b = gko::initialize<Mtx>({4.0, 2.0, 3.0}, this->exec);
auto x = gko::initialize<Mtx>({0.0, 0.0, 0.0}, this->exec);
auto solver = this->upper_trs_syncfree_factory->generate(this->mtx);
auto conj_transposed_solver =
gko::as<typename Solver::transposed_type>(solver->conj_transpose());

conj_transposed_solver->apply(b, x);

GKO_ASSERT_MTX_NEAR(x, l({4.0, -10.0, 19.0}), r<value_type>::value);
// Ensure that this test uses syncfree
ASSERT_EQ(solver->get_parameters().algorithm,
gko::solver::trisolve_algorithm::syncfree);
ASSERT_EQ(conj_transposed_solver->get_parameters().algorithm,
solver->get_parameters().algorithm);
}


Expand Down

0 comments on commit 8c6a7b7

Please sign in to comment.