From 4b617e932616d1f68bfc4da2159ec2722f6fdb45 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Mon, 30 Jan 2023 17:40:55 -0500 Subject: [PATCH 1/3] add metrics + LAS test which is expected to fail --- .../core/tests/cb_las_one_pass_svd_test.cc | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc b/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc index cde7ee943ab..5e5a2eb9c23 100644 --- a/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc +++ b/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc @@ -21,6 +21,35 @@ using internal_action_space_op = VW::cb_explore_adf::cb_explore_adf_base>; +TEST(Las, CheckMatricsWithLASRunsOK) +{ + auto d = 3; + std::vector args{"--cb_explore_adf", "--large_action_space", "--max_actions", std::to_string(d), + "--quiet", "--extra_metrics", "las_metrics.json"}; + auto vw = VW::initialize(VW::make_unique(args)); + + VW::multi_ex examples; + + examples.push_back(VW::read_example(*vw, "1:1:0.1 | 1:0.1 2:0.12 3:0.13")); + examples.push_back(VW::read_example(*vw, "| a_1:0.5 a_2:0.65 a_3:0.12")); + examples.push_back(VW::read_example(*vw, "| a_4:0.8 a_5:0.32 a_6:0.15")); + examples.push_back(VW::read_example(*vw, "| a_7 a_8 a_9")); + examples.push_back(VW::read_example(*vw, "| a_10 a_11 a_12")); + examples.push_back(VW::read_example(*vw, "| a_13 a_14 a_15")); + examples.push_back(VW::read_example(*vw, "| a_16 a_17 a_18")); + + vw->learn(examples); + + auto num_actions = examples[0]->pred.a_s.size(); + + EXPECT_EQ(num_actions, 7); + + vw->finish_example(examples); + + auto metrics = vw->global_metrics.collect_metrics(vw->l); + EXPECT_EQ(metrics.get_uint("cbea_labeled_ex"), 1); +} + TEST(Las, CheckAOSameActionsSameRepresentation) { auto d = 3; From 7f8bb62cac5da3dbb394d1fae17836cb09381fdb Mon Sep 17 00:00:00 2001 From: olgavrou Date: Mon, 30 Jan 2023 17:46:08 -0500 Subject: [PATCH 2/3] add persist metrics func to LAS that does nothing --- .../src/reductions/cb/cb_explore_adf_large_action_space.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc index 1665f0f8ed0..dde9ad76e3e 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc @@ -249,6 +249,11 @@ void generate_Z(const multi_ex& examples, Eigen::MatrixXf& Z, Eigen::MatrixXf& B VW::gram_schmidt(Z); } +template +void persist_metrics(cb_explore_adf_base>&, VW::metric_sink&) +{ +} + template cb_explore_adf_large_action_space::cb_explore_adf_large_action_space(uint64_t d, float gamma_scale, float gamma_exponent, float c, bool apply_shrink_factor, VW::workspace* all, uint64_t seed, size_t total_size, @@ -315,7 +320,7 @@ VW::LEARNER::base_learner* make_las_with_impl(VW::setup_base_i& stack_builder, V .set_output_example_prediction(explore_type::output_example_prediction) .set_update_stats(explore_type::update_stats) .set_print_update(explore_type::print_update) - .set_persist_metrics(explore_type::persist_metrics) + .set_persist_metrics(VW::cb_explore_adf::persist_metrics) .set_save_load(explore_type::save_load) .set_learn_returns_prediction(base->learn_returns_prediction) .build(); From c32f494db5428d7f30a473911dd605d4ed223759 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 31 Jan 2023 01:26:37 -0500 Subject: [PATCH 3/3] las don't template cb adf common --- .../cb/cb_explore_adf_large_action_space.cc | 104 +++++++++--------- .../cb/details/large_action_space.h | 3 - .../core/tests/cb_las_one_pass_svd_test.cc | 1 + 3 files changed, 53 insertions(+), 55 deletions(-) diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc index dde9ad76e3e..400c4722981 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc @@ -112,20 +112,6 @@ bool _test_only_generate_A(VW::workspace* _all, const multi_ex& examples, std::v return (_A.cols() != 0 && _A.rows() != 0); } -template -void cb_explore_adf_large_action_space::predict( - VW::LEARNER::multi_learner& base, multi_ex& examples) -{ - predict_or_learn_impl(base, examples); -} - -template -void cb_explore_adf_large_action_space::learn( - VW::LEARNER::multi_learner& base, multi_ex& examples) -{ - predict_or_learn_impl(base, examples); -} - template void cb_explore_adf_large_action_space::save_load(io_buf& io, bool read, bool text) { @@ -208,21 +194,20 @@ void cb_explore_adf_large_action_space::updat } template -template -void cb_explore_adf_large_action_space::predict_or_learn_impl( +void cb_explore_adf_large_action_space::predict( VW::LEARNER::multi_learner& base, multi_ex& examples) { - if (is_learn) - { - base.learn(examples); - if (base.learn_returns_prediction) { update_example_prediction(examples); } - ++_counter; - } - else - { - base.predict(examples); - update_example_prediction(examples); - } + base.predict(examples); + update_example_prediction(examples); +} + +template +void cb_explore_adf_large_action_space::learn( + VW::LEARNER::multi_learner& base, multi_ex& examples) +{ + base.learn(examples); + if (base.learn_returns_prediction) { update_example_prediction(examples); } + ++_counter; } void generate_Z(const multi_ex& examples, Eigen::MatrixXf& Z, Eigen::MatrixXf& B, uint64_t d, uint64_t seed) @@ -249,11 +234,6 @@ void generate_Z(const multi_ex& examples, Eigen::MatrixXf& Z, Eigen::MatrixXf& B VW::gram_schmidt(Z); } -template -void persist_metrics(cb_explore_adf_base>&, VW::metric_sink&) -{ -} - template cb_explore_adf_large_action_space::cb_explore_adf_large_action_space(uint64_t d, float gamma_scale, float gamma_exponent, float c, bool apply_shrink_factor, VW::workspace* all, uint64_t seed, size_t total_size, @@ -295,37 +275,58 @@ template class cb_explore_adf_large_action_space +void persist_metrics(cb_explore_adf_large_action_space& data, VW::metric_sink& metrics) +{ + metrics.set_uint("cb_las_filtering_factor", data.number_of_non_degenerate_singular_values()); +} + template -VW::LEARNER::base_learner* make_las_with_impl(VW::setup_base_i& stack_builder, VW::LEARNER::multi_learner* base, - implementation_type& impl_type, VW::workspace& all, bool with_metrics, uint64_t d, float gamma_scale, - float gamma_exponent, float c, bool apply_shrink_factor, size_t thread_pool_size, size_t block_size, - bool use_explicit_simd) +void save_load(cb_explore_adf_large_action_space& data, VW::io_buf& io, bool read, bool text) { - using explore_type = cb_explore_adf_base>; + data.save_load(io, read, text); +} +template +void predict(cb_explore_adf_large_action_space& data, VW::LEARNER::multi_learner& base, VW::multi_ex& examples) +{ + data.predict(base, examples); +} + +template +void learn(cb_explore_adf_large_action_space& data, VW::LEARNER::multi_learner& base, VW::multi_ex& examples) +{ + data.learn(base, examples); +} + +template +VW::LEARNER::base_learner* make_las_with_impl(VW::setup_base_i& stack_builder, VW::LEARNER::multi_learner* base, + implementation_type& impl_type, VW::workspace& all, uint64_t d, float gamma_scale, float gamma_exponent, float c, + bool apply_shrink_factor, size_t thread_pool_size, size_t block_size, bool use_explicit_simd) +{ size_t problem_multiplier = 1; float seed = (all.get_random_state()->get_random() + 1) * 10.f; - auto data = VW::make_unique(with_metrics, d, gamma_scale, gamma_exponent, c, apply_shrink_factor, &all, - seed, 1 << all.num_bits, thread_pool_size, block_size, use_explicit_simd, impl_type); + auto data = VW::make_unique>(d, gamma_scale, gamma_exponent, c, + apply_shrink_factor, &all, seed, 1 << all.num_bits, thread_pool_size, block_size, use_explicit_simd, impl_type); - auto* l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, + auto* l = make_reduction_learner(std::move(data), base, learn, predict, stack_builder.get_setupfn_name(VW::reductions::cb_explore_adf_large_action_space_setup)) .set_input_label_type(VW::label_type_t::CB) .set_output_label_type(VW::label_type_t::CB) .set_input_prediction_type(VW::prediction_type_t::ACTION_SCORES) .set_output_prediction_type(VW::prediction_type_t::ACTION_SCORES) .set_params_per_weight(problem_multiplier) - .set_output_example_prediction(explore_type::output_example_prediction) - .set_update_stats(explore_type::update_stats) - .set_print_update(explore_type::print_update) - .set_persist_metrics(VW::cb_explore_adf::persist_metrics) - .set_save_load(explore_type::save_load) + .set_persist_metrics(persist_metrics) + .set_save_load(save_load) .set_learn_returns_prediction(base->learn_returns_prediction) .build(); - return make_base(*l); + return VW::LEARNER::make_base(*l); } +} // namespace VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_setup(VW::setup_base_i& stack_builder) { @@ -413,16 +414,15 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set if (use_two_pass_svd_impl) { auto impl_type = implementation_type::two_pass_svd; - return make_las_with_impl(stack_builder, base, impl_type, all, - all.global_metrics.are_metrics_enabled(), d, gamma_scale, gamma_exponent, c, apply_shrink_factor, - thread_pool_size, block_size, + return make_las_with_impl(stack_builder, base, impl_type, all, d, + gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size, /*use_explicit_simd=*/false); } else { auto impl_type = implementation_type::one_pass_svd; - return make_las_with_impl(stack_builder, base, impl_type, all, - all.global_metrics.are_metrics_enabled(), d, gamma_scale, gamma_exponent, c, apply_shrink_factor, - thread_pool_size, block_size, use_simd_in_one_pass_svd_impl); + return make_las_with_impl(stack_builder, base, impl_type, all, d, + gamma_scale, gamma_exponent, c, apply_shrink_factor, thread_pool_size, block_size, + use_simd_in_one_pass_svd_impl); } } diff --git a/vowpalwabbit/core/src/reductions/cb/details/large_action_space.h b/vowpalwabbit/core/src/reductions/cb/details/large_action_space.h index bd43b081e31..aa465a5fe00 100644 --- a/vowpalwabbit/core/src/reductions/cb/details/large_action_space.h +++ b/vowpalwabbit/core/src/reductions/cb/details/large_action_space.h @@ -160,7 +160,6 @@ class cb_explore_adf_large_action_space ~cb_explore_adf_large_action_space() = default; void save_load(io_buf& io, bool read, bool text); - // Should be called through cb_explore_adf_base for pre/post-processing void predict(VW::LEARNER::multi_learner& base, multi_ex& examples); void learn(VW::LEARNER::multi_learner& base, multi_ex& examples); @@ -183,8 +182,6 @@ class cb_explore_adf_large_action_space } private: - template - void predict_or_learn_impl(VW::LEARNER::multi_learner& base, multi_ex& examples); void update_example_prediction(VW::multi_ex& examples); }; diff --git a/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc b/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc index 5e5a2eb9c23..dfb6b3477f7 100644 --- a/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc +++ b/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc @@ -48,6 +48,7 @@ TEST(Las, CheckMatricsWithLASRunsOK) auto metrics = vw->global_metrics.collect_metrics(vw->l); EXPECT_EQ(metrics.get_uint("cbea_labeled_ex"), 1); + EXPECT_EQ(metrics.get_uint("cb_las_filtering_factor"), 5); } TEST(Las, CheckAOSameActionsSameRepresentation)