Skip to content

Commit

Permalink
go back to swap, fix igl's usage of pred in learn
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Aug 15, 2023
1 parent cae2640 commit cb27769
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 50 deletions.
2 changes: 1 addition & 1 deletion vowpalwabbit/core/include/vw/core/example.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class polyprediction
};

std::string to_string(const v_array<float>& scalars, int decimal_precision = details::DEFAULT_FLOAT_PRECISION);
void move_pred_to(polyprediction& src, polyprediction& dest, prediction_type_t prediction_type);
void swap_prediction(polyprediction& a, polyprediction& b, prediction_type_t prediction_type);

class example : public example_predict // core example datatype.
{
Expand Down
26 changes: 13 additions & 13 deletions vowpalwabbit/core/src/example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,45 @@
#include <climits>
#include <cstdint>

void VW::move_pred_to(VW::polyprediction& src, VW::polyprediction& dest, VW::prediction_type_t prediction_type)
void VW::swap_prediction(VW::polyprediction& a, VW::polyprediction& b, VW::prediction_type_t prediction_type)
{
switch (prediction_type)
{
case VW::prediction_type_t::SCALAR:
dest.scalar = std::move(src.scalar);
std::swap(b.scalar, a.scalar);
break;
case VW::prediction_type_t::SCALARS:
dest.scalars = std::move(src.scalars);
std::swap(b.scalars, a.scalars);
break;
case VW::prediction_type_t::ACTION_SCORES:
dest.a_s = std::move(src.a_s);
std::swap(b.a_s, a.a_s);
break;
case VW::prediction_type_t::PDF:
dest.pdf = std::move(src.pdf);
std::swap(b.pdf, a.pdf);
break;
case VW::prediction_type_t::ACTION_PROBS:
dest.a_s = std::move(src.a_s);
std::swap(b.a_s, a.a_s);
break;
case VW::prediction_type_t::MULTICLASS:
dest.multiclass = std::move(src.multiclass);
std::swap(b.multiclass, a.multiclass);
break;
case VW::prediction_type_t::MULTILABELS:
dest.multilabels = std::move(src.multilabels);
std::swap(b.multilabels, a.multilabels);
break;
case VW::prediction_type_t::PROB:
dest.prob = std::move(src.prob);
std::swap(b.prob, a.prob);
break;
case VW::prediction_type_t::MULTICLASS_PROBS:
dest.scalars = std::move(src.scalars);
std::swap(b.scalars, a.scalars);
break;
case VW::prediction_type_t::DECISION_PROBS:
dest.decision_scores = std::move(src.decision_scores);
std::swap(b.decision_scores, a.decision_scores);
break;
case VW::prediction_type_t::ACTION_PDF_VALUE:
dest.pdf_value = std::move(src.pdf_value);
std::swap(b.pdf_value, a.pdf_value);
break;
case VW::prediction_type_t::ACTIVE_MULTICLASS:
dest.active_multiclass = std::move(src.active_multiclass);
std::swap(b.active_multiclass, a.active_multiclass);
break;
case VW::prediction_type_t::NOPRED:
// Noop
Expand Down
10 changes: 4 additions & 6 deletions vowpalwabbit/core/src/global_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,9 @@ void workspace::learn(example& ec)
{
VW::LEARNER::require_singleline(l)->predict(ec);
VW::polyprediction saved_prediction;
// VW::move_pred_to(ec.pred, saved_prediction, l->get_output_prediction_type());
// new (&ec.pred) VW::polyprediction();
VW::swap_prediction(ec.pred, saved_prediction, l->get_output_prediction_type());
VW::LEARNER::require_singleline(l)->learn(ec);
// VW::move_pred_to(saved_prediction, ec.pred, l->get_output_prediction_type());
VW::swap_prediction(saved_prediction, ec.pred, l->get_output_prediction_type());
}
}
}
Expand All @@ -113,10 +112,9 @@ void workspace::learn(multi_ex& ec)
{
VW::LEARNER::require_multiline(l)->predict(ec);
VW::polyprediction saved_prediction;
// VW::move_pred_to(ec[0]->pred, saved_prediction, l->get_output_prediction_type());
// new (&ec[0]->pred) VW::polyprediction();
VW::swap_prediction(ec[0]->pred, saved_prediction, l->get_output_prediction_type());
VW::LEARNER::require_multiline(l)->learn(ec);
// VW::move_pred_to(saved_prediction, ec[0]->pred, l->get_output_prediction_type());
VW::swap_prediction(saved_prediction, ec[0]->pred, l->get_output_prediction_type());
}
}
}
Expand Down
66 changes: 36 additions & 30 deletions vowpalwabbit/core/src/reductions/interaction_ground.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,42 @@ void add_obs_features_to_ik_ex(VW::example& ik_ex, const VW::example& obs_ex)
}
}

