diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b98d52a85be..c76b4ffbc469 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,7 +82,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613) - Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609)) -- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668)) +- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668)) - Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522)) - Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517)) - Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512)) diff --git a/test/transforms/test_gcn_norm.py b/test/transforms/test_gcn_norm.py new file mode 100644 index 000000000000..9eb2bd5debc5 --- /dev/null +++ b/test/transforms/test_gcn_norm.py @@ -0,0 +1,45 @@ +import torch +from torch_sparse import SparseTensor + +from torch_geometric.data import Data +from torch_geometric.transforms import GCNNorm + + +def test_gcn_norm(): + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_weight = torch.ones(edge_index.size(1)) + adj_t = SparseTensor.from_edge_index(edge_index, edge_weight).t() + + transform = GCNNorm() + assert str(transform) == 'GCNNorm(add_self_loops=True)' + + expected_edge_index = [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]] + expected_edge_weight = torch.tensor( + [0.4082, 0.4082, 0.4082, 0.4082, 0.5000, 0.3333, 0.5000]) + + data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3) + data = transform(data) + assert len(data) == 3 + assert data.num_nodes == 3 + assert data.edge_index.tolist() == expected_edge_index + assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4) + + data = Data(edge_index=edge_index, num_nodes=3) + data = transform(data) + assert len(data) == 3 + assert data.num_nodes == 3 + assert data.edge_index.tolist() == expected_edge_index + assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4) + + # For `SparseTensor`, expected outputs will be sorted: + expected_edge_index = [[0, 0, 1, 1, 1, 2, 2], [0, 1, 0, 1, 2, 1, 2]] + expected_edge_weight = torch.tensor( + [0.500, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000]) + + data = Data(adj_t=adj_t) + data = transform(data) + assert len(data) == 1 + row, col, value = data.adj_t.coo() + assert row.tolist() == expected_edge_index[0] + assert col.tolist() == expected_edge_index[1] + assert torch.allclose(value, expected_edge_weight, atol=1e-4) diff --git a/torch_geometric/transforms/gcn_norm.py b/torch_geometric/transforms/gcn_norm.py index aa52a4f020aa..66ca9e59dac3 100644 --- a/torch_geometric/transforms/gcn_norm.py +++ b/torch_geometric/transforms/gcn_norm.py @@ -24,14 +24,15 @@ def __call__(self, data: Data) -> Data: assert 'edge_index' in data or 'adj_t' in data if 'edge_index' in data: - edge_weight = data.edge_attr - if 'edge_weight' in data: - edge_weight = data.edge_weight data.edge_index, data.edge_weight = gcn_norm( - data.edge_index, edge_weight, data.num_nodes, + data.edge_index, data.edge_weight, data.num_nodes, add_self_loops=self.add_self_loops) else: data.adj_t = gcn_norm(data.adj_t, add_self_loops=self.add_self_loops) return data + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(' + f'add_self_loops={self.add_self_loops})')