Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed all the bugs of save_resume #1917

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions test/save_resume_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,9 @@ def do_test(filename, args, verbose=None, repeat_args=None, known_failure=False)
errors += do_test(filename, '--loss_function logistic --link logistic')
errors += do_test(filename, '--nn 2')
errors += do_test(filename, '--binary')
errors += do_test(filename, '--ftrl', known_failure=True)
errors += do_test(filename, '--pistol', known_failure=True)
errors += do_test(filename, '--ftrl')
errors += do_test(filename, '--pistol')
errors += do_test(filename, '--coin')

# this one also fails but pollutes output
#errors += do_test(filename, '--ksvm', known_failure=True)
Expand Down
15 changes: 9 additions & 6 deletions vowpalwabbit/ftrl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct ftrl
struct update_data data;
size_t no_win_counter;
size_t early_stop_thres;
double total_weight;
uint32_t ftrl_size;
};

struct uncertainty
Expand Down Expand Up @@ -151,7 +151,7 @@ void inner_update_pistol_state_and_predict(update_data& d, float x, float& wref)

float squared_theta = w[W_ZT] * w[W_ZT];
float tmp = 1.f / (d.ftrl_alpha * w[W_MX] * (w[W_G2] + w[W_MX]));
w[W_XT] = sqrt(w[W_G2]) * d.ftrl_beta * w[W_ZT] * correctedExp(squared_theta / 2 * tmp) * tmp;
w[W_XT] = sqrt(w[W_G2]) * d.ftrl_beta * w[W_ZT] * correctedExp(squared_theta / 2.f * tmp) * tmp;

d.predict += w[W_XT] * x;
}
Expand Down Expand Up @@ -228,9 +228,9 @@ void update_state_and_predict_cb(ftrl& b, single_learner&, example& ec)
GD::foreach_feature<update_data, inner_update_cb_state_and_predict>(*b.all, ec, b.data);

b.all->normalized_sum_norm_x += ((double)ec.weight) * b.data.normalized_squared_norm_x;
b.total_weight += ec.weight;
b.all->total_weight += ec.weight;

ec.partial_prediction = b.data.predict/((float)((b.all->normalized_sum_norm_x + 1e-6)/b.total_weight));
ec.partial_prediction = b.data.predict/((float)((b.all->normalized_sum_norm_x + 1e-6)/b.all->total_weight));

ec.pred.scalar = GD::finalize_prediction(b.all->sd, ec.partial_prediction);
}
Expand Down Expand Up @@ -315,7 +315,7 @@ void save_load(ftrl& b, io_buf& model_file, bool read, bool text)
bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text);

if (resume)
GD::save_load_online_state(*all, model_file, read, text);
GD::save_load_online_state(*all, model_file, read, text, nullptr, b.ftrl_size);
else
GD::save_load_regressor(*all, model_file, read, text);
}
Expand Down Expand Up @@ -376,7 +376,7 @@ base_learner* ftrl_setup(options_i& options, vw& all)
b->all = &all;
b->no_win_counter = 0;
b->all->normalized_sum_norm_x = 0;
b->total_weight = 0.;
b->all->total_weight = 0;

void (*learn_ptr)(ftrl&, single_learner&, example&) = nullptr;

Expand All @@ -389,18 +389,21 @@ base_learner* ftrl_setup(options_i& options, vw& all)
else
learn_ptr = learn_proximal<false>;
all.weights.stride_shift(2); // NOTE: for more parameter storage
b->ftrl_size = 3;
}
else if (pistol)
{
algorithm_name = "PiSTOL";
learn_ptr = learn_pistol;
all.weights.stride_shift(2); // NOTE: for more parameter storage
b->ftrl_size = 4;
}
else if (coin)
{
algorithm_name = "Coin Betting";
learn_ptr = learn_cb;
all.weights.stride_shift(3); // NOTE: for more parameter storage
b->ftrl_size = 6;
}

