Skip to content

Commit

Permalink
Merge branch 'master' into cached-loader
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Aug 28, 2023
2 parents 5b81af4 + 6e6634a commit 014b2fa
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
16 changes: 12 additions & 4 deletions test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_homo_link_neighbor_loader_basic(device, subgraph_type,
neg_edge_index = get_random_edge_index(50, 50, 500, device=device)
neg_edge_index += 50

edge_label_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
input_edges = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
edge_label = torch.cat([
torch.ones(500, device=device),
torch.zeros(500, device=device),
Expand All @@ -45,7 +45,7 @@ def test_homo_link_neighbor_loader_basic(device, subgraph_type,
data,
num_neighbors=[-1] * 2,
batch_size=20,
edge_label_index=edge_label_index,
edge_label_index=input_edges,
edge_label=edge_label if neg_sampling_ratio is None else None,
subgraph_type=subgraph_type,
neg_sampling_ratio=neg_sampling_ratio,
Expand All @@ -58,8 +58,8 @@ def test_homo_link_neighbor_loader_basic(device, subgraph_type,

batch = loader([0])
assert isinstance(batch, Data)
assert int(edge_label_index[0, 0]) in batch.n_id.tolist()
assert int(edge_label_index[1, 0]) in batch.n_id.tolist()
assert int(input_edges[0, 0]) in batch.n_id.tolist()
assert int(input_edges[1, 0]) in batch.n_id.tolist()

for batch in loader:
assert isinstance(batch, Data)
Expand Down Expand Up @@ -97,6 +97,14 @@ def test_homo_link_neighbor_loader_basic(device, subgraph_type,
assert torch.all(batch.edge_label[:20] == 1)
assert torch.all(batch.edge_label[20:] == 0)

# Ensure local `edge_label_index` correctly maps to input edges.
global_edge_label_index = batch.n_id[batch.edge_label_index]
global_edge_label_index = (
global_edge_label_index[:, batch.edge_label >= 1])
global_edge_label_index = unique_edge_pairs(global_edge_label_index)
assert (len(global_edge_label_index & unique_edge_pairs(input_edges))
== len(global_edge_label_index))


@onlyNeighborSampler
@pytest.mark.parametrize('subgraph_type', ['directional', 'bidirectional'])
Expand Down
12 changes: 9 additions & 3 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,14 +311,19 @@ def test_trim_to_layer():
num_compile_calls = 0


@withCUDA
@onlyLinux
@disableExtensions
@withPackage('torch>=2.0.0')
@pytest.mark.parametrize('Model', [GCN, GraphSAGE, GIN, GAT, EdgeCNN, PNA])
@pytest.mark.skip(reason="Does not work yet in the full test suite")
def test_compile_graph_breaks(Model):
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
def test_compile_graph_breaks(Model, device):
# TODO EdgeCNN and PNA currently lead to graph breaks on CUDA :(
if Model in {EdgeCNN, PNA} and device.type == 'cuda':
return

x = torch.randn(3, 8, device=device)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)

kwargs = {}
if Model in {GCN, GAT}:
Expand All @@ -331,6 +336,7 @@ def test_compile_graph_breaks(Model):
kwargs['deg'] = torch.tensor([1, 2, 1])

model = Model(in_channels=8, hidden_channels=16, num_layers=2, **kwargs)
model = model.to(device)

def my_custom_backend(gm, *args):
global num_compile_calls
Expand Down

0 comments on commit 014b2fa

Please sign in to comment.