Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
cheng-tan committed Mar 1, 2023
1 parent 3d5b55e commit 9f0d953
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
2 changes: 2 additions & 0 deletions test/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,8 @@ def convert_tests_for_flatbuffers(
"337",
"338",
"351",
"367",
"368",
"399",
"400",
"404",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ inline void cb_explore_adf_base<ExploreType>::predict(

if (label_example != nullptr)
{
// predict path, replace the label example with an empty one
data._action_label = std::move(label_example->l.cb);
label_example->l.cb = std::move(data._empty_label);
}
Expand All @@ -137,6 +138,7 @@ inline void cb_explore_adf_base<ExploreType>::predict(

if (label_example != nullptr)
{
// predict path, restore label
label_example->l.cb = std::move(data._action_label);
data._empty_label.costs.clear();
data._empty_label.weight = 1.f;
Expand Down Expand Up @@ -214,7 +216,7 @@ void cb_explore_adf_base<ExploreType>::_update_stats(

for (const auto& example : ec_seq)
{
if (VW::ec_is_example_header_cb(*example) || VW::ec_is_example_header_cb_with_observations(*example))
if (VW::ec_is_example_header_cb(*example))
{
num_features += (ec_seq.size() - 1) *
(example->get_num_features() - example->feature_space[VW::details::CONSTANT_NAMESPACE].size());
Expand Down
18 changes: 5 additions & 13 deletions vowpalwabbit/core/src/reductions/cb/cb_adf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ VW::cb_class VW::get_observed_cost_or_default_cb_adf(const VW::multi_ex& example

for (const auto* example_ptr : examples)
{
std::vector<VW::cb_class> costs;
costs = example_ptr->l.cb.costs;

for (const auto& cost : costs)
for (const auto& cost : example_ptr->l.cb.costs)
{
if (cost.has_observed_cost())
{
Expand Down Expand Up @@ -66,26 +63,21 @@ VW::example* VW::test_cb_adf_sequence(const VW::multi_ex& ec_seq)
VW::example* ret = nullptr;
for (auto* ec : ec_seq)
{
std::vector<cb_class> costs;
costs = ec->l.cb.costs;

// Check if there is more than one cost for this example.
if (costs.size() > 1)
if (ec->l.cb.costs.size() > 1)
{
auto message = fmt::format(
"cb_adf: badly formatted example, only one cost can be known but found {}. Example number={}, tag={}",
costs.size(), ec->example_counter, VW::string_view{ec->tag.data(), ec->tag.size()});
ec->l.cb.costs.size(), ec->example_counter, VW::string_view{ec->tag.data(), ec->tag.size()});
THROW(message);
}

// Check whether the cost was initialized to a value.
if (costs.size() == 1 && costs[0].cost != FLT_MAX)
if (ec->l.cb.costs.size() == 1 && ec->l.cb.costs[0].cost != FLT_MAX)
{
ret = ec;
count += 1;
if (count > 1) {
THROW("cb_adf: badly formatted example, only one line can have a cost");
}
if (count > 1) THROW("cb_adf: badly formatted example, only one line can have a cost");
}
}

Expand Down

0 comments on commit 9f0d953

Please sign in to comment.