diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index de3555a5b74d0..f8b56293e3ccc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.Param import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType @@ -45,9 +45,9 @@ class ElementwiseProduct extends UnaryTransformer[Vector, Vector, ElementwisePro /** @group getParam */ def getScalingVec: Vector = getOrDefault(scalingVec) - override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { - require(paramMap.contains(scalingVec), s"transformation requires a weight vector: $scalingVec") - val elemScaler = new feature.ElementwiseProduct(paramMap(scalingVec)) + override protected def createTransformFunc: Vector => Vector = { + require(params.contains(scalingVec), s"transformation requires a weight vector") + val elemScaler = new feature.ElementwiseProduct($(scalingVec)) elemScaler.transform }