Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into aml_mm
Browse files Browse the repository at this point in the history
  • Loading branch information
bassmang committed Jan 25, 2023
2 parents f63d357 + 0c0a43f commit 7eace7c
Show file tree
Hide file tree
Showing 14 changed files with 76 additions and 110 deletions.
2 changes: 1 addition & 1 deletion python/docs/source/examples/epsilon_decay.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "c06fbf85",
"metadata": {},
"source": [
"### [Experimental] Testing Basic Models with Varying Epsilon Values and Model Counts for Non-Stationary Epsilon Decay"
"# [Experimental] Testing Basic Models with Varying Epsilon Values and Model Counts for Non-Stationary Epsilon Decay"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions python/docs/source/examples/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
basics
contextual_bandit
epsilon_decay
mini_vw
poisson_regression
predict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@
"metadata": {},
"source": [
"\n",
"We want to be able to visualize what is occurring, so we are going to plot the click through rate over each iteration of the simulation. If VW is showing actions the get rewards the ctr will be higher. Below is a little utility function to make showing the plot easier.\n"
"We want to be able to visualize what is occurring, so we are going to plot the click through rate over each iteration of the simulation. If VW is showing actions that get rewards the ctr will be higher. Below is a little utility function to make showing the plot easier.\n"
]
},
{
Expand Down
9 changes: 5 additions & 4 deletions python/tests/test_confidence_sequence_robust.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,14 @@ def test_cs_robust_ci():
csr.addobs((0.879389936, 1))
csr.addobs((0.880083546, 1))
csr.addobs((0.8807676, 1))
csr.addobs((43.91880613, 0.9))

alpha = 0.05 / 16.0
lb, ub = csr.getci(alpha)

# Compare bounds to confidence_sequence_robust_test.cc
np.testing.assert_almost_equal(lb, 0.19224909776696902, 6)
np.testing.assert_almost_equal(ub, 0.93637288977188682, 6)
np.testing.assert_almost_equal(lb, 0.30209151281846858, 4)
np.testing.assert_almost_equal(ub, 0.90219143188334106, 4)

# Test brentq separately
s = 139.8326745448
Expand All @@ -114,8 +115,8 @@ def test_cs_robust_ci():
)

# Compare to brentq in confidence_sequence_robust_test.cc
np.testing.assert_almost_equal(res.root, 0.31672263211621371, 6)
np.testing.assert_almost_equal(res.root, 0.30209143008131789, 4)

# Test that root of objective function is 0
test_root = csr.lower.logwealthmix(mu=res.root, s=s, thres=thres, memo=memo) - thres
np.testing.assert_almost_equal(test_root, 0.0)
np.testing.assert_almost_equal(test_root, 0.0, 2)
6 changes: 3 additions & 3 deletions test/train-sets/ref/coin_model_overflow.invert.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ sd::oec.weighted_labeled_examples 1
current_pass 1
l1_state 0
l2_state 1
34875:0 2e+10 2e+10 1e+10 0 2
116060:0 2 2 1 0 2
195802:0 2e+20 2e+20 1e+20 0 2
f:34875:0 2e+10 2e+10 1e+10 0 2
Constant:116060:0 2 2 1 0 2
f*f:195802:0 2e+20 2e+20 1e+20 0 2
4 changes: 2 additions & 2 deletions test/train-sets/ref/help.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ Feature Options:
only (type: list[str], keep)
--ignore_features_dsjson_experimental args...
Ignore specified features from namespace. To ignore a feature
arg should be namespace|feature To ignore a feature in the default
namespace, arg should be |feature (type: list[str], keep, experimental)
arg should be <namespace>|<feature>. <namespace> should be empty
for default (type: list[str], keep, experimental)
--keep args... Keep namespaces beginning with character <arg> (type: list[str],
keep)
--redefine args... Redefine namespaces beginning with characters of std::string
Expand Down
4 changes: 2 additions & 2 deletions test/train-sets/ref/help_cbadf.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ Feature Options:
only (type: list[str], keep)
--ignore_features_dsjson_experimental args...
Ignore specified features from namespace. To ignore a feature
arg should be namespace|feature To ignore a feature in the default
namespace, arg should be |feature (type: list[str], keep, experimental)
arg should be <namespace>|<feature>. <namespace> should be empty
for default (type: list[str], keep, experimental)
--keep args... Keep namespaces beginning with character <arg> (type: list[str],
keep)
--redefine args... Redefine namespaces beginning with characters of std::string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,15 @@ class g_tilde
class countable_discrete_base
{
public:
countable_discrete_base(double eta = 0.95f, double r = 2.0, double k = 1.5, double lambda_max = 0.5, double xi = 1.6);
countable_discrete_base(double eta = 0.95f, double k = 1.5, double lambda_max = 0.5, double xi = 1.6);
double get_ci(double alpha) const;
double get_lam_sqrt_tp1(double j) const;
double get_v_impl(std::map<uint64_t, double>& memo, uint64_t j) const;
double log_wealth_mix(double mu, double s, double thres, std::map<uint64_t, double>& memo) const;
double root_brentq(double s, double thres, std::map<uint64_t, double>& memo, double min_mu, double max_mu,
double toll_x = 1e-10, double toll_f = 1e-12) const;
double toll_x = 1e-4, double toll_f = 1e-20) const;
double log_sum_exp(const std::vector<double>& combined) const;
double lb_log_wealth(double alpha) const;
double polylog(double r, double eta) const;
double get_log_weight(double j) const;
double get_log_remaining_weight(double j) const;
double get_s() const;
Expand Down
33 changes: 5 additions & 28 deletions vowpalwabbit/core/src/confidence_sequence_robust.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,17 @@ void g_tilde::reset_stats()
t = 0;
}

