diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 5a419d1640292..aaacf3a8a29e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -64,11 +64,17 @@ class LogisticGradient extends Gradient { val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label val gradient = data.copy scal(gradientMultiplier, gradient) + val minusYP = if (label > 0) margin else -margin + + // log1p is log(1+p) but more accurate for small p + // Following two equations are the same analytically but not numerically, e.g., + // math.log1p(math.exp(1000)) == Infinity + // 1000 + math.log1p(math.exp(-1000)) == 1000.0 val loss = - if (label > 0) { - math.log1p(math.exp(margin)) // log1p is log(1+p) but more accurate for small p + if (minusYP < 0) { + math.log1p(math.exp(minusYP)) } else { - math.log1p(math.exp(margin)) - margin + math.log1p(math.exp(-minusYP)) + minusYP } (gradient, loss)