diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 969e23be21623..ed2f8b41bcae5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -21,23 +21,45 @@ import java.lang.{Integer => JavaInteger} import org.jblas.DoubleMatrix -import org.apache.spark.SparkContext._ +import org.apache.spark.Logging import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * Model representing the result of matrix factorization. * + * Note: If you create the model directly using constructor, please be aware that fast prediction + * requires cached user/product features and their associated partitioners. + * * @param rank Rank for the features in this model. * @param userFeatures RDD of tuples where each tuple represents the userId and * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. */ -class MatrixFactorizationModel private[mllib] ( +class MatrixFactorizationModel( val rank: Int, val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) extends Serializable { + val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging { + + require(rank > 0) + validateFeatures("User", userFeatures) + validateFeatures("Product", productFeatures) + + /** Validates factors and warns users if there are performance concerns. */ + private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = { + require(features.first()._2.size == rank, + s"$name feature dimension does not match the rank $rank.") + if (features.partitioner.isEmpty) { + logWarning(s"$name factor does not have a partitioner. " + + "Prediction on individual records could be slow.") + } + if (features.getStorageLevel == StorageLevel.NONE) { + logWarning(s"$name factor is not cached. Prediction could be slow.") + } + } + /** Predict the rating of one user for one product. */ def predict(user: Int, product: Int): Double = { val userVector = new DoubleMatrix(userFeatures.lookup(user).head) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala new file mode 100644 index 0000000000000..b9caecc904a23 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD + +class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { + + val rank = 2 + var userFeatures: RDD[(Int, Array[Double])] = _ + var prodFeatures: RDD[(Int, Array[Double])] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0)))) + prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0)))) + } + + test("constructor") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + assert(model.predict(0, 2) ~== 17.0 relTol 1e-14) + + intercept[IllegalArgumentException] { + new MatrixFactorizationModel(1, userFeatures, prodFeatures) + } + + val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0)))) + intercept[IllegalArgumentException] { + new MatrixFactorizationModel(rank, userFeatures1, prodFeatures) + } + + val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0)))) + intercept[IllegalArgumentException] { + new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) + } + } +}