Skip to content

Commit

Permalink
update doctest
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed May 5, 2015
1 parent acac727 commit 060f7c3
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

from pyspark.ml.param import Params, Param
from pyspark.ml import Estimator, Model
from pyspark.sql.functions import rand
from pyspark.ml.util import keyword_only
from pyspark.sql.functions import rand

__all__ = ['ParamGridBuilder', 'CrossValidator']
__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel']


class ParamGridBuilder(object):
Expand Down Expand Up @@ -84,6 +84,7 @@ def build(self):
grid_values = self._param_grid.values()
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]


class CrossValidator(Estimator):
"""
K-fold cross validation.
Expand All @@ -99,9 +100,7 @@ class CrossValidator(Estimator):
... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
... ["features", "label"])
>>> lr = LogisticRegression()
>>> grid = ParamGridBuilder() \
.addGrid(lr.maxIter, [0, 1, 5]) \
.build()
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
>>> evaluator = BinaryClassificationEvaluator()
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
>>> cvModel = cv.fit(dataset)
Expand Down

0 comments on commit 060f7c3

Please sign in to comment.