Skip to content

Commit

Permalink
feat: add TiDE modules;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed May 9, 2024
1 parent 113d407 commit 036efc8
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 0 deletions.
26 changes: 26 additions & 0 deletions pypots/nn/modules/tide/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
The package including the modules of TiDE.
Refer to the paper
`Abhimanyu Das, Weihao Kong, Andrew Leach, Shaan Mathur, Rajat Sen, and Rose Yu.
"Long-term Forecasting with TiDE: Time-series Dense Encoder".
In Transactions on Machine Learning Research, 2023.
<https://openreview.net/pdf?id=pCbC3aQB5W>`_
Notes
-----
This implementation is inspired by the official one
https://github.com/google-research/google-research/blob/master/tide and https://github.com/lich99/TiDE
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .autoencoder import TideEncoder, TideDecoder

__all__ = [
"TideEncoder",
"TideDecoder",
]
102 changes: 102 additions & 0 deletions pypots/nn/modules/tide/autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.fft
import torch.nn as nn

from .layers import ResBlock


class TideEncoder(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
n_layers: int,
d_hidden: int,
d_feature_encode: int,
dropout: float,
):
super().__init__()
self.n_steps = n_steps
self.n_features = n_features
self.n_layers = n_layers
self.d_hidden = d_hidden
self.res_hidden = d_hidden
self.dropout = dropout
self.d_feature_encode = d_feature_encode

flatten_dim = (
self.n_steps + (self.n_steps + self.n_pred_steps) * self.d_feature_encode
)
self.feature_encoder = ResBlock(
self.n_features, self.res_hidden, self.d_feature_encode, dropout
)
self.encoder_layers = nn.Sequential(
ResBlock(flatten_dim, self.res_hidden, self.d_hidden, dropout),
*(
[ResBlock(self.d_hidden, self.res_hidden, self.d_hidden, dropout)]
* (self.n_layers - 1)
)
)

def forward(self, X, dynamic):
feature = self.feature_encoder(dynamic)
hidden = self.encoder_layers(
torch.cat([X, feature.reshape(feature.shape[0], -1)], dim=-1)
)
return hidden


class TideDecoder(nn.Module):
def __init__(
self,
n_steps: int,
n_pred_steps: int,
n_pred_features: int,
n_layers: int,
d_hidden: int,
d_feature_encode,
dropout: float,
):
super().__init__()
self.n_steps = n_steps
self.n_pred_steps = n_pred_steps
self.d_hidden = d_hidden
res_hidden = d_hidden

self.decoder_layers = nn.Sequential(
*([ResBlock(d_hidden, res_hidden, d_hidden, dropout)] * (n_layers - 1)),
ResBlock(
d_hidden,
res_hidden,
n_pred_features * n_pred_steps,
dropout,
)
)
self.final_temporal_decoder = ResBlock(
n_pred_features + d_feature_encode,
d_hidden,
1,
dropout,
)
self.residual_proj = nn.Linear(self.n_steps, self.n_steps)

def forward(
self,
X,
feature_encoding,
hidden_stats,
):
decoded = self.decoder_layers(hidden_stats).reshape(
hidden_stats.shape[0], self.n_pred_steps, self.n_pred_features
)
dec_out = self.temporalDecoder(
torch.cat([feature_encoding[:, self.n_steps :], decoded], dim=-1)
).squeeze(-1) + self.residual_proj(X)
return dec_out
44 changes: 44 additions & 0 deletions pypots/nn/modules/tide/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F


class LayerNorm(nn.Module):
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""

def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

def forward(self, x):
return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)


class ResBlock(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1, bias=True):
super().__init__()

self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias)
self.fc3 = nn.Linear(input_dim, output_dim, bias=bias)
self.dropout = nn.Dropout(dropout)
self.relu = nn.ReLU()
self.ln = LayerNorm(output_dim, bias=bias)

def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.dropout(out)
out = out + self.fc3(x)
out = self.ln(out)
return out

0 comments on commit 036efc8

Please sign in to comment.