Skip to content

Commit

Permalink
nameing
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Oct 10, 2020
1 parent 557d32f commit d812512
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 55 deletions.
32 changes: 16 additions & 16 deletions tests/metrics/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,56 +25,56 @@
torch.manual_seed(42)


def _binary_prob_sk_metric(preds, target):
def _sk_accuracy_binary_prob(preds, target):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()

return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _binary_sk_metric(preds, target):
def _sk_accuracy_binary_sk_metric(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _multilabel_prob_sk_metric(preds, target):
def _sk_accuracy_multilabel_prob(preds, target):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()

return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _multilabel_sk_metric(preds, target):
def _sk_accuracy_multilabel(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _multiclass_prob_sk_metric(preds, target):
def _sk_accuracy_multiclass_prob(preds, target):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
sk_target = target.view(-1).numpy()

return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _multiclass_sk_metric(preds, target):
def _sk_accuracy_multiclass(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _multidim_multiclass_prob_sk_metric(preds, target):
def _sk_accuracy_multidim_multiclass_prob(preds, target):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
sk_target = target.view(-1).numpy()

return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _multidim_multiclass_sk_metric(preds, target):
def _sk_accuracy_multidim_multiclass(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

Expand All @@ -90,21 +90,21 @@ def test_accuracy_invalid_shape():
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("preds, target, sk_metric", [
(_testing_binary_prob_inputs.preds, _testing_binary_prob_inputs.target, _binary_prob_sk_metric),
(_testing_binary_inputs.preds, _testing_binary_inputs.target, _binary_sk_metric),
(_testing_multilabel_prob_inputs.preds, _testing_multilabel_prob_inputs.target, _multilabel_prob_sk_metric),
(_testing_multilabel_inputs.preds, _testing_multilabel_inputs.target, _multilabel_sk_metric),
(_testing_multiclass_prob_inputs.preds, _testing_multiclass_prob_inputs.target, _multiclass_prob_sk_metric),
(_testing_multiclass_inputs.preds, _testing_multiclass_inputs.target, _multiclass_sk_metric),
(_testing_binary_prob_inputs.preds, _testing_binary_prob_inputs.target, _sk_accuracy_binary_prob),
(_testing_binary_inputs.preds, _testing_binary_inputs.target, _sk_accuracy_binary_sk_metric),
(_testing_multilabel_prob_inputs.preds, _testing_multilabel_prob_inputs.target, _sk_accuracy_multilabel_prob),
(_testing_multilabel_inputs.preds, _testing_multilabel_inputs.target, _sk_accuracy_multilabel),
(_testing_multiclass_prob_inputs.preds, _testing_multiclass_prob_inputs.target, _sk_accuracy_multiclass_prob),
(_testing_multiclass_inputs.preds, _testing_multiclass_inputs.target, _sk_accuracy_multiclass),
(
_testing_multidim_multiclass_prob_inputs.preds,
_testing_multidim_multiclass_prob_inputs.target,
_multidim_multiclass_prob_sk_metric
_sk_accuracy_multidim_multiclass_prob
),
(
_testing_multidim_multiclass_inputs.preds,
_testing_multidim_multiclass_inputs.target,
_multidim_multiclass_sk_metric
_sk_accuracy_multidim_multiclass
)
])
def test_accuracy(ddp, dist_sync_on_step, preds, target, sk_metric):
Expand Down
40 changes: 17 additions & 23 deletions tests/metrics/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,56 +26,56 @@
torch.manual_seed(42)


def _binary_prob_sk_metric(preds, target, average='micro', beta=1.):
def _sk_fbeta_binary_prob(preds, target, average='micro', beta=1.):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()

return fbeta_score(y_true=sk_target, y_pred=sk_preds, average='binary', beta=beta)


def _binary_sk_metric(preds, target, average='micro', beta=1.):
def _sk_fbeta_binary(preds, target, average='micro', beta=1.):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return fbeta_score(y_true=sk_target, y_pred=sk_preds, average='binary', beta=beta)


def _multilabel_prob_sk_metric(preds, target, average='micro', beta=1.):
def _sk_fbeta_multilabel_prob(preds, target, average='micro', beta=1.):
sk_preds = (preds.view(-1, NUM_CLASSES).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1, NUM_CLASSES).numpy()

return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta)


def _multilabel_sk_metric(preds, target, average='micro', beta=1.):
def _sk_fbeta_multilabel(preds, target, average='micro', beta=1.):
sk_preds = preds.view(-1, NUM_CLASSES).numpy()
sk_target = target.view(-1, NUM_CLASSES).numpy()

return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta)


def _multiclass_prob_sk_metric(preds, target, average='micro', beta=1.):
def _sk_fbeta_multiclass_prob(preds, target, average='micro', beta=1.):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
sk_target = target.view(-1).numpy()

return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta)


def _multiclass_sk_metric(preds, target, average='micro', beta=1.):
def _sk_fbeta_multiclass(preds, target, average='micro', beta=1.):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta)


