Skip to content

Commit

Permalink
refactor: pass model version in save_load
Browse files Browse the repository at this point in the history
  • Loading branch information
peterychang committed Nov 11, 2023
1 parent a541d85 commit 47f7bf0
Show file tree
Hide file tree
Showing 39 changed files with 108 additions and 110 deletions.
11 changes: 10 additions & 1 deletion vowpalwabbit/core/include/vw/core/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ using multipredict_func =

using sensitivity_func = std::function<float(example& ex)>;
using save_load_func = std::function<void(io_buf&, bool read, bool text)>;
using save_load_ver_func = std::function<void(io_buf&, bool read, bool text, const VW::version_struct&)>;
using pre_save_load_func = std::function<void(VW::workspace& all)>;
using save_metric_func = std::function<void(metric_sink& metrics)>;

Expand Down Expand Up @@ -182,7 +183,7 @@ class learner final : public std::enable_shared_from_this<learner>
float sensitivity(example& ec, size_t i = 0);

// Called anytime saving or loading needs to happen. Autorecursive.
void save_load(io_buf& io, const bool read, const bool text);
void save_load(io_buf& io, const bool read, const bool text, const VW::version_struct& model_version);

// Called to edit the command-line from a learner. Autorecursive
void pre_save_load(VW::workspace& all);
Expand Down Expand Up @@ -287,6 +288,7 @@ class learner final : public std::enable_shared_from_this<learner>
details::cleanup_example_func _cleanup_example_f;

details::save_load_func _save_load_f;
details::save_load_ver_func _save_load_ver_f;
details::void_func _end_pass_f;
details::void_func _end_examples_f;
details::pre_save_load_func _pre_save_load_f;
Expand Down Expand Up @@ -417,12 +419,19 @@ class common_learner_builder
learner_ptr->learn_returns_prediction = learn_returns_prediction;
)

// TODO: deprecate?
LEARNER_BUILDER_DEFINE(set_save_load(void (*fn_ptr)(DataT&, io_buf&, bool, bool)),
assert(fn_ptr != nullptr);
DataT* data = this->learner_data.get();
this->learner_ptr->_save_load_f = [fn_ptr, data](io_buf& buf, bool read, bool text)
{ fn_ptr(*data, buf, read, text); };
)
LEARNER_BUILDER_DEFINE(set_save_load(void (*fn_ptr)(DataT&, io_buf&, bool, bool, const VW::version_struct&)),
assert(fn_ptr != nullptr);
DataT* data = this->learner_data.get();
this->learner_ptr->_save_load_ver_f = [fn_ptr, data](io_buf& buf, bool read, bool text, const VW::version_struct& ver)
{ fn_ptr(*data, buf, read, text, ver); };
)

LEARNER_BUILDER_DEFINE(set_finish(void (*fn_ptr)(DataT&)),
assert(fn_ptr != nullptr);
Expand Down
5 changes: 1 addition & 4 deletions vowpalwabbit/core/include/vw/core/reductions/active.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ namespace reductions
class active
{
public:
active(float active_c0, std::shared_ptr<shared_data> shared_data, std::shared_ptr<rand_state> random_state,
VW::version_struct model_version)
active(float active_c0, std::shared_ptr<shared_data> shared_data, std::shared_ptr<rand_state> random_state)
: active_c0(active_c0)
, _shared_data(shared_data)
, _random_state(std::move(random_state))
, _model_version{std::move(model_version)}
{
}

Expand All @@ -31,7 +29,6 @@ class active

float _min_seen_label = 0.f;
float _max_seen_label = 1.f;
VW::version_struct _model_version;
};

std::shared_ptr<VW::LEARNER::learner> active_setup(VW::setup_base_i& stack_builder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class cb_explore_adf_base
if (with_metrics) { _metrics = VW::make_unique<cb_explore_metrics>(); }
}

static void save_load(cb_explore_adf_base<ExploreType>& data, io_buf& io, bool read, bool text);
static void save_load(cb_explore_adf_base<ExploreType>& data, io_buf& io, bool read, bool text, const VW::version_struct& ver);
static void persist_metrics(cb_explore_adf_base<ExploreType>& data, metric_sink& metrics);
static void predict(cb_explore_adf_base<ExploreType>& data, VW::LEARNER::learner& base, multi_ex& examples);
static void learn(cb_explore_adf_base<ExploreType>& data, VW::LEARNER::learner& base, multi_ex& examples);
Expand Down Expand Up @@ -302,9 +302,9 @@ void cb_explore_adf_base<ExploreType>::_print_update(

template <typename ExploreType>
inline void cb_explore_adf_base<ExploreType>::save_load(
cb_explore_adf_base<ExploreType>& data, io_buf& io, bool read, bool text)
cb_explore_adf_base<ExploreType>& data, io_buf& io, bool read, bool text, const VW::version_struct& ver)
{
data.explore.save_load(io, read, text);
data.explore.save_load(io, read, text, ver);
}

template <typename ExploreType>
Expand Down
17 changes: 15 additions & 2 deletions vowpalwabbit/core/src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ float learner::sensitivity(example& ec, size_t i)
return ret;
}

