Skip to content

Commit

Permalink
Revert "Fix to save the state of FTRL models (#1912)" (#1916)
Browse files Browse the repository at this point in the history
This reverts commit 538e9bb.
  • Loading branch information
JohnLangford authored Jun 6, 2019
1 parent 71627e7 commit 2cbe61b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 31 deletions.
3 changes: 1 addition & 2 deletions test/save_resume_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,8 @@ def do_test(filename, args, verbose=None, repeat_args=None, known_failure=False)
errors += do_test(filename, '--loss_function logistic --link logistic')
errors += do_test(filename, '--nn 2')
errors += do_test(filename, '--binary')
errors += do_test(filename, '--ftrl')
errors += do_test(filename, '--ftrl', known_failure=True)
errors += do_test(filename, '--pistol', known_failure=True)
errors += do_test(filename, '--coin', known_failure=True)

# this one also fails but pollutes output
#errors += do_test(filename, '--ksvm', known_failure=True)
Expand Down
8 changes: 2 additions & 6 deletions vowpalwabbit/ftrl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ struct ftrl
size_t no_win_counter;
size_t early_stop_thres;
double total_weight;
uint32_t ftrl_size;
};

struct uncertainty
Expand Down Expand Up @@ -152,7 +151,7 @@ void inner_update_pistol_state_and_predict(update_data& d, float x, float& wref)

float squared_theta = w[W_ZT] * w[W_ZT];
float tmp = 1.f / (d.ftrl_alpha * w[W_MX] * (w[W_G2] + w[W_MX]));
w[W_XT] = sqrt(w[W_G2]) * d.ftrl_beta * w[W_ZT] * correctedExp(squared_theta / 2.f * tmp) * tmp;
w[W_XT] = sqrt(w[W_G2]) * d.ftrl_beta * w[W_ZT] * correctedExp(squared_theta / 2 * tmp) * tmp;

d.predict += w[W_XT] * x;
}
Expand Down Expand Up @@ -316,7 +315,7 @@ void save_load(ftrl& b, io_buf& model_file, bool read, bool text)
bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text);

if (resume)
GD::save_load_online_state(*all, model_file, read, text, nullptr, b.ftrl_size);
GD::save_load_online_state(*all, model_file, read, text);
else
GD::save_load_regressor(*all, model_file, read, text);
}
Expand Down Expand Up @@ -390,21 +389,18 @@ base_learner* ftrl_setup(options_i& options, vw& all)
else
learn_ptr = learn_proximal<false>;
all.weights.stride_shift(2); // NOTE: for more parameter storage
b->ftrl_size = 3;
}
else if (pistol)
{
algorithm_name = "PiSTOL";
learn_ptr = learn_pistol;
all.weights.stride_shift(2); // NOTE: for more parameter storage
b->ftrl_size = 4;
}
else if (coin)
{
algorithm_name = "Coin Betting";
learn_ptr = learn_cb;
all.weights.stride_shift(3); // NOTE: for more parameter storage
b->ftrl_size = 6;
}

b->data.ftrl_alpha = b->ftrl_alpha;
Expand Down
28 changes: 7 additions & 21 deletions vowpalwabbit/gd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text)
}

template <class T>
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, stringstream& msg, uint32_t ftrl_size, T& weights)
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, stringstream& msg, T& weights)
{
uint64_t length = (uint64_t)1 << all.num_bits;

Expand All @@ -786,10 +786,8 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
if (i >= length)
THROW("Model content is corrupted, weight vector index " << i << " must be less than total vector length "
<< length);
weight buff[8] = {0, 0, 0, 0, 0, 0, 0, 0};
if (ftrl_size>0)
brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]) * ftrl_size, "");
else if (g == NULL || (!g->adaptive && !g->normalized))
weight buff[4] = {0, 0, 0, 0};
if (g == NULL || (!g->adaptive && !g->normalized))
brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]), "");
else if ((g->adaptive && !g->normalized) || (!g->adaptive && g->normalized))
brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]) * 2, "");
Expand All @@ -814,19 +812,7 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
else
brw = bin_text_write_fixed(model_file, (char*)&i, sizeof(i), msg, text);

if (ftrl_size==3) {
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 3 * sizeof(*v), msg, text);
}
else if (ftrl_size==4) {
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 4 * sizeof(*v), msg, text);
}
else if (ftrl_size==6) {
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << " " << (&(*v))[4] << " " << (&(*v))[5] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 6 * sizeof(*v), msg, text);
}
else if (g == nullptr || (!g->adaptive && !g->normalized))
if (g == nullptr || (!g->adaptive && !g->normalized))
{
msg << ":" << *v << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), sizeof(*v), msg, text);
Expand All @@ -846,7 +832,7 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
}
}

void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, uint32_t ftrl_size)
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g)
{
// vw& all = *g.all;
stringstream msg;
Expand Down Expand Up @@ -945,9 +931,9 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
all.current_pass = 0;
}
if (all.weights.sparse)
save_load_online_state(all, model_file, read, text, g, msg, ftrl_size, all.weights.sparse_weights);
save_load_online_state(all, model_file, read, text, g, msg, all.weights.sparse_weights);
else
save_load_online_state(all, model_file, read, text, g, msg, ftrl_size, all.weights.dense_weights);
save_load_online_state(all, model_file, read, text, g, msg, all.weights.dense_weights);
}

template <class T>
Expand Down
3 changes: 1 addition & 2 deletions vowpalwabbit/gd.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ struct gd;
float finalize_prediction(shared_data* sd, float ret);
void print_audit_features(vw&, example& ec);
void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text);
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, GD::gd* g = nullptr, uint32_t ftrl_size = 0);

void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, GD::gd* g = nullptr);

template <class T>
struct multipredict_info
Expand Down

0 comments on commit 2cbe61b

Please sign in to comment.