Skip to content

Commit

Permalink
fix: Baseline should signal its enablement through an index and not a…
Browse files Browse the repository at this point in the history
… feature (#3069)

* fix: Baseline should signal its enablement through an index and not a feature

* formatting

* Update baseline_test.stderr
  • Loading branch information
jackgerrits authored Jun 14, 2021
1 parent a4f513c commit 9de7a00
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 29 deletions.
2 changes: 1 addition & 1 deletion test/train-sets/ref/baseline_test.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ weighted label sum = 91.000000
average loss = 0.190245
best constant = 0.455000
best constant's loss = 0.247975
total feature number = 15482
total feature number = 15482
34 changes: 7 additions & 27 deletions vowpalwabbit/baseline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.

#include "baseline.h"

#include <cfloat>
#include <cerrno>

Expand All @@ -15,47 +17,25 @@ using namespace VW::config;
namespace
{
const float max_multiplier = 1000.f;
const size_t baseline_enabled_idx = 1357; // feature index for enabling baseline
} // namespace

namespace BASELINE
{
void set_baseline_enabled(example* ec)
{
auto& fs = ec->feature_space[message_namespace];
for (auto& f : fs)
{
if (f.index() == baseline_enabled_idx)
{
f.value() = 1;
return;
}
}
// if not found, push new feature
fs.push_back(1, baseline_enabled_idx);
if (!baseline_enabled(ec)) { ec->indices.push_back(baseline_enabled_message_namespace); }
}

void reset_baseline_disabled(example* ec)
{
auto& fs = ec->feature_space[message_namespace];
for (auto& f : fs)
{
if (f.index() == baseline_enabled_idx)
{
f.value() = 0;
return;
}
}
auto it = std::find(ec->indices.begin(), ec->indices.end(), baseline_enabled_message_namespace);
if (it != ec->indices.end()) { ec->indices.erase(it); }
}

bool baseline_enabled(example* ec)
{
auto& fs = ec->feature_space[message_namespace];
for (auto& f : fs)
{
if (f.index() == baseline_enabled_idx) return f.value() == 1;
}
return false;
auto it = std::find(ec->indices.begin(), ec->indices.end(), baseline_enabled_message_namespace);
return it != ec->indices.end();
}
} // namespace BASELINE

Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ constexpr unsigned char spelling_namespace = 133; // this is \x85
constexpr unsigned char conditioning_namespace = 134; // this is \x86
constexpr unsigned char dictionary_namespace = 135; // this is \x87
constexpr unsigned char node_id_namespace = 136; // this is \x88
constexpr unsigned char message_namespace = 137; // this is \x89
constexpr unsigned char baseline_enabled_message_namespace = 137; // this is \x89
constexpr unsigned char ccb_slot_namespace = 139;
constexpr unsigned char ccb_id_namespace = 140;

Expand Down

0 comments on commit 9de7a00

Please sign in to comment.