Skip to content

Commit

Permalink
Add max_num_nodes parameter to the constructor of `EquilibriumAggre…
Browse files Browse the repository at this point in the history
…gation` layer (#7530)

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
piotrchmiel and rusty1s authored Jun 7, 2023
1 parent e37d0ec commit 44ebf0d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Use `dim_size` to initialize output size of the `EquilibriumAggregation` layer ([#7530](https://github.com/pyg-team/pytorch_geometric/pull/7530))
- Added a `max_num_elements` parameter to the forward method of `GraphMultisetTransformer`, `GRUAggregation`, `LSTMAggregation` and `SetTransformerAggregation` ([#7529](https://github.com/pyg-team/pytorch_geometric/pull/7529))
- Fixed empty edge indices handling in `SparseTensor` ([#7519](https://github.com/pyg-team/pytorch_geometric/pull/7519))
- Move the `scaler` tensor in `GeneralConv` to the correct device ([#7484](https://github.com/pyg-team/pytorch_geometric/pull/7484))
Expand Down
6 changes: 0 additions & 6 deletions test/nn/aggr/test_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def test_equilibrium(iter, alpha):
out = model(x)
assert out.size() == (1, 2)

with pytest.raises(ValueError):
model(x, dim_size=0)

out = model(x, dim_size=3)
assert out.size() == (3, 2)
assert torch.all(out[1:, :] == 0)
Expand All @@ -43,9 +40,6 @@ def test_equilibrium_batch(iter, alpha):
out = model(x, batch)
assert out.size() == (2, 2)

with pytest.raises(ValueError):
model(x, dim_size=0)

out = model(x, dim_size=3)
assert out.size() == (3, 2)
assert torch.all(out[1:, :] == 0)
20 changes: 5 additions & 15 deletions torch_geometric/nn/aggr/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,8 @@ def reset_parameters(self):
reset(self.optimizer)
reset(self.potential)

def init_output(self, index: Optional[Tensor] = None) -> Tensor:
index_size = 1 if index is None else int(index.max().item() + 1)
return torch.zeros(index_size, self.output_dim, requires_grad=True,
def init_output(self, dim_size: int) -> Tensor:
return torch.zeros(dim_size, self.output_dim, requires_grad=True,
device=self.lamb.device).float()

def reg(self, y: Tensor) -> Tensor:
Expand All @@ -163,20 +162,11 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None,

self.assert_index_present(index)

index_size = 1 if index is None else index.max() + 1
dim_size = index_size if dim_size is None else dim_size

if dim_size < index_size:
raise ValueError("`dim_size` is less than `index` "
"implied size")
dim_size = int(index.max()) + 1 if dim_size is None else dim_size

with torch.enable_grad():
y = self.optimizer(x, self.init_output(index), index, self.energy,
iterations=self.grad_iter)

if dim_size > index_size:
zero = y.new_zeros(dim_size - index_size, *y.size()[1:])
y = torch.cat([y, zero])
y = self.optimizer(x, self.init_output(dim_size), index,
self.energy, iterations=self.grad_iter)

return y

Expand Down

0 comments on commit 44ebf0d

Please sign in to comment.