Skip to content

Commit

Permalink
ksvm lambda fix (#1556)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alekh Agarwal authored and JohnLangford committed Aug 2, 2018
1 parent 1f1c98b commit 0510b1b
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions vowpalwabbit/kernel_svm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,11 @@ void predict(svm_params& params, svm_example** ec_arr, float* scores, size_t n)
for(size_t i = 0; i < n; i++)
{
ec_arr[i]->compute_kernels(params);
scores[i] = dense_dot(ec_arr[i]->krow.begin(), model->alpha, model->num_support)/params.lambda;
//cout<<"size of krow = "<<ec_arr[i]->krow.size()<<endl;
if(ec_arr[i]->krow.size() > 0)
scores[i] = dense_dot(ec_arr[i]->krow.begin(), model->alpha, model->num_support)/params.lambda;
else
scores[i] = 0;
}
}

Expand Down Expand Up @@ -517,7 +521,7 @@ bool update(svm_params& params, size_t pos)

float proj = alphaKi*ld.label;
float ai = (params.lambda - proj)/inprods[pos];
//params.all->opts_n_args.trace_message<<model->num_support<<" "<<pos<<" "<<proj<<" "<<alphaKi<<" "<<alpha_old<<" "<<ld.label<<" "<<model->delta[pos]<<" " << endl;
//cout<<model->num_support<<" "<<pos<<" "<<proj<<" "<<alphaKi<<" "<<alpha_old<<" "<<ld.label<<" "<<model->delta[pos]<<" " << ai<<" "<<params.lambda<<endl;

if(ai > fec->ex.l.simple.weight)
ai = fec->ex.l.simple.weight;
Expand Down Expand Up @@ -731,13 +735,13 @@ void train(svm_params& params)
if(model_pos >= 0)
{
bool overshoot = update(params, model_pos);
//params.all->opts_n_args.trace_message<<model_pos<<":alpha = "<<model->alpha[model_pos]<<endl;
//cout<<model_pos<<":alpha = "<<model->alpha[model_pos]<<endl;

double* subopt = calloc_or_throw<double>(model->num_support);
for(size_t j = 0; j < params.reprocess; j++)
{
if(model->num_support == 0) break;
//params.all->opts_n_args.trace_message<<"reprocess: ";
//cout<<"reprocess: ";
int randi = 1;
if (merand48(params.all->random_state) < 0.5)
randi = 0;
Expand All @@ -748,8 +752,8 @@ void train(svm_params& params)
{
if(!overshoot && max_pos == (size_t)model_pos && max_pos > 0 && j == 0)
params.all->opts_n_args.trace_message<<"Shouldn't reprocess right after process!!!"<<endl;
//params.all->opts_n_args.trace_message<<max_pos<<" "<<subopt[max_pos]<<endl;
// params.all->opts_n_args.trace_message<<params.model->support_vec[0]->example_counter<<endl;
//cout<<max_pos<<" "<<subopt[max_pos]<<endl;
//cout<<params.model->support_vec[0]->example_counter<<endl;
if(max_pos*model->num_support <= params.maxcache)
make_hot_sv(params, max_pos);
update(params, max_pos);
Expand All @@ -761,8 +765,8 @@ void train(svm_params& params)
update(params, rand_pos);
}
}
//params.all->opts_n_args.trace_message<<endl;
// params.all->opts_n_args.trace_message<<params.model->support_vec[0]->example_counter<<endl;
//cout<<endl;
//cout<<params.model->support_vec[0]->example_counter<<endl;
free(subopt);
}
}
Expand Down Expand Up @@ -792,6 +796,7 @@ void learn(svm_params& params, single_learner&, example& ec)
float score = 0;
predict(params, &sec, &score, 1);
ec.pred.scalar = score;
//cout<<"Score = "<<score<<endl;
ec.loss = max(0.f, 1.f - score*ec.l.simple.label);
params.loss_sum += ec.loss;
if(params.all->training && ec.example_counter % 100 == 0)
Expand Down Expand Up @@ -864,8 +869,8 @@ LEARNER::base_learner* kernel_svm_setup(arguments& arg)
("subsample", params->subsample, (size_t)1, "number of items to subsample from the pool")
.keep("kernel", kernel_type, (string)"linear", "type of kernel (rbf or linear (default))")
.keep("bandwidth", bandwidth, 1.f, "bandwidth of rbf kernel")
.keep("degree", degree, 2, "degree of poly kernel")
.keep("lambda", params->lambda, "saving regularization for test time").missing())
.keep("degree", degree, 2, "degree of poly kernel").missing())
//.keep("lambda", params->lambda, "saving regularization for test time").missing())
return nullptr;

string loss_function = "hinge";
Expand All @@ -891,6 +896,8 @@ LEARNER::base_learner* kernel_svm_setup(arguments& arg)
params->subsample = (size_t)ceil(params->pool_size / arg.all->all_reduce->total);

params->lambda = arg.all->l2_lambda;
if(params->lambda == 0.)
params->lambda = 1.;
params->all->opts_n_args.trace_message<<"Lambda = "<<params->lambda<<endl;
params->all->opts_n_args.trace_message<<"Kernel = "<<kernel_type<<endl;

Expand Down

0 comments on commit 0510b1b

Please sign in to comment.