Skip to content

Commit

Permalink
Fix a bug in the QuantileAggregation with the dim size parameter …
Browse files Browse the repository at this point in the history
…passed.

Passing the `dim_size` parameter led to `index out of range` error,
during `index_select` operation. Setting parameters `dim_size` and
`fill_value` has not been tested, appropriate tests have been added.
  • Loading branch information
piotrchmiel committed May 22, 2023
1 parent 1de9dee commit 62a295d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added an optional `batch_size` and `max_num_nodes` arguments to `MemPooling` layer ([#7239](https://github.com/pyg-team/pytorch_geometric/pull/7239))
- Fixed training issues of the GraphGPS example ([#7377](https://github.com/pyg-team/pytorch_geometric/pull/7377))
- Allowed `CaptumExplainer` to be called multiple times in a row ([#7391](https://github.com/pyg-team/pytorch_geometric/pull/7391))
- Fixed an `index out of range` bug in the `QuantileAggregation` with the `dim size` parameter passed ([#7407](https://github.com/pyg-team/pytorch_geometric/pull/7407))

### Removed

Expand Down
24 changes: 20 additions & 4 deletions test/nn/aggr/test_quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
@pytest.mark.parametrize('q', [0., .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.])
@pytest.mark.parametrize('interpolation', QuantileAggregation.interpolations)
@pytest.mark.parametrize('dim', [0, 1])
def test_quantile_aggregation(q, interpolation, dim):
@pytest.mark.parametrize('use_dim_size', [True, False])
@pytest.mark.parametrize('fill_value', [0.0, 10.0])
def test_quantile_aggregation(q, interpolation, dim, use_dim_size, fill_value):
x = torch.tensor([
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
Expand All @@ -22,12 +24,26 @@ def test_quantile_aggregation(q, interpolation, dim):
])
index = torch.zeros(x.size(dim), dtype=torch.long)

aggr = QuantileAggregation(q=q, interpolation=interpolation)
aggr = QuantileAggregation(q=q, interpolation=interpolation,
fill_value=fill_value)
assert str(aggr) == f"QuantileAggregation(q={q})"

out = aggr(x, index, dim=dim)
dim_size = None
if use_dim_size:
dim_size = x.size(dim)

out = aggr(x, index, dim=dim, dim_size=dim_size)
expected = x.quantile(q, dim, interpolation=interpolation, keepdim=True)
assert torch.allclose(out, expected)

out_quantile = torch.index_select(out, dim,
torch.tensor(0, dtype=torch.long))
assert torch.allclose(out_quantile, expected)

if use_dim_size:
out_fill_value = torch.index_select(
out, dim, torch.tensor(range(1, dim_size), dtype=torch.long))
assert torch.allclose(out_fill_value,
torch.tensor([fill_value], dtype=out.dtype))


def test_median_aggregation():
Expand Down
5 changes: 5 additions & 0 deletions torch_geometric/nn/aggr/quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,
count = torch.bincount(index, minlength=dim_size or 0)
cumsum = torch.cumsum(count, dim=0) - count

if dim_size is not None:
cumsum = torch.where(
cumsum >= torch.tensor(x.shape[dim], dtype=cumsum.dtype),
torch.tensor(x.shape[dim], dtype=cumsum.dtype) - 1, cumsum)

q_point = self.q * (count - 1) + cumsum
q_point = q_point.t().reshape(-1)

Expand Down

0 comments on commit 62a295d

Please sign in to comment.