Skip to content

Commit

Permalink
implement save/load for MFM
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Feb 6, 2015
1 parent 4d8d070 commit 62fc43c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.mllib.recommendation

import org.apache.spark.Logging
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.recommendation.{ALS => NewALS}
import org.apache.spark.rdd.RDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ package org.apache.spark.mllib.recommendation

import java.lang.{Integer => JavaInteger}

import org.apache.hadoop.fs.Path
import org.jblas.DoubleMatrix

import org.apache.spark.Logging
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel.SaveLoadV1_0
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.storage.StorageLevel

/**
Expand All @@ -41,7 +45,8 @@ import org.apache.spark.storage.StorageLevel
class MatrixFactorizationModel(
val rank: Int,
val userFeatures: RDD[(Int, Array[Double])],
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
val productFeatures: RDD[(Int, Array[Double])])
extends Saveable with Serializable with Logging {

require(rank > 0)
validateFeatures("User", userFeatures)
Expand Down Expand Up @@ -125,6 +130,11 @@ class MatrixFactorizationModel(
recommend(productFeatures.lookup(product).head, userFeatures, num)
.map(t => Rating(t._1, product, t._2))


override def save(sc: SparkContext, path: String): Unit = {
SaveLoadV1_0.save(this, path)
}

private def recommend(
recommendToFeatures: Array[Double],
recommendableFeatures: RDD[(Int, Array[Double])],
Expand All @@ -136,3 +146,53 @@ class MatrixFactorizationModel(
scored.top(num)(Ordering.by(_._2))
}
}

private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {

import org.apache.spark.mllib.util.Loader._

private object SaveLoadV1_0 {

private val thisFormatVersion = "1.0"

private val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel"

def save(model: MatrixFactorizationModel, path: String): Unit = {
val sc = model.userFeatures.sparkContext
val sqlContext = new SQLContext(sc)
import sqlContext.implicits.createDataFrame
val metadata = (thisClassName, thisFormatVersion, model.rank)
val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
}

override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
val sqlContext = new SQLContext(sc)
val (className, formatVersion, metadata) = loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val rank = metadata.select("rank").map { case Row(r: Int) =>
r
}.first()
val userFeatures = sqlContext.parquetFile(userPath(path))
.map { case Row(id: Int, features: Seq[Double]) =>
(id, features.toArray)
}
val productFeatures = sqlContext.parquetFile(productPath(path))
.map { case Row(id: Int, features: Seq[Double]) =>
(id, features.toArray)
}
new MatrixFactorizationModel(r, userFeatures, productFeatures)
}

private def userPath(path: String): String = {
new Path(dataPath(path), "user").toUri.toString
}

private def productPath(path: String): String = {
new Path(dataPath(path), "product").toUri.toString
}
}
}

0 comments on commit 62fc43c

Please sign in to comment.