Skip to content

Commit

Permalink
use axpy in Updater
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Mar 27, 2014
1 parent db808a1 commit e981396
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.mllib.optimization

import scala.collection.mutable.ArrayBuffer

import breeze.linalg.{Vector => BV, DenseVector => BDV}

import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Vectors, Vector}
Expand Down Expand Up @@ -157,11 +159,16 @@ object GradientDescent extends Logging {
for (i <- 1 to numIterations) {
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i).map {
case (y, features) =>
val (grad, loss) = gradient.compute(features, y, weights)
(grad.toBreeze, loss)
}.reduce((a, b) => (a._1 += b._1, a._2 + b._2))
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
.aggregate((BDV.zeros[Double](weights.size), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
val (g, l) = gradient.compute(features, label, weights)
(grad += g.toBreeze, loss + l)
},
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
(grad1 += grad2, loss1 + loss2)
}
)

/**
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.mllib.optimization

import scala.math._

import breeze.linalg.{norm => brzNorm}
import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV}

import org.apache.spark.mllib.linalg.{Vectors, Vector}

Expand Down Expand Up @@ -70,7 +70,9 @@ class SimpleUpdater extends Updater {
iter: Int,
regParam: Double): (Vector, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
val brzWeights = weightsOld.toBreeze - gradient.toBreeze * thisIterStepSize
val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)

(Vectors.fromBreeze(brzWeights), 0)
}
}
Expand Down Expand Up @@ -102,7 +104,8 @@ class L1Updater extends Updater {
regParam: Double): (Vector, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
// Take gradient step
val brzWeights = weightsOld.toBreeze - gradient.toBreeze * thisIterStepSize
val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
// Apply proximal operator (soft thresholding)
val shrinkageVal = regParam * thisIterStepSize
var i = 0
Expand Down Expand Up @@ -133,8 +136,9 @@ class SquaredL2Updater extends Updater {
// w' = w - thisIterStepSize * (gradient + regParam * w)
// w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient
val thisIterStepSize = stepSize / math.sqrt(iter)
val brzWeights = weightsOld.toBreeze * (1.0 - thisIterStepSize * regParam) -
(gradient.toBreeze * thisIterStepSize)
val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
brzWeights :*= (1.0 - thisIterStepSize * regParam)
brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
val norm = brzNorm(brzWeights, 2.0)

(Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm)
Expand Down

0 comments on commit e981396

Please sign in to comment.