diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst index 9baa6e75..c7013c63 100644 --- a/docs/pypots.imputation.rst +++ b/docs/pypots.imputation.rst @@ -19,6 +19,15 @@ pypots.imputation.transformer :show-inheritance: :inherited-members: +pypots.imputation.frets +------------------------------ + +.. automodule:: pypots.imputation.frets + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.imputation.crossformer ------------------------------ diff --git a/docs/references.bib b/docs/references.bib index 724d878f..2dcedc22 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -554,4 +554,16 @@ @inproceedings{zhou2022film url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/524ef58c2bd075775861234266e5e020-Paper-Conference.pdf}, volume = {35}, year = {2022} +} + +@inproceedings{yi2023frets, +author = {Yi, Kun and Zhang, Qi and Fan, Wei and Wang, Shoujin and Wang, Pengyang and He, Hui and An, Ning and Lian, Defu and Cao, Longbing and Niu, Zhendong}, +booktitle = {Advances in Neural Information Processing Systems}, +editor = {A. Oh and T. Neumann and A. Globerson and K. Saenko and M. Hardt and S. Levine}, +pages = {76656--76679}, +publisher = {Curran Associates, Inc.}, +title = {Frequency-domain MLPs are More Effective Learners in Time Series Forecasting}, +url = {https://proceedings.neurips.cc/paper_files/paper/2023/file/f1d16af76939f476b5f040fd1398c0a3-Paper-Conference.pdf}, +volume = {36}, +year = {2023} } \ No newline at end of file diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index d352f001..1be5e0cb 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -16,6 +16,7 @@ from .etsformer import ETSformer from .fedformer import FEDformer from .film import FiLM +from .frets import FreTS from .crossformer import Crossformer from .informer import Informer from .autoformer import Autoformer @@ -35,6 +36,7 @@ "ETSformer", "FEDformer", "FiLM", + "FreTS", "Crossformer", "TimesNet", "PatchTST", diff --git a/pypots/imputation/frets/__init__.py b/pypots/imputation/frets/__init__.py new file mode 100644 index 00000000..2c690829 --- /dev/null +++ b/pypots/imputation/frets/__init__.py @@ -0,0 +1,24 @@ +""" +The package of the partially-observed time-series imputation model FreTS. + +Refer to the paper +`Kun Yi, Qi Zhang, Wei Fan, Shoujin Wang, Pengyang Wang, Hui He, Ning An, Defu Lian, Longbing Cao, and Zhendong Niu. +"Frequency-domain MLPs are More Effective Learners in Time Series Forecasting." +Advances in Neural Information Processing Systems 36 (2024). +`_ + +Notes +----- +Partial implementation uses code from https://github.com/aikunyi/FreTS + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import FreTS + +__all__ = [ + "FreTS", +] diff --git a/pypots/imputation/frets/core.py b/pypots/imputation/frets/core.py new file mode 100644 index 00000000..53a4ecfc --- /dev/null +++ b/pypots/imputation/frets/core.py @@ -0,0 +1,83 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn + +from ...nn.modules.frets import BackboneFreTS +from ...nn.modules.saits import SaitsLoss +from ...nn.modules.transformer.embedding import DataEmbedding + + +class _FreTS(nn.Module): + def __init__( + self, + n_steps, + n_features, + embed_size: int = 128, # the default value is the same as the fixed one in the original implementation + hidden_size: int = 256, # the default value is the same as the fixed one in the original implementation + channel_independence: bool = False, + ORT_weight: float = 1, + MIT_weight: float = 1, + ): + super().__init__() + + self.n_steps = n_steps + + self.enc_embedding = DataEmbedding( + n_features * 2, + embed_size, + dropout=0, + with_pos=False, + ) + self.backbone = BackboneFreTS( + n_steps, + n_features, + embed_size, + n_steps, + hidden_size, + channel_independence, + ) + + # for the imputation task, the output dim is the same as input dim + self.output_projection = nn.Linear(embed_size, n_features) + self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, missing_mask = inputs["X"], inputs["missing_mask"] + + # WDU: the original FreTS paper isn't proposed for imputation task. Hence the model doesn't take + # the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the + # embedding layers to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project back from the hidden space to the original space. + + # the same as SAITS, concatenate the time series data and the missing mask for embedding + input_X = torch.cat([X, missing_mask], dim=2) + enc_out = self.enc_embedding(input_X) + + # FreTS processing + backbone_output = self.backbone(enc_out) + reconstruction = self.output_projection(backbone_output) + + imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction + results = { + "imputed_data": imputed_data, + } + + # if in training mode, return results with losses + if training: + X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] + loss, ORT_loss, MIT_loss = self.saits_loss_func( + reconstruction, X_ori, missing_mask, indicating_mask + ) + results["ORT_loss"] = ORT_loss + results["MIT_loss"] = MIT_loss + # `loss` is always the item for backward propagating to update the model + results["loss"] = loss + + return results diff --git a/pypots/imputation/frets/data.py b/pypots/imputation/frets/data.py new file mode 100644 index 00000000..203131f0 --- /dev/null +++ b/pypots/imputation/frets/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for FreTS. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForFreTS(DatasetForSAITS): + """Actually FreTS uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_y: bool, + file_type: str = "hdf5", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_y, file_type, rate) diff --git a/pypots/imputation/frets/model.py b/pypots/imputation/frets/model.py new file mode 100644 index 00000000..42101ac4 --- /dev/null +++ b/pypots/imputation/frets/model.py @@ -0,0 +1,296 @@ +""" +The implementation of FreTS for the partially-observed time-series imputation task. + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .core import _FreTS +from .data import DatasetForFreTS +from ..base import BaseNNImputer +from ...data.checking import key_in_data_set +from ...data.dataset import BaseDataset +from ...optim.adam import Adam +from ...optim.base import Optimizer + + +class FreTS(BaseNNImputer): + """The PyTorch implementation of the FreTS model. + FreTS is originally proposed by Yi et al. in :cite:`yi2023frets`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + embed_size : + The size of the embedding layer in the FreTS model. + + hidden_size : + The size of the hidden layer in the FreTS model. + + channel_independence : + Whether to use the channel independence mechanism in the FreTS model. + + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + embed_size: int = 128, + hidden_size: int = 256, + channel_independence: bool = False, + ORT_weight: float = 1, + MIT_weight: float = 1, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.embed_size = embed_size + self.hidden_size = hidden_size + self.channel_independence = channel_independence + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + + # set up the model + self.model = _FreTS( + self.n_steps, + self.n_features, + self.embed_size, + self.hidden_size, + self.channel_independence, + self.ORT_weight, + self.MIT_weight, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForFreTS( + train_set, return_X_ori=False, return_y=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForFreTS( + val_set, return_X_ori=True, return_y=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + Returns + ------- + file_type : + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, + return_X_ori=False, + return_X_pred=False, + return_y=False, + file_type=file_type, + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (n_steps), n_features], + Imputed data. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["imputation"] diff --git a/pypots/nn/modules/frets/__init__.py b/pypots/nn/modules/frets/__init__.py new file mode 100644 index 00000000..7ebc1dce --- /dev/null +++ b/pypots/nn/modules/frets/__init__.py @@ -0,0 +1,20 @@ +""" +The package including the modules of FiLM. + +Refer to the paper +`Kun Yi, Qi Zhang, Wei Fan, Shoujin Wang, Pengyang Wang, Hui He, Ning An, Defu Lian, Longbing Cao, and Zhendong Niu. +"Frequency-domain MLPs are More Effective Learners in Time Series Forecasting." +Advances in Neural Information Processing Systems 36 (2024). +`_ + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .backbone import BackboneFreTS + +__all__ = [ + "BackboneFreTS", +] diff --git a/pypots/nn/modules/frets/backbone.py b/pypots/nn/modules/frets/backbone.py new file mode 100644 index 00000000..2b53af10 --- /dev/null +++ b/pypots/nn/modules/frets/backbone.py @@ -0,0 +1,131 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn + +import torch.nn.functional as F + + +class BackboneFreTS(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + embed_size: int, + n_pred_steps: int, + hidden_size: int, + channel_independence: bool = False, + ): + super().__init__() + + self.n_steps = n_steps + self.n_features = n_features + self.n_pred_steps = n_pred_steps + self.embed_size = embed_size # embed_size, the input is already embedded + self.hidden_size = hidden_size # hidden_size + self.channel_independence = channel_independence + self.sparsity_threshold = 0.01 + self.scale = 0.02 + + # self.embeddings = nn.Parameter(torch.randn(1, self.embed_size)) # original embedding method, deprecate here + self.r1 = nn.Parameter( + self.scale * torch.randn(self.embed_size, self.embed_size) + ) + self.i1 = nn.Parameter( + self.scale * torch.randn(self.embed_size, self.embed_size) + ) + self.rb1 = nn.Parameter(self.scale * torch.randn(self.embed_size)) + self.ib1 = nn.Parameter(self.scale * torch.randn(self.embed_size)) + self.r2 = nn.Parameter( + self.scale * torch.randn(self.embed_size, self.embed_size) + ) + self.i2 = nn.Parameter( + self.scale * torch.randn(self.embed_size, self.embed_size) + ) + self.rb2 = nn.Parameter(self.scale * torch.randn(self.embed_size)) + self.ib2 = nn.Parameter(self.scale * torch.randn(self.embed_size)) + + self.fc = nn.Sequential( + nn.Linear(self.n_steps * self.embed_size, self.hidden_size), + nn.LeakyReLU(), + nn.Linear(self.hidden_size, self.n_pred_steps), + ) + + # # dimension extension + # def tokenEmb(self, x): + # # x: [Batch, Input length, Channel] + # x = x.permute(0, 2, 1) + # x = x.unsqueeze(3) + # # N*T*1 x 1*D = N*T*D + # y = self.embeddings + # return x * y + + # frequency temporal learner + def MLP_temporal(self, x, B, N, L): + # [B, N, T, D] + x = torch.fft.rfft(x, dim=2, norm="ortho") # FFT on L dimension + y = self.FreMLP(B, N, L, x, self.r2, self.i2, self.rb2, self.ib2) + x = torch.fft.irfft(y, n=self.n_steps, dim=2, norm="ortho") + return x + + # frequency channel learner + def MLP_channel(self, x, B, N, L): + # [B, N, T, D] + x = x.permute(0, 2, 1, 3) + # [B, T, N, D] + x = torch.fft.rfft(x, dim=2, norm="ortho") # FFT on N dimension + y = self.FreMLP(B, L, N, x, self.r1, self.i1, self.rb1, self.ib1) + x = torch.fft.irfft(y, n=self.n_features, dim=2, norm="ortho") + x = x.permute(0, 2, 1, 3) + # [B, N, T, D] + return x + + # frequency-domain MLPs + # dimension: FFT along the dimension, r: the real part of weights, i: the imaginary part of weights + # rb: the real part of bias, ib: the imaginary part of bias + def FreMLP(self, B, nd, dimension, x, r, i, rb, ib): + o1_real = torch.zeros( + [B, nd, dimension // 2 + 1, self.embed_size], device=x.device + ) + o1_imag = torch.zeros( + [B, nd, dimension // 2 + 1, self.embed_size], device=x.device + ) + + o1_real = F.relu( + torch.einsum("bijd,dd->bijd", x.real, r) + - torch.einsum("bijd,dd->bijd", x.imag, i) + + rb + ) + + o1_imag = F.relu( + torch.einsum("bijd,dd->bijd", x.imag, r) + + torch.einsum("bijd,dd->bijd", x.real, i) + + ib + ) + + y = torch.stack([o1_real, o1_imag], dim=-1) + y = F.softshrink(y, lambd=self.sparsity_threshold) + y = torch.view_as_complex(y) + return y + + def forward(self, x): + # x: [Batch, n_steps, embed_size] + B, T, N = x.shape + + x = x.permute(0, 2, 1) + x = x.unsqueeze(3) + + bias = x + # [B, N, T, D] + if self.channel_independence == "0": + x = self.MLP_channel(x, B, N, T) + # [B, N, T, D] + x = self.MLP_temporal(x, B, N, T) + x = x + bias + x = self.fc(x.reshape(B, N, -1)).permute(0, 2, 1) + return x diff --git a/tests/imputation/frets.py b/tests/imputation/frets.py new file mode 100644 index 00000000..8821b36a --- /dev/null +++ b/tests/imputation/frets.py @@ -0,0 +1,123 @@ +""" +Test cases for FreTS imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import FreTS +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + GENERAL_H5_TRAIN_SET_PATH, + GENERAL_H5_VAL_SET_PATH, + GENERAL_H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestFreTS(unittest.TestCase): + logger.info("Running tests for an imputation model FreTS...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "FreTS") + model_save_name = "saved_frets_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a FreTS model + frets = FreTS( + DATA["n_steps"], + DATA["n_features"], + embed_size=128, + hidden_size=256, + channel_independence=False, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-frets") + def test_0_fit(self): + self.frets.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-frets") + def test_1_impute(self): + imputation_results = self.frets.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"FreTS test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-frets") + def test_2_parameters(self): + assert hasattr(self.frets, "model") and self.frets.model is not None + + assert hasattr(self.frets, "optimizer") and self.frets.optimizer is not None + + assert hasattr(self.frets, "best_loss") + self.assertNotEqual(self.frets.best_loss, float("inf")) + + assert ( + hasattr(self.frets, "best_model_dict") + and self.frets.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-frets") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.frets) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.frets.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.frets.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-frets") + def test_4_lazy_loading(self): + self.frets.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) + imputation_results = self.frets.predict(GENERAL_H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading FreTS test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main()