Skip to content

Commit

Permalink
Convert TopK reduction to be multiline example based (VowpalWabbit#1752)
Browse files Browse the repository at this point in the history
* Make topk a multiline learner

* Fix test for new format and rename B to K

* revert destructor usage

* remove all from data, move to output seq in finish

* Revert "remove all from data, move to output seq in finish"

This reverts commit 2c8f965.
  • Loading branch information
jackgerrits authored and Borislav Nikolov committed Mar 7, 2019
1 parent 1ef3cac commit 5edac80
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 37 deletions.
23 changes: 10 additions & 13 deletions test/train-sets/ref/topk-rec.stderr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
creating quadratic features for pairs: MF
creating quadratic features for pairs: MF
predictions = topk-rec.predict
Num weight bits = 18
learning rate = 0.5
Expand All @@ -12,20 +12,17 @@ loss last counter weight label predict features
0.000003 0.000003 1 1.0 3.0000 2.9983 4
0.000004 0.000005 2 2.0 0.0000 0.0022 4
0.000003 0.000000 3 3.0 2.0000 2.0000 4
0.000003 n.a. 4 4.0 unknown 0.0000 1
0.000003 0.000003 5 5.0 0.0000 0.0018 4
0.000004 0.000011 6 6.0 3.0000 2.9968 4
0.000004 0.000001 7 7.0 1.0000 1.0007 4
0.000004 n.a. 8 8.0 unknown 0.0000 1
0.000003 0.000000 9 9.0 2.0000 2.0004 4
0.000003 0.000000 10 10.0 1.0000 1.0003 4
0.000002 0.000000 11 11.0 3.0000 2.9995 4
0.000002 n.a. 12 12.0 unknown 0.0000 1
0.000003 0.000003 4 4.0 0.0000 0.0018 4
0.000004 0.000011 5 5.0 3.0000 2.9968 4
0.000004 0.000001 6 6.0 1.0000 1.0007 4
0.000003 0.000000 7 7.0 2.0000 2.0004 4
0.000003 0.000000 8 8.0 1.0000 1.0003 4
0.000002 0.000000 9 9.0 3.0000 2.9995 4

finished run
number of examples = 12
weighted example sum = 12.000000
number of examples = 9
weighted example sum = 9.000000
weighted label sum = 15.000000
average loss = 0.000002
best constant = 1.666667
total feature number = 39
total feature number = 36
53 changes: 29 additions & 24 deletions vowpalwabbit/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ license as described in the file LICENSE.
using namespace std;
using namespace VW::config;

typedef pair<float, v_array<char> > scored_example;
using scored_example = pair<float, v_array<char>>;

struct compare_scored_examples
{
Expand All @@ -22,8 +22,9 @@ struct compare_scored_examples

struct topk
{
uint32_t B; // rec number
uint32_t K; // rec number
priority_queue<scored_example, vector<scored_example>, compare_scored_examples> pr_queue;
vw* all;
};

void print_result(int f, priority_queue<scored_example, vector<scored_example>, compare_scored_examples>& pr_queue)
Expand Down Expand Up @@ -56,45 +57,47 @@ void print_result(int f, priority_queue<scored_example, vector<scored_example>,
}
}

void output_example(vw& all, topk& d, example& ec)
void output_example(vw& all, example& ec)
{
label_data& ld = ec.l.simple;

all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.num_features);
if (ld.label != FLT_MAX)
all.sd->weighted_labels += ((double)ld.label) * ec.weight;

if (example_is_newline(ec))
for (int sink : all.final_prediction_sink) print_result(sink, d.pr_queue);

print_update(all, ec);
}

template <bool is_learn>
void predict_or_learn(topk& d, LEARNER::single_learner& base, example& ec)
void predict_or_learn(topk& d, LEARNER::single_learner& base, multi_ex& ec_seq)
{
if (example_is_newline(ec))
return; // do not predict newline
for (auto example : ec_seq)
{
auto ec = *example;

if (is_learn)
base.learn(ec);
else
base.predict(ec);
if (is_learn)
base.learn(ec);
else
base.predict(ec);

if (d.pr_queue.size() < d.B)
d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
if (d.pr_queue.size() < d.K)
d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
else if (d.pr_queue.top().first < ec.pred.scalar)
{
d.pr_queue.pop();
d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
}

else if (d.pr_queue.top().first < ec.pred.scalar)
{
d.pr_queue.pop();
d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
output_example(*d.all, ec);
}
}

void finish_example(vw& all, topk& d, example& ec)
void finish_example(vw& all, topk& d, multi_ex& ec_seq)
{
output_example(all, d, ec);
VW::finish_example(all, ec);
for (int sink : all.final_prediction_sink)
print_result(sink, d.pr_queue);

VW::clear_seq_and_finish_examples(all, ec_seq);
}

void finish(topk& d) { d.pr_queue = priority_queue<scored_example, vector<scored_example>, compare_scored_examples>(); }
Expand All @@ -104,13 +107,15 @@ LEARNER::base_learner* topk_setup(options_i& options, vw& all)
auto data = scoped_calloc_or_throw<topk>();

option_group_definition new_options("Top K");
new_options.add(make_option("top", data->B).keep().help("top k recommendation"));
new_options.add(make_option("top", data->K).keep().help("top k recommendation"));
options.add_and_parse(new_options);

if (!options.was_supplied("top"))
return nullptr;

LEARNER::learner<topk, example>& l =
data->all = &all;

LEARNER::learner<topk, multi_ex>& l =
init_learner(data, as_singleline(setup_base(options, all)), predict_or_learn<true>, predict_or_learn<false>);
l.set_finish_example(finish_example);
l.set_finish(finish);
Expand Down

0 comments on commit 5edac80

Please sign in to comment.