Skip to content

Commit

Permalink
remove all from data, move to output seq in finish
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Feb 8, 2019
1 parent d3787fa commit 2c8f965
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions vowpalwabbit/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ struct topk
{
uint32_t K; // rec number
priority_queue<scored_example, vector<scored_example>, compare_scored_examples> pr_queue;
vw* all;
};

void print_result(int f, priority_queue<scored_example, vector<scored_example>, compare_scored_examples>& pr_queue)
Expand Down Expand Up @@ -57,15 +56,23 @@ void print_result(int f, priority_queue<scored_example, vector<scored_example>,
}
}

void output_example(vw& all, example& ec)
void output_example(vw& all, topk& d, multi_ex& ec_seq)
{
label_data& ld = ec.l.simple;
for (auto example : ec_seq)
{
auto ec = *example;

label_data& ld = ec.l.simple;

all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.num_features);
if (ld.label != FLT_MAX)
all.sd->weighted_labels += ((double)ld.label) * ec.weight;
all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.num_features);
if (ld.label != FLT_MAX)
all.sd->weighted_labels += ((double)ld.label) * ec.weight;

print_update(all, ec);
print_update(all, ec);
}

for (int sink : all.final_prediction_sink)
print_result(sink, d.pr_queue);
}

template <bool is_learn>
Expand All @@ -87,16 +94,12 @@ void predict_or_learn(topk& d, LEARNER::single_learner& base, multi_ex& ec_seq)
d.pr_queue.pop();
d.pr_queue.push(make_pair(ec.pred.scalar, ec.tag));
}

output_example(*d.all, ec);
}
}

void finish_example(vw& all, topk& d, multi_ex& ec_seq)
{
for (int sink : all.final_prediction_sink)
print_result(sink, d.pr_queue);

output_example(all, d, ec_seq);
VW::clear_seq_and_finish_examples(all, ec_seq);
}

Expand All @@ -113,8 +116,6 @@ LEARNER::base_learner* topk_setup(options_i& options, vw& all)
if (!options.was_supplied("top"))
return nullptr;

data->all = &all;

LEARNER::learner<topk, multi_ex>& l =
init_learner(data, as_singleline(setup_base(options, all)), predict_or_learn<true>, predict_or_learn<false>);
l.set_finish_example(finish_example);
Expand Down

0 comments on commit 2c8f965

Please sign in to comment.