Skip to content

Commit

Permalink
Merge pull request #346 from WenjieDu/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu authored Apr 11, 2024
2 parents eb03a15 + 77c7ab2 commit 985a5c2
Show file tree
Hide file tree
Showing 18 changed files with 178 additions and 46 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,14 @@ mae = calc_mae(imputation, np.nan_to_num(X_ori), indicating_mask) # calculate m
## ❖ Available Algorithms
PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values.
The currently available algorithms of four tasks are cataloged in the following table with four partitions.
The paper references are all listed at the bottom of this readme file. Please refer to them if you want more details.
The paper references are all listed at the bottom of this readme file.

🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support.
This functionality is implemented with the [Microsoft NNI](https://github.com/microsoft/nni) framework.

🔥 Note that Transformer, Crossformer, PatchTST, DLinear, ETSformer, FEDformer, Informer, Autoformer are not proposed as imputation methods in their original papers,
and they cannot accept POTS as input. **To make them applicable on POTS data, we apply the embedding strategy the same as we did in [SAITS paper](https://arxiv.org/pdf/2202.08516).**
and they cannot accept POTS as input. **To make them applicable on POTS data, we apply the embedding strategy and training approach (ORT+MIT)
the same as we did in [SAITS paper](https://arxiv.org/pdf/2202.08516).**

| ***`Imputation`*** | 🚥 | 🚥 | 🚥 |
|:----------------------:|:-----------:|:-----------------------------------------------------------------------------------------------:|:--------:|
Expand Down
17 changes: 13 additions & 4 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,14 @@ Welcome to PyPOTS docs!
-----------------


⦿ `Motivation`: Due to all kinds of reasons like failure of collection sensors, communication error, and unexpected malfunction, missing values are common to see in time series from the real-world environment. This makes partially-observed time series (POTS) a pervasive problem in open-world modeling and prevents advanced data analysis. Although this problem is important, the area of data mining on POTS still lacks a dedicated toolkit. PyPOTS is created to fill in this blank.
⦿ `Motivation`: Due to all kinds of reasons like failure of collection sensors, communication error, and unexpected malfunction, missing values are common to see in time series from the real-world environment.
This makes partially-observed time series (POTS) a pervasive problem in open-world modeling and prevents advanced data analysis.
Although this problem is important, the area of data mining on POTS still lacks a dedicated toolkit. PyPOTS is created to fill in this blank.

⦿ `Mission`: PyPOTS is born to become a handy toolbox that is going to make data mining on POTS easy rather than tedious, to help engineers and researchers focus more on the core problems in their hands rather than on how to deal with the missing parts in their data. PyPOTS will keep integrating classical and the latest state-of-the-art data mining algorithms for partially-observed multivariate time series. For sure, besides various algorithms, PyPOTS is going to have unified APIs together with detailed documentation and interactive examples across algorithms as tutorials.
⦿ `Mission`: PyPOTS is born to become a handy toolbox that is going to make data mining on POTS easy rather than tedious,
to help engineers and researchers focus more on the core problems in their hands rather than on how to deal with the missing parts in their data.
PyPOTS will keep integrating classical and the latest state-of-the-art data mining algorithms for partially-observed multivariate time series.
For sure, besides various algorithms, PyPOTS is going to have unified APIs together with detailed documentation and interactive examples across algorithms as tutorials.

🤗 **Please** star this repo to help others notice PyPOTS if you think it is a useful toolkit.
**Please** properly `cite PyPOTS <https://docs.pypots.com/en/latest/milestones.html#citing-pypots>`_ in your publications
Expand Down Expand Up @@ -164,12 +169,16 @@ Additionally, we present you a usage example of imputing missing values in time

❖ Available Algorithms
^^^^^^^^^^^^^^^^^^^^^^^
PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values. The currently available algorithms of four tasks are cataloged in the following table with four partitions. The paper references are all listed at the bottom of this readme file. Please refer to them if you want more details.

PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values.
The currently available algorithms of four tasks are cataloged in the following table with four partitions. The paper references are all listed `on the reference page </references.html>`_.

🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support.
This functionality is implemented with the `Microsoft NNI <https://github.com/microsoft/nni>`_ framework.

🔥 Note that Transformer, Crossformer, PatchTST, DLinear, ETSformer, FEDformer, Informer, Autoformer are not proposed as imputation methods in their original papers,
and they cannot accept POTS as input. To make them applicable on POTS data, we apply the embedding strategy and training approach (ORT+MIT)
the same as we did in `SAITS paper <https://arxiv.org/pdf/2202.08516)>`_.

============================== ================ ========================================================================================================= ====== =========
Task Type Algorithm Year Reference
============================== ================ ========================================================================================================= ====== =========
Expand Down
12 changes: 12 additions & 0 deletions pypots/imputation/autoformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class Autoformer(BaseNNImputer):
dropout :
The dropout rate for the model.
ORT_weight :
The weight for the ORT loss, the same as SAITS.
MIT_weight :
The weight for the MIT loss, the same as SAITS.
batch_size :
The batch size for training and evaluating the model.
Expand Down Expand Up @@ -115,6 +121,8 @@ def __init__(
factor: int,
moving_avg_window_size: int,
dropout: float = 0,
ORT_weight: float = 1,
MIT_weight: float = 1,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
Expand Down Expand Up @@ -144,6 +152,8 @@ def __init__(
self.factor = factor
self.moving_avg_window_size = moving_avg_window_size
self.dropout = dropout
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight

# set up the model
self.model = _Autoformer(
Expand All @@ -156,6 +166,8 @@ def __init__(
self.factor,
self.moving_avg_window_size,
self.dropout,
self.ORT_weight,
self.MIT_weight,
)
self._send_model_to_given_device()
self._print_model_size()
Expand Down
13 changes: 10 additions & 3 deletions pypots/imputation/autoformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ def __init__(
factor,
moving_avg_window_size,
dropout,
ORT_weight: float = 1,
MIT_weight: float = 1,
activation="relu",
output_attention=False,
):
super().__init__()

self.seq_len = n_steps
self.n_layers = n_layers
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight

self.enc_embedding = DataEmbedding(
n_features * 2,
d_model,
Expand All @@ -48,7 +52,7 @@ def __init__(
[
AutoformerEncoderLayer(
AutoCorrelationLayer(
AutoCorrelation(False, factor, dropout, output_attention),
AutoCorrelation(factor, dropout),
d_model,
n_heads,
),
Expand Down Expand Up @@ -91,8 +95,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

if training:
# apply SAITS loss function to Autoformer on the imputation task
ORT_loss = calc_mse(output, X, masks)
MIT_loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"])
# `loss` is always the item for backward propagating to update the model
loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"])
loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss
results["loss"] = loss

return results
11 changes: 1 addition & 10 deletions pypots/imputation/autoformer/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,11 @@ class AutoCorrelation(nn.Module):

def __init__(
self,
mask_flag=True,
factor=1,
scale=None,
attention_dropout=0.1,
output_attention=False,
):
super().__init__()
self.factor = factor
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)

def time_delay_agg_training(self, values, corr):
Expand Down Expand Up @@ -165,10 +159,7 @@ def forward(self, queries, keys, values, attn_mask):
values.permute(0, 2, 3, 1).contiguous(), corr
).permute(0, 3, 1, 2)

if self.output_attention:
return (V.contiguous(), corr.permute(0, 3, 1, 2))
else:
return (V.contiguous(), None)
return V.contiguous(), corr.permute(0, 3, 1, 2)


class AutoCorrelationLayer(nn.Module):
Expand Down
12 changes: 12 additions & 0 deletions pypots/imputation/crossformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ class Crossformer(BaseNNImputer):
dropout :
The dropout rate for the model.
ORT_weight :
The weight for the ORT loss, the same as SAITS.
MIT_weight :
The weight for the MIT loss, the same as SAITS.
batch_size :
The batch size for training and evaluating the model.
Expand Down Expand Up @@ -120,6 +126,8 @@ def __init__(
seg_len: int,
win_size: int,
dropout: float = 0,
ORT_weight: float = 1,
MIT_weight: float = 1,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
Expand Down Expand Up @@ -150,6 +158,8 @@ def __init__(
self.seg_len = seg_len
self.win_size = win_size
self.dropout = dropout
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight

# set up the model
self.model = _Crossformer(
Expand All @@ -163,6 +173,8 @@ def __init__(
self.seg_len,
self.win_size,
self.dropout,
self.ORT_weight,
self.MIT_weight,
)
self._send_model_to_given_device()
self._print_model_size()
Expand Down
9 changes: 8 additions & 1 deletion pypots/imputation/crossformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ def __init__(
seg_len,
win_size,
dropout,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()

self.n_features = n_features
self.d_model = d_model
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight

# The padding operation to handle invisible sgemnet length
pad_in_len = ceil(1.0 * n_steps / seg_len) * seg_len
Expand Down Expand Up @@ -104,8 +108,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

if training:
# apply SAITS loss function to Crossformer on the imputation task
ORT_loss = calc_mse(output, X, masks)
MIT_loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"])
# `loss` is always the item for backward propagating to update the model
loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"])
loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss
results["loss"] = loss

return results
22 changes: 17 additions & 5 deletions pypots/imputation/dlinear/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ class DLinear(BaseNNImputer):
The dimension of the space in which the time-series data will be embedded and modeled.
It is necessary only for DLinear in the non-individual mode.
ORT_weight :
The weight for the ORT loss, the same as SAITS.
MIT_weight :
The weight for the MIT loss, the same as SAITS.
batch_size :
The batch size for training and evaluating the model.
Expand Down Expand Up @@ -101,6 +107,8 @@ def __init__(
moving_avg_window_size: int,
individual: bool = False,
d_model: Optional[int] = None,
ORT_weight: float = 1,
MIT_weight: float = 1,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
Expand All @@ -126,14 +134,18 @@ def __init__(
self.moving_avg_window_size = moving_avg_window_size
self.individual = individual
self.d_model = d_model
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight

# set up the model
self.model = _DLinear(
n_steps,
n_features,
moving_avg_window_size,
individual,
d_model,
self.n_steps,
self.n_features,
self.moving_avg_window_size,
self.individual,
self.d_model,
self.ORT_weight,
self.MIT_weight,
)
self._send_model_to_given_device()
self._print_model_size()
Expand Down
12 changes: 10 additions & 2 deletions pypots/imputation/dlinear/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@ def __init__(
moving_avg_window_size: int,
individual: bool = False,
d_model: Optional[int] = None,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()

self.n_steps = n_steps
self.n_features = n_features
self.series_decomp = SeriesDecompositionBlock(moving_avg_window_size)
self.individual = individual
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight

self.series_decomp = SeriesDecompositionBlock(moving_avg_window_size)

if individual:
# create linear layers for each feature individually
Expand Down Expand Up @@ -119,8 +124,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

if training:
# apply SAITS loss function to DLinear on the imputation task
ORT_loss = calc_mse(output, X, masks)
MIT_loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"])
# `loss` is always the item for backward propagating to update the model
loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"])
loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss
results["loss"] = loss

return results
12 changes: 12 additions & 0 deletions pypots/imputation/etsformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class ETSformer(BaseNNImputer):
dropout :
The dropout rate for the model.
ORT_weight :
The weight for the ORT loss, the same as SAITS.
MIT_weight :
The weight for the MIT loss, the same as SAITS.
batch_size :
The batch size for training and evaluating the model.
Expand Down Expand Up @@ -115,6 +121,8 @@ def __init__(
d_ffn,
top_k,
dropout: float = 0,
ORT_weight: float = 1,
MIT_weight: float = 1,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
Expand Down Expand Up @@ -144,6 +152,8 @@ def __init__(
self.d_ffn = d_ffn
self.dropout = dropout
self.top_k = top_k
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight

# set up the model
self.model = _ETSformer(
Expand All @@ -156,6 +166,8 @@ def __init__(
self.d_ffn,
self.dropout,
self.top_k,
self.ORT_weight,
self.MIT_weight,
)
self._send_model_to_given_device()
self._print_model_size()
Expand Down
9 changes: 8 additions & 1 deletion pypots/imputation/etsformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ def __init__(
d_ffn,
dropout,
top_k,
ORT_weight: float = 1,
MIT_weight: float = 1,
activation="sigmoid",
):
super().__init__()

self.n_steps = n_steps
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight

self.enc_embedding = DataEmbedding(
n_features * 2,
Expand Down Expand Up @@ -98,8 +102,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

if training:
# apply SAITS loss function to ETSformer on the imputation task
ORT_loss = calc_mse(output, X, masks)
MIT_loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"])
# `loss` is always the item for backward propagating to update the model
loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"])
loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss
results["loss"] = loss

return results
Loading

0 comments on commit 985a5c2

Please sign in to comment.