Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-1405] [mllib] Latent Dirichlet Allocation (LDA) using EM #4047

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
0cb7187
Added 3 files from dlwh LDA implementation
jkbradley Dec 16, 2014
2891e89
Prepped LDA main class for PR, but some cleanups remain
jkbradley Jan 12, 2015
2d40006
cleanups before PR
jkbradley Jan 13, 2015
377ebd9
separated LDA models into own file. more cleanups before PR
jkbradley Jan 14, 2015
9f2a492
Unit tests and fixes for LDA, now ready for PR
jkbradley Jan 14, 2015
75749e7
scala style fix
jkbradley Jan 14, 2015
ce53be9
fixed example name
jkbradley Jan 14, 2015
45cc7f2
mapPart -> flatMap
mengxr Jan 16, 2015
892530c
use axpy
mengxr Jan 16, 2015
9eb3d02
+ -> +=
mengxr Jan 16, 2015
6cb11b0
optimize computePTopic
mengxr Jan 16, 2015
cec0a9c
* -> *=
mengxr Jan 16, 2015
9fe0b95
optimize aggregateMessages
mengxr Jan 16, 2015
08d59a3
reset spacing
mengxr Jan 16, 2015
fb1e7b5
minor
mengxr Jan 16, 2015
77a2c85
Moved auto term,topic smoothing computation to get*Smoothing methods.…
jkbradley Jan 21, 2015
0b90393
renamed LDA LearningState.collectTopicTotals to globalTopicTotals
jkbradley Jan 25, 2015
43c1c40
small cleanup
jkbradley Jan 26, 2015
cb5a319
Added checkpointing to LDA
jkbradley Jan 29, 2015
993ca56
* Removed Document type in favor of (Long, Vector)
jkbradley Jan 30, 2015
b75472d
merged improvements from LDATiming into LDAExample. Will remove LDAT…
jkbradley Jan 30, 2015
91aadfe
Added Java-friendly run method to LDA.
jkbradley Jan 30, 2015
1a231b4
fixed scalastyle
jkbradley Jan 30, 2015
e8d8acf
Added catch for BreakIterator exception. Improved preprocessing to r…
jkbradley Jan 30, 2015
e391474
Removed LDATiming. Added PeriodicGraphCheckpointerSuite.scala. Smal…
jkbradley Feb 3, 2015
4ae2a7d
removed duplicate graphx dependency in mllib/pom.xml
jkbradley Feb 3, 2015
74487e5
Merge remote-tracking branch 'upstream/master' into davidhall-lda
jkbradley Feb 3, 2015
e3980d2
cleaned up PeriodicGraphCheckpointerSuite.scala
jkbradley Feb 3, 2015
589728b
Updates per code review. Main change was in LDAExample for faster vo…
jkbradley Feb 3, 2015
5c74345
cleaned up doc based on code review
jkbradley Feb 3, 2015
77e8814
small doc fix
jkbradley Feb 3, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.mllib

import java.text.BreakIterator

import scala.collection.mutable

import scopt.OptionParser

import org.apache.log4j.{Level, Logger}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

organize imports

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? Wait what's out of order? scala/java > non-spark > spark

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide#SparkCodeStyleGuide-Imports

  • java.* and javax.*
  • scala.*
  • Third-party libraries (org., com., etc)
  • Project classes (org.apache.spark.*)


import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.clustering.LDA
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD


