Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Code Coverage] transforms/gcn_norm.py #6673

Merged
merged 12 commits into from
Feb 12, 2023
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613)
- Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668))
- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))
- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))
- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))
Expand Down
45 changes: 45 additions & 0 deletions test/transforms/test_gcn_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.data import Data
from torch_geometric.transforms import GCNNorm


def test_gcn_norm():
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
edge_weight = torch.ones(edge_index.size(1))
adj_t = SparseTensor.from_edge_index(edge_index, edge_weight).t()

transform = GCNNorm()
assert str(transform) == 'GCNNorm(add_self_loops=True)'

expected_edge_index = [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]]
expected_edge_weight = torch.tensor(
[0.4082, 0.4082, 0.4082, 0.4082, 0.5000, 0.3333, 0.5000])

data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3)
data = transform(data)
assert len(data) == 3
assert data.num_nodes == 3
assert data.edge_index.tolist() == expected_edge_index
assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4)

data = Data(edge_index=edge_index, num_nodes=3)
data = transform(data)
assert len(data) == 3
assert data.num_nodes == 3
assert data.edge_index.tolist() == expected_edge_index
assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4)

# For `SparseTensor`, expected outputs will be sorted:
expected_edge_index = [[0, 0, 1, 1, 1, 2, 2], [0, 1, 0, 1, 2, 1, 2]]
expected_edge_weight = torch.tensor(
[0.500, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000])

data = Data(adj_t=adj_t)
data = transform(data)
assert len(data) == 1
row, col, value = data.adj_t.coo()
assert row.tolist() == expected_edge_index[0]
assert col.tolist() == expected_edge_index[1]
assert torch.allclose(value, expected_edge_weight, atol=1e-4)
9 changes: 5 additions & 4 deletions torch_geometric/transforms/gcn_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ def __call__(self, data: Data) -> Data:
assert 'edge_index' in data or 'adj_t' in data

if 'edge_index' in data:
edge_weight = data.edge_attr
if 'edge_weight' in data:
edge_weight = data.edge_weight
data.edge_index, data.edge_weight = gcn_norm(
data.edge_index, edge_weight, data.num_nodes,
data.edge_index, data.edge_weight, data.num_nodes,
add_self_loops=self.add_self_loops)
else:
data.adj_t = gcn_norm(data.adj_t,
add_self_loops=self.add_self_loops)

return data

def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'add_self_loops={self.add_self_loops})')