diff --git a/test/unit_test/automl_test.cc b/test/unit_test/automl_test.cc index 824a96b8e70..86d1d875d2c 100644 --- a/test/unit_test/automl_test.cc +++ b/test/unit_test/automl_test.cc @@ -677,16 +677,16 @@ BOOST_AUTO_TEST_CASE(qbase_unittest_w_iterations) const set_ns_list_t excl_0{}; BOOST_CHECK_EQUAL_COLLECTIONS( configs[0].elements.begin(), configs[0].elements.end(), excl_0.begin(), excl_0.end()); - const set_ns_list_t excl_1{{'A', 'B', 'B'}}; + const set_ns_list_t excl_1{{'A', 'A', 'A'}}; BOOST_CHECK_EQUAL_COLLECTIONS( configs[1].elements.begin(), configs[1].elements.end(), excl_1.begin(), excl_1.end()); - const set_ns_list_t excl_2{{'A', 'A', 'A'}}; + const set_ns_list_t excl_2{{'A', 'B', 'C'}}; BOOST_CHECK_EQUAL_COLLECTIONS( configs[2].elements.begin(), configs[2].elements.end(), excl_2.begin(), excl_2.end()); - const set_ns_list_t excl_3{{'A', 'A', 'C'}}; + const set_ns_list_t excl_3{{'A', 'B', 'B'}}; BOOST_CHECK_EQUAL_COLLECTIONS( configs[3].elements.begin(), configs[3].elements.end(), excl_3.begin(), excl_3.end()); - const set_ns_list_t excl_9{{'C', 'C', 'C'}}; + const set_ns_list_t excl_9{{'A', 'A', 'C'}}; BOOST_CHECK_EQUAL_COLLECTIONS( configs[9].elements.begin(), configs[9].elements.end(), excl_9.begin(), excl_9.end()); @@ -717,10 +717,10 @@ BOOST_AUTO_TEST_CASE(qbase_unittest_w_iterations) BOOST_CHECK_EQUAL(oracle.valid_config_size, 11); BOOST_CHECK_EQUAL(configs.size(), 11); - const set_ns_list_t excl_4{{'A', 'B', 'B'}, {'C', 'C', 'C'}}; + const set_ns_list_t excl_4{{'A', 'A', 'A'}, {'A', 'A', 'C'}}; BOOST_CHECK_EQUAL_COLLECTIONS( configs[2].elements.begin(), configs[2].elements.end(), excl_4.begin(), excl_4.end()); - const set_ns_list_t excl_5{{'A', 'A', 'A'}, {'C', 'C', 'C'}}; + const set_ns_list_t excl_5{{'A', 'A', 'C'}, {'A', 'C', 'C'}}; BOOST_CHECK_EQUAL_COLLECTIONS( configs[3].elements.begin(), configs[3].elements.end(), excl_5.begin(), excl_5.end()); @@ -741,7 +741,7 @@ BOOST_AUTO_TEST_CASE(qbase_unittest_w_iterations) BOOST_CHECK_EQUAL(prio_queue.size(), 7); // excl_7 is now champ - const set_ns_list_t excl_7{{'A', 'C', 'C'}, {'C', 'C', 'C'}}; + const set_ns_list_t excl_7{{'A', 'A', 'C'}, {'A', 'B', 'B'}}; interaction_config_manager, VW::confidence_sequence>::apply_new_champ( oracle, 3, estimators, 0, ns_counter); @@ -749,7 +749,7 @@ BOOST_AUTO_TEST_CASE(qbase_unittest_w_iterations) configs[0].elements.begin(), configs[0].elements.end(), excl_7.begin(), excl_7.end()); BOOST_CHECK_EQUAL_COLLECTIONS( configs[1].elements.begin(), configs[1].elements.end(), excl_9.begin(), excl_9.end()); - const set_ns_list_t excl_6{{'A', 'B', 'B'}, {'A', 'C', 'C'}, {'C', 'C', 'C'}}; + const set_ns_list_t excl_6{{'A', 'A', 'A'}, {'A', 'A', 'C'}, {'A', 'B', 'C'}}; BOOST_CHECK_EQUAL_COLLECTIONS( configs[2].elements.begin(), configs[2].elements.end(), excl_6.begin(), excl_6.end()); diff --git a/vowpalwabbit/core/src/reductions/details/automl/automl_oracle.cc b/vowpalwabbit/core/src/reductions/details/automl/automl_oracle.cc index f9e9ac1df52..212bfdf760d 100644 --- a/vowpalwabbit/core/src/reductions/details/automl/automl_oracle.cc +++ b/vowpalwabbit/core/src/reductions/details/automl/automl_oracle.cc @@ -345,6 +345,19 @@ void config_oracle::gen_configs( } } +class RNGWrapper +{ +public: + RNGWrapper(VW::rand_state* random_state) : _random_state(random_state) {} + typedef size_t result_type; + static size_t min() { return 0; } + static size_t max() { return 42; } + size_t operator()() { return max() * _random_state->get_and_update_random(); } + +private: + VW::rand_state* _random_state; +}; + template <> void config_oracle::gen_configs( const interaction_vec_t&, const std::map& ns_counter) @@ -362,7 +375,7 @@ void config_oracle::gen_configs( for (size_t i = 0; i < _impl.total_space.size(); i++) { indexes.push_back(i); } - std::shuffle(indexes.begin(), indexes.end(), std::default_random_engine(_impl.random_state->get_current_state())); + std::shuffle(indexes.begin(), indexes.end(), RNGWrapper(_impl.random_state.get())); for (std::vector::iterator it = indexes.begin(); it != indexes.end(); ++it) {