-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-5598][MLLIB] model save/load for ALS #4422
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,13 +17,17 @@ | |
|
||
package org.apache.spark.mllib.recommendation | ||
|
||
import java.io.IOException | ||
import java.lang.{Integer => JavaInteger} | ||
|
||
import org.apache.hadoop.fs.Path | ||
import org.jblas.DoubleMatrix | ||
|
||
import org.apache.spark.Logging | ||
import org.apache.spark.{Logging, SparkContext} | ||
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} | ||
import org.apache.spark.mllib.util.{Loader, Saveable} | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.{Row, SQLContext} | ||
import org.apache.spark.storage.StorageLevel | ||
|
||
/** | ||
|
@@ -41,7 +45,8 @@ import org.apache.spark.storage.StorageLevel | |
class MatrixFactorizationModel( | ||
val rank: Int, | ||
val userFeatures: RDD[(Int, Array[Double])], | ||
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging { | ||
val productFeatures: RDD[(Int, Array[Double])]) | ||
extends Saveable with Serializable with Logging { | ||
|
||
require(rank > 0) | ||
validateFeatures("User", userFeatures) | ||
|
@@ -125,6 +130,12 @@ class MatrixFactorizationModel( | |
recommend(productFeatures.lookup(product).head, userFeatures, num) | ||
.map(t => Rating(t._1, product, t._2)) | ||
|
||
override val formatVersion: String = "1.0" | ||
|
||
override def save(sc: SparkContext, path: String): Unit = { | ||
MatrixFactorizationModel.SaveLoadV1_0.save(this, path) | ||
} | ||
|
||
private def recommend( | ||
recommendToFeatures: Array[Double], | ||
recommendableFeatures: RDD[(Int, Array[Double])], | ||
|
@@ -136,3 +147,69 @@ class MatrixFactorizationModel( | |
scored.top(num)(Ordering.by(_._2)) | ||
} | ||
} | ||
|
||
private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not private There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice catch |
||
|
||
import org.apache.spark.mllib.util.Loader._ | ||
|
||
override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { | ||
val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path) | ||
val classNameV1_0 = SaveLoadV1_0.thisClassName | ||
(loadedClassName, formatVersion) match { | ||
case (className, "1.0") if className == classNameV1_0 => | ||
SaveLoadV1_0.load(sc, path) | ||
case _ => | ||
throw new IOException("" + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the preferred syntax? I've been wondering about this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here can't you just omit the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume it's to make the lines below line up There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's was from IDE. |
||
"MatrixFactorizationModel.load did not recognize model with" + | ||
s"(class: $loadedClassName, version: $formatVersion). Supported:\n" + | ||
s" ($classNameV1_0, 1.0)") | ||
} | ||
} | ||
|
||
private object SaveLoadV1_0 extends Loader[MatrixFactorizationModel] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does not need to extend Loader There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be okay. It is indeed a Loader, so I can omit the doc of |
||
|
||
private val thisFormatVersion = "1.0" | ||
|
||
val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel" | ||
|
||
/** | ||
* Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and | ||
* product features are saved under `data/products`. | ||
*/ | ||
def save(model: MatrixFactorizationModel, path: String): Unit = { | ||
val sc = model.userFeatures.sparkContext | ||
val sqlContext = new SQLContext(sc) | ||
import sqlContext.implicits.createDataFrame | ||
val metadata = (thisClassName, thisFormatVersion, model.rank) | ||
val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank") | ||
metadataRDD.toJSON.saveAsTextFile(metadataPath(path)) | ||
model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path)) | ||
model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path)) | ||
} | ||
|
||
override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { | ||
val sqlContext = new SQLContext(sc) | ||
val (className, formatVersion, metadata) = loadMetadata(sc, path) | ||
assert(className == thisClassName) | ||
assert(formatVersion == thisFormatVersion) | ||
val rank = metadata.select("rank").first().getInt(0) | ||
val userFeatures = sqlContext.parquetFile(userPath(path)) | ||
.map { case Row(id: Int, features: Seq[Double]) => | ||
(id, features.toArray) | ||
} | ||
val productFeatures = sqlContext.parquetFile(productPath(path)) | ||
.map { case Row(id: Int, features: Seq[Double]) => | ||
(id, features.toArray) | ||
} | ||
new MatrixFactorizationModel(rank, userFeatures, productFeatures) | ||
} | ||
|
||
private def userPath(path: String): String = { | ||
new Path(dataPath(path), "user").toUri.toString | ||
} | ||
|
||
private def productPath(path: String): String = { | ||
new Path(dataPath(path), "product").toUri.toString | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
protected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done