diff --git a/pypots/__init__.py b/pypots/__init__.py index 1bb2b9d8..62da4b21 100644 --- a/pypots/__init__.py +++ b/pypots/__init__.py @@ -22,7 +22,7 @@ # # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' -__version__ = "0.4.1" +__version__ = "0.5" from . import imputation, classification, clustering, forecasting, optim, data, utils diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 07d93a45..4620eb4c 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -13,6 +13,7 @@ from .saits import SAITS from .transformer import Transformer from .itransformer import iTransformer +from .nonstationary_transformer import NonstationaryTransformer from .timesnet import TimesNet from .etsformer import ETSformer from .fedformer import FEDformer @@ -45,6 +46,7 @@ "DLinear", "Informer", "Autoformer", + "NonstationaryTransformer", "BRITS", "MRNN", "GPVAE", diff --git a/pypots/imputation/nonstationary_transformer/__init__.py b/pypots/imputation/nonstationary_transformer/__init__.py new file mode 100644 index 00000000..e3b18f8c --- /dev/null +++ b/pypots/imputation/nonstationary_transformer/__init__.py @@ -0,0 +1,24 @@ +""" +The package of the partially-observed time-series imputation model Nonstationary-Transformer. + +Refer to the paper +`Yong Liu, Haixu Wu, Jianmin Wang, Mingsheng Long. +Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting. +Advances in Neural Information Processing Systems 35 (2022): 9881-9893. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/thuml/Nonstationary_Transformers + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import NonstationaryTransformer + +__all__ = [ + "NonstationaryTransformer", +] diff --git a/pypots/imputation/nonstationary_transformer/core.py b/pypots/imputation/nonstationary_transformer/core.py new file mode 100644 index 00000000..9ca21e1d --- /dev/null +++ b/pypots/imputation/nonstationary_transformer/core.py @@ -0,0 +1,111 @@ +""" +The core wrapper assembles the submodules of NonstationaryTransformer imputation model +and takes over the forward progress of the algorithm. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from ...nn.modules.nonstationary_transformer import ( + NonstationaryTransformerEncoder, + Projector, +) +from ...nn.modules.saits import SaitsLoss, SaitsEmbedding +from ...nn.functional.normalization import nonstationary_norm, nonstationary_denorm + + +class _NonstationaryTransformer(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + n_layers: int, + d_model: int, + n_heads: int, + d_ffn: int, + d_projector_hidden: int, + n_projector_hidden_layers: int, + dropout: float, + attn_dropout: float, + ORT_weight: float = 1, + MIT_weight: float = 1, + ): + super().__init__() + + d_k = d_v = d_model // n_heads + self.n_steps = n_steps + + self.saits_embedding = SaitsEmbedding( + n_features * 2, + d_model, + with_pos=False, + dropout=dropout, + ) + self.encoder = NonstationaryTransformerEncoder( + n_layers, + d_model, + n_heads, + d_k, + d_v, + d_ffn, + dropout, + attn_dropout, + ) + self.tau_learner = Projector( + d_in=n_features, + n_steps=n_steps, + d_hidden=d_projector_hidden, + n_hidden_layers=n_projector_hidden_layers, + d_output=1, + ) + self.delta_learner = Projector( + d_in=n_features, + n_steps=n_steps, + d_hidden=d_projector_hidden, + n_hidden_layers=n_projector_hidden_layers, + d_output=n_steps, + ) + + # for the imputation task, the output dim is the same as input dim + self.output_projection = nn.Linear(d_model, n_features) + self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, missing_mask = inputs["X"], inputs["missing_mask"] + X_enc, means, stdev = nonstationary_norm(X, missing_mask) + + tau = self.tau_learner(X, stdev).exp() + delta = self.delta_learner(X, means) + + # WDU: the original Nonstationary Transformer paper isn't proposed for imputation task. Hence the model doesn't + # take the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the + # SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project back from the hidden space to the original space. + enc_out = self.saits_embedding(X, missing_mask) + + # NonstationaryTransformer encoder processing + enc_out, attns = self.encoder(enc_out, tau=tau, delta=delta) + # project back the original data space + reconstruction = self.output_projection(enc_out) + reconstruction = nonstationary_denorm(reconstruction, means, stdev) + + imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction + results = { + "imputed_data": imputed_data, + } + + # if in training mode, return results with losses + if training: + X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] + loss, ORT_loss, MIT_loss = self.saits_loss_func( + reconstruction, X_ori, missing_mask, indicating_mask + ) + results["ORT_loss"] = ORT_loss + results["MIT_loss"] = MIT_loss + # `loss` is always the item for backward propagating to update the model + results["loss"] = loss + + return results diff --git a/pypots/imputation/nonstationary_transformer/data.py b/pypots/imputation/nonstationary_transformer/data.py new file mode 100644 index 00000000..3b703cb5 --- /dev/null +++ b/pypots/imputation/nonstationary_transformer/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for NonstationaryTransformer. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForNonstationaryTransformer(DatasetForSAITS): + """Actually NonstationaryTransformer uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_y: bool, + file_type: str = "hdf5", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_y, file_type, rate) diff --git a/pypots/imputation/nonstationary_transformer/model.py b/pypots/imputation/nonstationary_transformer/model.py new file mode 100644 index 00000000..1e5f11f2 --- /dev/null +++ b/pypots/imputation/nonstationary_transformer/model.py @@ -0,0 +1,333 @@ +""" +The implementation of Nonstationary-Transformer for the partially-observed time-series imputation task. + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .core import _NonstationaryTransformer +from .data import DatasetForNonstationaryTransformer +from ..base import BaseNNImputer +from ...data.checking import key_in_data_set +from ...data.dataset import BaseDataset +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.logging import logger + + +class NonstationaryTransformer(BaseNNImputer): + """The PyTorch implementation of the Nonstationary-Transformer model. + NonstationaryTransformer is originally proposed by Wu et al. in :cite:`liu2022nonstationary`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_layers : + The number of layers in the NonstationaryTransformer model. + + d_model : + The dimension of the model. + + n_heads : + The number of heads in each layer of NonstationaryTransformer. + + d_ffn : + The dimension of the feed-forward network. + + d_projector_hidden : + The dimensions of hidden layers in MLP projectors. + It should be a list of integers and the length of the list should be equal to n_projector_hidden_layers. + + n_projector_hidden_layers : + The number of hidden layers in MLP projectors. + + dropout : + The dropout rate for the model. + + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + n_layers: int, + d_model: int, + n_heads: int, + d_ffn: int, + d_projector_hidden: list, + n_projector_hidden_layers: int, + dropout: float = 0, + ORT_weight: float = 1, + MIT_weight: float = 1, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + assert len(d_projector_hidden) == n_projector_hidden_layers, ( + f"The length of d_hidden should be equal to n_hidden_layers, " + f"but got {len(d_projector_hidden)} and {n_projector_hidden_layers}." + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.n_heads = n_heads + self.n_layers = n_layers + self.d_model = d_model + self.d_ffn = d_ffn + self.d_projector_hidden = d_projector_hidden + self.n_projector_hidden_layers = n_projector_hidden_layers + self.dropout = dropout + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + + # set up the model + self.model = _NonstationaryTransformer( + self.n_steps, + self.n_features, + self.n_layers, + self.d_model, + self.n_heads, + self.d_ffn, + self.d_projector_hidden, + self.n_projector_hidden_layers, + self.dropout, + self.ORT_weight, + self.MIT_weight, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForNonstationaryTransformer( + train_set, return_X_ori=False, return_y=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForNonstationaryTransformer( + val_set, return_X_ori=True, return_y=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + Returns + ------- + file_type : + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, + return_X_ori=False, + return_X_pred=False, + return_y=False, + file_type=file_type, + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Warnings + -------- + The method impute is deprecated. Please use `predict()` instead. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (time steps), n_features], + Imputed data. + """ + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/nn/modules/nonstationary_transformer/__init__.py b/pypots/nn/modules/nonstationary_transformer/__init__.py new file mode 100644 index 00000000..d92de271 --- /dev/null +++ b/pypots/nn/modules/nonstationary_transformer/__init__.py @@ -0,0 +1,26 @@ +""" +The package including the modules of Non-stationary Transformer. + +Refer to the paper +`Yong Liu, Haixu Wu, Jianmin Wang, Mingsheng Long. +Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting. +Advances in Neural Information Processing Systems 35 (2022): 9881-9893. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/thuml/Nonstationary_Transformers + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from .autoencoder import NonstationaryTransformerEncoder +from .layers import DeStationaryAttention, Projector + +__all__ = [ + "NonstationaryTransformerEncoder", + "DeStationaryAttention", + "Projector", +] diff --git a/pypots/nn/modules/nonstationary_transformer/autoencoder.py b/pypots/nn/modules/nonstationary_transformer/autoencoder.py new file mode 100644 index 00000000..fcd7863f --- /dev/null +++ b/pypots/nn/modules/nonstationary_transformer/autoencoder.py @@ -0,0 +1,120 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from .layers import DeStationaryAttention +from ..transformer.layers import TransformerEncoderLayer + + +class NonstationaryTransformerEncoder(nn.Module): + """NonstationaryTransformer encoder. + Its arch is the same with the original Transformer encoder, + but the attention operator is replaced by the DeStationaryAttention. + + Parameters + ---------- + n_layers: + The number of layers in the encoder. + + d_model: + The dimension of the module manipulation space. + The input tensor will be projected to a space with d_model dimensions. + + n_heads: + The number of heads in multi-head attention. + + d_k: + The dimension of the key and query tensor. + + d_v: + The dimension of the value tensor. + + d_ffn: + The dimension of the hidden layer in the feed-forward network. + + dropout: + The dropout rate. + + attn_dropout: + The dropout rate for the attention map. + + """ + + def __init__( + self, + n_layers: int, + d_model: int, + n_heads: int, + d_k: int, + d_v: int, + d_ffn: int, + dropout: float, + attn_dropout: float, + ): + super().__init__() + + self.enc_layer_stack = nn.ModuleList( + [ + TransformerEncoderLayer( + DeStationaryAttention(d_k**0.5, attn_dropout), + d_model, + n_heads, + d_k, + d_v, + d_ffn, + dropout, + ) + for _ in range(n_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]: + """Forward processing of the encoder. + + Parameters + ---------- + x: + Input tensor. + + src_mask: + Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps]. + + Returns + ------- + enc_output: + Output tensor. + + attn_weights_collector: + A list containing the attention map from each encoder layer. + + """ + attn_weights_collector = [] + enc_output = x + + if src_mask is None: + # triangular causal mask + bz, n_steps, _ = x.shape + mask_shape = [bz, n_steps, n_steps] + src_mask = torch.triu( + torch.ones(mask_shape, dtype=torch.bool), diagonal=1 + ).to(x.device) + + for layer in self.enc_layer_stack: + enc_output, attn_weights = layer(enc_output, src_mask, **kwargs) + attn_weights_collector.append(attn_weights) + + return enc_output, attn_weights_collector diff --git a/pypots/nn/modules/nonstationary_transformer/layers.py b/pypots/nn/modules/nonstationary_transformer/layers.py new file mode 100644 index 00000000..8464bc9e --- /dev/null +++ b/pypots/nn/modules/nonstationary_transformer/layers.py @@ -0,0 +1,107 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import math +from typing import Optional, Tuple + +import torch +import torch.fft +import torch.nn as nn + +from ..transformer.attention import AttentionOperator + + +class DeStationaryAttention(AttentionOperator): + """De-stationary Attention""" + + def __init__(self, temperature: float, attn_dropout: float = 0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + + def forward( + self, + q: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # q, k, v all have 4 dimensions [batch_size, n_steps, n_heads, d_tensor] + # d_tensor could be d_q, d_k, d_v + + B, L, H, E = q.shape + _, S, _, D = v.shape + temperature = self.temperature or 1.0 / math.sqrt(E) + + tau, delta = kwargs["tau"], kwargs["delta"] + tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1 + delta = ( + 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1) + ) # B x 1 x 1 x S + + # De-stationary Attention, rescaling pre-softmax score with learned de-stationary factors + scores = torch.einsum("blhe,bshe->bhls", q, k) * tau + delta + + if attn_mask is not None: + scores.masked_fill_(attn_mask, -torch.inf) + + attn = self.dropout(torch.softmax(temperature * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", attn, v) + output = V.contiguous() + + return output, attn + + +class Projector(nn.Module): + """ + MLP to learn the De-stationary factors + """ + + def __init__( + self, + d_in: int, + n_steps: int, + d_hidden: list, + n_hidden_layers: int, + d_output: int, + kernel_size: int = 3, + ): + super().__init__() + + assert ( + len(d_hidden) == n_hidden_layers + ), f"The length of d_hidden should be equal to n_hidden_layers, but got {len(d_hidden)} and {n_hidden_layers}." + + padding = 1 if torch.__version__ >= "1.5.0" else 2 + self.series_conv = nn.Conv1d( + in_channels=n_steps, + out_channels=1, + kernel_size=kernel_size, + padding=padding, + padding_mode="circular", + bias=False, + ) + + layers = [nn.Linear(2 * d_in, d_hidden[0]), nn.ReLU()] + for i in range(n_hidden_layers - 1): + layers += [nn.Linear(d_hidden[i], d_hidden[i + 1]), nn.ReLU()] + + layers += [nn.Linear(d_hidden[-1], d_output, bias=False)] + self.backbone = nn.Sequential(*layers) + + def forward(self, x, stats): + # x: B x S x E + # stats: B x 1 x E + # y: B x O + batch_size = x.shape[0] + x = self.series_conv(x) # B x 1 x E + x = torch.cat([x, stats], dim=1) # B x 2 x E + x = x.view(batch_size, -1) # B x 2E + y = self.backbone(x) # B x O + + return y diff --git a/tests/imputation/nonstationary_transformer.py b/tests/imputation/nonstationary_transformer.py new file mode 100644 index 00000000..57dcf537 --- /dev/null +++ b/tests/imputation/nonstationary_transformer.py @@ -0,0 +1,139 @@ +""" +Test cases for NonstationaryTransformer imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import NonstationaryTransformer +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + GENERAL_H5_TRAIN_SET_PATH, + GENERAL_H5_VAL_SET_PATH, + GENERAL_H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestNonstationaryTransformer(unittest.TestCase): + logger.info("Running tests for an imputation model NonstationaryTransformer...") + + # set the log and model saving path + saving_path = os.path.join( + RESULT_SAVING_DIR_FOR_IMPUTATION, "NonstationaryTransformer" + ) + model_save_name = "saved_nonstationary_transformer_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a NonstationaryTransformer model + nonstationary_transformer = NonstationaryTransformer( + DATA["n_steps"], + DATA["n_features"], + n_layers=2, + d_model=32, + n_heads=2, + d_ffn=32, + d_projector_hidden=[64, 64], + n_projector_hidden_layers=2, + dropout=0, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-nonstationary_transformer") + def test_0_fit(self): + self.nonstationary_transformer.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-nonstationary_transformer") + def test_1_impute(self): + imputation_results = self.nonstationary_transformer.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"NonstationaryTransformer test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-nonstationary_transformer") + def test_2_parameters(self): + assert ( + hasattr(self.nonstationary_transformer, "model") + and self.nonstationary_transformer.model is not None + ) + + assert ( + hasattr(self.nonstationary_transformer, "optimizer") + and self.nonstationary_transformer.optimizer is not None + ) + + assert hasattr(self.nonstationary_transformer, "best_loss") + self.assertNotEqual(self.nonstationary_transformer.best_loss, float("inf")) + + assert ( + hasattr(self.nonstationary_transformer, "best_model_dict") + and self.nonstationary_transformer.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-nonstationary_transformer") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.nonstationary_transformer) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.nonstationary_transformer.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.nonstationary_transformer.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-nonstationary_transformer") + def test_4_lazy_loading(self): + self.nonstationary_transformer.fit( + GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH + ) + imputation_results = self.nonstationary_transformer.predict( + GENERAL_H5_TEST_SET_PATH + ) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading NonstationaryTransformer test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main()