Skip to content

Commit

Permalink
[OSCP] 使用 SPU 实现 AP(average_precision_score) 函数 (#801)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #727

Implemented `average_precision_score` function for binary classification
and multi-class classification with three average methods.
  • Loading branch information
z0gSh1u committed Aug 20, 2024
1 parent c826c48 commit 28fef7d
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 6 deletions.
1 change: 1 addition & 0 deletions sml/metrics/classification/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ py_library(
srcs = ["classification.py"],
deps = [
":auc",
"//sml/preprocessing",
"//spu/ops/groupby",
],
)
Expand Down
7 changes: 5 additions & 2 deletions sml/metrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from spu.ops.groupby import groupby_sorted


def binary_clf_curve(sorted_pairs: jnp.array) -> Tuple[jnp.array, jnp.array, jnp.array]:
def binary_clf_curve(
sorted_pairs: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Calculate true and false positives per binary classification
threshold (can be used for roc curve or precision/recall curve).
Results may include trailing zeros.
Args:
sorted_pairs: jnp.array
sorted_pairs: jnp.ndarray
y_true y_score pairs sorted by y_score in decreasing order
Returns:
fps: 1d ndarray
Expand Down Expand Up @@ -57,6 +59,7 @@ def binary_clf_curve(sorted_pairs: jnp.array) -> Tuple[jnp.array, jnp.array, jnp
fps = seg_end_marks * fps
thresholds = seg_end_marks * thresholds
thresholds, fps, tps = jax.lax.sort([-thresholds] + [fps, tps], num_keys=1)

return fps, tps, -thresholds


Expand Down
156 changes: 155 additions & 1 deletion sml/metrics/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

import jax
import jax.numpy as jnp
from auc import binary_roc_auc

from sml.preprocessing.preprocessing import label_binarize
from spu.ops.groupby import groupby, groupby_sum

from .auc import binary_clf_curve, binary_roc_auc


def roc_auc_score(y_true, y_pred):
sorted_arr = create_sorted_label_score_pair(y_true, y_pred)
Expand Down Expand Up @@ -222,3 +224,155 @@ def fun_score(
else:
raise ValueError("average should be None or 'binary'")
return fun_result


def precision_recall_curve(
y_true: jnp.ndarray, y_score: jnp.ndarray, pos_label=1, score_eps=1e-5
):
"""Compute precision-recall pairs for different probability thresholds.
Note: this implementation is restricted to the binary classification task.
Parameters
----------
y_true : 1d array-like of shape (n,). True binary labels.
y_score : 1d array-like of shape (n,). Target scores, non-negative.
pos_label : int, default=1. The label of the positive class.
score_eps : float, default=1e-5. The lower bound for y_score.
Returns
-------
precisions : ndarray of shape (n + 1,).
Precision values where element i is the precision s.t.
score >= thresholds[i] and the last element is 1.
recalls : ndarray of shape (n + 1,).
Increasing recall values where element i is the recall s.t.
score >= thresholds[i] and the last element is 0.
thresholds : ndarray of shape (n,).
Decreasing thresholds used to compute precision and recall.
Results might include trailing zeros.
"""

# normalize the input
y_true = jnp.where(y_true == pos_label, 1, 0)
y_score = jnp.where(
y_score < score_eps, score_eps, y_score
) # to avoid messing up trailing zero and score zero

# compute TP and FP
sorted_pairs = create_sorted_label_score_pair(y_true, y_score)
fp, tp, thresholds = binary_clf_curve(sorted_pairs)

# compute precision and recalls
mask = jnp.where(thresholds > 0, 1, 0) # tied value entries have mask=0
precisions = jnp.where(mask, tp / (tp + fp + 1e-5), 0)
max_tp = jnp.max(tp)
recalls = jnp.where(max_tp == 0, jnp.ones_like(tp), tp / max_tp)

return (
jnp.hstack((1, precisions)),
jnp.hstack((0, recalls)),
thresholds,
)


def average_precision_score(
y_true: jnp.ndarray,
y_score: jnp.ndarray,
classes=(0, 1),
average="macro",
pos_label=1,
score_eps=1e-5,
):
"""Compute average precision (AP) from prediction scores.
.. math::
\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
Parameters
-------
y_true : array-like of shape (n_samples,)
True labels.
y_score : array-like of shape (n_samples,) or (n_samples, n_classes)
Estimated target scores as returned by a classifier, non-negative.
classes : 1d array-like, shape (n_classes,), default=(0,1) as for binary classification
Uniquely holds the label for each class.
SPU cannot support dynamic shape, so this parameter needs to be designated.
average : {'macro', 'micro', None}, default='macro'
This parameter is required for multiclass/multilabel targets and
will be ignored when y_true is binary.
'macro':
Calculate metrics for each label, and find their unweighted mean.
'micro':
Calculate metrics globally by considering each element of the label
indicator matrix as a label.
None:
Scores for each class are returned.
pos_label : int, default=1
The label of the positive class. Only applied to binary y_true.
score_eps : float, default=1e-5. The lower bound for y_score.
Returns
-------
average_precision : float
Average precision score.
"""

assert average in (
'macro',
'micro',
None,
), 'average must be either "macro", "micro" or None'

def binary_average_precision(y_true, y_score, pos_label=1):
"""Compute the average precision for binary classification."""
precisions, recalls, _ = precision_recall_curve(
y_true, y_score, pos_label=pos_label, score_eps=score_eps
)

return jnp.sum(jnp.diff(recalls) * precisions[1:])

n_classes = len(classes)
if n_classes <= 2:
# binary classification
# given y_true all the same is a special case considered as binary classification
return binary_average_precision(y_true, y_score, pos_label=pos_label)
else:
# multi-class classification
# binarize labels using one-vs-all scheme into multilabel-indicator
y_true = label_binarize(y_true, classes=classes, n_classes=n_classes)

if average == "micro":
y_true = y_true.ravel()
y_score = y_score.ravel()
elif average == "macro":
pass

# extend the classes dimension if needed
if y_true.ndim == 1:
y_true = y_true[:, jnp.newaxis]
if y_score.ndim == 1:
y_score = y_score[:, jnp.newaxis]

# compute score for each class
n_classes = y_score.shape[1]
score = jnp.zeros((n_classes,))
for c in range(n_classes):
binary_ap = binary_average_precision(
y_true[:, c].ravel(), y_score[:, c].ravel(), pos_label=pos_label
)
score = score.at[c].set(binary_ap)

# average the scores
return jnp.average(score) if average else score
80 changes: 77 additions & 3 deletions sml/metrics/classification/classification_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
import jax.numpy as jnp
import numpy as np
from sklearn import metrics
from sklearn.metrics import average_precision_score as sk_average_precision_score

# add ops dir to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))

import sml.utils.emulation as emulation
from sml.metrics.classification.classification import (
accuracy_score,
average_precision_score,
f1_score,
precision_score,
recall_score,
Expand All @@ -42,7 +44,7 @@ def emul_auc(mode: emulation.Mode.MULTIPROCESS):

# Run
result = emulator.run(roc_auc_score)(
y_true, y_pred
*emulator.seal(y_true, y_pred)
) # X, y should be two-dimension array
print(result)

Expand Down Expand Up @@ -97,7 +99,7 @@ def check(spu_result, sk_result):
y_true = jnp.array([0, 1, 1, 0, 1, 1])
y_pred = jnp.array([0, 0, 1, 0, 1, 1])
spu_result = emulator.run(proc, static_argnums=(2, 5))(
y_true, y_pred, 'binary', None, 1, False
*emulator.seal(y_true, y_pred), 'binary', None, 1, False
)
sk_result = sklearn_proc(y_true, y_pred)
check(spu_result, sk_result)
Expand All @@ -106,12 +108,83 @@ def check(spu_result, sk_result):
y_true = jnp.array([0, 1, 1, 0, 2, 1])
y_pred = jnp.array([0, 0, 1, 0, 2, 1])
spu_result = emulator.run(proc, static_argnums=(2, 5))(
y_true, y_pred, None, [0, 1, 2], 1, True
*emulator.seal(y_true, y_pred), None, [0, 1, 2], 1, True
)
sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
check(spu_result, sk_result)


def emul_average_precision_score(mode: emulation.Mode.MULTIPROCESS):
def procBinary(y_true, y_score, **kwargs):
sk_res = sk_average_precision_score(y_true, y_score, **kwargs)
spu_res = emulator.run(average_precision_score)(
*emulator.seal(y_true, y_score), **kwargs
)
return sk_res, spu_res

def check(res1, res2):
return np.testing.assert_allclose(res1, res2, rtol=1e-3, atol=1e-3)

# --- Test binary classification ---
# 0-1 labels, no tied value
y_true = jnp.array([0, 0, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0.1, 0.4, 0.35, 0.8], dtype=jnp.float32)
check(*procBinary(y_true, y_score))
# 0-1 labels, with tied value, even length
y_true = jnp.array([0, 0, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0.4, 0.4, 0.4, 0.25], dtype=jnp.float32)
check(*procBinary(y_true, y_score))
# 0-1 labels, with tied value, odd length
y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0.4, 0.4, 0.4, 0.25, 0.25], dtype=jnp.float32)
check(*procBinary(y_true, y_score))
# customized labels
y_true = jnp.array([2, 2, 3, 3], dtype=jnp.int32)
y_score = jnp.array([0.1, 0.2, 0.3, 0.4], dtype=jnp.float32)
check(*procBinary(y_true, y_score, pos_label=3))
# larger random dataset
y_true = jnp.array(np.random.randint(0, 2, 100), dtype=jnp.int32)
y_score = jnp.array(np.hstack((0, 1, np.random.random(98))), dtype=jnp.float32)
check(*procBinary(y_true, y_score))
# single label edge case
y_true = jnp.array([0, 0, 0, 0], dtype=jnp.int32)
y_score = jnp.array([0.4, 0.25, 0.4, 0.25], dtype=jnp.float32)
check(*procBinary(y_true, y_score))
y_true = jnp.array([1, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0.4, 0.25, 0.4, 0.25], dtype=jnp.float32)
check(*procBinary(y_true, y_score))
# zero score edge case
y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0, 0, 0, 0.25, 0.25], dtype=jnp.float32)
check(*procBinary(y_true, y_score))
# score > 1 edge case
y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([1.5, 1.5, 1.5, 0.25, 0.25], dtype=jnp.float32)
check(*procBinary(y_true, y_score))

# --- Test multiclass classification ---
y_true = np.array([0, 0, 1, 1, 2, 2], dtype=jnp.int32)
y_score = np.array(
[
[0.7, 0.2, 0.1],
[0.4, 0.3, 0.3],
[0.1, 0.8, 0.1],
[0.2, 0.3, 0.5],
[0.4, 0.4, 0.2],
[0.1, 0.2, 0.7],
],
dtype=jnp.float32,
)
classes = jnp.unique(y_true)
# test over three supported average options
for average in ["macro", "micro", None]:
sk_res = sk_average_precision_score(y_true, y_score, average=average)
spu_res = emulator.run(average_precision_score, static_argnums=(3,))(
*emulator.seal(y_true, y_score), classes, average
)
check(sk_res, spu_res)


if __name__ == "__main__":
try:
# bandwidth and latency only work for docker mode
Expand All @@ -124,5 +197,6 @@ def check(spu_result, sk_result):
emulator.up()
emul_auc(emulation.Mode.MULTIPROCESS)
emul_Classification(emulation.Mode.MULTIPROCESS)
emul_average_precision_score(emulation.Mode.MULTIPROCESS)
finally:
emulator.down()
Loading

0 comments on commit 28fef7d

Please sign in to comment.