Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-1405] [mllib] Latent Dirichlet Allocation (LDA) using EM #4047

Closed
wants to merge 31 commits into from

Conversation

jkbradley
Copy link
Member

This PR introduces an API + simple implementation for Latent Dirichlet Allocation (LDA).

The design doc for this PR has been updated since I initially posted it. In particular, see the API and Planning for the Future sections.

Goals

  • Settle on a public API which may eventually include:
    • more inference algorithms
    • more options / functionality
  • Have an initial easy-to-understand implementation which others may improve.
  • This is NOT intended to support every topic model out there. However, if there are suggestions for making this extensible or pluggable in the future, that could be nice, as long as it does not complicate the API or implementation too much.
  • This may not be very scalable currently. It will be important to check and improve accuracy. For correctness of the implementation, please check against the Asuncion et al. (2009) paper in the design doc.

Sketch of contents of this PR

Dependency: This makes MLlib depend on GraphX.

Files and classes:

  • LDA.scala (441 lines):
    • class LDA (main estimator class)
    • LDA.Document (text + document ID)
  • LDAModel.scala (266 lines)
    • abstract class LDAModel
    • class LocalLDAModel
    • class DistributedLDAModel
  • LDAExample.scala (245 lines): script to run LDA + a simple (private) Tokenizer
  • LDASuite.scala (144 lines)

Data/model representation and algorithm:

Design notes

Please refer to the JIRA for more discussion + the design doc for this PR

Here, I list the main changes AFTER the design doc was posted.

Design decisions:

  • logLikelihood() computes the log likelihood of the data and the current point estimate of parameters. This is different from the likelihood of the data given the hyperparameters, which would be harder to compute. I’d describe the current approach as more frequentist, whereas the harder approach would be more Bayesian.
  • The current API takes Documents as token count vectors. I believe there should be an extended API taking RDD[String] or RDD[Array[String]] in a future PR. I have sketched this out in the design doc (as well as handier versions of getTopics returning Strings).
  • Hyperparameters should be set differently for different inference/learning algorithms. See Asuncion et al. (2009) in the design doc for a good demonstration. I encourage good behavior via defaults and warning messages.

Items planned for future PRs:

  • perplexity
  • API taking Strings

Questions for reviewers

  • Should LDA be called LatentDirichletAllocation (and LDAModel be LatentDirichletAllocationModel)?
    • Pro: We may someday want LinearDiscriminantAnalysis.
    • Con: Very long names
  • Should LDA reside in clustering? Or do we want a sub-package?
    • mllib.topicmodel
    • mllib.clustering.topicmodel
  • Does the API seem reasonable and extensible?
  • Unit tests:
    • Should there be a test which checks a clustering results? E.g., train on a small, fake dataset with 2 very distinct topics/clusters, and ensure LDA finds those 2 topics/clusters. Does that sound useful or too flaky?

Other notes

This has not been tested much for scaling. I have run it on a laptop for 200 iterations on a 5MB dataset with 1000 terms and 5 topics. Running it for 500 iterations made it fail because of GC problems. I'm running larger scale tests & will put results here, but future PRs may need to improve the scaling.

Thanks to…

  • @dlwh for the initial implementation
    • + @jegonzal for some code in the initial implementation
  • The many contributors towards topic model implementations in Spark which were referenced as a basis for this PR: @akopich @witgo @yinxusen @dlwh @EntilZha @jegonzal @IlyaKozlov
    • Note: The plan is to include this full list in the authors if this PR gets merged. Please notify me if you prefer otherwise.

CC: @mengxr

@SparkQA
Copy link

SparkQA commented Jan 14, 2015

Test build #25558 has started for PR 4047 at commit c6e4308.

  • This patch does not merge cleanly.

@SparkQA
Copy link

SparkQA commented Jan 14, 2015

Test build #25558 has finished for PR 4047 at commit c6e4308.

  • This patch fails Scala style tests.
  • This patch does not merge cleanly.
  • This patch adds the following public classes (experimental):
    • case class Document(counts: Vector, id: Long)

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/25558/
Test FAILed.

