Skip to content

Commit

Permalink
Fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
cheng-tan committed Apr 4, 2023
1 parent 097fecc commit 9dd38b3
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 55 deletions.
1 change: 1 addition & 0 deletions vowpalwabbit/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ vw_add_test_executable(
tests/cb_large_actions_test.cc
tests/cb_las_one_pass_svd_test.cc
tests/cb_las_spanner_test.cc
tests/cb_with_observations_parser_test.cc
tests/ccb_parser_test.cc
tests/ccb_test.cc
tests/chain_hashing.cc
Expand Down
14 changes: 14 additions & 0 deletions vowpalwabbit/core/include/vw/core/cb_with_observations_label.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#include "vw/core/cb.h"

namespace VW
Expand All @@ -8,9 +10,21 @@ class cb_with_observations_label
cb_label event;
bool is_observation = false;
bool is_definitely_bad = false;

VW_ATTR(nodiscard) bool is_test_label() const;
void reset_to_default();
};

bool ec_is_example_header_cb_with_observations(VW::example const& ec);

extern VW::label_parser cb_with_observations_global;
} // namespace VW

namespace VW
{
namespace model_utils
{
size_t read_model_field(io_buf&, cb_with_observations_label&);
size_t write_model_field(io_buf&, const cb_with_observations_label&, const std::string&, bool);
} // namespace model_utils
} // namespace VW
6 changes: 4 additions & 2 deletions vowpalwabbit/core/include/vw/core/reductions/ftrl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
#include "vw/core/reductions/gd.h"
#include "vw/core/vw_fwd.h"

#include <stdint.h>

#include <cstddef>
#include <cstdint>

namespace
{
class ftrl_update_data
{
public:
Expand All @@ -35,6 +36,7 @@ class ftrl
uint32_t ftrl_size = 0;
std::vector<VW::reductions::details::gd_per_model_state> gd_per_model_states;
};
} // namespace

#include <memory>

Expand Down
73 changes: 49 additions & 24 deletions vowpalwabbit/core/src/cb_with_observations_label.cc
Original file line number Diff line number Diff line change
@@ -1,54 +1,79 @@
#include "vw/core/example.h"
#include "vw/core/model_utils.h"

bool VW::ec_is_example_header_cb_with_observations(VW::example const& ec)
namespace
{
const auto& costs = ec.l.cb_with_observations.event.costs;
if (costs.size() != 1) { return false; }
if (costs[0].probability == -1.f) { return true; }
float cb_with_observations_weight(const VW::cb_with_observations_label& ld) { return ld.event.weight; }

return false;
void parse_label_cb_with_observations(VW::cb_with_observations_label& /*ld*/, VW::reduction_features& /*red_features*/,
VW::label_parser_reuse_mem& /*reuse_mem*/, const std::vector<VW::string_view>& /*words*/,
VW::io::logger& /*logger*/)
{
// TODO: implement text format parsing for cb with observations
}
} // namespace

namespace
namespace VW
{
namespace model_utils
{
size_t read_model_field(io_buf& io, VW::cb_with_observations_label& cb_with_obs)
{
float weight_cb_with_observations(const VW::cb_with_observations_label& ld) { return ld.event.weight; }
size_t bytes = 0;
bytes += read_model_field(io, cb_with_obs.event);
bytes += read_model_field(io, cb_with_obs.is_observation);
bytes += read_model_field(io, cb_with_obs.is_definitely_bad);
return bytes;
}

void default_label_cb_with_observations(VW::cb_with_observations_label& ld)
size_t write_model_field(io_buf& io, const VW::cb_with_observations_label& cb_with_obs,
const std::string& upstream_name, bool text)
{
ld.event.reset_to_default();
ld.is_observation = false;
ld.is_definitely_bad = false;
size_t bytes = 0;
bytes += VW::model_utils::write_model_field(io, cb_with_obs.event, upstream_name + "_event", text);
bytes += write_model_field(io, cb_with_obs.is_observation, upstream_name + "_is_observation", text);
bytes += write_model_field(io, cb_with_obs.is_definitely_bad, upstream_name + "_is_definitely_bad", text);
return bytes;
}
} // namespace model_utils

bool test_label_cb_with_observations(const VW::cb_with_observations_label& ld) { return ld.event.is_test_label(); }
bool ec_is_example_header_cb_with_observations(VW::example const& ec)
{
const auto& costs = ec.l.cb_with_observations.event.costs;
if (costs.size() != 1) { return false; }
if (costs[0].probability == -1.f) { return true; }

void parse_label_cb_with_observations(VW::cb_with_observations_label& /*ld*/, VW::reduction_features& /*red_features*/,
VW::label_parser_reuse_mem& /*reuse_mem*/, const std::vector<VW::string_view>& /*words*/,
VW::io::logger& /*logger*/)
return false;
}

void cb_with_observations_label::reset_to_default()
{
// TODO: implement text format parsing for cb with observations
event.reset_to_default();
is_observation = false;
is_definitely_bad = false;
}
} // namespace

VW::label_parser VW::cb_with_observations_global = {
bool cb_with_observations_label::is_test_label() const { return event.is_test_label(); }

VW::label_parser cb_with_observations_global = {
// default_label
[](VW::polylabel& label) { default_label_cb_with_observations(label.cb_with_observations); },
[](VW::polylabel& label) { label.cb_with_observations.reset_to_default(); },
// parse_label
[](VW::polylabel& label, VW::reduction_features& red_features, VW::label_parser_reuse_mem& reuse_mem,
const VW::named_labels* /*ldict*/, const std::vector<VW::string_view>& words, VW::io::logger& logger)
{ parse_label_cb_with_observations(label.cb_with_observations, red_features, reuse_mem, words, logger); },
// cache_label
// TODO: implement this for cb_with_observations
[](const VW::polylabel& label, const VW::reduction_features& /*red_features*/, io_buf& cache,
const std::string& upstream_name, bool text)
{ return VW::model_utils::write_model_field(cache, label.cb, upstream_name, text); },
{ return VW::model_utils::write_model_field(cache, label.cb_with_observations, upstream_name, text); },
// read_cached_label
[](VW::polylabel& label, VW::reduction_features& /*red_features*/, io_buf& cache)
{ return VW::model_utils::read_model_field(cache, label.cb); }, // TODO: implement this for cb_with_observations
{ return VW::model_utils::read_model_field(cache, label.cb_with_observations); },
// get_weight
[](const VW::polylabel& label, const VW::reduction_features& /*red_features*/)
{ return weight_cb_with_observations(label.cb_with_observations); },
{ return cb_with_observations_weight(label.cb_with_observations); },
// test_label
[](const VW::polylabel& label) { return test_label_cb_with_observations(label.cb_with_observations); },
[](const VW::polylabel& label) { return label.cb_with_observations.is_test_label(); },
// Label type
VW::label_type_t::CB_WITH_OBSERVATIONS};
} // namespace VW
2 changes: 0 additions & 2 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_greedy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,11 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::cb_explore_adf_greedy_setu
{
input_label_type = VW::label_type_t::CB_WITH_OBSERVATIONS;
output_label_type = VW::label_type_t::CB_WITH_OBSERVATIONS;
all.parser_runtime.example_parser->lbl_parser = VW::cb_with_observations_global;
}
else
{
input_label_type = VW::label_type_t::CB;
output_label_type = VW::label_type_t::CB;
all.parser_runtime.example_parser->lbl_parser = VW::cb_label_parser_global;
}

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }
Expand Down
43 changes: 21 additions & 22 deletions vowpalwabbit/core/src/reductions/interaction_ground.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ class ik_stack_builder : public VW::default_reduction_stack_setup
};

