From 87df3fc0dbdd55a27527138608d6e4c4d9e05d89 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 28 Aug 2023 13:40:31 +0200 Subject: [PATCH] Test for zero graph breaks on CUDA (#7944) --- test/nn/models/test_basic_gnn.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/nn/models/test_basic_gnn.py b/test/nn/models/test_basic_gnn.py index 0b98b57ca32c4..4f91704be54ca 100644 --- a/test/nn/models/test_basic_gnn.py +++ b/test/nn/models/test_basic_gnn.py @@ -311,14 +311,19 @@ def test_trim_to_layer(): num_compile_calls = 0 +@withCUDA @onlyLinux @disableExtensions @withPackage('torch>=2.0.0') @pytest.mark.parametrize('Model', [GCN, GraphSAGE, GIN, GAT, EdgeCNN, PNA]) @pytest.mark.skip(reason="Does not work yet in the full test suite") -def test_compile_graph_breaks(Model): - x = torch.randn(3, 8) - edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) +def test_compile_graph_breaks(Model, device): + # TODO EdgeCNN and PNA currently lead to graph breaks on CUDA :( + if Model in {EdgeCNN, PNA} and device.type == 'cuda': + return + + x = torch.randn(3, 8, device=device) + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device) kwargs = {} if Model in {GCN, GAT}: @@ -331,6 +336,7 @@ def test_compile_graph_breaks(Model): kwargs['deg'] = torch.tensor([1, 2, 1]) model = Model(in_channels=8, hidden_channels=16, num_layers=2, **kwargs) + model = model.to(device) def my_custom_backend(gm, *args): global num_compile_calls