From 8a1a524fbf334d5df10e1280e5511fb59ce2fa80 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 15 Aug 2023 14:33:25 -0400 Subject: [PATCH] move dont swap --- vowpalwabbit/core/include/vw/core/example.h | 2 +- vowpalwabbit/core/src/example.cc | 28 ++++++++++----------- vowpalwabbit/core/src/global_data.cc | 10 +++++--- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/vowpalwabbit/core/include/vw/core/example.h b/vowpalwabbit/core/include/vw/core/example.h index cd2185a8b08..64c88edb263 100644 --- a/vowpalwabbit/core/include/vw/core/example.h +++ b/vowpalwabbit/core/include/vw/core/example.h @@ -83,7 +83,7 @@ class polyprediction }; std::string to_string(const v_array& scalars, int decimal_precision = details::DEFAULT_FLOAT_PRECISION); -void swap_prediction(polyprediction& a, polyprediction& b, prediction_type_t prediction_type); +void move_pred_to(polyprediction& src, polyprediction& dest, prediction_type_t prediction_type); class example : public example_predict // core example datatype. { diff --git a/vowpalwabbit/core/src/example.cc b/vowpalwabbit/core/src/example.cc index ac925634119..fec2320787d 100644 --- a/vowpalwabbit/core/src/example.cc +++ b/vowpalwabbit/core/src/example.cc @@ -16,45 +16,45 @@ #include #include -void VW::swap_prediction(VW::polyprediction& a, VW::polyprediction& b, VW::prediction_type_t prediction_type) +void VW::move_pred_to(VW::polyprediction& src, VW::polyprediction& dest, VW::prediction_type_t prediction_type) { - switch(prediction_type) + switch (prediction_type) { case VW::prediction_type_t::SCALAR: - std::swap(a.scalar, b.scalar); + dest.scalar = std::move(src.scalar); break; case VW::prediction_type_t::SCALARS: - std::swap(a.scalars, b.scalars); + dest.scalars = std::move(src.scalars); break; case VW::prediction_type_t::ACTION_SCORES: - std::swap(a.a_s, b.a_s); + dest.a_s = std::move(src.a_s); break; case VW::prediction_type_t::PDF: - std::swap(a.pdf, b.pdf); + dest.pdf = std::move(src.pdf); break; case VW::prediction_type_t::ACTION_PROBS: - std::swap(a.a_s, b.a_s); + dest.a_s = std::move(src.a_s); break; case VW::prediction_type_t::MULTICLASS: - std::swap(a.multiclass, b.multiclass); + dest.multiclass = std::move(src.multiclass); break; case VW::prediction_type_t::MULTILABELS: - std::swap(a.multilabels, b.multilabels); + dest.multilabels = std::move(src.multilabels); break; case VW::prediction_type_t::PROB: - std::swap(a.prob, b.prob); + dest.prob = std::move(src.prob); break; case VW::prediction_type_t::MULTICLASS_PROBS: - std::swap(a.scalars, b.scalars); + dest.scalars = std::move(src.scalars); break; case VW::prediction_type_t::DECISION_PROBS: - std::swap(a.decision_scores, b.decision_scores); + dest.decision_scores = std::move(src.decision_scores); break; case VW::prediction_type_t::ACTION_PDF_VALUE: - std::swap(a.pdf_value, b.pdf_value); + dest.pdf_value = std::move(src.pdf_value); break; case VW::prediction_type_t::ACTIVE_MULTICLASS: - std::swap(a.active_multiclass, b.active_multiclass); + dest.active_multiclass = std::move(src.active_multiclass); break; case VW::prediction_type_t::NOPRED: // Noop diff --git a/vowpalwabbit/core/src/global_data.cc b/vowpalwabbit/core/src/global_data.cc index c476427c74a..1b8fada2a2f 100644 --- a/vowpalwabbit/core/src/global_data.cc +++ b/vowpalwabbit/core/src/global_data.cc @@ -93,9 +93,10 @@ void workspace::learn(example& ec) { VW::LEARNER::require_singleline(l)->predict(ec); VW::polyprediction saved_prediction; - VW::swap_prediction(saved_prediction, ec.pred, l->get_output_prediction_type()); + VW::move_pred_to(ec.pred, saved_prediction, l->get_output_prediction_type()); + new (&ec.pred) VW::polyprediction(); VW::LEARNER::require_singleline(l)->learn(ec); - VW::swap_prediction(saved_prediction, ec.pred, l->get_output_prediction_type()); + VW::move_pred_to(saved_prediction, ec.pred, l->get_output_prediction_type()); } } } @@ -112,9 +113,10 @@ void workspace::learn(multi_ex& ec) { VW::LEARNER::require_multiline(l)->predict(ec); VW::polyprediction saved_prediction; - VW::swap_prediction(saved_prediction, ec[0]->pred, l->get_output_prediction_type()); + VW::move_pred_to(ec[0]->pred, saved_prediction, l->get_output_prediction_type()); + new (&ec[0]->pred) VW::polyprediction(); VW::LEARNER::require_multiline(l)->learn(ec); - VW::swap_prediction(saved_prediction, ec[0]->pred, l->get_output_prediction_type()); + VW::move_pred_to(saved_prediction, ec[0]->pred, l->get_output_prediction_type()); } } }