Skip to content

Commit

Permalink
improved gain purification during split search instead of once at the…
Browse files Browse the repository at this point in the history
… end
  • Loading branch information
paulbkoch committed Nov 16, 2024
1 parent d23e5a7 commit 5b9d1b6
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 219 deletions.
9 changes: 5 additions & 4 deletions shared/libebm/GenerateTermUpdate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ static ErrorEbm BoostMultiDimensional(BoosterShell* const pBoosterShell,
double* aWeights = nullptr;
double* pGradient = nullptr;
double* pHessian = nullptr;
if(0 != (TermBoostFlags_PurifyUpdate & flags)) {
if(0 != ((TermBoostFlags_PurifyUpdate | TermBoostFlags_PurifyGain) & flags)) {
// allocate the biggest tensor that is possible to split into

// TODO: cache this memory allocation so that we don't do it each time
Expand Down Expand Up @@ -383,7 +383,7 @@ static ErrorEbm BoostMultiDimensional(BoosterShell* const pBoosterShell,
EBM_ASSERT(!std::isnan(*pTotalGain));
EBM_ASSERT(0 <= *pTotalGain);

if(0 != (TermBoostFlags_PurifyUpdate & flags)) {
if(0 != (TermBoostFlags_PurifyUpdate & flags) && 0 == (TermBoostFlags_PurifyGain & flags)) {
Tensor* const pTensor = pBoosterShell->GetInnerTermUpdate();

size_t cDimensions = pTerm->GetCountDimensions();
Expand Down Expand Up @@ -461,8 +461,9 @@ static ErrorEbm BoostMultiDimensional(BoosterShell* const pBoosterShell,

if(/* NaN */ !(std::numeric_limits<double>::min() <= gain)) {
// Purification can push the updates to a point where they are detrimental to the purified gain
// in which case gain can end up slightly negative. If this happens disallow the cuts so that we
// never have negative gain.
// in which case gain can end up slightly negative. If this happens, disallow the cuts so that we
// never have negative gain. For purified updates, if we make no cuts, then the update is zero
// so we don't have to call CalcNegUpdate.

pTensor->Reset();
gain = std::isnan(gain) ? gain : double{0};
Expand Down
Loading

0 comments on commit 5b9d1b6

Please sign in to comment.