Skip to content

Commit

Permalink
create FPTree class
Browse files Browse the repository at this point in the history
  • Loading branch information
jackylk committed Jan 30, 2015
1 parent d110ab2 commit 93f3280
Show file tree
Hide file tree
Showing 5 changed files with 480 additions and 117 deletions.
156 changes: 53 additions & 103 deletions mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,34 @@

package org.apache.spark.mllib.fpm

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.broadcast._
import org.apache.spark.rdd.RDD

import scala.collection.mutable.{ArrayBuffer, Map}


/**
* This class implements Parallel FPGrowth algorithm to do frequent pattern matching on input data.
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
* Parallel FPGrowth (PFP) partitions computation in such a way that each machine executes an
* independent group of mining tasks. More detail of this algorithm can be found at
* http://infolab.stanford.edu/~echang/recsys08-69.pdf
* [[http://dx.doi.org/10.1145/1454008.1454027, PFP]], and the original FP-growth paper can be found at
* [[http://dx.doi.org/10.1145/335191.335372, FP-growth]]
*
* @param minSupport the minimal support level of the frequent pattern, any pattern appears more than
* (minSupport * size-of-the-dataset) times will be output
*/
class FPGrowth private(private var minSupport: Double) extends Logging with Serializable {

/**
* Constructs a FPGrowth instance with default parameters:
* {minSupport: 0.5}
* {minSupport: 0.3}
*/
def this() = this(0.5)
def this() = this(0.3)

/**
* set the minimal support level, default is 0.5
* set the minimal support level, default is 0.3
* @param minSupport minimal support level
*/
def setMinSupport(minSupport: Double): this.type = {
Expand All @@ -49,87 +54,82 @@ class FPGrowth private(private var minSupport: Double) extends Logging with Seri

/**
* Compute a FPGrowth Model that contains frequent pattern result.
* @param data input data set
* @param data input data set, each element contains a transaction
* @return FPGrowth Model
*/
def run(data: RDD[Array[String]]): FPGrowthModel = {
val model = runAlgorithm(data)
model
}

/**
* Implementation of PFP.
*/
private def runAlgorithm(data: RDD[Array[String]]): FPGrowthModel = {
val count = data.count()
val minCount = minSupport * count
val single = generateSingleItem(data, minCount)
val combinations = generateCombinations(data, minCount, single)
new FPGrowthModel(single ++ combinations)
val all = single.map(v => (Array[String](v._1), v._2)).union(combinations)
new FPGrowthModel(all.collect())
}

/**
* Generate single item pattern by filtering the input data using minimal support level
* @return array of frequent pattern with its count
*/
private def generateSingleItem(
data: RDD[Array[String]],
minCount: Double): Array[(String, Int)] = {
data.flatMap(v => v)
.map(v => (v, 1))
minCount: Double): RDD[(String, Long)] = {
val single = data.flatMap(v => v.toSet)
.map(v => (v, 1L))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
.collect()
.distinct
.sortWith(_._2 > _._2)
.sortBy(_._2)
single
}

/**
* Generate combination of items by computing on FPTree,
* Generate combination of frequent pattern by computing on FPTree,
* the computation is done on each FPTree partitions.
* @return array of frequent pattern with its count
*/
private def generateCombinations(
data: RDD[Array[String]],
minCount: Double,
singleItem: Array[(String, Int)]): Array[(String, Int)] = {
val single = data.context.broadcast(singleItem)
data.flatMap(basket => createFPTree(basket, single))
.groupByKey()
.flatMap(partition => runFPTree(partition, minCount))
.collect()
singleItem: RDD[(String, Long)]): RDD[(Array[String], Long)] = {
val single = data.context.broadcast(singleItem.collect())
data.flatMap(transaction => createConditionPatternBase(transaction, single))
.aggregateByKey(new FPTree)(
(aggregator, condPattBase) => aggregator.add(condPattBase),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
.flatMap(partition => partition._2.mine(minCount, partition._1))
}

/**
* Create FP-Tree partition for the giving basket
* @return an array contains a tuple, whose first element is the single
* item (hash key) and second element is its condition pattern base
*/
private def createFPTree(
basket: Array[String],
singleItem: Broadcast[Array[(String, Int)]]): Array[(String, Array[String])] = {
private def createConditionPatternBase(
transaction: Array[String],
singleBC: Broadcast[Array[(String, Long)]]): Array[(String, Array[String])] = {
var output = ArrayBuffer[(String, Array[String])]()
var combination = ArrayBuffer[String]()
val single = singleItem.value
var items = ArrayBuffer[(String, Int)]()

// Filter the basket by single item pattern
val iterator = basket.iterator
while (iterator.hasNext){
val item = iterator.next
val opt = single.find(_._1.equals(item))
if (opt != None) {
items ++= opt
}
}

// Sort it and create the item combinations
val sortedItems = items.sortWith(_._1 > _._1).sortWith(_._2 > _._2).toArray
val itemIterator = sortedItems.iterator
var items = ArrayBuffer[(String, Long)]()
val single = singleBC.value
val singleMap = single.toMap

// Filter the basket by single item pattern and sort
// by single item and its count
val candidates = transaction
.filter(singleMap.contains)
.map(item => (item, singleMap(item)))
.sortBy(_._1)
.sortBy(_._2)
.toArray

val itemIterator = candidates.iterator
while (itemIterator.hasNext) {
combination.clear()
val item = itemIterator.next
val firstNItems = sortedItems.take(sortedItems.indexOf(item))
val item = itemIterator.next()
val firstNItems = candidates.take(candidates.indexOf(item))
if (firstNItems.length > 0) {
val iterator = firstNItems.iterator
while (iterator.hasNext) {
val elem = iterator.next
val elem = iterator.next()
combination += elem._1
}
output += ((item._1, combination.toArray))
Expand All @@ -138,56 +138,6 @@ class FPGrowth private(private var minSupport: Double) extends Logging with Seri
output.toArray
}

/**
* Generate frequent pattern by walking through the FPTree
*/
private def runFPTree(
partition: (String, Iterable[Array[String]]),
minCount: Double): Array[(String, Int)] = {
val key = partition._1
val value = partition._2
val output = ArrayBuffer[(String, Int)]()
val map = Map[String, Int]()

// Walk through the FPTree partition to generate all combinations that satisfy
// the minimal support level.
var k = 1
while (k > 0) {
map.clear()
val iterator = value.iterator
while (iterator.hasNext) {
val pattern = iterator.next
if (pattern.length >= k) {
val combination = pattern.toList.combinations(k).toList
val itemIterator = combination.iterator
while (itemIterator.hasNext){
val item = itemIterator.next
val list2key: List[String] = (item :+ key).sortWith(_ > _)
val newKey = list2key.mkString(" ")
if (map.get(newKey) == None) {
map(newKey) = 1
} else {
map(newKey) = map.apply(newKey) + 1
}
}
}
}
var eligible: Array[(String, Int)] = null
if (map.size != 0) {
val candidate = map.filter(_._2 >= minCount)
if (candidate.size != 0) {
eligible = candidate.toArray
output ++= eligible
}
}
if ((eligible == null) || (eligible.length == 0)) {
k = 0
} else {
k = k + 1
}
}
output.toArray
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ package org.apache.spark.mllib.fpm
/**
* A FPGrowth Model for FPGrowth, each element is a frequent pattern with count.
*/
class FPGrowthModel (val frequentPattern: Array[(String, Int)]) extends Serializable {
class FPGrowthModel (val frequentPattern: Array[(Array[String], Long)]) extends Serializable {
}
Loading

0 comments on commit 93f3280

Please sign in to comment.