Skip to content

Commit

Permalink
Fix issue with uninitialized module in GAT+hetero+DDP scenario (#8397)
Browse files Browse the repository at this point in the history
This PR fixes the error described in #8382.

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
DamianSzwichtenberg and rusty1s authored Nov 19, 2023
1 parent cf24b4b commit 24a661f
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Breaking Change: Properly initialize modules in `GATConv` depending on whether the input is bipartite or non-bipartite ([#8397](https://github.com/pyg-team/pytorch_geometric/pull/8397))
- Fixed `input_id` computation in `NeighborLoader` in case a `mask` is given ([#8312](https://github.com/pyg-team/pytorch_geometric/pull/8312))
- Respect current device when deep-copying `Linear` layers ([#8311](https://github.com/pyg-team/pytorch_geometric/pull/8311))
- Fixed `Data.subgraph()`/`HeteroData.subgraph()` in case `edge_index` is not defined ([#8277](https://github.com/pyg-team/pytorch_geometric/pull/8277))
Expand Down
9 changes: 6 additions & 3 deletions test/nn/conv/test_point_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@ def test_point_net_conv():
assert out.size() == (2, 32)
assert torch.allclose(conv((x1, None), (pos1, pos2), edge_index), out)
assert torch.allclose(conv(x1, (pos1, pos2), adj1.t()), out, atol=1e-6)
assert torch.allclose(conv((x1, None), (pos1, pos2), adj1.t()), out)
assert torch.allclose(conv((x1, None), (pos1, pos2), adj1.t()), out,
atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))
assert torch.allclose(conv(x1, (pos1, pos2), adj2.t()), out, atol=1e-6)
assert torch.allclose(conv((x1, None), (pos1, pos2), adj2.t()), out)
assert torch.allclose(conv((x1, None), (pos1, pos2), adj2.t()), out,
atol=1e-6)

if is_full_test():
t = '(PairOptTensor, PairTensor, Tensor) -> Tensor'
Expand All @@ -68,4 +70,5 @@ def test_point_net_conv():
if is_full_test() and torch_geometric.typing.WITH_TORCH_SPARSE:
t = '(PairOptTensor, PairTensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, None), (pos1, pos2), adj2.t()), out)
assert torch.allclose(jit((x1, None), (pos1, pos2), adj2.t()), out,
atol=1e-6)
2 changes: 1 addition & 1 deletion test/nn/dense/test_dense_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_dense_gat_conv(heads, concat):
assert str(dense_conv) == f'DenseGATConv(16, 16, heads={heads})'

# Ensure same weights and bias:
dense_conv.lin = sparse_conv.lin_src
dense_conv.lin = sparse_conv.lin
dense_conv.att_src = sparse_conv.att_src
dense_conv.att_dst = sparse_conv.att_dst
dense_conv.bias = sparse_conv.bias
Expand Down
42 changes: 33 additions & 9 deletions torch_geometric/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,10 @@ def __init__(

# In case we are operating in bipartite graphs, we apply separate
# transformations 'lin_src' and 'lin_dst' to source and target nodes:
self.lin = self.lin_src = self.lin_dst = None
if isinstance(in_channels, int):
self.lin_src = Linear(in_channels, heads * out_channels,
bias=False, weight_initializer='glorot')
self.lin_dst = self.lin_src
self.lin = Linear(in_channels, heads * out_channels, bias=False,
weight_initializer='glorot')
else:
self.lin_src = Linear(in_channels[0], heads * out_channels, False,
weight_initializer='glorot')
Expand Down Expand Up @@ -182,8 +182,12 @@ def __init__(

def reset_parameters(self):
super().reset_parameters()
self.lin_src.reset_parameters()
self.lin_dst.reset_parameters()
if self.lin is not None:
self.lin.reset_parameters()
if self.lin_src is not None:
self.lin_src.reset_parameters()
if self.lin_dst is not None:
self.lin_dst.reset_parameters()
if self.lin_edge is not None:
self.lin_edge.reset_parameters()
glorot(self.att_src)
Expand Down Expand Up @@ -232,13 +236,33 @@ def forward(
# transform source and target node features via separate weights:
if isinstance(x, Tensor):
assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
x_src = x_dst = self.lin_src(x).view(-1, H, C)

if self.lin is not None:
x_src = x_dst = self.lin(x).view(-1, H, C)
else:
# If the module is initialized as bipartite, transform source
# and destination node features separately:
assert self.lin_src is not None and self.lin_dst is not None
x_src = self.lin_src(x).view(-1, H, C)
x_dst = self.lin_dst(x).view(-1, H, C)

else: # Tuple of source and target node features:
x_src, x_dst = x
assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'"
x_src = self.lin_src(x_src).view(-1, H, C)
if x_dst is not None:
x_dst = self.lin_dst(x_dst).view(-1, H, C)

if self.lin is not None:
# If the module is initialized as non-bipartite, we expect that
# source and destination node features have the same shape and
# that they their transformations are shared:
x_src = self.lin(x_src).view(-1, H, C)
if x_dst is not None:
x_dst = self.lin(x_dst).view(-1, H, C)
else:
assert self.lin_src is not None and self.lin_dst is not None

x_src = self.lin_src(x_src).view(-1, H, C)
if x_dst is not None:
x_dst = self.lin_dst(x_dst).view(-1, H, C)

x = (x_src, x_dst)

Expand Down

0 comments on commit 24a661f

Please sign in to comment.