Skip to content

Commit

Permalink
Check that local edge_label_index maps to global edge_label_index
Browse files Browse the repository at this point in the history
… in `LinkNeighborLoader` (#7943)

Related: #7900
  • Loading branch information
rusty1s authored Aug 28, 2023
1 parent 32a76cb commit ecc4b76
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_homo_link_neighbor_loader_basic(device, subgraph_type,
neg_edge_index = get_random_edge_index(50, 50, 500, device=device)
neg_edge_index += 50

edge_label_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
input_edges = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
edge_label = torch.cat([
torch.ones(500, device=device),
torch.zeros(500, device=device),
Expand All @@ -45,7 +45,7 @@ def test_homo_link_neighbor_loader_basic(device, subgraph_type,
data,
num_neighbors=[-1] * 2,
batch_size=20,
edge_label_index=edge_label_index,
edge_label_index=input_edges,
edge_label=edge_label if neg_sampling_ratio is None else None,
subgraph_type=subgraph_type,
neg_sampling_ratio=neg_sampling_ratio,
Expand All @@ -58,8 +58,8 @@ def test_homo_link_neighbor_loader_basic(device, subgraph_type,

batch = loader([0])
assert isinstance(batch, Data)
assert int(edge_label_index[0, 0]) in batch.n_id.tolist()
assert int(edge_label_index[1, 0]) in batch.n_id.tolist()
assert int(input_edges[0, 0]) in batch.n_id.tolist()
assert int(input_edges[1, 0]) in batch.n_id.tolist()

for batch in loader:
assert isinstance(batch, Data)
Expand Down Expand Up @@ -97,6 +97,14 @@ def test_homo_link_neighbor_loader_basic(device, subgraph_type,
assert torch.all(batch.edge_label[:20] == 1)
assert torch.all(batch.edge_label[20:] == 0)

# Ensure local `edge_label_index` correctly maps to input edges.
global_edge_label_index = batch.n_id[batch.edge_label_index]
global_edge_label_index = (
global_edge_label_index[:, batch.edge_label >= 1])
global_edge_label_index = unique_edge_pairs(global_edge_label_index)
assert (len(global_edge_label_index & unique_edge_pairs(input_edges))
== len(global_edge_label_index))


@onlyNeighborSampler
@pytest.mark.parametrize('subgraph_type', ['directional', 'bidirectional'])
Expand Down

0 comments on commit ecc4b76

Please sign in to comment.