Skip to content

Commit

Permalink
fix to make invert_hash work with bfgs (#1892)
Browse files Browse the repository at this point in the history
  • Loading branch information
agroh1 authored and jackgerrits committed May 28, 2019
1 parent d533941 commit d00d8fb
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions vowpalwabbit/bfgs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -943,12 +943,16 @@ void end_pass(bfgs& b)
}

// placeholder
template <bool audit>
void predict(bfgs& b, base_learner&, example& ec)
{
vw* all = b.all;
ec.pred.scalar = bfgs_predict(*all, ec);
if (audit)
GD::print_audit_features(*(b.all), ec);
}

template <bool audit>
void learn(bfgs& b, base_learner& base, example& ec)
{
vw* all = b.all;
Expand All @@ -957,7 +961,7 @@ void learn(bfgs& b, base_learner& base, example& ec)
if (b.current_pass <= b.final_pass)
{
if (test_example(ec))
predict(b, base, ec);
predict<audit>(b, base, ec);
else
process_example(*all, b, ec);
}
Expand Down Expand Up @@ -1147,11 +1151,22 @@ base_learner* bfgs_setup(options_i& options, vw& all)
all.bfgs = true;
all.weights.stride_shift(2);

learner<bfgs, example>& l = init_learner(b, learn, predict, all.weights.stride());
l.set_save_load(save_load);
l.set_init_driver(init_driver);
l.set_end_pass(end_pass);
l.set_finish(finish);
void (*learn_ptr)(bfgs&, base_learner&, example&) = nullptr;
if (all.audit)
learn_ptr = learn<true>;
else
learn_ptr = learn<false>;

learner<bfgs, example>* l;
if (all.audit || all.hash_inv)
l = &init_learner(b, learn_ptr, predict<true>, all.weights.stride());
else
l = &init_learner(b, learn_ptr, predict<false>, all.weights.stride());

l->set_save_load(save_load);
l->set_init_driver(init_driver);
l->set_end_pass(end_pass);
l->set_finish(finish);

return make_base(l);
return make_base(*l);
}

0 comments on commit d00d8fb

Please sign in to comment.