Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-4531] [MLlib] cache serialized java object #3397

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,28 @@ class PythonMLLibAPI extends Serializable {
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
initialWeights: Vector): JList[Object] = {
// Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
learner.disableUncachedWarning()
val model = learner.run(data.rdd, initialWeights)
List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
try {
val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights)
List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
} finally {
data.rdd.unpersist(blocking = false)
}
}

/**
* Return the Updater from string
*/
def getUpdateFromString(regType: String): Updater = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update --> Updater

if (regType == "l2") {
new SquaredL2Updater
} else if (regType == "l1") {
new L1Updater
} else if (regType == null || regType == "none") {
new SimpleUpdater
} else {
throw new IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
}

/**
Expand All @@ -99,16 +117,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
lrAlg.optimizer.setUpdater(new L1Updater)
} else if (regType == null) {
lrAlg.optimizer.setUpdater(new SimpleUpdater)
} else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
lrAlg.optimizer.setUpdater(getUpdateFromString(regType))
trainRegressionModel(
lrAlg,
data,
Expand Down Expand Up @@ -178,16 +187,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
SVMAlg.optimizer.setUpdater(new L1Updater)
} else if (regType == null) {
SVMAlg.optimizer.setUpdater(new SimpleUpdater)
} else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
SVMAlg.optimizer.setUpdater(getUpdateFromString(regType))
trainRegressionModel(
SVMAlg,
data,
Expand All @@ -213,16 +213,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
LogRegAlg.optimizer.setUpdater(new L1Updater)
} else if (regType == null) {
LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
} else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
LogRegAlg.optimizer.setUpdater(getUpdateFromString(regType))
trainRegressionModel(
LogRegAlg,
data,
Expand All @@ -248,16 +239,7 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setNumCorrections(corrections)
.setConvergenceTol(tolerance)
if (regType == "l2") {
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
LogRegAlg.optimizer.setUpdater(new L1Updater)
} else if (regType == null) {
LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
} else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
LogRegAlg.optimizer.setUpdater(getUpdateFromString(regType))
trainRegressionModel(
LogRegAlg,
data,
Expand Down Expand Up @@ -289,9 +271,11 @@ class PythonMLLibAPI extends Serializable {
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)
// Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
.disableUncachedWarning()
kMeansAlg.run(data.rdd)
try {
kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
} finally {
data.rdd.unpersist(blocking = false)
}
}

/**
Expand Down Expand Up @@ -425,16 +409,18 @@ class PythonMLLibAPI extends Serializable {
numPartitions: Int,
numIterations: Int,
seed: Long): Word2VecModelWrapper = {
val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)
val word2vec = new Word2Vec()
.setVectorSize(vectorSize)
.setLearningRate(learningRate)
.setNumPartitions(numPartitions)
.setNumIterations(numIterations)
.setSeed(seed)
val model = word2vec.fit(data)
data.unpersist()
new Word2VecModelWrapper(model)
try {
val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
new Word2VecModelWrapper(model)
} finally {
dataJRDD.rdd.unpersist(blocking = false)
}
}

private[python] class Word2VecModelWrapper(model: Word2VecModel) {
Expand Down Expand Up @@ -495,8 +481,11 @@ class PythonMLLibAPI extends Serializable {
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)

DecisionTree.train(data.rdd, strategy)
try {
DecisionTree.train(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), strategy)
} finally {
data.rdd.unpersist(blocking = false)
}
}

/**
Expand Down Expand Up @@ -526,10 +515,15 @@ class PythonMLLibAPI extends Serializable {
numClassesForClassification = numClasses,
maxBins = maxBins,
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap)
if (algo == Algo.Classification) {
RandomForest.trainClassifier(data.rdd, strategy, numTrees, featureSubsetStrategy, seed)
} else {
RandomForest.trainRegressor(data.rdd, strategy, numTrees, featureSubsetStrategy, seed)
val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
try {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why have the try-finally block? (Or did you mean for the caching to happen inside the block?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have returned value here, I think is easy to 'insert' a logic between final statement before return the value, or it will be:

val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
val model = if (algo == Algo.Classification) {
       RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed)
} else {
       RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed)
}
cached.unpersist(blocking = false)
model

if (algo == Algo.Classification) {
RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed)
} else {
RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed)
}
} finally {
cached.unpersist(blocking = false)
}
}

Expand Down Expand Up @@ -711,7 +705,7 @@ private[spark] object SerDe extends Serializable {
def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
if (obj == this) {
out.write(Opcodes.GLOBAL)
out.write((module + "\n" + name + "\n").getBytes())
out.write((module + "\n" + name + "\n").getBytes)
} else {
pickler.save(this) // it will be memorized by Pickler
saveState(obj, out, pickler)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,13 @@ class KMeans private (
this
}

/** Whether a warning should be logged if the input RDD is uncached. */
private var warnOnUncachedInput = true

