Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Feb 24, 2023
1 parent d336d13 commit ec803e1
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
19 changes: 15 additions & 4 deletions test/nn/test_encoding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import torch

from torch_geometric.nn import PositionalEncoding
from torch_geometric.nn import PositionalEncoding, TemporalEncoding
from torch_geometric.testing import withCUDA


def test_positional_encoding():
encoder = PositionalEncoding(64)
@withCUDA
def test_positional_encoding(device):
encoder = PositionalEncoding(64).to(device)
assert str(encoder) == 'PositionalEncoding(64)'

x = torch.tensor([1.0, 2.0, 3.0])
x = torch.tensor([1.0, 2.0, 3.0], device=device)
assert encoder(x).size() == (3, 64)


@withCUDA
def test_temporal_encoding(device):
encoder = TemporalEncoding(64).to(device)
assert str(encoder) == 'TemporalEncoding(64)'

x = torch.tensor([1.0, 2.0, 3.0], device=device)
assert encoder(x).size() == (3, 64)
3 changes: 2 additions & 1 deletion torch_geometric/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .to_hetero_transformer import to_hetero
from .to_hetero_with_bases_transformer import to_hetero_with_bases
from .to_fixed_size_transformer import to_fixed_size
from .encoding import PositionalEncoding
from .encoding import PositionalEncoding, TemporalEncoding
from .model_hub import PyGModelHubMixin
from .summary import summary

Expand All @@ -27,6 +27,7 @@
'to_hetero_with_bases',
'to_fixed_size',
'PositionalEncoding',
'TemporalEncoding',
'PyGModelHubMixin',
'summary',
]
42 changes: 41 additions & 1 deletion torch_geometric/nn/encoding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import math

import torch
from torch import Tensor


class PositionalEncoding(torch.nn.Module):
r"""The positional encoding scheme from `"Attention Is All You Need"
r"""The positional encoding scheme from the `"Attention Is All You Need"
<https://arxiv.org/pdf/1706.03762.pdf>`_ paper
.. math::
Expand Down Expand Up @@ -54,3 +56,41 @@ def forward(self, x: Tensor) -> Tensor:

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.out_channels})'


class TemporalEncoding(torch.nn.Module):
r"""The time-encoding function from the `"Do We Really Need Complicated
Model Architectures for Temporal Networks?"
<https://openreview.net/forum?id=ayPPc0SyLv1>`_ paper, which first maps
each entry to a vector with monotonically exponentially decreasing values,
and then uses the cosine function to project all values to range
:math:`[-1, 1]`:
.. math::
y_{i} = \cos \left( \sqrt{d}^{-(i - 1)/\sqrt{d}} \right)
where :math:`d` defines the output feature dimension, and
:math:`1 \leq i \leq d`.
Args:
out_channels (int): Size :math:`d` of each output sample.
"""
def __init__(self, out_channels: int):
super().__init__()
self.out_channels = out_channels

sqrt = math.sqrt(out_channels)
weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels).view(1, -1)
self.register_buffer('weight', weight)

self.reset_parameters()

def reset_parameters(self):
pass

def forward(self, x: Tensor) -> Tensor:
""""""
return torch.cos(x.view(-1, 1) @ self.weight)

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.out_channels})'

0 comments on commit ec803e1

Please sign in to comment.