Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PatchTST as an imputation model #323

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .transformer import Transformer
from .timesnet import TimesNet
from .autoformer import Autoformer
from .patchtst import PatchTST
from .usgan import USGAN

# naive imputation methods
Expand All @@ -26,6 +27,7 @@
"SAITS",
"Transformer",
"TimesNet",
"PatchTST",
"Autoformer",
"BRITS",
"MRNN",
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/autoformer/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Dataset class for TimesNet.
Dataset class for Autoformer.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
Expand Down
12 changes: 6 additions & 6 deletions pypots/imputation/autoformer/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
The implementation of Transformer for the partially-observed time-series imputation task.
The implementation of Autoformer for the partially-observed time-series imputation task.

Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021).
Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.".
Expand Down Expand Up @@ -31,7 +31,7 @@

class Autoformer(BaseNNImputer):
"""The PyTorch implementation of the Autoformer model.
TimesNet is originally proposed by Wu et al. in :cite:`wu2021autoformer`.
Autoformer is originally proposed by Wu et al. in :cite:`wu2021autoformer`.

Parameters
----------
Expand All @@ -56,7 +56,7 @@ class Autoformer(BaseNNImputer):
factor :
The factor of the auto correlation mechanism for the Autoformer model.

moving_avg_kernel_size :
moving_avg_window_size :
The window size of moving average.

dropout :
Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(
d_model: int,
d_ffn: int,
factor: int,
moving_avg_kernel_size: int,
moving_avg_window_size: int,
dropout: float = 0,
batch_size: int = 32,
epochs: int = 100,
Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(
self.d_model = d_model
self.d_ffn = d_ffn
self.factor = factor
self.moving_avg_kernel_size = moving_avg_kernel_size
self.moving_avg_window_size = moving_avg_window_size
self.dropout = dropout

# set up the model
Expand All @@ -161,7 +161,7 @@ def __init__(
self.d_model,
self.d_ffn,
self.factor,
self.moving_avg_kernel_size,
self.moving_avg_window_size,
self.dropout,
)
self._send_model_to_given_device()
Expand Down
6 changes: 3 additions & 3 deletions pypots/imputation/autoformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
d_model,
d_ffn,
factor,
moving_avg_kernel_size,
moving_avg_window_size,
dropout,
activation="relu",
output_attention=False,
Expand All @@ -38,7 +38,7 @@ def __init__(

self.seq_len = n_steps
self.n_layers = n_layers
self.series_decomp = SeriesDecompositionBlock(moving_avg_kernel_size)
self.series_decomp = SeriesDecompositionBlock(moving_avg_window_size)
self.enc_embedding = DataEmbedding_wo_Pos(
n_features,
d_model,
Expand All @@ -54,7 +54,7 @@ def __init__(
),
d_model,
d_ffn,
moving_avg_kernel_size,
moving_avg_window_size,
dropout,
activation,
)
Expand Down
17 changes: 17 additions & 0 deletions pypots/imputation/patchtst/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
The package of the partially-observed time-series imputation model PatchTST.

Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021).
PatchTST: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.".

"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import PatchTST

__all__ = [
"PatchTST",
]
24 changes: 24 additions & 0 deletions pypots/imputation/patchtst/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for PatchTST.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForPatchTST(DatasetForSAITS):
"""Actually PatchTST uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_labels: bool,
file_type: str = "h5py",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_labels, file_type, rate)
Loading
Loading