Skip to content

Commit

Permalink
Fix LinkNeighborLoader producing double-sized edge_label_time for…
Browse files Browse the repository at this point in the history
… homogeneous graphs (#7807)

Fixes `edge_label_time.size() == (2*batch_size,)` to have
`(batch_size,)`.
Adds a test case for #7791.
Part of #7796 and #6528.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Jul 31, 2023
1 parent bc69d1a commit e8f752f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Fixed the shape of `edge_label_time` when using temporal sampling on homogeneous graphs ([#7807](https://github.com/pyg-team/pytorch_geometric/pull/7807))
- Made `FieldStatus` enum picklable to avoid `PicklingError` in a multi-process setting ([#7808](https://github.com/pyg-team/pytorch_geometric/pull/7808))
- Fixed `edge_label_index` computation in `LinkNeighborLoader` for the homogeneous+`disjoint` mode ([#7791](https://github.com/pyg-team/pytorch_geometric/pull/7791))
- Fixed `CaptumExplainer` for `binary_classification` tasks ([#7787](https://github.com/pyg-team/pytorch_geometric/pull/7787))
Expand Down
32 changes: 32 additions & 0 deletions test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,38 @@ def test_link_neighbor_loader_edge_label():
assert torch.all(batch.edge_label[10:] == 0)


@withPackage('pyg_lib')
@pytest.mark.parametrize('batch_size', [1])
def test_temporal_homo_link_neighbor_loader(batch_size):
data = Data(
x=torch.randn(10, 5),
edge_index=torch.randint(0, 10, (2, 123)),
time=torch.arange(10),
)

# Ensure that nodes exist at the time of the `edge_label_time`:
edge_label_time = torch.max(
data.time[data.edge_index[0]],
data.time[data.edge_index[1]],
)

loader = LinkNeighborLoader(
data,
num_neighbors=[-1],
time_attr='time',
edge_label=torch.ones(data.num_edges),
edge_label_time=edge_label_time,
batch_size=batch_size,
shuffle=True,
)

for batch in loader:
assert batch.edge_label_index.size() == (2, batch_size)
assert batch.edge_label_time.size() == (batch_size, )
assert batch.edge_label.size() == (batch_size, )
assert torch.all(batch.time <= batch.edge_label_time)


@withPackage('pyg_lib')
def test_temporal_hetero_link_neighbor_loader():
data = HeteroData()
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def edge_sample(
else:
edge_label_index = inverse_seed.view(2, -1)

out.metadata = (input_id, edge_label_index, edge_label, seed_time)
out.metadata = (input_id, edge_label_index, edge_label, src_time)

elif neg_sampling.is_triplet():
if disjoint:
Expand Down

0 comments on commit e8f752f

Please sign in to comment.