diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 70955bffa206b..1433ee9a0dd5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -17,21 +17,30 @@ package org.apache.spark.mllib.fpm -import java.lang.{Iterable => JavaIterable} import java.{util => ju} +import java.lang.{Iterable => JavaIterable} -import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} -class FPGrowthModel[Item](val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable { - def javaFreqItemsets(): JavaRDD[(Array[Item], Long)] = { - freqItemsets.toJavaRDD() +/** + * Model trained by [[FPGrowth]], which holds frequent itemsets. + * @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs + * @tparam Item item type + */ +class FPGrowthModel[Item: ClassTag]( + val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable { + + /** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD]]. */ + def javaFreqItemsets(): JavaPairRDD[Array[Item], java.lang.Long] = { + JavaPairRDD.fromRDD(freqItemsets).asInstanceOf[JavaPairRDD[Array[Item], java.lang.Long]] } } @@ -77,7 +86,7 @@ class FPGrowth private ( * @param data input data set, each element contains a transaction * @return an [[FPGrowthModel]] */ - def run[Item: ClassTag, Basket <: Iterable[Item]](data: RDD[Basket]): FPGrowthModel[Item] = { + def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } @@ -85,14 +94,14 @@ class FPGrowth private ( val minCount = math.ceil(minSupport * count).toLong val numParts = if (numPartitions > 0) numPartitions else data.partitions.length val partitioner = new HashPartitioner(numParts) - val freqItems = genFreqItems[Item, Basket](data, minCount, partitioner) - val freqItemsets = genFreqItemsets[Item, Basket](data, minCount, freqItems, partitioner) + val freqItems = genFreqItems(data, minCount, partitioner) + val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner) new FPGrowthModel(freqItemsets) } - def run[Item: ClassTag, Basket <: JavaIterable[Item]]( - data: JavaRDD[Basket]): FPGrowthModel[Item] = { - this.run(data.rdd.map(_.asScala)) + def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { + implicit val tag = fakeClassTag[Item] + run(data.rdd.map(_.asScala.toArray)) } /** @@ -101,8 +110,8 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute items * @return array of frequent pattern ordered by their frequencies */ - private def genFreqItems[Item: ClassTag, Basket <: Iterable[Item]]( - data: RDD[Basket], + private def genFreqItems[Item: ClassTag]( + data: RDD[Array[Item]], minCount: Long, partitioner: Partitioner): Array[Item] = { data.flatMap { t => @@ -127,8 +136,8 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return an RDD of (frequent itemset, count) */ - private def genFreqItemsets[Item: ClassTag, Basket <: Iterable[Item]]( - data: RDD[Basket], + private def genFreqItemsets[Item: ClassTag]( + data: RDD[Array[Item]], minCount: Long, freqItems: Array[Item], partitioner: Partitioner): RDD[(Array[Item], Long)] = { @@ -152,13 +161,13 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return a map of (target partition, conditional transaction) */ - private def genCondTransactions[Item: ClassTag, Basket <: Iterable[Item]]( - transaction: Basket, + private def genCondTransactions[Item: ClassTag]( + transaction: Array[Item], itemToRank: Map[Item, Int], partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { val output = mutable.Map.empty[Int, Array[Int]] // Filter the basket by frequent items pattern and sort their ranks. - val filtered = transaction.flatMap(itemToRank.get).toArray + val filtered = transaction.flatMap(itemToRank.get) ju.Arrays.sort(filtered) val n = filtered.length var i = n - 1 diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index c0b55691983ae..851707c8a19c4 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -19,78 +19,66 @@ import java.io.Serializable; import java.util.ArrayList; -import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.*; - import com.google.common.collect.Lists; +import static org.junit.Assert.*; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; public class JavaFPGrowthSuite implements Serializable { - private transient JavaSparkContext sc; + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaFPGrowth"); + } - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaFPGrowth"); - } + @After + public void tearDown() { + sc.stop(); + sc = null; + } - @After - public void tearDown() { - sc.stop(); - sc = null; - } + @Test + public void runFPGrowth() { - @Test - public void runFPGrowth() { - JavaRDD> rdd = sc.parallelize(Lists.newArrayList( - Lists.newArrayList("r z h k p".split(" ")), - Lists.newArrayList("z y x w v u t s".split(" ")), - Lists.newArrayList("s x o n r".split(" ")), - Lists.newArrayList("x z y m t s q e".split(" ")), - Lists.newArrayList("z".split(" ")), - Lists.newArrayList("x z y r q t p".split(" "))), 2); + @SuppressWarnings("unchecked") + JavaRDD> rdd = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("r z h k p".split(" ")), + Lists.newArrayList("z y x w v u t s".split(" ")), + Lists.newArrayList("s x o n r".split(" ")), + Lists.newArrayList("x z y m t s q e".split(" ")), + Lists.newArrayList("z".split(" ")), + Lists.newArrayList("x z y r q t p".split(" "))), 2); - FPGrowth fpg = new FPGrowth(); + FPGrowth fpg = new FPGrowth(); - /* - FPGrowthModel model6 = fpg - .setMinSupport(0.9) - .setNumPartitions(1) - .run(rdd); - assert(model6.javaFreqItemsets().count() == 0); + FPGrowthModel model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run(rdd); + assertEquals(0, model6.javaFreqItemsets().count()); - FPGrowthModel model3 = fpg - .setMinSupport(0.5) - .setNumPartitions(2) - .run(rdd); - val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => - (items.toSet, count) - } - val expected = Set( - (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), - (Set("r"), 3L), - (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L), - (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L), - (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L), - (Set("t", "y", "x"), 3L), - (Set("t", "y", "x", "z"), 3L)) - assert(freqItemsets3.toSet === expected) + FPGrowthModel model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd); + assertEquals(18, model3.javaFreqItemsets().count()); - val model2 = fpg - .setMinSupport(0.3) - .setNumPartitions(4) - .run[String](rdd) - assert(model2.freqItemsets.count() == 54) + FPGrowthModel model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run(rdd); + assertEquals(54, model2.javaFreqItemsets().count()); - val model1 = fpg - .setMinSupport(0.1) - .setNumPartitions(8) - .run[String](rdd) - assert(model1.freqItemsets.count() == 625) */ - } -} \ No newline at end of file + FPGrowthModel model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run(rdd); + assertEquals(625, model1.javaFreqItemsets().count()); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 67dc246abfb24..68128284b8608 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -31,7 +31,7 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { "x z y m t s q e", "z", "x z y r q t p") - .map(_.split(" ").toSeq) + .map(_.split(" ")) val rdd = sc.parallelize(transactions, 2).cache() val fpg = new FPGrowth() @@ -39,13 +39,13 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) - .run[String, Seq[String]](rdd) + .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .run[String, Seq[String]](rdd) + .run(rdd) val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => (items.toSet, count) } @@ -62,13 +62,13 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) - .run[String, Seq[String]](rdd) + .run(rdd) assert(model2.freqItemsets.count() === 54) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) - .run[String, Seq[String]](rdd) + .run(rdd) assert(model1.freqItemsets.count() === 625) } @@ -81,7 +81,7 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { "2 4", "1 3", "1 7") - .map(_.split(" ").map(_.toInt).toList) + .map(_.split(" ").map(_.toInt).toArray) val rdd = sc.parallelize(transactions, 2).cache() val fpg = new FPGrowth() @@ -89,13 +89,15 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) - .run[Int, List[Int]](rdd) + .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .run[Int, List[Int]](rdd) + .run(rdd) + assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass, + "frequent itemsets should use primitive arrays") val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => (items.toSet, count) } @@ -108,13 +110,13 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) - .run[Int, List[Int]](rdd) + .run(rdd) assert(model2.freqItemsets.count() === 15) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) - .run[Int, List[Int]](rdd) + .run(rdd) assert(model1.freqItemsets.count() === 65) } }