Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-7388][SPARK-7383] wrapper for VectorAssembler in Python #5930

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._

/**
* :: AlphaComponent ::
* A feature transformer than merge multiple columns into a vector column.
* A feature transformer that merges multiple columns into a vector column.
*/
@AlphaComponent
class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
Expand Down
23 changes: 22 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.NoSuchElementException

import scala.annotation.varargs
import scala.collection.mutable
import scala.reflect.ClassTag
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line.


import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.util.Identifiable
Expand Down Expand Up @@ -218,6 +219,22 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}

/** Specialized version of [[Param[Array[T]]]] for Java. */
class ArrayParam[T : ClassTag](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having ClassTag is not Java friendly. Array[T] will be translated into Object in Java to handle both primitive arrays and object arrays. How about adding StringArrayParam instead of ArrayParam[T] in this PR?

parent: Params,
name: String,
doc: String,
isValid: Array[T] => Boolean)
extends Param[Array[T]](parent, name, doc, isValid) {

def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)

override def w(value: Array[T]): ParamPair[Array[T]] = super.w(value)

private[param] def wCast(value: Seq[T]): ParamPair[Array[T]] = w(value.toArray)
}

/**
* A param amd its value.
*/
Expand Down Expand Up @@ -311,7 +328,11 @@ trait Params extends Identifiable with Serializable {
*/
protected final def set[T](param: Param[T], value: T): this.type = {
shouldOwn(param)
paramMap.put(param.asInstanceOf[Param[Any]], value)
if (param.isInstanceOf[ArrayParam[_]] && value.isInstanceOf[Seq[_]]) {
paramMap.put(param.asInstanceOf[ArrayParam[Any]].wCast(value.asInstanceOf[Seq[Any]]))
} else {
paramMap.put(param.w(value))
}
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ private[shared] object SharedParamsCodeGen {
case _ if c == classOf[Float] => "FloatParam"
case _ if c == classOf[Double] => "DoubleParam"
case _ if c == classOf[Boolean] => "BooleanParam"
case _ if c.isArray => s"ArrayParam[${getTypeString(c.getComponentType)}]"
case _ => s"Param[${getTypeString(c)}]"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params {
* Param for input column names.
* @group param
*/
final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names")
final val inputCols: ArrayParam[String] = new ArrayParam[String](this, "inputCols", "input column names")

/** @group getParam */
final def getInputCols: Array[String] = $(inputCols)
Expand Down
43 changes: 41 additions & 2 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
#

from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, HasNumFeatures
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaTransformer
from pyspark.mllib.common import inherit_doc

__all__ = ['Tokenizer', 'HashingTF']
__all__ = ['Tokenizer', 'HashingTF', 'VectorAssembler']


@inherit_doc
Expand Down Expand Up @@ -112,6 +112,45 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
return self._set(**kwargs)


@inherit_doc
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
"""
A feature transformer that merges multiple columns into a vector column.

>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF()
>>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features")
>>> vecAssembler.transform(df).head().features
SparseVector(3, {0: 1.0, 2: 3.0})
>>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
SparseVector(3, {0: 1.0, 2: 3.0})
>>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
>>> vecAssembler.transform(df, params).head().vector
SparseVector(2, {1: 1.0})
"""

_java_class = "org.apache.spark.ml.feature.VectorAssembler"

@keyword_only
def __init__(self, inputCols=None, outputCol=None):
"""
__init__(self, inputCols=None, outputCol=None)
"""
super(VectorAssembler, self).__init__()
self._setDefault()
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, inputCols=None, outputCol=None):
"""
setParams(self, inputCols=None, outputCol=None)
Sets params for this VectorAssembler.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)


if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/ml/param/_shared_params_code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def get$Name(self):
("predictionCol", "prediction column name", "'prediction'"),
("rawPredictionCol", "raw prediction column name", "'rawPrediction'"),
("inputCol", "input column name", None),
("inputCols", "input column names", None),
("outputCol", "output column name", None),
("numFeatures", "number of features", None)]
code = []
Expand Down
29 changes: 29 additions & 0 deletions python/pyspark/ml/param/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,35 @@ def getInputCol(self):
return self.getOrDefault(self.inputCol)


class HasInputCols(Params):
"""
Mixin for param inputCols: input column names.
"""

# a placeholder to make it appear in the generated doc
inputCols = Param(Params._dummy(), "inputCols", "input column names")

def __init__(self):
super(HasInputCols, self).__init__()
#: param for input column names
self.inputCols = Param(self, "inputCols", "input column names")
if None is not None:
self._setDefault(inputCols=None)

def setInputCols(self, value):
"""
Sets the value of :py:attr:`inputCols`.
"""
self.paramMap[self.inputCols] = value
return self

def getInputCols(self):
"""
Gets the value of inputCols or its default value.
"""
return self.getOrDefault(self.inputCols)


class HasOutputCol(Params):
"""
Mixin for param outputCol: output column name.
Expand Down
11 changes: 6 additions & 5 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def _transfer_params_to_java(self, params, java_obj):
paramMap = self.extractParamMap(params)
for param in self.params:
if param in paramMap:
java_obj.set(param.name, paramMap[param])
value = paramMap[param]
if isinstance(value, list):
value = _jvm().PythonUtils.toSeq(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this special treatment if we overload w with JList<String> in StringArrayParam?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a setAll with varargs, then we will need to cast it to a seq.

java_obj.set(param.name, value)

def _empty_java_param_map(self):
"""
Expand Down Expand Up @@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper):

def transform(self, dataset, params={}):
java_obj = self._java_obj()
self._transfer_params_to_java({}, java_obj)
java_param_map = self._create_java_param_map(params, java_obj)
return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
dataset.sql_ctx)
self._transfer_params_to_java(params, java_obj)
return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx)


@inherit_doc
Expand Down