diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 9591c7966e06a..96baac07cb5d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -17,15 +17,23 @@ package org.apache.spark.mllib.fpm +import java.lang.{Iterable => JavaIterable} import java.{util => ju} +import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.reflect.ClassTag -import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} -class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable +class FPGrowthModel[Item](val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable { + def javaFreqItemsets(): JavaRDD[(Array[Item], Long)] = { + freqItemsets.toJavaRDD() + } +} /** * This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data. @@ -69,7 +77,7 @@ class FPGrowth private ( * @param data input data set, each element contains a transaction * @return an [[FPGrowthModel]] */ - def run(data: RDD[Array[String]]): FPGrowthModel = { + def run[Item: ClassTag, Basket <: Iterable[Item]](data: RDD[Basket]): FPGrowthModel[Item] = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } @@ -77,24 +85,28 @@ class FPGrowth private ( val minCount = math.ceil(minSupport * count).toLong val numParts = if (numPartitions > 0) numPartitions else data.partitions.length val partitioner = new HashPartitioner(numParts) - val freqItems = genFreqItems(data, minCount, partitioner) - val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner) + val freqItems = genFreqItems[Item, Basket](data, minCount, partitioner) + val freqItemsets = genFreqItemsets[Item, Basket](data, minCount, freqItems, partitioner) new FPGrowthModel(freqItemsets) } + def run[Item: ClassTag, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { + this.run(data.rdd.map(_.asScala)) + } + /** * Generates frequent items by filtering the input data using minimal support level. * @param minCount minimum count for frequent itemsets * @param partitioner partitioner used to distribute items * @return array of frequent pattern ordered by their frequencies */ - private def genFreqItems( - data: RDD[Array[String]], + private def genFreqItems[Item: ClassTag, Basket <: Iterable[Item]]( + data: RDD[Basket], minCount: Long, - partitioner: Partitioner): Array[String] = { + partitioner: Partitioner): Array[Item] = { data.flatMap { t => val uniq = t.toSet - if (t.length != uniq.size) { + if (t.size != uniq.size) { throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.") } t @@ -114,11 +126,11 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return an RDD of (frequent itemset, count) */ - private def genFreqItemsets( - data: RDD[Array[String]], + private def genFreqItemsets[Item: ClassTag, Basket <: Iterable[Item]]( + data: RDD[Basket], minCount: Long, - freqItems: Array[String], - partitioner: Partitioner): RDD[(Array[String], Long)] = { + freqItems: Array[Item], + partitioner: Partitioner): RDD[(Array[Item], Long)] = { val itemToRank = freqItems.zipWithIndex.toMap data.flatMap { transaction => genCondTransactions(transaction, itemToRank, partitioner) @@ -139,13 +151,13 @@ class FPGrowth private ( * @param partitioner partitioner used to distribute transactions * @return a map of (target partition, conditional transaction) */ - private def genCondTransactions( - transaction: Array[String], - itemToRank: Map[String, Int], + private def genCondTransactions[Item: ClassTag, Basket <: Iterable[Item]]( + transaction: Basket, + itemToRank: Map[Item, Int], partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { val output = mutable.Map.empty[Int, Array[Int]] // Filter the basket by frequent items pattern and sort their ranks. - val filtered = transaction.flatMap(itemToRank.get) + val filtered = transaction.flatMap(itemToRank.get).toArray ju.Arrays.sort(filtered) val n = filtered.length var i = n - 1 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 71ef60da6dd32..67dc246abfb24 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -22,7 +22,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { - test("FP-Growth") { + + test("FP-Growth using String type") { val transactions = Seq( "r z h k p", "z y x w v u t s", @@ -30,7 +31,7 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { "x z y m t s q e", "z", "x z y r q t p") - .map(_.split(" ")) + .map(_.split(" ").toSeq) val rdd = sc.parallelize(transactions, 2).cache() val fpg = new FPGrowth() @@ -38,13 +39,13 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) - .run(rdd) + .run[String, Seq[String]](rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) - .run(rdd) + .run[String, Seq[String]](rdd) val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => (items.toSet, count) } @@ -61,13 +62,59 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) - .run(rdd) + .run[String, Seq[String]](rdd) assert(model2.freqItemsets.count() === 54) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) - .run(rdd) + .run[String, Seq[String]](rdd) assert(model1.freqItemsets.count() === 625) } + + test("FP-Growth using Int type") { + val transactions = Seq( + "1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7") + .map(_.split(" ").map(_.toInt).toList) + val rdd = sc.parallelize(transactions, 2).cache() + + val fpg = new FPGrowth() + + val model6 = fpg + .setMinSupport(0.9) + .setNumPartitions(1) + .run[Int, List[Int]](rdd) + assert(model6.freqItemsets.count() === 0) + + val model3 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .run[Int, List[Int]](rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => + (items.toSet, count) + } + val expected = Set( + (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L), + (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L), + (Set(2, 4), 4L), (Set(1, 2, 3), 4L)) + assert(freqItemsets3.toSet === expected) + + val model2 = fpg + .setMinSupport(0.3) + .setNumPartitions(4) + .run[Int, List[Int]](rdd) + assert(model2.freqItemsets.count() === 15) + + val model1 = fpg + .setMinSupport(0.1) + .setNumPartitions(8) + .run[Int, List[Int]](rdd) + assert(model1.freqItemsets.count() === 65) + } }