diff --git a/README.md b/README.md index 9b294fa2..3152f3c4 100644 --- a/README.md +++ b/README.md @@ -187,13 +187,14 @@ mae = calc_mae(imputation, np.nan_to_num(X_ori), indicating_mask) # calculate m ## ❖ Available Algorithms PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values. The currently available algorithms of four tasks are cataloged in the following table with four partitions. -The paper references are all listed at the bottom of this readme file. Please refer to them if you want more details. +The paper references are all listed at the bottom of this readme file. 🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support. This functionality is implemented with the [Microsoft NNI](https://github.com/microsoft/nni) framework. 🔥 Note that Transformer, Crossformer, PatchTST, DLinear, ETSformer, FEDformer, Informer, Autoformer are not proposed as imputation methods in their original papers, -and they cannot accept POTS as input. **To make them applicable on POTS data, we apply the embedding strategy the same as we did in [SAITS paper](https://arxiv.org/pdf/2202.08516).** +and they cannot accept POTS as input. **To make them applicable on POTS data, we apply the embedding strategy and training approach (ORT+MIT) +the same as we did in [SAITS paper](https://arxiv.org/pdf/2202.08516).** | ***`Imputation`*** | 🚥 | 🚥 | 🚥 | |:----------------------:|:-----------:|:-----------------------------------------------------------------------------------------------:|:--------:| diff --git a/docs/index.rst b/docs/index.rst index 217fc05e..6d662cb0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -78,9 +78,14 @@ Welcome to PyPOTS docs! ----------------- -⦿ `Motivation`: Due to all kinds of reasons like failure of collection sensors, communication error, and unexpected malfunction, missing values are common to see in time series from the real-world environment. This makes partially-observed time series (POTS) a pervasive problem in open-world modeling and prevents advanced data analysis. Although this problem is important, the area of data mining on POTS still lacks a dedicated toolkit. PyPOTS is created to fill in this blank. +⦿ `Motivation`: Due to all kinds of reasons like failure of collection sensors, communication error, and unexpected malfunction, missing values are common to see in time series from the real-world environment. +This makes partially-observed time series (POTS) a pervasive problem in open-world modeling and prevents advanced data analysis. +Although this problem is important, the area of data mining on POTS still lacks a dedicated toolkit. PyPOTS is created to fill in this blank. -⦿ `Mission`: PyPOTS is born to become a handy toolbox that is going to make data mining on POTS easy rather than tedious, to help engineers and researchers focus more on the core problems in their hands rather than on how to deal with the missing parts in their data. PyPOTS will keep integrating classical and the latest state-of-the-art data mining algorithms for partially-observed multivariate time series. For sure, besides various algorithms, PyPOTS is going to have unified APIs together with detailed documentation and interactive examples across algorithms as tutorials. +⦿ `Mission`: PyPOTS is born to become a handy toolbox that is going to make data mining on POTS easy rather than tedious, +to help engineers and researchers focus more on the core problems in their hands rather than on how to deal with the missing parts in their data. +PyPOTS will keep integrating classical and the latest state-of-the-art data mining algorithms for partially-observed multivariate time series. +For sure, besides various algorithms, PyPOTS is going to have unified APIs together with detailed documentation and interactive examples across algorithms as tutorials. 🤗 **Please** star this repo to help others notice PyPOTS if you think it is a useful toolkit. **Please** properly `cite PyPOTS `_ in your publications @@ -164,12 +169,16 @@ Additionally, we present you a usage example of imputing missing values in time ❖ Available Algorithms ^^^^^^^^^^^^^^^^^^^^^^^ -PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values. The currently available algorithms of four tasks are cataloged in the following table with four partitions. The paper references are all listed at the bottom of this readme file. Please refer to them if you want more details. - +PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values. +The currently available algorithms of four tasks are cataloged in the following table with four partitions. The paper references are all listed `on the reference page `_. 🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support. This functionality is implemented with the `Microsoft NNI `_ framework. +🔥 Note that Transformer, Crossformer, PatchTST, DLinear, ETSformer, FEDformer, Informer, Autoformer are not proposed as imputation methods in their original papers, +and they cannot accept POTS as input. To make them applicable on POTS data, we apply the embedding strategy and training approach (ORT+MIT) +the same as we did in `SAITS paper `_. + ============================== ================ ========================================================================================================= ====== ========= Task Type Algorithm Year Reference ============================== ================ ========================================================================================================= ====== ========= diff --git a/pypots/imputation/autoformer/model.py b/pypots/imputation/autoformer/model.py index 9f0a3072..c021c2cd 100644 --- a/pypots/imputation/autoformer/model.py +++ b/pypots/imputation/autoformer/model.py @@ -63,6 +63,12 @@ class Autoformer(BaseNNImputer): dropout : The dropout rate for the model. + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + batch_size : The batch size for training and evaluating the model. @@ -115,6 +121,8 @@ def __init__( factor: int, moving_avg_window_size: int, dropout: float = 0, + ORT_weight: float = 1, + MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -144,6 +152,8 @@ def __init__( self.factor = factor self.moving_avg_window_size = moving_avg_window_size self.dropout = dropout + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight # set up the model self.model = _Autoformer( @@ -156,6 +166,8 @@ def __init__( self.factor, self.moving_avg_window_size, self.dropout, + self.ORT_weight, + self.MIT_weight, ) self._send_model_to_given_device() self._print_model_size() diff --git a/pypots/imputation/autoformer/modules/core.py b/pypots/imputation/autoformer/modules/core.py index 14cdb53c..ceb953a1 100644 --- a/pypots/imputation/autoformer/modules/core.py +++ b/pypots/imputation/autoformer/modules/core.py @@ -31,13 +31,17 @@ def __init__( factor, moving_avg_window_size, dropout, + ORT_weight: float = 1, + MIT_weight: float = 1, activation="relu", - output_attention=False, ): super().__init__() self.seq_len = n_steps self.n_layers = n_layers + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + self.enc_embedding = DataEmbedding( n_features * 2, d_model, @@ -48,7 +52,7 @@ def __init__( [ AutoformerEncoderLayer( AutoCorrelationLayer( - AutoCorrelation(False, factor, dropout, output_attention), + AutoCorrelation(factor, dropout), d_model, n_heads, ), @@ -91,8 +95,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } if training: + # apply SAITS loss function to Autoformer on the imputation task + ORT_loss = calc_mse(output, X, masks) + MIT_loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) # `loss` is always the item for backward propagating to update the model - loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) + loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss results["loss"] = loss return results diff --git a/pypots/imputation/autoformer/modules/submodules.py b/pypots/imputation/autoformer/modules/submodules.py index 6eb3d9e2..791c4373 100644 --- a/pypots/imputation/autoformer/modules/submodules.py +++ b/pypots/imputation/autoformer/modules/submodules.py @@ -24,17 +24,11 @@ class AutoCorrelation(nn.Module): def __init__( self, - mask_flag=True, factor=1, - scale=None, attention_dropout=0.1, - output_attention=False, ): super().__init__() self.factor = factor - self.scale = scale - self.mask_flag = mask_flag - self.output_attention = output_attention self.dropout = nn.Dropout(attention_dropout) def time_delay_agg_training(self, values, corr): @@ -165,10 +159,7 @@ def forward(self, queries, keys, values, attn_mask): values.permute(0, 2, 3, 1).contiguous(), corr ).permute(0, 3, 1, 2) - if self.output_attention: - return (V.contiguous(), corr.permute(0, 3, 1, 2)) - else: - return (V.contiguous(), None) + return V.contiguous(), corr.permute(0, 3, 1, 2) class AutoCorrelationLayer(nn.Module): diff --git a/pypots/imputation/crossformer/model.py b/pypots/imputation/crossformer/model.py index e6fc9b2f..692aeaa1 100644 --- a/pypots/imputation/crossformer/model.py +++ b/pypots/imputation/crossformer/model.py @@ -67,6 +67,12 @@ class Crossformer(BaseNNImputer): dropout : The dropout rate for the model. + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + batch_size : The batch size for training and evaluating the model. @@ -120,6 +126,8 @@ def __init__( seg_len: int, win_size: int, dropout: float = 0, + ORT_weight: float = 1, + MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -150,6 +158,8 @@ def __init__( self.seg_len = seg_len self.win_size = win_size self.dropout = dropout + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight # set up the model self.model = _Crossformer( @@ -163,6 +173,8 @@ def __init__( self.seg_len, self.win_size, self.dropout, + self.ORT_weight, + self.MIT_weight, ) self._send_model_to_given_device() self._print_model_size() diff --git a/pypots/imputation/crossformer/modules/core.py b/pypots/imputation/crossformer/modules/core.py index 8eb04df6..614f58e8 100644 --- a/pypots/imputation/crossformer/modules/core.py +++ b/pypots/imputation/crossformer/modules/core.py @@ -29,11 +29,15 @@ def __init__( seg_len, win_size, dropout, + ORT_weight: float = 1, + MIT_weight: float = 1, ): super().__init__() self.n_features = n_features self.d_model = d_model + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight # The padding operation to handle invisible sgemnet length pad_in_len = ceil(1.0 * n_steps / seg_len) * seg_len @@ -104,8 +108,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } if training: + # apply SAITS loss function to Crossformer on the imputation task + ORT_loss = calc_mse(output, X, masks) + MIT_loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) # `loss` is always the item for backward propagating to update the model - loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) + loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss results["loss"] = loss return results diff --git a/pypots/imputation/dlinear/model.py b/pypots/imputation/dlinear/model.py index f6c89976..615cdde7 100644 --- a/pypots/imputation/dlinear/model.py +++ b/pypots/imputation/dlinear/model.py @@ -53,6 +53,12 @@ class DLinear(BaseNNImputer): The dimension of the space in which the time-series data will be embedded and modeled. It is necessary only for DLinear in the non-individual mode. + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + batch_size : The batch size for training and evaluating the model. @@ -101,6 +107,8 @@ def __init__( moving_avg_window_size: int, individual: bool = False, d_model: Optional[int] = None, + ORT_weight: float = 1, + MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -126,14 +134,18 @@ def __init__( self.moving_avg_window_size = moving_avg_window_size self.individual = individual self.d_model = d_model + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight # set up the model self.model = _DLinear( - n_steps, - n_features, - moving_avg_window_size, - individual, - d_model, + self.n_steps, + self.n_features, + self.moving_avg_window_size, + self.individual, + self.d_model, + self.ORT_weight, + self.MIT_weight, ) self._send_model_to_given_device() self._print_model_size() diff --git a/pypots/imputation/dlinear/modules/core.py b/pypots/imputation/dlinear/modules/core.py index 18f33cec..c4f5b24b 100644 --- a/pypots/imputation/dlinear/modules/core.py +++ b/pypots/imputation/dlinear/modules/core.py @@ -22,13 +22,18 @@ def __init__( moving_avg_window_size: int, individual: bool = False, d_model: Optional[int] = None, + ORT_weight: float = 1, + MIT_weight: float = 1, ): super().__init__() self.n_steps = n_steps self.n_features = n_features - self.series_decomp = SeriesDecompositionBlock(moving_avg_window_size) self.individual = individual + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + + self.series_decomp = SeriesDecompositionBlock(moving_avg_window_size) if individual: # create linear layers for each feature individually @@ -119,8 +124,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } if training: + # apply SAITS loss function to DLinear on the imputation task + ORT_loss = calc_mse(output, X, masks) + MIT_loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) # `loss` is always the item for backward propagating to update the model - loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) + loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss results["loss"] = loss return results diff --git a/pypots/imputation/etsformer/model.py b/pypots/imputation/etsformer/model.py index bf5fbc0f..8af9ddfb 100644 --- a/pypots/imputation/etsformer/model.py +++ b/pypots/imputation/etsformer/model.py @@ -63,6 +63,12 @@ class ETSformer(BaseNNImputer): dropout : The dropout rate for the model. + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + batch_size : The batch size for training and evaluating the model. @@ -115,6 +121,8 @@ def __init__( d_ffn, top_k, dropout: float = 0, + ORT_weight: float = 1, + MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -144,6 +152,8 @@ def __init__( self.d_ffn = d_ffn self.dropout = dropout self.top_k = top_k + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight # set up the model self.model = _ETSformer( @@ -156,6 +166,8 @@ def __init__( self.d_ffn, self.dropout, self.top_k, + self.ORT_weight, + self.MIT_weight, ) self._send_model_to_given_device() self._print_model_size() diff --git a/pypots/imputation/etsformer/modules/core.py b/pypots/imputation/etsformer/modules/core.py index 57faa6de..a1ad1de7 100644 --- a/pypots/imputation/etsformer/modules/core.py +++ b/pypots/imputation/etsformer/modules/core.py @@ -30,11 +30,15 @@ def __init__( d_ffn, dropout, top_k, + ORT_weight: float = 1, + MIT_weight: float = 1, activation="sigmoid", ): super().__init__() self.n_steps = n_steps + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight self.enc_embedding = DataEmbedding( n_features * 2, @@ -98,8 +102,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } if training: + # apply SAITS loss function to ETSformer on the imputation task + ORT_loss = calc_mse(output, X, masks) + MIT_loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) # `loss` is always the item for backward propagating to update the model - loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) + loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss results["loss"] = loss return results diff --git a/pypots/imputation/fedformer/model.py b/pypots/imputation/fedformer/model.py index f9376753..761d0362 100644 --- a/pypots/imputation/fedformer/model.py +++ b/pypots/imputation/fedformer/model.py @@ -71,6 +71,12 @@ class FEDformer(BaseNNImputer): Get modes on frequency domain. It has to "random" or "low". The default value is "random". 'random' means sampling randomly; 'low' means sampling the lowest modes; + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + batch_size : The batch size for training and evaluating the model. @@ -125,6 +131,8 @@ def __init__( version="Fourier", modes=32, mode_select="random", + ORT_weight: float = 1, + MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -151,11 +159,13 @@ def __init__( self.n_heads = n_heads self.d_model = d_model self.d_ffn = d_ffn - self.modes = modes - self.mode_select = mode_select self.moving_avg_window_size = moving_avg_window_size self.dropout = dropout self.version = version + self.modes = modes + self.mode_select = mode_select + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight # set up the model self.model = _FEDformer( @@ -170,6 +180,8 @@ def __init__( self.version, self.modes, self.mode_select, + self.ORT_weight, + self.MIT_weight, ) self._send_model_to_given_device() self._print_model_size() diff --git a/pypots/imputation/fedformer/modules/core.py b/pypots/imputation/fedformer/modules/core.py index 00f5241a..97f31c1f 100644 --- a/pypots/imputation/fedformer/modules/core.py +++ b/pypots/imputation/fedformer/modules/core.py @@ -33,10 +33,15 @@ def __init__( version="Fourier", modes=32, mode_select="random", + ORT_weight: float = 1, + MIT_weight: float = 1, activation="relu", ): super().__init__() + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + self.enc_embedding = DataEmbedding( n_features * 2, d_model, @@ -101,8 +106,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } if training: + # apply SAITS loss function to FEDformer on the imputation task + ORT_loss = calc_mse(output, X, masks) + MIT_loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) # `loss` is always the item for backward propagating to update the model - loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) + loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss results["loss"] = loss return results diff --git a/pypots/imputation/informer/model.py b/pypots/imputation/informer/model.py index 070873ff..81afaf9a 100644 --- a/pypots/imputation/informer/model.py +++ b/pypots/imputation/informer/model.py @@ -61,6 +61,12 @@ class Informer(BaseNNImputer): dropout : The dropout rate for the model. + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + batch_size : The batch size for training and evaluating the model. @@ -112,6 +118,8 @@ def __init__( d_ffn: int, factor: int, dropout: float = 0, + ORT_weight: float = 1, + MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -140,6 +148,8 @@ def __init__( self.d_ffn = d_ffn self.factor = factor self.dropout = dropout + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight # set up the model self.model = _Informer( @@ -151,6 +161,8 @@ def __init__( self.d_ffn, self.factor, self.dropout, + self.ORT_weight, + self.MIT_weight, ) self._send_model_to_given_device() self._print_model_size() diff --git a/pypots/imputation/informer/modules/core.py b/pypots/imputation/informer/modules/core.py index e6240c63..ef6b3f20 100644 --- a/pypots/imputation/informer/modules/core.py +++ b/pypots/imputation/informer/modules/core.py @@ -25,14 +25,18 @@ def __init__( d_ffn, factor, dropout, + ORT_weight: float = 1, + MIT_weight: float = 1, distil=False, activation="relu", - output_attention=False, ): super().__init__() self.seq_len = n_steps self.n_layers = n_layers + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + self.enc_embedding = DataEmbedding( n_features * 2, d_model, @@ -46,7 +50,7 @@ def __init__( d_model, d_model // n_heads, d_model // n_heads, - ProbAttention(False, factor, dropout, output_attention), + ProbAttention(False, factor, dropout), ), d_model, d_ffn, @@ -87,8 +91,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } if training: + # apply SAITS loss function to Informer on the imputation task + ORT_loss = calc_mse(output, X, masks) + MIT_loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) # `loss` is always the item for backward propagating to update the model - loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) + loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss results["loss"] = loss return results diff --git a/pypots/imputation/informer/modules/submodules.py b/pypots/imputation/informer/modules/submodules.py index 8465feeb..4bb5af4b 100644 --- a/pypots/imputation/informer/modules/submodules.py +++ b/pypots/imputation/informer/modules/submodules.py @@ -60,15 +60,13 @@ def __init__( self, mask_flag=True, factor=5, - scale=None, attention_dropout=0.1, - output_attention=False, + scale=None, ): super().__init__() self.factor = factor self.scale = scale self.mask_flag = mask_flag - self.output_attention = output_attention self.dropout = nn.Dropout(attention_dropout) def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) @@ -121,14 +119,12 @@ def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): context_in[ torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : ] = torch.matmul(attn, V).type_as(context_in) - if self.output_attention: - attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) - attns[ - torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : - ] = attn - return (context_in, attns) - else: - return (context_in, None) + + attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) + attns[ + torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : + ] = attn + return context_in, attns def forward( self, diff --git a/pypots/imputation/patchtst/model.py b/pypots/imputation/patchtst/model.py index ec1810fd..a5d8ee0a 100644 --- a/pypots/imputation/patchtst/model.py +++ b/pypots/imputation/patchtst/model.py @@ -73,6 +73,12 @@ class PatchTST(BaseNNImputer): dropout : The dropout rate for the model. + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + batch_size : The batch size for training and evaluating the model. @@ -128,6 +134,8 @@ def __init__( d_ffn: int, dropout: float, attn_dropout: float, + ORT_weight: float = 1, + MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -169,6 +177,8 @@ def __init__( self.d_ffn = d_ffn self.dropout = dropout self.attn_dropout = attn_dropout + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight # set up the model self.model = _PatchTST( @@ -184,6 +194,8 @@ def __init__( self.stride, self.dropout, self.attn_dropout, + self.ORT_weight, + self.MIT_weight, ) self._send_model_to_given_device() self._print_model_size() diff --git a/pypots/imputation/patchtst/modules/core.py b/pypots/imputation/patchtst/modules/core.py index c1fc97c7..e826bea4 100644 --- a/pypots/imputation/patchtst/modules/core.py +++ b/pypots/imputation/patchtst/modules/core.py @@ -29,6 +29,8 @@ def __init__( stride: int, dropout: float, attn_dropout: float, + ORT_weight: float = 1, + MIT_weight: float = 1, ): super().__init__() @@ -40,6 +42,8 @@ def __init__( self.n_features = n_features self.n_layers = n_layers self.d_model = d_model + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight self.embedding = nn.Linear(n_features * 2, d_model) self.patch_embedding = PatchEmbedding( @@ -97,8 +101,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } if training: + # apply SAITS loss function to PatchTST on the imputation task + ORT_loss = calc_mse(output, X, masks) + MIT_loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) # `loss` is always the item for backward propagating to update the model - loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) + loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss results["loss"] = loss return results