Skip to content

Commit

Permalink
protoptype for discussion
Browse files Browse the repository at this point in the history
  • Loading branch information
hhbyyh committed Apr 22, 2015
1 parent bdc5c16 commit ec2f857
Show file tree
Hide file tree
Showing 2 changed files with 334 additions and 131 deletions.
183 changes: 52 additions & 131 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.clustering.LDAOptimizer
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -68,6 +69,26 @@ class LDA private (
def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
seed = Utils.random.nextLong(), checkpointInterval = 10)

var ldaOptimizer = setOptimizer("EM")

def getOptimizer(): LDAOptimizer = {
ldaOptimizer
}

def setOptimizer(optimizer: LDAOptimizer): this.type = {
this.ldaOptimizer = optimizer
this
}

def setOptimizer(optimizer: String): this.type = {
optimizer match{
case "EM" => this.setOptimizer(new EMOptimizer(default parameter))
case "Gibbs"=> this.setOptimizer(new GibbsOptimizer(default parameter))
case "Online"=> this.setOptimizer(new OnlineLDAOptimizer(default parameter))
}
}


/**
* Number of topics to infer. I.e., the number of soft cluster centers.
*/
Expand Down Expand Up @@ -230,19 +251,37 @@ class LDA private (
* @return Inferred LDA model
*/
def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
val start = System.nanoTime()
state.next()
val elapsedSeconds = (System.nanoTime() - start) / 1e9
iterationTimes(iter) = elapsedSeconds
iter += 1
if(ldaOptimizer.isInstanceOf[EMOptimizer]){
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
val start = System.nanoTime()
state.next()
val elapsedSeconds = (System.nanoTime() - start) / 1e9
iterationTimes(iter) = elapsedSeconds
iter += 1
}
state.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(state, iterationTimes)
}
else if(ldaOptimizer.isInstanceOf[OnlineLDAOptimizer]){
val vocabSize = documents.first._2.size
val D = documents.count().toInt // total documents count
val onlineLDA = new OnlineLDAOptimizer(k, D, vocabSize, 1.0/k, 1.0/k, tau_0, kappa)

val arr = Array.fill(math.ceil(1.0 / miniBatchFraction).toInt)(miniBatchFraction)
val splits = documents.randomSplit(arr)
for(i <- 0 until numIterations){
val index = i % splits.size
onlineLDA.submitMiniBatch(splits(index))
}
onlineLDA.getTopicDistribution()
}
state.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(state, iterationTimes)



}

/** Java-friendly version of [[run()]] */
Expand Down Expand Up @@ -307,126 +346,7 @@ private[clustering] object LDA {
* Vector over topics (length k) of token counts.
* The meaning of these counts can vary, and it may or may not be normalized to be a distribution.
*/
private[clustering] type TopicCounts = BDV[Double]

private[clustering] type TokenCount = Double

/** Term vertex IDs are {-1, -2, ..., -vocabSize} */
private[clustering] def term2index(term: Int): Long = -(1 + term.toLong)

private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt

private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0

private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0

/**
* Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
*
* @param graph EM graph, storing current parameter estimates in vertex descriptors and
* data (token counts) in edge descriptors.
* @param k Number of topics
* @param vocabSize Number of unique terms
* @param docConcentration "alpha"
* @param topicConcentration "beta" or "eta"
*/
private[clustering] class EMOptimizer(
var graph: Graph[TopicCounts, TokenCount],
val k: Int,
val vocabSize: Int,
val docConcentration: Double,
val topicConcentration: Double,
checkpointInterval: Int) {

private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
graph, checkpointInterval)

def next(): EMOptimizer = {
val eta = topicConcentration
val W = vocabSize
val alpha = docConcentration

val N_k = globalTopicTotals
val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit =
(edgeContext) => {
// Compute N_{wj} gamma_{wjk}
val N_wj = edgeContext.attr
// E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count
// N_{wj}.
val scaledTopicDistribution: TopicCounts =
computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj
edgeContext.sendToDst((false, scaledTopicDistribution))
edgeContext.sendToSrc((false, scaledTopicDistribution))
}
// This is a hack to detect whether we could modify the values in-place.
// TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) =
(m0, m1) => {
val sum =
if (m0._1) {
m0._2 += m1._2
} else if (m1._1) {
m1._2 += m0._2
} else {
m0._2 + m1._2
}
(true, sum)
}
// M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
val docTopicDistributions: VertexRDD[TopicCounts] =
graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg)
.mapValues(_._2)
// Update the vertex descriptors with the new counts.
val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
graph = newGraph
graphCheckpointer.updateGraph(newGraph)
globalTopicTotals = computeGlobalTopicTotals()
this
}

/**
* Aggregate distributions over topics from all term vertices.
*
* Note: This executes an action on the graph RDDs.
*/
var globalTopicTotals: TopicCounts = computeGlobalTopicTotals()

private def computeGlobalTopicTotals(): TopicCounts = {
val numTopics = k
graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _)
}

}

/**
* Compute gamma_{wjk}, a distribution over topics k.
*/
private def computePTopic(
docTopicCounts: TopicCounts,
termTopicCounts: TopicCounts,
totalTopicCounts: TopicCounts,
vocabSize: Int,
eta: Double,
alpha: Double): TopicCounts = {
val K = docTopicCounts.length
val N_j = docTopicCounts.data
val N_w = termTopicCounts.data
val N = totalTopicCounts.data
val eta1 = eta - 1.0
val alpha1 = alpha - 1.0
val Weta1 = vocabSize * eta1
var sum = 0.0
val gamma_wj = new Array[Double](K)
var k = 0
while (k < K) {
val gamma_wjk = (N_w(k) + eta1) * (N_j(k) + alpha1) / (N(k) + Weta1)
gamma_wj(k) = gamma_wjk
sum += gamma_wjk
k += 1
}
// normalize
BDV(gamma_wj) /= sum
}

/**
* Compute bipartite term/doc graph.
Expand Down Expand Up @@ -473,3 +393,4 @@ private[clustering] object LDA {
}

}

Loading

0 comments on commit ec2f857

Please sign in to comment.