Skip to content

Commit

Permalink
Merge #1687 Add workspace for reduction in residual norm
Browse files Browse the repository at this point in the history
add workspace for reduction in residual norm, which avoid alloc/dealloc in the repeated call.

Related PR: #1687
  • Loading branch information
yhmtsai authored Oct 9, 2024
2 parents e518aee + 1ed76a7 commit 2dbe389
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
18 changes: 12 additions & 6 deletions core/stop/residual_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ ResidualNormBase<ValueType>::ResidualNormBase(
system_matrix_{args.system_matrix},
b_{args.b},
one_{gko::initialize<Vector>({1}, exec)},
neg_one_{gko::initialize<Vector>({-1}, exec)}
neg_one_{gko::initialize<Vector>({-1}, exec)},
reduction_tmp_{exec}
{
switch (baseline_) {
case mode::initial_resnorm: {
Expand All @@ -113,7 +114,8 @@ ResidualNormBase<ValueType>::ResidualNormBase(
args.system_matrix->apply(neg_one_, args.x, one_, b_clone);
norm_dispatch<ValueType>(
[&](auto dense_r) {
dense_r->compute_norm2(this->starting_tau_);
dense_r->compute_norm2(this->starting_tau_,
reduction_tmp_);
},
b_clone.get());
}
Expand All @@ -122,7 +124,7 @@ ResidualNormBase<ValueType>::ResidualNormBase(
exec, dim<2>{1, args.initial_residual->get_size()[1]});
norm_dispatch<ValueType>(
[&](auto dense_r) {
dense_r->compute_norm2(this->starting_tau_);
dense_r->compute_norm2(this->starting_tau_, reduction_tmp_);
},
args.initial_residual);
}
Expand All @@ -135,7 +137,9 @@ ResidualNormBase<ValueType>::ResidualNormBase(
this->starting_tau_ =
NormVector::create(exec, dim<2>{1, args.b->get_size()[1]});
norm_dispatch<ValueType>(
[&](auto dense_r) { dense_r->compute_norm2(this->starting_tau_); },
[&](auto dense_r) {
dense_r->compute_norm2(this->starting_tau_, reduction_tmp_);
},
args.b.get());
break;
}
Expand Down Expand Up @@ -169,7 +173,9 @@ bool ResidualNormBase<ValueType>::check_impl(
return false;
} else if (updater.residual_ != nullptr) {
norm_dispatch<ValueType>(
[&](auto dense_r) { dense_r->compute_norm2(u_dense_tau_); },
[&](auto dense_r) {
dense_r->compute_norm2(u_dense_tau_, reduction_tmp_);
},
updater.residual_);
dense_tau = u_dense_tau_.get();
} else if (updater.solution_ != nullptr && system_matrix_ != nullptr &&
Expand All @@ -179,7 +185,7 @@ bool ResidualNormBase<ValueType>::check_impl(
[&](auto dense_b, auto dense_x) {
auto dense_r = dense_b->clone();
system_matrix_->apply(neg_one_, dense_x, one_, dense_r);
dense_r->compute_norm2(u_dense_tau_);
dense_r->compute_norm2(u_dense_tau_, reduction_tmp_);
},
b_.get(), updater.solution_);
dense_tau = u_dense_tau_.get();
Expand Down
2 changes: 2 additions & 0 deletions include/ginkgo/core/stop/residual_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class ResidualNormBase
/* one/neg_one for residual computation */
std::shared_ptr<const Vector> one_{};
std::shared_ptr<const Vector> neg_one_{};
// workspace for reduction
mutable gko::array<char> reduction_tmp_;
};


Expand Down

0 comments on commit 2dbe389

Please sign in to comment.