Skip to content

Commit

Permalink
Merge pull request #504 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Add ModernTCN
  • Loading branch information
WenjieDu authored Sep 4, 2024
2 parents 2d702aa + 66da59c commit 2e8063a
Show file tree
Hide file tree
Showing 14 changed files with 1,222 additions and 4 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ The paper references and links are all listed at the bottom of this file.
| LLM | <a href="https://time-series.ai"><img src="https://time-series.ai/static/figs/robot.svg" width="26px"> Time-Series.AI</a> [^36] |||||| `Later in 2024` |
| Neural Net | TimeMixer[^37] || | | | | `2024 - ICLR` |
| Neural Net | iTransformer🧑‍🔧[^24] || | | | | `2024 - ICLR` |
| Neural Net | ModernTCN[^38] || | | | | `2024 - ICLR` |
| Neural Net | ImputeFormer🧑‍🔧[^34] || | | | | `2024 - KDD` |
| Neural Net | SAITS[^1] || | | | | `2023 - ESWA` |
| Neural Net | FreTS🧑‍🔧[^23] || | | | | `2023 - NeurIPS` |
Expand Down Expand Up @@ -397,3 +398,4 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together
Hard to perform multitask learning with your time series? Not problems no longer. We'll open application for public beta test recently ;-) Follow us, and stay tuned!
<a href="https://time-series.ai"><img src="https://time-series.ai/static/figs/robot.svg" width="20px" align="center"> Time-Series.AI</a>
[^37]: Wang, S., Wu, H., Shi, X., Hu, T., Luo, H., Ma, L., ... & ZHOU, J. (2024). [TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting](https://openreview.net/forum?id=7oLshfEIC2). *ICLR 2024*
[^38]: Luo, D., & Wang X. (2024). [ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis](https://openreview.net/forum?id=vpJMJerXHU). *ICLR 2024*
4 changes: 3 additions & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ PyPOTS当前支持多变量POTS数据的插补,预测,分类,聚类以及
| LLM | <a href="https://time-series.ai"><img src="https://time-series.ai/static/figs/robot.svg" width="26px"> Time-Series.AI</a> [^36] |||||| `Later in 2024` |
| Neural Net | TimeMixer[^37] || | | | | `2024 - ICLR` |
| Neural Net | iTransformer🧑‍🔧[^24] || | | | | `2024 - ICLR` |
| Neural Net | ModernTCN[^38] || | | | | `2024 - ICLR` |
| Neural Net | ImputeFormer🧑‍🔧[^34] || | | | | `2024 - KDD` |
| Neural Net | SAITS[^1] || | | | | `2023 - ESWA` |
| Neural Net | FreTS🧑‍🔧[^23] || | | | | `2023 - NeurIPS` |
Expand Down Expand Up @@ -365,4 +366,5 @@ PyPOTS社区是一个开放、透明、友好的社区,让我们共同努力
[^35]: Bai, S., Kolter, J. Z., & Koltun, V. (2018). [An empirical evaluation of generic convolutional and recurrent networks for sequence modeling](https://arxiv.org/abs/1803.01271). *arXiv 2018*.
[^36]: Gungnir项目,世界上第一个时间序列多任务大模型,将很快与大家见面。🚀 数据集存在缺少值且样本长短不一?多任务建模场景困难?都不再是问题,让我们的大模型来帮你解决。我们将在近期开放公测申请 ;-) 关注我们,敬请期待!
<a href="https://time-series.ai"><img src="https://time-series.ai/static/figs/robot.svg" width="20px" align="center"> Time-Series.AI</a>
[^37]: Wang, S., Wu, H., Shi, X., Hu, T., Luo, H., Ma, L., ... & ZHOU, J. (2024). [TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting](https://openreview.net/forum?id=7oLshfEIC2). *ICLR 2024*
[^37]: Wang, S., Wu, H., Shi, X., Hu, T., Luo, H., Ma, L., ... & ZHOU, J. (2024). [TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting](https://openreview.net/forum?id=7oLshfEIC2). *ICLR 2024*
[^38]: Luo, D., & Wang X. (2024). [ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis](https://openreview.net/forum?id=vpJMJerXHU). *ICLR 2024*
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ The paper references are all listed at the bottom of this readme file.
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | iTransformer🧑‍🔧 :cite:`liu2024itransformer` || | | | | ``2024 - ICLR`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | ModernTCN :cite:`luo2024moderntcn` || | | | | ``2024 - ICLR`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | ImputeFormer :cite:`nie2024imputeformer` || | | | | ``2024 - KDD`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | SAITS :cite:`du2023SAITS` || | | | | ``2023 - ESWA`` |
Expand Down
9 changes: 9 additions & 0 deletions docs/pypots.imputation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ pypots.imputation.timemixer
:show-inheritance:
:inherited-members:

pypots.imputation.moderntcn
------------------------------------

.. automodule:: pypots.imputation.moderntcn
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.imputation.imputeformer
------------------------------------

Expand Down
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .stemgnn import StemGNN
from .imputeformer import ImputeFormer
from .timemixer import TimeMixer
from .moderntcn import ModernTCN

# naive imputation methods
from .locf import LOCF
Expand Down Expand Up @@ -77,6 +78,7 @@
"StemGNN",
"ImputeFormer",
"TimeMixer",
"ModernTCN",
# naive imputation methods
"LOCF",
"Mean",
Expand Down
24 changes: 24 additions & 0 deletions pypots/imputation/moderntcn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package of the partially-observed time-series imputation model ModernTCN.
Refer to the paper
`Donghao Luo, and Xue Wang.
ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis.
In The Twelfth International Conference on Learning Representations. 2024.
<https://openreview.net/pdf?id=vpJMJerXHU>`_
Notes
-----
This implementation is inspired by the official one https://github.com/luodhhh/ModernTCN
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import ModernTCN

__all__ = [
"ModernTCN",
]
95 changes: 95 additions & 0 deletions pypots/imputation/moderntcn/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
The core wrapper assembles the submodules of ModernTCN imputation model
and takes over the forward progress of the algorithm.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch.nn as nn

from ...nn.functional import nonstationary_norm, nonstationary_denorm
from ...nn.modules.moderntcn import BackboneModernTCN
from ...nn.modules.patchtst.layers import FlattenHead
from ...utils.metrics import calc_mse


class _ModernTCN(nn.Module):
def __init__(
self,
n_steps,
n_features,
patch_size,
patch_stride,
downsampling_ratio,
ffn_ratio,
num_blocks: list,
large_size: list,
small_size: list,
dims: list,
small_kernel_merged: bool = False,
backbone_dropout: float = 0.1,
head_dropout: float = 0.1,
use_multi_scale: bool = True,
individual: bool = False,
apply_nonstationary_norm: bool = False,
):
super().__init__()

self.apply_nonstationary_norm = apply_nonstationary_norm

self.backbone = BackboneModernTCN(
n_steps,
n_features,
n_features,
patch_size,
patch_stride,
downsampling_ratio,
ffn_ratio,
num_blocks,
large_size,
small_size,
dims,
small_kernel_merged,
backbone_dropout,
head_dropout,
use_multi_scale,
individual,
)

# for the imputation task, the output dim is the same as input dim
self.projection = FlattenHead(
self.backbone.head_nf,
n_steps,
n_features,
head_dropout,
individual,
)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.apply_nonstationary_norm:
# Normalization from Non-stationary Transformer
X, means, stdev = nonstationary_norm(X, missing_mask)

in_X = X.permute(0, 2, 1)
in_X = self.backbone(in_X)
reconstruction = self.projection(in_X)
reconstruction = reconstruction.permute(0, 2, 1)

if self.apply_nonstationary_norm:
# De-Normalization from Non-stationary Transformer
reconstruction = nonstationary_denorm(reconstruction, means, stdev)

imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
loss = calc_mse(reconstruction, inputs["X_ori"], inputs["indicating_mask"])
results["loss"] = loss

return results
24 changes: 24 additions & 0 deletions pypots/imputation/moderntcn/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for ModernTCN.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForModernTCN(DatasetForSAITS):
"""Actually ModernTCN uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading

0 comments on commit 2e8063a

Please sign in to comment.