Skip to content

Commit

Permalink
updated python code
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Feb 24, 2019
1 parent 079e114 commit a571adc
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
18 changes: 13 additions & 5 deletions python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def isLargerBetter(self):


@inherit_doc
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol,
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol, HasWeightCol,
JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
Expand All @@ -130,6 +130,14 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
>>> evaluator2 = BinaryClassificationEvaluator.load(bce_path)
>>> str(evaluator2.getRawPredictionCol())
'raw'
>>> scoreAndLabelsAndWeight = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1], x[2]),
... [(0.1, 0.0, 1.0), (0.1, 1.0, 0.9), (0.4, 0.0, 0.7), (0.6, 0.0, 0.9),
... (0.6, 1.0, 1.0), (0.6, 1.0, 0.3), (0.8, 1.0, 1.0)])
>>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["raw", "label", "weight"])
...
>>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw", weightCol="weight")
>>> evaluator.evaluate(dataset)
0.70...
.. versionadded:: 1.4.0
"""
Expand All @@ -140,10 +148,10 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction

@keyword_only
def __init__(self, rawPredictionCol="rawPrediction", labelCol="label",
metricName="areaUnderROC"):
metricName="areaUnderROC", weightCol=None):
"""
__init__(self, rawPredictionCol="rawPrediction", labelCol="label", \
metricName="areaUnderROC")
metricName="areaUnderROC", weightCol=None)
"""
super(BinaryClassificationEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
Expand All @@ -169,10 +177,10 @@ def getMetricName(self):
@keyword_only
@since("1.4.0")
def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
metricName="areaUnderROC"):
metricName="areaUnderROC", weightCol=None):
"""
setParams(self, rawPredictionCol="rawPrediction", labelCol="label", \
metricName="areaUnderROC")
metricName="areaUnderROC", weightCol=None)
Sets params for binary classification evaluator.
"""
kwargs = self._input_kwargs
Expand Down
22 changes: 17 additions & 5 deletions python/pyspark/mllib/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class BinaryClassificationMetrics(JavaModelWrapper):
"""
Evaluator for binary classification.
:param scoreAndLabels: an RDD of (score, label) pairs
:param scoreAndLabelsWithOptWeight: an RDD of score, label and optional weight.
>>> scoreAndLabels = sc.parallelize([
... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
Expand All @@ -40,16 +40,28 @@ class BinaryClassificationMetrics(JavaModelWrapper):
>>> metrics.areaUnderPR
0.83...
>>> metrics.unpersist()
>>> scoreAndLabelsWithOptWeight = sc.parallelize([
... (0.1, 0.0, 1.0), (0.1, 1.0, 0.4), (0.4, 0.0, 0.2), (0.6, 0.0, 0.6), (0.6, 1.0, 0.9),
... (0.6, 1.0, 0.5), (0.8, 1.0, 0.7)], 2)
>>> metrics = BinaryClassificationMetrics(scoreAndLabelsWithOptWeight)
>>> metrics.areaUnderROC
0.70...
>>> metrics.areaUnderPR
0.83...
.. versionadded:: 1.4.0
"""

def __init__(self, scoreAndLabels):
sc = scoreAndLabels.ctx
def __init__(self, scoreAndLabelsWithOptWeight):
sc = scoreAndLabelsWithOptWeight.ctx
sql_ctx = SQLContext.getOrCreate(sc)
df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
numCol = len(scoreAndLabelsWithOptWeight.first())
schema = StructType([
StructField("score", DoubleType(), nullable=False),
StructField("label", DoubleType(), nullable=False)]))
StructField("label", DoubleType(), nullable=False)])
if (numCol == 3):
schema.add("weight", DoubleType(), False)
df = sql_ctx.createDataFrame(scoreAndLabelsWithOptWeight, schema=schema)
java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
java_model = java_class(df._jdf)
super(BinaryClassificationMetrics, self).__init__(java_model)
Expand Down

0 comments on commit a571adc

Please sign in to comment.