From 097321aca5cac3455c10c368e21d803388da25d3 Mon Sep 17 00:00:00 2001 From: Piotr Chmiel Date: Mon, 12 Jun 2023 13:19:24 +0100 Subject: [PATCH] Add a `num_nodes` parameter to the HypergraphConv layer. --- CHANGELOG.md | 1 + torch_geometric/nn/conv/hypergraph_conv.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae0f0b87c9c6..27d69fc21e35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Added a `num_edges` parameter to the forward method of `HypergraphConv` layer ([#7560](https://github.com/pyg-team/pytorch_geometric/pull/7560)) - Fixed `get_mesh_laplacian` for `normalization="sym"` ([#7544](https://github.com/pyg-team/pytorch_geometric/pull/7544)) - Use `dim_size` to initialize output size of the `EquilibriumAggregation` layer ([#7530](https://github.com/pyg-team/pytorch_geometric/pull/7530)) - Added a `max_num_elements` parameter to the forward method of `GraphMultisetTransformer`, `GRUAggregation`, `LSTMAggregation` and `SetTransformerAggregation` ([#7529](https://github.com/pyg-team/pytorch_geometric/pull/7529)) diff --git a/torch_geometric/nn/conv/hypergraph_conv.py b/torch_geometric/nn/conv/hypergraph_conv.py index de1c00180ed7..3894def04bc2 100644 --- a/torch_geometric/nn/conv/hypergraph_conv.py +++ b/torch_geometric/nn/conv/hypergraph_conv.py @@ -5,6 +5,7 @@ from torch import Tensor from torch.nn import Parameter +from torch_geometric.experimental import disable_dynamic_shapes from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import glorot, zeros @@ -107,9 +108,11 @@ def reset_parameters(self): glorot(self.att) zeros(self.bias) + @disable_dynamic_shapes(required_args=['num_edges']) def forward(self, x: Tensor, hyperedge_index: Tensor, hyperedge_weight: Optional[Tensor] = None, - hyperedge_attr: Optional[Tensor] = None) -> Tensor: + hyperedge_attr: Optional[Tensor] = None, + num_edges: Optional[int] = None) -> Tensor: r"""Runs the forward pass of the module. Args: @@ -125,10 +128,15 @@ def forward(self, x: Tensor, hyperedge_index: Tensor, in :math:`\mathbb{R}^{M \times F}`. These features only need to get passed in case :obj:`use_attention=True`. (default: :obj:`None`) + num_edges (int, optional) : Number of edges. (default: :obj:`None`) """ - num_nodes, num_edges = x.size(0), 0 - if hyperedge_index.numel() > 0: - num_edges = int(hyperedge_index[1].max()) + 1 + num_nodes = x.size(0) + + if num_edges is None: + if hyperedge_index.numel() > 0: + num_edges = int(hyperedge_index[1].max()) + 1 + else: + num_edges = 0 if hyperedge_weight is None: hyperedge_weight = x.new_ones(num_edges)