/**
* An example Latent Dirichlet Allocation (LDA) app. Run with
* {{{
* ./bin/run-example mllib.LDAExample [options] <input>
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object LDAExample {

private case class Params(
input: Seq[String] = Seq.empty,
k: Int = 20,
maxIterations: Int = 10,
docConcentration: Double = -1,
topicConcentration: Double = -1,
vocabSize: Int = 10000,
stopwordFile: String = "",
checkpointDir: Option[String] = None,
checkpointInterval: Int = 10) extends AbstractParams[Params]

def main(args: Array[String]) {
val defaultParams = Params()

val parser = new OptionParser[Params]("LDAExample") {
head("LDAExample: an example LDA app for plain text data.")
opt[Int]("k")
.text(s"number of topics. default: ${defaultParams.k}")
.action((x, c) => c.copy(k = x))
opt[Int]("maxIterations")
.text(s"number of iterations of learning. default: ${defaultParams.maxIterations}")
.action((x, c) => c.copy(maxIterations = x))
opt[Double]("docConcentration")
.text(s"amount of topic smoothing to use (> 1.0) (-1=auto)." +
s" default: ${defaultParams.docConcentration}")
.action((x, c) => c.copy(docConcentration = x))
opt[Double]("topicConcentration")
.text(s"amount of term (word) smoothing to use (> 1.0) (-1=auto)." +
s" default: ${defaultParams.topicConcentration}")
.action((x, c) => c.copy(topicConcentration = x))
opt[Int]("vocabSize")
.text(s"number of distinct word types to use, chosen by frequency. (-1=all)" +
s" default: ${defaultParams.vocabSize}")
.action((x, c) => c.copy(vocabSize = x))
opt[String]("stopwordFile")
.text(s"filepath for a list of stopwords. Note: This must fit on a single machine." +
s" default: ${defaultParams.stopwordFile}")
.action((x, c) => c.copy(stopwordFile = x))
opt[String]("checkpointDir")
.text(s"Directory for checkpointing intermediate results." +
s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." +
s" default: ${defaultParams.checkpointDir}")
.action((x, c) => c.copy(checkpointDir = Some(x)))
opt[Int]("checkpointInterval")
.text(s"Iterations between each checkpoint. Only used if checkpointDir is set." +
s" default: ${defaultParams.checkpointInterval}")
.action((x, c) => c.copy(checkpointInterval = x))
arg[String]("<input>...")
.text("input paths (directories) to plain text corpora." +
" Each text file line should hold 1 document.")
.unbounded()
.required()
.action((x, c) => c.copy(input = c.input :+ x))
}

parser.parse(args, defaultParams).map { params =>
run(params)
}.getOrElse {
parser.showUsageAsError
sys.exit(1)
}
}

private def run(params: Params) {
val conf = new SparkConf().setAppName(s"LDAExample with $params")
val sc = new SparkContext(conf)

Logger.getRootLogger.setLevel(Level.WARN)

// Load documents, and prepare them for LDA.
val preprocessStart = System.nanoTime()
val (corpus, vocabArray, actualNumTokens) =
preprocess(sc, params.input, params.vocabSize, params.stopwordFile)
corpus.cache()
val actualCorpusSize = corpus.count()
val actualVocabSize = vocabArray.size
val preprocessElapsed = (System.nanoTime() - preprocessStart) / 1e9

println()
println(s"Corpus summary:")
println(s"\t Training set size: $actualCorpusSize documents")
println(s"\t Vocabulary size: $actualVocabSize terms")
println(s"\t Training set size: $actualNumTokens tokens")
println(s"\t Preprocessing time: $preprocessElapsed sec")
println()

// Run LDA.
val lda = new LDA()
lda.setK(params.k)
.setMaxIterations(params.maxIterations)
.setDocConcentration(params.docConcentration)
.setTopicConcentration(params.topicConcentration)
.setCheckpointInterval(params.checkpointInterval)
if (params.checkpointDir.nonEmpty) {
lda.setCheckpointDir(params.checkpointDir.get)
}
val startTime = System.nanoTime()
val ldaModel = lda.run(corpus)
val elapsed = (System.nanoTime() - startTime) / 1e9

println(s"Finished training LDA model. Summary:")
println(s"\t Training time: $elapsed sec")
val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble
println(s"\t Training data average log likelihood: $avgLogLikelihood")
println()

// Print the topics, showing the top-weighted terms for each topic.
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
val topics = topicIndices.map { case (terms, termWeights) =>
terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) }
}
println(s"${params.k} topics:")
topics.zipWithIndex.foreach { case (topic, i) =>
println(s"TOPIC $i")
topic.foreach { case (term, weight) =>
println(s"$term\t$weight")
}
println()
}

}

/**
* Load documents, tokenize them, create vocabulary, and prepare documents as term count vectors.
* @return (corpus, vocabulary as array, total token count in corpus)
*/
private def preprocess(
sc: SparkContext,
paths: Seq[String],
vocabSize: Int,
stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = {

// Get dataset of document texts
// One document per line in each text file.
val textRDD: RDD[String] = sc.textFile(paths.mkString(","))

// Split text into words
val tokenizer = new SimpleTokenizer(sc, stopwordFile)
val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) =>
id -> tokenizer.getWords(text)
}
tokenized.cache()

// Counts words: RDD[(word, wordCount)]
val wordCounts: RDD[(String, Long)] = tokenized
.flatMap { case (_, tokens) => tokens.map(_ -> 1L) }
.reduceByKey(_ + _)
wordCounts.cache()
val fullVocabSize = wordCounts.count()
// Select vocab
// (vocab: Map[word -> id], total tokens after selecting vocab)
val (vocab: Map[String, Int], selectedTokenCount: Long) = {
val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) {
// Use all terms
wordCounts.collect().sortBy(-_._2)
} else {
// Sort terms to select vocab
wordCounts.sortBy(_._2, ascending = false).take(vocabSize)
}
(tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum)
}

val documents = tokenized.map { case (id, tokens) =>
// Filter tokens by vocabulary, and create word count vector representation of document.
val wc = new mutable.HashMap[Int, Int]()
tokens.foreach { term =>
if (vocab.contains(term)) {
val termIndex = vocab(term)
wc(termIndex) = wc.getOrElse(termIndex, 0) + 1
}
}
val indices = wc.keys.toArray.sorted
val values = indices.map(i => wc(i).toDouble)

val sb = Vectors.sparse(vocab.size, indices, values)
(id, sb)
}

val vocabArray = new Array[String](vocab.size)
vocab.foreach { case (term, i) => vocabArray(i) = term }

(documents, vocabArray, selectedTokenCount)
}
}

/**
* Simple Tokenizer.
*
* TODO: Formalize the interface, and make this a public class in mllib.feature
*/
private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable {

private val stopwords: Set[String] = if (stopwordFile.isEmpty) {
Set.empty[String]
} else {
val stopwordText = sc.textFile(stopwordFile).collect()
stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet
}

// Matches sequences of Unicode letters
private val allWordRegex = "^(\\p{L}*)$".r

// Ignore words shorter than this length.
private val minWordLength = 3

def getWords(text: String): IndexedSeq[String] = {

val words = new mutable.ArrayBuffer[String]()

// Use Java BreakIterator to tokenize text into words.
val wb = BreakIterator.getWordInstance
wb.setText(text)

// current,end index start,end of each word
var current = wb.first()
var end = wb.next()
while (end != BreakIterator.DONE) {
// Convert to lowercase
val word: String = text.substring(current, end).toLowerCase
// Remove short words and strings that aren't only letters
word match {
case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) =>
words += w
case _ =>
}

current = end
try {
end = wb.next()
} catch {
case e: Exception =>
// Ignore remaining text in line.
// This is a known bug in BreakIterator (for some Java versions),
// which fails when it sees certain characters.
end = BreakIterator.DONE
}
}
words
}

}
Loading