diff --git a/vowpalwabbit/core/src/reductions/active.cc b/vowpalwabbit/core/src/reductions/active.cc index a7449affde2..b41717a60df 100644 --- a/vowpalwabbit/core/src/reductions/active.cc +++ b/vowpalwabbit/core/src/reductions/active.cc @@ -32,25 +32,37 @@ using namespace VW::reductions; namespace { float get_active_coin_bias(float example_count, float avg_loss, float alt_label_error_rate_diff, float mellowness) -{//implementation follows https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf - const float mellow_log_e_count_over_e_count = mellowness * (std::log(example_count + 1.f) + 0.0001f) / (example_count + 0.0001f); +{ // implementation follows + // https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf + const float mellow_log_e_count_over_e_count = + mellowness * (std::log(example_count + 1.f) + 0.0001f) / (example_count + 0.0001f); const float sqrt_mellow_lecoec = std::sqrt(mellow_log_e_count_over_e_count); // loss should be in [0,1] avg_loss = VW::math::clamp(avg_loss, 0.f, 1.f); - const float sqrt_avg_loss_plus_sqrt_alt_loss = std::min(1.f, //std::sqrt(avg_loss) + // commented out because two square roots appears to conservative. - std::sqrt(avg_loss + alt_label_error_rate_diff));//emperical variance deflater. - //std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = " << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count - // << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss << std::endl; - - if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss//deflater in use. - + mellow_log_e_count_over_e_count) { return 1; } - //old equation - // const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss * sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff); - // return mellow_log_e_count_over_e_count * rs * rs; - const float sqrt_s = (sqrt_mellow_lecoec + std::sqrt(mellow_log_e_count_over_e_count+4*alt_label_error_rate_diff*mellow_log_e_count_over_e_count)) / 2*alt_label_error_rate_diff; + const float sqrt_avg_loss_plus_sqrt_alt_loss = + std::min(1.f, // std::sqrt(avg_loss) + // commented out because two square roots appears to conservative. + std::sqrt(avg_loss + alt_label_error_rate_diff)); // emperical variance deflater. + // std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = " + // << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count + // << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss + //<< std::endl; + + if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss // deflater in use. + + mellow_log_e_count_over_e_count) + { + return 1; + } + // old equation + // const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss * + // sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff); return + // mellow_log_e_count_over_e_count * rs * rs; + const float sqrt_s = (sqrt_mellow_lecoec + + std::sqrt(mellow_log_e_count_over_e_count + + 4 * alt_label_error_rate_diff * mellow_log_e_count_over_e_count)) / + 2 * alt_label_error_rate_diff; // std::cout << "sqrt_s = " << sqrt_s << std::endl; - return sqrt_s*sqrt_s; + return sqrt_s * sqrt_s; } float query_decision(const active& a, float updates_to_change_prediction, float example_count) @@ -61,8 +73,10 @@ float query_decision(const active& a, float updates_to_change_prediction, float { // const auto weighted_queries = static_cast(a._shared_data->weighted_labeled_examples); const float avg_loss = (static_cast(a._shared_data->sum_loss) / example_count); - //+ std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not following why we need it from the theory. - // std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " << a._shared_data->sum_loss << " example_count = " << example_count << std::endl; + //+ std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not + // following why we need it from the theory. + // std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " << + // a._shared_data->sum_loss << " example_count = " << example_count << std::endl; bias = get_active_coin_bias(example_count, avg_loss, updates_to_change_prediction / example_count, a.active_c0); } // std::cout << "bias = " << bias << std::endl; @@ -122,14 +136,14 @@ void predict_or_learn_active(active& a, learner& base, VW::example& ec) template void predict_or_learn_active_direct(active& a, learner& base, VW::example& ec) -{ +{ if (is_learn) { base.learn(ec); } else { base.predict(ec); } - + if (ec.l.simple.label == FLT_MAX) { - if (std::string(ec.tag.begin(), ec.tag.begin()+6) == "query?") - { + if (std::string(ec.tag.begin(), ec.tag.begin() + 6) == "query?") + { const float threshold = (a._shared_data->max_label + a._shared_data->min_label) * 0.5f; // We want to understand the change in prediction if the label were to be // the opposite of what was predicted. 0 and 1 are used for the expected min @@ -137,16 +151,17 @@ void predict_or_learn_active_direct(active& a, learner& base, VW::example& ec) ec.l.simple.label = (ec.pred.scalar >= threshold) ? a._min_seen_label : a._max_seen_label; ec.confidence = std::abs(ec.pred.scalar - threshold) / base.sensitivity(ec); ec.l.simple.label = FLT_MAX; - ec.pred.scalar = query_decision(a, ec.confidence, static_cast(a._shared_data->weighted_unlabeled_examples)); + ec.pred.scalar = + query_decision(a, ec.confidence, static_cast(a._shared_data->weighted_unlabeled_examples)); } } else - { + { // Update seen labels based on the current example's label. a._min_seen_label = std::min(ec.l.simple.label, a._min_seen_label); a._max_seen_label = std::max(ec.l.simple.label, a._max_seen_label); - } -} + } +} void active_print_result( VW::io::writer* f, float res, float weight, const VW::v_array& tag, VW::io::logger& logger) @@ -232,7 +247,9 @@ std::shared_ptr VW::reductions::active_setup(VW::setup_bas option_group_definition new_options("[Reduction] Active Learning"); new_options.add(make_option("active", active_option).keep().necessary().help("Enable active learning")) .add(make_option("simulation", simulation).help("Active learning simulation mode")) - .add(make_option("direct", direct).help("Active learning via the tag and predictions interface. Tag should start with \"query?\" to get query decision. Returned prediction is either -1 for no or the importance weight for yes.")) + .add(make_option("direct", direct) + .help("Active learning via the tag and predictions interface. Tag should start with \"query?\" to get " + "query decision. Returned prediction is either -1 for no or the importance weight for yes.")) .add(make_option("mellowness", active_c0) .keep() .default_value(1.f)