diff --git a/CHANGELOG.md b/CHANGELOG.md index a4b6482a6b19..9710bf227e6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Fixed the shape of `edge_label_time` when using temporal sampling on homogeneous graphs ([#7807](https://github.com/pyg-team/pytorch_geometric/pull/7807)) - Made `FieldStatus` enum picklable to avoid `PicklingError` in a multi-process setting ([#7808](https://github.com/pyg-team/pytorch_geometric/pull/7808)) - Fixed `edge_label_index` computation in `LinkNeighborLoader` for the homogeneous+`disjoint` mode ([#7791](https://github.com/pyg-team/pytorch_geometric/pull/7791)) - Fixed `CaptumExplainer` for `binary_classification` tasks ([#7787](https://github.com/pyg-team/pytorch_geometric/pull/7787)) diff --git a/test/loader/test_link_neighbor_loader.py b/test/loader/test_link_neighbor_loader.py index 80cda9c23f32..406e347b1106 100644 --- a/test/loader/test_link_neighbor_loader.py +++ b/test/loader/test_link_neighbor_loader.py @@ -204,6 +204,38 @@ def test_link_neighbor_loader_edge_label(): assert torch.all(batch.edge_label[10:] == 0) +@withPackage('pyg_lib') +@pytest.mark.parametrize('batch_size', [1]) +def test_temporal_homo_link_neighbor_loader(batch_size): + data = Data( + x=torch.randn(10, 5), + edge_index=torch.randint(0, 10, (2, 123)), + time=torch.arange(10), + ) + + # Ensure that nodes exist at the time of the `edge_label_time`: + edge_label_time = torch.max( + data.time[data.edge_index[0]], + data.time[data.edge_index[1]], + ) + + loader = LinkNeighborLoader( + data, + num_neighbors=[-1], + time_attr='time', + edge_label=torch.ones(data.num_edges), + edge_label_time=edge_label_time, + batch_size=batch_size, + shuffle=True, + ) + + for batch in loader: + assert batch.edge_label_index.size() == (2, batch_size) + assert batch.edge_label_time.size() == (batch_size, ) + assert batch.edge_label.size() == (batch_size, ) + assert torch.all(batch.time <= batch.edge_label_time) + + @withPackage('pyg_lib') def test_temporal_hetero_link_neighbor_loader(): data = HeteroData() diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index 702c7810667c..24c1e2834c19 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -576,7 +576,7 @@ def edge_sample( else: edge_label_index = inverse_seed.view(2, -1) - out.metadata = (input_id, edge_label_index, edge_label, seed_time) + out.metadata = (input_id, edge_label_index, edge_label, src_time) elif neg_sampling.is_triplet(): if disjoint: