Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Code Coverage] PNAConv and DegreeScalerAggregation #6600

Merged
merged 9 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600))
- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))
- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))
- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))
Expand Down
18 changes: 14 additions & 4 deletions test/nn/conv/test_pna_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
def test_pna_conv():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
deg = torch.tensor([0, 3, 0, 1])
row, col = edge_index
value = torch.rand(row.size(0), 3)
adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

conv = PNAConv(16, 32, aggregators, scalers,
deg=torch.tensor([0, 3, 0, 1]), edge_dim=3, towers=4)
assert conv.__repr__() == 'PNAConv(16, 32, towers=4, edge_dim=3)'
conv = PNAConv(16, 32, aggregators, scalers, deg=deg, edge_dim=3, towers=4,
pre_layers=2, post_layers=2)
assert str(conv) == 'PNAConv(16, 32, towers=4, edge_dim=3)'
out = conv(x, edge_index, value)
assert out.size() == (4, 32)
assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)
Expand All @@ -36,6 +36,16 @@ def test_pna_conv():
assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)


def test_pna_divide_input():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
deg = torch.tensor([0, 3, 0, 1])
conv = PNAConv(16, 32, aggregators, scalers, deg=deg, divide_input=True,
zechengz marked this conversation as resolved.
Show resolved Hide resolved
train_norm=True)
zechengz marked this conversation as resolved.
Show resolved Hide resolved
out = conv(x, edge_index)
assert out.size() == (4, 32)


def test_pna_conv_get_degree_histogram():
edge_index = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]])
data = Data(num_nodes=5, edge_index=edge_index)
Expand Down
7 changes: 4 additions & 3 deletions torch_geometric/nn/aggr/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class DegreeScalerAggregation(Aggregation):
:obj:`"attenuation"`, :obj:`"linear"` and :obj:`"inverse_linear"`.
deg (Tensor): Histogram of in-degrees of nodes in the training set,
used by scalers to normalize.
train_norm (bool, optional) Whether normalization parameters
train_norm (bool, optional): Whether normalization parameters
are trainable. (default: :obj:`False`)
aggr_kwargs (Dict[str, Any], optional): Arguments passed to the
respective aggregation function in case it gets automatically
Expand Down Expand Up @@ -82,7 +82,7 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
out = self.aggr(x, index, ptr, dim_size, dim)

assert index is not None
deg = degree(index, num_nodes=dim_size, dtype=out.dtype).clamp_(1)
deg = degree(index, num_nodes=dim_size, dtype=out.dtype)
size = [1] * len(out.size())
size[dim] = -1
deg = deg.view(size)
Expand All @@ -98,7 +98,8 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
elif scaler == 'linear':
out_scaler = out * (deg / self.avg_deg_lin)
elif scaler == 'inverse_linear':
out_scaler = out * (self.avg_deg_lin / deg)
# Clamps minimum degree into one to avoid dividing by zero
out_scaler = out * (self.avg_deg_lin / deg.clamp(1))
else:
raise ValueError(f"Unknown scaler '{scaler}'")
outs.append(out_scaler)
Expand Down
20 changes: 10 additions & 10 deletions torch_geometric/nn/conv/pna_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import Tensor
from torch.nn import ModuleList, Sequential
from torch.utils.data import DataLoader

from torch_geometric.nn.aggr import DegreeScalerAggregation
from torch_geometric.nn.conv import MessagePassing
Expand Down Expand Up @@ -76,7 +77,7 @@ class PNAConv(MessagePassing):
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
train_norm (bool, optional) Whether normalization parameters
train_norm (bool, optional): Whether normalization parameters
are trainable. (default: :obj:`False`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Expand Down Expand Up @@ -192,17 +193,16 @@ def __repr__(self):
f'edge_dim={self.edge_dim})')

@staticmethod
def get_degree_histogram(loader) -> Tensor:
max_degree = 0
def get_degree_histogram(loader: DataLoader) -> Tensor:
deg_histogram = torch.zeros(1, dtype=torch.long)
for data in loader:
d = degree(data.edge_index[1], num_nodes=data.num_nodes,
dtype=torch.long)
max_degree = max(max_degree, int(d.max()))
# Compute the in-degree histogram tensor
deg_histogram = torch.zeros(max_degree + 1, dtype=torch.long)
for data in loader:
d = degree(data.edge_index[1], num_nodes=data.num_nodes,
dtype=torch.long)
deg_histogram += torch.bincount(d, minlength=deg_histogram.numel())
d_bincount = torch.bincount(d, minlength=deg_histogram.numel())
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
if d_bincount.size(0) > deg_histogram.size(0):
d_bincount[:deg_histogram.size(0)] += deg_histogram
deg_histogram = d_bincount
else:
deg_histogram[:d_bincount.size(0)] += d_bincount

return deg_histogram