Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

do not output progressive validation loss for oaa with subsampling #1880

Merged
merged 8 commits into from
May 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/RunTests
Original file line number Diff line number Diff line change
Expand Up @@ -1706,3 +1706,8 @@ printf '3 |f a b c |e x y z\n2 |f a y c |e x\n' | {VW} --oaa 3 -q ef --audit
{VW} -d train-sets/b1848_dsjson_parser_regression.txt --dsjson --cb_explore_adf -P 1
train-sets/ref/b1848_dsjson_parser_regression.stderr

# Test 190: one-against-all with subsampling
{VW} -k --oaa 10 --oaa_subsample 5 -c --passes 10 -d train-sets/multiclass --holdout_off
train-sets/ref/oaa_subsample.stderr

# Do not delete this line or the empty line above it
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

25 changes: 25 additions & 0 deletions test/train-sets/ref/oaa_subsample.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
decay_learning_rate = 1
creating cache_file = train-sets/multiclass.cache
Reading datafile = train-sets/multiclass
num sources = 1
average since example example current current current
loss last counter weight label predict features
n.a. n.a. 1 1.0 1 1 2
n.a. n.a. 2 2.0 2 1 2
n.a. n.a. 4 4.0 4 1 2
n.a. n.a. 8 8.0 8 1 2
n.a. n.a. 16 16.0 6 6 2
n.a. n.a. 32 32.0 2 2 2
n.a. n.a. 64 64.0 4 4 2

finished run
number of examples per pass = 10
passes used = 10
weighted example sum = 100.000000
weighted label sum = 0.000000
average loss = n.a.
total feature number = 200
4 changes: 2 additions & 2 deletions vowpalwabbit/multiclass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ void print_update_with_probability(vw& all, example& ec, uint32_t pred)
}
void print_update_with_score(vw& all, example& ec, uint32_t pred) { print_update<print_score>(all, ec, pred); }

void finish_example(vw& all, example& ec)
void finish_example(vw& all, example& ec, bool update_loss)
{
float loss = 0;
if (ec.l.multi.label != (uint32_t)ec.pred.multiclass && ec.l.multi.label != (uint32_t)-1)
loss = ec.weight;

all.sd->update(ec.test_only, ec.l.multi.label != (uint32_t)-1, loss, ec.weight, ec.num_features);
all.sd->update(ec.test_only, update_loss && (ec.l.multi.label != (uint32_t)-1), loss, ec.weight, ec.num_features);

for (int sink : all.final_prediction_sink)
if (!all.sd->ldict)
Expand Down
10 changes: 8 additions & 2 deletions vowpalwabbit/multiclass.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@ extern label_parser mc_label;
void print_update_with_probability(vw& all, example& ec, uint32_t prediction);
void print_update_with_score(vw& all, example& ec, uint32_t prediction);

void finish_example(vw& all, example& ec);
void finish_example(vw& all, example& ec, bool update_loss);

template <class T>
void finish_example(vw& all, T&, example& ec)
{
finish_example(all, ec);
finish_example(all, ec, true);
}

template <class T>
void finish_example_without_loss(vw& all, T&, example& ec)
{
finish_example(all, ec, false);
}
} // namespace MULTICLASS
4 changes: 3 additions & 1 deletion vowpalwabbit/oaa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,10 @@ LEARNER::base_learner* oaa_setup(options_i& options, vw& all)
l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, false, false, false>,
predict_or_learn<false, false, false, false>, all.p, data->k, prediction_type::multiclass);

if (data_ptr->num_subsample > 0)
if (data_ptr->num_subsample > 0) {
l->set_learn(learn_randomized);
l->set_finish_example(MULTICLASS::finish_example_without_loss<oaa>);
}
l->set_finish(finish);

return make_base(*l);
Expand Down