From bea68162b1c7457af9506991c6a29a6e833f276b Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Thu, 31 Aug 2023 13:14:47 -0400 Subject: [PATCH] fix: update to reflect the fact that coin's learn does not return a prediction (#4621) * fix: coin's learn does not return a prediction, update to reflect The prediction it currently returns essentially represents what the prediction would have been post learn and therefore does not match the contract of learn_returns_prediction, * update test --------- Co-authored-by: Alexey Taymanov <41013086+ataymano@users.noreply.github.com> --- test/pred-sets/ref/ftrl_coin.predict | 22 +++++++++++----------- test/train-sets/ref/ftrl_coin.stderr | 14 +++++++------- vowpalwabbit/core/src/reductions/ftrl.cc | 4 ++-- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/test/pred-sets/ref/ftrl_coin.predict b/test/pred-sets/ref/ftrl_coin.predict index 5fcc48aa9bc..e91425cf8a9 100644 --- a/test/pred-sets/ref/ftrl_coin.predict +++ b/test/pred-sets/ref/ftrl_coin.predict @@ -1,12 +1,12 @@ 1 0.361325 -0.143347 +0.143348 0 0.067602 0.934070 0.160364 0 -0.576177 +0.576178 1 0.068731 0 @@ -38,17 +38,17 @@ 0 1 1 -0.118569 +0.118570 0.791537 0 0 0.001875 0 -0.179304 +0.179305 0.141783 0.575180 0.236666 -0.860866 +0.860867 1 0 0 @@ -95,7 +95,7 @@ 1 0 0.683733 -0.138811 +0.138812 0.555727 0.665974 0.212627 @@ -147,8 +147,8 @@ 0.616288 1 0 -0.636642 -0.694423 +0.636643 +0.694424 0.643842 0.647201 0.032352 @@ -163,13 +163,13 @@ 0 1 0.250740 -0.752841 +0.752842 0.756894 0.112263 0.912395 0.563831 0.576449 -0.142849 +0.142850 0 0.390992 0 @@ -188,7 +188,7 @@ 0.207041 0.094113 0.555100 -0.612930 +0.612931 0.217877 0.319611 0.119250 diff --git a/test/train-sets/ref/ftrl_coin.stderr b/test/train-sets/ref/ftrl_coin.stderr index f17886106be..028a86f39c0 100644 --- a/test/train-sets/ref/ftrl_coin.stderr +++ b/test/train-sets/ref/ftrl_coin.stderr @@ -16,13 +16,13 @@ Output pred = SCALAR average since example example current current current loss last counter weight label predict features 1.000000 1.000000 1 1.0 1.0000 0.0000 51 -0.503420 0.006841 2 2.0 0.0000 0.0827 104 -0.253545 0.003669 4 4.0 0.0000 0.0000 135 -0.250782 0.248019 8 8.0 0.0000 0.0259 146 -0.278089 0.305396 16 16.0 1.0000 0.1788 24 -0.292183 0.306278 32 32.0 0.0000 0.3479 32 -0.263581 0.234979 64 64.0 0.0000 0.0007 61 -0.233750 0.203920 128 128.0 1.0000 0.7720 106 +0.503420 0.006841 2 2.0 0.0000 0.0000 104 +0.253545 0.003669 4 4.0 0.0000 0.0291 135 +0.250782 0.248019 8 8.0 0.0000 0.0620 146 +0.278089 0.305396 16 16.0 1.0000 0.1084 24 +0.292183 0.306278 32 32.0 0.0000 0.1627 32 +0.263581 0.234979 64 64.0 0.0000 0.0626 61 +0.233750 0.203920 128 128.0 1.0000 0.7347 106 finished run number of examples = 200 diff --git a/vowpalwabbit/core/src/reductions/ftrl.cc b/vowpalwabbit/core/src/reductions/ftrl.cc index be1116fac10..129f69012ab 100644 --- a/vowpalwabbit/core/src/reductions/ftrl.cc +++ b/vowpalwabbit/core/src/reductions/ftrl.cc @@ -458,7 +458,7 @@ std::shared_ptr VW::reductions::ftrl_setup(VW::setup_base_ all.output_config.audit || all.output_config.hash_inv ? learn_coin_betting : learn_coin_betting; all.weights.stride_shift(3); // NOTE: for more parameter storage b->ftrl_size = 6; - learn_returns_prediction = true; + learn_returns_prediction = false; } b->data.ftrl_alpha = b->ftrl_alpha; @@ -498,4 +498,4 @@ std::shared_ptr VW::reductions::ftrl_setup(VW::setup_base_ .set_print_update(VW::details::print_update_simple_label) .build(); return l; -} \ No newline at end of file +}