countable_discrete_base::countable_discrete_base(double eta, double r, double k, double lambda_max, double xi)
countable_discrete_base::countable_discrete_base(double eta, double k, double lambda_max, double xi)
: log_xi(std::log1p(xi - 1))
, log_xi_m1(std::log1p(xi - 2.0))
, lambda_max(lambda_max)
, zeta_r(1.6449340668482264) // std::riemann_zeta(r) -- Assuming r=2.0 is constant
, scale_fac(0.5 * (1.0 + polylog(r, eta) / (eta * zeta_r)))
, zeta_r(1.6449340668482264) // std::riemann_zeta(r) -- Assuming r=2.0
, scale_fac(0.5 * (1.0 + 1.4406337969700393 / (eta * zeta_r))) // polylog(r, eta) -- Assuming eta=0.95, r=2.0
, log_scale_fac(std::log1p(scale_fac - 1.0))
, t(0)
, gt(k)
{
assert(0.0 < eta && eta < 1.0);
assert(r > 1.0);
assert(0.0 < scale_fac && scale_fac < 1.0);
assert(0.0 < lambda_max && lambda_max <= 1.0 + (-0.15859433956303937)); // sc.lambertw(-exp(-2)) in Python
assert(1.0 < xi);
Expand Down Expand Up @@ -213,15 +212,7 @@ double countable_discrete_base::root_brentq(
}
else { mflag = false; }

size_t memo_size = memo.size();
fs = f(s);
while (memo.size() > memo_size)
{
memo_size = memo.size();
fa = f(a);
fb = f(b);
fs = f(s);
}
d = c;
c = b;
fc = fb;
Expand All @@ -244,7 +235,8 @@ double countable_discrete_base::root_brentq(
}
}

return s;
// Returning lower estimate of location of root
return std::min(a, b);
}

double countable_discrete_base::lb_log_wealth(double alpha) const
Expand All @@ -265,21 +257,6 @@ double countable_discrete_base::lb_log_wealth(double alpha) const
return root_brentq(s, thres, memo, min_mu, max_mu);
}

double countable_discrete_base::polylog(double r, double eta) const
{
double ret_val = 0.0;
double min_thres = 1e-10;
double curr_val = std::numeric_limits<int>::max();
uint64_t k = 0;
while (curr_val > min_thres)
{
k += 1;
curr_val = std::pow(eta, k) / std::pow(k, r);
ret_val += curr_val;
}
return ret_val;
}

double countable_discrete_base::get_log_weight(double j) const { return log_scale_fac + log_xi_m1 - (1 + j) * log_xi; }

