Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: pass model version in save_load #4664

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion vowpalwabbit/core/include/vw/core/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ using multipredict_func =

using sensitivity_func = std::function<float(example& ex)>;
using save_load_func = std::function<void(io_buf&, bool read, bool text)>;
using save_load_ver_func = std::function<void(io_buf&, bool read, bool text, const VW::version_struct&)>;
using pre_save_load_func = std::function<void(VW::workspace& all)>;
using save_metric_func = std::function<void(metric_sink& metrics)>;

Expand Down Expand Up @@ -182,7 +183,7 @@ class learner final : public std::enable_shared_from_this<learner>
float sensitivity(example& ec, size_t i = 0);

// Called anytime saving or loading needs to happen. Autorecursive.
void save_load(io_buf& io, const bool read, const bool text);
void save_load(io_buf& io, const bool read, const bool text, const VW::version_struct& model_version);

// Called to edit the command-line from a learner. Autorecursive
void pre_save_load(VW::workspace& all);
Expand Down Expand Up @@ -287,6 +288,7 @@ class learner final : public std::enable_shared_from_this<learner>
details::cleanup_example_func _cleanup_example_f;

details::save_load_func _save_load_f;
details::save_load_ver_func _save_load_ver_f;
details::void_func _end_pass_f;
details::void_func _end_examples_f;
details::pre_save_load_func _pre_save_load_f;
Expand Down Expand Up @@ -417,12 +419,19 @@ class common_learner_builder
learner_ptr->learn_returns_prediction = learn_returns_prediction;
)

// TODO: deprecate?
LEARNER_BUILDER_DEFINE(set_save_load(void (*fn_ptr)(DataT&, io_buf&, bool, bool)),
assert(fn_ptr != nullptr);
DataT* data = this->learner_data.get();
this->learner_ptr->_save_load_f = [fn_ptr, data](io_buf& buf, bool read, bool text)
{ fn_ptr(*data, buf, read, text); };
)
LEARNER_BUILDER_DEFINE(set_save_load(void (*fn_ptr)(DataT&, io_buf&, bool, bool, const VW::version_struct&)),
assert(fn_ptr != nullptr);
DataT* data = this->learner_data.get();
this->learner_ptr->_save_load_ver_f = [fn_ptr, data](io_buf& buf, bool read, bool text, const VW::version_struct& ver)
{ fn_ptr(*data, buf, read, text, ver); };
)

LEARNER_BUILDER_DEFINE(set_finish(void (*fn_ptr)(DataT&)),
assert(fn_ptr != nullptr);
Expand Down
5 changes: 1 addition & 4 deletions vowpalwabbit/core/include/vw/core/reductions/active.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ namespace reductions
class active
{
public:
active(float active_c0, std::shared_ptr<shared_data> shared_data, std::shared_ptr<rand_state> random_state,
VW::version_struct model_version)
active(float active_c0, std::shared_ptr<shared_data> shared_data, std::shared_ptr<rand_state> random_state)
: active_c0(active_c0)
, _shared_data(shared_data)
, _random_state(std::move(random_state))
, _model_version{std::move(model_version)}
{
}

Expand All @@ -31,7 +29,6 @@ class active

float _min_seen_label = 0.f;
float _max_seen_label = 1.f;
VW::version_struct _model_version;
};

std::shared_ptr<VW::LEARNER::learner> active_setup(VW::setup_base_i& stack_builder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class cb_explore_adf_base
if (with_metrics) { _metrics = VW::make_unique<cb_explore_metrics>(); }
}

static void save_load(cb_explore_adf_base<ExploreType>& data, io_buf& io, bool read, bool text);
static void save_load(cb_explore_adf_base<ExploreType>& data, io_buf& io, bool read, bool text, const VW::version_struct& ver);
static void persist_metrics(cb_explore_adf_base<ExploreType>& data, metric_sink& metrics);
static void predict(cb_explore_adf_base<ExploreType>& data, VW::LEARNER::learner& base, multi_ex& examples);
static void learn(cb_explore_adf_base<ExploreType>& data, VW::LEARNER::learner& base, multi_ex& examples);
Expand Down Expand Up @@ -302,9 +302,9 @@ void cb_explore_adf_base<ExploreType>::_print_update(

template <typename ExploreType>
inline void cb_explore_adf_base<ExploreType>::save_load(
cb_explore_adf_base<ExploreType>& data, io_buf& io, bool read, bool text)
cb_explore_adf_base<ExploreType>& data, io_buf& io, bool read, bool text, const VW::version_struct& ver)
{
data.explore.save_load(io, read, text);
data.explore.save_load(io, read, text, ver);
}

template <typename ExploreType>
Expand Down
17 changes: 15 additions & 2 deletions vowpalwabbit/core/src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ float learner::sensitivity(example& ec, size_t i)
return ret;
}

