Skip to content

Commit

Permalink
Support complex numbers in parallel GMRES
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsvu committed Jul 19, 2024
1 parent ced43eb commit 96200bc
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 79 deletions.
27 changes: 19 additions & 8 deletions src/ParallelAlgorithms/LinearSolver/Gmres/ElementActions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "Utilities/PrettyType.hpp"
#include "Utilities/Requires.hpp"
#include "Utilities/TMPL.hpp"
#include "Utilities/TypeTraits/GetFundamentalType.hpp"

/// \cond
namespace tuples {
Expand Down Expand Up @@ -300,6 +301,8 @@ struct PerformStep {
db::add_tag_prefix<LinearSolver::Tags::Operand, fields_tag>;
using preconditioned_operand_tag =
db::add_tag_prefix<LinearSolver::Tags::Preconditioned, operand_tag>;
using ValueType =
tt::get_complex_or_fundamental_type_t<typename fields_tag::type>;

public:
using const_global_cache_tags =
Expand Down Expand Up @@ -374,7 +377,7 @@ struct PerformStep {
Parallel::ReductionData<
Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
Parallel::ReductionDatum<double, funcl::Plus<>>>{
Parallel::ReductionDatum<ValueType, funcl::Plus<>>>{
get<Convergence::Tags::IterationId<OptionsGroup>>(box),
get<orthogonalization_iteration_id_tag>(box),
inner_product(get<basis_history_tag>(box)[0],
Expand All @@ -400,9 +403,12 @@ struct OrthogonalizeOperand {
Convergence::Tags::IterationId<OptionsGroup>>;
using basis_history_tag =
LinearSolver::Tags::KrylovSubspaceBasis<operand_tag>;
using ValueType =
tt::get_complex_or_fundamental_type_t<typename fields_tag::type>;

public:
using inbox_tags = tmpl::list<Tags::Orthogonalization<OptionsGroup>>;
using inbox_tags =
tmpl::list<Tags::Orthogonalization<OptionsGroup, ValueType>>;

template <typename DbTagsList, typename... InboxTags, typename Metavariables,
typename ArrayIndex, typename ActionList,
Expand All @@ -414,12 +420,13 @@ struct OrthogonalizeOperand {
const ParallelComponent* const /*meta*/) {
const size_t iteration_id =
db::get<Convergence::Tags::IterationId<OptionsGroup>>(box);
auto& inbox = get<Tags::Orthogonalization<OptionsGroup>>(inboxes);
auto& inbox =
get<Tags::Orthogonalization<OptionsGroup, ValueType>>(inboxes);
if (inbox.find(iteration_id) == inbox.end()) {
return {Parallel::AlgorithmExecution::Retry, std::nullopt};
}

const double orthogonalization =
const ValueType orthogonalization =
std::move(inbox.extract(iteration_id).mapped());

db::mutate<operand_tag, orthogonalization_iteration_id_tag>(
Expand All @@ -437,7 +444,7 @@ struct OrthogonalizeOperand {
get<orthogonalization_iteration_id_tag>(box);
const bool orthogonalization_complete =
next_orthogonalization_iteration_id == iteration_id;
const double local_orthogonalization =
const ValueType local_orthogonalization =
inner_product(orthogonalization_complete
? get<operand_tag>(box)
: gsl::at(get<basis_history_tag>(box),
Expand All @@ -451,7 +458,7 @@ struct OrthogonalizeOperand {
Parallel::ReductionData<
Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
Parallel::ReductionDatum<size_t, funcl::AssertEqual<>>,
Parallel::ReductionDatum<double, funcl::Plus<>>>{
Parallel::ReductionDatum<ValueType, funcl::Plus<>>>{
iteration_id, next_orthogonalization_iteration_id,
local_orthogonalization},
Parallel::get_parallel_component<ParallelComponent>(cache)[array_index],
Expand Down Expand Up @@ -483,11 +490,14 @@ struct NormalizeOperandAndUpdateField {
using preconditioned_basis_history_tag =
LinearSolver::Tags::KrylovSubspaceBasis<std::conditional_t<
Preconditioned, preconditioned_operand_tag, operand_tag>>;
using ValueType =
tt::get_complex_or_fundamental_type_t<typename fields_tag::type>;

public:
using const_global_cache_tags =
tmpl::list<logging::Tags::Verbosity<OptionsGroup>>;
using inbox_tags = tmpl::list<Tags::FinalOrthogonalization<OptionsGroup>>;
using inbox_tags =
tmpl::list<Tags::FinalOrthogonalization<OptionsGroup, ValueType>>;

template <typename DbTagsList, typename... InboxTags, typename Metavariables,
typename ArrayIndex, typename ActionList,
Expand All @@ -499,7 +509,8 @@ struct NormalizeOperandAndUpdateField {
const ParallelComponent* const /*meta*/) {
const size_t iteration_id =
db::get<Convergence::Tags::IterationId<OptionsGroup>>(box);
auto& inbox = get<Tags::FinalOrthogonalization<OptionsGroup>>(inboxes);
auto& inbox =
get<Tags::FinalOrthogonalization<OptionsGroup, ValueType>>(inboxes);
if (inbox.find(iteration_id) == inbox.end()) {
return {Parallel::AlgorithmExecution::Retry, std::nullopt};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "IO/Logging/Tags.hpp"
#include "IO/Logging/Verbosity.hpp"
#include "NumericalAlgorithms/Convergence/Tags.hpp"
#include "NumericalAlgorithms/LinearSolver/InnerProduct.hpp"
#include "Parallel/GlobalCache.hpp"
#include "Parallel/Invoke.hpp"
#include "Parallel/Printf/Printf.hpp"
Expand Down Expand Up @@ -106,6 +107,8 @@ struct StoreOrthogonalization {
::Tags::Previous<residual_magnitude_tag>;
using orthogonalization_history_tag =
LinearSolver::Tags::OrthogonalizationHistory<fields_tag>;
using ValueType =
tt::get_complex_or_fundamental_type_t<typename fields_tag::type>;

public:
template <typename ParallelComponent, typename DbTagsList,
Expand All @@ -116,7 +119,7 @@ struct StoreOrthogonalization {
const ArrayIndex& /*array_index*/,
const size_t iteration_id,
const size_t orthogonalization_iteration_id,
const double orthogonalization) {
const ValueType orthogonalization) {
if (UNLIKELY(orthogonalization_iteration_id == 0)) {
// Append a row and a column to the orthogonalization history. Zero the
// entries that won't be set during the orthogonalization procedure below.
Expand All @@ -143,19 +146,21 @@ struct StoreOrthogonalization {
},
make_not_null(&box));

Parallel::receive_data<Tags::Orthogonalization<OptionsGroup>>(
Parallel::receive_data<Tags::Orthogonalization<OptionsGroup, ValueType>>(
Parallel::get_parallel_component<BroadcastTarget>(cache),
iteration_id, orthogonalization);
return;
}

// At this point, the orthogonalization procedure is complete.
ASSERT(equal_within_roundoff(imag(orthogonalization), 0.0),
"Normalization is not real: " << orthogonalization);
const double normalization = sqrt(real(orthogonalization));
db::mutate<orthogonalization_history_tag>(
[orthogonalization, iteration_id,
[normalization, iteration_id,
orthogonalization_iteration_id](const auto orthogonalization_history) {
(*orthogonalization_history)(orthogonalization_iteration_id,
iteration_id - 1) =
sqrt(orthogonalization);
iteration_id - 1) = normalization;
},
make_not_null(&box));

Expand All @@ -164,18 +169,19 @@ struct StoreOrthogonalization {
const auto& orthogonalization_history =
get<orthogonalization_history_tag>(box);
const auto num_rows = orthogonalization_iteration_id + 1;
blaze::DynamicMatrix<double> qr_Q;
blaze::DynamicMatrix<double> qr_R;
blaze::DynamicMatrix<ValueType> qr_Q;
blaze::DynamicMatrix<ValueType> qr_R;
blaze::qr(orthogonalization_history, qr_Q, qr_R);
// Compute the residual vector from the QR decomposition
blaze::DynamicVector<double> beta(num_rows, 0.);
const double initial_residual_magnitude =
get<initial_residual_magnitude_tag>(box);
beta[0] = initial_residual_magnitude;
blaze::DynamicVector<double> minres =
blaze::inv(qr_R) * blaze::trans(qr_Q) * beta;
const double residual_magnitude =
blaze::length(beta - orthogonalization_history * minres);
blaze::DynamicVector<ValueType> minres =
blaze::inv(qr_R) * blaze::ctrans(qr_Q) * beta;
blaze::DynamicVector<ValueType> res =
beta - orthogonalization_history * minres;
const double residual_magnitude = sqrt(magnitude_square(res));

// At this point, the iteration is complete. We proceed with observing,
// logging and checking convergence before broadcasting back to the
Expand Down Expand Up @@ -233,9 +239,10 @@ struct StoreOrthogonalization {
}
}

Parallel::receive_data<Tags::FinalOrthogonalization<OptionsGroup>>(
Parallel::receive_data<
Tags::FinalOrthogonalization<OptionsGroup, ValueType>>(
Parallel::get_parallel_component<BroadcastTarget>(cache), iteration_id,
std::make_tuple(sqrt(orthogonalization), std::move(minres),
std::make_tuple(normalization, std::move(minres),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(has_converged)));
}
Expand Down
15 changes: 8 additions & 7 deletions src/ParallelAlgorithms/LinearSolver/Gmres/Tags/InboxTags.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,20 @@ struct InitialOrthogonalization
std::map<temporal_id, std::tuple<double, Convergence::HasConverged>>;
};

template <typename OptionsGroup>
struct Orthogonalization
: Parallel::InboxInserters::Value<Orthogonalization<OptionsGroup>> {
template <typename OptionsGroup, typename ValueType>
struct Orthogonalization : Parallel::InboxInserters::Value<
Orthogonalization<OptionsGroup, ValueType>> {
using temporal_id = size_t;
using type = std::map<temporal_id, double>;
using type = std::map<temporal_id, ValueType>;
};

template <typename OptionsGroup>
template <typename OptionsGroup, typename ValueType>
struct FinalOrthogonalization
: Parallel::InboxInserters::Value<FinalOrthogonalization<OptionsGroup>> {
: Parallel::InboxInserters::Value<
FinalOrthogonalization<OptionsGroup, ValueType>> {
using temporal_id = size_t;
using type =
std::map<temporal_id, std::tuple<double, blaze::DynamicVector<double>,
std::map<temporal_id, std::tuple<double, blaze::DynamicVector<ValueType>,
Convergence::HasConverged>>;
};

Expand Down
4 changes: 3 additions & 1 deletion src/ParallelAlgorithms/LinearSolver/Tags.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "DataStructures/DataBox/TagName.hpp"
#include "DataStructures/DynamicMatrix.hpp"
#include "Utilities/Gsl.hpp"
#include "Utilities/TypeTraits/GetFundamentalType.hpp"

/*!
* \ingroup LinearSolverGroup
Expand Down Expand Up @@ -137,11 +138,12 @@ struct Orthogonalization : db::PrefixTag, db::SimpleTag {
*/
template <typename Tag>
struct OrthogonalizationHistory : db::PrefixTag, db::SimpleTag {
using ValueType = tt::get_complex_or_fundamental_type_t<typename Tag::type>;
static std::string name() {
// Add "Linear" prefix to abbreviate the namespace for uniqueness
return "LinearOrthogonalizationHistory(" + db::tag_name<Tag>() + ")";
}
using type = blaze::DynamicMatrix<double>;
using type = blaze::DynamicMatrix<ValueType>;
using tag = Tag;
};

Expand Down
Loading

0 comments on commit 96200bc

Please sign in to comment.