double countable_discrete_base::get_log_remaining_weight(double j) const { return log_scale_fac - j * log_xi; }
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/core/src/parse_args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,8 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti
.help("Ignore namespaces beginning with character <arg> for linear terms only"))
.add(make_option("ignore_features_dsjson_experimental", ignore_features_dsjson)
.keep()
.help("Ignore specified features from namespace. To ignore a feature arg should be namespace|feature "
"To ignore a feature in the default namespace, arg should be |feature")
.help("Ignore specified features from namespace. To ignore a feature arg should be "
"<namespace>|<feature>. <namespace> should be empty for default")
.experimental())
.add(make_option("keep", keeps).keep().help("Keep namespaces beginning with character <arg>"))
.add(make_option("redefine", redefines)
Expand Down
51 changes: 21 additions & 30 deletions vowpalwabbit/core/src/reductions/gd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -994,48 +994,39 @@ void save_load_online_state_weights(VW::workspace& all, VW::io_buf& model_file,
for (typename T::iterator v = weights.begin(); v != weights.end(); ++v)
{
i = v.index() >> weights.stride_shift();

bool gd_write = *v != 0.f;
bool ftrl3_write = ftrl_size == 3 && (*v != 0.f || (&(*v))[1] != 0.f || (&(*v))[2] != 0.f);
bool ftrl4_write = ftrl_size == 4 && (*v != 0.f || (&(*v))[1] != 0.f || (&(*v))[2] != 0.f || (&(*v))[3] != 0.f);
bool ftrl6_write = ftrl_size == 6 &&
(*v != 0.f || (&(*v))[1] != 0.f || (&(*v))[2] != 0.f || (&(*v))[3] != 0.f || (&(*v))[4] != 0.f ||
(&(*v))[5] != 0.f);
if (all.print_invert) // write readable model with feature names
{
if (*v != 0.f)
if (gd_write || ftrl3_write || ftrl4_write || ftrl6_write)
{
const auto map_it = all.index_name_map.find(i);
if (map_it != all.index_name_map.end())
{
msg << to_string(map_it->second) << ":";
VW::details::bin_text_write_fixed(model_file, nullptr /*unused*/, 0 /*unused*/, msg, true);
}
if (map_it != all.index_name_map.end()) { msg << to_string(map_it->second) << ":"; }
}
}

if (ftrl_size == 3)
if (ftrl3_write)
{
if (*v != 0. || (&(*v))[1] != 0. || (&(*v))[2] != 0.)
{
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << "\n";
brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), 3 * sizeof(*v), msg, text);
}
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << "\n";
brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), 3 * sizeof(*v), msg, text);
}
else if (ftrl_size == 4)
else if (ftrl4_write)
{
if (*v != 0. || (&(*v))[1] != 0. || (&(*v))[2] != 0. || (&(*v))[3] != 0.)
{
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << "\n";
brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), 4 * sizeof(*v), msg, text);
}
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << "\n";
brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), 4 * sizeof(*v), msg, text);
}
else if (ftrl_size == 6)
else if (ftrl6_write)
{
if (*v != 0. || (&(*v))[1] != 0. || (&(*v))[2] != 0. || (&(*v))[3] != 0. || (&(*v))[4] != 0. ||
(&(*v))[5] != 0.)
{
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << " " << (&(*v))[4] << " "
<< (&(*v))[5] << "\n";
brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), 6 * sizeof(*v), msg, text);
}
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << " " << (&(*v))[4] << " "
<< (&(*v))[5] << "\n";
brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), 6 * sizeof(*v), msg, text);
}
else if (g == nullptr || (!all.weights.adaptive && !all.weights.normalized))
{
Expand Down
18 changes: 8 additions & 10 deletions vowpalwabbit/core/tests/automl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,9 @@ TEST(Automl, NamespaceSwitchWIterations)

auto champ_exclusions =
aml->cm->_config_oracle.configs[aml->cm->estimators[aml->cm->current_champ].first.config_index].elements;
EXPECT_EQ(champ_exclusions.size(), 1);
std::vector<VW::namespace_index> ans{'U', 'U'};
EXPECT_NE(champ_exclusions.find(ans), champ_exclusions.end());
EXPECT_EQ(champ_exclusions.size(), 0);
auto champ_interactions = aml->cm->estimators[aml->cm->current_champ].first.live_interactions;
EXPECT_EQ(champ_interactions.size(), 5);
EXPECT_EQ(champ_interactions.size(), 6);
return true;
});

Expand Down Expand Up @@ -361,9 +359,9 @@ TEST(Automl, ClearConfigsWIterations)
EXPECT_EQ(aml->cm->current_champ, 0);
EXPECT_EQ(aml->cm->_config_oracle.valid_config_size, 4);
EXPECT_EQ(clear_champ_switch - 1, aml->cm->total_learn_count);
EXPECT_EQ(aml->cm->estimators[0].first.live_interactions.size(), 2);
EXPECT_EQ(aml->cm->estimators[1].first.live_interactions.size(), 3);
EXPECT_EQ(aml->cm->estimators[2].first.live_interactions.size(), 1);
EXPECT_EQ(aml->cm->estimators[0].first.live_interactions.size(), 3);
EXPECT_EQ(aml->cm->estimators[1].first.live_interactions.size(), 2);
EXPECT_EQ(aml->cm->estimators[2].first.live_interactions.size(), 2);
EXPECT_EQ(aml->current_state, VW::reductions::automl::automl_state::Experimenting);
return true;
});
Expand Down Expand Up @@ -434,9 +432,9 @@ TEST(Automl, ClearConfigsOneDiffWIterations)
{
aml_test::aml_onediff* aml = aml_test::get_automl_data<VW::reductions::automl::one_diff_impl>(all);
EXPECT_EQ(aml->cm->estimators.size(), 3);
EXPECT_EQ(aml->cm->estimators[0].first.live_interactions.size(), 2);
EXPECT_EQ(aml->cm->estimators[1].first.live_interactions.size(), 3);
EXPECT_EQ(aml->cm->estimators[2].first.live_interactions.size(), 1);
EXPECT_EQ(aml->cm->estimators[0].first.live_interactions.size(), 3);
EXPECT_EQ(aml->cm->estimators[1].first.live_interactions.size(), 2);
EXPECT_EQ(aml->cm->estimators[2].first.live_interactions.size(), 2);
return true;
});