void learner::save_load(io_buf& io, const bool read, const bool text)
void learner::save_load(io_buf& io, const bool read, const bool text, const VW::version_struct& model_version)
{
if (_save_load_f)
{
Expand All @@ -457,7 +457,20 @@ void learner::save_load(io_buf& io, const bool read, const bool text)
throw VW::save_load_model_exception(vwex.filename(), vwex.line_number(), better_msg.str());
}
}
if (_base_learner) { _base_learner->save_load(io, read, text); }
else if (_save_load_ver_f)
{
try
{
_save_load_ver_f(io, read, text, model_version);
}
catch (VW::vw_exception& vwex)
{
std::stringstream better_msg;
better_msg << "model " << std::string(read ? "load" : "save") << " failed. Error Details: " << vwex.what();
throw VW::save_load_model_exception(vwex.filename(), vwex.line_number(), better_msg.str());
}
}
if (_base_learner) { _base_learner->save_load(io, read, text, model_version); }
}

void learner::pre_save_load(VW::workspace& all)
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/core/src/parse_args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,7 @@ void load_input_model(VW::workspace& all, VW::io_buf& io_temp)
all.feature_mask == all.initial_weights_config.initial_regressors[0])
{
// load rest of regressor
all.l->save_load(io_temp, true, false);
all.l->save_load(io_temp, true, false, all.runtime_state.model_file_ver);
io_temp.close_file();

VW::details::parse_mask_regressor_args(all, all.feature_mask, all.initial_weights_config.initial_regressors);
Expand All @@ -1372,7 +1372,7 @@ void load_input_model(VW::workspace& all, VW::io_buf& io_temp)
VW::details::parse_mask_regressor_args(all, all.feature_mask, all.initial_weights_config.initial_regressors);

// load rest of regressor
all.l->save_load(io_temp, true, false);
all.l->save_load(io_temp, true, false, all.runtime_state.model_file_ver);
io_temp.close_file();
}
}
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/core/src/parse_regressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ void VW::details::dump_regressor(VW::workspace& all, VW::io_buf& buf, bool as_te
std::string unused;
if (all.l != nullptr) { all.l->pre_save_load(all); }
VW::details::save_load_header(all, buf, false, as_text, unused, *all.options);
if (all.l != nullptr) { all.l->save_load(buf, false, as_text); }
if (all.l != nullptr) { all.l->save_load(buf, false, as_text, all.runtime_state.model_file_ver); }

buf.flush(); // close_file() should do this for me ...
buf.close_file();
Expand Down Expand Up @@ -577,7 +577,7 @@ void VW::details::parse_mask_regressor_args(
io_temp_mask.add_file(VW::io::open_file_reader(feature_mask));

save_load_header(all, io_temp_mask, true, false, file_options, *all.options);
all.l->save_load(io_temp_mask, true, false);
all.l->save_load(io_temp_mask, true, false, all.runtime_state.model_file_ver);
io_temp_mask.close_file();

// Deal with the over-written header from initial regressor
Expand Down
10 changes: 5 additions & 5 deletions vowpalwabbit/core/src/reductions/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,16 @@ void active_print_result(
if (t != len) { logger.err_error("write error: {}", VW::io::strerror_to_string(errno)); }
}

void save_load(active& a, VW::io_buf& io, bool read, bool text)
void save_load(active& a, VW::io_buf& io, bool read, bool text, const VW::version_struct& ver)
{
using namespace VW::version_definitions;
if (io.num_files() == 0) { return; }
// This code is valid if version is within
// [VERSION_FILE_WITH_ACTIVE_SEEN_LABELS, VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED)
// or >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED
if ((a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS &&
a._model_version < VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) ||
a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED)
if ((ver >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS &&
ver < VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) ||
ver >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED)
{
if (read)
{
Expand Down Expand Up @@ -258,7 +258,7 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::active_setup(VW::setup_bas
if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; }

if (options.was_supplied("lda")) { THROW("lda cannot be combined with active learning") }
auto data = VW::make_unique<active>(active_c0, all.sd, all.get_random_state(), all.runtime_state.model_file_ver);
auto data = VW::make_unique<active>(active_c0, all.sd, all.get_random_state());
auto base = require_singleline(stack_builder.setup_base_learner());

using learn_pred_func_t = void (*)(active&, VW::LEARNER::learner&, VW::example&);
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/reductions/automl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ void pre_save_load_automl(VW::workspace& all, automl<CMType>& data)
}

template <typename CMType>
void save_load_automl(automl<CMType>& aml, VW::io_buf& io, bool read, bool text)
void save_load_automl(automl<CMType>& aml, VW::io_buf& io, bool read, bool text, const VW::version_struct&)
{
if (io.num_files() == 0) { return; }
if (read) { VW::model_utils::read_model_field(io, aml); }
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/reductions/baseline_challenger_cb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ void learn_or_predict(baseline_challenger_data& data, learner& base, VW::multi_e
data.learn_or_predict<is_learn>(base, examples);
}

void save_load(baseline_challenger_data& data, VW::io_buf& io, bool read, bool text)
void save_load(baseline_challenger_data& data, VW::io_buf& io, bool read, bool text, const VW::version_struct&)
{
if (io.num_files() == 0) { return; }
if (read) { VW::model_utils::read_model_field(io, data); }
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/reductions/bfgs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ void save_load_regularizer(VW::workspace& all, bfgs& b, VW::io_buf& model_file,
if (read) { regularizer_to_weight(all, b); }
}

void save_load(bfgs& b, VW::io_buf& model_file, bool read, bool text)
void save_load(bfgs& b, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&)
{
VW::workspace* all = b.all;

Expand Down
8 changes: 4 additions & 4 deletions vowpalwabbit/core/src/reductions/boosting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void predict_or_learn_adaptive(boosting& o, VW::LEARNER::learner& base, VW::exam
ec.pred.scalar = final_prediction;
}

void save_load_sampling(boosting& o, VW::io_buf& model_file, bool read, bool text)
void save_load_sampling(boosting& o, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&)
{
if (model_file.num_files() == 0) { return; }
std::stringstream os;
Expand Down Expand Up @@ -298,7 +298,7 @@ void save_load_sampling(boosting& o, VW::io_buf& model_file, bool read, bool tex
o.logger.err_info("{}", fmt::to_string(buffer));
}

void save_load(boosting& o, VW::io_buf& model_file, bool read, bool text)
void save_load(boosting& o, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&)
{
if (model_file.num_files() == 0) { return; }
std::stringstream os;
Expand Down Expand Up @@ -340,7 +340,7 @@ void save_load(boosting& o, VW::io_buf& model_file, bool read, bool text)
}
}

void save_load_boosting_noop(boosting&, VW::io_buf&, bool, bool) {}
void save_load_boosting_noop(boosting&, VW::io_buf&, bool, bool, const VW::version_struct&) {}
} // namespace

std::shared_ptr<VW::LEARNER::learner> VW::reductions::boosting_setup(VW::setup_base_i& stack_builder)
Expand Down Expand Up @@ -383,7 +383,7 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::boosting_setup(VW::setup_b
std::string name_addition;
void (*learn_ptr)(boosting&, VW::LEARNER::learner&, VW::example&);
void (*pred_ptr)(boosting&, VW::LEARNER::learner&, VW::example&);
void (*save_load_fn)(boosting&, io_buf&, bool, bool);
void (*save_load_fn)(boosting&, io_buf&, bool, bool, const VW::version_struct&);

if (data->alg == "BBM")
{
Expand Down
5 changes: 2 additions & 3 deletions vowpalwabbit/core/src/reductions/cb/cb_adf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,9 @@ void print_update_cb_adf(VW::workspace& all, VW::shared_data& /* sd */, const VW
else { VW::details::print_update_cb(all, !labeled_example, ec, &ec_seq, true, nullptr); }
}

void save_load(VW::reductions::cb_adf& c, VW::io_buf& model_file, bool read, bool text)
void save_load(VW::reductions::cb_adf& c, VW::io_buf& model_file, bool read, bool text, const VW::version_struct& ver)
{
if (c.get_model_file_ver() != nullptr &&
*c.get_model_file_ver() < VW::version_definitions::VERSION_FILE_WITH_CB_ADF_SAVE)
if (ver < VW::version_definitions::VERSION_FILE_WITH_CB_ADF_SAVE)
{
return;
}
Expand Down
6 changes: 2 additions & 4 deletions vowpalwabbit/core/src/reductions/cb/cb_explore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class cb_explore
float psi = 0.f;
bool nounif = false;
bool epsilon_decay = false;
VW::version_struct model_file_version;
VW::io::logger logger;

size_t counter = 0;
Expand Down Expand Up @@ -270,11 +269,11 @@ float calc_loss(const cb_explore& data, const VW::example& ec, const VW::cb_labe
return loss;
}

void save_load(cb_explore& cb, VW::io_buf& io, bool read, bool text)
void save_load(cb_explore& cb, VW::io_buf& io, bool read, bool text, const VW::version_struct& ver)
{
if (io.num_files() == 0) { return; }

if (!read || cb.model_file_version >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG)
if (!read || ver >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG)
{
std::stringstream msg;
if (!read) { msg << "cb cover storing VW::example counter: = " << cb.counter << "\n"; }
Expand Down Expand Up @@ -372,7 +371,6 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::cb_explore_setup(VW::setup
if (data->epsilon < 0.0 || data->epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }

data->cbcs.cb_type = VW::cb_type_t::DR;
data->model_file_version = all.runtime_state.model_file_ver;

size_t params_per_weight = 1;
if (options.was_supplied("cover")) { params_per_weight = data->cover_size + 1; }
Expand Down
14 changes: 6 additions & 8 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_cover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class cb_explore_adf_cover
public:
cb_explore_adf_cover(size_t cover_size, float psi, bool nounif, float epsilon, bool epsilon_decay, bool first_only,
VW::LEARNER::learner* cs_ldf_learner, VW::LEARNER::learner* scorer, VW::cb_type_t cb_type,
VW::version_struct model_file_version, VW::io::logger logger);
VW::io::logger logger);

// Should be called through cb_explore_adf_base for pre/post-processing
void predict(VW::LEARNER::learner& base, VW::multi_ex& examples) { predict_or_learn_impl<false>(base, examples); }
void learn(VW::LEARNER::learner& base, VW::multi_ex& examples) { predict_or_learn_impl<true>(base, examples); }
void save_load(VW::io_buf& io, bool read, bool text);
void save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct&);

private:
size_t _cover_size;
Expand All @@ -56,7 +56,6 @@ class cb_explore_adf_cover
VW::details::cb_to_cs_adf_dr _gen_cs_dr;
VW::cb_type_t _cb_type = VW::cb_type_t::DM;

VW::version_struct _model_file_version;
VW::io::logger _logger;

VW::v_array<VW::action_score> _action_probs;
Expand All @@ -71,7 +70,7 @@ class cb_explore_adf_cover

cb_explore_adf_cover::cb_explore_adf_cover(size_t cover_size, float psi, bool nounif, float epsilon, bool epsilon_decay,
bool first_only, VW::LEARNER::learner* cs_ldf_learner, VW::LEARNER::learner* scorer, VW::cb_type_t cb_type,
VW::version_struct model_file_version, VW::io::logger logger)
VW::io::logger logger)
: _cover_size(cover_size)
, _psi(psi)
, _nounif(nounif)
Expand All @@ -81,7 +80,6 @@ cb_explore_adf_cover::cb_explore_adf_cover(size_t cover_size, float psi, bool no
, _counter(0)
, _cs_ldf_learner(cs_ldf_learner)
, _cb_type(cb_type)
, _model_file_version(model_file_version)
, _logger(std::move(logger))
{
_gen_cs_dr.scorer = scorer;
Expand Down Expand Up @@ -222,10 +220,10 @@ void cb_explore_adf_cover::predict_or_learn_impl(VW::LEARNER::learner& base, VW:
if (is_learn) { ++_counter; }
}

void cb_explore_adf_cover::save_load(VW::io_buf& io, bool read, bool text)
void cb_explore_adf_cover::save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct& ver)
{
if (io.num_files() == 0) { return; }
if (!read || _model_file_version >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG)
if (!read || ver >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG)
{
std::stringstream msg;
if (!read) { msg << "cb cover adf storing example counter: = " << _counter << "\n"; }
Expand Down Expand Up @@ -326,7 +324,7 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::cb_explore_adf_cover_setup
using explore_type = cb_explore_adf_base<cb_explore_adf_cover>;
auto data = VW::make_unique<explore_type>(all.output_runtime.global_metrics.are_metrics_enabled(),
VW::cast_to_smaller_type<size_t>(cover_size), psi, nounif, epsilon, epsilon_decay, first_only, cost_sensitive,
scorer, cb_type, all.runtime_state.model_file_ver, all.logger);
scorer, cb_type, all.logger);
auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_cover_setup))
.set_input_label_type(VW::label_type_t::CB)
Expand Down
Loading
Loading