Skip to content

Commit

Permalink
[SPARK-44250][ML][PYTHON][CONNECT] Implement classification evaluator
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Implement classification evaluator

### Why are the changes needed?

Distributed ML <> spark connect project.

### Does this PR introduce _any_ user-facing change?

Yes.
`BinaryClassificationEvaluator` and `MulticlassClassificationEvaluator` are added.

### How was this patch tested?

Closes #41793 from WeichenXu123/classification-evaluator.

Authored-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Jul 4, 2023
1 parent 7bc28d5 commit 7fcabef
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 36 deletions.
161 changes: 126 additions & 35 deletions python/pyspark/ml/connect/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,61 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np

import pandas as pd
from typing import Any, Union
from typing import Any, Union, List, Tuple

from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasProbabilityCol
from pyspark.ml.connect.base import Evaluator
from pyspark.ml.connect.io_utils import ParamsReadWrite
from pyspark.ml.connect.util import aggregate_dataframe
from pyspark.sql import DataFrame

import torch
import torcheval.metrics as torchmetrics

class _TorchMetricEvaluator(Evaluator):

class RegressionEvaluator(Evaluator, HasLabelCol, HasPredictionCol, ParamsReadWrite):
metricName: Param[str] = Param(
Params._dummy(),
"metricName",
"metric name for the regression evaluator, valid values are 'mse' and 'r2'",
typeConverter=TypeConverters.toString,
)

def _get_torch_metric(self) -> Any:
raise NotImplementedError()

def _get_input_cols(self) -> List[str]:
raise NotImplementedError()

def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]:
raise NotImplementedError()

def _evaluate(self, dataset: Union["DataFrame", "pd.DataFrame"]) -> float:
torch_metric = self._get_torch_metric()

def local_agg_fn(pandas_df: "pd.DataFrame") -> "pd.DataFrame":
torch_metric.update(*self._get_metric_update_inputs(pandas_df))
return torch_metric

def merge_agg_state(state1: Any, state2: Any) -> Any:
state1.merge_state([state2])
return state1

def agg_state_to_result(state: Any) -> Any:
return state.compute().item()

return aggregate_dataframe(
dataset,
self._get_input_cols(),
local_agg_fn,
merge_agg_state,
agg_state_to_result,
)


