Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Splines recover transpose #271

Merged
merged 8 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/ddc/kernels/splines/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ class Matrix
return bx;
}

virtual ddc::DSpan2D solve_multiple_transpose_inplace(ddc::DSpan2D const bx) const
{
assert(int(bx.extent(1)) == m_n);
int const info = solve_inplace_method(bx.data_handle(), 'T', bx.extent(0));

if (info < 0) {
std::cerr << -info << "-th argument had an illegal value" << std::endl;
// TODO: Add LOG_FATAL_ERROR
}
return bx;
}

template <class... Args>
Kokkos::View<double**, Args...> solve_batch_inplace(
Kokkos::View<double**, Args...> const bx) const
Expand Down
21 changes: 15 additions & 6 deletions include/ddc/kernels/splines/matrix_sparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class Matrix_Sparse : public Matrix
std::shared_ptr<matrix_sparse_type> m_matrix_sparse;

std::shared_ptr<gko::solver::Bicgstab<double>> m_solver;
std::shared_ptr<gko::LinOp> m_solver_tr;

int m_cols_per_chunk; // Maximum number of columns of B to be passed to a Ginkgo solver

Expand Down Expand Up @@ -144,8 +145,8 @@ class Matrix_Sparse : public Matrix
m_matrix_dense.reset();
matrix_data.remove_zeros();
m_matrix_sparse->read(matrix_data);

std::shared_ptr const gko_exec = m_matrix_sparse->get_executor();

// Create the solver factory
std::shared_ptr const residual_criterion
= gko::stop::ResidualNorm<double>::build().with_reduction_factor(1e-19).on(
Expand All @@ -166,17 +167,14 @@ class Matrix_Sparse : public Matrix
.on(gko_exec);

m_solver = solver_factory->generate(m_matrix_sparse);
m_solver_tr = m_solver->transpose();
gko_exec->synchronize();

return 0;
}

virtual int solve_inplace_method(double* b, char transpose, int n_equations) const override
{
if (transpose != 'N') {
throw std::domain_error("transpose");
}

std::shared_ptr const gko_exec = m_solver->get_executor();

int const main_chunk_size = std::min(m_cols_per_chunk, n_equations);
Expand All @@ -199,7 +197,18 @@ class Matrix_Sparse : public Matrix

Kokkos::deep_copy(x_subview, b_subview);

m_solver->apply(to_gko_dense(gko_exec, b_subview), to_gko_dense(gko_exec, x_subview));
if (transpose == 'N') {
m_solver
->apply(to_gko_dense(gko_exec, b_subview),
to_gko_dense(gko_exec, x_subview));
} else if (transpose == 'T') {
m_solver_tr
->apply(to_gko_dense(gko_exec, b_subview),
to_gko_dense(gko_exec, x_subview));
} else {
throw std::domain_error("transpose option not recognized");
}


Kokkos::deep_copy(b_subview, x_subview);
}
Expand Down
30 changes: 29 additions & 1 deletion tests/splines/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ void check_inverse(ddc::DSpan2D matrix, ddc::DSpan2D inv)
}
}
}

void check_inverse_transpose(ddc::DSpan2D matrix, ddc::DSpan2D inv)
{
double TOL = 1e-10;
std::size_t N = matrix.extent(0);

for (std::size_t i(0); i < N; ++i) {
for (std::size_t j(0); j < N; ++j) {
double id_val = 0.0;
for (std::size_t k(0); k < N; ++k) {
id_val += matrix(i, k) * inv(k, j);
}
EXPECT_NEAR(id_val, static_cast<double>(i == j), TOL);
}
}
}
} // namespace

class MatrixSizesFixture : public testing::TestWithParam<std::tuple<std::size_t, std::size_t>>
Expand Down Expand Up @@ -85,16 +101,28 @@ TEST_P(MatrixSizesFixture, Sparse)
}
// copy_matrix(val, matrix); // copy_matrix is not available for sparse matrix because of a limitation of Ginkgo API (get_element is not implemented). The workaround is to fill val directly in the loop

matrix->factorize();

Kokkos::DualView<double*> inv_ptr("inv_ptr", N * N);
ddc::DSpan2D inv(inv_ptr.h_view.data(), N, N);
fill_identity(inv);
inv_ptr.modify_host();
inv_ptr.sync_device();
matrix->factorize();
matrix->solve_multiple_inplace(ddc::DSpan2D(inv_ptr.d_view.data(), N, N));
inv_ptr.modify_device();
inv_ptr.sync_host();

Kokkos::DualView<double*> inv_tr_ptr("inv_tr_ptr", N * N);
ddc::DSpan2D inv_tr(inv_tr_ptr.h_view.data(), N, N);
fill_identity(inv_tr);
inv_tr_ptr.modify_host();
inv_tr_ptr.sync_device();
matrix->solve_multiple_transpose_inplace(ddc::DSpan2D(inv_tr_ptr.d_view.data(), N, N));
inv_tr_ptr.modify_device();
inv_tr_ptr.sync_host();

check_inverse(val, inv);
check_inverse_transpose(val, inv_tr);
}

INSTANTIATE_TEST_SUITE_P(
Expand Down
Loading