/**
* An example Latent Dirichlet Allocation (LDA) app. Run with
* {{{
* ./bin/run-example mllib.DenseKMeans [options] <input>
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(rename)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@SparkQA
Copy link

SparkQA commented Jan 14, 2015

Test build #25560 has started for PR 4047 at commit 984c414.

  • This patch merges cleanly.

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/25559/
Test FAILed.

@SparkQA
Copy link

SparkQA commented Jan 14, 2015

Test build #25560 has finished for PR 4047 at commit 984c414.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/25560/
Test PASSed.

@jkbradley
Copy link
Member Author

By the way, I'm running larger-scale tests, and I'll post results once they are ready!

@EntilZha
Copy link
Contributor

Mind posting the data set size (vocab, doc, etc) and type of cluster? About to start some performance tests and would be cool to hit both the Chinese dataset size and what you are testing.

@witgo
Copy link
Contributor

witgo commented Jan 15, 2015

There are two questions.

  1. How to cover the long tail topic features
  2. The algorithm performance in 10K -100K topic

Here is a relevant industrial practice:
"Peacock: Learning Long-Tail Topic Features for Industrial Applications"

@jkbradley
Copy link
Member Author

@EntilZha Here’s a sketch of my plan.

Datasets:

  • UCI ML Repository data (also used by Asuncion et al., 2009):
    • KOS
    • NIPS
    • NYTimes
    • PubMed (full)
  • Wikipedia?

Data preparation:

  • Converting to bags of words:
    • UCI datasets are given as word counts already.
    • Wikipedia dump is text.
      • I use the SimpleTokenizer in the LDAExample, which sets term = word and only accepts alphabetic characters.
      • Use stopwords from @dlwh located at [https://github.com/dlwh/spark/feature/lda]
      • No stemming
  • Choosing vocab: For various vocabSize settings, I took the most common vocabSize terms.

Scaling tests: (doing these first)

  • corpus size
  • vocabSize
  • k
  • numIterations

Accuracy tests: (doing these second)

  • train on full datasets
  • Tune hyperparameters via grid search, following Asuncion et al. (2009) section 4.1.
  • Can hopefully compare with their results in Fig. 5.

These tests will run on a 16-node EC2 cluster of r3.2xlarge instances.

@jkbradley
Copy link
Member Author

@witgo I agree that there are 2 different use regimes for LDA: interpretable topics and featurization. The current implementation follows pretty much every other graph-based implementation I’ve seen:

  • 1 vertex per document + 1 vertex per term
  • Each vertex stores a vector of length # topics.
  • On each iteration, each doc vertex must communicate its vector to any connected term vertices (and likewise for term vertices), via map-reduce stages over triplets.

I have not heard of methods which can avoid this amount of communication for LDA. I’m sure the implementation can be optimized, so please make comments here or JIRAs afterwards about that. For modified models, it might be possible to communicate less: sparsity-inducing priors, hierarchical models, etc.

@SparkQA
Copy link

SparkQA commented Jan 16, 2015

Test build #25689 has started for PR 4047 at commit 648f66c.

  • This patch merges cleanly.

@SparkQA
Copy link

SparkQA commented Jan 17, 2015

Test build #25689 has finished for PR 4047 at commit 648f66c.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • case class Document(counts: Vector, id: Long)

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/25689/
Test PASSed.

@jkbradley
Copy link
Member Author

Here are some initial test results. There are 2 sets since I had run some before the updates from @mengxr and other tests after the updates.

Summary: Iterations keep getting longer. Will need to work on scalability, but it at least runs on medium-sized datasets. Updates from @mengxr improve scalability. Still need to test large numbers of topics.

How tests were run

I ran using this branch:
[https://github.com/jkbradley/spark/tree/lda-testing].
It includes a little more instrumentation and a Timing script: [https://github.com/jkbradley/spark/blob/lda-testing/examples/src/main/scala/org/apache/spark/examples/mllib/LDATiming.scala].

I used the collection of stopwords from @dlwh here: [https://github.com/dlwh/spark/blob/feature/lda/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleLatentDirichletAllocation.scala]

I ran using a (partial) dump of Wikipedia) consisting of about 4.7GB of gzipped text files.

My goal is to do a few sets of tests, scaling:

  • corpus sizes: 10K, 100K, 1M
  • k: 10, 100, 1K, 10K, 100K
  • vocabSize: (not run yet)

I ran:

FOR SCALING CORPUS SIZE:

bin/spark-submit --class org.apache.spark.examples.mllib.LDATiming --master spark://MY_EC2_URL:7077 --driver-memory 20g /root/spark-git/examples/target/scala-2.10/spark-examples-1.3.0-SNAPSHOT-hadoop1.0.4.jar --corpusSizes "10000 100000 1000000 -1" --ks "100" --maxIterations 10 --topicSmoothing -1 --termSmoothing -1 --vocabSizes "1000000" --stopwordFile "stopwords.txt" "DATADIR"

SCALING K:

bin/spark-submit --class org.apache.spark.examples.mllib.LDATiming --master spark://MY_EC2_URL:7077 --driver-memory 20g /root/spark-git/examples/target/scala-2.10/spark-examples-1.3.0-SNAPSHOT-hadoop1.0.4.jar --corpusSizes "-1" --ks "10 100 1000 10000 100000" --maxIterations 10 --topicSmoothing -1 --termSmoothing -1 --vocabSizes "1000000" --stopwordFile "stopwords.txt" "DATADIR"

These used a 16-node EC2 cluster of r3.2xlarge machines.

Results (before recent GC-related updates)

Take-aways: Iterations keep getting longer. Will need to work on scalability, but it at least runs on medium-sized datasets.

Scaling corpus size

DATASET
     Training set size: 9999 documents
     Vocabulary size: 183313 terms
     Training set size: 2537186 tokens
     Preprocessing time: 53.03116398 sec
Finished training LDA model.  Summary:
     Training time: 51.766336365 sec
     Training data average log likelihood: -2374.150807757336
     Training times per iteration (sec):
    16.301772584
    2.714758941
    2.681336067
    2.812407396
    3.067381155
    3.148446287
    4.091387595
    4.845391099
    4.800537527
    5.784521297

Note that iteration times keep getting longer.

DATASET
     Training set size: 99657 documents
     Vocabulary size: 864755 terms
     Training set size: 25372240 tokens
     Preprocessing time: 56.172117053 sec
Finished training LDA model.  Summary:
     Training time: 272.724740335 sec
     Training data average log likelihood: -2453.2238815201995
     Training times per iteration (sec):
    36.27088487
    9.239099504
    12.899834887
    17.887326081
    22.548736594
    29.705019399
    34.532178918
    37.132915562
    43.264158967
    24.167606732
DATASET
     Training set size: 998875 documents
     Vocabulary size: 3718603 terms
     Training set size: 255269137 tokens
     Preprocessing time: 969.582218325 sec
(died)

Scaling k

DATASET
     Training set size: 4072243 documents
     Vocabulary size: 1000000 terms
     Training set size: 955849462 tokens
     Preprocessing time: 1023.173703836 sec
Finished training LDA model.  Summary:
     Training time: 734.18870584 sec
     Training data average log likelihood: -2487.378006538547
     Training times per iteration (sec):
    220.962623351
    43.31892217
    44.65509746
    49.119503552
    52.24947807
    53.822309875
    57.582740118
    64.41201
    70.151547256
    72.043746927

(larger tests died)

Results (after recent GC-related updates)

Main take-away: Updates from @mengxr improve scaling. (Notice the not dying on the later tests.)

Scaling corpus size

DATASET
     Training set size: 9999 documents
     Vocabulary size: 549136 terms
     Training set size: 2651999 tokens
     Preprocessing time: 59.886667054 sec
Finished training LDA model.  Summary:
     Training time: 87.066636441 sec
     Training data average log likelihood: -2987.1021154536284
     Training times per iteration (sec):
    25.916230397
    3.348482537
    3.133761688
    4.325156952
    5.231940702
    6.117500071
    7.246081989
    8.584900244
    9.413571911
    9.114112645
DATASET
     Training set size: 99657 documents
     Vocabulary size: 1000000 terms
     Training set size: 24106059 tokens
     Preprocessing time: 79.624635494 sec
Finished training LDA model.  Summary:
     Training time: 295.883936257 sec
     Training data average log likelihood: -2608.5841219446515
     Training times per iteration (sec):
    41.455679987
    11.062455643
    15.526668004
    21.027575262
    26.262190857
    25.565775147
    30.829831734
    35.716967684
    37.592023917
    44.04621023

DATASET
     Training set size: 998875 documents
     Vocabulary size: 1000000 terms
     Training set size: 235682866 tokens
     Preprocessing time: 322.008531951 sec
Finished training LDA model.  Summary:
     Training time: 1073.726914484 sec
     Training data average log likelihood: -2496.418600245705
     Training times per iteration (sec):
    119.644333858
    41.555120562
    52.719948261
    64.48673763
    88.892069695
    100.981587858
    123.62990158
    150.65753992
    157.688974275
    168.515556567
DATASET
     Training set size: 4072243 documents
     Vocabulary size: 1000000 terms
     Training set size: 955849462 tokens
     Preprocessing time: 1110.123689033 sec
Finished training LDA model.  Summary:
     Training time: 4781.682695595 sec
     Training data average log likelihood: -2483.61085533687
     Training times per iteration (sec):
    363.747503418
    234.396490798
    264.977904783
    377.257946593
    447.054876375
    364.207562754
    408.152587705
    420.5513901
    1080.746177241
    813.866786165

@hhbyyh
Copy link
Contributor

hhbyyh commented Jan 18, 2015

Is there plan to include the inference for new (unseen) document based on the generated distribution? Thanks

@jkbradley
Copy link
Member Author

@hhbyyh Yes, please review the design doc linked from the JIRA. There is quite a bit of functionality which will not be in this initial PR.

@EntilZha
Copy link
Contributor

I've had a good chance to look at PR while making changes to my own code. I really liked the Graph initialization code (especially the partition strategy), I was able to copy that and get a 2x boost almost across the board in computation phases.

Questions in PR:

Naming (LDA vs LatentDirichletAllocation)

I think that Linear Discriminant Analysis and Latent Dirichlet Allocation are different enough that it doesn't warrant making the class names significantly longer, so LDA is probably good.

Package Name

I think a separate topicmodeling package is fitting, it makes it very clear that LDA is for topic modeling. The second motivation for a topicmodeling package is that it looks like there will be this EM implementation and soon the Gibbs version I am using, so both could fit in this namespace well. Along those same lines, perhaps keep LDAModel for the abstract class and for implementations use some way to denote whether it is Gibbs ( LDAGibbs, GibbsLDA, LDAWithGibbs) or EM based (LDAEM, EMLDA, LDAWithEM). Naming is hard...

Unit Test

On tests, I am for including a small training example, similar to here:
https://github.com/EntilZha/spark/blob/LDA/graphx/src/test/scala/org/apache/spark/graphx/lib/LDASuite.scala#L52-L76
It has been very helpful to have a sanity check test as I have been working on potentially breaking changes. In that example, the data set is 9 lines, so it would probably be better to refactor and place it in an Array[String] defined in code instead of an external file.

Lastly, it looks like the abstract classes/traits haven't made it in yet to have multiple implementations satisfy a common LDA API. I have taken a careful look at the design doc and your code, and am fairly confident that I have a way to do this with minimal code changes. I will post something once I refactor what I have to satisfy it, and post the proposed changes that would need to be made. The general sketch is to make LDA a trait, LearningState a trait, and have implementations have a signature something like object LDAGibbs extends LDA and provide a LearningState implementation. More soon.

@jkbradley
Copy link
Member Author

@EntilZha Thanks for taking a look! W.r.t. the class names, I really hope we can keep a single LDA class and have an optimizer parameter which lets users specify the optimization/inference method. We are trying to move away from the AlgWithOptimizer naming conventions. That will end up affecting the abstractions you were discussing too (though maybe those abstractions will exist but be private/protected).

@EntilZha
Copy link
Contributor

What might be the best way to have the EM and Gibbs LDA implementations play well with each other?

If the aim is to not have separate LDA classes, on first thought I think maybe then LearningState would be where to put the work in. So, that would mean creating the LearningState trait, then within LDA, there would be multiple definitions of classes that satisfy LearningState. So you might have GibbsLearningState and EMLearningState (solves that naming problem easily). It seems like most of the implementation would be within LearningState anyway, not per se within LDA.

This has the advantage you named, in the LDA constructor we could set the default LearningState type to either EM or Gibbs (probably whichever performs better), allowing users to specify which algorithm they would like if wanted.

@jkbradley
Copy link
Member Author

By the way, the 2 sets of timing results above are not that comparable since I realized that I limited the vocab size in the later tests. I'm re-running tests to see if the updates from @mengxr helped.

@jkbradley
Copy link
Member Author

@EntilZha I agree with your sketch: abstracting LearningState, and having different versions for each algorithm.

@EntilZha
Copy link
Contributor

Nice, I am not too far from completing my refactoring. How would it be best to share it, open a different PR or link to it here? Eventually, will it be two PRs... not sure best way to do this.

@jkbradley
Copy link
Member Author

Sounds good! I'd recommend linking to it from here for now and creating a PR later.

@EntilZha
Copy link
Contributor

Just finished refactoring, here is the combined API/LDAModel code + Gibbs implementing it. It should give a pretty good idea of what I was thinking about using LearningState. Note though: 1. although it compiles, I haven't run it since there are some methods I still need to implement, 2. I need to update the docs after I finish implementing missing methods 3. I also need to go through to properly make objects/classes/methods/vars have the correct access level (public, private, etc)
https://github.com/EntilZha/spark/tree/LDA-Refactor/mllib/src/main/scala/org/apache/spark/mllib/topicmodeling

this.topicSmoothing
} else {
(50.0 / k) + 1.0
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor thing. How about moving the if-else to getTermSmoothing? For logic separation and also users can have a interface to collect the para

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will do.

@jkbradley
Copy link
Member Author

There may be issues with extra stuff being caught in the closure. Iterations seem to shuffle more and more data. See attached image of web UI

lda-webui

*
* Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
* but values in (0,1) are not yet supported.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. I feel the doc should be in setters.

@jkbradley
Copy link
Member Author

@mengxr Almost done with updating per your feedback, but 1 question: Do you recommend doc only in the setter methods, or in both the getters and setters?

…ab computation. Also updated PeriodicGraphCheckpointerSuite.scala to clean up checkpoint files at end
@jkbradley
Copy link
Member Author

@mengxr Oops, that commit hid some of my responses to your feedback above. The 2 issues remaining are (1) Who is responsible for materializing graphs? and (2) Where should the param docs appear (getters and/or setters)? (My last commit moved the docs to the setters.)

@AmplabJenkins
Copy link

Test FAILed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/26603/
Test FAILed.

@jkbradley
Copy link
Member Author

Test failure with system...will test again

@SparkQA
Copy link

SparkQA commented Feb 3, 2015

Test build #574 has started for PR 4047 at commit 589728b.

  • This patch merges cleanly.

@mengxr
Copy link
Contributor

mengxr commented Feb 3, 2015

@jkbradley Both setter and getter are public methods and hence both of them should have JavaDoc. I prefer doc the default value in setter because it is used when a user want to change the parameter.

@jkbradley
Copy link
Member Author

@mengxr OK, I'll copy the doc to both, but only put the default value in the setter.

@SparkQA
Copy link

SparkQA commented Feb 3, 2015

Test build #574 has finished for PR 4047 at commit 589728b.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member Author

@mengxr I believe that's it. Thanks for the feedback!

@jkbradley
Copy link
Member Author

Ok...now I believe that's it!

@SparkQA
Copy link

SparkQA commented Feb 3, 2015

Test build #26624 has started for PR 4047 at commit 77e8814.

  • This patch merges cleanly.

@SparkQA
Copy link

SparkQA commented Feb 3, 2015

Test build #26624 has finished for PR 4047 at commit 77e8814.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class EMOptimizer(

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/26624/
Test PASSed.

@asfgit asfgit closed this in 980764f Feb 3, 2015
@mengxr
Copy link
Contributor

mengxr commented Feb 3, 2015

LGTM. Merged into master. Thanks everyone for collaborating on LDA! @jkbradley Please create follow-up JIRAs and see who are interested in working on LDA features.

@mengxr mengxr mentioned this pull request Feb 3, 2015
9 tasks
@jkbradley
Copy link
Member Author

@EntilZha @mengxr + everyone else: Thanks very much for all of the feedback!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants