Skip to content

Commit

Permalink
use rand_state for shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
lalo committed Dec 15, 2022
1 parent aa8e518 commit e9a476c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
16 changes: 8 additions & 8 deletions test/unit_test/automl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -677,16 +677,16 @@ BOOST_AUTO_TEST_CASE(qbase_unittest_w_iterations)
const set_ns_list_t excl_0{};
BOOST_CHECK_EQUAL_COLLECTIONS(
configs[0].elements.begin(), configs[0].elements.end(), excl_0.begin(), excl_0.end());
const set_ns_list_t excl_1{{'A', 'B', 'B'}};
const set_ns_list_t excl_1{{'A', 'A', 'A'}};
BOOST_CHECK_EQUAL_COLLECTIONS(
configs[1].elements.begin(), configs[1].elements.end(), excl_1.begin(), excl_1.end());
const set_ns_list_t excl_2{{'A', 'A', 'A'}};
const set_ns_list_t excl_2{{'A', 'B', 'C'}};
BOOST_CHECK_EQUAL_COLLECTIONS(
configs[2].elements.begin(), configs[2].elements.end(), excl_2.begin(), excl_2.end());
const set_ns_list_t excl_3{{'A', 'A', 'C'}};
const set_ns_list_t excl_3{{'A', 'B', 'B'}};
BOOST_CHECK_EQUAL_COLLECTIONS(
configs[3].elements.begin(), configs[3].elements.end(), excl_3.begin(), excl_3.end());
const set_ns_list_t excl_9{{'C', 'C', 'C'}};
const set_ns_list_t excl_9{{'A', 'A', 'C'}};
BOOST_CHECK_EQUAL_COLLECTIONS(
configs[9].elements.begin(), configs[9].elements.end(), excl_9.begin(), excl_9.end());

Expand Down Expand Up @@ -717,10 +717,10 @@ BOOST_AUTO_TEST_CASE(qbase_unittest_w_iterations)
BOOST_CHECK_EQUAL(oracle.valid_config_size, 11);
BOOST_CHECK_EQUAL(configs.size(), 11);

const set_ns_list_t excl_4{{'A', 'B', 'B'}, {'C', 'C', 'C'}};
const set_ns_list_t excl_4{{'A', 'A', 'A'}, {'A', 'A', 'C'}};
BOOST_CHECK_EQUAL_COLLECTIONS(
configs[2].elements.begin(), configs[2].elements.end(), excl_4.begin(), excl_4.end());
const set_ns_list_t excl_5{{'A', 'A', 'A'}, {'C', 'C', 'C'}};
const set_ns_list_t excl_5{{'A', 'A', 'C'}, {'A', 'C', 'C'}};
BOOST_CHECK_EQUAL_COLLECTIONS(
configs[3].elements.begin(), configs[3].elements.end(), excl_5.begin(), excl_5.end());

Expand All @@ -741,15 +741,15 @@ BOOST_AUTO_TEST_CASE(qbase_unittest_w_iterations)
BOOST_CHECK_EQUAL(prio_queue.size(), 7);

// excl_7 is now champ
const set_ns_list_t excl_7{{'A', 'C', 'C'}, {'C', 'C', 'C'}};
const set_ns_list_t excl_7{{'A', 'A', 'C'}, {'A', 'B', 'B'}};
interaction_config_manager<config_oracle<qbase_cubic>, VW::confidence_sequence>::apply_new_champ(
oracle, 3, estimators, 0, ns_counter);

BOOST_CHECK_EQUAL_COLLECTIONS(
configs[0].elements.begin(), configs[0].elements.end(), excl_7.begin(), excl_7.end());
BOOST_CHECK_EQUAL_COLLECTIONS(
configs[1].elements.begin(), configs[1].elements.end(), excl_9.begin(), excl_9.end());
const set_ns_list_t excl_6{{'A', 'B', 'B'}, {'A', 'C', 'C'}, {'C', 'C', 'C'}};
const set_ns_list_t excl_6{{'A', 'A', 'A'}, {'A', 'A', 'C'}, {'A', 'B', 'C'}};
BOOST_CHECK_EQUAL_COLLECTIONS(
configs[2].elements.begin(), configs[2].elements.end(), excl_6.begin(), excl_6.end());

Expand Down
15 changes: 14 additions & 1 deletion vowpalwabbit/core/src/reductions/details/automl/automl_oracle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,19 @@ void config_oracle<champdupe_impl>::gen_configs(
}
}

class RNGWrapper
{
public:
RNGWrapper(VW::rand_state* random_state) : _random_state(random_state) {}
typedef size_t result_type;
static size_t min() { return 0; }
static size_t max() { return 42; }
size_t operator()() { return max() * _random_state->get_and_update_random(); }

private:
VW::rand_state* _random_state;
};

template <>
void config_oracle<qbase_cubic>::gen_configs(
const interaction_vec_t&, const std::map<namespace_index, uint64_t>& ns_counter)
Expand All @@ -362,7 +375,7 @@ void config_oracle<qbase_cubic>::gen_configs(

for (size_t i = 0; i < _impl.total_space.size(); i++) { indexes.push_back(i); }

std::shuffle(indexes.begin(), indexes.end(), std::default_random_engine(_impl.random_state->get_current_state()));
std::shuffle(indexes.begin(), indexes.end(), RNGWrapper(_impl.random_state.get()));

for (std::vector<int>::iterator it = indexes.begin(); it != indexes.end(); ++it)
{
Expand Down

0 comments on commit e9a476c

Please sign in to comment.