diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index aaacf3a8a29e7..1ca0f36c6ac34 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.optimization import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} +import org.apache.spark.mllib.util.MLUtils /** * :: DeveloperApi :: @@ -64,17 +65,12 @@ class LogisticGradient extends Gradient { val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label val gradient = data.copy scal(gradientMultiplier, gradient) - val minusYP = if (label > 0) margin else -margin - - // log1p is log(1+p) but more accurate for small p - // Following two equations are the same analytically but not numerically, e.g., - // math.log1p(math.exp(1000)) == Infinity - // 1000 + math.log1p(math.exp(-1000)) == 1000.0 val loss = - if (minusYP < 0) { - math.log1p(math.exp(minusYP)) + if (label > 0) { + // The following is equivalent to log(1 + exp(margin)) but more numerically stable. + MLUtils.log1pExp(margin) } else { - math.log1p(math.exp(-minusYP)) + minusYP + MLUtils.log1pExp(margin) - margin } (gradient, loss) @@ -89,9 +85,10 @@ class LogisticGradient extends Gradient { val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label axpy(gradientMultiplier, data, cumGradient) if (label > 0) { - math.log1p(math.exp(margin)) + // The following is equivalent to log(1 + exp(margin)) but more numerically stable. + MLUtils.log1pExp(margin) } else { - math.log1p(math.exp(margin)) - margin + MLUtils.log1pExp(margin) - margin } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 7ce9fa6f86c42..55213e695638c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD /** @@ -61,13 +62,8 @@ object LogLoss extends Loss { data.map { case point => val prediction = model.predict(point.features) val margin = 2.0 * point.label * prediction - // The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically - // stable. - if (margin >= 0) { - 2.0 * math.log1p(math.exp(-margin)) - } else { - 2.0 * (-margin + math.log1p(math.exp(margin))) - } + // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. + 2.0 * MLUtils.log1pExp(-margin) }.mean() } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index c7843464a7505..5d6ddd47f67d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -322,4 +322,20 @@ object MLUtils { } sqDist } + + /** + * When `x` is positive and large, computing `math.log(1 + math.exp(x))` will lead to arithmetic + * overflow. This will happen when `x > 709.78` which is not a very large number. + * It can be addressed by rewriting the formula into `x + math.log1p(math.exp(-x))` when `x > 0`. + * + * @param x a floating-point value as input. + * @return the result of `math.log(1 + math.exp(x))`. + */ + private[mllib] def log1pExp(x: Double): Double = { + if (x > 0) { + x + math.log1p(math.exp(-x)) + } else { + math.log1p(math.exp(x)) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 7778847f8b72a..668fc1d43c5d6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -20,18 +20,17 @@ package org.apache.spark.mllib.util import java.io.File import scala.io.Source -import scala.math import org.scalatest.FunSuite -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm, - squaredDistance => breezeSquaredDistance} +import breeze.linalg.{squaredDistance => breezeSquaredDistance} import com.google.common.base.Charsets import com.google.common.io.Files import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { @@ -204,4 +203,12 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { assert(points.collect().toSet === loaded.collect().toSet) Utils.deleteRecursively(tempDir) } + + test("log1pExp") { + assert(log1pExp(76.3) ~== math.log1p(math.exp(76.3)) relTol 1E-10) + assert(log1pExp(87296763.234) ~== 87296763.234 relTol 1E-10) + + assert(log1pExp(-13.8) ~== math.log1p(math.exp(-13.8)) absTol 1E-10) + assert(log1pExp(-238423789.865) ~== math.log1p(math.exp(-238423789.865)) absTol 1E-10) + } }