diff --git a/CHANGELOG.md b/CHANGELOG.md index 63521df17363..e5c350e81fb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- Fixed issue where one-hot tensors were passed to `F.one_hot` ([4970](https://github.com/pyg-team/pytorch_geometric/pull/4970)) - Fixed `bool` arugments in `argparse` in `benchmark/` ([#4967](https://github.com/pyg-team/pytorch_geometric/pull/4967)) - Fixed `BasicGNN` for `num_layers=1`, which now respects a desired number of `out_channels` ([#4943](https://github.com/pyg-team/pytorch_geometric/pull/4943)) - `len(batch)` will now return the number of graphs inside the batch, not the number of attributes ([#4931](https://github.com/pyg-team/pytorch_geometric/pull/4931)) diff --git a/torch_geometric/nn/models/correct_and_smooth.py b/torch_geometric/nn/models/correct_and_smooth.py index 66e9568e8557..a78c5c83fce3 100644 --- a/torch_geometric/nn/models/correct_and_smooth.py +++ b/torch_geometric/nn/models/correct_and_smooth.py @@ -92,7 +92,7 @@ def correct(self, y_soft: Tensor, y_true: Tensor, mask: Tensor, numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) assert y_true.size(0) == numel - if y_true.dtype == torch.long: + if y_true.dtype == torch.long and y_true.size(0) == y_true.numel(): y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)) y_true = y_true.to(y_soft.dtype) @@ -125,7 +125,7 @@ def smooth(self, y_soft: Tensor, y_true: Tensor, mask: Tensor, numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) assert y_true.size(0) == numel - if y_true.dtype == torch.long: + if y_true.dtype == torch.long and y_true.size(0) == y_true.numel(): y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)) y_true = y_true.to(y_soft.dtype) diff --git a/torch_geometric/nn/models/label_prop.py b/torch_geometric/nn/models/label_prop.py index c5c9f43456da..fa8b732e5390 100644 --- a/torch_geometric/nn/models/label_prop.py +++ b/torch_geometric/nn/models/label_prop.py @@ -38,7 +38,7 @@ def forward( ) -> Tensor: """""" - if y.dtype == torch.long: + if y.dtype == torch.long and y.size(0) == y.numel(): y = F.one_hot(y.view(-1)).to(torch.float) out = y