From 6149ca6438b0e625e364953a5dda103c73d7c8ab Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sat, 2 May 2015 00:48:14 +0800 Subject: [PATCH] fix for setOptimizer --- .../main/scala/org/apache/spark/mllib/clustering/LDA.scala | 5 +++-- .../org/apache/spark/mllib/clustering/LDAOptimizer.scala | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 597f17b0972a0..c8daa2388e868 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -210,14 +210,15 @@ class LDA private ( /** * Set the LDAOptimizer used to perform the actual calculation by algorithm name. - * Currently "em" is supported. + * Currently "em", "online" is supported. */ def setOptimizer(optimizerName: String): this.type = { this.ldaOptimizer = optimizerName.toLowerCase match { case "em" => new EMLDAOptimizer + case "online" => new OnlineLDAOptimizer case other => - throw new IllegalArgumentException(s"Only em is supported but got $other.") + throw new IllegalArgumentException(s"Only em, online are supported but got $other.") } this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index b352847ab6cc2..4353463aca050 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -396,7 +396,7 @@ class OnlineLDAOptimizer extends LDAOptimizer { val batchResult = statsSum :* expElogbeta // Note that this is an optimization to avoid batch.count - update(batchResult, iteration, (miniBatchFraction * corpusSize).toInt) + update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt) this }