b->data.ftrl_alpha = b->ftrl_alpha;
Expand Down
124 changes: 77 additions & 47 deletions vowpalwabbit/gd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace GD
struct gd
{
// double normalized_sum_norm_x;
double total_weight;
// double total_weight;
size_t no_win_counter;
size_t early_stop_thres;
float initial_constant;
Expand Down Expand Up @@ -559,14 +559,14 @@ float get_pred_per_update(gd& g, example& ec)
if (!stateless)
{
g.all->normalized_sum_norm_x += ((double)ec.weight) * nd.norm_x;
g.total_weight += ec.weight;
g.all->total_weight += ec.weight;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why prefer a global?

The general rule of thumb is to use variables which are as local as possible to minimize context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not clear to me how to solve this, I am open to suggestions.

From one side, the struct vw already contains similar quantities: power_t, invariant_updates, normalized_sum_norm_x, and similar ones are specific to gd, still they are in a global place.

Also, the problem comes from using GD::save_load_online_state in ftrl. We don't have access to ftrl data in this way. We could duplicate and customize the entire save state function in ftrl? It seems painful... Or hack the GD::save_load_online_state with even more optional inputs, but it also seems a bad idea...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of the three options, an extra argument seems preferred to either a global or code duplication.

Code duplication seems particularly bad---it's a recipe for non-maintainability.

The global variable is moving in the wrong direction---we are working towards atomizing the reductions so they can be composed with other learning algorithms.

The extra arguments approach seems the best. In the long term, we'd probably want to adjust the arguments so they are semantic rather than algorithm-specific. Basically, instead of having ftrl, we'd have "the number of floats per weight to store", etc... But this is a minor refactoring consistent with the extra arguments approach.

g.update_multiplier = average_update<sqrt_rate, adaptive, normalized>(
(float)g.total_weight, (float)g.all->normalized_sum_norm_x, g.neg_norm_power);
(float)g.all->total_weight, (float)g.all->normalized_sum_norm_x, g.neg_norm_power);
}
else
{
float nsnx = ((float)g.all->normalized_sum_norm_x) + ec.weight * nd.norm_x;
float tw = (float)g.total_weight + ec.weight;
float tw = (float)g.all->total_weight + ec.weight;
g.update_multiplier = average_update<sqrt_rate, adaptive, normalized>(tw, nsnx, g.neg_norm_power);
}
nd.pred_per_update *= g.update_multiplier;
Expand Down Expand Up @@ -684,6 +684,23 @@ void sync_weights(vw& all)
all.sd->contraction = 1.;
}

size_t write_index(io_buf& model_file, stringstream& msg, bool text, uint32_t num_bits, uint64_t i) {
size_t brw;
uint32_t old_i = 0;

msg << i;

if (num_bits < 31)
{
old_i = (uint32_t)i;
brw = bin_text_write_fixed(model_file, (char*)&old_i, sizeof(old_i), msg, text);
}
else
brw = bin_text_write_fixed(model_file, (char*)&i, sizeof(i), msg, text);

return brw;
}

template <class T>
void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text, T& weights)
{
Expand Down Expand Up @@ -738,16 +755,8 @@ void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text, T& w
{
i = v.index() >> weights.stride_shift();
stringstream msg;
msg << i;

if (all.num_bits < 31)
{
old_i = (uint32_t)i;
brw = bin_text_write_fixed(model_file, (char*)&old_i, sizeof(old_i), msg, text);
}
else
brw = bin_text_write_fixed(model_file, (char*)&i, sizeof(i), msg, text);

brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), sizeof(*v), msg, text);
}
Expand All @@ -762,7 +771,7 @@ void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text)
}

template <class T>
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, stringstream& msg, T& weights)
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, stringstream& msg, uint32_t ftrl_size, T& weights)
{
uint64_t length = (uint64_t)1 << all.num_bits;

Expand All @@ -786,8 +795,10 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
if (i >= length)
THROW("Model content is corrupted, weight vector index " << i << " must be less than total vector length "
<< length);
weight buff[4] = {0, 0, 0, 0};
if (g == NULL || (!g->adaptive && !g->normalized))
weight buff[8] = {0, 0, 0, 0, 0, 0, 0, 0};
if (ftrl_size>0)
brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]) * ftrl_size, "");
else if (g == NULL || (!g->adaptive && !g->normalized))
brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]), "");
else if ((g->adaptive && !g->normalized) || (!g->adaptive && g->normalized))
brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]) * 2, "");
Expand All @@ -799,40 +810,60 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
}
} while (brw > 0);
else // write binary or text
for (typename T::iterator v = weights.begin(); v != weights.end(); ++v)
if (*v != 0.)
{
i = v.index() >> weights.stride_shift();
msg << i;
if (all.num_bits < 31)
{
old_i = (uint32_t)i;
brw = bin_text_write_fixed(model_file, (char*)&old_i, sizeof(old_i), msg, text);
}
else
brw = bin_text_write_fixed(model_file, (char*)&i, sizeof(i), msg, text);
for (typename T::iterator v = weights.begin(); v != weights.end(); ++v) {
i = v.index() >> weights.stride_shift();

if (g == nullptr || (!g->adaptive && !g->normalized))
{
if (ftrl_size==3) {
if (*v != 0. || (&(*v))[1]!=0. || (&(*v))[2]!=0.) {
brw = write_index(model_file, msg, text, all.num_bits, i);
msg <<":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 3 * sizeof(*v), msg, text);
}
}
else if (ftrl_size==4) {
if (*v != 0. || (&(*v))[1]!=0. || (&(*v))[2]!=0. || (&(*v))[3]!=0.) {
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 4 * sizeof(*v), msg, text);
}
}
else if (ftrl_size==6) {
if (*v != 0. || (&(*v))[1]!=0. || (&(*v))[2]!=0. || (&(*v))[3]!=0. || (&(*v))[4]!=0. || (&(*v))[5]!=0.) {
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << " " << (&(*v))[4] << " " << (&(*v))[5] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 6 * sizeof(*v), msg, text);
}
}
else if (g == nullptr || (!g->adaptive && !g->normalized))
{
if (*v != 0.) {
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), sizeof(*v), msg, text);
}
else if ((g->adaptive && !g->normalized) || (!g->adaptive && g->normalized))
{
// either adaptive or normalized
}
else if ((g->adaptive && !g->normalized) || (!g->adaptive && g->normalized))
{
// either adaptive or normalized
if (*v != 0. || (&(*v))[1]!=0.) {
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 2 * sizeof(*v), msg, text);
}
else
{
// adaptive and normalized
}
else
{
// adaptive and normalized
if (*v != 0. || (&(*v))[1]!=0. || (&(*v))[2]!=0.) {
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 3 * sizeof(*v), msg, text);
}
}
}
}

