From c221db9a710c45e1ede0b0cdea623422639677d6 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 7 May 2015 01:07:49 -0700 Subject: [PATCH] overload StringArrayParam.w --- .../org/apache/spark/ml/param/params.scala | 22 +++++++++++-------- python/pyspark/ml/wrapper.py | 8 +++---- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 6525a5a9aee52..dd1f4a1759568 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -22,7 +22,7 @@ import java.util.NoSuchElementException import scala.annotation.varargs import scala.collection.mutable -import scala.reflect.ClassTag +import scala.collection.JavaConverters._ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.util.Identifiable @@ -228,7 +228,8 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value) - private[param] def wCast(value: Seq[String]): ParamPair[Array[String]] = w(value.toArray) + /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ + def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray) } /** @@ -323,13 +324,7 @@ trait Params extends Identifiable with Serializable { * Sets a parameter in the embedded param map. */ protected final def set[T](param: Param[T], value: T): this.type = { - shouldOwn(param) - if (param.isInstanceOf[StringArrayParam] && value.isInstanceOf[Seq[_]]) { - paramMap.put(param.asInstanceOf[StringArrayParam].wCast(value.asInstanceOf[Seq[String]])) - } else { - paramMap.put(param.w(value)) - } - this + set(param -> value) } /** @@ -339,6 +334,15 @@ trait Params extends Identifiable with Serializable { set(getParam(param), value) } + /** + * Sets a parameter in the embedded param map. + */ + protected final def set(paramPair: ParamPair[_]): this.type = { + shouldOwn(paramPair.param) + paramMap.put(paramPair) + this + } + /** * Optionally returns the user-supplied value of a param. */ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index cf31c6266e09c..6f9cd9837befe 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -68,9 +68,8 @@ def _transfer_params_to_java(self, params, java_obj): for param in self.params: if param in paramMap: value = paramMap[param] - if isinstance(value, list): - value = _jvm().PythonUtils.toSeq(value) - java_obj.set(param.name, value) + java_param = java_obj.getParam(param.name) + java_obj.set(java_param.w(value)) def _empty_java_param_map(self): """ @@ -82,7 +81,8 @@ def _create_java_param_map(self, params, java_obj): paramMap = self._empty_java_param_map() for param, value in params.items(): if param.parent is self: - paramMap.put(java_obj.getParam(param.name), value) + java_param = java_obj.getParam(param.name) + paramMap.put(java_param.w(value)) return paramMap