Skip to content

Commit

Permalink
Add NodeEncoder from the GraphMixer paper (#7501)
Browse files Browse the repository at this point in the history
Not very sure it's the correct way to implement what they did in
GraphMixer, but hopefully not too far off..

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
wwymak and rusty1s authored Jun 5, 2023
1 parent 0fb30b8 commit d63627c
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `GraphMixer` model ([#7501](https://github.com/pyg-team/pytorch_geometric/pull/7501))
- Added the `disable_dynamic_shape` experimental flag ([#7246](https://github.com/pyg-team/pytorch_geometric/pull/7246))
- Added the option to override `use_segmm` selection in `HeteroLinear` ([#7474](https://github.com/pyg-team/pytorch_geometric/pull/7474))
- Added the `MovieLens-1M` heterogeneous dataset ([#7479](https://github.com/pyg-team/pytorch_geometric/pull/7479))
Expand Down
26 changes: 26 additions & 0 deletions test/nn/models/test_graph_mixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch

from torch_geometric.nn.models.graph_mixer import NodeEncoder


def test_node_encoder():
x = torch.arange(4, dtype=torch.float).view(-1, 1)
edge_index = torch.tensor([[1, 2, 0, 0, 1, 3], [0, 0, 1, 2, 2, 2]])
edge_time = torch.tensor([0, 1, 1, 1, 2, 3])
seed_time = torch.tensor([2, 2, 2, 2])

encoder = NodeEncoder(time_window=2)
assert str(encoder) == 'NodeEncoder(time_window=2)'

out = encoder(x, edge_index, edge_time, seed_time)
# Node 0 aggregates information from node 2 (excluding node 1).
# Node 1 aggregates information from node 0.
# Node 2 aggregates information from node 0 and node 1 (exluding node 3).
# Node 3 aggregates no information.
expected = torch.tensor([
[0 + 2],
[1 + 0],
[2 + 0.5 * (0 + 1)],
[3],
])
assert torch.allclose(out, expected)
49 changes: 49 additions & 0 deletions torch_geometric/nn/models/graph_mixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torch import Tensor

from torch_geometric.utils import scatter


class NodeEncoder(torch.nn.Module):
r"""The node encoder module from the `"Do We Really Need Complicated
Model Architectures for Temporal Networks?"
<https://openreview.net/forum?id=ayPPc0SyLv1>`_ paper.
:class:`NodeEncoder` captures the 1-hop temporal neighborhood information
via mean pooling.
.. math::
\mathbf{x}_v^{\prime}(t_0) = \mathbf{x}_v + \textrm{mean} \left\{
\mathbf{x}_w : w \in \mathcal{N}(v, t_0 - T, t_0) \right\}
Args:
time_window (int): The temporal window size :math:`T` to define the
1-hop temporal neighborhood.
"""
def __init__(self, time_window: int):
super().__init__()
self.time_window = time_window

def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_time: Tensor,
seed_time: Tensor,
) -> Tensor:
r"""
Args:
x (torch.Tensor): The input node features.
edge_index (torch.Tensor): The edge indices.
edge_time (torch.Tensor): The timestamp attached to every edge.
seed_time (torch.Tensor): The seed time :math:`t_0` for every
destination node.
"""
mask = ((edge_time <= seed_time[edge_index[1]]) &
(edge_time > seed_time[edge_index[1]] - self.time_window))

src, dst = edge_index[:, mask]
mean = scatter(x[src], dst, dim=0, dim_size=x.size(0), reduce='mean')
return x + mean

def __repr__(self) -> str:
return f'{self.__class__.__name__}(time_window={self.time_window})'

0 comments on commit d63627c

Please sign in to comment.