diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 330ae2938f4e0..15e23466250b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -113,9 +113,9 @@ private[spark] abstract class ProbabilisticClassificationModel[ } if ($(predictionCol).nonEmpty) { val predUDF = if ($(rawPredictionCol).nonEmpty) { - callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol))) + udf[Double, Vector](raw2prediction).apply(col($(rawPredictionCol))) } else if ($(probabilityCol).nonEmpty) { - callUDF(probability2prediction _, DoubleType, col($(probabilityCol))) + udf[Double, Vector](probability2prediction).apply(col($(probabilityCol))) } else { callUDF(predict _, DoubleType, col($(featuresCol))) }