void learner::save_load(io_buf& io, const bool read, const bool text)
void learner::save_load(io_buf& io, const bool read, const bool text, const VW::version_struct& model_version)
{
if (_save_load_f)
{
Expand All @@ -457,7 +457,20 @@ void learner::save_load(io_buf& io, const bool read, const bool text)
throw VW::save_load_model_exception(vwex.filename(), vwex.line_number(), better_msg.str());
}
}
if (_base_learner) { _base_learner->save_load(io, read, text); }
else if (_save_load_ver_f)
{
try
{
_save_load_ver_f(io, read, text, model_version);
}
catch (VW::vw_exception& vwex)
{
std::stringstream better_msg;
better_msg << "model " << std::string(read ? "load" : "save") << " failed. Error Details: " << vwex.what();
throw VW::save_load_model_exception(vwex.filename(), vwex.line_number(), better_msg.str());
}
}
if (_base_learner) { _base_learner->save_load(io, read, text, model_version); }
}

void learner::pre_save_load(VW::workspace& all)
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/core/src/parse_args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,7 @@ void load_input_model(VW::workspace& all, VW::io_buf& io_temp)
all.feature_mask == all.initial_weights_config.initial_regressors[0])
{
// load rest of regressor
all.l->save_load(io_temp, true, false);
all.l->save_load(io_temp, true, false, all.runtime_state.model_file_ver);
io_temp.close_file();

VW::details::parse_mask_regressor_args(all, all.feature_mask, all.initial_weights_config.initial_regressors);
Expand All @@ -1372,7 +1372,7 @@ void load_input_model(VW::workspace& all, VW::io_buf& io_temp)
VW::details::parse_mask_regressor_args(all, all.feature_mask, all.initial_weights_config.initial_regressors);

// load rest of regressor
all.l->save_load(io_temp, true, false);
all.l->save_load(io_temp, true, false, all.runtime_state.model_file_ver);
io_temp.close_file();
}
}
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/core/src/parse_regressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ void VW::details::dump_regressor(VW::workspace& all, VW::io_buf& buf, bool as_te
std::string unused;
if (all.l != nullptr) { all.l->pre_save_load(all); }
VW::details::save_load_header(all, buf, false, as_text, unused, *all.options);
if (all.l != nullptr) { all.l->save_load(buf, false, as_text); }
if (all.l != nullptr) { all.l->save_load(buf, false, as_text, all.runtime_state.model_file_ver); }

buf.flush(); // close_file() should do this for me ...
buf.close_file();
Expand Down Expand Up @@ -580,7 +580,7 @@ void VW::details::parse_mask_regressor_args(
io_temp_mask.add_file(VW::io::open_file_reader(feature_mask));

save_load_header(all, io_temp_mask, true, false, file_options, *all.options);
all.l->save_load(io_temp_mask, true, false);
all.l->save_load(io_temp_mask, true, false, all.runtime_state.model_file_ver);
io_temp_mask.close_file();

// Deal with the over-written header from initial regressor
Expand Down
10 changes: 5 additions & 5 deletions vowpalwabbit/core/src/reductions/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,16 @@ void active_print_result(
if (t != len) { logger.err_error("write error: {}", VW::io::strerror_to_string(errno)); }
}

void save_load(active& a, VW::io_buf& io, bool read, bool text)
void save_load(active& a, VW::io_buf& io, bool read, bool text, const VW::version_struct& ver)
{
using namespace VW::version_definitions;
if (io.num_files() == 0) { return; }
// This code is valid if version is within
// [VERSION_FILE_WITH_ACTIVE_SEEN_LABELS, VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED)
// or >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED
if ((a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS &&
a._model_version < VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) ||
a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED)
if ((ver >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS &&
ver < VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) ||
ver >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED)
{
if (read)
{
Expand Down Expand Up @@ -201,7 +201,7 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::active_setup(VW::setup_bas
if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; }

if (options.was_supplied("lda")) { THROW("lda cannot be combined with active learning") }
auto data = VW::make_unique<active>(active_c0, all.sd, all.get_random_state(), all.runtime_state.model_file_ver);
auto data = VW::make_unique<active>(active_c0, all.sd, all.get_random_state());
auto base = require_singleline(stack_builder.setup_base_learner());

using learn_pred_func_t = void (*)(active&, VW::LEARNER::learner&, VW::example&);
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/reductions/automl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ void pre_save_load_automl(VW::workspace& all, automl<CMType>& data)
}

template <typename CMType>
void save_load_automl(automl<CMType>& aml, VW::io_buf& io, bool read, bool text)
void save_load_automl(automl<CMType>& aml, VW::io_buf& io, bool read, bool text, const VW::version_struct&)
{
if (io.num_files() == 0) { return; }
if (read) { VW::model_utils::read_model_field(io, aml); }
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/reductions/baseline_challenger_cb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ void learn_or_predict(baseline_challenger_data& data, learner& base, VW::multi_e
data.learn_or_predict<is_learn>(base, examples);
}

void save_load(baseline_challenger_data& data, VW::io_buf& io, bool read, bool text)
void save_load(baseline_challenger_data& data, VW::io_buf& io, bool read, bool text, const VW::version_struct&)
{
if (io.num_files() == 0) { return; }
if (read) { VW::model_utils::read_model_field(io, data); }
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/reductions/bfgs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ void save_load_regularizer(VW::workspace& all, bfgs& b, VW::io_buf& model_file,
if (read) { regularizer_to_weight(all, b); }
}

void save_load(bfgs& b, VW::io_buf& model_file, bool read, bool text)
void save_load(bfgs& b, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&)
{
VW::workspace* all = b.all;

Expand Down
8 changes: 4 additions & 4 deletions vowpalwabbit/core/src/reductions/boosting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void predict_or_learn_adaptive(boosting& o, VW::LEARNER::learner& base, VW::exam
ec.pred.scalar = final_prediction;
}

void save_load_sampling(boosting& o, VW::io_buf& model_file, bool read, bool text)
void save_load_sampling(boosting& o, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&)
{
if (model_file.num_files() == 0) { return; }
std::stringstream os;
Expand Down Expand Up @@ -298,7 +298,7 @@ void save_load_sampling(boosting& o, VW::io_buf& model_file, bool read, bool tex
o.logger.err_info("{}", fmt::to_string(buffer));
}

void save_load(boosting& o, VW::io_buf& model_file, bool read, bool text)
void save_load(boosting& o, VW::io_buf& model_file, bool read, bool text, const VW::version_struct&)
{
if (model_file.num_files() == 0) { return; }
std::stringstream os;
Expand Down Expand Up @@ -340,7 +340,7 @@ void save_load(boosting& o, VW::io_buf& model_file, bool read, bool text)
}
}

void save_load_boosting_noop(boosting&, VW::io_buf&, bool, bool) {}
void save_load_boosting_noop(boosting&, VW::io_buf&, bool, bool, const VW::version_struct&) {}
} // namespace

std::shared_ptr<VW::LEARNER::learner> VW::reductions::boosting_setup(VW::setup_base_i& stack_builder)
Expand Down Expand Up @@ -383,7 +383,7 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::boosting_setup(VW::setup_b
std::string name_addition;
void (*learn_ptr)(boosting&, VW::LEARNER::learner&, VW::example&);
void (*pred_ptr)(boosting&, VW::LEARNER::learner&, VW::example&);
void (*save_load_fn)(boosting&, io_buf&, bool, bool);
void (*save_load_fn)(boosting&, io_buf&, bool, bool, const VW::version_struct&);

if (data->alg == "BBM")
{
Expand Down
5 changes: 2 additions & 3 deletions vowpalwabbit/core/src/reductions/cb/cb_adf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,9 @@ void print_update_cb_adf(VW::workspace& all, VW::shared_data& /* sd */, const VW
else { VW::details::print_update_cb(all, !labeled_example, ec, &ec_seq, true, nullptr); }
}

void save_load(VW::reductions::cb_adf& c, VW::io_buf& model_file, bool read, bool text)
void save_load(VW::reductions::cb_adf& c, VW::io_buf& model_file, bool read, bool text, const VW::version_struct& ver)
{
if (c.get_model_file_ver() != nullptr &&
*c.get_model_file_ver() < VW::version_definitions::VERSION_FILE_WITH_CB_ADF_SAVE)
if (ver < VW::version_definitions::VERSION_FILE_WITH_CB_ADF_SAVE)
{
return;
}
Expand Down
6 changes: 2 additions & 4 deletions vowpalwabbit/core/src/reductions/cb/cb_explore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class cb_explore
float psi = 0.f;
bool nounif = false;
bool epsilon_decay = false;
VW::version_struct model_file_version;
VW::io::logger logger;

size_t counter = 0;
Expand Down Expand Up @@ -270,11 +269,11 @@ float calc_loss(const cb_explore& data, const VW::example& ec, const VW::cb_labe
return loss;
}

void save_load(cb_explore& cb, VW::io_buf& io, bool read, bool text)
void save_load(cb_explore& cb, VW::io_buf& io, bool read, bool text, const VW::version_struct& ver)
{
if (io.num_files() == 0) { return; }

if (!read || cb.model_file_version >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG)
if (!read || ver >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG)
{
std::stringstream msg;
if (!read) { msg << "cb cover storing VW::example counter: = " << cb.counter << "\n"; }
Expand Down Expand Up @@ -372,7 +371,6 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::cb_explore_setup(VW::setup
if (data->epsilon < 0.0 || data->epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }

data->cbcs.cb_type = VW::cb_type_t::DR;
data->model_file_version = all.runtime_state.model_file_ver;

size_t params_per_weight = 1;
if (options.was_supplied("cover")) { params_per_weight = data->cover_size + 1; }
Expand Down
14 changes: 6 additions & 8 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_cover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class cb_explore_adf_cover
public:
cb_explore_adf_cover(size_t cover_size, float psi, bool nounif, float epsilon, bool epsilon_decay, bool first_only,
VW::LEARNER::learner* cs_ldf_learner, VW::LEARNER::learner* scorer, VW::cb_type_t cb_type,
VW::version_struct model_file_version, VW::io::logger logger);
VW::io::logger logger);

// Should be called through cb_explore_adf_base for pre/post-processing
void predict(VW::LEARNER::learner& base, VW::multi_ex& examples) { predict_or_learn_impl<false>(base, examples); }
void learn(VW::LEARNER::learner& base, VW::multi_ex& examples) { predict_or_learn_impl<true>(base, examples); }
void save_load(VW::io_buf& io, bool read, bool text);
void save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct&);

private:
size_t _cover_size;
Expand All @@ -56,7 +56,6 @@ class cb_explore_adf_cover
VW::details::cb_to_cs_adf_dr _gen_cs_dr;
VW::cb_type_t _cb_type = VW::cb_type_t::DM;

VW::version_struct _model_file_version;
VW::io::logger _logger;

VW::v_array<VW::action_score> _action_probs;
Expand All @@ -71,7 +70,7 @@ class cb_explore_adf_cover

cb_explore_adf_cover::cb_explore_adf_cover(size_t cover_size, float psi, bool nounif, float epsilon, bool epsilon_decay,
bool first_only, VW::LEARNER::learner* cs_ldf_learner, VW::LEARNER::learner* scorer, VW::cb_type_t cb_type,
VW::version_struct model_file_version, VW::io::logger logger)
VW::io::logger logger)
: _cover_size(cover_size)
, _psi(psi)
, _nounif(nounif)
Expand All @@ -81,7 +80,6 @@ cb_explore_adf_cover::cb_explore_adf_cover(size_t cover_size, float psi, bool no
, _counter(0)
, _cs_ldf_learner(cs_ldf_learner)
, _cb_type(cb_type)
, _model_file_version(model_file_version)
, _logger(std::move(logger))
{
_gen_cs_dr.scorer = scorer;
Expand Down Expand Up @@ -222,10 +220,10 @@ void cb_explore_adf_cover::predict_or_learn_impl(VW::LEARNER::learner& base, VW:
if (is_learn) { ++_counter; }
}

void cb_explore_adf_cover::save_load(VW::io_buf& io, bool read, bool text)
void cb_explore_adf_cover::save_load(VW::io_buf& io, bool read, bool text, const VW::version_struct& ver)
{
if (io.num_files() == 0) { return; }
if (!read || _model_file_version >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG)
if (!read || ver >= VW::version_definitions::VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG)
{
std::stringstream msg;
if (!read) { msg << "cb cover adf storing example counter: = " << _counter << "\n"; }
Expand Down Expand Up @@ -326,7 +324,7 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::cb_explore_adf_cover_setup
using explore_type = cb_explore_adf_base<cb_explore_adf_cover>;
auto data = VW::make_unique<explore_type>(all.output_runtime.global_metrics.are_metrics_enabled(),
VW::cast_to_smaller_type<size_t>(cover_size), psi, nounif, epsilon, epsilon_decay, first_only, cost_sensitive,
scorer, cb_type, all.runtime_state.model_file_ver, all.logger);
scorer, cb_type, all.logger);
auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(cb_explore_adf_cover_setup))
.set_input_label_type(VW::label_type_t::CB)
Expand Down
Loading

0 comments on commit 47f7bf0

Please sign in to comment.