Skip to content

Commit

Permalink
Test for zero graph breaks on CUDA (pyg-team#7944)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored and erfanloghmani committed Aug 31, 2023
1 parent 4af12d1 commit 87df3fc
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,14 +311,19 @@ def test_trim_to_layer():
num_compile_calls = 0


@withCUDA
@onlyLinux
@disableExtensions
@withPackage('torch>=2.0.0')
@pytest.mark.parametrize('Model', [GCN, GraphSAGE, GIN, GAT, EdgeCNN, PNA])
@pytest.mark.skip(reason="Does not work yet in the full test suite")
def test_compile_graph_breaks(Model):
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
def test_compile_graph_breaks(Model, device):
# TODO EdgeCNN and PNA currently lead to graph breaks on CUDA :(
if Model in {EdgeCNN, PNA} and device.type == 'cuda':
return

x = torch.randn(3, 8, device=device)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)

kwargs = {}
if Model in {GCN, GAT}:
Expand All @@ -331,6 +336,7 @@ def test_compile_graph_breaks(Model):
kwargs['deg'] = torch.tensor([1, 2, 1])

model = Model(in_channels=8, hidden_channels=16, num_layers=2, **kwargs)
model = model.to(device)

def my_custom_backend(gm, *args):
global num_compile_calls
Expand Down

0 comments on commit 87df3fc

Please sign in to comment.