def _multidim_multiclass_prob_sk_metric(preds, target, average='micro', beta=1.):
def _sk_fbeta_multidim_multiclass_prob(preds, target, average='micro', beta=1.):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
sk_target = target.view(-1).numpy()

return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta)


def _multidim_multiclass_sk_metric(preds, target, average='micro', beta=1.):
def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

Expand All @@ -86,34 +86,28 @@ def _multidim_multiclass_sk_metric(preds, target, average='micro', beta=1.):
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("average", ['micro', 'macro'])
@pytest.mark.parametrize("preds, target, sk_metric, num_classes, multilabel", [
(_testing_binary_prob_inputs.preds, _testing_binary_prob_inputs.target, _binary_prob_sk_metric, 1, False),
(_testing_binary_inputs.preds, _testing_binary_inputs.target, _binary_sk_metric, 1, False),
(_testing_multilabel_prob_inputs.preds, _testing_multilabel_prob_inputs.target, _multilabel_prob_sk_metric, NUM_CLASSES, True),
(_testing_multilabel_inputs.preds, _testing_multilabel_inputs.target, _multilabel_sk_metric, NUM_CLASSES, True),
(_testing_multiclass_prob_inputs.preds, _testing_multiclass_prob_inputs.target, _multiclass_prob_sk_metric, NUM_CLASSES, False),
(_testing_multiclass_inputs.preds, _testing_multiclass_inputs.target, _multiclass_sk_metric, NUM_CLASSES, False),
(_testing_binary_prob_inputs.preds, _testing_binary_prob_inputs.target, _sk_fbeta_binary_prob, 1, False),
(_testing_binary_inputs.preds, _testing_binary_inputs.target, _sk_fbeta_binary, 1, False),
(_testing_multilabel_prob_inputs.preds, _testing_multilabel_prob_inputs.target, _sk_fbeta_multilabel_prob, NUM_CLASSES, True),
(_testing_multilabel_inputs.preds, _testing_multilabel_inputs.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
(_testing_multiclass_prob_inputs.preds, _testing_multiclass_prob_inputs.target, _sk_fbeta_multiclass_prob, NUM_CLASSES, False),
(_testing_multiclass_inputs.preds, _testing_multiclass_inputs.target, _sk_fbeta_multiclass, NUM_CLASSES, False),
(
_testing_multidim_multiclass_prob_inputs.preds,
_testing_multidim_multiclass_prob_inputs.target,
_multidim_multiclass_prob_sk_metric,
_sk_fbeta_multidim_multiclass_prob,
NUM_CLASSES,
False
),
(
_testing_multidim_multiclass_inputs.preds,
_testing_multidim_multiclass_inputs.target,
_multidim_multiclass_sk_metric,
_sk_fbeta_multidim_multiclass,
NUM_CLASSES,
False
)
])
@pytest.mark.parametrize(
"metric_class, beta",
[
(Fbeta, 0.5),
(Fbeta, 1.),
],
)
@pytest.mark.parametrize("metric_class, beta", [(Fbeta, 0.5), (Fbeta, 1.)])
def test_fbeta(
ddp,
dist_sync_on_step,
Expand Down
32 changes: 16 additions & 16 deletions tests/metrics/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,56 +26,56 @@
torch.manual_seed(42)


def _binary_prob_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
def _sk_prec_recall_binary_prob(preds, target, sk_fn=precision_score, average='micro'):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()

return sk_fn(y_true=sk_target, y_pred=sk_preds, average='binary')


def _binary_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
def _sk_prec_recall_binary(preds, target, sk_fn=precision_score, average='micro'):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_fn(y_true=sk_target, y_pred=sk_preds, average='binary')


def _multilabel_prob_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
def _sk_prec_recall_multilabel_prob(preds, target, sk_fn=precision_score, average='micro'):
sk_preds = (preds.view(-1, NUM_CLASSES).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1, NUM_CLASSES).numpy()

return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average)


def _multilabel_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
def _sk_prec_recall_multilabel(preds, target, sk_fn=precision_score, average='micro'):
sk_preds = preds.view(-1, NUM_CLASSES).numpy()
sk_target = target.view(-1, NUM_CLASSES).numpy()

return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average)


