Skip to content

Commit

Permalink
[SPARK-6893][ML] default pipeline parameter handling in python
Browse files Browse the repository at this point in the history
Same as apache#5431 but for Python. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes apache#5534 from mengxr/SPARK-6893 and squashes the following commits:

d3b519b [Xiangrui Meng] address comments
ebaccc6 [Xiangrui Meng] style update
fce244e [Xiangrui Meng] update explainParams with test
4d6b07a [Xiangrui Meng] add tests
5294500 [Xiangrui Meng] update default param handling in python
  • Loading branch information
mengxr committed Apr 16, 2015
1 parent 52c3439 commit 57cd1e8
Show file tree
Hide file tree
Showing 11 changed files with 270 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import java.util.UUID
private[ml] trait Identifiable extends Serializable {

/**
* A unique id for the object. The default implementation concatenates the class name, "-", and 8
* A unique id for the object. The default implementation concatenates the class name, "_", and 8
* random hex chars.
*/
private[ml] val uid: String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@

package org.apache.spark.ml.param

import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter}

/** A subclass of Params for testing. */
class TestParams extends Params {
class TestParams extends Params with HasMaxIter with HasInputCol {

val maxIter = new IntParam(this, "maxIter", "max number of iterations")
def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
def getMaxIter: Int = getOrDefault(maxIter)

val inputCol = new Param[String](this, "inputCol", "input column name")
def setInputCol(value: String): this.type = { set(inputCol, value); this }
def getInputCol: String = getOrDefault(inputCol)

setDefault(maxIter -> 10)

Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
maxIter=100, regParam=0.1)
"""
super(LogisticRegression, self).__init__()
self._setDefault(maxIter=100, regParam=0.1)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

Expand All @@ -71,7 +72,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
Sets params for logistic regression.
"""
kwargs = self.setParams._input_kwargs
return self._set_params(**kwargs)
return self._set(**kwargs)

def _create_model(self, java_model):
return LogisticRegressionModel(java_model)
Expand Down
19 changes: 10 additions & 9 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,22 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
_java_class = "org.apache.spark.ml.feature.Tokenizer"

@keyword_only
def __init__(self, inputCol="input", outputCol="output"):
def __init__(self, inputCol=None, outputCol=None):
"""
__init__(self, inputCol="input", outputCol="output")
__init__(self, inputCol=None, outputCol=None)
"""
super(Tokenizer, self).__init__()
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, inputCol="input", outputCol="output"):
def setParams(self, inputCol=None, outputCol=None):
"""
setParams(self, inputCol="input", outputCol="output")
Sets params for this Tokenizer.
"""
kwargs = self.setParams._input_kwargs
return self._set_params(**kwargs)
return self._set(**kwargs)


@inherit_doc
Expand All @@ -91,22 +91,23 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
_java_class = "org.apache.spark.ml.feature.HashingTF"

@keyword_only
def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
"""
__init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
__init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
"""
super(HashingTF, self).__init__()
self._setDefault(numFeatures=1 << 18)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
"""
setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
Sets params for this HashingTF.
"""
kwargs = self.setParams._input_kwargs
return self._set_params(**kwargs)
return self._set(**kwargs)


if __name__ == "__main__":
Expand Down
146 changes: 127 additions & 19 deletions python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,21 @@

class Param(object):
"""
A param with self-contained documentation and optionally default value.
A param with self-contained documentation.
"""

def __init__(self, parent, name, doc, defaultValue=None):
if not isinstance(parent, Identifiable):
raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__)
def __init__(self, parent, name, doc):
if not isinstance(parent, Params):
raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__)
self.parent = parent
self.name = str(name)
self.doc = str(doc)
self.defaultValue = defaultValue

def __str__(self):
return str(self.parent) + "-" + self.name
return str(self.parent) + "__" + self.name

def __repr__(self):
return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \
(self.parent, self.name, self.doc, self.defaultValue)
return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)


class Params(Identifiable):
Expand All @@ -52,26 +50,128 @@ class Params(Identifiable):

__metaclass__ = ABCMeta

def __init__(self):
super(Params, self).__init__()
#: embedded param map
self.paramMap = {}
#: internal param map for user-supplied values param map
paramMap = {}

#: internal param map for default values
defaultParamMap = {}

@property
def params(self):
"""
Returns all params. The default implementation uses
:py:func:`dir` to get all attributes of type
Returns all params ordered by name. The default implementation
uses :py:func:`dir` to get all attributes of type
:py:class:`Param`.
"""
return filter(lambda attr: isinstance(attr, Param),
[getattr(self, x) for x in dir(self) if x != "params"])

def _merge_params(self, params):
paramMap = self.paramMap.copy()
paramMap.update(params)
def _explain(self, param):
"""
Explains a single param and returns its name, doc, and optional
default value and user-supplied value in a string.
"""
param = self._resolveParam(param)
values = []
if self.isDefined(param):
if param in self.defaultParamMap:
values.append("default: %s" % self.defaultParamMap[param])
if param in self.paramMap:
values.append("current: %s" % self.paramMap[param])
else:
values.append("undefined")
valueStr = "(" + ", ".join(values) + ")"
return "%s: %s %s" % (param.name, param.doc, valueStr)

def explainParams(self):
"""
Returns the documentation of all params with their optionally
default values and user-supplied values.
"""
return "\n".join([self._explain(param) for param in self.params])

