Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: [LAS] LAS not a cb adf common reduction, fixes metrics with LAS bug #4476

Merged
merged 3 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,6 @@ bool _test_only_generate_A(VW::workspace* _all, const multi_ex& examples, std::v
return (_A.cols() != 0 && _A.rows() != 0);
}

template <typename randomized_svd_impl, typename spanner_impl>
void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::predict(
VW::LEARNER::multi_learner& base, multi_ex& examples)
{
predict_or_learn_impl<false>(base, examples);
}

template <typename randomized_svd_impl, typename spanner_impl>
void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::learn(
VW::LEARNER::multi_learner& base, multi_ex& examples)
{
predict_or_learn_impl<true>(base, examples);
}

template <typename randomized_svd_impl, typename spanner_impl>
void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::save_load(io_buf& io, bool read, bool text)
{
Expand Down Expand Up @@ -208,21 +194,20 @@ void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::updat
}

template <typename randomized_svd_impl, typename spanner_impl>
template <bool is_learn>
void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::predict_or_learn_impl(
void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::predict(
VW::LEARNER::multi_learner& base, multi_ex& examples)
{
if (is_learn)
{
base.learn(examples);
if (base.learn_returns_prediction) { update_example_prediction(examples); }
++_counter;
}
else
{
base.predict(examples);
update_example_prediction(examples);
}
base.predict(examples);
update_example_prediction(examples);
}

template <typename randomized_svd_impl, typename spanner_impl>
void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::learn(
VW::LEARNER::multi_learner& base, multi_ex& examples)
{
base.learn(examples);
if (base.learn_returns_prediction) { update_example_prediction(examples); }
++_counter;
}

void generate_Z(const multi_ex& examples, Eigen::MatrixXf& Z, Eigen::MatrixXf& B, uint64_t d, uint64_t seed)
Expand Down Expand Up @@ -290,37 +275,58 @@ template class cb_explore_adf_large_action_space<two_pass_svd_impl, one_rank_spa
} // namespace cb_explore_adf
} // namespace VW

namespace
{
template <typename T, typename S>
VW::LEARNER::base_learner* make_las_with_impl(VW::setup_base_i& stack_builder, VW::LEARNER::multi_learner* base,
implementation_type& impl_type, VW::workspace& all, bool with_metrics, uint64_t d, float gamma_scale,
float gamma_exponent, float c, bool apply_shrink_factor, size_t thread_pool_size, size_t block_size,
bool use_explicit_simd)
void persist_metrics(cb_explore_adf_large_action_space<T, S>& data, VW::metric_sink& metrics)
{
using explore_type = cb_explore_adf_base<cb_explore_adf_large_action_space<T, S>>;
metrics.set_uint("cb_las_filtering_factor", data.number_of_non_degenerate_singular_values());
}

template <typename T, typename S>
void save_load(cb_explore_adf_large_action_space<T, S>& data, VW::io_buf& io, bool read, bool text)
{
data.save_load(io, read, text);
}

template <typename T, typename S>
void predict(cb_explore_adf_large_action_space<T, S>& data, VW::LEARNER::multi_learner& base, VW::multi_ex& examples)
{
data.predict(base, examples);
}

template <typename T, typename S>
void learn(cb_explore_adf_large_action_space<T, S>& data, VW::LEARNER::multi_learner& base, VW::multi_ex& examples)
{
data.learn(base, examples);
}

template <typename T, typename S>
VW::LEARNER::base_learner* make_las_with_impl(VW::setup_base_i& stack_builder, VW::LEARNER::multi_learner* base,
implementation_type& impl_type, VW::workspace& all, uint64_t d, float gamma_scale, float gamma_exponent, float c,
bool apply_shrink_factor, size_t thread_pool_size, size_t block_size, bool use_explicit_simd)
{
size_t problem_multiplier = 1;

float seed = (all.get_random_state()->get_random() + 1) * 10.f;

auto data = VW::make_unique<explore_type>(with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, &all,
seed, 1 << all.num_bits, thread_pool_size, block_size, use_explicit_simd, impl_type);
auto data = VW::make_unique<cb_explore_adf_large_action_space<T, S>>(d, gamma_scale, gamma_exponent, c,
apply_shrink_factor, &all, seed, 1 << all.num_bits, thread_pool_size, block_size, use_explicit_simd, impl_type);

auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
auto* l = make_reduction_learner(std::move(data), base, learn<T, S>, predict<T, S>,
stack_builder.get_setupfn_name(VW::reductions::cb_explore_adf_large_action_space_setup))
.set_input_label_type(VW::label_type_t::CB)
.set_output_label_type(VW::label_type_t::CB)
.set_input_prediction_type(VW::prediction_type_t::ACTION_SCORES)
.set_output_prediction_type(VW::prediction_type_t::ACTION_SCORES)
.set_params_per_weight(problem_multiplier)
.set_output_example_prediction(explore_type::output_example_prediction)
.set_update_stats(explore_type::update_stats)
.set_print_update(explore_type::print_update)
.set_persist_metrics(explore_type::persist_metrics)
.set_save_load(explore_type::save_load)
.set_persist_metrics(persist_metrics<T, S>)
.set_save_load(save_load<T, S>)
.set_learn_returns_prediction(base->learn_returns_prediction)
.build();
return make_base(*l);
return VW::LEARNER::make_base(*l);
}
} // namespace

VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_setup(VW::setup_base_i& stack_builder)
{
Expand Down Expand Up @@ -408,16 +414,15 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
if (use_two_pass_svd_impl)
{
auto impl_type = implementation_type::two_pass_svd;
return make_las_with_impl<two_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all,
all.global_metrics.are_metrics_enabled(), d, gamma_scale, gamma_exponent, c, apply_shrink_factor,
thread_pool_size, block_size,
return make_las_with_impl<two_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all, d,
gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size,
/*use_explicit_simd=*/false);
}
else
{
auto impl_type = implementation_type::one_pass_svd;
return make_las_with_impl<one_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all,
all.global_metrics.are_metrics_enabled(), d, gamma_scale, gamma_exponent, c, apply_shrink_factor,
thread_pool_size, block_size, use_simd_in_one_pass_svd_impl);
return make_las_with_impl<one_pass_svd_impl, one_rank_spanner_state>(stack_builder, base, impl_type, all, d,
gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size,
use_simd_in_one_pass_svd_impl);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ class cb_explore_adf_large_action_space
~cb_explore_adf_large_action_space() = default;

void save_load(io_buf& io, bool read, bool text);
// Should be called through cb_explore_adf_base for pre/post-processing
void predict(VW::LEARNER::multi_learner& base, multi_ex& examples);
void learn(VW::LEARNER::multi_learner& base, multi_ex& examples);

Expand All @@ -183,8 +182,6 @@ class cb_explore_adf_large_action_space
}

private:
template <bool is_learn>
void predict_or_learn_impl(VW::LEARNER::multi_learner& base, multi_ex& examples);
void update_example_prediction(VW::multi_ex& examples);
};

Expand Down
30 changes: 30 additions & 0 deletions vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,36 @@ using internal_action_space_op =
VW::cb_explore_adf::cb_explore_adf_base<VW::cb_explore_adf::cb_explore_adf_large_action_space<
VW::cb_explore_adf::one_pass_svd_impl, VW::cb_explore_adf::one_rank_spanner_state>>;

TEST(Las, CheckMatricsWithLASRunsOK)
{
auto d = 3;
std::vector<std::string> args{"--cb_explore_adf", "--large_action_space", "--max_actions", std::to_string(d),
"--quiet", "--extra_metrics", "las_metrics.json"};
auto vw = VW::initialize(VW::make_unique<VW::config::options_cli>(args));

VW::multi_ex examples;

examples.push_back(VW::read_example(*vw, "1:1:0.1 | 1:0.1 2:0.12 3:0.13"));
examples.push_back(VW::read_example(*vw, "| a_1:0.5 a_2:0.65 a_3:0.12"));
examples.push_back(VW::read_example(*vw, "| a_4:0.8 a_5:0.32 a_6:0.15"));
examples.push_back(VW::read_example(*vw, "| a_7 a_8 a_9"));
examples.push_back(VW::read_example(*vw, "| a_10 a_11 a_12"));
examples.push_back(VW::read_example(*vw, "| a_13 a_14 a_15"));
examples.push_back(VW::read_example(*vw, "| a_16 a_17 a_18"));

vw->learn(examples);

auto num_actions = examples[0]->pred.a_s.size();

EXPECT_EQ(num_actions, 7);

vw->finish_example(examples);

auto metrics = vw->global_metrics.collect_metrics(vw->l);
EXPECT_EQ(metrics.get_uint("cbea_labeled_ex"), 1);
EXPECT_EQ(metrics.get_uint("cb_las_filtering_factor"), 5);
}

TEST(Las, CheckAOSameActionsSameRepresentation)
{
auto d = 3;
Expand Down