-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-4531] [MLlib] cache serialized java object #3397
Changes from 7 commits
f1063e1
d572f00
c2bdfc2
dff33e1
7da0332
63b984e
4b52edd
7f6e6ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,10 +74,28 @@ class PythonMLLibAPI extends Serializable { | |
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], | ||
data: JavaRDD[LabeledPoint], | ||
initialWeights: Vector): JList[Object] = { | ||
// Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. | ||
learner.disableUncachedWarning() | ||
val model = learner.run(data.rdd, initialWeights) | ||
List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava | ||
try { | ||
val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights) | ||
List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava | ||
} finally { | ||
data.rdd.unpersist(blocking = false) | ||
} | ||
} | ||
|
||
/** | ||
* Return the Updater from string | ||
*/ | ||
def getUpdateFromString(regType: String): Updater = { | ||
if (regType == "l2") { | ||
new SquaredL2Updater | ||
} else if (regType == "l1") { | ||
new L1Updater | ||
} else if (regType == null || regType == "none") { | ||
new SimpleUpdater | ||
} else { | ||
throw new IllegalArgumentException("Invalid value for 'regType' parameter." | ||
+ " Can only be initialized using the following string values: ['l1', 'l2', None].") | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -99,16 +117,7 @@ class PythonMLLibAPI extends Serializable { | |
.setRegParam(regParam) | ||
.setStepSize(stepSize) | ||
.setMiniBatchFraction(miniBatchFraction) | ||
if (regType == "l2") { | ||
lrAlg.optimizer.setUpdater(new SquaredL2Updater) | ||
} else if (regType == "l1") { | ||
lrAlg.optimizer.setUpdater(new L1Updater) | ||
} else if (regType == null) { | ||
lrAlg.optimizer.setUpdater(new SimpleUpdater) | ||
} else { | ||
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." | ||
+ " Can only be initialized using the following string values: ['l1', 'l2', None].") | ||
} | ||
lrAlg.optimizer.setUpdater(getUpdateFromString(regType)) | ||
trainRegressionModel( | ||
lrAlg, | ||
data, | ||
|
@@ -178,16 +187,7 @@ class PythonMLLibAPI extends Serializable { | |
.setRegParam(regParam) | ||
.setStepSize(stepSize) | ||
.setMiniBatchFraction(miniBatchFraction) | ||
if (regType == "l2") { | ||
SVMAlg.optimizer.setUpdater(new SquaredL2Updater) | ||
} else if (regType == "l1") { | ||
SVMAlg.optimizer.setUpdater(new L1Updater) | ||
} else if (regType == null) { | ||
SVMAlg.optimizer.setUpdater(new SimpleUpdater) | ||
} else { | ||
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." | ||
+ " Can only be initialized using the following string values: ['l1', 'l2', None].") | ||
} | ||
SVMAlg.optimizer.setUpdater(getUpdateFromString(regType)) | ||
trainRegressionModel( | ||
SVMAlg, | ||
data, | ||
|
@@ -213,16 +213,7 @@ class PythonMLLibAPI extends Serializable { | |
.setRegParam(regParam) | ||
.setStepSize(stepSize) | ||
.setMiniBatchFraction(miniBatchFraction) | ||
if (regType == "l2") { | ||
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) | ||
} else if (regType == "l1") { | ||
LogRegAlg.optimizer.setUpdater(new L1Updater) | ||
} else if (regType == null) { | ||
LogRegAlg.optimizer.setUpdater(new SimpleUpdater) | ||
} else { | ||
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." | ||
+ " Can only be initialized using the following string values: ['l1', 'l2', None].") | ||
} | ||
LogRegAlg.optimizer.setUpdater(getUpdateFromString(regType)) | ||
trainRegressionModel( | ||
LogRegAlg, | ||
data, | ||
|
@@ -248,16 +239,7 @@ class PythonMLLibAPI extends Serializable { | |
.setRegParam(regParam) | ||
.setNumCorrections(corrections) | ||
.setConvergenceTol(tolerance) | ||
if (regType == "l2") { | ||
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) | ||
} else if (regType == "l1") { | ||
LogRegAlg.optimizer.setUpdater(new L1Updater) | ||
} else if (regType == null) { | ||
LogRegAlg.optimizer.setUpdater(new SimpleUpdater) | ||
} else { | ||
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." | ||
+ " Can only be initialized using the following string values: ['l1', 'l2', None].") | ||
} | ||
LogRegAlg.optimizer.setUpdater(getUpdateFromString(regType)) | ||
trainRegressionModel( | ||
LogRegAlg, | ||
data, | ||
|
@@ -289,9 +271,11 @@ class PythonMLLibAPI extends Serializable { | |
.setMaxIterations(maxIterations) | ||
.setRuns(runs) | ||
.setInitializationMode(initializationMode) | ||
// Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. | ||
.disableUncachedWarning() | ||
kMeansAlg.run(data.rdd) | ||
try { | ||
kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) | ||
} finally { | ||
data.rdd.unpersist(blocking = false) | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -425,16 +409,18 @@ class PythonMLLibAPI extends Serializable { | |
numPartitions: Int, | ||
numIterations: Int, | ||
seed: Long): Word2VecModelWrapper = { | ||
val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER) | ||
val word2vec = new Word2Vec() | ||
.setVectorSize(vectorSize) | ||
.setLearningRate(learningRate) | ||
.setNumPartitions(numPartitions) | ||
.setNumIterations(numIterations) | ||
.setSeed(seed) | ||
val model = word2vec.fit(data) | ||
data.unpersist() | ||
new Word2VecModelWrapper(model) | ||
try { | ||
val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) | ||
new Word2VecModelWrapper(model) | ||
} finally { | ||
dataJRDD.rdd.unpersist(blocking = false) | ||
} | ||
} | ||
|
||
private[python] class Word2VecModelWrapper(model: Word2VecModel) { | ||
|
@@ -495,8 +481,11 @@ class PythonMLLibAPI extends Serializable { | |
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap, | ||
minInstancesPerNode = minInstancesPerNode, | ||
minInfoGain = minInfoGain) | ||
|
||
DecisionTree.train(data.rdd, strategy) | ||
try { | ||
DecisionTree.train(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), strategy) | ||
} finally { | ||
data.rdd.unpersist(blocking = false) | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -526,10 +515,15 @@ class PythonMLLibAPI extends Serializable { | |
numClassesForClassification = numClasses, | ||
maxBins = maxBins, | ||
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap) | ||
if (algo == Algo.Classification) { | ||
RandomForest.trainClassifier(data.rdd, strategy, numTrees, featureSubsetStrategy, seed) | ||
} else { | ||
RandomForest.trainRegressor(data.rdd, strategy, numTrees, featureSubsetStrategy, seed) | ||
val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) | ||
try { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why have the try-finally block? (Or did you mean for the caching to happen inside the block?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have returned value here, I think is easy to 'insert' a logic between final statement before return the value, or it will be:
|
||
if (algo == Algo.Classification) { | ||
RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed) | ||
} else { | ||
RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed) | ||
} | ||
} finally { | ||
cached.unpersist(blocking = false) | ||
} | ||
} | ||
|
||
|
@@ -711,7 +705,7 @@ private[spark] object SerDe extends Serializable { | |
def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { | ||
if (obj == this) { | ||
out.write(Opcodes.GLOBAL) | ||
out.write((module + "\n" + name + "\n").getBytes()) | ||
out.write((module + "\n" + name + "\n").getBytes) | ||
} else { | ||
pickler.save(this) // it will be memorized by Pickler | ||
saveState(obj, out, pickler) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update --> Updater