diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index d9a330f67e8dc..149e99d2f195a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -166,6 +166,7 @@ class PowerIterationClustering private[clustering] ( val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) { lit(1.0) } else { + SchemaUtils.checkNumericType(dataset.schema, $(weightCol)) col($(weightCol)).cast(DoubleType) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 55b460f1a4524..0ba3ffabb75d2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -145,6 +145,21 @@ class PowerIterationClusteringSuite extends SparkFunSuite assert(msg.contains("Similarity must be nonnegative")) } + test("check for invalid input types of weight") { + val invalidWeightData = spark.createDataFrame(Seq( + (0L, 1L, "a"), + (2L, 3L, "b") + )).toDF("src", "dst", "weight") + + val msg = intercept[IllegalArgumentException] { + new PowerIterationClustering() + .setWeightCol("weight") + .assignClusters(invalidWeightData) + }.getMessage + assert(msg.contains("requirement failed: Column weight must be of type numeric" + + " but was actually of type string.")) + } + test("test default weight") { val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst)