-
-
Notifications
You must be signed in to change notification settings - Fork 120
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #388 from WenjieDu/(feat)add_nonstationary_transfo…
…rmer Implement Non-stationary Transformer as an imputation model
- Loading branch information
Showing
10 changed files
with
887 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
""" | ||
The package of the partially-observed time-series imputation model Nonstationary-Transformer. | ||
Refer to the paper | ||
`Yong Liu, Haixu Wu, Jianmin Wang, Mingsheng Long. | ||
Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting. | ||
Advances in Neural Information Processing Systems 35 (2022): 9881-9893. | ||
<https://proceedings.neurips.cc/paper_files/paper/2022/file/4054556fcaa934b0bf76da52cf4f92cb-Paper-Conference.pdf>`_ | ||
Notes | ||
----- | ||
This implementation is inspired by the official one https://github.com/thuml/Nonstationary_Transformers | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
|
||
from .model import NonstationaryTransformer | ||
|
||
__all__ = [ | ||
"NonstationaryTransformer", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
""" | ||
The core wrapper assembles the submodules of NonstationaryTransformer imputation model | ||
and takes over the forward progress of the algorithm. | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
import torch.nn as nn | ||
|
||
from ...nn.modules.nonstationary_transformer import ( | ||
NonstationaryTransformerEncoder, | ||
Projector, | ||
) | ||
from ...nn.modules.saits import SaitsLoss, SaitsEmbedding | ||
from ...nn.functional.normalization import nonstationary_norm, nonstationary_denorm | ||
|
||
|
||
class _NonstationaryTransformer(nn.Module): | ||
def __init__( | ||
self, | ||
n_steps: int, | ||
n_features: int, | ||
n_layers: int, | ||
d_model: int, | ||
n_heads: int, | ||
d_ffn: int, | ||
d_projector_hidden: int, | ||
n_projector_hidden_layers: int, | ||
dropout: float, | ||
attn_dropout: float, | ||
ORT_weight: float = 1, | ||
MIT_weight: float = 1, | ||
): | ||
super().__init__() | ||
|
||
d_k = d_v = d_model // n_heads | ||
self.n_steps = n_steps | ||
|
||
self.saits_embedding = SaitsEmbedding( | ||
n_features * 2, | ||
d_model, | ||
with_pos=False, | ||
dropout=dropout, | ||
) | ||
self.encoder = NonstationaryTransformerEncoder( | ||
n_layers, | ||
d_model, | ||
n_heads, | ||
d_k, | ||
d_v, | ||
d_ffn, | ||
dropout, | ||
attn_dropout, | ||
) | ||
self.tau_learner = Projector( | ||
d_in=n_features, | ||
n_steps=n_steps, | ||
d_hidden=d_projector_hidden, | ||
n_hidden_layers=n_projector_hidden_layers, | ||
d_output=1, | ||
) | ||
self.delta_learner = Projector( | ||
d_in=n_features, | ||
n_steps=n_steps, | ||
d_hidden=d_projector_hidden, | ||
n_hidden_layers=n_projector_hidden_layers, | ||
d_output=n_steps, | ||
) | ||
|
||
# for the imputation task, the output dim is the same as input dim | ||
self.output_projection = nn.Linear(d_model, n_features) | ||
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) | ||
|
||
def forward(self, inputs: dict, training: bool = True) -> dict: | ||
X, missing_mask = inputs["X"], inputs["missing_mask"] | ||
X_enc, means, stdev = nonstationary_norm(X, missing_mask) | ||
|
||
tau = self.tau_learner(X, stdev).exp() | ||
delta = self.delta_learner(X, means) | ||
|
||
# WDU: the original Nonstationary Transformer paper isn't proposed for imputation task. Hence the model doesn't | ||
# take the missing mask into account, which means, in the process, the model doesn't know which part of | ||
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the | ||
# SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as | ||
# the output layers to project back from the hidden space to the original space. | ||
enc_out = self.saits_embedding(X, missing_mask) | ||
|
||
# NonstationaryTransformer encoder processing | ||
enc_out, attns = self.encoder(enc_out, tau=tau, delta=delta) | ||
# project back the original data space | ||
reconstruction = self.output_projection(enc_out) | ||
reconstruction = nonstationary_denorm(reconstruction, means, stdev) | ||
|
||
imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction | ||
results = { | ||
"imputed_data": imputed_data, | ||
} | ||
|
||
# if in training mode, return results with losses | ||
if training: | ||
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] | ||
loss, ORT_loss, MIT_loss = self.saits_loss_func( | ||
reconstruction, X_ori, missing_mask, indicating_mask | ||
) | ||
results["ORT_loss"] = ORT_loss | ||
results["MIT_loss"] = MIT_loss | ||
# `loss` is always the item for backward propagating to update the model | ||
results["loss"] = loss | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
""" | ||
Dataset class for NonstationaryTransformer. | ||
""" | ||
|
||
# Created by Wenjie Du <wenjay.du@gmail.com> | ||
# License: BSD-3-Clause | ||
|
||
from typing import Union | ||
|
||
from ..saits.data import DatasetForSAITS | ||
|
||
|
||
class DatasetForNonstationaryTransformer(DatasetForSAITS): | ||
"""Actually NonstationaryTransformer uses the same data strategy as SAITS, needs MIT for training.""" | ||
|
||
def __init__( | ||
self, | ||
data: Union[dict, str], | ||
return_X_ori: bool, | ||
return_y: bool, | ||
file_type: str = "hdf5", | ||
rate: float = 0.2, | ||
): | ||
super().__init__(data, return_X_ori, return_y, file_type, rate) |
Oops, something went wrong.