From da7be5d7c617c45c93cf76243009da786c0ae873 Mon Sep 17 00:00:00 2001 From: Kyle Whitecross <35904712+kpstesla@users.noreply.github.com> Date: Tue, 12 Jul 2022 10:01:26 -0700 Subject: [PATCH] Stronger check for label dimensions in CorrectAndSmooth and LabelPropagation (#4970) * check y.numel() to avoid calling F.one_hot on one-hot matrices * changelog --- CHANGELOG.md | 1 + torch_geometric/nn/models/correct_and_smooth.py | 4 ++-- torch_geometric/nn/models/label_prop.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 63521df17363..e5c350e81fb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- Fixed issue where one-hot tensors were passed to `F.one_hot` ([4970](https://github.com/pyg-team/pytorch_geometric/pull/4970)) - Fixed `bool` arugments in `argparse` in `benchmark/` ([#4967](https://github.com/pyg-team/pytorch_geometric/pull/4967)) - Fixed `BasicGNN` for `num_layers=1`, which now respects a desired number of `out_channels` ([#4943](https://github.com/pyg-team/pytorch_geometric/pull/4943)) - `len(batch)` will now return the number of graphs inside the batch, not the number of attributes ([#4931](https://github.com/pyg-team/pytorch_geometric/pull/4931)) diff --git a/torch_geometric/nn/models/correct_and_smooth.py b/torch_geometric/nn/models/correct_and_smooth.py index 66e9568e8557..a78c5c83fce3 100644 --- a/torch_geometric/nn/models/correct_and_smooth.py +++ b/torch_geometric/nn/models/correct_and_smooth.py @@ -92,7 +92,7 @@ def correct(self, y_soft: Tensor, y_true: Tensor, mask: Tensor, numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) assert y_true.size(0) == numel - if y_true.dtype == torch.long: + if y_true.dtype == torch.long and y_true.size(0) == y_true.numel(): y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)) y_true = y_true.to(y_soft.dtype) @@ -125,7 +125,7 @@ def smooth(self, y_soft: Tensor, y_true: Tensor, mask: Tensor, numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0) assert y_true.size(0) == numel - if y_true.dtype == torch.long: + if y_true.dtype == torch.long and y_true.size(0) == y_true.numel(): y_true = F.one_hot(y_true.view(-1), y_soft.size(-1)) y_true = y_true.to(y_soft.dtype) diff --git a/torch_geometric/nn/models/label_prop.py b/torch_geometric/nn/models/label_prop.py index c5c9f43456da..fa8b732e5390 100644 --- a/torch_geometric/nn/models/label_prop.py +++ b/torch_geometric/nn/models/label_prop.py @@ -38,7 +38,7 @@ def forward( ) -> Tensor: """""" - if y.dtype == torch.long: + if y.dtype == torch.long and y.size(0) == y.numel(): y = F.one_hot(y.view(-1)).to(torch.float) out = y