Skip to content

Commit

Permalink
refactor LDA with Optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
hhbyyh committed Apr 23, 2015
1 parent ec2f857 commit 0bb8400
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 322 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
corpus.cache();

// Cluster the documents into three topics using LDA
DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus);
DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus);

// Output topics. Each is a distribution over words (matching word count vectors)
System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scopt.OptionParser
import org.apache.log4j.{Level, Logger}

import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.clustering.LDA
import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD

Expand Down Expand Up @@ -137,7 +137,7 @@ object LDAExample {
sc.setCheckpointDir(params.checkpointDir.get)
}
val startTime = System.nanoTime()
val ldaModel = lda.run(corpus)
val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
val elapsed = (System.nanoTime() - startTime) / 1e9

println(s"Finished training LDA model. Summary:")
Expand Down
191 changes: 82 additions & 109 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,11 @@

package org.apache.spark.mllib.clustering

import java.util.Random

import breeze.linalg.{DenseVector => BDV, normalize}

import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.clustering.LDAOptimizer
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
Expand All @@ -43,17 +37,6 @@ import org.apache.spark.util.Utils
* - "token": instance of a term appearing in a document
* - "topic": multinomial distribution over words representing some concept
*
* Currently, the underlying implementation uses Expectation-Maximization (EM), implemented
* according to the Asuncion et al. (2009) paper referenced below.
*
* References:
* - Original LDA paper (journal version):
* Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
* - This class implements their "smoothed" LDA model.
* - Paper which clearly explains several algorithms, including EM:
* Asuncion, Welling, Smyth, and Teh.
* "On Smoothing and Inference for Topic Models." UAI, 2009.
*
* @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation
* (Wikipedia)]]
*/
Expand All @@ -69,25 +52,7 @@ class LDA private (
def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
seed = Utils.random.nextLong(), checkpointInterval = 10)

var ldaOptimizer = setOptimizer("EM")

def getOptimizer(): LDAOptimizer = {
ldaOptimizer
}

def setOptimizer(optimizer: LDAOptimizer): this.type = {
this.ldaOptimizer = optimizer
this
}

def setOptimizer(optimizer: String): this.type = {
optimizer match{
case "EM" => this.setOptimizer(new EMOptimizer(default parameter))
case "Gibbs"=> this.setOptimizer(new GibbsOptimizer(default parameter))
case "Online"=> this.setOptimizer(new OnlineLDAOptimizer(default parameter))
}
}

private var ldaOptimizer: LDAOptimizer = getDefaultOptimizer("EM")

/**
* Number of topics to infer. I.e., the number of soft cluster centers.
Expand Down Expand Up @@ -241,6 +206,38 @@ class LDA private (
this
}


/** LDAOptimizer used to perform the actual calculation */
def getOptimizer(): LDAOptimizer = ldaOptimizer

/**
* LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer)
*/
def setOptimizer(optimizer: LDAOptimizer): this.type = {
this.ldaOptimizer = optimizer
this
}

/**
* Set the LDAOptimizer used to perform the actual calculation by algorithm name.
* Currently "EM" is supported.
*/
def setOptimizer(optimizerName: String): this.type = {
this.ldaOptimizer = getDefaultOptimizer(optimizerName)
this
}

/**
* Get the default optimizer from String parameter.
*/
private def getDefaultOptimizer(optimizerName: String): LDAOptimizer = {
optimizerName match{
case "EM" => new EMLDAOptimizer()
case other =>
throw new UnsupportedOperationException(s"Only EM are supported but got $other.")
}
}

