Skip to content

Commit

Permalink
add Estimator and Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Dec 31, 2014
1 parent 46eea43 commit a3015cf
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
14 changes: 13 additions & 1 deletion python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pyspark import SparkContext
from pyspark.ml.param import Param

__all__ = ["Pipeline"]
__all__ = ["Pipeline", "Transformer", "Estimator"]

# An implementation of PEP3102 for Python 2.
_keyword_only_secret = 70861589
Expand Down Expand Up @@ -60,3 +60,15 @@ def transform(self, dataset):
for t in self.transformers:
dataset = t.transform(dataset)
return dataset


class Estimator(object):

def fit(self, dataset, params={}):
raise NotImplementedError()


class Transformer(object):

def transform(self, dataset, paramMap={}):
raise NotImplementedError()
6 changes: 3 additions & 3 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from pyspark.sql import SchemaRDD
from pyspark.ml import _jvm
from pyspark.ml import Estimator, Transformer, _jvm
from pyspark.ml.param import Param


class LogisticRegression(object):
class LogisticRegression(Estimator):
"""
Logistic regression.
"""
Expand Down Expand Up @@ -45,7 +45,7 @@ def fit(self, dataset, params=None):
return LogisticRegressionModel(java_model)


class LogisticRegressionModel(object):
class LogisticRegressionModel(Transformer):
"""
Model fitted by LogisticRegression.
"""
Expand Down

0 comments on commit a3015cf

Please sign in to comment.