/** Disable warnings about uncached input. */
private[spark] def disableUncachedWarning(): this.type = {
warnOnUncachedInput = false
this
}

/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
*/
def run(data: RDD[Vector]): KMeansModel = {

if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
Expand All @@ -143,7 +134,7 @@ class KMeans private (
norms.unpersist()

// Warn at the end of the run as well, for increased visibility.
if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
this
}

/** Whether a warning should be logged if the input RDD is uncached. */
private var warnOnUncachedInput = true

/** Disable warnings about uncached input. */
private[spark] def disableUncachedWarning(): this.type = {
warnOnUncachedInput = false
this
}

/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
Expand All @@ -161,7 +152,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {

if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
Expand Down Expand Up @@ -241,7 +232,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
}

// Warn at the end of the run as well, for increased visibility.
if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
Expand Down
8 changes: 3 additions & 5 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

from pyspark import SparkContext
from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _to_java_object_rdd
from pyspark.mllib.common import callMLlibFunc, callJavaFunc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector

__all__ = ['KMeansModel', 'KMeans']
Expand Down Expand Up @@ -80,10 +80,8 @@ class KMeans(object):
@classmethod
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
"""Train a k-means clustering model."""
# cache serialized data to avoid objects over head in JVM
jcached = _to_java_object_rdd(rdd.map(_convert_to_vector), cache=True)
model = callMLlibFunc("trainKMeansModel", jcached, k, maxIterations, runs,
initializationMode)
model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
runs, initializationMode)
centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])

Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/mllib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,13 @@ def _new_smart_decode(obj):


# this will call the MLlib version of pythonToJava()
def _to_java_object_rdd(rdd, cache=False):
def _to_java_object_rdd(rdd):
""" Return an JavaRDD of Object by unpickling

It will convert each Python object into Java object by Pyrolite, whenever the
RDD is serialized in batch or not.
"""
rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
if cache:
rdd.cache()
return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)


Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/mllib/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pyspark import SparkContext
from pyspark.rdd import RDD
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc

__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']

Expand Down Expand Up @@ -110,7 +110,7 @@ def _prepare(cls, ratings):
ratings = ratings.map(lambda x: Rating(*x))
else:
raise ValueError("rating should be RDD of Rating or tuple/list")
return _to_java_object_rdd(ratings, True)
return ratings

@classmethod
def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False,
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/mllib/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
from numpy import array

from pyspark.mllib.common import callMLlibFunc, _to_java_object_rdd
from pyspark.mllib.common import callMLlibFunc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector

__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel',
Expand Down Expand Up @@ -129,8 +129,7 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
if not isinstance(first, LabeledPoint):
raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
initial_weights = initial_weights or [0.0] * len(data.first().features)
weights, intercept = train_func(_to_java_object_rdd(data, cache=True),
_convert_to_vector(initial_weights))
weights, intercept = train_func(data, _convert_to_vector(initial_weights))
return modelClass(weights, intercept)


Expand Down