From 1ed76a7d5056d9505e4f6161828c586a65596d5a Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Tue, 8 Oct 2024 15:40:34 +0200 Subject: [PATCH] add workspace for reduction usage --- core/stop/residual_norm.cpp | 18 ++++++++++++------ include/ginkgo/core/stop/residual_norm.hpp | 2 ++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/core/stop/residual_norm.cpp b/core/stop/residual_norm.cpp index adf7da3e2e6..4e73cc8d56a 100644 --- a/core/stop/residual_norm.cpp +++ b/core/stop/residual_norm.cpp @@ -98,7 +98,8 @@ ResidualNormBase::ResidualNormBase( system_matrix_{args.system_matrix}, b_{args.b}, one_{gko::initialize({1}, exec)}, - neg_one_{gko::initialize({-1}, exec)} + neg_one_{gko::initialize({-1}, exec)}, + reduction_tmp_{exec} { switch (baseline_) { case mode::initial_resnorm: { @@ -113,7 +114,8 @@ ResidualNormBase::ResidualNormBase( args.system_matrix->apply(neg_one_, args.x, one_, b_clone); norm_dispatch( [&](auto dense_r) { - dense_r->compute_norm2(this->starting_tau_); + dense_r->compute_norm2(this->starting_tau_, + reduction_tmp_); }, b_clone.get()); } @@ -122,7 +124,7 @@ ResidualNormBase::ResidualNormBase( exec, dim<2>{1, args.initial_residual->get_size()[1]}); norm_dispatch( [&](auto dense_r) { - dense_r->compute_norm2(this->starting_tau_); + dense_r->compute_norm2(this->starting_tau_, reduction_tmp_); }, args.initial_residual); } @@ -135,7 +137,9 @@ ResidualNormBase::ResidualNormBase( this->starting_tau_ = NormVector::create(exec, dim<2>{1, args.b->get_size()[1]}); norm_dispatch( - [&](auto dense_r) { dense_r->compute_norm2(this->starting_tau_); }, + [&](auto dense_r) { + dense_r->compute_norm2(this->starting_tau_, reduction_tmp_); + }, args.b.get()); break; } @@ -169,7 +173,9 @@ bool ResidualNormBase::check_impl( return false; } else if (updater.residual_ != nullptr) { norm_dispatch( - [&](auto dense_r) { dense_r->compute_norm2(u_dense_tau_); }, + [&](auto dense_r) { + dense_r->compute_norm2(u_dense_tau_, reduction_tmp_); + }, updater.residual_); dense_tau = u_dense_tau_.get(); } else if (updater.solution_ != nullptr && system_matrix_ != nullptr && @@ -179,7 +185,7 @@ bool ResidualNormBase::check_impl( [&](auto dense_b, auto dense_x) { auto dense_r = dense_b->clone(); system_matrix_->apply(neg_one_, dense_x, one_, dense_r); - dense_r->compute_norm2(u_dense_tau_); + dense_r->compute_norm2(u_dense_tau_, reduction_tmp_); }, b_.get(), updater.solution_); dense_tau = u_dense_tau_.get(); diff --git a/include/ginkgo/core/stop/residual_norm.hpp b/include/ginkgo/core/stop/residual_norm.hpp index 6ee3c843e6a..7ee020207d4 100644 --- a/include/ginkgo/core/stop/residual_norm.hpp +++ b/include/ginkgo/core/stop/residual_norm.hpp @@ -82,6 +82,8 @@ class ResidualNormBase /* one/neg_one for residual computation */ std::shared_ptr one_{}; std::shared_ptr neg_one_{}; + // workspace for reduction + mutable gko::array reduction_tmp_; };