diff --git a/test/loader/test_link_neighbor_loader.py b/test/loader/test_link_neighbor_loader.py index 4da61d966cac..fcdb0ee6f425 100644 --- a/test/loader/test_link_neighbor_loader.py +++ b/test/loader/test_link_neighbor_loader.py @@ -224,8 +224,9 @@ def test_temporal_homo_link_neighbor_loader(): ) for sample in loader: - assert sample.edge_label_index.size() == (batch_size,) + assert sample.edge_label_index.size() == (2, batch_size) assert sample.edge_label_time.size() == (batch_size,) + assert sample.edge_label.size() == (batch_size,) assert torch.all( sample.edge_time <= sample.edge_label_time ), "The target time should be later than all timestamps in the subgraph"