From 6fd8f43f85249a8872f797ae9ad7a6703118d3cd Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sun, 5 Jun 2022 08:47:59 +0200 Subject: [PATCH] update --- torch_geometric/nn/aggr/base.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index e8c3a2c6a14c..ae24b3d9d11b 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -43,18 +43,24 @@ def __call__(self, x: Tensor, index: Optional[Tensor] = None, *, raise ValueError(f"Encountered invalid dimension '{dim}' of " f"source tensor with {x.dim()} dimensions") - if (ptr is not None and dim_size is not None - and dim_size != ptr.numel() - 1): - raise ValueError(f"Encountered mismatch between 'dim_size' (got " - f"'{dim_size}') and 'ptr' (got '{ptr.size(0)}')") - if index is None and ptr is None: index = x.new_zeros(x.size(dim), dtype=torch.long) - if dim_size is None and ptr is not None: - dim_size = ptr.numel() - 1 - elif dim_size is None: - dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 + if ptr is not None: + if dim_size is None: + dim_size = ptr.numel() - 1 + elif dim_size != ptr.numel() - 1: + raise ValueError(f"Encountered invalid 'dim_size' (got " + f"'{dim_size}' but expected " + f"'{ptr.numel() - 1}')") + + if index is not None: + if dim_size is None: + dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 + elif index.numel() > 0 and dim_size <= int(index.max()): + raise ValueError(f"Encountered invalid 'dim_size' (got " + f"'{dim_size}' but expected " + f">= '{int(index.max()) + 1}')") return super().__call__(x, index, ptr=ptr, dim_size=dim_size, dim=dim)