Skip to content

Commit

Permalink
Update tests/ignite/contrib/metrics/test_roc_curve.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-ai[bot] authored Jul 7, 2023
1 parent e0ab81b commit 6122c76
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions tests/ignite/contrib/metrics/test_roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,14 @@ def update_fn(engine, batch):
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

engine = Engine(update_fn)

roc_curve_metric = RocCurve(output_transform=lambda x: (x[1], x[2]))
roc_curve_metric.attach(engine, "roc_curve")

data = list(range(size // batch_size))

fpr, tpr, thresholds = engine.run(data, max_epochs=1).metrics["roc_curve"]
assert np.array_equal(fpr, sk_fpr)
assert np.array_equal(tpr, sk_tpr)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, [np.inf, 1.0, 0.711, 0.047])
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_integration_roc_curve_with_activated_output_transform():
Expand All @@ -105,18 +102,13 @@ def update_fn(engine, batch):
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

engine = Engine(update_fn)

roc_curve_metric = RocCurve(output_transform=lambda x: (torch.sigmoid(x[1]), x[2]))
roc_curve_metric.attach(engine, "roc_curve")
data = list(range(size // batch_size))
fpr, tpr, thresholds = engine.run(data, max_epochs=1).metrics["roc_curve"]

assert np.array_equal(fpr, sk_fpr)
assert np.array_equal(tpr, sk_tpr)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, [np.inf, 1.0, 0.711, 0.047])
assert np.array_equal(tpr, sk_tpr)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_check_compute_fn():
Expand Down

0 comments on commit 6122c76

Please sign in to comment.