Skip to content

Commit

Permalink
add keyword args
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed May 5, 2015
1 parent cdddecd commit acac727
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pyspark.ml.param import Params, Param
from pyspark.ml import Estimator, Model
from pyspark.sql.functions import rand
from pyspark.ml.util import keyword_only

__all__ = ['ParamGridBuilder', 'CrossValidator']

Expand Down Expand Up @@ -102,11 +103,7 @@ class CrossValidator(Estimator):
.addGrid(lr.maxIter, [0, 1, 5]) \
.build()
>>> evaluator = BinaryClassificationEvaluator()
>>> cv = CrossValidator() \
.setEstimator(lr) \
.setEstimatorParamMaps(grid) \
.setEvaluator(evaluator) \
.setNumFolds(3)
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
>>> cvModel = cv.fit(dataset)
>>> expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
>>> cvModel.transform(dataset).collect() == expected.collect()
Expand All @@ -125,7 +122,11 @@ class CrossValidator(Estimator):
# a placeholder to make it appear in the generated doc
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")

def __init__(self):
@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
"""
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
"""
super(CrossValidator, self).__init__()
#: param for estimator to be cross-validated
self.estimator = Param(self, "estimator", "estimator to be cross-validated")
Expand All @@ -135,6 +136,18 @@ def __init__(self):
self.evaluator = Param(self, "evaluator", "evaluator for selection")
#: param for number of folds for cross validation
self.numFolds = Param(self, "numFolds", "number of folds for cross validation")
self._setDefault(numFolds=3)
kwargs = self.__init__._input_kwargs
self._set(**kwargs)

@keyword_only
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
"""
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
Sets params for cross validator.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)

def setEstimator(self, value):
"""
Expand Down

0 comments on commit acac727

Please sign in to comment.