From 036efc8b2ccf6ab98fd89f4d1722358f100a5ed5 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 10 May 2024 01:23:05 +0800 Subject: [PATCH] feat: add TiDE modules; --- pypots/nn/modules/tide/__init__.py | 26 +++++++ pypots/nn/modules/tide/autoencoder.py | 102 ++++++++++++++++++++++++++ pypots/nn/modules/tide/layers.py | 44 +++++++++++ 3 files changed, 172 insertions(+) create mode 100644 pypots/nn/modules/tide/__init__.py create mode 100644 pypots/nn/modules/tide/autoencoder.py create mode 100644 pypots/nn/modules/tide/layers.py diff --git a/pypots/nn/modules/tide/__init__.py b/pypots/nn/modules/tide/__init__.py new file mode 100644 index 00000000..4bcbf925 --- /dev/null +++ b/pypots/nn/modules/tide/__init__.py @@ -0,0 +1,26 @@ +""" +The package including the modules of TiDE. + +Refer to the paper +`Abhimanyu Das, Weihao Kong, Andrew Leach, Shaan Mathur, Rajat Sen, and Rose Yu. +"Long-term Forecasting with TiDE: Time-series Dense Encoder". +In Transactions on Machine Learning Research, 2023. +`_ + +Notes +----- +This implementation is inspired by the official one +https://github.com/google-research/google-research/blob/master/tide and https://github.com/lich99/TiDE + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .autoencoder import TideEncoder, TideDecoder + +__all__ = [ + "TideEncoder", + "TideDecoder", +] diff --git a/pypots/nn/modules/tide/autoencoder.py b/pypots/nn/modules/tide/autoencoder.py new file mode 100644 index 00000000..e626f511 --- /dev/null +++ b/pypots/nn/modules/tide/autoencoder.py @@ -0,0 +1,102 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.fft +import torch.nn as nn + +from .layers import ResBlock + + +class TideEncoder(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + n_layers: int, + d_hidden: int, + d_feature_encode: int, + dropout: float, + ): + super().__init__() + self.n_steps = n_steps + self.n_features = n_features + self.n_layers = n_layers + self.d_hidden = d_hidden + self.res_hidden = d_hidden + self.dropout = dropout + self.d_feature_encode = d_feature_encode + + flatten_dim = ( + self.n_steps + (self.n_steps + self.n_pred_steps) * self.d_feature_encode + ) + self.feature_encoder = ResBlock( + self.n_features, self.res_hidden, self.d_feature_encode, dropout + ) + self.encoder_layers = nn.Sequential( + ResBlock(flatten_dim, self.res_hidden, self.d_hidden, dropout), + *( + [ResBlock(self.d_hidden, self.res_hidden, self.d_hidden, dropout)] + * (self.n_layers - 1) + ) + ) + + def forward(self, X, dynamic): + feature = self.feature_encoder(dynamic) + hidden = self.encoder_layers( + torch.cat([X, feature.reshape(feature.shape[0], -1)], dim=-1) + ) + return hidden + + +class TideDecoder(nn.Module): + def __init__( + self, + n_steps: int, + n_pred_steps: int, + n_pred_features: int, + n_layers: int, + d_hidden: int, + d_feature_encode, + dropout: float, + ): + super().__init__() + self.n_steps = n_steps + self.n_pred_steps = n_pred_steps + self.d_hidden = d_hidden + res_hidden = d_hidden + + self.decoder_layers = nn.Sequential( + *([ResBlock(d_hidden, res_hidden, d_hidden, dropout)] * (n_layers - 1)), + ResBlock( + d_hidden, + res_hidden, + n_pred_features * n_pred_steps, + dropout, + ) + ) + self.final_temporal_decoder = ResBlock( + n_pred_features + d_feature_encode, + d_hidden, + 1, + dropout, + ) + self.residual_proj = nn.Linear(self.n_steps, self.n_steps) + + def forward( + self, + X, + feature_encoding, + hidden_stats, + ): + decoded = self.decoder_layers(hidden_stats).reshape( + hidden_stats.shape[0], self.n_pred_steps, self.n_pred_features + ) + dec_out = self.temporalDecoder( + torch.cat([feature_encoding[:, self.n_steps :], decoded], dim=-1) + ).squeeze(-1) + self.residual_proj(X) + return dec_out diff --git a/pypots/nn/modules/tide/layers.py b/pypots/nn/modules/tide/layers.py new file mode 100644 index 00000000..05e938cd --- /dev/null +++ b/pypots/nn/modules/tide/layers.py @@ -0,0 +1,44 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.fft +import torch.nn as nn +import torch.nn.functional as F + + +class LayerNorm(nn.Module): + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" + + def __init__(self, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + + def forward(self, x): + return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) + + +class ResBlock(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1, bias=True): + super().__init__() + + self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias) + self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias) + self.fc3 = nn.Linear(input_dim, output_dim, bias=bias) + self.dropout = nn.Dropout(dropout) + self.relu = nn.ReLU() + self.ln = LayerNorm(output_dim, bias=bias) + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + out = self.dropout(out) + out = out + self.fc3(x) + out = self.ln(out) + return out