diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 00a6a6e5633e3..20c4dfab1a57e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -128,12 +128,10 @@ class LogisticRegressionModel private[ml] ( 1.0 / (1.0 + math.exp(-margin)) } val t = map(threshold) - val predict: Vector => Double = (v) => { - if (score(v) > t) 1.0 else 0.0 + val predict: Double => Double = (score) => { + if (score > t) 1.0 else 0.0 } - dataset.select( - Star(None), - score.call(map(featuresCol).attr) as map(scoreCol), - predict.call(map(featuresCol).attr) as map(predictionCol)) + dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) + .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) } }