Skip to content

Commit

Permalink
[SPARK-4604][MLLIB] make MatrixFactorizationModel public
Browse files Browse the repository at this point in the history
User could construct an MF model directly. I added a note about the performance.

Author: Xiangrui Meng <meng@databricks.com>

Closes #3459 from mengxr/SPARK-4604 and squashes the following commits:

f64bcd3 [Xiangrui Meng] organize imports
ed08214 [Xiangrui Meng] check preconditions and unit tests
a624c12 [Xiangrui Meng] make MatrixFactorizationModel public

(cherry picked from commit b5fb141)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
  • Loading branch information
mengxr committed Nov 26, 2014
1 parent e866972 commit b749000
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,45 @@ import java.lang.{Integer => JavaInteger}

import org.jblas.DoubleMatrix

import org.apache.spark.SparkContext._
import org.apache.spark.Logging
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

/**
* Model representing the result of matrix factorization.
*
* Note: If you create the model directly using constructor, please be aware that fast prediction
* requires cached user/product features and their associated partitioners.
*
* @param rank Rank for the features in this model.
* @param userFeatures RDD of tuples where each tuple represents the userId and
* the features computed for this user.
* @param productFeatures RDD of tuples where each tuple represents the productId
* and the features computed for this product.
*/
class MatrixFactorizationModel private[mllib] (
class MatrixFactorizationModel(
val rank: Int,
val userFeatures: RDD[(Int, Array[Double])],
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable {
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {

require(rank > 0)
validateFeatures("User", userFeatures)
validateFeatures("Product", productFeatures)

/** Validates factors and warns users if there are performance concerns. */
private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = {
require(features.first()._2.size == rank,
s"$name feature dimension does not match the rank $rank.")
if (features.partitioner.isEmpty) {
logWarning(s"$name factor does not have a partitioner. "
+ "Prediction on individual records could be slow.")
}
if (features.getStorageLevel == StorageLevel.NONE) {
logWarning(s"$name factor is not cached. Prediction could be slow.")
}
}

/** Predict the rating of one user for one product. */
def predict(user: Int, product: Int): Double = {
val userVector = new DoubleMatrix(userFeatures.lookup(user).head)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.recommendation

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD

class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {

val rank = 2
var userFeatures: RDD[(Int, Array[Double])] = _
var prodFeatures: RDD[(Int, Array[Double])] = _

override def beforeAll(): Unit = {
super.beforeAll()
userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0))))
prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0))))
}

test("constructor") {
val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
assert(model.predict(0, 2) ~== 17.0 relTol 1e-14)

intercept[IllegalArgumentException] {
new MatrixFactorizationModel(1, userFeatures, prodFeatures)
}

val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0))))
intercept[IllegalArgumentException] {
new MatrixFactorizationModel(rank, userFeatures1, prodFeatures)
}

val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0))))
intercept[IllegalArgumentException] {
new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
}
}
}

0 comments on commit b749000

Please sign in to comment.