void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g)
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, uint32_t ftrl_size)
{
// vw& all = *g.all;
stringstream msg;
Expand Down Expand Up @@ -892,12 +923,12 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g

// fix average loss
double total_weight = 0.; // value holder as g* may be null
if (!read && g != nullptr)
total_weight = g->total_weight;
msg << "gd::total_weight " << total_weight << "\n";
if (!read)
total_weight = all.total_weight;
msg << "total_weight " << total_weight << "\n";
bin_text_read_write_fixed(model_file, (char*)&total_weight, sizeof(total_weight), "", read, msg, text);
if (read && g != nullptr)
g->total_weight = total_weight;
if (read)
all.total_weight = total_weight;

// fix "loss since last" for first printed out example details
msg << "sd::oec.weighted_labeled_examples " << all.sd->old_weighted_labeled_examples << "\n";
Expand Down Expand Up @@ -931,9 +962,9 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
all.current_pass = 0;
}
if (all.weights.sparse)
save_load_online_state(all, model_file, read, text, g, msg, all.weights.sparse_weights);
save_load_online_state(all, model_file, read, text, g, msg, ftrl_size, all.weights.sparse_weights);
else
save_load_online_state(all, model_file, read, text, g, msg, all.weights.dense_weights);
save_load_online_state(all, model_file, read, text, g, msg, ftrl_size, all.weights.dense_weights);
}

template <class T>
Expand Down Expand Up @@ -987,7 +1018,6 @@ void save_load(gd& g, io_buf& model_file, bool read, bool text)
<< "WARNING: --save_resume functionality is known to have inaccuracy in model files version less than "
<< VERSION_SAVE_RESUME_FIX << endl
<< endl;
// save_load_online_state(g, model_file, read, text);
save_load_online_state(all, model_file, read, text, &g);
}
else
Expand Down Expand Up @@ -1105,7 +1135,7 @@ base_learner* setup(options_i& options, vw& all)
g->all = &all;
g->all->normalized_sum_norm_x = 0;
g->no_win_counter = 0;
g->total_weight = 0.;
g->all->total_weight = 0.;
g->neg_norm_power = (all.adaptive ? (all.power_t - 1.f) : -1.f);
g->neg_power_t = -all.power_t;
g->adaptive = all.adaptive;
Expand All @@ -1115,7 +1145,7 @@ base_learner* setup(options_i& options, vw& all)
// seen (all.initial_t) previous fake datapoints all with norm 1
{
g->all->normalized_sum_norm_x = all.initial_t;
g->total_weight = all.initial_t;
g->all->total_weight = all.initial_t;
}

bool feature_mask_off = true;
Expand Down
3 changes: 2 additions & 1 deletion vowpalwabbit/gd.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ struct gd;
float finalize_prediction(shared_data* sd, float ret);
void print_audit_features(vw&, example& ec);
void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text);
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, GD::gd* g = nullptr);
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, GD::gd* g = nullptr, uint32_t ftrl_size = 0);


template <class T>
struct multipredict_info
Expand Down
1 change: 1 addition & 0 deletions vowpalwabbit/global_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ struct vw

version_struct model_file_ver;
double normalized_sum_norm_x;
double total_weight;
bool vw_is_main = false; // true if vw is executable; false in library mode

// error reporting
Expand Down