Skip to content

Commit

Permalink
Streaming linear regression
Browse files Browse the repository at this point in the history
- Abstract class to support a variety of streaming regression analyses
- Example concrete class for streaming linear regression
- Example usage: continually train on one data stream and test on
another
  • Loading branch information
freeman-lab committed Jul 10, 2014
1 parent 604f4d7 commit c4b1143
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.mllib

import org.apache.spark.SparkConf
import org.apache.spark.mllib.util.MLStreamingUtils
import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD
import org.apache.spark.streaming.{Seconds, StreamingContext}

/**
* Continually update a model on one stream of data using streaming linear regression,
* while making predictions on another stream of data
*
*/
object StreamingLinearRegression {

def main(args: Array[String]) {

if (args.length != 4) {
System.err.println("Usage: StreamingLinearRegression <trainingData> <testData> <batchDuration> <numFeatures>")
System.exit(1)
}

val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression")
val ssc = new StreamingContext(conf, Seconds(args(2).toLong))

val trainingData = MLStreamingUtils.loadLabeledPointsFromText(ssc, args(0))
val testData = MLStreamingUtils.loadLabeledPointsFromText(ssc, args(1))

val model = StreamingLinearRegressionWithSGD.start(args(3).toInt)

model.trainOn(trainingData)
model.predictOn(testData).print()

ssc.start()
ssc.awaitTermination()

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.regression

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.annotation.Experimental

/**
* Train or predict a linear regression model on streaming data. Training uses
* Stochastic Gradient Descent to update the model based on each new batch of
* incoming data from a DStream (see LinearRegressionWithSGD for model equation)
*
* Each batch of data is assumed to be an RDD of LabeledPoints.
* The number of data points per batch can vary, but the number
* of features must be constant.
*/
@Experimental
class StreamingLinearRegressionWithSGD private (
private var stepSize: Double,
private var numIterations: Int,
private var miniBatchFraction: Double,
private var numFeatures: Int)
extends StreamingRegression[LinearRegressionModel, LinearRegressionWithSGD] with Serializable {

val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction)

var model = algorithm.createModel(Vectors.dense(new Array[Double](numFeatures)), 0.0)

}

/**
* Top-level methods for calling StreamingLinearRegressionWithSGD.
*/
@Experimental
object StreamingLinearRegressionWithSGD {

/**
* Start a streaming Linear Regression model by setting optimization parameters.
*
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param miniBatchFraction Fraction of data to be used per iteration.
* @param numFeatures Number of features per record, must be constant for all batches of data.
*/
def start(
stepSize: Double,
numIterations: Int,
miniBatchFraction: Double,
numFeatures: Int): StreamingLinearRegressionWithSGD = {
new StreamingLinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction, numFeatures)
}

/**
* Start a streaming Linear Regression model by setting optimization parameters.
*
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param numFeatures Number of features per record, must be constant for all batches of data.
*/
def start(
numIterations: Int,
stepSize: Double,
numFeatures: Int): StreamingLinearRegressionWithSGD = {
start(stepSize, numIterations, 1.0, numFeatures)
}

/**
* Start a streaming Linear Regression model by setting optimization parameters.
*
* @param numIterations Number of iterations of gradient descent to run.
* @param numFeatures Number of features per record, must be constant for all batches of data.
*/
def start(
numIterations: Int,
numFeatures: Int): StreamingLinearRegressionWithSGD = {
start(0.1, numIterations, 1.0, numFeatures)
}

/**
* Start a streaming Linear Regression model by setting optimization parameters.
*
* @param numFeatures Number of features per record, must be constant for all batches of data.
*/
def start(
numFeatures: Int): StreamingLinearRegressionWithSGD = {
start(0.1, 100, 1.0, numFeatures)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.regression

import org.apache.spark.Logging
import org.apache.spark.annotation.{Experimental, DeveloperApi}
import org.apache.spark.streaming.dstream.DStream

/**
* :: DeveloperApi ::
* StreamingRegression implements methods for training
* a linear regression model on streaming data, and using it
* for prediction on streaming data.
*
* This class takes as type parameters a GeneralizedLinearModel,
* and a GeneralizedLinearAlgorithm, making it easy to extend to construct
* streaming versions of arbitrary regression analyses. For example usage,
* see StreamingLinearRegressionWithSGD.
*
*/
@DeveloperApi
@Experimental
abstract class StreamingRegression[M <: GeneralizedLinearModel, A <: GeneralizedLinearAlgorithm[M]] extends Logging {

/** The model to be updated and used for prediction. */
var model: M

/** The algorithm to use for updating. */
val algorithm: A

/** Log the latest model parameters and return the model. */
def latest(): M = {
logInfo("Latest model: weights, %s".format(model.weights.toString))
logInfo("Latest model: intercept, %s".format(model.intercept.toString))
model
}

/**
* Update the model by training on batches of data from a DStream.
* This operation registers a DStream for training the model,
* and updates the model based on every subsequent non-empty
* batch of data from the stream.
*
* @param data DStream containing labeled data
*/
def trainOn(data: DStream[LabeledPoint]) {
data.foreachRDD{
rdd =>
if (rdd.count() > 0) {
model = algorithm.run(rdd, model.weights)
logInfo("Model updated")
}
this.latest()
}
}

/**
* Use the model to make predictions on batches of data from a DStream
*
* @param data DStream containing labeled data
* @return DStream containing predictions
*/
def predictOn(data: DStream[LabeledPoint]): DStream[Double] = {
data.map(x => model.predict(x.features))
}

}

0 comments on commit c4b1143

Please sign in to comment.