std::vector<std::vector<VW::namespace_index>> get_ik_interactions(
std::vector<std::vector<VW::namespace_index>> interactions, VW::example* observation_ex)
const std::vector<std::vector<VW::namespace_index>>& interactions, const VW::example& observation_ex)
{
std::vector<std::vector<VW::namespace_index>> new_interactions;
for (auto& interaction : interactions)
for (const auto& interaction : interactions)
{
for (auto& obs_ns : observation_ex->indices)
for (auto obs_ns : observation_ex.indices)
{
if (obs_ns == VW::details::DEFAULT_NAMESPACE) { obs_ns = VW::details::IGL_FEEDBACK_NAMESPACE; }

Expand All @@ -72,21 +72,21 @@ std::vector<std::vector<VW::namespace_index>> get_ik_interactions(
return new_interactions;
}

void add_obs_features_to_ik_ex(VW::example* ik_ex, VW::example* obs_ex)
void add_obs_features_to_ik_ex(VW::example& ik_ex, const VW::example& obs_ex)
{
for (auto& obs_ns : obs_ex->indices)
for (auto obs_ns : obs_ex.indices)
{
ik_ex->indices.push_back(obs_ns);
ik_ex.indices.push_back(obs_ns);

for (size_t i = 0; i < obs_ex->feature_space[obs_ns].indices.size(); i++)
for (size_t i = 0; i < obs_ex.feature_space[obs_ns].indices.size(); i++)
{
auto feature_hash = obs_ex->feature_space[obs_ns].indices[i];
auto feature_val = obs_ex->feature_space[obs_ns].values[i];
auto feature_hash = obs_ex.feature_space[obs_ns].indices[i];
auto feature_val = obs_ex.feature_space[obs_ns].values[i];

if (obs_ns == VW::details::DEFAULT_NAMESPACE) { obs_ns = VW::details::IGL_FEEDBACK_NAMESPACE; }

ik_ex->feature_space[obs_ns].indices.push_back(feature_hash);
ik_ex->feature_space[obs_ns].values.push_back(feature_val);
ik_ex.feature_space[obs_ns].indices.push_back(feature_hash);
ik_ex.feature_space[obs_ns].values.push_back(feature_val);
}
}
}
Expand Down Expand Up @@ -118,7 +118,7 @@ void learn(interaction_ground& igl, learner& base, VW::multi_ex& ec_seq)
size_t chosen_action_idx = 0;

const auto it = std::find_if(ec_seq.begin(), ec_seq.end(),
[](VW::example* item) { return !item->l.cb_with_observations.event.costs.empty(); });
[](const VW::example* item) { return !item->l.cb_with_observations.event.costs.empty(); });

if (it != ec_seq.end()) { chosen_action_idx = std::distance(ec_seq.begin(), it); }

Expand All @@ -130,8 +130,7 @@ void learn(interaction_ground& igl, learner& base, VW::multi_ex& ec_seq)
ec_seq.pop_back();
}

VW::action_scores action_scores = ec_seq[0]->pred.a_s;
std::vector<std::vector<VW::namespace_index>> ik_interactions = get_ik_interactions(igl.interactions, observation_ex);
std::vector<std::vector<VW::namespace_index>> ik_interactions = get_ik_interactions(igl.interactions, *observation_ex);

for (size_t i = 0; i < ec_seq.size(); i++)
{
Expand All @@ -140,12 +139,12 @@ void learn(interaction_ground& igl, learner& base, VW::multi_ex& ec_seq)
// TODO: Do we need constant feature here? If so, VW::add_constant_feature
VW::details::append_example_namespaces_from_example(igl.ik_ex, *action_ex);

add_obs_features_to_ik_ex(&igl.ik_ex, observation_ex);
add_obs_features_to_ik_ex(igl.ik_ex, *observation_ex);
// 1. set up ik ex
igl.ik_ex.l.simple.label = i == chosen_action_idx ? 1.f : -1.f;

auto action_score_iter = std::find_if(
action_scores.begin(), action_scores.end(), [&i](VW::action_score& element) { return element.action == i; });
auto action_score_iter = std::find_if(ec_seq[0]->pred.a_s.begin(), ec_seq[0]->pred.a_s.end(),
[i](const VW::action_score& element) { return element.action == i; });

float pa = action_score_iter->score;

Expand All @@ -172,7 +171,7 @@ void learn(interaction_ground& igl, learner& base, VW::multi_ex& ec_seq)
// 4. update multi line ex label
if (ik_pred * 2 > 1.f)
{
bool is_definitely_bad = observation_ex->l.cb_with_observations.is_definitely_bad;
int is_definitely_bad = static_cast<int>(observation_ex->l.cb_with_observations.is_definitely_bad);
predicted_cost = -1.f + is_definitely_bad * (1.f + 1.f / p_unlabeled_prior);
}

Expand Down Expand Up @@ -312,15 +311,15 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::interaction_ground_setup(V
auto ftrl_coin = pi_learner->get_learner_by_name_prefix("ftrl-Coin")->shared_from_this();

// 2. prepare args for ik stack
std::string ik_args = "--quiet --link=logistic --loss_function=logistic --coin";
std::vector<std::string> ik_args = {"--quiet", "--link=logistic", "--loss_function=logistic", "--coin"};
std::unique_ptr<options_i, options_deleter_type> ik_options(
new config::options_cli(VW::split_command_line(ik_args)), [](VW::config::options_i* ptr) { delete ptr; });
new config::options_cli(ik_args), [](VW::config::options_i* ptr) { delete ptr; });

assert(ik_options->was_supplied("cb_explore_adf") == false || ik_options->was_supplied("cb_adf") == false);
assert(ik_options->was_supplied("loss_function") == true);

ld->ik_all = VW::initialize_experimental(VW::make_unique<VW::config::options_cli>(VW::split_command_line(ik_args)));
all->parser_runtime.example_parser->lbl_parser = VW::get_label_parser(label_type_t::CB_WITH_OBSERVATIONS);
ld->ik_all = VW::initialize_experimental(
VW::make_unique<VW::config::options_cli>(ik_args), nullptr, nullptr, nullptr, &all->logger);

std::unique_ptr<ik_stack_builder> ik_builder = VW::make_unique<ik_stack_builder>(ftrl_coin);
ik_builder->delayed_state_attach(*ld->ik_all, *ik_options);
Expand Down
44 changes: 44 additions & 0 deletions vowpalwabbit/core/tests/cb_with_observations_parser_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.

#include "vw/common/string_view.h"
#include "vw/common/text_utils.h"
#include "vw/core/cb_with_observations_label.h"
#include "vw/core/memory.h"
#include "vw/core/parse_primitives.h"
#include "vw/core/parser.h"
#include "vw/io/logger.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <memory>
#include <vector>


TEST(CbWithObservations, CacheLabel)
{
auto backing_vector = std::make_shared<std::vector<char>>();
VW::io_buf io_writer;
io_writer.add_file(VW::io::create_vector_writer(backing_vector));

VW::cb_with_observations_label cb_with_obs_label;
cb_with_obs_label.event.weight = 5.f;
cb_with_obs_label.is_definitely_bad = true;
cb_with_obs_label.is_observation = true;

VW::model_utils::write_model_field(io_writer, cb_with_obs_label, "", false);
io_writer.flush();

VW::io_buf io_reader;
io_reader.add_file(VW::io::create_buffer_view(backing_vector->data(), backing_vector->size()));

auto uncached_label = VW::make_unique<VW::cb_with_observations_label>();
uncached_label->reset_to_default();
VW::model_utils::read_model_field(io_reader, *uncached_label);

EXPECT_FLOAT_EQ(uncached_label->event.weight, 5.f);
EXPECT_EQ(uncached_label->is_definitely_bad, true);
EXPECT_EQ(uncached_label->is_observation, true);
}
5 changes: 0 additions & 5 deletions vowpalwabbit/json_parser/src/parse_example_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -603,11 +603,6 @@ class MultiState : public BaseState<audit>
{
VW::cb_with_observations_label* ld = &(*ctx.examples)[0]->l.cb_with_observations;
VW::cb_class f;

f.partial_prediction = 0.;
f.action = static_cast<uint32_t>(VW::uniform_hash("shared", 6, 0));
f.cost = FLT_MAX;
f.probability = -1.f;
ld->event.costs.push_back(f);
}
else if (ctx._label_parser.label_type == VW::label_type_t::CCB)
Expand Down

0 comments on commit 9dd38b3

Please sign in to comment.