Skip to content

Commit

Permalink
fix error with get_loss_function and set_minmax function
Browse files Browse the repository at this point in the history
  • Loading branch information
byronxu99 committed Jan 19, 2023
1 parent 40ba14c commit 0a1a075
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions vowpalwabbit/core/src/loss_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <cfloat>
#include <cmath>
#include <cstdlib>
#include <type_traits>

namespace
{
Expand Down Expand Up @@ -482,14 +483,18 @@ namespace VW
std::unique_ptr<loss_function> get_loss_function(
VW::workspace& all, const std::string& funcName, float function_parameter_0, float function_parameter_1)
{
// For some loss functions, we want to know if all.set_minmax has been overwritten
// Check if the std::function object does not contain VW::details::noop_mm
using noop_mm_fn_ptr_type = typename std::add_pointer<decltype(VW::details::noop_mm)>::type;
auto* minmax_fptr = all.set_minmax.target<noop_mm_fn_ptr_type>();
bool set_minmax_is_not_noop = (minmax_fptr == nullptr || *minmax_fptr != VW::details::noop_mm);

if (funcName == "squared" || funcName == "Huber") { return VW::make_unique<squaredloss>(); }
else if (funcName == "classic") { return VW::make_unique<classic_squaredloss>(); }
else if (funcName == "hinge") { return VW::make_unique<hingeloss>(all.logger); }
else if (funcName == "logistic")
{
// Check if the std::function object contains exactly VW::details::noop_mm
auto minmax_fptr = all.set_minmax.target<decltype(VW::details::noop_mm)>();
if (minmax_fptr == VW::details::noop_mm)
if (set_minmax_is_not_noop)
{
all.sd->min_label = -50;
all.sd->max_label = 50;
Expand All @@ -503,9 +508,7 @@ std::unique_ptr<loss_function> get_loss_function(
else if (funcName == "expectile") { return VW::make_unique<expectileloss>(function_parameter_0); }
else if (funcName == "poisson")
{
// Check if the std::function object contains exactly VW::details::noop_mm
auto minmax_fptr = all.set_minmax.target<decltype(VW::details::noop_mm)>();
if (minmax_fptr == VW::details::noop_mm)
if (set_minmax_is_not_noop)
{
all.sd->min_label = -50;
all.sd->max_label = 50;
Expand Down

0 comments on commit 0a1a075

Please sign in to comment.