diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 7ead6327486cc..0ea792081086d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { minCount: Long, maxPatternLength: Int, prefixes: List[Int], - database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { + database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = { if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) @@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } - def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { + def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { database .map(getSuffix(prefix, _)) .filter(_.nonEmpty) @@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { */ private def getFreqItemAndCounts( minCount: Long, - database: Array[Array[Int]]): mutable.Map[Int, Long] = { + database: Iterable[Array[Int]]): mutable.Map[Int, Long] = { // TODO: use PrimitiveKeyOpenHashMap val counts = mutable.Map[Int, Long]().withDefaultValue(0L) database.foreach { sequence => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 6f52db7b073ae..e6752332cdeeb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD @@ -43,28 +45,45 @@ class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { + /** + * The maximum number of items allowed in a projected database before local processing. If a + * projected database exceeds this size, another iteration of distributed PrefixSpan is run. + */ + // TODO: make configurable with a better default value, 10000 may be too small + private val maxLocalProjDBSize: Long = 10000 + /** * Constructs a default instance with default parameters * {minSupport: `0.1`, maxPatternLength: `10`}. */ def this() = this(0.1, 10) + /** + * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered + * frequent). + */ + def getMinSupport: Double = this.minSupport + /** * Sets the minimal support level (default: `0.1`). */ def setMinSupport(minSupport: Double): this.type = { - require(minSupport >= 0 && minSupport <= 1, - "The minimum support value must be between 0 and 1, including 0 and 1.") + require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].") this.minSupport = minSupport this } + /** + * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. + */ + def getMaxPatternLength: Double = this.maxPatternLength + /** * Sets maximal pattern length (default: `10`). */ def setMaxPatternLength(maxPatternLength: Int): this.type = { - require(maxPatternLength >= 1, - "The maximum pattern length value must be greater than 0.") + // TODO: support unbounded pattern length when maxPatternLength = 0 + require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.") this.maxPatternLength = maxPatternLength this } @@ -78,81 +97,153 @@ class PrefixSpan private ( * the value of pair is the pattern's count. */ def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { + val sc = sequences.sparkContext + if (sequences.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } - val minCount = getMinCount(sequences) - val lengthOnePatternsAndCounts = - getFreqItemAndCounts(minCount, sequences).collect() - val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( - lengthOnePatternsAndCounts.map(_._1), sequences) - val groupedProjectedDatabase = prefixAndProjectedDatabase - .map(x => (x._1.toSeq, x._2)) - .groupByKey() - .map(x => (x._1.toArray, x._2.toArray)) - val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase) - val lengthOnePatternsAndCountsRdd = - sequences.sparkContext.parallelize( - lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) - val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns - allPatterns + + // Convert min support to a min number of transactions for this dataset + val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + + // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold + val freqItemCounts = sequences + .flatMap(seq => seq.distinct.map(item => (item, 1L))) + .reduceByKey(_ + _) + .filter(_._2 >= minCount) + .collect() + + // Pairs of (length 1 prefix, suffix consisting of frequent items) + val itemSuffixPairs = { + val freqItems = freqItemCounts.map(_._1).toSet + sequences.flatMap { seq => + val filteredSeq = seq.filter(freqItems.contains(_)) + freqItems.flatMap { item => + val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq) + candidateSuffix match { + case suffix if !suffix.isEmpty => Some((List(item), suffix)) + case _ => None + } + } + } + } + + // Accumulator for the computed results to be returned, initialized to the frequent items (i.e. + // frequent length-one prefixes) + var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2)) + + // Remaining work to be locally and distributively processed respectfully + var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) + + // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have + // projected database sizes <= `maxLocalProjDBSize`) + while (pairsForDistributed.count() != 0) { + val (nextPatternAndCounts, nextPrefixSuffixPairs) = + extendPrefixes(minCount, pairsForDistributed) + pairsForDistributed.unpersist() + val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs) + pairsForDistributed = largerPairsPart + pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) + pairsForLocal ++= smallerPairsPart + resultsAccumulator ++= nextPatternAndCounts.collect() + } + + // Process the small projected databases locally + val remainingResults = getPatternsInLocal( + minCount, sc.parallelize(pairsForLocal, 1).groupByKey()) + + (sc.parallelize(resultsAccumulator, 1) ++ remainingResults) + .map { case (pattern, count) => (pattern.toArray, count) } } + /** - * Get the minimum count (sequences count * minSupport). - * @param sequences input data set, contains a set of sequences, - * @return minimum count, + * Partitions the prefix-suffix pairs by projected database size. + * @param prefixSuffixPairs prefix (length n) and suffix pairs, + * @return prefix-suffix pairs partitioned by whether their projected database size is <= or + * greater than [[maxLocalProjDBSize]] */ - private def getMinCount(sequences: RDD[Array[Int]]): Long = { - if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = { + val prefixToSuffixSize = prefixSuffixPairs + .aggregateByKey(0)( + seqOp = { case (count, suffix) => count + suffix.length }, + combOp = { _ + _ }) + val smallPrefixes = prefixToSuffixSize + .filter(_._2 <= maxLocalProjDBSize) + .keys + .collect() + .toSet + val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } + val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) } + (small.collect(), large) } /** - * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the absolute minimum count - * @param sequences original sequences data - * @return array of item and count pair + * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes + * and remaining work. + * @param minCount minimum count + * @param prefixSuffixPairs prefix (length N) and suffix pairs, + * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended + * prefix, corresponding suffix) pairs. */ - private def getFreqItemAndCounts( + private def extendPrefixes( minCount: Long, - sequences: RDD[Array[Int]]): RDD[(Int, Long)] = { - sequences.flatMap(_.distinct.map((_, 1L))) + prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = { + + // (length N prefix, item from suffix) pairs and their corresponding number of occurrences + // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport` + val prefixItemPairAndCounts = prefixSuffixPairs + .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } .reduceByKey(_ + _) .filter(_._2 >= minCount) - } - /** - * Get the frequent prefixes' projected database. - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and projected database - */ - private def getPrefixAndProjectedDatabase( - frequentPrefixes: Array[Int], - sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = { - val filteredSequences = sequences.map { p => - p.filter (frequentPrefixes.contains(_) ) - } - filteredSequences.flatMap { x => - frequentPrefixes.map { y => - val sub = LocalPrefixSpan.getSuffix(y, x) - (Array(y), sub) - }.filter(_._2.nonEmpty) - } + // Map from prefix to set of possible next items from suffix + val prefixToNextItems = prefixItemPairAndCounts + .keys + .groupByKey() + .mapValues(_.toSet) + .collect() + .toMap + + + // Frequent patterns with length N+1 and their corresponding counts + val extendedPrefixAndCounts = prefixItemPairAndCounts + .map { case ((prefix, item), count) => (item :: prefix, count) } + + // Remaining work, all prefixes will have length N+1 + val extendedPrefixAndSuffix = prefixSuffixPairs + .filter(x => prefixToNextItems.contains(x._1)) + .flatMap { case (prefix, suffix) => + val frequentNextItems = prefixToNextItems(prefix) + val filteredSuffix = suffix.filter(frequentNextItems.contains(_)) + frequentNextItems.flatMap { item => + LocalPrefixSpan.getSuffix(item, filteredSuffix) match { + case suffix if !suffix.isEmpty => Some(item :: prefix, suffix) + case _ => None + } + } + } + + (extendedPrefixAndCounts, extendedPrefixAndSuffix) } /** - * calculate the patterns in local. + * Calculate the patterns in local. * @param minCount the absolute minimum count - * @param data patterns and projected sequences data data + * @param data prefixes and projected sequences data data * @return patterns */ private def getPatternsInLocal( minCount: Long, - data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { - data.flatMap { case (prefix, projDB) => - LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) - .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) } + data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = { + data.flatMap { + case (prefix, projDB) => + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) + .map { case (pattern: List[Int], count: Long) => + (pattern.reverse, count) + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 9f107c89f6d80..6dd2dc926acc5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(sequences, 2).cache() - def compareResult( - expectedValue: Array[(Array[Int], Long)], - actualValue: Array[(Array[Int], Long)]): Boolean = { - expectedValue.map(x => (x._1.toSeq, x._2)).toSet == - actualValue.map(x => (x._1.toSeq, x._2)).toSet - } - val prefixspan = new PrefixSpan() .setMinSupport(0.33) .setMaxPatternLength(50) @@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue1, result1.collect())) + assert(compareResults(expectedValue1, result1.collect())) prefixspan.setMinSupport(0.5).setMaxPatternLength(50) val result2 = prefixspan.run(rdd) @@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4), 4L), (Array(5), 3L) ) - assert(compareResult(expectedValue2, result2.collect())) + assert(compareResults(expectedValue2, result2.collect())) prefixspan.setMinSupport(0.33).setMaxPatternLength(2) val result3 = prefixspan.run(rdd) @@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue3, result3.collect())) + assert(compareResults(expectedValue3, result3.collect())) + } + + private def compareResults( + expectedValue: Array[(Array[Int], Long)], + actualValue: Array[(Array[Int], Long)]): Boolean = { + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } + }