Skip to content

Commit

Permalink
[Code Coverage] transforms/gcn_norm.py (#6673)
Browse files Browse the repository at this point in the history
Co-authored-by: wsad1 <jinu.sunil@gmail.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Feb 12, 2023
1 parent 2a80dae commit 105c0d8
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613)
- Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668))
- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))
- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))
- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))
Expand Down
45 changes: 45 additions & 0 deletions test/transforms/test_gcn_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.data import Data
from torch_geometric.transforms import GCNNorm


def test_gcn_norm():
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
edge_weight = torch.ones(edge_index.size(1))
adj_t = SparseTensor.from_edge_index(edge_index, edge_weight).t()

transform = GCNNorm()
assert str(transform) == 'GCNNorm(add_self_loops=True)'

expected_edge_index = [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]]
expected_edge_weight = torch.tensor(
[0.4082, 0.4082, 0.4082, 0.4082, 0.5000, 0.3333, 0.5000])

data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3)
data = transform(data)
assert len(data) == 3
assert data.num_nodes == 3
assert data.edge_index.tolist() == expected_edge_index
assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4)

data = Data(edge_index=edge_index, num_nodes=3)
data = transform(data)
assert len(data) == 3
assert data.num_nodes == 3
assert data.edge_index.tolist() == expected_edge_index
assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4)

# For `SparseTensor`, expected outputs will be sorted:
expected_edge_index = [[0, 0, 1, 1, 1, 2, 2], [0, 1, 0, 1, 2, 1, 2]]
expected_edge_weight = torch.tensor(
[0.500, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000])

data = Data(adj_t=adj_t)
data = transform(data)
assert len(data) == 1
row, col, value = data.adj_t.coo()
assert row.tolist() == expected_edge_index[0]
assert col.tolist() == expected_edge_index[1]
assert torch.allclose(value, expected_edge_weight, atol=1e-4)
9 changes: 5 additions & 4 deletions torch_geometric/transforms/gcn_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ def __call__(self, data: Data) -> Data:
assert 'edge_index' in data or 'adj_t' in data

if 'edge_index' in data:
edge_weight = data.edge_attr
if 'edge_weight' in data:
edge_weight = data.edge_weight
data.edge_index, data.edge_weight = gcn_norm(
data.edge_index, edge_weight, data.num_nodes,
data.edge_index, data.edge_weight, data.num_nodes,
add_self_loops=self.add_self_loops)
else:
data.adj_t = gcn_norm(data.adj_t,
add_self_loops=self.add_self_loops)

return data

def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'add_self_loops={self.add_self_loops})')

0 comments on commit 105c0d8

Please sign in to comment.