Skip to content

Commit

Permalink
[Code Coverage] models/test_correct_and_smooth.py (#6637)
Browse files Browse the repository at this point in the history
Add tests for `CorrectAndSmooth`
Minor documentation fix
Also minor fix of `smooth` in `CorrectAndSmooth` to avoid changing the
`y_soft` tensor in the `smooth` function

---------

Co-authored-by: Jinu Sunil <jinu.sunil@gmail.com>
  • Loading branch information
zechengz and wsad1 authored Feb 8, 2023
1 parent 41897d9 commit 98626c6
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613)
- Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640),[#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637))
- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))
- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))
- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))
Expand Down
35 changes: 35 additions & 0 deletions test/nn/models/test_correct_and_smooth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.nn.models import CorrectAndSmooth


def test_correct_and_smooth():
y_soft = torch.tensor([0.1, 0.5, 0.4]).repeat(6, 1)
y_true = torch.tensor([1, 0, 0, 2, 1, 1])
edge_index = torch.tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]])
adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(6, 6))
mask = torch.randint(0, 2, (6, ), dtype=torch.bool)

model = CorrectAndSmooth(num_correction_layers=2, correction_alpha=0.5,
num_smoothing_layers=2, smoothing_alpha=0.5)
assert str(model) == ('CorrectAndSmooth(\n'
' correct: num_layers=2, alpha=0.5\n'
' smooth: num_layers=2, alpha=0.5\n'
' autoscale=True, scale=1.0\n'
')')

correct_out = model.correct(y_soft, y_true[mask], mask, edge_index)
assert correct_out.size() == (6, 3)
assert torch.allclose(correct_out,
model.correct(y_soft, y_true[mask], mask, adj))
smooth_out = model.smooth(y_soft, y_true[mask], mask, edge_index)
assert smooth_out.size() == (6, 3)
assert torch.allclose(smooth_out,
model.smooth(y_soft, y_true[mask], mask, adj))

# Test without autoscale:
model = CorrectAndSmooth(num_correction_layers=2, correction_alpha=0.5,
num_smoothing_layers=2, smoothing_alpha=0.5,
autoscale=False)
model.correct(y_soft, y_true[mask], mask, edge_index)
3 changes: 0 additions & 3 deletions test/nn/models/test_label_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

from torch_geometric.nn.models import LabelPropagation

num_layers = [0, 5]
alphas = [0, 0.5]


def test_label_prop():
y = torch.tensor([1, 0, 0, 2, 1, 1])
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/models/correct_and_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class CorrectAndSmooth(torch.nn.Module):
.. math::
\mathbf{\hat{Z}}^{(\ell)} = \alpha_2 \mathbf{D}^{-1/2}\mathbf{A}
\mathbf{D}^{-1/2} \mathbf{\hat{Z}}^{(\ell - 1)} +
(1 - \alpha_1) \mathbf{\hat{Z}}^{(\ell - 1)}
(1 - \alpha_2) \mathbf{\hat{Z}}^{(\ell - 1)}
to obtain the final prediction :math:`\mathbf{\hat{Z}}^{(L_2)}`.
Expand Down Expand Up @@ -128,6 +128,7 @@ def smooth(self, y_soft: Tensor, y_true: Tensor, mask: Tensor,
y_true = F.one_hot(y_true.view(-1), y_soft.size(-1))
y_true = y_true.to(y_soft.dtype)

y_soft = y_soft.clone()
y_soft[mask] = y_true

return self.prop2(y_soft, edge_index, edge_weight=edge_weight)
Expand Down

0 comments on commit 98626c6

Please sign in to comment.