diff --git a/vowpalwabbit/core/src/reductions/gd.cc b/vowpalwabbit/core/src/reductions/gd.cc index 57bcbb4e003..e3da801b088 100644 --- a/vowpalwabbit/core/src/reductions/gd.cc +++ b/vowpalwabbit/core/src/reductions/gd.cc @@ -386,7 +386,12 @@ template void multipredict(gd& g, base_learner&, VW::example& ec, size_t count, size_t step, VW::polyprediction* pred, bool finalize_predictions) { - assert(g.current_model_state != nullptr); + // tODO: do we care? use the first slot for multipredict for now + if (g.current_model_state == nullptr) + { + std::cerr << "multipredict case" << std::endl; + g.current_model_state = &(g.per_model_states[0]); + } VW::workspace& all = *g.all; for (size_t c = 0; c < count; c++) { @@ -445,6 +450,7 @@ void multipredict(gd& g, base_learner&, VW::example& ec, size_t count, size_t st } ec.ft_offset -= static_cast(step * count); } + g.current_model_state = nullptr; } struct power_data @@ -623,9 +629,15 @@ float get_scale(gd& g, VW::example& /* ec */, float weight) template float sensitivity(gd& g, base_learner& /* base */, VW::example& ec) { - assert(g.current_model_state != nullptr); + // tODO: do we care? use the first slot for multipredict for now + if (g.current_model_state == nullptr) + { + std::cerr << "gd sensitivity case" << std::endl; + g.current_model_state = &(g.per_model_states[0]); + } return get_scale(g, ec, 1.) * sensitivity(g, ec); + g.current_model_state = nullptr; } template void update(gd& g, base_learner&, VW::example& ec) { - assert(g.current_model_state != nullptr); + // tODO: do we care? use the first slot for multipredict for now + if (g.current_model_state == nullptr) + { + std::cerr << "gd update called directly" << std::endl; + g.current_model_state = &(g.per_model_states[0]); + } // invariant: not a test label, importance weight > 0 float update; if ((update = compute_update( @@ -707,7 +724,8 @@ void update(gd& g, base_learner&, VW::example& ec) { // updating weights now to avoid numerical instability sync_weights(*g.all); } -} // namespace GD + g.current_model_state = nullptr; +} template @@ -719,6 +737,7 @@ void learn(gd& g, base_learner& base, VW::example& ec) g.predict(g, base, ec); g.current_model_state = &(g.per_model_states[ec.ft_offset / g.all->weights.stride()]); update(g, base, ec); + assert(g.current_model_state == nullptr); // update clears this pointer g.current_model_state = nullptr; }