Skip to content

Commit

Permalink
Add a num_nodes parameter to the HypergraphConv layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrchmiel committed Jun 12, 2023
1 parent e4297b1 commit f6c8a04
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions torch_geometric/nn/conv/hypergraph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, zeros
Expand Down Expand Up @@ -107,9 +108,11 @@ def reset_parameters(self):
glorot(self.att)
zeros(self.bias)

@disable_dynamic_shapes(required_args=['num_edges'])
def forward(self, x: Tensor, hyperedge_index: Tensor,
hyperedge_weight: Optional[Tensor] = None,
hyperedge_attr: Optional[Tensor] = None) -> Tensor:
hyperedge_attr: Optional[Tensor] = None,
num_edges: Optional[int] = None) -> Tensor:
r"""Runs the forward pass of the module.
Args:
Expand All @@ -125,10 +128,15 @@ def forward(self, x: Tensor, hyperedge_index: Tensor,
in :math:`\mathbb{R}^{M \times F}`.
These features only need to get passed in case
:obj:`use_attention=True`. (default: :obj:`None`)
num_edges (int, optional) : Number of edges. (default: :obj:`None`)
"""
num_nodes, num_edges = x.size(0), 0
if hyperedge_index.numel() > 0:
num_edges = int(hyperedge_index[1].max()) + 1
num_nodes = x.size(0)

if num_edges is None:
if hyperedge_index.numel() > 0:
num_edges = int(hyperedge_index[1].max()) + 1
else:
num_edges = 0

if hyperedge_weight is None:
hyperedge_weight = x.new_ones(num_edges)
Expand Down

0 comments on commit f6c8a04

Please sign in to comment.