Skip to content

Commit

Permalink
Add a num_nodes parameter to the HypergraphConv layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrchmiel committed Jun 12, 2023
1 parent e4297b1 commit 097321a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Added a `num_edges` parameter to the forward method of `HypergraphConv` layer ([#7560](https://github.com/pyg-team/pytorch_geometric/pull/7560))
- Fixed `get_mesh_laplacian` for `normalization="sym"` ([#7544](https://github.com/pyg-team/pytorch_geometric/pull/7544))
- Use `dim_size` to initialize output size of the `EquilibriumAggregation` layer ([#7530](https://github.com/pyg-team/pytorch_geometric/pull/7530))
- Added a `max_num_elements` parameter to the forward method of `GraphMultisetTransformer`, `GRUAggregation`, `LSTMAggregation` and `SetTransformerAggregation` ([#7529](https://github.com/pyg-team/pytorch_geometric/pull/7529))
Expand Down
16 changes: 12 additions & 4 deletions torch_geometric/nn/conv/hypergraph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, zeros
Expand Down Expand Up @@ -107,9 +108,11 @@ def reset_parameters(self):
glorot(self.att)
zeros(self.bias)

@disable_dynamic_shapes(required_args=['num_edges'])
def forward(self, x: Tensor, hyperedge_index: Tensor,
hyperedge_weight: Optional[Tensor] = None,
hyperedge_attr: Optional[Tensor] = None) -> Tensor:
hyperedge_attr: Optional[Tensor] = None,
num_edges: Optional[int] = None) -> Tensor:
r"""Runs the forward pass of the module.
Args:
Expand All @@ -125,10 +128,15 @@ def forward(self, x: Tensor, hyperedge_index: Tensor,
in :math:`\mathbb{R}^{M \times F}`.
These features only need to get passed in case
:obj:`use_attention=True`. (default: :obj:`None`)
num_edges (int, optional) : Number of edges. (default: :obj:`None`)
"""
num_nodes, num_edges = x.size(0), 0
if hyperedge_index.numel() > 0:
num_edges = int(hyperedge_index[1].max()) + 1
num_nodes = x.size(0)

if num_edges is None:
if hyperedge_index.numel() > 0:
num_edges = int(hyperedge_index[1].max()) + 1
else:
num_edges = 0

if hyperedge_weight is None:
hyperedge_weight = x.new_ones(num_edges)
Expand Down

0 comments on commit 097321a

Please sign in to comment.