From 24a661fde1c628822c4ea861a9f9d54dff1db9b4 Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Sun, 19 Nov 2023 09:15:40 +0100 Subject: [PATCH] Fix issue with uninitialized module in `GAT+hetero+DDP` scenario (#8397) This PR fixes the error described in #8382. --------- Co-authored-by: rusty1s --- CHANGELOG.md | 1 + test/nn/conv/test_point_conv.py | 9 ++++-- test/nn/dense/test_dense_gat_conv.py | 2 +- torch_geometric/nn/conv/gat_conv.py | 42 ++++++++++++++++++++++------ 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 196f79ba6f69..4ecc97adb613 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Breaking Change: Properly initialize modules in `GATConv` depending on whether the input is bipartite or non-bipartite ([#8397](https://github.com/pyg-team/pytorch_geometric/pull/8397)) - Fixed `input_id` computation in `NeighborLoader` in case a `mask` is given ([#8312](https://github.com/pyg-team/pytorch_geometric/pull/8312)) - Respect current device when deep-copying `Linear` layers ([#8311](https://github.com/pyg-team/pytorch_geometric/pull/8311)) - Fixed `Data.subgraph()`/`HeteroData.subgraph()` in case `edge_index` is not defined ([#8277](https://github.com/pyg-team/pytorch_geometric/pull/8277)) diff --git a/test/nn/conv/test_point_conv.py b/test/nn/conv/test_point_conv.py index 84c0069ed0d5..496b2442b0cb 100644 --- a/test/nn/conv/test_point_conv.py +++ b/test/nn/conv/test_point_conv.py @@ -53,12 +53,14 @@ def test_point_net_conv(): assert out.size() == (2, 32) assert torch.allclose(conv((x1, None), (pos1, pos2), edge_index), out) assert torch.allclose(conv(x1, (pos1, pos2), adj1.t()), out, atol=1e-6) - assert torch.allclose(conv((x1, None), (pos1, pos2), adj1.t()), out) + assert torch.allclose(conv((x1, None), (pos1, pos2), adj1.t()), out, + atol=1e-6) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) assert torch.allclose(conv(x1, (pos1, pos2), adj2.t()), out, atol=1e-6) - assert torch.allclose(conv((x1, None), (pos1, pos2), adj2.t()), out) + assert torch.allclose(conv((x1, None), (pos1, pos2), adj2.t()), out, + atol=1e-6) if is_full_test(): t = '(PairOptTensor, PairTensor, Tensor) -> Tensor' @@ -68,4 +70,5 @@ def test_point_net_conv(): if is_full_test() and torch_geometric.typing.WITH_TORCH_SPARSE: t = '(PairOptTensor, PairTensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert torch.allclose(jit((x1, None), (pos1, pos2), adj2.t()), out) + assert torch.allclose(jit((x1, None), (pos1, pos2), adj2.t()), out, + atol=1e-6) diff --git a/test/nn/dense/test_dense_gat_conv.py b/test/nn/dense/test_dense_gat_conv.py index d54ee28b496e..7c2c17dc8241 100644 --- a/test/nn/dense/test_dense_gat_conv.py +++ b/test/nn/dense/test_dense_gat_conv.py @@ -14,7 +14,7 @@ def test_dense_gat_conv(heads, concat): assert str(dense_conv) == f'DenseGATConv(16, 16, heads={heads})' # Ensure same weights and bias: - dense_conv.lin = sparse_conv.lin_src + dense_conv.lin = sparse_conv.lin dense_conv.att_src = sparse_conv.att_src dense_conv.att_dst = sparse_conv.att_dst dense_conv.bias = sparse_conv.bias diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index ad1d66359da4..3e4dcd69cc99 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -149,10 +149,10 @@ def __init__( # In case we are operating in bipartite graphs, we apply separate # transformations 'lin_src' and 'lin_dst' to source and target nodes: + self.lin = self.lin_src = self.lin_dst = None if isinstance(in_channels, int): - self.lin_src = Linear(in_channels, heads * out_channels, - bias=False, weight_initializer='glorot') - self.lin_dst = self.lin_src + self.lin = Linear(in_channels, heads * out_channels, bias=False, + weight_initializer='glorot') else: self.lin_src = Linear(in_channels[0], heads * out_channels, False, weight_initializer='glorot') @@ -182,8 +182,12 @@ def __init__( def reset_parameters(self): super().reset_parameters() - self.lin_src.reset_parameters() - self.lin_dst.reset_parameters() + if self.lin is not None: + self.lin.reset_parameters() + if self.lin_src is not None: + self.lin_src.reset_parameters() + if self.lin_dst is not None: + self.lin_dst.reset_parameters() if self.lin_edge is not None: self.lin_edge.reset_parameters() glorot(self.att_src) @@ -232,13 +236,33 @@ def forward( # transform source and target node features via separate weights: if isinstance(x, Tensor): assert x.dim() == 2, "Static graphs not supported in 'GATConv'" - x_src = x_dst = self.lin_src(x).view(-1, H, C) + + if self.lin is not None: + x_src = x_dst = self.lin(x).view(-1, H, C) + else: + # If the module is initialized as bipartite, transform source + # and destination node features separately: + assert self.lin_src is not None and self.lin_dst is not None + x_src = self.lin_src(x).view(-1, H, C) + x_dst = self.lin_dst(x).view(-1, H, C) + else: # Tuple of source and target node features: x_src, x_dst = x assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" - x_src = self.lin_src(x_src).view(-1, H, C) - if x_dst is not None: - x_dst = self.lin_dst(x_dst).view(-1, H, C) + + if self.lin is not None: + # If the module is initialized as non-bipartite, we expect that + # source and destination node features have the same shape and + # that they their transformations are shared: + x_src = self.lin(x_src).view(-1, H, C) + if x_dst is not None: + x_dst = self.lin(x_dst).view(-1, H, C) + else: + assert self.lin_src is not None and self.lin_dst is not None + + x_src = self.lin_src(x_src).view(-1, H, C) + if x_dst is not None: + x_dst = self.lin_dst(x_dst).view(-1, H, C) x = (x_src, x_dst)