From 1235cfcc9367b546bcf564972a33b769f62da520 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 28 Jul 2015 15:30:29 -0700 Subject: [PATCH] Use Iterable[Array[_]] over Array[Array[_]] for database --- .../spark/mllib/fpm/LocalPrefixSpan.scala | 6 +-- .../apache/spark/mllib/fpm/PrefixSpan.scala | 37 +++++++++---------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 7ead6327486cc..0ea792081086d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { minCount: Long, maxPatternLength: Int, prefixes: List[Int], - database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { + database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = { if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) @@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } - def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { + def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { database .map(getSuffix(prefix, _)) .filter(_.nonEmpty) @@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { */ private def getFreqItemAndCounts( minCount: Long, - database: Array[Array[Int]]): mutable.Map[Int, Long] = { + database: Iterable[Array[Int]]): mutable.Map[Int, Long] = { // TODO: use PrimitiveKeyOpenHashMap val counts = mutable.Map[Int, Long]().withDefaultValue(0L) database.foreach { sequence => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 79f8b651f83b3..bbdc75532ae6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -45,7 +45,11 @@ class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { - private val maxProjectedDBSizeBeforeLocalProcessing: Long = 10000 + /** + * The maximum number of items allowed in a projected database before local processing. If a + * projected database exceeds this size, another iteration of distributed PrefixSpan is run. + */ + private val maxLocalProjDBSize: Long = 10000 /** * Constructs a default instance with default parameters @@ -63,8 +67,7 @@ class PrefixSpan private ( * Sets the minimal support level (default: `0.1`). */ def setMinSupport(minSupport: Double): this.type = { - require(minSupport >= 0 && minSupport <= 1, - "The minimum support value must be in [0, 1].") + require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].") this.minSupport = minSupport this } @@ -79,8 +82,7 @@ class PrefixSpan private ( */ def setMaxPatternLength(maxPatternLength: Int): this.type = { // TODO: support unbounded pattern length when maxPatternLength = 0 - require(maxPatternLength >= 1, - "The maximum pattern length value must be greater than 0.") + require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.") this.maxPatternLength = maxPatternLength this } @@ -119,13 +121,13 @@ class PrefixSpan private ( }.filter(_._2.nonEmpty) } } - var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = splitPrefixSuffixPairs(prefixSuffixPairs) + var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = partitionByProjDBSize(prefixSuffixPairs) while (largePrefixSuffixPairs.count() != 0) { val (nextPatternAndCounts, nextPrefixSuffixPairs) = getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs) largePrefixSuffixPairs.unpersist() - val (smallerPairsPart, largerPairsPart) = splitPrefixSuffixPairs(nextPrefixSuffixPairs) + val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs) largePrefixSuffixPairs = largerPairsPart largePrefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK) smallPrefixSuffixPairs ++= smallerPairsPart @@ -136,7 +138,6 @@ class PrefixSpan private ( val projectedDatabase = smallPrefixSuffixPairs // TODO aggregateByKey .groupByKey() - .mapValues(_.toArray) val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase) allPatternAndCounts ++= nextPatternAndCounts } @@ -145,23 +146,21 @@ class PrefixSpan private ( /** - * Split prefix suffix pairs to two parts: - * Prefixes with projected databases smaller than maxSuffixesBeforeLocalProcessing and - * Prefixes with projected databases larger than maxSuffixesBeforeLocalProcessing + * Partitions the prefix-suffix pairs by projected database size. + * * @param prefixSuffixPairs prefix (length n) and suffix pairs, - * @return small size prefix suffix pairs and big size prefix suffix pairs - * (RDD[prefix, suffix], RDD[prefix, suffix ]) + * @return prefix-suffix pairs partitioned by whether their projected database size is <= or + * greater than [[maxLocalProjDBSize]] */ - private def splitPrefixSuffixPairs( - prefixSuffixPairs: RDD[(List[Int], Array[Int])]): - (RDD[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = { + private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (RDD[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = { val prefixToSuffixSize = prefixSuffixPairs .aggregateByKey(0)( seqOp = { case (count, suffix) => count + suffix.length }, combOp = { _ + _ }) val smallPrefixes = prefixToSuffixSize - .filter(_._2 <= maxProjectedDBSizeBeforeLocalProcessing) - .map(_._1) + .filter(_._2 <= maxLocalProjDBSize) + .keys .collect() .toSet val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } @@ -214,7 +213,7 @@ class PrefixSpan private ( */ private def getPatternsInLocal( minCount: Long, - data: RDD[(List[Int], Array[Array[Int]])]): RDD[(List[Int], Long)] = { + data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = { data.flatMap { case (prefix, projDB) => LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)