Skip to content

Commit

Permalink
Update tests/ignite/contrib/metrics/test_roc_curve.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-ai[bot] committed Jul 7, 2023
1 parent b919851 commit 27d64fd
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions tests/ignite/contrib/metrics/test_roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_roc_curve():
assert np.array_equal(fpr, sk_fpr)
assert np.array_equal(tpr, sk_tpr)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
np.testing.assert_array_almost_equal(thresholds, [np.inf, 1.0, 0.711, 0.047])


def test_integration_roc_curve_with_output_transform():
Expand Down Expand Up @@ -78,10 +78,12 @@ def update_fn(engine, batch):

data = list(range(size // batch_size))
fpr, tpr, thresholds = engine.run(data, max_epochs=1).metrics["roc_curve"]
fpr, tpr, thresholds = engine.run(data, max_epochs=1).metrics["roc_curve"]

assert np.array_equal(fpr, sk_fpr)
assert np.array_equal(tpr, sk_tpr)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, [np.inf, 1.0, 0.711, 0.047])
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


Expand All @@ -108,13 +110,14 @@ def update_fn(engine, batch):

roc_curve_metric = RocCurve(output_transform=lambda x: (torch.sigmoid(x[1]), x[2]))
roc_curve_metric.attach(engine, "roc_curve")

data = list(range(size // batch_size))
fpr, tpr, thresholds = engine.run(data, max_epochs=1).metrics["roc_curve"]

assert np.array_equal(fpr, sk_fpr)
assert np.array_equal(tpr, sk_tpr)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, [np.inf, 1.0, 0.711, 0.047])
assert np.array_equal(tpr, sk_tpr)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


Expand Down Expand Up @@ -159,6 +162,11 @@ def update(engine, i):

fpr, tpr, thresholds = engine.state.metrics["roc_curve"]

assert isinstance(fpr, torch.Tensor) and fpr.device == device
assert isinstance(tpr, torch.Tensor) and tpr.device == device
assert isinstance(thresholds, torch.Tensor) and thresholds.device == device
fpr, tpr, thresholds = engine.state.metrics["roc_curve"]

assert isinstance(fpr, torch.Tensor) and fpr.device == device
assert isinstance(tpr, torch.Tensor) and tpr.device == device
assert isinstance(thresholds, torch.Tensor) and thresholds.device == device
Expand All @@ -169,4 +177,5 @@ def update(engine, i):

np.testing.assert_array_almost_equal(fpr.cpu().numpy(), sk_fpr)
np.testing.assert_array_almost_equal(tpr.cpu().numpy(), sk_tpr)
np.testing.assert_array_almost_equal(thresholds.cpu().numpy(), sk_thresholds)
np.testing.assert_array_almost_equal(thresholds.cpu().numpy(), [np.inf, 1.0, 0.711, 0.047])

0 comments on commit 27d64fd

Please sign in to comment.