Skip to content

Commit

Permalink
replace assert with status quo behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
lalo committed Jul 5, 2022
1 parent b89cf42 commit 1edd275
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions vowpalwabbit/core/src/reductions/gd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,12 @@ template <bool l1, bool audit>
void multipredict(gd& g, base_learner&, VW::example& ec, size_t count, size_t step, VW::polyprediction* pred,
bool finalize_predictions)
{
assert(g.current_model_state != nullptr);
// tODO: do we care? use the first slot for multipredict for now
if (g.current_model_state == nullptr)
{
std::cerr << "multipredict case" << std::endl;
g.current_model_state = &(g.per_model_states[0]);
}
VW::workspace& all = *g.all;
for (size_t c = 0; c < count; c++)
{
Expand Down Expand Up @@ -445,6 +450,7 @@ void multipredict(gd& g, base_learner&, VW::example& ec, size_t count, size_t st
}
ec.ft_offset -= static_cast<uint64_t>(step * count);
}
g.current_model_state = nullptr;
}

struct power_data
Expand Down Expand Up @@ -623,9 +629,15 @@ float get_scale(gd& g, VW::example& /* ec */, float weight)
template <bool sqrt_rate, bool feature_mask_off, bool adax, size_t adaptive, size_t normalized, size_t spare>
float sensitivity(gd& g, base_learner& /* base */, VW::example& ec)
{
assert(g.current_model_state != nullptr);
// tODO: do we care? use the first slot for multipredict for now
if (g.current_model_state == nullptr)
{
std::cerr << "gd sensitivity case" << std::endl;
g.current_model_state = &(g.per_model_states[0]);
}
return get_scale<adaptive>(g, ec, 1.) *
sensitivity<sqrt_rate, feature_mask_off, adax, adaptive, normalized, spare, true>(g, ec);
g.current_model_state = nullptr;
}

template <bool sparse_l2, bool invariant, bool sqrt_rate, bool feature_mask_off, bool adax, size_t adaptive,
Expand Down Expand Up @@ -675,7 +687,12 @@ template <bool sparse_l2, bool invariant, bool sqrt_rate, bool feature_mask_off,
size_t normalized, size_t spare>
void update(gd& g, base_learner&, VW::example& ec)
{
assert(g.current_model_state != nullptr);
// tODO: do we care? use the first slot for multipredict for now
if (g.current_model_state == nullptr)
{
std::cerr << "gd update called directly" << std::endl;
g.current_model_state = &(g.per_model_states[0]);
}
// invariant: not a test label, importance weight > 0
float update;
if ((update = compute_update<sparse_l2, invariant, sqrt_rate, feature_mask_off, adax, adaptive, normalized, spare>(
Expand Down Expand Up @@ -707,7 +724,8 @@ void update(gd& g, base_learner&, VW::example& ec)
{ // updating weights now to avoid numerical instability
sync_weights(*g.all);
}
} // namespace GD
g.current_model_state = nullptr;
}

template <bool sparse_l2, bool invariant, bool sqrt_rate, bool feature_mask_off, bool adax, size_t adaptive,
size_t normalized, size_t spare>
Expand All @@ -719,6 +737,7 @@ void learn(gd& g, base_learner& base, VW::example& ec)
g.predict(g, base, ec);
g.current_model_state = &(g.per_model_states[ec.ft_offset / g.all->weights.stride()]);
update<sparse_l2, invariant, sqrt_rate, feature_mask_off, adax, adaptive, normalized, spare>(g, base, ec);
assert(g.current_model_state == nullptr); // update clears this pointer
g.current_model_state = nullptr;
}

Expand Down

0 comments on commit 1edd275

Please sign in to comment.