Expand Down
9 changes: 5 additions & 4 deletions vowpalwabbit/core/tests/confidence_sequence_robust_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,14 @@ TEST(ConfidenceSequenceRobust, PythonEquivalenceCI)
csr.update(0.879389936, 1);
csr.update(0.880083546, 1);
csr.update(0.8807676, 1);
csr.update(43.91880613, 0.9);

double lb = csr.lower_bound();
double ub = csr.upper_bound();

// Compare to test_confidence_sequence_robust.py
EXPECT_NEAR(lb, 0.19224909776696902, 1e-6);
EXPECT_NEAR(ub, 0.93637288977188682, 1e-6);
EXPECT_NEAR(lb, 0.30209151281846858, 1e-6);
EXPECT_NEAR(ub, 0.90219143188334106, 1e-6);

// Test brentq separately
double s = 139.8326745448;
Expand All @@ -118,9 +119,9 @@ TEST(ConfidenceSequenceRobust, PythonEquivalenceCI)
double root = csr.lower.root_brentq(s, thres, memo, min_mu, max_mu);

// Compare root to test_confidence_sequence_robust.py
EXPECT_NEAR(root, 0.31672263211621371, 1e-6);
EXPECT_NEAR(root, 0.30209143008131789, 1e-6);

// Test that root of objective function is 0
auto test_root = csr.lower.log_wealth_mix(root, s, thres, memo) - thres;
EXPECT_NEAR(test_root, 0.0, 1e-6);
EXPECT_NEAR(test_root, 0.0, 1e-2);
}
38 changes: 18 additions & 20 deletions vowpalwabbit/core/tests/epsilon_decay_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,28 @@ TEST(EpsilonDecay, InitWIterations)
"--epsilon_decay", "--model_count=3", "--cb_explore_adf", "--quiet", "--epsilon=0.2", "--random_seed=5"});
}

#if !defined(__APPLE__) && !defined(_WIN32)
TEST(EpsilonDecay, ChampChangeWIterations)
{
const size_t num_iterations = 610;
const size_t seed = 4;
const size_t seed = 36;
const std::vector<uint64_t> swap_after = {500};
const size_t deterministic_champ_switch = 601;
const size_t deterministic_champ_switch = 600;
callback_map test_hooks;

test_hooks.emplace(deterministic_champ_switch - 1,
[&](cb_sim&, VW::workspace& all, VW::multi_ex&)
{
epsilon_decay_data* epsilon_decay = epsilon_decay_test::get_epsilon_decay_data(all);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[0][0].update_count, 88);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[1][0].update_count, 88);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[2][0].update_count, 88);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][0].update_count, 88);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[1][1].update_count, 93);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[2][1].update_count, 93);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][1].update_count, 93);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[2][2].update_count, 94);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][2].update_count, 94);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][3].update_count, 600);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[0][0].update_count, 87);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[1][0].update_count, 87);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[2][0].update_count, 87);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][0].update_count, 87);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[1][1].update_count, 92);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[2][1].update_count, 92);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][1].update_count, 92);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[2][2].update_count, 93);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][2].update_count, 93);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][3].update_count, 599);
return true;
});

Expand All @@ -94,12 +93,12 @@ TEST(EpsilonDecay, ChampChangeWIterations)
EXPECT_EQ(epsilon_decay->conf_seq_estimators[1][0].update_count, 0);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[2][0].update_count, 0);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][0].update_count, 0);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[1][1].update_count, 0);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[2][1].update_count, 0);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][1].update_count, 0);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[2][2].update_count, 0);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][2].update_count, 0);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][3].update_count, 89);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[1][1].update_count, 88);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[2][1].update_count, 88);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][1].update_count, 88);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[2][2].update_count, 93);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][2].update_count, 93);
EXPECT_EQ(epsilon_decay->conf_seq_estimators[3][3].update_count, 94);
return true;
});

Expand All @@ -110,7 +109,6 @@ TEST(EpsilonDecay, ChampChangeWIterations)

EXPECT_GT(ctr.back(), 0.4f);
}
#endif

TEST(EpsilonDecay, UpdateCountWIterations)
{
Expand Down

0 comments on commit 7eace7c

Please sign in to comment.