def _multiclass_prob_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
def _sk_prec_recall_multiclass_prob(preds, target, sk_fn=precision_score, average='micro'):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average)


def _multiclass_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
def _sk_prec_recall_multiclass(preds, target, sk_fn=precision_score, average='micro'):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average)


def _multidim_multiclass_prob_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
def _sk_prec_recall_multidim_multiclass_prob(preds, target, sk_fn=precision_score, average='micro'):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average)


def _multidim_multiclass_sk_metric(preds, target, sk_fn=precision_score, average='micro'):
def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn=precision_score, average='micro'):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

Expand All @@ -86,23 +86,23 @@ def _multidim_multiclass_sk_metric(preds, target, sk_fn=precision_score, average
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("average", ['micro', 'macro'])
@pytest.mark.parametrize("preds, target, sk_metric, num_classes, multilabel", [
(_testing_binary_prob_inputs.preds, _testing_binary_prob_inputs.target, _binary_prob_sk_metric, 1, False),
(_testing_binary_inputs.preds, _testing_binary_inputs.target, _binary_sk_metric, 1, False),
(_testing_multilabel_prob_inputs.preds, _testing_multilabel_prob_inputs.target, _multilabel_prob_sk_metric, NUM_CLASSES, True),
(_testing_multilabel_inputs.preds, _testing_multilabel_inputs.target, _multilabel_sk_metric, NUM_CLASSES, True),
(_testing_multiclass_prob_inputs.preds, _testing_multiclass_prob_inputs.target, _multiclass_prob_sk_metric, NUM_CLASSES, False),
(_testing_multiclass_inputs.preds, _testing_multiclass_inputs.target, _multiclass_sk_metric, NUM_CLASSES, False),
(_testing_binary_prob_inputs.preds, _testing_binary_prob_inputs.target, _sk_prec_recall_binary_prob, 1, False),
(_testing_binary_inputs.preds, _testing_binary_inputs.target, _sk_prec_recall_binary, 1, False),
(_testing_multilabel_prob_inputs.preds, _testing_multilabel_prob_inputs.target, _sk_prec_recall_multilabel_prob, NUM_CLASSES, True),
(_testing_multilabel_inputs.preds, _testing_multilabel_inputs.target, _sk_prec_recall_multilabel, NUM_CLASSES, True),
(_testing_multiclass_prob_inputs.preds, _testing_multiclass_prob_inputs.target, _sk_prec_recall_multiclass_prob, NUM_CLASSES, False),
(_testing_multiclass_inputs.preds, _testing_multiclass_inputs.target, _sk_prec_recall_multiclass, NUM_CLASSES, False),
(
_testing_multidim_multiclass_prob_inputs.preds,
_testing_multidim_multiclass_prob_inputs.target,
_multidim_multiclass_prob_sk_metric,
_sk_prec_recall_multidim_multiclass_prob,
NUM_CLASSES,
False
),
(
_testing_multidim_multiclass_inputs.preds,
_testing_multidim_multiclass_inputs.target,
_multidim_multiclass_sk_metric,
_sk_prec_recall_multidim_multiclass,
NUM_CLASSES,
False
)
Expand Down

0 comments on commit d812512

Please sign in to comment.