diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index eec2f19939b6f..d44384bbf1cb0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -60,8 +60,16 @@ object LogLoss extends Loss { override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { data.map { case point => val prediction = model.predict(point.features) - // Use log1p since it is more stable than explicitly writing log(1 + exp()). - 2.0 * math.log1p(math.exp(-2.0 * point.label * prediction)) + val margin = 2.0 * point.label * prediction + // The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically + // stable. + if (margin >= 0) { + 2.0 * math.log1p(math.exp(-margin)) + //math.log1p(math.exp(w)) + } else { + //w + math.log1p(math.exp(-w)) + 2.0 * (-margin + math.log1p(math.exp(margin))) + } }.mean() } }