From 7a25596814a21869087c70f759cd7e21087cd492 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Fri, 28 Jun 2024 18:35:22 +0200 Subject: [PATCH] add config for chebyshev --- core/config/config_helper.hpp | 8 +++--- core/config/registry.cpp | 1 + core/config/solver_config.cpp | 2 ++ core/solver/chebyshev.cpp | 23 +++++++++++++++ core/test/config/solver.cpp | 36 ++++++++++++++++++++++++ include/ginkgo/core/solver/chebyshev.hpp | 22 ++++++++++++++- 6 files changed, 87 insertions(+), 5 deletions(-) diff --git a/core/config/config_helper.hpp b/core/config/config_helper.hpp index 555bb75c2a8..c5c07b78e54 100644 --- a/core/config/config_helper.hpp +++ b/core/config/config_helper.hpp @@ -24,10 +24,9 @@ namespace gko { namespace config { -#define GKO_INVALID_CONFIG_VALUE(_entry, _value) \ - GKO_INVALID_STATE(std::string("The value >" + _value + \ - "< is invalid for the entry >" + _entry + \ - "<")) +#define GKO_INVALID_CONFIG_VALUE(_entry, _value) \ + GKO_INVALID_STATE(std::string("The value >") + _value + \ + "< is invalid for the entry >" + _entry + "<") #define GKO_MISSING_CONFIG_ENTRY(_entry) \ @@ -52,6 +51,7 @@ enum class LinOpFactoryType : int { Direct, LowerTrs, UpperTrs, + Chebyshev, Factorization_Ic, Factorization_Ilu, Cholesky, diff --git a/core/config/registry.cpp b/core/config/registry.cpp index 188c34b35dd..2d477542a40 100644 --- a/core/config/registry.cpp +++ b/core/config/registry.cpp @@ -32,6 +32,7 @@ configuration_map generate_config_map() {"solver::Direct", parse}, {"solver::LowerTrs", parse}, {"solver::UpperTrs", parse}, + {"solver::Chebyshev", parse}, {"factorization::Ic", parse}, {"factorization::Ilu", parse}, {"factorization::Cholesky", parse}, diff --git a/core/config/solver_config.cpp b/core/config/solver_config.cpp index b35a639b8e7..aeef0f1356a 100644 --- a/core/config/solver_config.cpp +++ b/core/config/solver_config.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -43,6 +44,7 @@ GKO_PARSE_VALUE_TYPE(CbGmres, gko::solver::CbGmres); GKO_PARSE_VALUE_AND_INDEX_TYPE(Direct, gko::experimental::solver::Direct); GKO_PARSE_VALUE_AND_INDEX_TYPE(LowerTrs, gko::solver::LowerTrs); GKO_PARSE_VALUE_AND_INDEX_TYPE(UpperTrs, gko::solver::UpperTrs); +GKO_PARSE_VALUE_TYPE(Chebyshev, gko::solver::Chebyshev); template <> diff --git a/core/solver/chebyshev.cpp b/core/solver/chebyshev.cpp index 7354083bcc4..bcc887254b5 100644 --- a/core/solver/chebyshev.cpp +++ b/core/solver/chebyshev.cpp @@ -8,6 +8,7 @@ #include #include +#include "core/config/solver_config.hpp" #include "core/distributed/helpers.hpp" #include "core/solver/ir_kernels.hpp" #include "core/solver/solver_base.hpp" @@ -27,6 +28,28 @@ GKO_REGISTER_OPERATION(initialize, ir::initialize); } // anonymous namespace } // namespace chebyshev +template +typename Chebyshev::parameters_type Chebyshev::parse( + const config::pnode& config, const config::registry& context, + const config::type_descriptor& td_for_child) +{ + auto params = solver::Chebyshev::build(); + common_solver_parse(params, config, context, td_for_child); + if (auto& obj = config.get("foci")) { + auto arr = obj.get_array(); + if (arr.size() != 2) { + GKO_INVALID_CONFIG_VALUE("foci", "must contain two elements"); + } + params.with_foci(gko::config::get_value(arr.at(0)), + gko::config::get_value(arr.at(1))); + } + if (auto& obj = config.get("default_initial_guess")) { + params.with_default_initial_guess( + gko::config::get_value(obj)); + } + return params; +} + template Chebyshev::Chebyshev(const Factory* factory, diff --git a/core/test/config/solver.cpp b/core/test/config/solver.cpp index a170ebb1e04..18359d61559 100644 --- a/core/test/config/solver.cpp +++ b/core/test/config/solver.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -439,6 +440,41 @@ struct UpperTrs : TrsHelper { }; +struct Chebyshev : SolverConfigTest, + gko::solver::Chebyshev> { + static pnode::map_type setup_base() + { + return {{"type", pnode{"solver::Chebyshev"}}}; + } + + template + static void set(pnode::map_type& config_map, ParamType& param, registry reg, + std::shared_ptr exec) + { + solver_config_test::template set(config_map, param, reg, + exec); + using fvt = typename decltype(param.foci)::first_type; + config_map["foci"] = + pnode{pnode::array_type{pnode{fvt{0.5}}, pnode{fvt{1.5}}}}; + param.with_foci(fvt{0.5}, fvt{1.5}); + config_map["default_initial_guess"] = pnode{"zero"}; + param.with_default_initial_guess(gko::solver::initial_guess_mode::zero); + } + + template + static void validate(gko::LinOpFactory* result, AnswerType* answer) + { + auto res_param = gko::as(result)->get_parameters(); + auto ans_param = answer->get_parameters(); + + solver_config_test::template validate(result, answer); + ASSERT_EQ(res_param.foci, ans_param.foci); + ASSERT_EQ(res_param.default_initial_guess, + ans_param.default_initial_guess); + } +}; + + template class Solver : public ::testing::Test { protected: diff --git a/include/ginkgo/core/solver/chebyshev.hpp b/include/ginkgo/core/solver/chebyshev.hpp index 7be54b53088..db778f21e72 100644 --- a/include/ginkgo/core/solver/chebyshev.hpp +++ b/include/ginkgo/core/solver/chebyshev.hpp @@ -11,6 +11,8 @@ #include #include #include +#include +#include #include #include #include @@ -49,7 +51,7 @@ namespace solver { * @ingroup LinOp */ template -class Chebyshev +class Chebyshev final : public EnableLinOp>, public EnablePreconditionedIterativeSolver>, @@ -133,6 +135,24 @@ class Chebyshev GKO_ENABLE_LIN_OP_FACTORY(Chebyshev, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); + /** + * Create the parameters from the property_tree. + * Because this is directly tied to the specific type, the value/index type + * settings within config are ignored and type_descriptor is only used + * for children configs. + * + * @param config the property tree for setting + * @param context the registry + * @param td_for_child the type descriptor for children configs. The + * default uses the value type of this class. + * + * @return parameters + */ + static parameters_type parse(const config::pnode& config, + const config::registry& context, + const config::type_descriptor& td_for_child = + config::make_type_descriptor()); + protected: void apply_impl(const LinOp* b, LinOp* x) const override;