Skip to content

Commit

Permalink
Stronger check for label dimensions in CorrectAndSmooth and LabelProp…
Browse files Browse the repository at this point in the history
…agation (#4970)

* check y.numel() to avoid calling F.one_hot on one-hot matrices

* changelog
  • Loading branch information
kswhitecross authored Jul 12, 2022
1 parent 38befa3 commit da7be5d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- Fixed issue where one-hot tensors were passed to `F.one_hot` ([4970](https://github.com/pyg-team/pytorch_geometric/pull/4970))
- Fixed `bool` arugments in `argparse` in `benchmark/` ([#4967](https://github.com/pyg-team/pytorch_geometric/pull/4967))
- Fixed `BasicGNN` for `num_layers=1`, which now respects a desired number of `out_channels` ([#4943](https://github.com/pyg-team/pytorch_geometric/pull/4943))
- `len(batch)` will now return the number of graphs inside the batch, not the number of attributes ([#4931](https://github.com/pyg-team/pytorch_geometric/pull/4931))
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/models/correct_and_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def correct(self, y_soft: Tensor, y_true: Tensor, mask: Tensor,
numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)
assert y_true.size(0) == numel

if y_true.dtype == torch.long:
if y_true.dtype == torch.long and y_true.size(0) == y_true.numel():
y_true = F.one_hot(y_true.view(-1), y_soft.size(-1))
y_true = y_true.to(y_soft.dtype)

Expand Down Expand Up @@ -125,7 +125,7 @@ def smooth(self, y_soft: Tensor, y_true: Tensor, mask: Tensor,
numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)
assert y_true.size(0) == numel

if y_true.dtype == torch.long:
if y_true.dtype == torch.long and y_true.size(0) == y_true.numel():
y_true = F.one_hot(y_true.view(-1), y_soft.size(-1))
y_true = y_true.to(y_soft.dtype)

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/models/label_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(
) -> Tensor:
""""""

if y.dtype == torch.long:
if y.dtype == torch.long and y.size(0) == y.numel():
y = F.one_hot(y.view(-1)).to(torch.float)

out = y
Expand Down

0 comments on commit da7be5d

Please sign in to comment.