diff --git a/torch_geometric/nn/conv/hypergraph_conv.py b/torch_geometric/nn/conv/hypergraph_conv.py index de1c00180ed75..3894def04bc2f 100644 --- a/torch_geometric/nn/conv/hypergraph_conv.py +++ b/torch_geometric/nn/conv/hypergraph_conv.py @@ -5,6 +5,7 @@ from torch import Tensor from torch.nn import Parameter +from torch_geometric.experimental import disable_dynamic_shapes from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import glorot, zeros @@ -107,9 +108,11 @@ def reset_parameters(self): glorot(self.att) zeros(self.bias) + @disable_dynamic_shapes(required_args=['num_edges']) def forward(self, x: Tensor, hyperedge_index: Tensor, hyperedge_weight: Optional[Tensor] = None, - hyperedge_attr: Optional[Tensor] = None) -> Tensor: + hyperedge_attr: Optional[Tensor] = None, + num_edges: Optional[int] = None) -> Tensor: r"""Runs the forward pass of the module. Args: @@ -125,10 +128,15 @@ def forward(self, x: Tensor, hyperedge_index: Tensor, in :math:`\mathbb{R}^{M \times F}`. These features only need to get passed in case :obj:`use_attention=True`. (default: :obj:`None`) + num_edges (int, optional) : Number of edges. (default: :obj:`None`) """ - num_nodes, num_edges = x.size(0), 0 - if hyperedge_index.numel() > 0: - num_edges = int(hyperedge_index[1].max()) + 1 + num_nodes = x.size(0) + + if num_edges is None: + if hyperedge_index.numel() > 0: + num_edges = int(hyperedge_index[1].max()) + 1 + else: + num_edges = 0 if hyperedge_weight is None: hyperedge_weight = x.new_ones(num_edges)