def getParam(self, paramName):
"""
Gets a param by its name.
"""
param = getattr(self, paramName)
if isinstance(param, Param):
return param
else:
raise ValueError("Cannot find param with name %s." % paramName)

def isSet(self, param):
"""
Checks whether a param is explicitly set by user.
"""
param = self._resolveParam(param)
return param in self.paramMap

def hasDefault(self, param):
"""
Checks whether a param has a default value.
"""
param = self._resolveParam(param)
return param in self.defaultParamMap

def isDefined(self, param):
"""
Checks whether a param is explicitly set by user or has a default value.
"""
return self.isSet(param) or self.hasDefault(param)

def getOrDefault(self, param):
"""
Gets the value of a param in the user-supplied param map or its
default value. Raises an error if either is set.
"""
if isinstance(param, Param):
if param in self.paramMap:
return self.paramMap[param]
else:
return self.defaultParamMap[param]
elif isinstance(param, str):
return self.getOrDefault(self.getParam(param))
else:
raise KeyError("Cannot recognize %r as a param." % param)

def extractParamMap(self, extraParamMap={}):
"""
Extracts the embedded default param values and user-supplied
values, and then merges them with extra values from input into
a flat param map, where the latter value is used if there exist
conflicts, i.e., with ordering: default param values <
user-supplied values < extraParamMap.
:param extraParamMap: extra param values
:return: merged param map
"""
paramMap = self.defaultParamMap.copy()
paramMap.update(self.paramMap)
paramMap.update(extraParamMap)
return paramMap

def _shouldOwn(self, param):
"""
Validates that the input param belongs to this Params instance.
"""
if param.parent is not self:
raise ValueError("Param %r does not belong to %r." % (param, self))

def _resolveParam(self, param):
"""
Resolves a param and validates the ownership.
:param param: param name or the param instance, which must
belong to this Params instance
:return: resolved param instance
"""
if isinstance(param, Param):
self._shouldOwn(param)
return param
elif isinstance(param, str):
return self.getParam(param)
else:
raise ValueError("Cannot resolve %r as a param." % param)

@staticmethod
def _dummy():
"""
Expand All @@ -81,10 +181,18 @@ def _dummy():
dummy.uid = "undefined"
return dummy

def _set_params(self, **kwargs):
def _set(self, **kwargs):
"""
Sets params.
Sets user-supplied params.
"""
for param, value in kwargs.iteritems():
self.paramMap[getattr(self, param)] = value
return self

def _setDefault(self, **kwargs):
"""
Sets default params.
"""
for param, value in kwargs.iteritems():
self.defaultParamMap[getattr(self, param)] = value
return self
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,34 @@
# limitations under the License.
#"""

# Code generator for shared params (shared.py). Run under this folder with:
# python _shared_params_code_gen.py > shared.py

def _gen_param_code(name, doc, defaultValue):

def _gen_param_code(name, doc, defaultValueStr):
"""
Generates Python code for a shared param class.
:param name: param name
:param doc: param doc
:param defaultValue: string representation of the param
:param defaultValueStr: string representation of the default value
:return: code string
"""
# TODO: How to correctly inherit instance attributes?
template = '''class Has$Name(Params):
"""
Params with $name.
Mixin for param $name: $doc.
"""
# a placeholder to make it appear in the generated doc
$name = Param(Params._dummy(), "$name", "$doc", $defaultValue)
$name = Param(Params._dummy(), "$name", "$doc")
def __init__(self):
super(Has$Name, self).__init__()
#: param for $doc
self.$name = Param(self, "$name", "$doc", $defaultValue)
self.$name = Param(self, "$name", "$doc")
if $defaultValueStr is not None:
self._setDefault($name=$defaultValueStr)
def set$Name(self, value):
"""
Expand All @@ -67,32 +72,29 @@ def get$Name(self):
"""
Gets the value of $name or its default value.
"""
if self.$name in self.paramMap:
return self.paramMap[self.$name]
else:
return self.$name.defaultValue'''
return self.getOrDefault(self.$name)'''

upperCamelName = name[0].upper() + name[1:]
Name = name[0].upper() + name[1:]
return template \
.replace("$name", name) \
.replace("$Name", upperCamelName) \
.replace("$Name", Name) \
.replace("$doc", doc) \
.replace("$defaultValue", defaultValue)
.replace("$defaultValueStr", str(defaultValueStr))

if __name__ == "__main__":
print header
print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n"
print "\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n"
print "from pyspark.ml.param import Param, Params\n\n"
shared = [
("maxIter", "max number of iterations", "100"),
("regParam", "regularization constant", "0.1"),
("maxIter", "max number of iterations", None),
("regParam", "regularization constant", None),
("featuresCol", "features column name", "'features'"),
("labelCol", "label column name", "'label'"),
("predictionCol", "prediction column name", "'prediction'"),
("inputCol", "input column name", "'input'"),
("outputCol", "output column name", "'output'"),
("numFeatures", "number of features", "1 << 18")]
("inputCol", "input column name", None),
("outputCol", "output column name", None),
("numFeatures", "number of features", None)]
code = []
for name, doc, defaultValue in shared:
code.append(_gen_param_code(name, doc, defaultValue))
for name, doc, defaultValueStr in shared:
code.append(_gen_param_code(name, doc, defaultValueStr))
print "\n\n\n".join(code)
Loading

0 comments on commit 57cd1e8

Please sign in to comment.