Skip to content

Commit

Permalink
Add batch_size and max_num_nodes argument to MemPooling layer.
Browse files Browse the repository at this point in the history
It can be used to avoid additional calculations if a user is using
fixed-size batch.
  • Loading branch information
piotrchmiel committed Apr 26, 2023
1 parent 0c4ea3a commit 8647370
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug in `FastHGTConv` that computed values via parameters used to compute the keys ([#7050](https://github.com/pyg-team/pytorch_geometric/pull/7050))
- Accelerated sparse tensor conversion routines ([#7042](https://github.com/pyg-team/pytorch_geometric/pull/7042), [#7043](https://github.com/pyg-team/pytorch_geometric/pull/7043))
- Change `torch_sparse.SparseTensor` logic to utilize `torch.sparse_csr` instead ([#7041](https://github.com/pyg-team/pytorch_geometric/pull/7041))

- Added an optional `batch_size` and `max_num_nodes` arguments to `MemPooling` layer([#7239](https://github.com/pyg-team/pytorch_geometric/pull/7239))
### Removed

- Replaced `FastHGTConv` with `HGTConv` ([#7117](https://github.com/pyg-team/pytorch_geometric/pull/7117))
Expand Down
12 changes: 10 additions & 2 deletions torch_geometric/nn/pool/mem_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def kl_loss(S: Tensor) -> Tensor:
return loss(S.clamp(EPS).log(), P.clamp(EPS))

def forward(self, x: Tensor, batch: Optional[Tensor] = None,
mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
mask: Optional[Tensor] = None,
max_num_nodes: Optional[int] = None,
batch_size: Optional[int] = None) -> Tuple[Tensor, Tensor]:
r"""
Args:
x (torch.Tensor): The node feature tensor of shape
Expand All @@ -97,9 +99,15 @@ def forward(self, x: Tensor, batch: Optional[Tensor] = None,
node features of shape
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`.
(default: :obj:`None`)
max_num_nodes (int, optional): The size of the :math:`B` node
dimension. Automatically calculated if not given.
(default: :obj:`None`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
"""
if x.dim() <= 2:
x, mask = to_dense_batch(x, batch)
x, mask = to_dense_batch(x, batch, max_num_nodes=max_num_nodes,
batch_size=batch_size)
elif mask is None:
mask = x.new_ones((x.size(0), x.size(1)), dtype=torch.bool)

Expand Down

0 comments on commit 8647370

Please sign in to comment.