Skip to content

Commit

Permalink
Merge branch 'master' into new-java-build
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-morra-zefr authored Feb 15, 2017
2 parents 819eea9 + 6d0e1e1 commit cd535bb
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 61 deletions.
15 changes: 9 additions & 6 deletions vowpalwabbit/cb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,15 @@ void print_update(vw& all, bool is_test, example& ec, v_array<example*>* ec_seq,
label_buf = " known";

if (action_scores)
{ std::ostringstream pred_buf;
pred_buf << std::setw(all.sd->col_current_predict) << std::right << std::setfill(' ')
<< ec.pred.a_s[0].action << ":" << ec.pred.a_s[0].score <<"...";
all.sd->print_update(all.holdout_set_off, all.current_pass, label_buf, pred_buf.str(),
num_features, all.progress_add, all.progress_arg);;
}
{ std::ostringstream pred_buf;
pred_buf << std::setw(all.sd->col_current_predict) << std::right << std::setfill(' ');
if (ec.pred.a_s.size() > 0)
pred_buf << ec.pred.a_s[0].action << ":" << ec.pred.a_s[0].score <<"...";
else
pred_buf << "no action";
all.sd->print_update(all.holdout_set_off, all.current_pass, label_buf, pred_buf.str(),
num_features, all.progress_add, all.progress_arg);;
}
else
all.sd->print_update(all.holdout_set_off, all.current_pass, label_buf, (uint32_t)pred,
num_features, all.progress_add, all.progress_arg);
Expand Down
5 changes: 5 additions & 0 deletions vowpalwabbit/cb_explore_adf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ void predict_or_learn_first(cb_explore_adf& data, base_learner& base, v_array<ex
uint32_t num_actions = (uint32_t)(examples.size() - 1);
if (CB::ec_is_example_header(*examples[0]))
num_actions--;
if (num_actions == 0)
{
preds.erase();
return;
}

data.action_probs.resize(num_actions);
data.action_probs.erase();
Expand Down
6 changes: 5 additions & 1 deletion vowpalwabbit/csoaa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,11 @@ void make_single_prediction(ldf& data, base_learner& base, example& ec)
}

bool check_ldf_sequence(ldf& data, size_t start_K)
{ bool isTest = COST_SENSITIVE::example_is_test(*data.ec_seq[start_K]);
{ bool isTest;
if (start_K == data.ec_seq.size())
isTest = true;
else
isTest = COST_SENSITIVE::example_is_test(*data.ec_seq[start_K]);
for (size_t k=start_K; k<data.ec_seq.size(); k++)
{ example *ec = data.ec_seq[k];
// Each sub-example must have just one cost
Expand Down
152 changes: 98 additions & 54 deletions vowpalwabbit/marginal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,69 +12,109 @@ struct data
{ float initial_numerator;
float initial_denominator;
float decay;
bool update_before_learn;
bool unweighted_marginals;
bool id_features[256];
features temp[256];//temporary storage when reducing.
unordered_map<uint64_t, marginal > marginals;
vw* all;
};

template <bool is_learn>
void predict_or_learn(data& sm, LEARNER::base_learner& base, example& ec)
{
uint64_t mask = sm.all->weights.mask();
void make_marginal(data& sm, example& ec)
{
uint64_t mask = sm.all->weights.mask();
for (example::iterator i = ec.begin(); i!= ec.end(); ++i)
{ namespace_index n = i.index();
if (sm.id_features[n])
{
std::swap(sm.temp[n],*i);
features& f = *i;
f.erase();
for (features::iterator j = sm.temp[n].begin(); j != sm.temp[n].end(); ++j)
{ float first_value = j.value();
uint64_t first_index = j.index() & mask;
if (++j == sm.temp[n].end())
{ cout << "warning: id feature namespace has " << sm.temp[n].size() << " features. Should be a multiple of 2" << endl;
break;
}
float second_value = j.value();
uint64_t second_index = j.index() & mask;
if (first_value != 1. || second_value != 1.)
{ cout << "warning: bad id features, must have value 1." << endl;
continue;
}
uint64_t key = second_index + ec.ft_offset;
if (sm.marginals.find(key) == sm.marginals.end())//need to initialize things.
sm.marginals.insert(make_pair(key,make_pair(sm.initial_numerator, sm.initial_denominator)));
f.push_back((float)(sm.marginals[key].first / sm.marginals[key].second), first_index);
if (!sm.temp[n].space_names.empty())
f.space_names.push_back(sm.temp[n].space_names[2*(f.size()-1)]);
}
}
}
}

for (example::iterator i = ec.begin(); i!= ec.end(); ++i)
{ namespace_index n = i.index();
if (sm.id_features[n])
{ std::swap(sm.temp[n],*i);
features& f = *i;
f.erase();
for (features::iterator j = sm.temp[n].begin(); j != sm.temp[n].end(); ++j)
{ float first_value = j.value();
uint64_t first_index = j.index() & mask;
if (++j == sm.temp[n].end())
{ cout << "warning: id feature namespace has " << sm.temp[n].size() << " features. Should be a multiple of 2" << endl;
void undo_marginal(data& sm, example& ec)
{
for (example::iterator i = ec.begin(); i!= ec.end(); ++i)
{ namespace_index n = i.index();
if (sm.id_features[n])
std::swap(sm.temp[n],*i);
}
}

void update_marginal(data& sm, example& ec)
{
uint64_t mask = sm.all->weights.mask();
for (example::iterator i = ec.begin(); i!= ec.end(); ++i)
{ namespace_index n = i.index();
if (sm.id_features[n])
for (features::iterator j = sm.temp[n].begin(); j != sm.temp[n].end(); ++j)
{ if (++j == sm.temp[n].end())
break;
}
float second_value = j.value();
uint64_t second_index = j.index() & mask;
if (first_value != 1. || second_value != 1.)
{ cout << "warning: bad id features, must have value 1." << endl;
continue;
}
uint64_t key = second_index + ec.ft_offset;
if (sm.marginals.find(key) == sm.marginals.end())//need to initialize things.
sm.marginals.insert(make_pair(key,make_pair(sm.initial_numerator, sm.initial_denominator)));
f.push_back((float)(sm.marginals[key].first / sm.marginals[key].second), first_index);
if (!sm.temp[n].space_names.empty())
f.space_names.push_back(sm.temp[n].space_names[2*(f.size()-1)]);
}
uint64_t second_index = j.index() & mask;
uint64_t key = second_index + ec.ft_offset;
marginal& m = sm.marginals[key];
if (sm.unweighted_marginals)
{
m.first = m.first * (1. - sm.decay) + ec.l.simple.label;
m.second = m.second * (1. - sm.decay);
}
else
{
m.first = m.first * (1. - sm.decay) + ec.l.simple.label * ec.weight;
m.second = m.second * (1. - sm.decay) + ec.weight;
}
}
}
}

}

template <bool is_learn>
void predict_or_learn(data& sm, LEARNER::base_learner& base, example& ec)
{
make_marginal(sm, ec);
if (is_learn)
base.learn(ec);
if (sm.update_before_learn)
{
base.predict(ec);
float pred = ec.pred.scalar;
undo_marginal(sm, ec);
update_marginal(sm, ec);//update features before learning.
make_marginal(sm, ec);
base.learn(ec);
ec.pred.scalar = pred;
}
else
{
base.learn(ec);
update_marginal(sm,ec);
}
else
base.predict(ec);

for (example::iterator i = ec.begin(); i!= ec.end(); ++i)
{ namespace_index n = i.index();
if (sm.id_features[n])
{ if (is_learn)
for (features::iterator j = sm.temp[n].begin(); j != sm.temp[n].end(); ++j)
{ if (++j == sm.temp[n].end())
break;
uint64_t second_index = j.index() & mask;
uint64_t key = second_index + ec.ft_offset;
marginal& m = sm.marginals[key];
m.first = m.first * (1. - sm.decay) + ec.l.simple.label * ec.weight;
m.second = m.second * (1. - sm.decay) + ec.weight;
}
std::swap(sm.temp[n],*i);
}
}
//undo marginalization
undo_marginal(sm, ec);
}

void finish(data& sm)
{ sm.marginals.~unordered_map();
for (size_t i =0; i < 256; i++)
Expand Down Expand Up @@ -129,15 +169,19 @@ LEARNER::base_learner* marginal_setup(vw& all)
{ if (missing_option<string, true>(all, "marginal", "substitute marginal label estimates for ids"))
return nullptr;
new_options(all)
("initial_denominator", po::value<float>()->default_value(1.f), "initial denominator")
("initial_numerator", po::value<float>()->default_value(0.5f), "initial numerator")
("decay", po::value<float>()->default_value(0.f), "decay multiplier per event (1e-3 for example)");
("initial_denominator", po::value<float>()->default_value(1.f), "initial denominator")
("initial_numerator", po::value<float>()->default_value(0.5f), "initial numerator")
("update_before_learn",po::value<bool>()->default_value(false), "update marginal values before learning")
("unweighted_marginals",po::value<bool>()->default_value(false), "ignore importance weights when computing marginals")
("decay", po::value<float>()->default_value(0.f), "decay multiplier per event (1e-3 for example)");
add_options(all);

data& d = calloc_or_throw<data>();
d.initial_numerator = all.vm["initial_numerator"].as<float>();
d.initial_denominator = all.vm["initial_denominator"].as<float>();
d.decay = all.vm["decay"].as<float>();
d.update_before_learn = all.vm["update_before_learn"].as<bool>();
d.unweighted_marginals = all.vm["unweighted_marginals"].as<bool>();
d.all = &all;
string s = (string)all.vm["marginal"].as<string>();

Expand Down

0 comments on commit cd535bb

Please sign in to comment.