void predict(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_seq)
{
VW::example* observation_ex = nullptr;

if (ec_seq.size() > 0 && ec_seq.back()->l.cb_with_observations.is_observation)
{
observation_ex = ec_seq.back();
ec_seq.pop_back();
}

std::swap(*igl.pi_ftrl.get(), *igl.ik_ftrl);

for (auto& ex : ec_seq)
{
ex->l.cb = ex->l.cb_with_observations.event;
ex->l.cb_with_observations.event.reset_to_default();
}

base.predict(ec_seq, 1);
std::swap(*igl.pi_ftrl.get(), *igl.ik_ftrl);

for (auto& ex : ec_seq)
{
ex->l.cb_with_observations.event = ex->l.cb;
ex->l.cb.reset_to_default();
}

if (observation_ex != nullptr) { ec_seq.push_back(observation_ex); }
}

void learn(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_seq)
{
predict(igl, base, ec_seq);

auto stashed_prediction = ec_seq[0]->pred.a_s;

float p_unlabeled_prior = 0.5f;

std::swap(igl.ik_ftrl->all->loss_config.loss, igl.ik_all->loss_config.loss);
Expand Down Expand Up @@ -212,36 +246,7 @@ void learn(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_s
}

ec_seq.push_back(observation_ex);
}

void predict(VW::reductions::igl::igl_data& igl, learner& base, VW::multi_ex& ec_seq)
{
VW::example* observation_ex = nullptr;

if (ec_seq.size() > 0 && ec_seq.back()->l.cb_with_observations.is_observation)
{
observation_ex = ec_seq.back();
ec_seq.pop_back();
}

std::swap(*igl.pi_ftrl.get(), *igl.ik_ftrl);

for (auto& ex : ec_seq)
{
ex->l.cb = ex->l.cb_with_observations.event;
ex->l.cb_with_observations.event.reset_to_default();
}

base.predict(ec_seq, 1);
std::swap(*igl.pi_ftrl.get(), *igl.ik_ftrl);

for (auto& ex : ec_seq)
{
ex->l.cb_with_observations.event = ex->l.cb;
ex->l.cb.reset_to_default();
}

if (observation_ex != nullptr) { ec_seq.push_back(observation_ex); }
ec_seq[0]->pred.a_s = std::move(stashed_prediction);
}

void save_load_igl(VW::reductions::igl::igl_data& igl, VW::io_buf& io, bool read, bool text)
Expand Down Expand Up @@ -415,6 +420,7 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::interaction_ground_setup(V
auto l = make_reduction_learner(
std::move(ld), pi_learner, learn, predict, stack_builder.get_setupfn_name(interaction_ground_setup))
.set_feature_width(feature_width)
.set_learn_returns_prediction(true)
.set_input_label_type(label_type_t::CB_WITH_OBSERVATIONS)
.set_output_label_type(label_type_t::CB)
.set_input_prediction_type(pred_type)
Expand Down

0 comments on commit cb27769

Please sign in to comment.