Skip to content

Commit

Permalink
cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Oct 10, 2020
1 parent d812512 commit e09faae
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 76 deletions.
10 changes: 3 additions & 7 deletions tests/metrics/classification/inputs.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import os
import pytest
import numpy as np
from collections import namedtuple

import torch

from collections import namedtuple
from tests.metrics.utils import (
NUM_BATCHES,
NUM_PROCESSES,
BATCH_SIZE,
NUM_CLASSES,
EXTRA_DIM,
THRESHOLD
EXTRA_DIM
)

Input = namedtuple('Input', ["preds", "target"])
Expand Down
44 changes: 9 additions & 35 deletions tests/metrics/classification/test_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import os
import numpy as np
import pytest
import torch
import numpy as np
from collections import namedtuple
from functools import partial

from pytorch_lightning.metrics.classification.accuracy import Accuracy
from sklearn.metrics import accuracy_score

from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import THRESHOLD

from pytorch_lightning.metrics.classification.accuracy import Accuracy
from tests.metrics.classification.inputs import (
_testing_binary_prob_inputs,
_testing_binary_inputs,
Expand All @@ -21,6 +14,8 @@
_testing_multidim_multiclass_prob_inputs,
_testing_multidim_multiclass_inputs,
)
from tests.metrics.utils import THRESHOLD
from tests.metrics.utils import compute_batch

torch.manual_seed(42)

Expand All @@ -32,7 +27,7 @@ def _sk_accuracy_binary_prob(preds, target):
return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _sk_accuracy_binary_sk_metric(preds, target):
def _sk_accuracy(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

Expand All @@ -46,41 +41,20 @@ def _sk_accuracy_multilabel_prob(preds, target):
return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _sk_accuracy_multilabel(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _sk_accuracy_multiclass_prob(preds, target):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
sk_target = target.view(-1).numpy()

return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _sk_accuracy_multiclass(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _sk_accuracy_multidim_multiclass_prob(preds, target):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
sk_target = target.view(-1).numpy()

return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def _sk_accuracy_multidim_multiclass(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return accuracy_score(y_true=sk_target, y_pred=sk_preds)


def test_accuracy_invalid_shape():
with pytest.raises(ValueError):
acc = Accuracy()
Expand All @@ -91,11 +65,11 @@ def test_accuracy_invalid_shape():
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("preds, target, sk_metric", [
(_testing_binary_prob_inputs.preds, _testing_binary_prob_inputs.target, _sk_accuracy_binary_prob),
(_testing_binary_inputs.preds, _testing_binary_inputs.target, _sk_accuracy_binary_sk_metric),
(_testing_binary_inputs.preds, _testing_binary_inputs.target, _sk_accuracy),
(_testing_multilabel_prob_inputs.preds, _testing_multilabel_prob_inputs.target, _sk_accuracy_multilabel_prob),
(_testing_multilabel_inputs.preds, _testing_multilabel_inputs.target, _sk_accuracy_multilabel),
(_testing_multilabel_inputs.preds, _testing_multilabel_inputs.target, _sk_accuracy),
(_testing_multiclass_prob_inputs.preds, _testing_multiclass_prob_inputs.target, _sk_accuracy_multiclass_prob),
(_testing_multiclass_inputs.preds, _testing_multiclass_inputs.target, _sk_accuracy_multiclass),
(_testing_multiclass_inputs.preds, _testing_multiclass_inputs.target, _sk_accuracy),
(
_testing_multidim_multiclass_prob_inputs.preds,
_testing_multidim_multiclass_prob_inputs.target,
Expand All @@ -104,7 +78,7 @@ def test_accuracy_invalid_shape():
(
_testing_multidim_multiclass_inputs.preds,
_testing_multidim_multiclass_inputs.target,
_sk_accuracy_multidim_multiclass
_sk_accuracy
)
])
def test_accuracy(ddp, dist_sync_on_step, preds, target, sk_metric):
Expand Down
24 changes: 7 additions & 17 deletions tests/metrics/classification/test_f_beta.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import os
import pytest
import torch
import numpy as np
from collections import namedtuple

from functools import partial

from pytorch_lightning.metrics import Fbeta
import numpy as np
import pytest
import torch
from sklearn.metrics import fbeta_score

from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE, NUM_CLASSES, THRESHOLD

from pytorch_lightning.metrics import Fbeta
from tests.metrics.classification.inputs import (
_testing_binary_prob_inputs,
_testing_binary_inputs,
Expand All @@ -22,6 +17,8 @@
_testing_multidim_multiclass_prob_inputs,
_testing_multidim_multiclass_inputs,
)
from tests.metrics.utils import NUM_CLASSES, THRESHOLD
from tests.metrics.utils import compute_batch

torch.manual_seed(42)

Expand Down Expand Up @@ -75,13 +72,6 @@ def _sk_fbeta_multidim_multiclass_prob(preds, target, average='micro', beta=1.):
return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta)


def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta)


@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("average", ['micro', 'macro'])
Expand All @@ -102,7 +92,7 @@ def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.):
(
_testing_multidim_multiclass_inputs.preds,
_testing_multidim_multiclass_inputs.target,
_sk_fbeta_multidim_multiclass,
_sk_fbeta_multiclass,
NUM_CLASSES,
False
)
Expand Down
24 changes: 7 additions & 17 deletions tests/metrics/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import os
import pytest
import torch
import numpy as np
from collections import namedtuple

from functools import partial

from pytorch_lightning.metrics import Precision, Recall
import numpy as np
import pytest
import torch
from sklearn.metrics import precision_score, recall_score

from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE, NUM_CLASSES, THRESHOLD

from pytorch_lightning.metrics import Precision, Recall
from tests.metrics.classification.inputs import (
_testing_binary_prob_inputs,
_testing_binary_inputs,
Expand All @@ -22,6 +17,8 @@
_testing_multidim_multiclass_prob_inputs,
_testing_multidim_multiclass_inputs,
)
from tests.metrics.utils import NUM_CLASSES, THRESHOLD
from tests.metrics.utils import compute_batch

torch.manual_seed(42)

Expand Down Expand Up @@ -75,13 +72,6 @@ def _sk_prec_recall_multidim_multiclass_prob(preds, target, sk_fn=precision_scor
return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average)


def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn=precision_score, average='micro'):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average)


@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("average", ['micro', 'macro'])
Expand All @@ -102,7 +92,7 @@ def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn=precision_score, av
(
_testing_multidim_multiclass_inputs.preds,
_testing_multidim_multiclass_inputs.target,
_sk_prec_recall_multidim_multiclass,
_sk_prec_recall_multiclass,
NUM_CLASSES,
False
)
Expand Down

0 comments on commit e09faae

Please sign in to comment.