Skip to content

Commit

Permalink
simplify java models
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 27, 2015
1 parent 036ca04 commit 5153cff
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
new ParamMap(this.map ++ other.map)
}


/**
* Adds all parameters from the input param map into this param map.
*/
Expand Down
50 changes: 43 additions & 7 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ def transform(self, dataset, params={}):
raise NotImplementedError()


@inherit_doc
class Model(Transformer):
"""
Abstract class for models fitted by :py:class:`Estimator`s.
"""

___metaclass__ = ABCMeta

def __init__(self):
super(Model, self).__init__()


@inherit_doc
class Pipeline(Estimator):
"""
Expand Down Expand Up @@ -169,7 +181,7 @@ def fit(self, dataset, params={}):


@inherit_doc
class PipelineModel(Transformer):
class PipelineModel(Model):
"""
Represents a compiled pipeline with transformers and fitted models.
"""
Expand Down Expand Up @@ -204,9 +216,9 @@ def _java_class(self):
"""
raise NotImplementedError

def _create_java_obj(self):
def _java_obj(self):
"""
Creates a new Java object and returns its reference.
Returns or creates a Java object.
"""
java_obj = _jvm()
for name in self._java_class.split("."):
Expand All @@ -231,6 +243,13 @@ def _empty_java_param_map(self):
"""
return _jvm().org.apache.spark.ml.param.ParamMap()

def _create_java_param_map(self, params, java_obj):
paramMap = self._empty_java_param_map()
for param, value in params.items():
if param.parent is self:
paramMap.put(java_obj.getParam(param.name), value)
return paramMap


@inherit_doc
class JavaEstimator(Estimator, JavaWrapper):
Expand Down Expand Up @@ -259,7 +278,7 @@ def _fit_java(self, dataset, params={}):
:param params: additional params (overwriting embedded values)
:return: fitted Java model
"""
java_obj = self._create_java_obj()
java_obj = self._java_obj()
self._transfer_params_to_java(params, java_obj)
return java_obj.fit(dataset._jschema_rdd, self._empty_java_param_map())

Expand All @@ -281,7 +300,24 @@ def __init__(self):
super(JavaTransformer, self).__init__()

def transform(self, dataset, params={}):
java_obj = self._create_java_obj()
self._transfer_params_to_java(params, java_obj)
return SchemaRDD(java_obj.transform(dataset._jschema_rdd, self._empty_java_param_map()),
java_obj = self._java_obj()
self._transfer_params_to_java({}, java_obj)
java_param_map = self._create_java_param_map(params, java_obj)
return SchemaRDD(java_obj.transform(dataset._jschema_rdd, java_param_map),
dataset.sql_ctx)


@inherit_doc
class JavaModel(JavaTransformer):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations.
"""

__metaclass__ = ABCMeta

def __init__(self):
super(JavaTransformer, self).__init__()

def _java_obj(self):
return self._java_model
14 changes: 6 additions & 8 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# limitations under the License.
#

from pyspark.sql import SchemaRDD, inherit_doc
from pyspark.ml import JavaEstimator, Transformer, _jvm
from pyspark.sql import inherit_doc
from pyspark.ml import JavaEstimator, JavaModel
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
HasRegParam

Expand All @@ -40,7 +40,7 @@ def _create_model(self, java_model):


@inherit_doc
class LogisticRegressionModel(Transformer):
class LogisticRegressionModel(JavaModel):
"""
Model fitted by LogisticRegression.
"""
Expand All @@ -49,8 +49,6 @@ def __init__(self, java_model):
super(LogisticRegressionModel, self).__init__()
self._java_model = java_model

def transform(self, dataset, params={}):
# TODO: handle params here.
return SchemaRDD(self._java_model.transform(
dataset._jschema_rdd,
_jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx)
@property
def _java_class(self):
return "org.apache.spark.ml.classification.LogisticRegressionModel"

0 comments on commit 5153cff

Please sign in to comment.