From 7f7ea2afcebf2f0b7be102fec6421ee5f3e3770a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 5 May 2015 19:56:09 -0700 Subject: [PATCH] [SPARK-7388][SPARK-7383] wrapper for VectorAssembler in Python --- .../spark/ml/feature/VectorAssembler.scala | 2 +- .../org/apache/spark/ml/param/params.scala | 19 ++++++++- .../ml/param/shared/SharedParamsCodeGen.scala | 1 + .../spark/ml/param/shared/sharedParams.scala | 2 +- python/pyspark/ml/feature.py | 41 ++++++++++++++++++- python/pyspark/ml/param/shared.py | 29 +++++++++++++ python/pyspark/ml/wrapper.py | 11 ++--- 7 files changed, 96 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 8f2e62a8e2081..b5a69cee6daf3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._ /** * :: AlphaComponent :: - * A feature transformer than merge multiple columns into a vector column. + * A feature transformer that merges multiple columns into a vector column. */ @AlphaComponent class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 51ce19d29cd29..0579ba3623a64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -22,6 +22,7 @@ import java.util.NoSuchElementException import scala.annotation.varargs import scala.collection.mutable +import scala.reflect.ClassTag import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.util.Identifiable @@ -218,6 +219,18 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } +/** Specialized version of [[Param[Array[T]]]] for Java. */ +class ArrayParam[T : ClassTag](parent: Params, name: String, doc: String, isValid: Array[T] => Boolean) + extends Param[Array[T]](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + override def w(value: Array[T]): ParamPair[Array[T]] = super.w(value) + + private[param] def wCast(value: Seq[T]): ParamPair[Array[T]] = w(value.toArray) +} + /** * A param amd its value. */ @@ -311,7 +324,11 @@ trait Params extends Identifiable with Serializable { */ protected final def set[T](param: Param[T], value: T): this.type = { shouldOwn(param) - paramMap.put(param.asInstanceOf[Param[Any]], value) + if (param.isInstanceOf[ArrayParam[_]] && value.isInstanceOf[Seq[_]]) { + paramMap.put(param.asInstanceOf[ArrayParam[Any]].wCast(value.asInstanceOf[Seq[Any]])) + } else { + paramMap.put(param.w(value)) + } this } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index d379172e0bf53..ae0950653d8dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -83,6 +83,7 @@ private[shared] object SharedParamsCodeGen { case _ if c == classOf[Float] => "FloatParam" case _ if c == classOf[Double] => "DoubleParam" case _ if c == classOf[Boolean] => "BooleanParam" + case _ if c.isArray => s"ArrayParam[${getTypeString(c.getComponentType)}]" case _ => s"Param[${getTypeString(c)}]" } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index fb1874ccfc8dc..736374114b82c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params { * Param for input column names. * @group param */ - final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names") + final val inputCols: ArrayParam[String] = new ArrayParam[String](this, "inputCols", "input column names") /** @group getParam */ final def getInputCols: Array[String] = $(inputCols) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 4e4614b859ac6..e2fcc18c4774d 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -16,7 +16,7 @@ # from pyspark.rdd import ignore_unicode_prefix -from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures +from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, HasNumFeatures from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaTransformer from pyspark.mllib.common import inherit_doc @@ -112,6 +112,45 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): return self._set(**kwargs) +@inherit_doc +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): + """ + A feature transformer that merges multiple columns into a vector column. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF() + >>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features") + >>> vecAssembler.transform(df).head().features + SparseVector(3, {0: 1.0, 2: 3.0}) + >>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs + SparseVector(3, {0: 1.0, 2: 3.0}) + >>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"} + >>> vecAssembler.transform(df, params).head().vector + SparseVector(2, {1: 1.0}) + """ + + _java_class = "org.apache.spark.ml.feature.VectorAssembler" + + @keyword_only + def __init__(self, inputCols=None, outputCol=None): + """ + __init__(self, inputCols=None, outputCol=None) + """ + super(VectorAssembler, self).__init__() + self._setDefault() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCols=None, outputCol=None): + """ + setParams(self, inputCols=None, outputCol=None) + Sets params for this VectorAssembler. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 4f243844f8caa..aaf80f00085bf 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -223,6 +223,35 @@ def getInputCol(self): return self.getOrDefault(self.inputCol) +class HasInputCols(Params): + """ + Mixin for param inputCols: input column names. + """ + + # a placeholder to make it appear in the generated doc + inputCols = Param(Params._dummy(), "inputCols", "input column names") + + def __init__(self): + super(HasInputCols, self).__init__() + #: param for input column names + self.inputCols = Param(self, "inputCols", "input column names") + if None is not None: + self._setDefault(inputCols=None) + + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + self.paramMap[self.inputCols] = value + return self + + def getInputCols(self): + """ + Gets the value of inputCols or its default value. + """ + return self.getOrDefault(self.inputCols) + + class HasOutputCol(Params): """ Mixin for param outputCol: output column name. diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 73741c4b40dfb..cf31c6266e09c 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -67,7 +67,10 @@ def _transfer_params_to_java(self, params, java_obj): paramMap = self.extractParamMap(params) for param in self.params: if param in paramMap: - java_obj.set(param.name, paramMap[param]) + value = paramMap[param] + if isinstance(value, list): + value = _jvm().PythonUtils.toSeq(value) + java_obj.set(param.name, value) def _empty_java_param_map(self): """ @@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper): def transform(self, dataset, params={}): java_obj = self._java_obj() - self._transfer_params_to_java({}, java_obj) - java_param_map = self._create_java_param_map(params, java_obj) - return DataFrame(java_obj.transform(dataset._jdf, java_param_map), - dataset.sql_ctx) + self._transfer_params_to_java(params, java_obj) + return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx) @inherit_doc