From 452e04913bc6bd123460a34b921590a6381136ae Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 26 Feb 2015 12:07:09 +0800 Subject: [PATCH] Address comment. --- .../spark/mllib/tree/GradientBoostedTrees.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 0149938fa719c..a9c93e181e3ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -226,10 +226,17 @@ object GradientBoostedTrees extends Logging { logDebug("error of gbt = " + loss.computeError(partialModel, input)) if (validate) { - // Record the best model if the reduction in error is more than the validationTol. + // Stop training early if + // 1. Reduction in error is less than the validationTol or + // 2. If the error increases, that is if the model is overfit. // We want the model returned corresponding to the best validation error. val currentValidateError = loss.computeError(partialModel, validationInput) - if (currentValidateError < bestValidateError - validationTol) { + if (bestValidateError - currentValidateError < validationTol) { + return new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, + baseLearners.slice(0, bestM), + baseLearnerWeights.slice(0, bestM)) + } else if (currentValidateError < bestValidateError) { bestValidateError = currentValidateError bestM = m + 1 }