diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 1cb30a0420ab5..ad33b4b23c096 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -20,10 +20,10 @@ from pyspark.ml.param import Params, Param from pyspark.ml import Estimator, Model -from pyspark.sql.functions import rand from pyspark.ml.util import keyword_only +from pyspark.sql.functions import rand -__all__ = ['ParamGridBuilder', 'CrossValidator'] +__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel'] class ParamGridBuilder(object): @@ -84,6 +84,7 @@ def build(self): grid_values = self._param_grid.values() return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] + class CrossValidator(Estimator): """ K-fold cross validation. @@ -99,9 +100,7 @@ class CrossValidator(Estimator): ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10, ... ["features", "label"]) >>> lr = LogisticRegression() - >>> grid = ParamGridBuilder() \ - .addGrid(lr.maxIter, [0, 1, 5]) \ - .build() + >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build() >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) >>> cvModel = cv.fit(dataset)