diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 5c563262e184d..79f8b651f83b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -102,9 +102,10 @@ class PrefixSpan private ( val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong val itemCounts = sequences - .flatMap(_.distinct.map((_, 1L))) + .flatMap(seq => seq.distinct.map(item => (item, 1L))) .reduceByKey(_ + _) .filter(_._2 >= minCount) + var allPatternAndCounts = itemCounts.map(x => (List(x._1), x._2)) val prefixSuffixPairs = { val frequentItems = itemCounts.map(_._1).collect() @@ -114,14 +115,12 @@ class PrefixSpan private ( candidates.flatMap { x => frequentItems.map { y => val sub = LocalPrefixSpan.getSuffix(y, x) - (ArrayBuffer(y), sub) + (List(y), sub) }.filter(_._2.nonEmpty) } } - prefixSuffixPairs.persist(StorageLevel.MEMORY_AND_DISK) - - var allPatternAndCounts = itemCounts.map(x => (ArrayBuffer(x._1), x._2)) var (smallPrefixSuffixPairs, largePrefixSuffixPairs) = splitPrefixSuffixPairs(prefixSuffixPairs) + while (largePrefixSuffixPairs.count() != 0) { val (nextPatternAndCounts, nextPrefixSuffixPairs) = getPatternCountsAndPrefixSuffixPairs(minCount, largePrefixSuffixPairs) @@ -135,9 +134,9 @@ class PrefixSpan private ( if (smallPrefixSuffixPairs.count() > 0) { val projectedDatabase = smallPrefixSuffixPairs - .map(x => (x._1.toSeq, x._2)) + // TODO aggregateByKey .groupByKey() - .map(x => (x._1.toArray, x._2.toArray)) + .mapValues(_.toArray) val nextPatternAndCounts = getPatternsInLocal(minCount, projectedDatabase) allPatternAndCounts ++= nextPatternAndCounts } @@ -154,8 +153,8 @@ class PrefixSpan private ( * (RDD[prefix, suffix], RDD[prefix, suffix ]) */ private def splitPrefixSuffixPairs( - prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]): - (RDD[(ArrayBuffer[Int], Array[Int])], RDD[(ArrayBuffer[Int], Array[Int])]) = { + prefixSuffixPairs: RDD[(List[Int], Array[Int])]): + (RDD[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = { val prefixToSuffixSize = prefixSuffixPairs .aggregateByKey(0)( seqOp = { case (count, suffix) => count + suffix.length }, @@ -179,14 +178,14 @@ class PrefixSpan private ( */ private def getPatternCountsAndPrefixSuffixPairs( minCount: Long, - prefixSuffixPairs: RDD[(ArrayBuffer[Int], Array[Int])]): - (RDD[(ArrayBuffer[Int], Long)], RDD[(ArrayBuffer[Int], Array[Int])]) = { + prefixSuffixPairs: RDD[(List[Int], Array[Int])]): + (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = { val prefixAndFrequentItemAndCounts = prefixSuffixPairs .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } .reduceByKey(_ + _) .filter(_._2 >= minCount) val patternAndCounts = prefixAndFrequentItemAndCounts - .map { case ((prefix, item), count) => (prefix :+ item, count) } + .map { case ((prefix, item), count) => (item :: prefix, count) } val prefixToFrequentNextItemsMap = prefixAndFrequentItemAndCounts .keys .groupByKey() @@ -201,23 +200,12 @@ class PrefixSpan private ( frequentNextItems.flatMap { item => val suffix = LocalPrefixSpan.getSuffix(item, filteredSuffix) if (suffix.isEmpty) None - else Some(prefix :+ item, suffix) + else Some(item :: prefix, suffix) } } (patternAndCounts, nextPrefixSuffixPairs) } - /** - * Get the frequent prefixes and suffix pairs. - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and suffix pairs. - */ - private def getPrefixSuffixPairs( - frequentPrefixes: Array[Int], - sequences: RDD[Array[Int]]): RDD[(ArrayBuffer[Int], Array[Int])] = { - } - /** * calculate the patterns in local. * @param minCount the absolute minimum count @@ -226,13 +214,13 @@ class PrefixSpan private ( */ private def getPatternsInLocal( minCount: Long, - data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(ArrayBuffer[Int], Long)] = { + data: RDD[(List[Int], Array[Array[Int]])]): RDD[(List[Int], Long)] = { data.flatMap { - case (prefix, projDB) => - LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) - .map { case (pattern: List[Int], count: Long) => - (pattern.toArray.reverse.to[ArrayBuffer], count) - } + case (prefix, projDB) => + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) + .map { case (pattern: List[Int], count: Long) => + (pattern.reverse, count) + } } } }