diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 82124703da6cd..3239af9d5df47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.optimization -import org.jblas.DoubleMatrix +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * Class used to compute the gradient for a loss function, given a single data point. @@ -26,17 +26,13 @@ abstract class Gradient extends Serializable { /** * Compute the gradient and loss given the features of a single data point. * - * @param data - Feature values for one data point. Column matrix of size dx1 - * where d is the number of features. - * @param label - Label for this data item. - * @param weights - Column matrix containing weights for every feature. - * - * @return A tuple of 2 elements. The first element is a column matrix containing the computed - * gradient and the second element is the loss computed at this data point. + * @param data features for one data point + * @param label label for this data point + * @param weights weights/coefficients corresponding to features * + * @return (gradient: Vector, loss: Double) */ - def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) + def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) } /** @@ -44,12 +40,12 @@ abstract class Gradient extends Serializable { * See also the documentation for the precise formulation. */ class LogisticGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - val margin: Double = -1.0 * data.dot(weights) + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val brzData = data.toBreeze + val brzWeights = data.toBreeze + val margin: Double = -1.0 * brzWeights.dot(brzData) val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - - val gradient = data.mul(gradientMultiplier) + val gradient = brzData * gradientMultiplier val loss = if (label > 0) { math.log(1 + math.exp(margin)) @@ -57,7 +53,7 @@ class LogisticGradient extends Gradient { math.log(1 + math.exp(margin)) - margin } - (gradient, loss) + (Vectors.fromBreeze(gradient), loss) } } @@ -68,14 +64,14 @@ class LogisticGradient extends Gradient { * See also the documentation for the precise formulation. */ class LeastSquaresGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - val diff: Double = data.dot(weights) - label - + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val diff: Double = brzWeights.dot(brzData) - label val loss = diff * diff - val gradient = data.mul(2.0 * diff) + val gradient = brzData * (2.0 * diff) - (gradient, loss) + (Vectors.fromBreeze(gradient), loss) } } @@ -85,19 +81,19 @@ class LeastSquaresGradient extends Gradient { * NOTE: This assumes that the labels are {0,1} */ class HingeGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - - val dotProduct = data.dot(weights) + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val dotProduct = brzWeights.dot(brzData) // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 if (1.0 > labelScaled * dotProduct) { - (data.mul(-labelScaled), 1.0 - labelScaled * dotProduct) + (Vectors.fromBreeze(brzData * (-labelScaled)), 1.0 - labelScaled * dotProduct) } else { - (DoubleMatrix.zeros(1, weights.length), 0.0) + (Vectors.dense(new Array[Double](weights.size)), 0.0) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index b967b22e818d3..e5555cc7f73e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -20,10 +20,10 @@ package org.apache.spark.mllib.optimization import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.jblas.DoubleMatrix - import scala.collection.mutable.ArrayBuffer +import org.apache.spark.mllib.linalg.Vector + /** * Class used to solve an optimization problem using Gradient Descent. * @param gradient Gradient function to be used. @@ -91,8 +91,7 @@ class GradientDescent(var gradient: Gradient, var updater: Updater) this } - def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]) - : Array[Double] = { + def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { val (weights, stochasticLossHistory) = GradientDescent.runMiniBatchSGD( data, @@ -133,14 +132,14 @@ object GradientDescent extends Logging { * stochastic loss computed for every iteration. */ def runMiniBatchSGD( - data: RDD[(Double, Array[Double])], + data: RDD[(Double, Vector)], gradient: Gradient, updater: Updater, stepSize: Double, numIterations: Int, regParam: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) : (Array[Double], Array[Double]) = { + initialWeights: Vector): (Vector, Vector) = { val stochasticLossHistory = new ArrayBuffer[Double](numIterations) @@ -148,7 +147,7 @@ object GradientDescent extends Logging { val miniBatchSize = nexamples * miniBatchFraction // Initialize weights as a column vector - var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*) + var weights = initialWeights.toBreeze.toDenseVector /** * For the first iteration, the regVal will be initialized as sum of sqrt of diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala index 94d30b56f212b..a62aecae5dd0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala @@ -19,11 +19,12 @@ package org.apache.spark.mllib.optimization import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector + trait Optimizer { /** * Solve the provided convex optimization problem. */ - def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]): Array[Double] - + def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index bf8f731459e99..6070071c5c18b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -20,6 +20,10 @@ package org.apache.spark.mllib.optimization import scala.math._ import org.jblas.DoubleMatrix +import breeze.linalg.{norm => brzNorm} + +import org.apache.spark.mllib.linalg.{Vectors, Vector} + /** * Class used to perform steps (weight update) using Gradient Descent methods. * @@ -47,8 +51,12 @@ abstract class Updater extends Serializable { * @return A tuple of 2 elements. The first element is a column matrix containing updated weights, * and the second element is the regularization value computed using updated weights. */ - def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, - regParam: Double): (DoubleMatrix, Double) + def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) } /** @@ -56,11 +64,11 @@ abstract class Updater extends Serializable { * Uses a step-size decreasing with the square root of the number of iterations. */ class SimpleUpdater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { + override def compute(weightsOld: Vector, gradient: Vector, + stepSize: Double, iter: Int, regParam: Double): (Vector, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) - val step = gradient.mul(thisIterStepSize) - (weightsOld.sub(step), 0) + val brzWeights = weightsOld.toBreeze - gradient.toBreeze * thisIterStepSize + (Vectors.fromBreeze(brzWeights), 0) } } @@ -83,19 +91,23 @@ class SimpleUpdater extends Updater { * Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal) */ class L1Updater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) - val step = gradient.mul(thisIterStepSize) // Take gradient step - val newWeights = weightsOld.sub(step) + val brzWeights = weightsOld.toBreeze - gradient.toBreeze * thisIterStepSize // Apply proximal operator (soft thresholding) val shrinkageVal = regParam * thisIterStepSize - (0 until newWeights.length).foreach { i => - val wi = newWeights.get(i) - newWeights.put(i, signum(wi) * max(0.0, abs(wi) - shrinkageVal)) + (0 until brzWeights.length).foreach { i => + val wi = brzWeights(i) + brzWeights(i) = signum(wi) * max(0.0, abs(wi) - shrinkageVal) } - (newWeights, newWeights.norm1 * regParam) + + (Vectors.fromBreeze(brzWeights), brzNorm(brzWeights, 1.0) * regParam) } } @@ -105,16 +117,22 @@ class L1Updater extends Updater { * Uses a step-size decreasing with the square root of the number of iterations. */ class SquaredL2Updater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { - val thisIterStepSize = stepSize / math.sqrt(iter) - val step = gradient.mul(thisIterStepSize) + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (DoubleMatrix, Double) = { // add up both updates from the gradient of the loss (= step) as well as // the gradient of the regularizer (= regParam * weightsOld) // w' = w - thisIterStepSize * (gradient + regParam * w) // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient - val newWeights = weightsOld.mul(1.0 - thisIterStepSize * regParam).sub(step) - (newWeights, 0.5 * pow(newWeights.norm2, 2.0) * regParam) + val thisIterStepSize = stepSize / math.sqrt(iter) + val brzWeights = weightsOld.toBreeze * (1.0 - thisIterStepSize * regParam) - + (gradient.toBreeze * thisIterStepSize) + val norm = brzNorm(brzWeights, 2.0) + + (Vectors.fromBreeze(newWeights), 0.5 * regParam * norm * norm) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index b9621530efa22..e4e710a726308 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,11 +17,12 @@ package org.apache.spark.mllib.regression +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} + import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ - -import org.jblas.DoubleMatrix +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * GeneralizedLinearModel (GLM) represents a model trained using @@ -31,12 +32,9 @@ import org.jblas.DoubleMatrix * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. */ -abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: Double) +abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double) extends Serializable { - // Create a column vector that can be used for predictions - private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) - /** * Predict the result given a data point and the weights learned. * @@ -44,8 +42,7 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: * @param weightMatrix Column vector containing the weights of the model * @param intercept Intercept of the model. */ - def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double): Double + def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double /** * Predict values for the given data set using the model trained. @@ -53,16 +50,13 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction */ - def predict(testData: RDD[Array[Double]]): RDD[Double] = { + def predict(testData: RDD[Vector]): RDD[Double] = { // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. - val localWeights = weightsMatrix + val localWeights = weights val localIntercept = intercept - testData.map { x => - val dataMatrix = new DoubleMatrix(1, x.length, x:_*) - predictPoint(dataMatrix, localWeights, localIntercept) - } + testData.map(v => predictPoint(v, localWeights, localIntercept)) } /** @@ -71,9 +65,8 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: * @param testData array representing a single data point * @return Double prediction from the trained model */ - def predict(testData: Array[Double]): Double = { - val dataMat = new DoubleMatrix(1, testData.length, testData:_*) - predictPoint(dataMat, weightsMatrix, intercept) + def predict(testData: Vector): Double = { + predictPoint(testData, weights, intercept) } } @@ -95,7 +88,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Create a model given the weights and intercept */ - protected def createModel(weights: Array[Double], intercept: Double): M + protected def createModel(weights: Vector, intercept: Double): M /** * Set if the algorithm should add an intercept. Default true. @@ -117,17 +110,26 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. */ - def run(input: RDD[LabeledPoint]) : M = { - val nfeatures: Int = input.first().features.length - val initialWeights = new Array[Double](nfeatures) + def run(input: RDD[LabeledPoint]): M = { + val numFeatures: Int = input.first().features.size + val initialWeights = Vectors.dense(new Array[Double](numFeatures)) run(input, initialWeights) } + private def prependOne(vector: Vector): Vector = { + val vectorWithIntercept = vector match { + case dv: BDV[Double] => BDV.vertcat(BDV.ones(1), dv) + case sv: BSV[Double] => BSV.vertcat(new BSV[Double](Array(0), Array(1.0), 1), sv) + case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + Vectors.fromBreeze(vectorWithIntercept) + } + /** * Run the algorithm with the configured parameters on an input RDD * of LabeledPoint entries starting from the initial weights provided. */ - def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = { + def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { // Check the data properties before running the optimizer if (validateData && !validators.forall(func => func(input))) { @@ -136,25 +138,24 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0))) + input.map(labeledPoint => (labeledPoint.label, prependOne(labeledPoint.features))) } else { input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) } val initialWeightsWithIntercept = if (addIntercept) { - initialWeights.+:(1.0) + prependOne(initialWeights) } else { initialWeights } - val weights = optimizer.optimize(data, initialWeightsWithIntercept) - val intercept = weights(0) - val weightsScaled = weights.tail + val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) + val brzWeightsWithIntercept = weightsWithIntercept.toBreeze + val intercept = if (addIntercept) brzWeightsWithIntercept(0) else 0.0 + val brzWeights = if (addIntercept) brzWeightsWithIntercept(1 to -1) else brzWeightsWithIntercept - val model = createModel(weightsScaled, intercept) + val model = createModel(Vectors.fromBreeze(brzWeights), intercept) - logInfo("Final model weights " + model.weights.mkString(",")) - logInfo("Final model intercept " + model.intercept) model } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 1a18292fe3f3b..3deab1ab785b9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -17,14 +17,16 @@ package org.apache.spark.mllib.regression +import org.apache.spark.mllib.linalg.Vector + /** * Class that represents the features and labels of a data point. * * @param label Label for this data point. * @param features List of features for this data point. */ -case class LabeledPoint(label: Double, features: Array[Double]) { +case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { - "LabeledPoint(%s, %s)".format(label, features.mkString("[", ", ", "]")) + "LabeledPoint(%s, %s)".format(label, features) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index c504d3d40c773..b8ce4602b53ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -21,8 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix +import org.apache.spark.mllib.linalg.Vector /** * Regression model trained using RidgeRegression. @@ -31,7 +30,7 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class RidgeRegressionModel( - override val weights: Array[Double], + override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {