diff --git a/vowpalwabbit/core/include/vw/core/example.h b/vowpalwabbit/core/include/vw/core/example.h index 64c88edb263..cd2185a8b08 100644 --- a/vowpalwabbit/core/include/vw/core/example.h +++ b/vowpalwabbit/core/include/vw/core/example.h @@ -83,7 +83,7 @@ class polyprediction }; std::string to_string(const v_array& scalars, int decimal_precision = details::DEFAULT_FLOAT_PRECISION); -void move_pred_to(polyprediction& src, polyprediction& dest, prediction_type_t prediction_type); +void swap_prediction(polyprediction& a, polyprediction& b, prediction_type_t prediction_type); class example : public example_predict // core example datatype. { diff --git a/vowpalwabbit/core/src/example.cc b/vowpalwabbit/core/src/example.cc index fec2320787d..cb2c88be038 100644 --- a/vowpalwabbit/core/src/example.cc +++ b/vowpalwabbit/core/src/example.cc @@ -16,45 +16,45 @@ #include #include -void VW::move_pred_to(VW::polyprediction& src, VW::polyprediction& dest, VW::prediction_type_t prediction_type) +void VW::swap_prediction(VW::polyprediction& a, VW::polyprediction& b, VW::prediction_type_t prediction_type) { switch (prediction_type) { case VW::prediction_type_t::SCALAR: - dest.scalar = std::move(src.scalar); + std::swap(b.scalar, a.scalar); break; case VW::prediction_type_t::SCALARS: - dest.scalars = std::move(src.scalars); + std::swap(b.scalars, a.scalars); break; case VW::prediction_type_t::ACTION_SCORES: - dest.a_s = std::move(src.a_s); + std::swap(b.a_s, a.a_s); break; case VW::prediction_type_t::PDF: - dest.pdf = std::move(src.pdf); + std::swap(b.pdf, a.pdf); break; case VW::prediction_type_t::ACTION_PROBS: - dest.a_s = std::move(src.a_s); + std::swap(b.a_s, a.a_s); break; case VW::prediction_type_t::MULTICLASS: - dest.multiclass = std::move(src.multiclass); + std::swap(b.multiclass, a.multiclass); break; case VW::prediction_type_t::MULTILABELS: - dest.multilabels = std::move(src.multilabels); + std::swap(b.multilabels, a.multilabels); break; case VW::prediction_type_t::PROB: - dest.prob = std::move(src.prob); + std::swap(b.prob, a.prob); break; case VW::prediction_type_t::MULTICLASS_PROBS: - dest.scalars = std::move(src.scalars); + std::swap(b.scalars, a.scalars); break; case VW::prediction_type_t::DECISION_PROBS: - dest.decision_scores = std::move(src.decision_scores); + std::swap(b.decision_scores, a.decision_scores); break; case VW::prediction_type_t::ACTION_PDF_VALUE: - dest.pdf_value = std::move(src.pdf_value); + std::swap(b.pdf_value, a.pdf_value); break; case VW::prediction_type_t::ACTIVE_MULTICLASS: - dest.active_multiclass = std::move(src.active_multiclass); + std::swap(b.active_multiclass, a.active_multiclass); break; case VW::prediction_type_t::NOPRED: // Noop diff --git a/vowpalwabbit/core/src/global_data.cc b/vowpalwabbit/core/src/global_data.cc index c14b699f432..ff3833d0682 100644 --- a/vowpalwabbit/core/src/global_data.cc +++ b/vowpalwabbit/core/src/global_data.cc @@ -93,10 +93,9 @@ void workspace::learn(example& ec) { VW::LEARNER::require_singleline(l)->predict(ec); VW::polyprediction saved_prediction; - // VW::move_pred_to(ec.pred, saved_prediction, l->get_output_prediction_type()); - // new (&ec.pred) VW::polyprediction(); + VW::swap_prediction(ec.pred, saved_prediction, l->get_output_prediction_type()); VW::LEARNER::require_singleline(l)->learn(ec); - // VW::move_pred_to(saved_prediction, ec.pred, l->get_output_prediction_type()); + VW::swap_prediction(saved_prediction, ec.pred, l->get_output_prediction_type()); } } } @@ -113,10 +112,9 @@ void workspace::learn(multi_ex& ec) { VW::LEARNER::require_multiline(l)->predict(ec); VW::polyprediction saved_prediction; - // VW::move_pred_to(ec[0]->pred, saved_prediction, l->get_output_prediction_type()); - // new (&ec[0]->pred) VW::polyprediction(); + VW::swap_prediction(ec[0]->pred, saved_prediction, l->get_output_prediction_type()); VW::LEARNER::require_multiline(l)->learn(ec); - // VW::move_pred_to(saved_prediction, ec[0]->pred, l->get_output_prediction_type()); + VW::swap_prediction(saved_prediction, ec[0]->pred, l->get_output_prediction_type()); } } } diff --git a/vowpalwabbit/core/src/reductions/interaction_ground.cc b/vowpalwabbit/core/src/reductions/interaction_ground.cc index 426b186e1fe..7df7f69ef7d 100644 --- a/vowpalwabbit/core/src/reductions/interaction_ground.cc +++ b/vowpalwabbit/core/src/reductions/interaction_ground.cc @@ -122,8 +122,42 @@ void add_obs_features_to_ik_ex(VW::example& ik_ex, const VW::example& obs_ex) } } +void predict(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_seq) +{ + VW::example* observation_ex = nullptr; + + if (ec_seq.size() > 0 && ec_seq.back()->l.cb_with_observations.is_observation) + { + observation_ex = ec_seq.back(); + ec_seq.pop_back(); + } + + std::swap(*igl.pi_ftrl.get(), *igl.ik_ftrl); + + for (auto& ex : ec_seq) + { + ex->l.cb = ex->l.cb_with_observations.event; + ex->l.cb_with_observations.event.reset_to_default(); + } + + base.predict(ec_seq, 1); + std::swap(*igl.pi_ftrl.get(), *igl.ik_ftrl); + + for (auto& ex : ec_seq) + { + ex->l.cb_with_observations.event = ex->l.cb; + ex->l.cb.reset_to_default(); + } + + if (observation_ex != nullptr) { ec_seq.push_back(observation_ex); } +} + void learn(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_seq) { + predict(igl, base, ec_seq); + + auto stashed_prediction = ec_seq[0]->pred.a_s; + float p_unlabeled_prior = 0.5f; std::swap(igl.ik_ftrl->all->loss_config.loss, igl.ik_all->loss_config.loss); @@ -212,36 +246,7 @@ void learn(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_s } ec_seq.push_back(observation_ex); -} - -void predict(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_seq) -{ - VW::example* observation_ex = nullptr; - - if (ec_seq.size() > 0 && ec_seq.back()->l.cb_with_observations.is_observation) - { - observation_ex = ec_seq.back(); - ec_seq.pop_back(); - } - - std::swap(*igl.pi_ftrl.get(), *igl.ik_ftrl); - - for (auto& ex : ec_seq) - { - ex->l.cb = ex->l.cb_with_observations.event; - ex->l.cb_with_observations.event.reset_to_default(); - } - - base.predict(ec_seq, 1); - std::swap(*igl.pi_ftrl.get(), *igl.ik_ftrl); - - for (auto& ex : ec_seq) - { - ex->l.cb_with_observations.event = ex->l.cb; - ex->l.cb.reset_to_default(); - } - - if (observation_ex != nullptr) { ec_seq.push_back(observation_ex); } + ec_seq[0]->pred.a_s = std::move(stashed_prediction); } void save_load_igl(VW::reductions::igl::igl_data& igl, VW::io_buf& io, bool read, bool text) @@ -415,6 +420,7 @@ std::shared_ptr VW::reductions::interaction_ground_setup(V auto l = make_reduction_learner( std::move(ld), pi_learner, learn, predict, stack_builder.get_setupfn_name(interaction_ground_setup)) .set_feature_width(feature_width) + .set_learn_returns_prediction(true) .set_input_label_type(label_type_t::CB_WITH_OBSERVATIONS) .set_output_label_type(label_type_t::CB) .set_input_prediction_type(pred_type)