Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the TemporalEncoding module #6785

Merged
merged 4 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `pip` caching in CI ([#6462](https://github.com/pyg-team/pytorch_geometric/pull/6462))
- Added the `TemporalEncoding` module ([#6785](https://github.com/pyg-team/pytorch_geometric/pull/6785))
- Added CPU-optimized `spmm_reduce` functionality via CSR format ([#6699](https://github.com/pyg-team/pytorch_geometric/pull/6699), [#6759](https://github.com/pyg-team/pytorch_geometric/pull/6759))
- Added support for the revised version of the `MD17` dataset ([#6734](https://github.com/pyg-team/pytorch_geometric/pull/6734))
- Added TorchScript support to the `RECT_L` model ([#6727](https://github.com/pyg-team/pytorch_geometric/pull/6727))
Expand Down
19 changes: 15 additions & 4 deletions test/nn/test_encoding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import torch

from torch_geometric.nn import PositionalEncoding
from torch_geometric.nn import PositionalEncoding, TemporalEncoding
from torch_geometric.testing import withCUDA


def test_positional_encoding():
encoder = PositionalEncoding(64)
@withCUDA
def test_positional_encoding(device):
encoder = PositionalEncoding(64).to(device)
assert str(encoder) == 'PositionalEncoding(64)'

x = torch.tensor([1.0, 2.0, 3.0])
x = torch.tensor([1.0, 2.0, 3.0], device=device)
assert encoder(x).size() == (3, 64)


@withCUDA
def test_temporal_encoding(device):
encoder = TemporalEncoding(64).to(device)
assert str(encoder) == 'TemporalEncoding(64)'

x = torch.tensor([1.0, 2.0, 3.0], device=device)
assert encoder(x).size() == (3, 64)
3 changes: 2 additions & 1 deletion torch_geometric/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .to_hetero_transformer import to_hetero
from .to_hetero_with_bases_transformer import to_hetero_with_bases
from .to_fixed_size_transformer import to_fixed_size
from .encoding import PositionalEncoding
from .encoding import PositionalEncoding, TemporalEncoding
from .model_hub import PyGModelHubMixin
from .summary import summary

Expand All @@ -27,6 +27,7 @@
'to_hetero_with_bases',
'to_fixed_size',
'PositionalEncoding',
'TemporalEncoding',
'PyGModelHubMixin',
'summary',
]
42 changes: 41 additions & 1 deletion torch_geometric/nn/encoding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import math

import torch
from torch import Tensor


class PositionalEncoding(torch.nn.Module):
r"""The positional encoding scheme from `"Attention Is All You Need"
r"""The positional encoding scheme from the `"Attention Is All You Need"
<https://arxiv.org/pdf/1706.03762.pdf>`_ paper

.. math::
Expand Down Expand Up @@ -54,3 +56,41 @@ def forward(self, x: Tensor) -> Tensor:

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.out_channels})'


class TemporalEncoding(torch.nn.Module):
r"""The time-encoding function from the `"Do We Really Need Complicated
Model Architectures for Temporal Networks?"
<https://openreview.net/forum?id=ayPPc0SyLv1>`_ paper.
:class:`TemporalEncoding` first maps each entry to a vector with
monotonically exponentially decreasing values, and then uses the cosine
function to project all values to range :math:`[-1, 1]`

.. math::
y_{i} = \cos \left(x \cdot \sqrt{d}^{-(i - 1)/\sqrt{d}} \right)

where :math:`d` defines the output feature dimension, and
:math:`1 \leq i \leq d`.

Args:
out_channels (int): Size :math:`d` of each output sample.
"""
def __init__(self, out_channels: int):
super().__init__()
self.out_channels = out_channels

sqrt = math.sqrt(out_channels)
weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels).view(1, -1)
self.register_buffer('weight', weight)

self.reset_parameters()

def reset_parameters(self):
pass

def forward(self, x: Tensor) -> Tensor:
""""""
return torch.cos(x.view(-1, 1) @ self.weight)

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.out_channels})'