-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-1405] [mllib] Latent Dirichlet Allocation (LDA) using EM #4047
Closed
Closed
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
0cb7187
Added 3 files from dlwh LDA implementation
jkbradley 2891e89
Prepped LDA main class for PR, but some cleanups remain
jkbradley 2d40006
cleanups before PR
jkbradley 377ebd9
separated LDA models into own file. more cleanups before PR
jkbradley 9f2a492
Unit tests and fixes for LDA, now ready for PR
jkbradley 75749e7
scala style fix
jkbradley ce53be9
fixed example name
jkbradley 45cc7f2
mapPart -> flatMap
mengxr 892530c
use axpy
mengxr 9eb3d02
+ -> +=
mengxr 6cb11b0
optimize computePTopic
mengxr cec0a9c
* -> *=
mengxr 9fe0b95
optimize aggregateMessages
mengxr 08d59a3
reset spacing
mengxr fb1e7b5
minor
mengxr 77a2c85
Moved auto term,topic smoothing computation to get*Smoothing methods.…
jkbradley 0b90393
renamed LDA LearningState.collectTopicTotals to globalTopicTotals
jkbradley 43c1c40
small cleanup
jkbradley cb5a319
Added checkpointing to LDA
jkbradley 993ca56
* Removed Document type in favor of (Long, Vector)
jkbradley b75472d
merged improvements from LDATiming into LDAExample. Will remove LDAT…
jkbradley 91aadfe
Added Java-friendly run method to LDA.
jkbradley 1a231b4
fixed scalastyle
jkbradley e8d8acf
Added catch for BreakIterator exception. Improved preprocessing to r…
jkbradley e391474
Removed LDATiming. Added PeriodicGraphCheckpointerSuite.scala. Smal…
jkbradley 4ae2a7d
removed duplicate graphx dependency in mllib/pom.xml
jkbradley 74487e5
Merge remote-tracking branch 'upstream/master' into davidhall-lda
jkbradley e3980d2
cleaned up PeriodicGraphCheckpointerSuite.scala
jkbradley 589728b
Updates per code review. Main change was in LDAExample for faster vo…
jkbradley 5c74345
cleaned up doc based on code review
jkbradley 77e8814
small doc fix
jkbradley File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
283 changes: 283 additions & 0 deletions
283
examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,283 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.examples.mllib | ||
|
||
import java.text.BreakIterator | ||
|
||
import scala.collection.mutable | ||
|
||
import scopt.OptionParser | ||
|
||
import org.apache.log4j.{Level, Logger} | ||
|
||
import org.apache.spark.{SparkContext, SparkConf} | ||
import org.apache.spark.mllib.clustering.LDA | ||
import org.apache.spark.mllib.linalg.{Vector, Vectors} | ||
import org.apache.spark.rdd.RDD | ||
|
||
|
||
/** | ||
* An example Latent Dirichlet Allocation (LDA) app. Run with | ||
* {{{ | ||
* ./bin/run-example mllib.LDAExample [options] <input> | ||
* }}} | ||
* If you use it as a template to create your own app, please use `spark-submit` to submit your app. | ||
*/ | ||
object LDAExample { | ||
|
||
private case class Params( | ||
input: Seq[String] = Seq.empty, | ||
k: Int = 20, | ||
maxIterations: Int = 10, | ||
docConcentration: Double = -1, | ||
topicConcentration: Double = -1, | ||
vocabSize: Int = 10000, | ||
stopwordFile: String = "", | ||
checkpointDir: Option[String] = None, | ||
checkpointInterval: Int = 10) extends AbstractParams[Params] | ||
|
||
def main(args: Array[String]) { | ||
val defaultParams = Params() | ||
|
||
val parser = new OptionParser[Params]("LDAExample") { | ||
head("LDAExample: an example LDA app for plain text data.") | ||
opt[Int]("k") | ||
.text(s"number of topics. default: ${defaultParams.k}") | ||
.action((x, c) => c.copy(k = x)) | ||
opt[Int]("maxIterations") | ||
.text(s"number of iterations of learning. default: ${defaultParams.maxIterations}") | ||
.action((x, c) => c.copy(maxIterations = x)) | ||
opt[Double]("docConcentration") | ||
.text(s"amount of topic smoothing to use (> 1.0) (-1=auto)." + | ||
s" default: ${defaultParams.docConcentration}") | ||
.action((x, c) => c.copy(docConcentration = x)) | ||
opt[Double]("topicConcentration") | ||
.text(s"amount of term (word) smoothing to use (> 1.0) (-1=auto)." + | ||
s" default: ${defaultParams.topicConcentration}") | ||
.action((x, c) => c.copy(topicConcentration = x)) | ||
opt[Int]("vocabSize") | ||
.text(s"number of distinct word types to use, chosen by frequency. (-1=all)" + | ||
s" default: ${defaultParams.vocabSize}") | ||
.action((x, c) => c.copy(vocabSize = x)) | ||
opt[String]("stopwordFile") | ||
.text(s"filepath for a list of stopwords. Note: This must fit on a single machine." + | ||
s" default: ${defaultParams.stopwordFile}") | ||
.action((x, c) => c.copy(stopwordFile = x)) | ||
opt[String]("checkpointDir") | ||
.text(s"Directory for checkpointing intermediate results." + | ||
s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." + | ||
s" default: ${defaultParams.checkpointDir}") | ||
.action((x, c) => c.copy(checkpointDir = Some(x))) | ||
opt[Int]("checkpointInterval") | ||
.text(s"Iterations between each checkpoint. Only used if checkpointDir is set." + | ||
s" default: ${defaultParams.checkpointInterval}") | ||
.action((x, c) => c.copy(checkpointInterval = x)) | ||
arg[String]("<input>...") | ||
.text("input paths (directories) to plain text corpora." + | ||
" Each text file line should hold 1 document.") | ||
.unbounded() | ||
.required() | ||
.action((x, c) => c.copy(input = c.input :+ x)) | ||
} | ||
|
||
parser.parse(args, defaultParams).map { params => | ||
run(params) | ||
}.getOrElse { | ||
parser.showUsageAsError | ||
sys.exit(1) | ||
} | ||
} | ||
|
||
private def run(params: Params) { | ||
val conf = new SparkConf().setAppName(s"LDAExample with $params") | ||
val sc = new SparkContext(conf) | ||
|
||
Logger.getRootLogger.setLevel(Level.WARN) | ||
|
||
// Load documents, and prepare them for LDA. | ||
val preprocessStart = System.nanoTime() | ||
val (corpus, vocabArray, actualNumTokens) = | ||
preprocess(sc, params.input, params.vocabSize, params.stopwordFile) | ||
corpus.cache() | ||
val actualCorpusSize = corpus.count() | ||
val actualVocabSize = vocabArray.size | ||
val preprocessElapsed = (System.nanoTime() - preprocessStart) / 1e9 | ||
|
||
println() | ||
println(s"Corpus summary:") | ||
println(s"\t Training set size: $actualCorpusSize documents") | ||
println(s"\t Vocabulary size: $actualVocabSize terms") | ||
println(s"\t Training set size: $actualNumTokens tokens") | ||
println(s"\t Preprocessing time: $preprocessElapsed sec") | ||
println() | ||
|
||
// Run LDA. | ||
val lda = new LDA() | ||
lda.setK(params.k) | ||
.setMaxIterations(params.maxIterations) | ||
.setDocConcentration(params.docConcentration) | ||
.setTopicConcentration(params.topicConcentration) | ||
.setCheckpointInterval(params.checkpointInterval) | ||
if (params.checkpointDir.nonEmpty) { | ||
lda.setCheckpointDir(params.checkpointDir.get) | ||
} | ||
val startTime = System.nanoTime() | ||
val ldaModel = lda.run(corpus) | ||
val elapsed = (System.nanoTime() - startTime) / 1e9 | ||
|
||
println(s"Finished training LDA model. Summary:") | ||
println(s"\t Training time: $elapsed sec") | ||
val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble | ||
println(s"\t Training data average log likelihood: $avgLogLikelihood") | ||
println() | ||
|
||
// Print the topics, showing the top-weighted terms for each topic. | ||
val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) | ||
val topics = topicIndices.map { case (terms, termWeights) => | ||
terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) } | ||
} | ||
println(s"${params.k} topics:") | ||
topics.zipWithIndex.foreach { case (topic, i) => | ||
println(s"TOPIC $i") | ||
topic.foreach { case (term, weight) => | ||
println(s"$term\t$weight") | ||
} | ||
println() | ||
} | ||
|
||
} | ||
|
||
/** | ||
* Load documents, tokenize them, create vocabulary, and prepare documents as term count vectors. | ||
* @return (corpus, vocabulary as array, total token count in corpus) | ||
*/ | ||
private def preprocess( | ||
sc: SparkContext, | ||
paths: Seq[String], | ||
vocabSize: Int, | ||
stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { | ||
|
||
// Get dataset of document texts | ||
// One document per line in each text file. | ||
val textRDD: RDD[String] = sc.textFile(paths.mkString(",")) | ||
|
||
// Split text into words | ||
val tokenizer = new SimpleTokenizer(sc, stopwordFile) | ||
val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) => | ||
id -> tokenizer.getWords(text) | ||
} | ||
tokenized.cache() | ||
|
||
// Counts words: RDD[(word, wordCount)] | ||
val wordCounts: RDD[(String, Long)] = tokenized | ||
.flatMap { case (_, tokens) => tokens.map(_ -> 1L) } | ||
.reduceByKey(_ + _) | ||
wordCounts.cache() | ||
val fullVocabSize = wordCounts.count() | ||
// Select vocab | ||
// (vocab: Map[word -> id], total tokens after selecting vocab) | ||
val (vocab: Map[String, Int], selectedTokenCount: Long) = { | ||
val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) { | ||
// Use all terms | ||
wordCounts.collect().sortBy(-_._2) | ||
} else { | ||
// Sort terms to select vocab | ||
wordCounts.sortBy(_._2, ascending = false).take(vocabSize) | ||
} | ||
(tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum) | ||
} | ||
|
||
val documents = tokenized.map { case (id, tokens) => | ||
// Filter tokens by vocabulary, and create word count vector representation of document. | ||
val wc = new mutable.HashMap[Int, Int]() | ||
tokens.foreach { term => | ||
if (vocab.contains(term)) { | ||
val termIndex = vocab(term) | ||
wc(termIndex) = wc.getOrElse(termIndex, 0) + 1 | ||
} | ||
} | ||
val indices = wc.keys.toArray.sorted | ||
val values = indices.map(i => wc(i).toDouble) | ||
|
||
val sb = Vectors.sparse(vocab.size, indices, values) | ||
(id, sb) | ||
} | ||
|
||
val vocabArray = new Array[String](vocab.size) | ||
vocab.foreach { case (term, i) => vocabArray(i) = term } | ||
|
||
(documents, vocabArray, selectedTokenCount) | ||
} | ||
} | ||
|
||
/** | ||
* Simple Tokenizer. | ||
* | ||
* TODO: Formalize the interface, and make this a public class in mllib.feature | ||
*/ | ||
private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable { | ||
|
||
private val stopwords: Set[String] = if (stopwordFile.isEmpty) { | ||
Set.empty[String] | ||
} else { | ||
val stopwordText = sc.textFile(stopwordFile).collect() | ||
stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet | ||
} | ||
|
||
// Matches sequences of Unicode letters | ||
private val allWordRegex = "^(\\p{L}*)$".r | ||
|
||
// Ignore words shorter than this length. | ||
private val minWordLength = 3 | ||
|
||
def getWords(text: String): IndexedSeq[String] = { | ||
|
||
val words = new mutable.ArrayBuffer[String]() | ||
|
||
// Use Java BreakIterator to tokenize text into words. | ||
val wb = BreakIterator.getWordInstance | ||
wb.setText(text) | ||
|
||
// current,end index start,end of each word | ||
var current = wb.first() | ||
var end = wb.next() | ||
while (end != BreakIterator.DONE) { | ||
// Convert to lowercase | ||
val word: String = text.substring(current, end).toLowerCase | ||
// Remove short words and strings that aren't only letters | ||
word match { | ||
case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) => | ||
words += w | ||
case _ => | ||
} | ||
|
||
current = end | ||
try { | ||
end = wb.next() | ||
} catch { | ||
case e: Exception => | ||
// Ignore remaining text in line. | ||
// This is a known bug in BreakIterator (for some Java versions), | ||
// which fails when it sees certain characters. | ||
end = BreakIterator.DONE | ||
} | ||
} | ||
words | ||
} | ||
|
||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
organize imports
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
? Wait what's out of order? scala/java > non-spark > spark
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide#SparkCodeStyleGuide-Imports