class RegressionEvaluator(_TorchMetricEvaluator, HasLabelCol, HasPredictionCol, ParamsReadWrite):
"""
Evaluator for Regression, which expects input columns prediction and label.
Supported metrics are 'mse' and 'r2'.
Expand All @@ -41,14 +80,9 @@ def __init__(self, metricName: str, labelCol: str, predictionCol: str) -> None:
super().__init__()
self._set(metricName=metricName, labelCol=labelCol, predictionCol=predictionCol)

metricName: Param[str] = Param(
Params._dummy(),
"metricName",
"metric name for the regression evaluator, valid values are 'mse' and 'r2'",
typeConverter=TypeConverters.toString,
)

def _get_torch_metric(self) -> Any:
import torcheval.metrics as torchmetrics

metric_name = self.getOrDefault(self.metricName)

if metric_name == "mse":
Expand All @@ -58,32 +92,89 @@ def _get_torch_metric(self) -> Any:

raise ValueError(f"Unsupported regressor evaluator metric name: {metric_name}")

def _evaluate(self, dataset: Union["DataFrame", "pd.DataFrame"]) -> float:
prediction_col = self.getPredictionCol()
label_col = self.getLabelCol()
def _get_input_cols(self) -> List[str]:
return [self.getPredictionCol(), self.getLabelCol()]

torch_metric = self._get_torch_metric()
def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]:
import torch

def local_agg_fn(pandas_df: "pd.DataFrame") -> "pd.DataFrame":
with torch.inference_mode():
preds_tensor = torch.tensor(pandas_df[prediction_col].values)
labels_tensor = torch.tensor(pandas_df[label_col].values)
torch_metric.update(preds_tensor, labels_tensor)
return torch_metric
preds_tensor = torch.tensor(dataset[self.getPredictionCol()].values)
labels_tensor = torch.tensor(dataset[self.getLabelCol()].values)
return preds_tensor, labels_tensor

def merge_agg_state(state1: Any, state2: Any) -> Any:
with torch.inference_mode():
state1.merge_state([state2])
return state1

def agg_state_to_result(state: Any) -> Any:
with torch.inference_mode():
return state.compute().item()
class BinaryClassificationEvaluator(
_TorchMetricEvaluator, HasLabelCol, HasProbabilityCol, ParamsReadWrite
):
"""
Evaluator for binary classification, which expects input columns prediction and label.
Supported metrics are 'areaUnderROC' and 'areaUnderPR'.
return aggregate_dataframe(
dataset,
[prediction_col, label_col],
local_agg_fn,
merge_agg_state,
agg_state_to_result,
.. versionadded:: 3.5.0
"""

def __init__(self, metricName: str, labelCol: str, probabilityCol: str) -> None:
super().__init__()
self._set(metricName=metricName, labelCol=labelCol, probabilityCol=probabilityCol)

def _get_torch_metric(self) -> Any:
import torcheval.metrics as torchmetrics

metric_name = self.getOrDefault(self.metricName)

if metric_name == "areaUnderROC":
return torchmetrics.BinaryAUROC()
if metric_name == "areaUnderPR":
return torchmetrics.BinaryAUPRC()

raise ValueError(f"Unsupported binary classification evaluator metric name: {metric_name}")

def _get_input_cols(self) -> List[str]:
return [self.getProbabilityCol(), self.getLabelCol()]

def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]:
import torch

values = np.stack(dataset[self.getProbabilityCol()].values) # type: ignore[call-overload]
preds_tensor = torch.tensor(values)
if preds_tensor.dim() == 2:
preds_tensor = preds_tensor[:, 1]
labels_tensor = torch.tensor(dataset[self.getLabelCol()].values)
return preds_tensor, labels_tensor


class MulticlassClassificationEvaluator(
_TorchMetricEvaluator, HasLabelCol, HasPredictionCol, ParamsReadWrite
):
"""
Evaluator for multiclass classification, which expects input columns prediction and label.
Supported metrics are 'accuracy'.
.. versionadded:: 3.5.0
"""

def __init__(self, metricName: str, labelCol: str, predictionCol: str) -> None:
super().__init__()
self._set(metricName=metricName, labelCol=labelCol, predictionCol=predictionCol)

def _get_torch_metric(self) -> Any:
import torcheval.metrics as torchmetrics

metric_name = self.getOrDefault(self.metricName)

if metric_name == "accuracy":
return torchmetrics.MulticlassAccuracy()

raise ValueError(
f"Unsupported multiclass classification evaluator metric name: {metric_name}"
)

def _get_input_cols(self) -> List[str]:
return [self.getPredictionCol(), self.getLabelCol()]

def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]:
import torch

preds_tensor = torch.tensor(dataset[self.getPredictionCol()].values)
labels_tensor = torch.tensor(dataset[self.getLabelCol()].values)
return preds_tensor, labels_tensor
77 changes: 76 additions & 1 deletion python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import unittest
import numpy as np

from pyspark.ml.connect.evaluation import RegressionEvaluator
from pyspark.ml.connect.evaluation import (
RegressionEvaluator,
BinaryClassificationEvaluator,
MulticlassClassificationEvaluator,
)
from pyspark.sql import SparkSession


Expand Down Expand Up @@ -66,6 +70,77 @@ def test_regressor_evaluator(self):
np.testing.assert_almost_equal(r2, expected_r2)
np.testing.assert_almost_equal(r2_local, expected_r2)

def test_binary_classifier_evaluator(self):
df1 = self.spark.createDataFrame(
[
(1, 0.2, [0.8, 0.2]),
(0, 0.6, [0.4, 0.6]),
(1, 0.8, [0.2, 0.8]),
(1, 0.7, [0.3, 0.7]),
(0, 0.4, [0.6, 0.4]),
(0, 0.3, [0.7, 0.3]),
],
schema=["label", "prob", "prob2"],
)

local_df1 = df1.toPandas()

for prob_col in ["prob", "prob2"]:
auroc_evaluator = BinaryClassificationEvaluator(
metricName="areaUnderROC",
labelCol="label",
probabilityCol=prob_col,
)

expected_auroc = 0.6667
auroc = auroc_evaluator.evaluate(df1)
auroc_local = auroc_evaluator.evaluate(local_df1)
np.testing.assert_almost_equal(auroc, expected_auroc, decimal=2)
np.testing.assert_almost_equal(auroc_local, expected_auroc, decimal=2)

auprc_evaluator = BinaryClassificationEvaluator(
metricName="areaUnderPR",
labelCol="label",
probabilityCol=prob_col,
)

expected_auprc = 0.8333
auprc = auprc_evaluator.evaluate(df1)
auprc_local = auprc_evaluator.evaluate(local_df1)
np.testing.assert_almost_equal(auprc, expected_auprc, decimal=2)
np.testing.assert_almost_equal(auprc_local, expected_auprc, decimal=2)

def test_multiclass_classifier_evaluator(self):
df1 = self.spark.createDataFrame(
[
(1, 1),
(1, 1),
(2, 3),
(0, 0),
(0, 1),
(3, 1),
(3, 3),
(2, 2),
(1, 0),
(2, 2),
],
schema=["label", "prediction"],
)

local_df1 = df1.toPandas()

accuracy_evaluator = MulticlassClassificationEvaluator(
metricName="accuracy",
labelCol="label",
predictionCol="prediction",
)

expected_accuracy = 0.600
accuracy = accuracy_evaluator.evaluate(df1)
accuracy_local = accuracy_evaluator.evaluate(local_df1)
np.testing.assert_almost_equal(accuracy, expected_accuracy, decimal=2)
np.testing.assert_almost_equal(accuracy_local, expected_accuracy, decimal=2)


@unittest.skipIf(not have_torcheval, "torcheval is required")
class EvaluationTests(EvaluationTestsMixin, unittest.TestCase):
Expand Down

0 comments on commit 7fcabef

Please sign in to comment.