Skip to content

Commit

Permalink
Test HANConv with empty tensors (#4756)
Browse files Browse the repository at this point in the history
* initial commit

* changelog
  • Loading branch information
rusty1s authored Jun 2, 2022
1 parent ef78db3 commit cf2010b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Test `HANConv` with empty tensors ([#4756](https://github.com/pyg-team/pytorch_geometric/pull/4756))
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
Expand Down
32 changes: 26 additions & 6 deletions test/nn/conv/test_han_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


def test_han_conv():

x_dict = {
'author': torch.randn(6, 16),
'paper': torch.randn(5, 12),
Expand All @@ -16,8 +15,8 @@ def test_han_conv():
edge3 = torch.randint(0, 3, (2, 5), dtype=torch.long)
edge_index_dict = {
('author', 'metapath0', 'author'): edge1,
('paper', 'matapath1', 'paper'): edge2,
('paper', 'matapath2', 'paper'): edge3,
('paper', 'metapath1', 'paper'): edge2,
('paper', 'metapath2', 'paper'): edge3,
}

adj_t_dict = {}
Expand Down Expand Up @@ -57,16 +56,15 @@ def test_han_conv():


def test_han_conv_lazy():

x_dict = {
'author': torch.randn(6, 16),
'paper': torch.randn(5, 12),
}
edge1 = torch.randint(0, 6, (2, 8), dtype=torch.long)
edge2 = torch.randint(0, 5, (2, 6), dtype=torch.long)
edge_index_dict = {
('author', 'metapath0', 'author'): edge1,
('paper', 'metapath1', 'paper'): edge2,
('author', 'to', 'author'): edge1,
('paper', 'to', 'paper'): edge2,
}

adj_t_dict = {}
Expand All @@ -90,3 +88,25 @@ def test_han_conv_lazy():
for node_type in out_dict1.keys():
assert torch.allclose(out_dict1[node_type], out_dict2[node_type],
atol=1e-6)


def test_han_conv_empty_tensor():
x_dict = {
'author': torch.randn(6, 16),
'paper': torch.empty(0, 12),
}
edge_index_dict = {
('paper', 'to', 'author'): torch.empty((2, 0), dtype=torch.long),
('author', 'to', 'paper'): torch.empty((2, 0), dtype=torch.long),
('paper', 'to', 'paper'): torch.empty((2, 0), dtype=torch.long),
}

metadata = (list(x_dict.keys()), list(edge_index_dict.keys()))
in_channels = {'author': 16, 'paper': 12}
conv = HANConv(in_channels, 16, metadata, heads=2)

out_dict = conv(x_dict, edge_index_dict)
assert len(out_dict) == 2
assert out_dict['author'].size() == (6, 16)
assert torch.all(out_dict['author'] == 0)
assert out_dict['paper'].size() == (0, 16)
5 changes: 2 additions & 3 deletions torch_geometric/nn/conv/han_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ def forward(
x_node_dict, out_dict = {}, {}

# Iterate over node types:
for node_type, x_node in x_dict.items():
x_node_dict[node_type] = self.proj[node_type](x_node).view(
-1, H, D)
for node_type, x in x_dict.items():
x_node_dict[node_type] = self.proj[node_type](x).view(-1, H, D)
out_dict[node_type] = []

# Iterate over edge types:
Expand Down

0 comments on commit cf2010b

Please sign in to comment.