diff --git a/vowpalwabbit/core/src/reductions/cbzo.cc b/vowpalwabbit/core/src/reductions/cbzo.cc index ace84e507dd..ea0fff5fe33 100644 --- a/vowpalwabbit/core/src/reductions/cbzo.cc +++ b/vowpalwabbit/core/src/reductions/cbzo.cc @@ -46,22 +46,22 @@ class linear_update_data }; // uint64_t index variant of VW::get_weight -inline float get_weight(VW::workspace& all, uint64_t index, uint32_t offset) +inline float get_weight(VW::workspace& all, uint64_t index) { - return (&all.weights[(index) << all.weights.stride_shift()])[offset]; + return all.weights[(index) << all.weights.stride_shift()]; } // uint64_t index variant of VW::set_weight -inline void set_weight(VW::workspace& all, uint64_t index, uint32_t offset, float value) +inline void set_weight(VW::workspace& all, uint64_t index, float value) { - (&all.weights[(index) << all.weights.stride_shift()])[offset] = value; + all.weights[(index) << all.weights.stride_shift()] = value; } float l1_grad(VW::workspace& all, uint64_t fi) { if (all.no_bias && fi == VW::details::CONSTANT) { return 0.0f; } - float fw = get_weight(all, fi, 0); + float fw = get_weight(all, fi); return fw >= 0.0f ? all.l1_lambda : -all.l1_lambda; } @@ -69,7 +69,7 @@ float l2_grad(VW::workspace& all, uint64_t fi) { if (all.no_bias && fi == VW::details::CONSTANT) { return 0.0f; } - float fw = get_weight(all, fi, 0); + float fw = get_weight(all, fi); return all.l2_lambda * fw; } @@ -77,7 +77,7 @@ inline void accumulate_dotprod(float& dotprod, float x, float& fw) { dotprod += inline float constant_inference(VW::workspace& all) { - float wt = get_weight(all, VW::details::CONSTANT, 0); + float wt = get_weight(all, VW::details::CONSTANT); return wt; } @@ -102,7 +102,7 @@ float inference(VW::workspace& all, VW::example& ec) template void constant_update(cbzo& data, VW::example& ec) { - float fw = get_weight(*data.all, VW::details::CONSTANT, 0); + float fw = get_weight(*data.all, VW::details::CONSTANT); if (feature_mask_off || fw != 0.0f) { float action_centroid = inference(*data.all, ec); @@ -110,19 +110,19 @@ void constant_update(cbzo& data, VW::example& ec) float update = -data.all->eta * (grad + l1_grad(*data.all, VW::details::CONSTANT) + l2_grad(*data.all, VW::details::CONSTANT)); - set_weight(*data.all, VW::details::CONSTANT, 0, fw + update); + set_weight(*data.all, VW::details::CONSTANT, fw + update); } } template void linear_per_feature_update(linear_update_data& upd_data, float x, uint64_t fi) { - float fw = get_weight(*upd_data.all, fi, 0); + float fw = get_weight(*upd_data.all, fi); if (feature_mask_off || fw != 0.0f) { float update = upd_data.mult * (upd_data.part_grad * x + (l1_grad(*upd_data.all, fi) + l2_grad(*upd_data.all, fi))); - set_weight(*upd_data.all, fi, 0, fw + update); + set_weight(*upd_data.all, fi, fw + update); } } @@ -236,7 +236,7 @@ void save_load(cbzo& data, VW::io_buf& model_file, bool read, bool text) if (read) { VW::details::initialize_regressor(all); - if (data.all->initial_constant != 0.0f) { set_weight(all, VW::details::CONSTANT, 0, data.all->initial_constant); } + if (data.all->initial_constant != 0.0f) { set_weight(all, VW::details::CONSTANT, data.all->initial_constant); } } if (model_file.num_files() > 0) { save_load_regressor(all, model_file, read, text); } } @@ -368,7 +368,6 @@ std::shared_ptr VW::reductions::cbzo_setup(VW::setup_base_ auto l = make_bottom_learner(std::move(data), get_learn(all, policy, feature_mask_off), get_predict(all, policy), stack_builder.get_setupfn_name(cbzo_setup), prediction_type_t::PDF, label_type_t::CONTINUOUS) - .set_params_per_weight(0) .set_save_load(save_load) .set_output_example_prediction(output_example_prediction_cbzo) .set_print_update(print_update_cbzo)