diff --git a/tests/metrics/classification/inputs.py b/tests/metrics/classification/inputs.py index 6cbcd755a264e9..de2ef04593b820 100644 --- a/tests/metrics/classification/inputs.py +++ b/tests/metrics/classification/inputs.py @@ -1,16 +1,12 @@ -import os -import pytest -import numpy as np +from collections import namedtuple + import torch -from collections import namedtuple from tests.metrics.utils import ( NUM_BATCHES, - NUM_PROCESSES, BATCH_SIZE, NUM_CLASSES, - EXTRA_DIM, - THRESHOLD + EXTRA_DIM ) Input = namedtuple('Input', ["preds", "target"]) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 5adfb297b5eae0..9a0b582a2b923f 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -1,16 +1,9 @@ -import os +import numpy as np import pytest import torch -import numpy as np -from collections import namedtuple -from functools import partial - -from pytorch_lightning.metrics.classification.accuracy import Accuracy from sklearn.metrics import accuracy_score -from tests.metrics.utils import compute_batch, setup_ddp -from tests.metrics.utils import THRESHOLD - +from pytorch_lightning.metrics.classification.accuracy import Accuracy from tests.metrics.classification.inputs import ( _testing_binary_prob_inputs, _testing_binary_inputs, @@ -21,6 +14,8 @@ _testing_multidim_multiclass_prob_inputs, _testing_multidim_multiclass_inputs, ) +from tests.metrics.utils import THRESHOLD +from tests.metrics.utils import compute_batch torch.manual_seed(42) @@ -32,7 +27,7 @@ def _sk_accuracy_binary_prob(preds, target): return accuracy_score(y_true=sk_target, y_pred=sk_preds) -def _sk_accuracy_binary_sk_metric(preds, target): +def _sk_accuracy(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() @@ -46,13 +41,6 @@ def _sk_accuracy_multilabel_prob(preds, target): return accuracy_score(y_true=sk_target, y_pred=sk_preds) -def _sk_accuracy_multilabel(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return accuracy_score(y_true=sk_target, y_pred=sk_preds) - - def _sk_accuracy_multiclass_prob(preds, target): sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() sk_target = target.view(-1).numpy() @@ -60,13 +48,6 @@ def _sk_accuracy_multiclass_prob(preds, target): return accuracy_score(y_true=sk_target, y_pred=sk_preds) -def _sk_accuracy_multiclass(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return accuracy_score(y_true=sk_target, y_pred=sk_preds) - - def _sk_accuracy_multidim_multiclass_prob(preds, target): sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() sk_target = target.view(-1).numpy() @@ -74,13 +55,6 @@ def _sk_accuracy_multidim_multiclass_prob(preds, target): return accuracy_score(y_true=sk_target, y_pred=sk_preds) -def _sk_accuracy_multidim_multiclass(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return accuracy_score(y_true=sk_target, y_pred=sk_preds) - - def test_accuracy_invalid_shape(): with pytest.raises(ValueError): acc = Accuracy() @@ -91,11 +65,11 @@ def test_accuracy_invalid_shape(): @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("preds, target, sk_metric", [ (_testing_binary_prob_inputs.preds, _testing_binary_prob_inputs.target, _sk_accuracy_binary_prob), - (_testing_binary_inputs.preds, _testing_binary_inputs.target, _sk_accuracy_binary_sk_metric), + (_testing_binary_inputs.preds, _testing_binary_inputs.target, _sk_accuracy), (_testing_multilabel_prob_inputs.preds, _testing_multilabel_prob_inputs.target, _sk_accuracy_multilabel_prob), - (_testing_multilabel_inputs.preds, _testing_multilabel_inputs.target, _sk_accuracy_multilabel), + (_testing_multilabel_inputs.preds, _testing_multilabel_inputs.target, _sk_accuracy), (_testing_multiclass_prob_inputs.preds, _testing_multiclass_prob_inputs.target, _sk_accuracy_multiclass_prob), - (_testing_multiclass_inputs.preds, _testing_multiclass_inputs.target, _sk_accuracy_multiclass), + (_testing_multiclass_inputs.preds, _testing_multiclass_inputs.target, _sk_accuracy), ( _testing_multidim_multiclass_prob_inputs.preds, _testing_multidim_multiclass_prob_inputs.target, @@ -104,7 +78,7 @@ def test_accuracy_invalid_shape(): ( _testing_multidim_multiclass_inputs.preds, _testing_multidim_multiclass_inputs.target, - _sk_accuracy_multidim_multiclass + _sk_accuracy ) ]) def test_accuracy(ddp, dist_sync_on_step, preds, target, sk_metric): diff --git a/tests/metrics/classification/test_f_beta.py b/tests/metrics/classification/test_f_beta.py index c7f70927445bda..f2113ff66127df 100644 --- a/tests/metrics/classification/test_f_beta.py +++ b/tests/metrics/classification/test_f_beta.py @@ -1,17 +1,12 @@ -import os import pytest -import torch -import numpy as np -from collections import namedtuple - from functools import partial -from pytorch_lightning.metrics import Fbeta +import numpy as np +import pytest +import torch from sklearn.metrics import fbeta_score -from tests.metrics.utils import compute_batch, setup_ddp -from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE, NUM_CLASSES, THRESHOLD - +from pytorch_lightning.metrics import Fbeta from tests.metrics.classification.inputs import ( _testing_binary_prob_inputs, _testing_binary_inputs, @@ -22,6 +17,8 @@ _testing_multidim_multiclass_prob_inputs, _testing_multidim_multiclass_inputs, ) +from tests.metrics.utils import NUM_CLASSES, THRESHOLD +from tests.metrics.utils import compute_batch torch.manual_seed(42) @@ -75,13 +72,6 @@ def _sk_fbeta_multidim_multiclass_prob(preds, target, average='micro', beta=1.): return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) -def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) - - @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("average", ['micro', 'macro']) @@ -102,7 +92,7 @@ def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.): ( _testing_multidim_multiclass_inputs.preds, _testing_multidim_multiclass_inputs.target, - _sk_fbeta_multidim_multiclass, + _sk_fbeta_multiclass, NUM_CLASSES, False ) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index cfc839cb1e603a..ff5ad65df1a594 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -1,17 +1,12 @@ -import os import pytest -import torch -import numpy as np -from collections import namedtuple - from functools import partial -from pytorch_lightning.metrics import Precision, Recall +import numpy as np +import pytest +import torch from sklearn.metrics import precision_score, recall_score -from tests.metrics.utils import compute_batch, setup_ddp -from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE, NUM_CLASSES, THRESHOLD - +from pytorch_lightning.metrics import Precision, Recall from tests.metrics.classification.inputs import ( _testing_binary_prob_inputs, _testing_binary_inputs, @@ -22,6 +17,8 @@ _testing_multidim_multiclass_prob_inputs, _testing_multidim_multiclass_inputs, ) +from tests.metrics.utils import NUM_CLASSES, THRESHOLD +from tests.metrics.utils import compute_batch torch.manual_seed(42) @@ -75,13 +72,6 @@ def _sk_prec_recall_multidim_multiclass_prob(preds, target, sk_fn=precision_scor return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) -def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn=precision_score, average='micro'): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) - - @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("average", ['micro', 'macro']) @@ -102,7 +92,7 @@ def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn=precision_score, av ( _testing_multidim_multiclass_inputs.preds, _testing_multidim_multiclass_inputs.target, - _sk_prec_recall_multidim_multiclass, + _sk_prec_recall_multiclass, NUM_CLASSES, False )