Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Nov 21, 2014
1 parent d572f00 commit c2bdfc2
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,26 @@ class PythonMLLibAPI extends Serializable {
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
initialWeights: Vector): JList[Object] = {
// Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
learner.disableUncachedWarning()
val model = learner.run(data.rdd, initialWeights)
val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER), initialWeights)
List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
}

/**
* Return the Updater from string
*/
def getUpdateFromString(regType: String): Updater = {
if (regType == "l2") {
new SquaredL2Updater
} else if (regType == "l1") {
new L1Updater
} else if (regType == null || regType == "none") {
new SimpleUpdater
} else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
}

/**
* Java stub for Python mllib LinearRegressionWithSGD.train()
*/
Expand All @@ -99,16 +113,6 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
lrAlg.optimizer.setUpdater(new L1Updater)
} else if (regType == null) {
lrAlg.optimizer.setUpdater(new SimpleUpdater)
} else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
trainRegressionModel(
lrAlg,
data,
Expand Down Expand Up @@ -178,16 +182,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
SVMAlg.optimizer.setUpdater(new L1Updater)
} else if (regType == null) {
SVMAlg.optimizer.setUpdater(new SimpleUpdater)
} else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
SVMAlg.optimizer.setUpdater(getUpdateFromString(regType))
trainRegressionModel(
SVMAlg,
data,
Expand All @@ -213,16 +208,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
LogRegAlg.optimizer.setUpdater(new L1Updater)
} else if (regType == null) {
LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
} else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
LogRegAlg.optimizer.setUpdater(getUpdateFromString(regType))
trainRegressionModel(
LogRegAlg,
data,
Expand All @@ -248,16 +234,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setNumCorrections(corrections)
.setConvergenceTol(tolerance)
if (regType == "l2") {
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
LogRegAlg.optimizer.setUpdater(new L1Updater)
} else if (regType == null) {
LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
} else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
LogRegAlg.optimizer.setUpdater(getUpdateFromString(regType))
trainRegressionModel(
LogRegAlg,
data,
Expand Down Expand Up @@ -289,9 +266,7 @@ class PythonMLLibAPI extends Serializable {
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)
// Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
.disableUncachedWarning()
kMeansAlg.run(data.rdd)
kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
}

/**
Expand Down Expand Up @@ -333,7 +308,7 @@ class PythonMLLibAPI extends Serializable {

if (seed != null) als.setSeed(seed)

val model = als.run(ratingsJRDD.rdd)
val model = als.run(ratingsJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
new MatrixFactorizationModelWrapper(model)
}

Expand Down Expand Up @@ -364,7 +339,7 @@ class PythonMLLibAPI extends Serializable {

if (seed != null) als.setSeed(seed)

val model = als.run(ratingsJRDD.rdd)
val model = als.run(ratingsJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
new MatrixFactorizationModelWrapper(model)
}

Expand Down Expand Up @@ -495,8 +470,8 @@ class PythonMLLibAPI extends Serializable {
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)

DecisionTree.train(data.rdd, strategy)
val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)
DecisionTree.train(cached, strategy)
}

/**
Expand Down Expand Up @@ -526,10 +501,11 @@ class PythonMLLibAPI extends Serializable {
numClassesForClassification = numClasses,
maxBins = maxBins,
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap)
val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)
if (algo == Algo.Classification) {
RandomForest.trainClassifier(data.rdd, strategy, numTrees, featureSubsetStrategy, seed)
RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed)
} else {
RandomForest.trainRegressor(data.rdd, strategy, numTrees, featureSubsetStrategy, seed)
RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,13 @@ class KMeans private (
this
}

/** Whether a warning should be logged if the input RDD is uncached. */
private var warnOnUncachedInput = true

/** Disable warnings about uncached input. */
private[spark] def disableUncachedWarning(): this.type = {
warnOnUncachedInput = false
this
}

/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
*/
def run(data: RDD[Vector]): KMeansModel = {

if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
Expand All @@ -143,7 +134,7 @@ class KMeans private (
norms.unpersist()

// Warn at the end of the run as well, for increased visibility.
if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
this
}

/** Whether a warning should be logged if the input RDD is uncached. */
private var warnOnUncachedInput = true

/** Disable warnings about uncached input. */
private[spark] def disableUncachedWarning(): this.type = {
warnOnUncachedInput = false
this
}

/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
Expand All @@ -161,7 +152,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {

if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
Expand Down Expand Up @@ -241,7 +232,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
}

// Warn at the end of the run as well, for increased visibility.
if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
Expand Down
7 changes: 3 additions & 4 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

from pyspark import SparkContext
from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _to_java_object_rdd
from pyspark.mllib.common import callMLlibFunc, callJavaFunc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector

__all__ = ['KMeansModel', 'KMeans']
Expand Down Expand Up @@ -80,9 +80,8 @@ class KMeans(object):
@classmethod
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
"""Train a k-means clustering model."""
jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector))
model = callMLlibFunc("trainKMeansModel", jrdd.cache(), k, maxIterations, runs,
initializationMode)
model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
runs, initializationMode)
centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/mllib/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pyspark import SparkContext
from pyspark.rdd import RDD
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc

__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']

Expand Down Expand Up @@ -110,7 +110,7 @@ def _prepare(cls, ratings):
ratings = ratings.map(lambda x: Rating(*x))
else:
raise ValueError("rating should be RDD of Rating or tuple/list")
return _to_java_object_rdd(ratings, True)
return ratings

@classmethod
def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False,
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/mllib/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
from numpy import array

from pyspark.mllib.common import callMLlibFunc, _to_java_object_rdd
from pyspark.mllib.common import callMLlibFunc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector

__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel',
Expand Down Expand Up @@ -129,8 +129,7 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
if not isinstance(first, LabeledPoint):
raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
initial_weights = initial_weights or [0.0] * len(data.first().features)
weights, intercept = train_func(_to_java_object_rdd(data).cache(),
_convert_to_vector(initial_weights))
weights, intercept = train_func(data, _convert_to_vector(initial_weights))
return modelClass(weights, intercept)


Expand Down

0 comments on commit c2bdfc2

Please sign in to comment.