Skip to content

Commit

Permalink
move dont swap
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Aug 15, 2023
1 parent b6ece99 commit 8a1a524
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 19 deletions.
2 changes: 1 addition & 1 deletion vowpalwabbit/core/include/vw/core/example.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class polyprediction
};

std::string to_string(const v_array<float>& scalars, int decimal_precision = details::DEFAULT_FLOAT_PRECISION);
void swap_prediction(polyprediction& a, polyprediction& b, prediction_type_t prediction_type);
void move_pred_to(polyprediction& src, polyprediction& dest, prediction_type_t prediction_type);

class example : public example_predict // core example datatype.
{
Expand Down
28 changes: 14 additions & 14 deletions vowpalwabbit/core/src/example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,45 @@
#include <climits>
#include <cstdint>

void VW::swap_prediction(VW::polyprediction& a, VW::polyprediction& b, VW::prediction_type_t prediction_type)
void VW::move_pred_to(VW::polyprediction& src, VW::polyprediction& dest, VW::prediction_type_t prediction_type)
{
switch(prediction_type)
switch (prediction_type)
{
case VW::prediction_type_t::SCALAR:
std::swap(a.scalar, b.scalar);
dest.scalar = std::move(src.scalar);
break;
case VW::prediction_type_t::SCALARS:
std::swap(a.scalars, b.scalars);
dest.scalars = std::move(src.scalars);
break;
case VW::prediction_type_t::ACTION_SCORES:
std::swap(a.a_s, b.a_s);
dest.a_s = std::move(src.a_s);
break;
case VW::prediction_type_t::PDF:
std::swap(a.pdf, b.pdf);
dest.pdf = std::move(src.pdf);
break;
case VW::prediction_type_t::ACTION_PROBS:
std::swap(a.a_s, b.a_s);
dest.a_s = std::move(src.a_s);
break;
case VW::prediction_type_t::MULTICLASS:
std::swap(a.multiclass, b.multiclass);
dest.multiclass = std::move(src.multiclass);
break;
case VW::prediction_type_t::MULTILABELS:
std::swap(a.multilabels, b.multilabels);
dest.multilabels = std::move(src.multilabels);
break;
case VW::prediction_type_t::PROB:
std::swap(a.prob, b.prob);
dest.prob = std::move(src.prob);
break;
case VW::prediction_type_t::MULTICLASS_PROBS:
std::swap(a.scalars, b.scalars);
dest.scalars = std::move(src.scalars);
break;
case VW::prediction_type_t::DECISION_PROBS:
std::swap(a.decision_scores, b.decision_scores);
dest.decision_scores = std::move(src.decision_scores);
break;
case VW::prediction_type_t::ACTION_PDF_VALUE:
std::swap(a.pdf_value, b.pdf_value);
dest.pdf_value = std::move(src.pdf_value);
break;
case VW::prediction_type_t::ACTIVE_MULTICLASS:
std::swap(a.active_multiclass, b.active_multiclass);
dest.active_multiclass = std::move(src.active_multiclass);
break;
case VW::prediction_type_t::NOPRED:
// Noop
Expand Down
10 changes: 6 additions & 4 deletions vowpalwabbit/core/src/global_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ void workspace::learn(example& ec)
{
VW::LEARNER::require_singleline(l)->predict(ec);
VW::polyprediction saved_prediction;
VW::swap_prediction(saved_prediction, ec.pred, l->get_output_prediction_type());
VW::move_pred_to(ec.pred, saved_prediction, l->get_output_prediction_type());
new (&ec.pred) VW::polyprediction();
VW::LEARNER::require_singleline(l)->learn(ec);
VW::swap_prediction(saved_prediction, ec.pred, l->get_output_prediction_type());
VW::move_pred_to(saved_prediction, ec.pred, l->get_output_prediction_type());
}
}
}
Expand All @@ -112,9 +113,10 @@ void workspace::learn(multi_ex& ec)
{
VW::LEARNER::require_multiline(l)->predict(ec);
VW::polyprediction saved_prediction;
VW::swap_prediction(saved_prediction, ec[0]->pred, l->get_output_prediction_type());
VW::move_pred_to(ec[0]->pred, saved_prediction, l->get_output_prediction_type());
new (&ec[0]->pred) VW::polyprediction();
VW::LEARNER::require_multiline(l)->learn(ec);
VW::swap_prediction(saved_prediction, ec[0]->pred, l->get_output_prediction_type());
VW::move_pred_to(saved_prediction, ec[0]->pred, l->get_output_prediction_type());
}
}
}
Expand Down

0 comments on commit 8a1a524

Please sign in to comment.