Skip to content

Commit

Permalink
Add new cat option to the aggr argument for HeteroConv (#6634)
Browse files Browse the repository at this point in the history
This new aggregation type will allow users to have the option of having
the features from different edge types be concatenated together during
the grouping stage of `HeteroConv`'s `forward` function.

---------

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
DylanSand and rusty1s authored Feb 7, 2023
1 parent ddda325 commit a5e5a48
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `cat` aggregation type to the `HeteroConv` class so that features can be concatenated during grouping ([#6634](https://github.com/pyg-team/pytorch_geometric/pull/6634))
- Added `torch.compile` support and benchmark study ([#6610](https://github.com/pyg-team/pytorch_geometric/pull/6610))
- Added the `AntiSymmetricConv` layer ([#6577](https://github.com/pyg-team/pytorch_geometric/pull/6577))
- Added a mixin for Huggingface model hub integration ([#5930](https://github.com/pyg-team/pytorch_geometric/pull/5930))
Expand Down
7 changes: 5 additions & 2 deletions test/nn/conv/test_hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
return torch.stack([row, col], dim=0)


@pytest.mark.parametrize('aggr', ['sum', 'mean', 'min', 'max', None])
@pytest.mark.parametrize('aggr', ['sum', 'mean', 'min', 'max', 'cat', None])
def test_hetero_conv(aggr):
data = HeteroData()
data['paper'].x = torch.randn(50, 32)
Expand Down Expand Up @@ -49,7 +49,10 @@ def test_hetero_conv(aggr):
edge_weight_dict=data.edge_weight_dict)

assert len(out) == 2
if aggr is not None:
if aggr == 'cat':
assert out['paper'].size() == (50, 128)
assert out['author'].size() == (30, 64)
elif aggr is not None:
assert out['paper'].size() == (50, 64)
assert out['author'].size() == (30, 64)
else:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/conv/hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class HeteroConv(Module):
aggr (str, optional): The aggregation scheme to use for grouping node
embeddings generated by different relations
(:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
:obj:`None`). (default: :obj:`"sum"`)
:obj:`"cat"`, :obj:`None`). (default: :obj:`"sum"`)
"""
def __init__(self, convs: Dict[EdgeType, Module],
aggr: Optional[str] = "sum"):
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/hgt_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]:
return torch.stack(xs, dim=1)
elif len(xs) == 1:
return xs[0]
elif aggr == "cat":
return torch.cat(xs, dim=-1)
else:
out = torch.stack(xs, dim=0)
out = getattr(torch, aggr)(out, dim=0)
Expand Down

0 comments on commit a5e5a48

Please sign in to comment.