Skip to content

Commit

Permalink
update to marginal for unweighted and update before learn
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnLangford committed Feb 13, 2017
1 parent e16f2d8 commit 6d0e1e1
Showing 1 changed file with 98 additions and 54 deletions.
152 changes: 98 additions & 54 deletions vowpalwabbit/marginal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,69 +12,109 @@ struct data
{ float initial_numerator;
float initial_denominator;
float decay;
bool update_before_learn;
bool unweighted_marginals;
bool id_features[256];
features temp[256];//temporary storage when reducing.
unordered_map<uint64_t, marginal > marginals;
vw* all;
};

template <bool is_learn>
void predict_or_learn(data& sm, LEARNER::base_learner& base, example& ec)
{
uint64_t mask = sm.all->weights.mask();
void make_marginal(data& sm, example& ec)
{
uint64_t mask = sm.all->weights.mask();
for (example::iterator i = ec.begin(); i!= ec.end(); ++i)
{ namespace_index n = i.index();
if (sm.id_features[n])
{
std::swap(sm.temp[n],*i);
features& f = *i;
f.erase();
for (features::iterator j = sm.temp[n].begin(); j != sm.temp[n].end(); ++j)
{ float first_value = j.value();
uint64_t first_index = j.index() & mask;
if (++j == sm.temp[n].end())
{ cout << "warning: id feature namespace has " << sm.temp[n].size() << " features. Should be a multiple of 2" << endl;
break;
}
float second_value = j.value();
uint64_t second_index = j.index() & mask;
if (first_value != 1. || second_value != 1.)
{ cout << "warning: bad id features, must have value 1." << endl;
continue;
}
uint64_t key = second_index + ec.ft_offset;
if (sm.marginals.find(key) == sm.marginals.end())//need to initialize things.
sm.marginals.insert(make_pair(key,make_pair(sm.initial_numerator, sm.initial_denominator)));
f.push_back((float)(sm.marginals[key].first / sm.marginals[key].second), first_index);
if (!sm.temp[n].space_names.empty())
f.space_names.push_back(sm.temp[n].space_names[2*(f.size()-1)]);
}
}
}
}

for (example::iterator i = ec.begin(); i!= ec.end(); ++i)
{ namespace_index n = i.index();
if (sm.id_features[n])
{ std::swap(sm.temp[n],*i);
features& f = *i;
f.erase();
for (features::iterator j = sm.temp[n].begin(); j != sm.temp[n].end(); ++j)
{ float first_value = j.value();
uint64_t first_index = j.index() & mask;
if (++j == sm.temp[n].end())
{ cout << "warning: id feature namespace has " << sm.temp[n].size() << " features. Should be a multiple of 2" << endl;
void undo_marginal(data& sm, example& ec)
{
for (example::iterator i = ec.begin(); i!= ec.end(); ++i)
{ namespace_index n = i.index();
if (sm.id_features[n])
std::swap(sm.temp[n],*i);
}
}

void update_marginal(data& sm, example& ec)
{
uint64_t mask = sm.all->weights.mask();
for (example::iterator i = ec.begin(); i!= ec.end(); ++i)
{ namespace_index n = i.index();
if (sm.id_features[n])
for (features::iterator j = sm.temp[n].begin(); j != sm.temp[n].end(); ++j)
{ if (++j == sm.temp[n].end())
break;
}
float second_value = j.value();
uint64_t second_index = j.index() & mask;
if (first_value != 1. || second_value != 1.)
{ cout << "warning: bad id features, must have value 1." << endl;
continue;
}
uint64_t key = second_index + ec.ft_offset;
if (sm.marginals.find(key) == sm.marginals.end())//need to initialize things.
sm.marginals.insert(make_pair(key,make_pair(sm.initial_numerator, sm.initial_denominator)));
f.push_back((float)(sm.marginals[key].first / sm.marginals[key].second), first_index);
if (!sm.temp[n].space_names.empty())
f.space_names.push_back(sm.temp[n].space_names[2*(f.size()-1)]);
}
uint64_t second_index = j.index() & mask;
uint64_t key = second_index + ec.ft_offset;
marginal& m = sm.marginals[key];
if (sm.unweighted_marginals)
{
m.first = m.first * (1. - sm.decay) + ec.l.simple.label;
m.second = m.second * (1. - sm.decay);
}
else
{
m.first = m.first * (1. - sm.decay) + ec.l.simple.label * ec.weight;
m.second = m.second * (1. - sm.decay) + ec.weight;
}
}
}
}

}

template <bool is_learn>
void predict_or_learn(data& sm, LEARNER::base_learner& base, example& ec)
{
make_marginal(sm, ec);
if (is_learn)
base.learn(ec);
if (sm.update_before_learn)
{
base.predict(ec);
float pred = ec.pred.scalar;
undo_marginal(sm, ec);
update_marginal(sm, ec);//update features before learning.
make_marginal(sm, ec);
base.learn(ec);
ec.pred.scalar = pred;
}
else
{
base.learn(ec);
update_marginal(sm,ec);
}
else
base.predict(ec);

for (example::iterator i = ec.begin(); i!= ec.end(); ++i)
{ namespace_index n = i.index();
if (sm.id_features[n])
{ if (is_learn)
for (features::iterator j = sm.temp[n].begin(); j != sm.temp[n].end(); ++j)
{ if (++j == sm.temp[n].end())
break;
uint64_t second_index = j.index() & mask;
uint64_t key = second_index + ec.ft_offset;
marginal& m = sm.marginals[key];
m.first = m.first * (1. - sm.decay) + ec.l.simple.label * ec.weight;
m.second = m.second * (1. - sm.decay) + ec.weight;
}
std::swap(sm.temp[n],*i);
}
}
//undo marginalization
undo_marginal(sm, ec);
}

void finish(data& sm)
{ sm.marginals.~unordered_map();
for (size_t i =0; i < 256; i++)
Expand Down Expand Up @@ -129,15 +169,19 @@ LEARNER::base_learner* marginal_setup(vw& all)
{ if (missing_option<string, true>(all, "marginal", "substitute marginal label estimates for ids"))
return nullptr;
new_options(all)
("initial_denominator", po::value<float>()->default_value(1.f), "initial denominator")
("initial_numerator", po::value<float>()->default_value(0.5f), "initial numerator")
("decay", po::value<float>()->default_value(0.f), "decay multiplier per event (1e-3 for example)");
("initial_denominator", po::value<float>()->default_value(1.f), "initial denominator")
("initial_numerator", po::value<float>()->default_value(0.5f), "initial numerator")
("update_before_learn",po::value<bool>()->default_value(false), "update marginal values before learning")
("unweighted_marginals",po::value<bool>()->default_value(false), "ignore importance weights when computing marginals")
("decay", po::value<float>()->default_value(0.f), "decay multiplier per event (1e-3 for example)");
add_options(all);

data& d = calloc_or_throw<data>();
d.initial_numerator = all.vm["initial_numerator"].as<float>();
d.initial_denominator = all.vm["initial_denominator"].as<float>();
d.decay = all.vm["decay"].as<float>();
d.update_before_learn = all.vm["update_before_learn"].as<bool>();
d.unweighted_marginals = all.vm["unweighted_marginals"].as<bool>();
d.all = &all;
string s = (string)all.vm["marginal"].as<string>();

Expand Down

0 comments on commit 6d0e1e1

Please sign in to comment.