diff --git a/test/RunTests b/test/RunTests index 16f8da876d1..52e86ab54d0 100755 --- a/test/RunTests +++ b/test/RunTests @@ -1617,23 +1617,26 @@ echo "1 | feature:1" | {VW} -a --initial_weight 0.1 --initial_t 0.3 {VW} -d train-sets/rcv1_multiclass.dat --cbify 2 --epsilon 0.05 train-sets/ref/rcv1_multiclass.stderr -# Test 170 cbify adf, epsilon-greedy +# Test 170: cbify adf, epsilon-greedy {VW} --cbify 10 --cb_explore_adf --epsilon 0.05 -d train-sets/multiclass train-sets/ref/cbify_epsilon_adf.stderr -# Test 171 cbify cs, epsilon-greedy +# Test 171: cbify cs, epsilon-greedy {VW} --cbify 3 --cbify_cs --epsilon 0.05 -d train-sets/cs_cb train-sets/ref/cbify_epsilon_cs.stderr -# Test 172 cbify adf cs, epsilon-greedy +# Test 172: cbify adf cs, epsilon-greedy {VW} --cbify 3 --cbify_cs --cb_explore_adf --epsilon 0.05 -d train-sets/cs_cb train-sets/ref/cbify_epsilon_cs_adf.stderr -# Test 173 cbify adf, regcb +# Test 173: cbify adf, regcb {VW} --cbify 10 --cb_explore_adf --cb_type mtr --regcb --mellowness 0.01 -d train-sets/multiclass train-sets/ref/cbify_regcb.stderr -# Test 174 cbify adf, regcbopt +# Test 174: cbify adf, regcbopt {VW} --cbify 10 --cb_explore_adf --cb_type mtr --regcbopt --mellowness 0.01 -d train-sets/multiclass train-sets/ref/cbify_regcbopt.stderr +# Test 175: cbify ldf, regcbopt +{VW} -d train-sets/cs_test.ldf --cbify_ldf --cb_type mtr --regcbopt --mellowness 0.01 + train-sets/ref/cbify_ldf_regcbopt.stderr diff --git a/test/train-sets/ref/cbify_ldf_regcbopt.stderr b/test/train-sets/ref/cbify_ldf_regcbopt.stderr new file mode 100644 index 00000000000..1391984c333 --- /dev/null +++ b/test/train-sets/ref/cbify_ldf_regcbopt.stderr @@ -0,0 +1,18 @@ +Num weight bits = 18 +learning rate = 0.5 +initial_t = 0 +power_t = 0.5 +using no cache +Reading datafile = train-sets/cs_test.ldf +num sources = 1 +average since example example current current current +loss last counter weight label predict features +0.000000 0.000000 1 1.0 unknown 0 12 +0.000000 0.000000 2 2.0 unknown 0 8 + +finished run +number of examples = 3 +weighted example sum = 3.000000 +weighted label sum = 0.000000 +average loss = 0.333333 +total feature number = 28 diff --git a/vowpalwabbit/cbify.cc b/vowpalwabbit/cbify.cc index 3b88d0fc46d..7159759b45f 100644 --- a/vowpalwabbit/cbify.cc +++ b/vowpalwabbit/cbify.cc @@ -12,6 +12,7 @@ using namespace LEARNER; using namespace exploration; using namespace ACTION_SCORE; +// using namespace COST_SENSITIVE; using namespace std; using namespace VW::config; @@ -35,6 +36,11 @@ struct cbify cbify_adf_data adf_data; float loss0; float loss1; + + // for ldf inputs + std::vector> cs_costs; + std::vector> cb_costs; + std::vector cb_as; }; float loss(cbify& data, uint32_t label, uint32_t final_prediction) @@ -59,6 +65,18 @@ float loss_cs(cbify& data, v_array& costs, uint32_t fina return data.loss0 + (data.loss1 - data.loss0) * cost; } +float loss_csldf(cbify& data, std::vector>& cs_costs, uint32_t final_prediction) +{ + float cost = 0.; + for (auto costs : cs_costs) + { if (costs[0].class_index == final_prediction) + { cost = costs[0].x; + break; + } + } + return data.loss0 + (data.loss1 - data.loss0) * cost; +} + template inline void delete_it(T* p) { @@ -225,6 +243,139 @@ void init_adf_data(cbify& data, const size_t num_actions) } } +template +void do_actual_learning_ldf(cbify& data, multi_learner& base, multi_ex& ec_seq) +{ + auto& cs_costs = data.cs_costs; + auto& cb_costs = data.cb_costs; + auto& cb_as = data.cb_as; + + // change label and pred data for cb + cs_costs.resize(ec_seq.size()); + cb_costs.resize(ec_seq.size()); + cb_as.resize(ec_seq.size()); + for (size_t i = 0; i < ec_seq.size(); ++i) + { + auto& ec = *ec_seq[i]; + cs_costs[i] = ec.l.cs.costs; + cb_costs[i].clear(); + cb_as[i].clear(); + ec.l.cb.costs = cb_costs[i]; + ec.pred.a_s = cb_as[i]; + } + + base.predict(ec_seq); + + auto& out_ec = *ec_seq[0]; + + uint32_t chosen_action; + if (sample_after_normalizing(data.app_seed + data.example_counter++, begin_scores(out_ec.pred.a_s), end_scores(out_ec.pred.a_s), chosen_action)) + THROW("Failed to sample from pdf"); + + CB::cb_class cl; + cl.action = out_ec.pred.a_s[chosen_action].action + 1; + cl.probability = out_ec.pred.a_s[chosen_action].score; + + if(!cl.action) + THROW("No action with non-zero probability found!"); + + cl.cost = loss_csldf(data, cs_costs, cl.action); + + // add cb label to chosen action + auto& lab = ec_seq[cl.action - 1]->l.cb; + lab.costs.push_back(cl); + + base.learn(ec_seq); + + // set cs prediction and reset cs costs + for (size_t i = 0; i < ec_seq.size(); ++i) + { + if (i == cl.action - 1) + { + ec_seq[i]->pred.multiclass = cl.action; + ec_seq[i]->l.cs.costs = cs_costs[cl.action - 1]; // only need this cost for eval + } + else + ec_seq[i]->pred.multiclass = 0; + } +} + +void output_example(vw& all, example& ec, bool& hit_loss, multi_ex* ec_seq) +{ + COST_SENSITIVE::label& ld = ec.l.cs; + v_array costs = ld.costs; + + if (example_is_newline(ec)) return; + if (COST_SENSITIVE::ec_is_example_header(ec)) return; + + all.sd->total_features += ec.num_features; + + float loss = 0.; + + uint32_t predicted_class = ec.pred.multiclass; + + if (!COST_SENSITIVE::cs_label.test_label(&ec.l)) + { + for (size_t j=0; jsum_loss += loss; + all.sd->sum_loss_since_last_dump += loss; + } + + for (int sink : all.final_prediction_sink) + all.print(sink, (float)ec.pred.multiclass, 0, ec.tag); + + if (all.raw_prediction > 0) + { + string outputString; + stringstream outputStringStream(outputString); + for (size_t i = 0; i < costs.size(); i++) + { + if (i > 0) outputStringStream << ' '; + outputStringStream << costs[i].class_index << ':' << costs[i].partial_prediction; + } + //outputStringStream << endl; + all.print_text(all.raw_prediction, outputStringStream.str(), ec.tag); + } + + COST_SENSITIVE::print_update(all, COST_SENSITIVE::cs_label.test_label(&ec.l), ec, ec_seq, false, predicted_class); +} + +void output_example_seq(vw& all, multi_ex& ec_seq) +{ + if (ec_seq.size() == 0) return; + all.sd->weighted_labeled_examples += ec_seq[0]->weight; + all.sd->example_number++; + + bool hit_loss = false; + for (example* ec : ec_seq) + output_example(all, *ec, hit_loss, &(ec_seq)); + + if (all.raw_prediction > 0) + { + v_array empty = { nullptr, nullptr, nullptr, 0 }; + all.print_text(all.raw_prediction, "", empty); + } +} + +void finish_multiline_example(vw& all, cbify& data, multi_ex& ec_seq) +{ + if (ec_seq.size() > 0) + { + output_example_seq(all, ec_seq); + // global_print_newline(all); + } + VW::clear_seq_and_finish_examples(all, ec_seq); +} + base_learner* cbify_setup(options_i& options, vw& all) { uint32_t num_actions = 0; @@ -298,3 +449,46 @@ base_learner* cbify_setup(options_i& options, vw& all) return make_base(*l); } + +base_learner* cbifyldf_setup(options_i& options, vw& all) +{ + auto data = scoped_calloc_or_throw(); + bool cbify_ldf_option = false; + + option_group_definition new_options("Make csoaa_ldf into Contextual Bandit"); + new_options + .add(make_option("cbify_ldf", cbify_ldf_option).keep().help("Convert csoaa_ldf into a contextual bandit problem")) + .add(make_option("loss0", data->loss0).default_value(0.f).help("loss for correct label")) + .add(make_option("loss1", data->loss1).default_value(1.f).help("loss for incorrect label")); + options.add_and_parse(new_options); + + if (!options.was_supplied("cbify_ldf")) + return nullptr; + + data->app_seed = uniform_hash("vw", 2, 0); + data->all = &all; + + if (!options.was_supplied("cb_explore_adf")) + { + options.insert("cb_explore_adf", ""); + } + options.insert("cb_min_cost", to_string(data->loss0)); + options.insert("cb_max_cost", to_string(data->loss1)); + + if (options.was_supplied("baseline")) + { + stringstream ss; + ss << max(abs(data->loss0), abs(data->loss1)) / (data->loss1 - data->loss0); + options.insert("lr_multiplier", ss.str()); + } + + multi_learner* base = as_multiline(setup_base(options, all)); + learner& l = init_learner(data, base, do_actual_learning_ldf, do_actual_learning_ldf, 1, prediction_type::multiclass); + + l.set_finish(finish); + l.set_finish_example(finish_multiline_example); + all.p->lp = COST_SENSITIVE::cs_label; + all.delete_prediction = nullptr; + + return make_base(l); +} diff --git a/vowpalwabbit/cbify.h b/vowpalwabbit/cbify.h index 9b29228bf4f..2ff12b9c2a1 100644 --- a/vowpalwabbit/cbify.h +++ b/vowpalwabbit/cbify.h @@ -3,4 +3,6 @@ Copyright (c) by respective owners including Yahoo!, Microsoft, and individual contributors. All rights reserved. Released under a BSD license as described in the file LICENSE. */ + LEARNER::base_learner* cbify_setup(VW::config::options_i& options, vw& all); +LEARNER::base_learner* cbifyldf_setup(VW::config::options_i& options, vw& all); diff --git a/vowpalwabbit/parse_args.cc b/vowpalwabbit/parse_args.cc index a1cf80f2f2e..2006413bc84 100644 --- a/vowpalwabbit/parse_args.cc +++ b/vowpalwabbit/parse_args.cc @@ -1269,6 +1269,7 @@ void parse_reductions(options_i& options, vw& all) all.reduction_stack.push_back(cb_explore_setup); all.reduction_stack.push_back(cb_explore_adf_setup); all.reduction_stack.push_back(cbify_setup); + all.reduction_stack.push_back(cbifyldf_setup); all.reduction_stack.push_back(explore_eval_setup); all.reduction_stack.push_back(ExpReplay::expreplay_setup<'c', COST_SENSITIVE::cs_label>); all.reduction_stack.push_back(Search::setup);