/**
* Learn an LDA model using the given dataset.
*
Expand All @@ -250,42 +247,23 @@ class LDA private (
* Document IDs must be unique and >= 0.
* @return Inferred LDA model
*/
def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
if(ldaOptimizer.isInstanceOf[EMOptimizer]){
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
val start = System.nanoTime()
state.next()
val elapsedSeconds = (System.nanoTime() - start) / 1e9
iterationTimes(iter) = elapsedSeconds
iter += 1
}
state.graphCheckpointer.deleteAllCheckpoints()
new DistributedLDAModel(state, iterationTimes)
}
else if(ldaOptimizer.isInstanceOf[OnlineLDAOptimizer]){
val vocabSize = documents.first._2.size
val D = documents.count().toInt // total documents count
val onlineLDA = new OnlineLDAOptimizer(k, D, vocabSize, 1.0/k, 1.0/k, tau_0, kappa)

val arr = Array.fill(math.ceil(1.0 / miniBatchFraction).toInt)(miniBatchFraction)
val splits = documents.randomSplit(arr)
for(i <- 0 until numIterations){
val index = i % splits.size
onlineLDA.submitMiniBatch(splits(index))
}
onlineLDA.getTopicDistribution()
def run(documents: RDD[(Long, Vector)]): LDAModel = {
val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration,
seed, checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
val start = System.nanoTime()
state.next()
val elapsedSeconds = (System.nanoTime() - start) / 1e9
iterationTimes(iter) = elapsedSeconds
iter += 1
}



state.getLDAModel(iterationTimes)
}

/** Java-friendly version of [[run()]] */
def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = {
run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
}
}
Expand Down Expand Up @@ -346,51 +324,46 @@ private[clustering] object LDA {
* Vector over topics (length k) of token counts.
* The meaning of these counts can vary, and it may or may not be normalized to be a distribution.
*/
private[clustering] type TopicCounts = BDV[Double]

private[clustering] type TokenCount = Double

/**
* Compute bipartite term/doc graph.
*/
private def initialState(
docs: RDD[(Long, Vector)],
k: Int,
docConcentration: Double,
topicConcentration: Double,
randomSeed: Long,
checkpointInterval: Int): EMOptimizer = {
// For each document, create an edge (Document -> Term) for each unique term in the document.
val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
// Add edges for terms with non-zero counts.
termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) =>
Edge(docID, term2index(term), cnt)
}
}
/** Term vertex IDs are {-1, -2, ..., -vocabSize} */
private[clustering] def term2index(term: Int): Long = -(1 + term.toLong)

val vocabSize = docs.take(1).head._2.size

// Create vertices.
// Initially, we use random soft assignments of tokens to topics (random gamma).
def createVertices(): RDD[(VertexId, TopicCounts)] = {
val verticesTMP: RDD[(VertexId, TopicCounts)] =
edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
val random = new Random(partIndex + randomSeed)
partEdges.flatMap { edge =>
val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)
val sum = gamma * edge.attr
Seq((edge.srcId, sum), (edge.dstId, sum))
}
}
verticesTMP.reduceByKey(_ + _)
}
private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt

val docTermVertices = createVertices()
private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0

// Partition such that edges are grouped by document
val graph = Graph(docTermVertices, edges)
.partitionBy(PartitionStrategy.EdgePartition1D)
private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0

new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)
/**
* Compute gamma_{wjk}, a distribution over topics k.
*/
private[clustering] def computePTopic(
docTopicCounts: TopicCounts,
termTopicCounts: TopicCounts,
totalTopicCounts: TopicCounts,
vocabSize: Int,
eta: Double,
alpha: Double): TopicCounts = {
val K = docTopicCounts.length
val N_j = docTopicCounts.data
val N_w = termTopicCounts.data
val N = totalTopicCounts.data
val eta1 = eta - 1.0
val alpha1 = alpha - 1.0
val Weta1 = vocabSize * eta1
var sum = 0.0
val gamma_wj = new Array[Double](K)
var k = 0
while (k < K) {
val gamma_wjk = (N_w(k) + eta1) * (N_j(k) + alpha1) / (N(k) + Weta1)
gamma_wj(k) = gamma_wjk
sum += gamma_wjk
k += 1
}
// normalize
BDV(gamma_wj) /= sum
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class DistributedLDAModel private (

import LDA._

private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = {
private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = {
this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration,
state.topicConcentration, iterationTimes)
}
Expand Down
Loading

0 comments on commit 0bb8400

Please sign in to comment.