Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5128][MLLib] Add common used log1pExp API in MLUtils #3915

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.optimization
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
import org.apache.spark.mllib.util.MLUtils

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -64,17 +65,12 @@ class LogisticGradient extends Gradient {
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
val gradient = data.copy
scal(gradientMultiplier, gradient)
val minusYP = if (label > 0) margin else -margin

// log1p is log(1+p) but more accurate for small p
// Following two equations are the same analytically but not numerically, e.g.,
// math.log1p(math.exp(1000)) == Infinity
// 1000 + math.log1p(math.exp(-1000)) == 1000.0
val loss =
if (minusYP < 0) {
math.log1p(math.exp(minusYP))
if (label > 0) {
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
MLUtils.log1pExp(margin)
} else {
math.log1p(math.exp(-minusYP)) + minusYP
MLUtils.log1pExp(margin) - margin
}

(gradient, loss)
Expand All @@ -89,9 +85,10 @@ class LogisticGradient extends Gradient {
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
axpy(gradientMultiplier, data, cumGradient)
if (label > 0) {
math.log1p(math.exp(margin))
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
MLUtils.log1pExp(margin)
} else {
math.log1p(math.exp(margin)) - margin
MLUtils.log1pExp(margin) - margin
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD

/**
Expand Down Expand Up @@ -61,13 +62,8 @@ object LogLoss extends Loss {
data.map { case point =>
val prediction = model.predict(point.features)
val margin = 2.0 * point.label * prediction
// The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically
// stable.
if (margin >= 0) {
2.0 * math.log1p(math.exp(-margin))
} else {
2.0 * (-margin + math.log1p(math.exp(margin)))
}
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}.mean()
}
}
16 changes: 16 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,20 @@ object MLUtils {
}
sqDist
}

/**
* When `x` is positive and large, computing `math.log(1 + math.exp(x))` will lead to arithmetic
* overflow. This will happen when `x > 709.78` which is not a very large number.
* It can be addressed by rewriting the formula into `x + math.log1p(math.exp(-x))` when `x > 0`.
*
* @param x a floating-point value as input.
* @return the result of `math.log(1 + math.exp(x))`.
*/
private[mllib] def log1pExp(x: Double): Double = {
if (x > 0) {
x + math.log1p(math.exp(-x))
} else {
math.log1p(math.exp(x))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@ package org.apache.spark.mllib.util
import java.io.File

import scala.io.Source
import scala.math

import org.scalatest.FunSuite

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
squaredDistance => breezeSquaredDistance}
import breeze.linalg.{squaredDistance => breezeSquaredDistance}
import com.google.common.base.Charsets
import com.google.common.io.Files

import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils

class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
Expand Down Expand Up @@ -204,4 +203,12 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
assert(points.collect().toSet === loaded.collect().toSet)
Utils.deleteRecursively(tempDir)
}

test("log1pExp") {
assert(log1pExp(76.3) ~== math.log1p(math.exp(76.3)) relTol 1E-10)
assert(log1pExp(87296763.234) ~== 87296763.234 relTol 1E-10)

assert(log1pExp(-13.8) ~== math.log1p(math.exp(-13.8)) absTol 1E-10)
assert(log1pExp(-238423789.865) ~== math.log1p(math.exp(-238423789.865)) absTol 1E-10)
}
}