diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 7ce9fa6f86c42..e189f81fde878 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -63,10 +63,10 @@ object LogLoss extends Loss { val margin = 2.0 * point.label * prediction // The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically // stable. - if (margin >= 0) { + if (point.label > 0) { 2.0 * math.log1p(math.exp(-margin)) } else { - 2.0 * (-margin + math.log1p(math.exp(margin))) + 2.0 * math.log1p(math.exp(margin)) } }.mean() }