From 6122c76c4b02ed2df8bf33415a0913686b931c35 Mon Sep 17 00:00:00 2001 From: "sweep-ai[bot]" <128439645+sweep-ai[bot]@users.noreply.github.com> Date: Fri, 7 Jul 2023 18:14:25 +0000 Subject: [PATCH] Update tests/ignite/contrib/metrics/test_roc_curve.py --- tests/ignite/contrib/metrics/test_roc_curve.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/ignite/contrib/metrics/test_roc_curve.py b/tests/ignite/contrib/metrics/test_roc_curve.py index 97bfbcd3097..7f776a855b1 100644 --- a/tests/ignite/contrib/metrics/test_roc_curve.py +++ b/tests/ignite/contrib/metrics/test_roc_curve.py @@ -72,17 +72,14 @@ def update_fn(engine, batch): return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) engine = Engine(update_fn) - roc_curve_metric = RocCurve(output_transform=lambda x: (x[1], x[2])) roc_curve_metric.attach(engine, "roc_curve") - data = list(range(size // batch_size)) - + fpr, tpr, thresholds = engine.run(data, max_epochs=1).metrics["roc_curve"] assert np.array_equal(fpr, sk_fpr) assert np.array_equal(tpr, sk_tpr) # assert thresholds almost equal, due to numpy->torch->numpy conversion np.testing.assert_array_almost_equal(thresholds, [np.inf, 1.0, 0.711, 0.047]) - np.testing.assert_array_almost_equal(thresholds, sk_thresholds) def test_integration_roc_curve_with_activated_output_transform(): @@ -105,18 +102,13 @@ def update_fn(engine, batch): return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) engine = Engine(update_fn) - roc_curve_metric = RocCurve(output_transform=lambda x: (torch.sigmoid(x[1]), x[2])) roc_curve_metric.attach(engine, "roc_curve") + data = list(range(size // batch_size)) fpr, tpr, thresholds = engine.run(data, max_epochs=1).metrics["roc_curve"] - - assert np.array_equal(fpr, sk_fpr) assert np.array_equal(tpr, sk_tpr) # assert thresholds almost equal, due to numpy->torch->numpy conversion np.testing.assert_array_almost_equal(thresholds, [np.inf, 1.0, 0.711, 0.047]) - assert np.array_equal(tpr, sk_tpr) - # assert thresholds almost equal, due to numpy->torch->numpy conversion - np.testing.assert_array_almost_equal(thresholds, sk_thresholds) def test_check_compute_fn():