From 47f7bf0f191ceadaf3072342fafd36a746c72739 Mon Sep 17 00:00:00 2001 From: Peter Chang Date: Fri, 10 Nov 2023 19:43:35 -0500 Subject: [PATCH] refactor: pass model version in save_load --- vowpalwabbit/core/include/vw/core/learner.h | 11 ++++++++++- .../core/include/vw/core/reductions/active.h | 5 +---- .../core/reductions/cb/cb_explore_adf_common.h | 6 +++--- vowpalwabbit/core/src/learner.cc | 17 +++++++++++++++-- vowpalwabbit/core/src/parse_args.cc | 4 ++-- vowpalwabbit/core/src/parse_regressor.cc | 4 ++-- vowpalwabbit/core/src/reductions/active.cc | 10 +++++----- vowpalwabbit/core/src/reductions/automl.cc | 2 +- .../src/reductions/baseline_challenger_cb.cc | 2 +- vowpalwabbit/core/src/reductions/bfgs.cc | 2 +- vowpalwabbit/core/src/reductions/boosting.cc | 8 ++++---- vowpalwabbit/core/src/reductions/cb/cb_adf.cc | 5 ++--- .../core/src/reductions/cb/cb_explore.cc | 6 ++---- .../src/reductions/cb/cb_explore_adf_cover.cc | 14 ++++++-------- .../src/reductions/cb/cb_explore_adf_first.cc | 15 +++++++-------- .../cb/cb_explore_adf_graph_feedback.cc | 4 ++-- .../src/reductions/cb/cb_explore_adf_regcb.cc | 16 ++++++---------- .../reductions/cb/cb_explore_adf_squarecb.cc | 15 ++++++--------- .../reductions/cb/cb_explore_adf_synthcover.cc | 15 ++++++--------- vowpalwabbit/core/src/reductions/cbzo.cc | 2 +- .../reductions/conditional_contextual_bandit.cc | 6 ++---- .../core/src/reductions/eigen_memory_tree.cc | 2 +- .../core/src/reductions/epsilon_decay.cc | 2 +- vowpalwabbit/core/src/reductions/freegrad.cc | 2 +- vowpalwabbit/core/src/reductions/ftrl.cc | 2 +- vowpalwabbit/core/src/reductions/gd.cc | 2 +- vowpalwabbit/core/src/reductions/gd_mf.cc | 2 +- .../core/src/reductions/interaction_ground.cc | 2 +- vowpalwabbit/core/src/reductions/kernel_svm.cc | 6 +++--- vowpalwabbit/core/src/reductions/lda_core.cc | 4 ++-- vowpalwabbit/core/src/reductions/log_multi.cc | 2 +- vowpalwabbit/core/src/reductions/marginal.cc | 2 +- vowpalwabbit/core/src/reductions/memory_tree.cc | 2 +- vowpalwabbit/core/src/reductions/mwt.cc | 2 +- vowpalwabbit/core/src/reductions/oja_newton.cc | 2 +- vowpalwabbit/core/src/reductions/plt.cc | 9 +++------ vowpalwabbit/core/src/reductions/recall_tree.cc | 2 +- .../core/src/reductions/stagewise_poly.cc | 2 +- vowpalwabbit/core/src/reductions/svrg.cc | 2 +- 39 files changed, 108 insertions(+), 110 deletions(-) diff --git a/vowpalwabbit/core/include/vw/core/learner.h b/vowpalwabbit/core/include/vw/core/learner.h index d9ea5a8b92a..76e6d96247f 100644 --- a/vowpalwabbit/core/include/vw/core/learner.h +++ b/vowpalwabbit/core/include/vw/core/learner.h @@ -70,6 +70,7 @@ using multipredict_func = using sensitivity_func = std::function; using save_load_func = std::function; +using save_load_ver_func = std::function; using pre_save_load_func = std::function; using save_metric_func = std::function; @@ -182,7 +183,7 @@ class learner final : public std::enable_shared_from_this float sensitivity(example& ec, size_t i = 0); // Called anytime saving or loading needs to happen. Autorecursive. - void save_load(io_buf& io, const bool read, const bool text); + void save_load(io_buf& io, const bool read, const bool text, const VW::version_struct& model_version); // Called to edit the command-line from a learner. Autorecursive void pre_save_load(VW::workspace& all); @@ -287,6 +288,7 @@ class learner final : public std::enable_shared_from_this details::cleanup_example_func _cleanup_example_f; details::save_load_func _save_load_f; + details::save_load_ver_func _save_load_ver_f; details::void_func _end_pass_f; details::void_func _end_examples_f; details::pre_save_load_func _pre_save_load_f; @@ -417,12 +419,19 @@ class common_learner_builder learner_ptr->learn_returns_prediction = learn_returns_prediction; ) + // TODO: deprecate? LEARNER_BUILDER_DEFINE(set_save_load(void (*fn_ptr)(DataT&, io_buf&, bool, bool)), assert(fn_ptr != nullptr); DataT* data = this->learner_data.get(); this->learner_ptr->_save_load_f = [fn_ptr, data](io_buf& buf, bool read, bool text) { fn_ptr(*data, buf, read, text); }; ) + LEARNER_BUILDER_DEFINE(set_save_load(void (*fn_ptr)(DataT&, io_buf&, bool, bool, const VW::version_struct&)), + assert(fn_ptr != nullptr); + DataT* data = this->learner_data.get(); + this->learner_ptr->_save_load_ver_f = [fn_ptr, data](io_buf& buf, bool read, bool text, const VW::version_struct& ver) + { fn_ptr(*data, buf, read, text, ver); }; + ) LEARNER_BUILDER_DEFINE(set_finish(void (*fn_ptr)(DataT&)), assert(fn_ptr != nullptr); diff --git a/vowpalwabbit/core/include/vw/core/reductions/active.h b/vowpalwabbit/core/include/vw/core/reductions/active.h index 0bff31db2d0..8974237058c 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/active.h +++ b/vowpalwabbit/core/include/vw/core/reductions/active.h @@ -16,12 +16,10 @@ namespace reductions class active { public: - active(float active_c0, std::shared_ptr shared_data, std::shared_ptr random_state, - VW::version_struct model_version) + active(float active_c0, std::shared_ptr shared_data, std::shared_ptr random_state) : active_c0(active_c0) , _shared_data(shared_data) , _random_state(std::move(random_state)) - , _model_version{std::move(model_version)} { } @@ -31,7 +29,6 @@ class active float _min_seen_label = 0.f; float _max_seen_label = 1.f; - VW::version_struct _model_version; }; std::shared_ptr active_setup(VW::setup_base_i& stack_builder); diff --git a/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_common.h b/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_common.h index d07cb0173c7..f4863d0c834 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_common.h +++ b/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_common.h @@ -93,7 +93,7 @@ class cb_explore_adf_base if (with_metrics) { _metrics = VW::make_unique(); } } - static void save_load(cb_explore_adf_base& data, io_buf& io, bool read, bool text); + static void save_load(cb_explore_adf_base& data, io_buf& io, bool read, bool text, const VW::version_struct& ver); static void persist_metrics(cb_explore_adf_base& data, metric_sink& metrics); static void predict(cb_explore_adf_base& data, VW::LEARNER::learner& base, multi_ex& examples); static void learn(cb_explore_adf_base& data, VW::LEARNER::learner& base, multi_ex& examples); @@ -302,9 +302,9 @@ void cb_explore_adf_base::_print_update( template inline void cb_explore_adf_base::save_load( - cb_explore_adf_base& data, io_buf& io, bool read, bool text) + cb_explore_adf_base& data, io_buf& io, bool read, bool text, const VW::version_struct& ver) { - data.explore.save_load(io, read, text); + data.explore.save_load(io, read, text, ver); } template diff --git a/vowpalwabbit/core/src/learner.cc b/vowpalwabbit/core/src/learner.cc index c124abb01ab..38300a5ed35 100644 --- a/vowpalwabbit/core/src/learner.cc +++ b/vowpalwabbit/core/src/learner.cc @@ -442,7 +442,7 @@ float learner::sensitivity(example& ec, size_t i) return ret; } -void learner::save_load(io_buf& io, const bool read, const bool text) +void learner::save_load(io_buf& io, const bool read, const bool text, const VW::version_struct& model_version) { if (_save_load_f) { @@ -457,7 +457,20 @@ void learner::save_load(io_buf& io, const bool read, const bool text) throw VW::save_load_model_exception(vwex.filename(), vwex.line_number(), better_msg.str()); } } - if (_base_learner) { _base_learner->save_load(io, read, text); } + else if (_save_load_ver_f) + { + try + { + _save_load_ver_f(io, read, text, model_version); + } + catch (VW::vw_exception& vwex) + { + std::stringstream better_msg; + better_msg << "model " << std::string(read ? "load" : "save") << " failed. Error Details: " << vwex.what(); + throw VW::save_load_model_exception(vwex.filename(), vwex.line_number(), better_msg.str()); + } + } + if (_base_learner) { _base_learner->save_load(io, read, text, model_version); } } void learner::pre_save_load(VW::workspace& all) diff --git a/vowpalwabbit/core/src/parse_args.cc b/vowpalwabbit/core/src/parse_args.cc index 3d33bdecd4b..b28cc9dd3ae 100644 --- a/vowpalwabbit/core/src/parse_args.cc +++ b/vowpalwabbit/core/src/parse_args.cc @@ -1362,7 +1362,7 @@ void load_input_model(VW::workspace& all, VW::io_buf& io_temp) all.feature_mask == all.initial_weights_config.initial_regressors[0]) { // load rest of regressor - all.l->save_load(io_temp, true, false); + all.l->save_load(io_temp, true, false, all.runtime_state.model_file_ver); io_temp.close_file(); VW::details::parse_mask_regressor_args(all, all.feature_mask, all.initial_weights_config.initial_regressors); @@ -1372,7 +1372,7 @@ void load_input_model(VW::workspace& all, VW::io_buf& io_temp) VW::details::parse_mask_regressor_args(all, all.feature_mask, all.initial_weights_config.initial_regressors); // load rest of regressor - all.l->save_load(io_temp, true, false); + all.l->save_load(io_temp, true, false, all.runtime_state.model_file_ver); io_temp.close_file(); } } diff --git a/vowpalwabbit/core/src/parse_regressor.cc b/vowpalwabbit/core/src/parse_regressor.cc index 6465389e5b6..c94e3c90c12 100644 --- a/vowpalwabbit/core/src/parse_regressor.cc +++ b/vowpalwabbit/core/src/parse_regressor.cc @@ -489,7 +489,7 @@ void VW::details::dump_regressor(VW::workspace& all, VW::io_buf& buf, bool as_te std::string unused; if (all.l != nullptr) { all.l->pre_save_load(all); } VW::details::save_load_header(all, buf, false, as_text, unused, *all.options); - if (all.l != nullptr) { all.l->save_load(buf, false, as_text); } + if (all.l != nullptr) { all.l->save_load(buf, false, as_text, all.runtime_state.model_file_ver); } buf.flush(); // close_file() should do this for me ... buf.close_file(); @@ -580,7 +580,7 @@ void VW::details::parse_mask_regressor_args( io_temp_mask.add_file(VW::io::open_file_reader(feature_mask)); save_load_header(all, io_temp_mask, true, false, file_options, *all.options); - all.l->save_load(io_temp_mask, true, false); + all.l->save_load(io_temp_mask, true, false, all.runtime_state.model_file_ver); io_temp_mask.close_file(); // Deal with the over-written header from initial regressor diff --git a/vowpalwabbit/core/src/reductions/active.cc b/vowpalwabbit/core/src/reductions/active.cc index ea8b66c40e4..1b9a878da5b 100644 --- a/vowpalwabbit/core/src/reductions/active.cc +++ b/vowpalwabbit/core/src/reductions/active.cc @@ -127,16 +127,16 @@ void active_print_result( if (t != len) { logger.err_error("write error: {}", VW::io::strerror_to_string(errno)); } } -void save_load(active& a, VW::io_buf& io, bool read, bool text) +void save_load(active& a, VW::io_buf& io, bool read, bool text, const VW::version_struct& ver) { using namespace VW::version_definitions; if (io.num_files() == 0) { return; } // This code is valid if version is within // [VERSION_FILE_WITH_ACTIVE_SEEN_LABELS, VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) // or >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED - if ((a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS && - a._model_version < VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) || - a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED) + if ((ver >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS && + ver < VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) || + ver >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED) { if (read) { @@ -201,7 +201,7 @@ std::shared_ptr VW::reductions::active_setup(VW::setup_bas if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } if (options.was_supplied("lda")) { THROW("lda cannot be combined with active learning") } - auto data = VW::make_unique(active_c0, all.sd, all.get_random_state(), all.runtime_state.model_file_ver); + auto data = VW::make_unique(active_c0, all.sd, all.get_random_state()); auto base = require_singleline(stack_builder.setup_base_learner()); using learn_pred_func_t = void (*)(active&, VW::LEARNER::learner&, VW::example&); diff --git a/vowpalwabbit/core/src/reductions/automl.cc b/vowpalwabbit/core/src/reductions/automl.cc index 6ef818179b6..92a0a2444a3 100644 --- a/vowpalwabbit/core/src/reductions/automl.cc +++ b/vowpalwabbit/core/src/reductions/automl.cc @@ -112,7 +112,7 @@ void pre_save_load_automl(VW::workspace& all, automl& data) } template -void save_load_automl(automl& aml, VW::io_buf& io, bool read, bool text) +void save_load_automl(automl& aml, VW::io_buf& io, bool read, bool text, const VW::version_struct&) { if (io.num_files() == 0) { return; } if (read) { VW::model_utils::read_model_field(io, aml); } diff --git a/vowpalwabbit/core/src/reductions/baseline_challenger_cb.cc b/vowpalwabbit/core/src/reductions/baseline_challenger_cb.cc index 9fb59cd59d5..3519cac3a7b 100644 --- a/vowpalwabbit/core/src/reductions/baseline_challenger_cb.cc +++ b/vowpalwabbit/core/src/reductions/baseline_challenger_cb.cc @@ -163,7 +163,7 @@ void learn_or_predict(baseline_challenger_data& data, learner& base, VW::multi_e data.learn_or_predict(base, examples); } -void save_load(baseline_challenger_data& data, VW::io_buf& io, bool read, bool text) +void save_load(baseline_challenger_data& data, VW::io_buf& io, bool read, bool text, const VW::version_struct&) { if (io.num_files() == 0) { return; } if (read) { VW::model_utils::read_model_field(io, data); } diff --git a/vowpalwabbit/core/src/reductions/bfgs.cc b/vowpalwabbit/core/src/reductions/bfgs.cc index 8b1e7ac2bad..6af5716e819 100644 --- a/vowpalwabbit/core/src/reductions/bfgs.cc +++ b/vowpalwabbit/core/src/reductions/bfgs.cc @@ -1048,7 +1048,7 @@ void save_load_regularizer(VW::workspace& all, bfgs& b, VW::io_buf& model_file, if (read) { regularizer_to_weight(all, b); } } -void save_load(bfgs& b, VW::io_buf& model_file, bool read, bool text) +void save_load(bfgs& b, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { VW::workspace* all = b.all; diff --git a/vowpalwabbit/core/src/reductions/boosting.cc b/vowpalwabbit/core/src/reductions/boosting.cc index b514a88d36b..0cdf372ab88 100644 --- a/vowpalwabbit/core/src/reductions/boosting.cc +++ b/vowpalwabbit/core/src/reductions/boosting.cc @@ -239,7 +239,7 @@ void predict_or_learn_adaptive(boosting& o, VW::LEARNER::learner& base, VW::exam ec.pred.scalar = final_prediction; } -void save_load_sampling(boosting& o, VW::io_buf& model_file, bool read, bool text) +void save_load_sampling(boosting& o, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { if (model_file.num_files() == 0) { return; } std::stringstream os; @@ -298,7 +298,7 @@ void save_load_sampling(boosting& o, VW::io_buf& model_file, bool read, bool tex o.logger.err_info("{}", fmt::to_string(buffer)); } -void save_load(boosting& o, VW::io_buf& model_file, bool read, bool text) +void save_load(boosting& o, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { if (model_file.num_files() == 0) { return; } std::stringstream os; @@ -340,7 +340,7 @@ void save_load(boosting& o, VW::io_buf& model_file, bool read, bool text) } } -void save_load_boosting_noop(boosting&, VW::io_buf&, bool, bool) {} +void save_load_boosting_noop(boosting&, VW::io_buf&, bool, bool, const VW::version_struct&) {} } // namespace std::shared_ptr VW::reductions::boosting_setup(VW::setup_base_i& stack_builder) @@ -383,7 +383,7 @@ std::shared_ptr VW::reductions::boosting_setup(VW::setup_b std::string name_addition; void (*learn_ptr)(boosting&, VW::LEARNER::learner&, VW::example&); void (*pred_ptr)(boosting&, VW::LEARNER::learner&, VW::example&); - void (*save_load_fn)(boosting&, io_buf&, bool, bool); + void (*save_load_fn)(boosting&, io_buf&, bool, bool, const VW::version_struct&); if (data->alg == "BBM") { diff --git a/vowpalwabbit/core/src/reductions/cb/cb_adf.cc b/vowpalwabbit/core/src/reductions/cb/cb_adf.cc index 4637daa0157..254f7c905e3 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_adf.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_adf.cc @@ -345,10 +345,9 @@ void print_update_cb_adf(VW::workspace& all, VW::shared_data& /* sd */, const VW else { VW::details::print_update_cb(all, !labeled_example, ec, &ec_seq, true, nullptr); } } -void save_load(VW::reductions::cb_adf& c, VW::io_buf& model_file, bool read, bool text) +void save_load(VW::reductions::cb_adf& c, VW::io_buf& model_file, bool read, bool text, const VW::version_struct& ver) { - if (c.get_model_file_ver() != nullptr && - *c.get_model_file_ver() < VW::version_definitions::VERSION_FILE_WITH_CB_ADF_SAVE) + if (ver < VW::version_definitions::VERSION_FILE_WITH_CB_ADF_SAVE) { return; } diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore.cc index af2d52260a2..6dc80876d2f 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore.cc @@ -54,7 +54,6 @@ class cb_explore float psi = 0.f; bool nounif = false; bool epsilon_decay = false; - VW::version_struct model_file_version; VW::io::logger logger; size_t counter = 0; @@ -270,11 +269,11 @@ float calc_loss(const cb_explore& data, const VW::example& ec, const VW::cb_labe return loss; } -void save_load(cb_explore& cb, VW::io_buf& io, bool read, bool text) +void save_load(cb_explore& cb, VW::io_buf& io, bool read, bool text, const VW::version_struct& ver) { if (io.num_files() == 0) { return; } - if (!read || cb.model_file_version >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG) + if (!read || ver >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG) { std::stringstream msg; if (!read) { msg << "cb cover storing VW::example counter: = " << cb.counter << "\n"; } @@ -372,7 +371,6 @@ std::shared_ptr VW::reductions::cb_explore_setup(VW::setup if (data->epsilon < 0.0 || data->epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); } data->cbcs.cb_type = VW::cb_type_t::DR; - data->model_file_version = all.runtime_state.model_file_ver; size_t params_per_weight = 1; if (options.was_supplied("cover")) { params_per_weight = data->cover_size + 1; } diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_cover.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_cover.cc index cbcec777762..ae6f4324f42 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_cover.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_cover.cc @@ -36,12 +36,12 @@ class cb_explore_adf_cover public: cb_explore_adf_cover(size_t cover_size, float psi, bool nounif, float epsilon, bool epsilon_decay, bool first_only, VW::LEARNER::learner* cs_ldf_learner, VW::LEARNER::learner* scorer, VW::cb_type_t cb_type, - VW::version_struct model_file_version, VW::io::logger logger); + VW::io::logger logger); // Should be called through cb_explore_adf_base for pre/post-processing void predict(VW::LEARNER::learner& base, VW::multi_ex& examples) { predict_or_learn_impl(base, examples); } void learn(VW::LEARNER::learner& base, VW::multi_ex& examples) { predict_or_learn_impl(base, examples); } - void save_load(VW::io_buf& io, bool read, bool text); + void save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct&); private: size_t _cover_size; @@ -56,7 +56,6 @@ class cb_explore_adf_cover VW::details::cb_to_cs_adf_dr _gen_cs_dr; VW::cb_type_t _cb_type = VW::cb_type_t::DM; - VW::version_struct _model_file_version; VW::io::logger _logger; VW::v_array _action_probs; @@ -71,7 +70,7 @@ class cb_explore_adf_cover cb_explore_adf_cover::cb_explore_adf_cover(size_t cover_size, float psi, bool nounif, float epsilon, bool epsilon_decay, bool first_only, VW::LEARNER::learner* cs_ldf_learner, VW::LEARNER::learner* scorer, VW::cb_type_t cb_type, - VW::version_struct model_file_version, VW::io::logger logger) + VW::io::logger logger) : _cover_size(cover_size) , _psi(psi) , _nounif(nounif) @@ -81,7 +80,6 @@ cb_explore_adf_cover::cb_explore_adf_cover(size_t cover_size, float psi, bool no , _counter(0) , _cs_ldf_learner(cs_ldf_learner) , _cb_type(cb_type) - , _model_file_version(model_file_version) , _logger(std::move(logger)) { _gen_cs_dr.scorer = scorer; @@ -222,10 +220,10 @@ void cb_explore_adf_cover::predict_or_learn_impl(VW::LEARNER::learner& base, VW: if (is_learn) { ++_counter; } } -void cb_explore_adf_cover::save_load(VW::io_buf& io, bool read, bool text) +void cb_explore_adf_cover::save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct& ver) { if (io.num_files() == 0) { return; } - if (!read || _model_file_version >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG) + if (!read || ver >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG) { std::stringstream msg; if (!read) { msg << "cb cover adf storing example counter: = " << _counter << "\n"; } @@ -326,7 +324,7 @@ std::shared_ptr VW::reductions::cb_explore_adf_cover_setup using explore_type = cb_explore_adf_base; auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), VW::cast_to_smaller_type(cover_size), psi, nounif, epsilon, epsilon_decay, first_only, cost_sensitive, - scorer, cb_type, all.runtime_state.model_file_ver, all.logger); + scorer, cb_type, all.logger); auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, stack_builder.get_setupfn_name(cb_explore_adf_cover_setup)) .set_input_label_type(VW::label_type_t::CB) diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_first.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_first.cc index 2a1fcafd6a3..5297c630498 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_first.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_first.cc @@ -31,25 +31,24 @@ namespace class cb_explore_adf_first { public: - cb_explore_adf_first(size_t tau, float epsilon, VW::version_struct model_file_version); + cb_explore_adf_first(size_t tau, float epsilon); ~cb_explore_adf_first() = default; // Should be called through cb_explore_adf_base for pre/post-processing void predict(learner& base, VW::multi_ex& examples) { predict_or_learn_impl(base, examples); } void learn(learner& base, VW::multi_ex& examples) { predict_or_learn_impl(base, examples); } - void save_load(VW::io_buf& io, bool read, bool text); + void save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct&); private: size_t _tau; float _epsilon; - VW::version_struct _model_file_version; template void predict_or_learn_impl(learner& base, VW::multi_ex& examples); }; -cb_explore_adf_first::cb_explore_adf_first(size_t tau, float epsilon, VW::version_struct model_file_version) - : _tau(tau), _epsilon(epsilon), _model_file_version(model_file_version) +cb_explore_adf_first::cb_explore_adf_first(size_t tau, float epsilon) + : _tau(tau), _epsilon(epsilon) { } @@ -78,10 +77,10 @@ void cb_explore_adf_first::predict_or_learn_impl(learner& base, VW::multi_ex& ex VW::explore::enforce_minimum_probability(_epsilon, true, begin_scores(preds), end_scores(preds)); } -void cb_explore_adf_first::save_load(VW::io_buf& io, bool read, bool text) +void cb_explore_adf_first::save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct& ver) { if (io.num_files() == 0) { return; } - if (!read || _model_file_version >= VW::version_definitions::VERSION_FILE_WITH_FIRST_SAVE_RESUME) + if (!read || ver >= VW::version_definitions::VERSION_FILE_WITH_FIRST_SAVE_RESUME) { std::stringstream msg; if (!read) { msg << "cb first adf storing example counter: = " << _tau << "\n"; } @@ -122,7 +121,7 @@ std::shared_ptr VW::reductions::cb_explore_adf_first_setup using explore_type = cb_explore_adf_base; auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), - VW::cast_to_smaller_type(tau), epsilon, all.runtime_state.model_file_ver); + VW::cast_to_smaller_type(tau), epsilon); if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); } auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_graph_feedback.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_graph_feedback.cc index 62f33a185aa..207dde4c2a4 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_graph_feedback.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_graph_feedback.cc @@ -78,7 +78,7 @@ class cb_explore_adf_graph_feedback // Should be called through cb_explore_adf_base for pre/post-processing void predict(VW::LEARNER::learner& base, multi_ex& examples); void learn(VW::LEARNER::learner& base, multi_ex& examples); - void save_load(io_buf& io, bool read, bool text); + void save_load(io_buf& io, bool read, bool text, const VW::version_struct&); size_t _counter = 0; float _gamma_scale; float _gamma_exponent; @@ -506,7 +506,7 @@ void cb_explore_adf_graph_feedback::learn(VW::LEARNER::learner& base, multi_ex& predict_or_learn_impl(base, examples); } -void cb_explore_adf_graph_feedback::save_load(VW::io_buf& io, bool read, bool text) +void cb_explore_adf_graph_feedback::save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct&) { if (io.num_files() == 0) { return; } diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_regcb.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_regcb.cc index ddb3f3c927d..1b45fad8db6 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_regcb.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_regcb.cc @@ -37,14 +37,13 @@ namespace class cb_explore_adf_regcb { public: - cb_explore_adf_regcb(bool regcbopt, float c0, bool first_only, float min_cb_cost, float max_cb_cost, - VW::version_struct model_file_version); + cb_explore_adf_regcb(bool regcbopt, float c0, bool first_only, float min_cb_cost, float max_cb_cost); ~cb_explore_adf_regcb() = default; // Should be called through cb_explore_adf_base for pre/post-processing void predict(learner& base, VW::multi_ex& examples) { predict_impl(base, examples); } void learn(learner& base, VW::multi_ex& examples) { learn_impl(base, examples); } - void save_load(VW::io_buf& io, bool read, bool text); + void save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct&); private: void predict_impl(learner& base, VW::multi_ex& examples); @@ -64,22 +63,19 @@ class cb_explore_adf_regcb std::vector _min_costs; std::vector _max_costs; - VW::version_struct _model_file_version; - // for backing up cb example data when computing sensitivities std::vector _ex_as; std::vector> _ex_costs; }; cb_explore_adf_regcb::cb_explore_adf_regcb(bool regcbopt, float c0, bool first_only, float min_cb_cost, - float max_cb_cost, VW::version_struct model_file_version) + float max_cb_cost) : _counter(0) , _regcbopt(regcbopt) , _c0(c0) , _first_only(first_only) , _min_cb_cost(min_cb_cost) , _max_cb_cost(max_cb_cost) - , _model_file_version(model_file_version) { } @@ -230,10 +226,10 @@ void cb_explore_adf_regcb::learn_impl(learner& base, VW::multi_ex& examples) examples[0]->pred.a_s = std::move(preds); } -void cb_explore_adf_regcb::save_load(VW::io_buf& io, bool read, bool text) +void cb_explore_adf_regcb::save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct& ver) { if (io.num_files() == 0) { return; } - if (!read || _model_file_version >= VW::version_definitions::VERSION_FILE_WITH_REG_CB_SAVE_RESUME) + if (!read || ver >= VW::version_definitions::VERSION_FILE_WITH_REG_CB_SAVE_RESUME) { std::stringstream msg; if (!read) { msg << "cb squarecb adf storing example counter: = " << _counter << "\n"; } @@ -303,7 +299,7 @@ std::shared_ptr VW::reductions::cb_explore_adf_regcb_setup using explore_type = cb_explore_adf_base; auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), regcbopt, c0, - first_only, min_cb_cost, max_cb_cost, all.runtime_state.model_file_ver); + first_only, min_cb_cost, max_cb_cost); auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, stack_builder.get_setupfn_name(cb_explore_adf_regcb_setup)) .set_input_label_type(VW::label_type_t::CB) diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_squarecb.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_squarecb.cc index 7eba27a5bf6..c50a64281b9 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_squarecb.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_squarecb.cc @@ -43,13 +43,13 @@ class cb_explore_adf_squarecb { public: cb_explore_adf_squarecb(float gamma_scale, float gamma_exponent, bool elim, float c0, float min_cb_cost, - float max_cb_cost, VW::version_struct model_file_version, float epsilon, bool store_gamma_in_reduction_features); + float max_cb_cost, float epsilon, bool store_gamma_in_reduction_features); ~cb_explore_adf_squarecb() = default; // Should be called through cb_explore_adf_base for pre/post-processing void predict(learner& base, VW::multi_ex& examples); void learn(learner& base, VW::multi_ex& examples); - void save_load(VW::io_buf& io, bool read, bool text); + void save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct&); private: size_t _counter; @@ -66,8 +66,6 @@ class cb_explore_adf_squarecb std::vector _min_costs; std::vector _max_costs; - VW::version_struct _model_file_version; - bool _store_gamma_in_reduction_features; // for backing up cb example data when computing sensitivities @@ -78,7 +76,7 @@ class cb_explore_adf_squarecb }; cb_explore_adf_squarecb::cb_explore_adf_squarecb(float gamma_scale, float gamma_exponent, bool elim, float c0, - float min_cb_cost, float max_cb_cost, VW::version_struct model_file_version, float epsilon, + float min_cb_cost, float max_cb_cost, float epsilon, bool store_gamma_in_reduction_features) : _counter(0) , _gamma_scale(gamma_scale) @@ -88,7 +86,6 @@ cb_explore_adf_squarecb::cb_explore_adf_squarecb(float gamma_scale, float gamma_ , _min_cb_cost(min_cb_cost) , _max_cb_cost(max_cb_cost) , _epsilon(epsilon) - , _model_file_version(model_file_version) , _store_gamma_in_reduction_features(store_gamma_in_reduction_features) { } @@ -296,10 +293,10 @@ void cb_explore_adf_squarecb::learn(learner& base, VW::multi_ex& examples) examples[0]->pred.a_s = std::move(preds); } -void cb_explore_adf_squarecb::save_load(VW::io_buf& io, bool read, bool text) +void cb_explore_adf_squarecb::save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct& ver) { if (io.num_files() == 0) { return; } - if (!read || _model_file_version >= VW::version_definitions::VERSION_FILE_WITH_SQUARE_CB_SAVE_RESUME) + if (!read || ver >= VW::version_definitions::VERSION_FILE_WITH_SQUARE_CB_SAVE_RESUME) { std::stringstream msg; if (!read) { msg << "cb squarecb adf storing example counter: = " << _counter << "\n"; } @@ -396,7 +393,7 @@ std::shared_ptr VW::reductions::cb_explore_adf_squarecb_se using explore_type = cb_explore_adf_base; auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), gamma_scale, - gamma_exponent, elim, c0, min_cb_cost, max_cb_cost, all.runtime_state.model_file_ver, epsilon, + gamma_exponent, elim, c0, min_cb_cost, max_cb_cost, epsilon, store_gamma_in_reduction_features); auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, stack_builder.get_setupfn_name(cb_explore_adf_squarecb_setup)) diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_synthcover.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_synthcover.cc index a9757248256..4aea395b876 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_synthcover.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_synthcover.cc @@ -34,12 +34,12 @@ class cb_explore_adf_synthcover { public: cb_explore_adf_synthcover(float epsilon, float psi, size_t synthcoversize, - std::shared_ptr random_state, VW::version_struct model_file_version); + std::shared_ptr random_state); // Should be called through cb_explore_adf_base for pre/post-processing void predict(VW::LEARNER::learner& base, VW::multi_ex& examples) { predict_or_learn_impl(base, examples); } void learn(VW::LEARNER::learner& base, VW::multi_ex& examples) { predict_or_learn_impl(base, examples); } - void save_load(VW::io_buf& model_file, bool read, bool text); + void save_load(VW::io_buf& model_file, bool read, bool text, const VW::version_struct&); private: float _epsilon; @@ -47,8 +47,6 @@ class cb_explore_adf_synthcover size_t _synthcoversize; std::shared_ptr _random_state; - VW::version_struct _model_file_version; - VW::v_array _action_probs; float _min_cost; float _max_cost; @@ -57,12 +55,11 @@ class cb_explore_adf_synthcover }; cb_explore_adf_synthcover::cb_explore_adf_synthcover(float epsilon, float psi, size_t synthcoversize, - std::shared_ptr random_state, VW::version_struct model_file_version) + std::shared_ptr random_state) : _epsilon(epsilon) , _psi(psi) , _synthcoversize(synthcoversize) , _random_state(std::move(random_state)) - , _model_file_version(model_file_version) , _min_cost(0.0) , _max_cost(0.0) { @@ -127,10 +124,10 @@ void cb_explore_adf_synthcover::predict_or_learn_impl(VW::LEARNER::learner& base for (size_t i = 0; i < num_actions; i++) { preds[i] = _action_probs[i]; } } -void cb_explore_adf_synthcover::save_load(VW::io_buf& model_file, bool read, bool text) +void cb_explore_adf_synthcover::save_load(VW::io_buf& model_file, bool read, bool text, const VW::version_struct& ver) { if (model_file.num_files() == 0) { return; } - if (!read || _model_file_version >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG) + if (!read || ver >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG) { std::stringstream msg; if (!read) { msg << "_min_cost " << _min_cost << "\n"; } @@ -196,7 +193,7 @@ std::shared_ptr VW::reductions::cb_explore_adf_synthcover_ using explore_type = cb_explore_adf_base; auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), epsilon, psi, - VW::cast_to_smaller_type(synthcoversize), all.get_random_state(), all.runtime_state.model_file_ver); + VW::cast_to_smaller_type(synthcoversize), all.get_random_state()); auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, stack_builder.get_setupfn_name(cb_explore_adf_synthcover_setup)) .set_input_label_type(VW::label_type_t::CB) diff --git a/vowpalwabbit/core/src/reductions/cbzo.cc b/vowpalwabbit/core/src/reductions/cbzo.cc index 6f5b2052d9e..992ce702a97 100644 --- a/vowpalwabbit/core/src/reductions/cbzo.cc +++ b/vowpalwabbit/core/src/reductions/cbzo.cc @@ -230,7 +230,7 @@ inline void save_load_regressor(VW::workspace& all, VW::io_buf& model_file, bool VW::details::save_load_regressor_gd(all, model_file, read, text); } -void save_load(cbzo& data, VW::io_buf& model_file, bool read, bool text) +void save_load(cbzo& data, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { VW::workspace& all = *data.all; if (read) diff --git a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc index b332793ded8..928ea9d02ca 100644 --- a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc +++ b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc @@ -85,7 +85,6 @@ class ccb_data VW::vector_pool cb_label_pool; VW::v_array_pool action_score_pool; - VW::version_struct model_file_version; // If the reduction has not yet seen a multi slot example, it will behave the same as if it were CB. // This means the interactions aren't added and the slot feature is not added. bool has_seen_multi_slot_example = false; @@ -609,13 +608,13 @@ void cleanup_example_ccb(ccb_data& data, VW::multi_ex& ec_seq) } } -void save_load(ccb_data& sm, VW::io_buf& io, bool read, bool text) +void save_load(ccb_data& sm, VW::io_buf& io, bool read, bool text, const VW::version_struct& ver) { if (io.num_files() == 0) { return; } // We need to check if reading a model file after the version in which this was added. if (read && - (sm.model_file_version >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG && + (ver >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG && sm.is_ccb_input_model)) { VW::model_utils::read_model_field(io, sm.has_seen_multi_slot_example); @@ -698,7 +697,6 @@ std::shared_ptr VW::reductions::ccb_explore_adf_setup(VW:: // Extract from lower level reductions data->shared = nullptr; data->all = &all; - data->model_file_version = all.runtime_state.model_file_ver; data->id_namespace_str = "_id"; data->id_namespace_audit_str = "_ccb_slot_index"; diff --git a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc index 37c856644fa..6cad7af34ad 100644 --- a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc @@ -906,7 +906,7 @@ size_t write_model_field( namespace { -void emt_save_load_tree(VW::reductions::eigen_memory_tree::emt_tree& tree, VW::io_buf& io, bool read, bool text) +void emt_save_load_tree(VW::reductions::eigen_memory_tree::emt_tree& tree, VW::io_buf& io, bool read, bool text, const VW::version_struct&) { if (io.num_files() == 0) { return; } if (read) { VW::model_utils::read_model_field(io, tree); } diff --git a/vowpalwabbit/core/src/reductions/epsilon_decay.cc b/vowpalwabbit/core/src/reductions/epsilon_decay.cc index 4c631978d2c..f3bb04e6ecc 100644 --- a/vowpalwabbit/core/src/reductions/epsilon_decay.cc +++ b/vowpalwabbit/core/src/reductions/epsilon_decay.cc @@ -288,7 +288,7 @@ void learn(VW::reductions::epsilon_decay::epsilon_decay_data& data, VW::LEARNER: } void save_load_epsilon_decay( - VW::reductions::epsilon_decay::epsilon_decay_data& epsilon_decay, VW::io_buf& io, bool read, bool text) + VW::reductions::epsilon_decay::epsilon_decay_data& epsilon_decay, VW::io_buf& io, bool read, bool text, const VW::version_struct&) { if (io.num_files() == 0) { return; } if (read) { VW::model_utils::read_model_field(io, epsilon_decay); } diff --git a/vowpalwabbit/core/src/reductions/freegrad.cc b/vowpalwabbit/core/src/reductions/freegrad.cc index ae181a6eda5..5774abe3c1e 100644 --- a/vowpalwabbit/core/src/reductions/freegrad.cc +++ b/vowpalwabbit/core/src/reductions/freegrad.cc @@ -283,7 +283,7 @@ void learn_freegrad(freegrad& a, VW::example& ec) freegrad_update_after_prediction(a, ec); } -void save_load(freegrad& fg, VW::io_buf& model_file, bool read, bool text) +void save_load(freegrad& fg, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { VW::workspace* all = fg.all; if (read) { VW::details::initialize_regressor(*all); } diff --git a/vowpalwabbit/core/src/reductions/ftrl.cc b/vowpalwabbit/core/src/reductions/ftrl.cc index 129f69012ab..0715e270b1b 100644 --- a/vowpalwabbit/core/src/reductions/ftrl.cc +++ b/vowpalwabbit/core/src/reductions/ftrl.cc @@ -331,7 +331,7 @@ void learn_coin_betting(ftrl& a, VW::example& ec) coin_betting_update_after_prediction(a, ec); } -void save_load(ftrl& b, VW::io_buf& model_file, bool read, bool text) +void save_load(ftrl& b, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { VW::workspace* all = b.all; if (read) { VW::details::initialize_regressor(*all); } diff --git a/vowpalwabbit/core/src/reductions/gd.cc b/vowpalwabbit/core/src/reductions/gd.cc index 8ac8ac6a7f4..b77c0b93911 100644 --- a/vowpalwabbit/core/src/reductions/gd.cc +++ b/vowpalwabbit/core/src/reductions/gd.cc @@ -1284,7 +1284,7 @@ void VW::details::save_load_online_state_gd(VW::workspace& all, VW::io_buf& mode namespace { -void save_load(VW::reductions::gd& g, VW::io_buf& model_file, bool read, bool text) +void save_load(VW::reductions::gd& g, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { VW::workspace& all = *g.all; if (read) diff --git a/vowpalwabbit/core/src/reductions/gd_mf.cc b/vowpalwabbit/core/src/reductions/gd_mf.cc index a4b1b41baa8..b72c19d296e 100644 --- a/vowpalwabbit/core/src/reductions/gd_mf.cc +++ b/vowpalwabbit/core/src/reductions/gd_mf.cc @@ -256,7 +256,7 @@ void initialize_weights(VW::weight* weights, uint64_t index, uint32_t stride) } } -void save_load(gdmf& d, VW::io_buf& model_file, bool read, bool text) +void save_load(gdmf& d, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { VW::workspace& all = *d.all; uint64_t length = static_cast(1) << all.initial_weights_config.num_bits; diff --git a/vowpalwabbit/core/src/reductions/interaction_ground.cc b/vowpalwabbit/core/src/reductions/interaction_ground.cc index 7df7f69ef7d..40f1f753981 100644 --- a/vowpalwabbit/core/src/reductions/interaction_ground.cc +++ b/vowpalwabbit/core/src/reductions/interaction_ground.cc @@ -249,7 +249,7 @@ void learn(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_s ec_seq[0]->pred.a_s = std::move(stashed_prediction); } -void save_load_igl(VW::reductions::igl::igl_data& igl, VW::io_buf& io, bool read, bool text) +void save_load_igl(VW::reductions::igl::igl_data& igl, VW::io_buf& io, bool read, bool text, const VW::version_struct&) { if (io.num_files() == 0) { return; } if (read) { VW::reductions::model_utils::read_model_field(io, igl); } diff --git a/vowpalwabbit/core/src/reductions/kernel_svm.cc b/vowpalwabbit/core/src/reductions/kernel_svm.cc index 7236dbf01e0..cc445b05143 100644 --- a/vowpalwabbit/core/src/reductions/kernel_svm.cc +++ b/vowpalwabbit/core/src/reductions/kernel_svm.cc @@ -343,15 +343,15 @@ void save_load_svm_model(svm_params& params, VW::io_buf& model_file, bool read, static_cast(model->num_support) * sizeof(float), read, msg, text); } -void save_load(svm_params& params, VW::io_buf& model_file, bool read, bool text) +void save_load(svm_params& params, VW::io_buf& model_file, bool read, bool text, const VW::version_struct& ver) { if (text) { *params.all->output_runtime.trace_message << "Not supporting readable model for kernel svm currently" << endl; return; } - else if (params.all->runtime_state.model_file_ver > VW::version_definitions::EMPTY_VERSION_FILE && - params.all->runtime_state.model_file_ver < VW::version_definitions::VERSION_FILE_WITH_FLAT_EXAMPLE_TAG_FIX) + else if (ver > VW::version_definitions::EMPTY_VERSION_FILE && + ver < VW::version_definitions::VERSION_FILE_WITH_FLAT_EXAMPLE_TAG_FIX) { THROW("Models using ksvm from before version 9.6 are not compatable with this version of VW.") } diff --git a/vowpalwabbit/core/src/reductions/lda_core.cc b/vowpalwabbit/core/src/reductions/lda_core.cc index efcfc0ef6c5..28262960248 100644 --- a/vowpalwabbit/core/src/reductions/lda_core.cc +++ b/vowpalwabbit/core/src/reductions/lda_core.cc @@ -762,7 +762,7 @@ class initial_weights uint64_t stride; }; -void save_load(lda& l, VW::io_buf& model_file, bool read, bool text) +void save_load(lda& l, VW::io_buf& model_file, bool read, bool text, const VW::version_struct& ver) { VW::workspace& all = *(l.all); uint64_t length = static_cast(1) << all.initial_weights_config.num_bits; @@ -800,7 +800,7 @@ void save_load(lda& l, VW::io_buf& model_file, bool read, bool text) size_t K = all.reduction_state.lda; // NOLINT if (!read && text) { msg << i << " "; } - if (!read || all.runtime_state.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_ID) + if (!read || ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_ID) { brw += VW::details::bin_text_read_write_fixed(model_file, reinterpret_cast(&i), sizeof(i), read, msg, text); diff --git a/vowpalwabbit/core/src/reductions/log_multi.cc b/vowpalwabbit/core/src/reductions/log_multi.cc index 01092faa827..63efedb09b3 100644 --- a/vowpalwabbit/core/src/reductions/log_multi.cc +++ b/vowpalwabbit/core/src/reductions/log_multi.cc @@ -302,7 +302,7 @@ void learn(log_multi& b, learner& base, VW::example& ec) } } -void save_load_tree(log_multi& b, VW::io_buf& model_file, bool read, bool text) +void save_load_tree(log_multi& b, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { if (model_file.num_files() > 0) { diff --git a/vowpalwabbit/core/src/reductions/marginal.cc b/vowpalwabbit/core/src/reductions/marginal.cc index 44de9e6686e..d9e0e70008e 100644 --- a/vowpalwabbit/core/src/reductions/marginal.cc +++ b/vowpalwabbit/core/src/reductions/marginal.cc @@ -280,7 +280,7 @@ void predict_or_learn(data& sm, VW::LEARNER::learner& base, VW::example& ec) undo_marginal(sm, ec); } -void save_load(data& sm, VW::io_buf& io, bool read, bool text) +void save_load(data& sm, VW::io_buf& io, bool read, bool text, const VW::version_struct&) { const uint64_t stride_shift = sm.m_all->weights.stride_shift(); diff --git a/vowpalwabbit/core/src/reductions/memory_tree.cc b/vowpalwabbit/core/src/reductions/memory_tree.cc index 67e767a3811..7ab4a9c81fa 100644 --- a/vowpalwabbit/core/src/reductions/memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/memory_tree.cc @@ -1132,7 +1132,7 @@ void save_load_node(node& cn, VW::io_buf& model_file, bool& read, bool& text, st for (uint32_t k = 0; k < leaf_n_examples; k++) DEPRECATED_WRITEIT(cn.examples_index[k], "example_location"); } -void save_load_memory_tree(memory_tree& b, VW::io_buf& model_file, bool read, bool text) +void save_load_memory_tree(memory_tree& b, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { std::stringstream msg; if (model_file.num_files() > 0) diff --git a/vowpalwabbit/core/src/reductions/mwt.cc b/vowpalwabbit/core/src/reductions/mwt.cc index 2f9d45bd546..f9638b1d838 100644 --- a/vowpalwabbit/core/src/reductions/mwt.cc +++ b/vowpalwabbit/core/src/reductions/mwt.cc @@ -211,7 +211,7 @@ void print_update_mwt( } } -void save_load(mwt& c, VW::io_buf& model_file, bool read, bool text) +void save_load(mwt& c, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { if (model_file.num_files() == 0) { return; } diff --git a/vowpalwabbit/core/src/reductions/oja_newton.cc b/vowpalwabbit/core/src/reductions/oja_newton.cc index 0945282084a..ce1501e2c49 100644 --- a/vowpalwabbit/core/src/reductions/oja_newton.cc +++ b/vowpalwabbit/core/src/reductions/oja_newton.cc @@ -465,7 +465,7 @@ void learn(OjaNewton& oja_newton_ptr, VW::example& ec) oja_newton_ptr.check(); } -void save_load(OjaNewton& oja_newton_ptr, VW::io_buf& model_file, bool read, bool text) +void save_load(OjaNewton& oja_newton_ptr, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { VW::workspace& all = *oja_newton_ptr.all; if (read) diff --git a/vowpalwabbit/core/src/reductions/plt.cc b/vowpalwabbit/core/src/reductions/plt.cc index 0b270764f5f..f6d7990c38d 100644 --- a/vowpalwabbit/core/src/reductions/plt.cc +++ b/vowpalwabbit/core/src/reductions/plt.cc @@ -70,7 +70,6 @@ class plt uint32_t fn = 0; // false negatives uint32_t ec_count = 0; // number of examples - VW::version_struct model_file_version; bool force_load_legacy_model = false; plt() @@ -376,11 +375,11 @@ void finish(plt& p) } } -void save_load_tree(plt& p, VW::io_buf& model_file, bool read, bool text) +void save_load_tree(plt& p, VW::io_buf& model_file, bool read, bool text, const VW::version_struct& ver) { if (model_file.num_files() == 0) { return; } - if (read && p.model_file_version < VW::version_definitions::VERSION_FILE_WITH_PLT_SAVE_LOAD_FIX && + if (read && ver < VW::version_definitions::VERSION_FILE_WITH_PLT_SAVE_LOAD_FIX && p.force_load_legacy_model) { bool resume = false; @@ -399,7 +398,7 @@ void save_load_tree(plt& p, VW::io_buf& model_file, bool read, bool text) return; } - if (read && p.model_file_version < VW::version_definitions::VERSION_FILE_WITH_PLT_SAVE_LOAD_FIX) + if (read && ver < VW::version_definitions::VERSION_FILE_WITH_PLT_SAVE_LOAD_FIX) { THROW( "PLT models before 9.7 had a bug which caused incorrect loading under certain conditions, so by default they " @@ -483,8 +482,6 @@ std::shared_ptr VW::reductions::plt_setup(VW::setup_base_i tree->r_at.resize(tree->top_k); } - tree->model_file_version = all.runtime_state.model_file_ver; - size_t feature_width = tree->t; std::string name_addition = ""; VW::prediction_type_t pred_type; diff --git a/vowpalwabbit/core/src/reductions/recall_tree.cc b/vowpalwabbit/core/src/reductions/recall_tree.cc index 73b9e48a1ef..633aa4b9855 100644 --- a/vowpalwabbit/core/src/reductions/recall_tree.cc +++ b/vowpalwabbit/core/src/reductions/recall_tree.cc @@ -439,7 +439,7 @@ void learn(recall_tree& b, learner& base, VW::example& ec) } } -void save_load_tree(recall_tree& b, VW::io_buf& model_file, bool read, bool text) +void save_load_tree(recall_tree& b, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { if (model_file.num_files() > 0) { diff --git a/vowpalwabbit/core/src/reductions/stagewise_poly.cc b/vowpalwabbit/core/src/reductions/stagewise_poly.cc index 61aba578f90..9d3edf2f1bd 100644 --- a/vowpalwabbit/core/src/reductions/stagewise_poly.cc +++ b/vowpalwabbit/core/src/reductions/stagewise_poly.cc @@ -620,7 +620,7 @@ void end_pass(stagewise_poly& poly) } } -void save_load(stagewise_poly& poly, VW::io_buf& model_file, bool read, bool text) +void save_load(stagewise_poly& poly, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { if (model_file.num_files() > 0) { diff --git a/vowpalwabbit/core/src/reductions/svrg.cc b/vowpalwabbit/core/src/reductions/svrg.cc index f9c5ec284eb..f897a8400f4 100644 --- a/vowpalwabbit/core/src/reductions/svrg.cc +++ b/vowpalwabbit/core/src/reductions/svrg.cc @@ -155,7 +155,7 @@ void learn(svrg& s, VW::example& ec) s.prev_pass = pass; } -void save_load(svrg& s, VW::io_buf& model_file, bool read, bool text) +void save_load(svrg& s, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&) { if (read) { VW::details::initialize_regressor(*s.all); }