Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jun 5, 2022
1 parent 6327c22 commit 6fd8f43
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions torch_geometric/nn/aggr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,24 @@ def __call__(self, x: Tensor, index: Optional[Tensor] = None, *,
raise ValueError(f"Encountered invalid dimension '{dim}' of "
f"source tensor with {x.dim()} dimensions")

if (ptr is not None and dim_size is not None
and dim_size != ptr.numel() - 1):
raise ValueError(f"Encountered mismatch between 'dim_size' (got "
f"'{dim_size}') and 'ptr' (got '{ptr.size(0)}')")

if index is None and ptr is None:
index = x.new_zeros(x.size(dim), dtype=torch.long)

if dim_size is None and ptr is not None:
dim_size = ptr.numel() - 1
elif dim_size is None:
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
if ptr is not None:
if dim_size is None:
dim_size = ptr.numel() - 1
elif dim_size != ptr.numel() - 1:
raise ValueError(f"Encountered invalid 'dim_size' (got "
f"'{dim_size}' but expected "
f"'{ptr.numel() - 1}')")

if index is not None:
if dim_size is None:
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
elif index.numel() > 0 and dim_size <= int(index.max()):
raise ValueError(f"Encountered invalid 'dim_size' (got "
f"'{dim_size}' but expected "
f">= '{int(index.max()) + 1}')")

return super().__call__(x, index, ptr=ptr, dim_size=dim_size, dim=dim)

Expand Down

0 comments on commit 6fd8f43

Please sign in to comment.