From 7623272ba1206049f2c8c2eb678b8c4c0b96391f Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 21 Aug 2024 15:51:27 +0800 Subject: [PATCH 1/5] feat: add ModernTCN modules; --- pypots/nn/modules/moderntcn/__init__.py | 24 ++ pypots/nn/modules/moderntcn/backbone.py | 184 +++++++++++++ pypots/nn/modules/moderntcn/layers.py | 328 ++++++++++++++++++++++++ 3 files changed, 536 insertions(+) create mode 100644 pypots/nn/modules/moderntcn/__init__.py create mode 100644 pypots/nn/modules/moderntcn/backbone.py create mode 100644 pypots/nn/modules/moderntcn/layers.py diff --git a/pypots/nn/modules/moderntcn/__init__.py b/pypots/nn/modules/moderntcn/__init__.py new file mode 100644 index 00000000..bfca4e48 --- /dev/null +++ b/pypots/nn/modules/moderntcn/__init__.py @@ -0,0 +1,24 @@ +""" +The package including the modules of ModernTCN. + +Refer to the paper +`Donghao Luo, and Xue Wang. +ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis. +In The Twelfth International Conference on Learning Representations. 2024. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/luodhhh/ModernTCN + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .backbone import BackboneModernTCN + +__all__ = [ + "BackboneModernTCN", +] diff --git a/pypots/nn/modules/moderntcn/backbone.py b/pypots/nn/modules/moderntcn/backbone.py new file mode 100644 index 00000000..a9e3b388 --- /dev/null +++ b/pypots/nn/modules/moderntcn/backbone.py @@ -0,0 +1,184 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn + +from .layers import Stage +from ..patchtst.layers import FlattenHead + + +class BackboneModernTCN(nn.Module): + def __init__( + self, + n_steps, + n_features, + n_predict_features, + patch_size, + patch_stride, + downsampling_ratio, + ffn_ratio, + num_blocks: list, + large_size: list, + small_size: list, + dims: list, + small_kernel_merged: bool = False, + backbone_dropout: float = 0.1, + head_dropout: float = 0.1, + use_multi_scale: bool = True, + individual: bool = False, + freq: str = "h", + ): + super().__init__() + + # stem layer & down sampling layers + self.downsample_layers = nn.ModuleList() + stem = nn.Linear(patch_size, dims[0]) + self.downsample_layers.append(stem) + + self.num_stage = len(num_blocks) + if self.num_stage > 1: + for i in range(self.num_stage - 1): + downsample_layer = nn.Sequential( + nn.BatchNorm1d(dims[i]), + nn.Conv1d( + dims[i], + dims[i + 1], + kernel_size=downsampling_ratio, + stride=downsampling_ratio, + ), + ) + self.downsample_layers.append(downsample_layer) + + self.patch_size = patch_size + self.patch_stride = patch_stride + self.downsample_ratio = downsampling_ratio + + if freq == "h": + time_feature_num = 4 + elif freq == "t": + time_feature_num = 5 + else: + raise NotImplementedError("time_feature_num should be 4 or 5") + + self.te_patch = nn.Sequential( + nn.Conv1d( + time_feature_num, + time_feature_num, + kernel_size=patch_size, + stride=patch_stride, + groups=time_feature_num, + ), + nn.Conv1d(time_feature_num, dims[0], kernel_size=1, stride=1, groups=1), + nn.BatchNorm1d(dims[0]), + ) + + # backbone + self.stages = nn.ModuleList() + for stage_idx in range(self.num_stage): + layer = Stage( + ffn_ratio, + num_blocks[stage_idx], + large_size[stage_idx], + small_size[stage_idx], + dmodel=dims[stage_idx], + nvars=n_features, + small_kernel_merged=small_kernel_merged, + drop=backbone_dropout, + ) + self.stages.append(layer) + + # Multi scale fusing + self.use_multi_scale = use_multi_scale + self.up_sample_ratio = downsampling_ratio + + self.lat_layer = nn.ModuleList() + self.smooth_layer = nn.ModuleList() + self.up_sample_conv = nn.ModuleList() + for i in range(self.num_stage): + align_dim = dims[-1] + lat = nn.Conv1d(dims[i], align_dim, kernel_size=1, stride=1) + self.lat_layer.append(lat) + smooth = nn.Conv1d(align_dim, align_dim, kernel_size=3, stride=1, padding=1) + self.smooth_layer.append(smooth) + up_conv = nn.Sequential( + nn.ConvTranspose1d( + align_dim, + align_dim, + kernel_size=self.up_sample_ratio, + stride=self.up_sample_ratio, + ), + nn.BatchNorm1d(align_dim), + ) + self.up_sample_conv.append(up_conv) + + # head + patch_num = n_steps // patch_stride + + self.n_features = n_features + self.individual = individual + d_model = dims[self.num_stage - 1] + + if use_multi_scale: + self.head_nf = d_model * patch_num + self.head = FlattenHead( + self.head_nf, + n_predict_features, + n_features, + head_dropout, + individual, + ) + else: + if patch_num % pow(downsampling_ratio, (self.num_stage - 1)) == 0: + self.head_nf = ( + d_model * patch_num // pow(downsampling_ratio, (self.num_stage - 1)) + ) + else: + self.head_nf = d_model * ( + patch_num // pow(downsampling_ratio, (self.num_stage - 1)) + 1 + ) + + self.head = FlattenHead( + self.head_nf, + n_predict_features, + n_features, + head_dropout, + individual, + ) + + def structural_reparam(self): + for m in self.modules(): + if hasattr(m, "merge_kernel"): + m.merge_kernel() + + def forward(self, x): + x = x.unsqueeze(-2) + + for i in range(self.num_stage): + B, M, D, N = x.shape + x = x.reshape(B * M, D, N) + + if i == 0: + if self.patch_size != self.patch_stride: + pad_len = self.patch_size - self.patch_stride + pad = x[:, :, -1:].repeat(1, 1, pad_len) + x = torch.cat([x, pad], dim=-1) + x = x.reshape(B, M, 1, -1).squeeze(-2) + x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride) + x = self.downsample_layers[i](x) + x = x.permute(0, 1, 3, 2) + + else: + if N % self.downsample_ratio != 0: + pad_len = self.downsample_ratio - (N % self.downsample_ratio) + x = torch.cat([x, x[:, :, -pad_len:]], dim=-1) + x = self.downsample_layers[i](x) + _, D_, N_ = x.shape + x = x.reshape(B, M, D_, N_) + + x = self.stages[i](x) + return x diff --git a/pypots/nn/modules/moderntcn/layers.py b/pypots/nn/modules/moderntcn/layers.py new file mode 100644 index 00000000..b7c21058 --- /dev/null +++ b/pypots/nn/modules/moderntcn/layers.py @@ -0,0 +1,328 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +from torch import nn + + +def get_conv1d( + in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias +): + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + +def get_bn(channels): + return nn.BatchNorm1d(channels) + + +def conv_bn( + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + dilation=1, + bias=False, +): + if padding is None: + padding = kernel_size // 2 + result = nn.Sequential() + result.add_module( + "conv", + get_conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ), + ) + result.add_module("bn", get_bn(out_channels)) + return result + + +def fuse_bn(conv, bn): + kernel = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-6): + super().__init__() + self.norm = nn.LayerNorm(channels, eps=eps) + + def forward(self, x): + B, M, D, N = x.shape + x = x.permute(0, 1, 3, 2) + x = x.reshape(B * M, N, D) + x = self.norm(x) + x = x.reshape(B, M, N, D) + x = x.permute(0, 1, 3, 2) + return x + + +class ReparamLargeKernelConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + groups, + small_kernel, + small_kernel_merged=False, + nvars=7, + ): + super().__init__() + self.kernel_size = kernel_size + self.small_kernel = small_kernel + # We assume the conv does not change the feature map size, so padding = k//2. + # Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly. + padding = kernel_size // 2 + if small_kernel_merged: + self.lkb_reparam = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=1, + groups=groups, + bias=True, + ) + else: + self.lkb_origin = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=1, + groups=groups, + bias=False, + ) + if small_kernel is not None: + assert ( + small_kernel <= kernel_size + ), "The kernel size for re-param cannot be larger than the large kernel!" + self.small_conv = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=small_kernel, + stride=stride, + padding=small_kernel // 2, + groups=groups, + dilation=1, + bias=False, + ) + + def forward(self, inputs): + if hasattr(self, "lkb_reparam"): + out = self.lkb_reparam(inputs) + else: + out = self.lkb_origin(inputs) + if hasattr(self, "small_conv"): + out += self.small_conv(inputs) + return out + + def PaddingTwoEdge1d(self, x, pad_length_left, pad_length_right, pad_values=0): + D_out, D_in, ks = x.shape + if pad_values == 0: + pad_left = torch.zeros(D_out, D_in, pad_length_left) + pad_right = torch.zeros(D_out, D_in, pad_length_right) + else: + pad_left = torch.ones(D_out, D_in, pad_length_left) * pad_values + pad_right = torch.ones(D_out, D_in, pad_length_right) * pad_values + x = torch.cat([pad_left, x], dims=-1) + x = torch.cat([x, pad_right], dims=-1) + return x + + def get_equivalent_kernel_bias(self): + eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) + if hasattr(self, "small_conv"): + small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn) + eq_b += small_b + eq_k += self.PaddingTwoEdge1d( + small_k, + (self.kernel_size - self.small_kernel) // 2, + (self.kernel_size - self.small_kernel) // 2, + 0, + ) + return eq_k, eq_b + + def merge_kernel(self): + eq_k, eq_b = self.get_equivalent_kernel_bias() + self.lkb_reparam = nn.Conv1d( + in_channels=self.lkb_origin.conv.in_channels, + out_channels=self.lkb_origin.conv.out_channels, + kernel_size=self.lkb_origin.conv.kernel_size, + stride=self.lkb_origin.conv.stride, + padding=self.lkb_origin.conv.padding, + dilation=self.lkb_origin.conv.dilation, + groups=self.lkb_origin.conv.groups, + bias=True, + ) + self.lkb_reparam.weight.data = eq_k + self.lkb_reparam.bias.data = eq_b + self.__delattr__("lkb_origin") + if hasattr(self, "small_conv"): + self.__delattr__("small_conv") + + +class Block(nn.Module): + def __init__( + self, + large_size, + small_size, + dmodel, + dff, + nvars, + small_kernel_merged=False, + drop=0.1, + ): + super().__init__() + self.dw = ReparamLargeKernelConv( + in_channels=nvars * dmodel, + out_channels=nvars * dmodel, + kernel_size=large_size, + stride=1, + groups=nvars * dmodel, + small_kernel=small_size, + small_kernel_merged=small_kernel_merged, + nvars=nvars, + ) + self.norm = nn.BatchNorm1d(dmodel) + + # convffn1 + self.ffn1pw1 = nn.Conv1d( + in_channels=nvars * dmodel, + out_channels=nvars * dff, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=nvars, + ) + self.ffn1act = nn.GELU() + self.ffn1pw2 = nn.Conv1d( + in_channels=nvars * dff, + out_channels=nvars * dmodel, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=nvars, + ) + self.ffn1drop1 = nn.Dropout(drop) + self.ffn1drop2 = nn.Dropout(drop) + + # convffn2 + self.ffn2pw1 = nn.Conv1d( + in_channels=nvars * dmodel, + out_channels=nvars * dff, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=dmodel, + ) + self.ffn2act = nn.GELU() + self.ffn2pw2 = nn.Conv1d( + in_channels=nvars * dff, + out_channels=nvars * dmodel, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=dmodel, + ) + self.ffn2drop1 = nn.Dropout(drop) + self.ffn2drop2 = nn.Dropout(drop) + + self.ffn_ratio = dff // dmodel + + def forward(self, x): + input = x + B, M, D, N = x.shape + x = x.reshape(B, M * D, N) + x = self.dw(x) + x = x.reshape(B, M, D, N) + x = x.reshape(B * M, D, N) + x = self.norm(x) + x = x.reshape(B, M, D, N) + x = x.reshape(B, M * D, N) + + x = self.ffn1drop1(self.ffn1pw1(x)) + x = self.ffn1act(x) + x = self.ffn1drop2(self.ffn1pw2(x)) + x = x.reshape(B, M, D, N) + + x = x.permute(0, 2, 1, 3) + x = x.reshape(B, D * M, N) + x = self.ffn2drop1(self.ffn2pw1(x)) + x = self.ffn2act(x) + x = self.ffn2drop2(self.ffn2pw2(x)) + x = x.reshape(B, D, M, N) + x = x.permute(0, 2, 1, 3) + + x = input + x + return x + + +class Stage(nn.Module): + def __init__( + self, + ffn_ratio, + num_blocks, + large_size, + small_size, + dmodel, + nvars, + small_kernel_merged=False, + drop=0.1, + ): + super().__init__() + d_ffn = dmodel * ffn_ratio + blks = [] + for i in range(num_blocks): + blk = Block( + large_size=large_size, + small_size=small_size, + dmodel=dmodel, + dff=d_ffn, + nvars=nvars, + small_kernel_merged=small_kernel_merged, + drop=drop, + ) + blks.append(blk) + self.blocks = nn.ModuleList(blks) + + def forward(self, x): + for blk in self.blocks: + x = blk(x) + + return x From e45424765807def433721291b3117396191bb801 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 22 Aug 2024 11:17:20 +0800 Subject: [PATCH 2/5] feat: add FlattenHead; --- pypots/nn/modules/patchtst/layers.py | 49 ++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/pypots/nn/modules/patchtst/layers.py b/pypots/nn/modules/patchtst/layers.py index 083c368a..3990954b 100644 --- a/pypots/nn/modules/patchtst/layers.py +++ b/pypots/nn/modules/patchtst/layers.py @@ -106,13 +106,13 @@ def __init__( head_dim = d_model * n_patches self.individual = individual - self.n_vars = n_features + self.n_features = n_features if self.individual: self.linears = nn.ModuleList() self.dropouts = nn.ModuleList() self.flattens = nn.ModuleList() - for i in range(self.n_vars): + for i in range(self.n_features): self.flattens.append(nn.Flatten(start_dim=-2)) self.linears.append(nn.Linear(head_dim, n_steps_forecast)) self.dropouts.append(nn.Dropout(head_dropout)) @@ -128,7 +128,7 @@ def forward(self, x): """ if self.individual: x_out = [] - for i in range(self.n_vars): + for i in range(self.n_features): z = self.flattens[i](x[:, i, :, :]) # z: [bs x d_model * num_patch] z = self.linears[i](z) # z: [bs x forecast_len] z = self.dropouts[i](z) @@ -139,3 +139,46 @@ def forward(self, x): x = self.dropout(x) x = self.linear(x) # x: [bs x nvars x forecast_len] return x.transpose(2, 1) # [bs x forecast_len x nvars] + + +class FlattenHead(nn.Module): + def __init__( + self, + d_input, + d_output, + n_features, + head_dropout=0, + individual=False, + ): + super().__init__() + + self.individual = individual + self.n_features = n_features + + if self.individual: + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + self.flattens = nn.ModuleList() + for i in range(self.n_features): + self.flattens.append(nn.Flatten(start_dim=-2)) + self.linears.append(nn.Linear(d_input, d_output)) + self.dropouts.append(nn.Dropout(head_dropout)) + else: + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(d_input, d_output) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x): + if self.individual: + x_out = [] + for i in range(self.n_features): + z = self.flattens[i](x[:, i, :, :]) # z: [bs x d_model * patch_num] + z = self.linears[i](z) # z: [bs x target_window] + z = self.dropouts[i](z) + x_out.append(z) + x = torch.stack(x_out, dim=1) # x: [bs x nvars x target_window] + else: + x = self.flatten(x) + x = self.linear(x) + x = self.dropout(x) + return x From cdfe6dc9e2298dc3a4059e6c784f6ed9c7d7f983 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 22 Aug 2024 15:50:02 +0800 Subject: [PATCH 3/5] feat: implement ModernTCN as an imputation model; --- pypots/imputation/moderntcn/__init__.py | 24 ++ pypots/imputation/moderntcn/core.py | 95 +++++++ pypots/imputation/moderntcn/data.py | 24 ++ pypots/imputation/moderntcn/model.py | 342 ++++++++++++++++++++++++ 4 files changed, 485 insertions(+) create mode 100644 pypots/imputation/moderntcn/__init__.py create mode 100644 pypots/imputation/moderntcn/core.py create mode 100644 pypots/imputation/moderntcn/data.py create mode 100644 pypots/imputation/moderntcn/model.py diff --git a/pypots/imputation/moderntcn/__init__.py b/pypots/imputation/moderntcn/__init__.py new file mode 100644 index 00000000..f82da16b --- /dev/null +++ b/pypots/imputation/moderntcn/__init__.py @@ -0,0 +1,24 @@ +""" +The package of the partially-observed time-series imputation model ModernTCN. + +Refer to the paper +`Donghao Luo, and Xue Wang. +ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis. +In The Twelfth International Conference on Learning Representations. 2024. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/luodhhh/ModernTCN + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import ModernTCN + +__all__ = [ + "ModernTCN", +] diff --git a/pypots/imputation/moderntcn/core.py b/pypots/imputation/moderntcn/core.py new file mode 100644 index 00000000..3ca8e8f5 --- /dev/null +++ b/pypots/imputation/moderntcn/core.py @@ -0,0 +1,95 @@ +""" +The core wrapper assembles the submodules of ModernTCN imputation model +and takes over the forward progress of the algorithm. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from ...nn.functional import nonstationary_norm, nonstationary_denorm +from ...nn.modules.moderntcn import BackboneModernTCN +from ...nn.modules.patchtst.layers import FlattenHead +from ...utils.metrics import calc_mse + + +class _ModernTCN(nn.Module): + def __init__( + self, + n_steps, + n_features, + patch_size, + patch_stride, + downsampling_ratio, + ffn_ratio, + num_blocks: list, + large_size: list, + small_size: list, + dims: list, + small_kernel_merged: bool = False, + backbone_dropout: float = 0.1, + head_dropout: float = 0.1, + use_multi_scale: bool = True, + individual: bool = False, + apply_nonstationary_norm: bool = False, + ): + super().__init__() + + self.apply_nonstationary_norm = apply_nonstationary_norm + + self.backbone = BackboneModernTCN( + n_steps, + n_features, + n_features, + patch_size, + patch_stride, + downsampling_ratio, + ffn_ratio, + num_blocks, + large_size, + small_size, + dims, + small_kernel_merged, + backbone_dropout, + head_dropout, + use_multi_scale, + individual, + ) + + # for the imputation task, the output dim is the same as input dim + self.projection = FlattenHead( + self.backbone.head_nf, + n_steps, + n_features, + head_dropout, + individual, + ) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, missing_mask = inputs["X"], inputs["missing_mask"] + + if self.apply_nonstationary_norm: + # Normalization from Non-stationary Transformer + X, means, stdev = nonstationary_norm(X, missing_mask) + + in_X = X.permute(0, 2, 1) + in_X = self.backbone(in_X) + reconstruction = self.projection(in_X) + reconstruction = reconstruction.permute(0, 2, 1) + + if self.apply_nonstationary_norm: + # De-Normalization from Non-stationary Transformer + reconstruction = nonstationary_denorm(reconstruction, means, stdev) + + imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction + results = { + "imputed_data": imputed_data, + } + + # if in training mode, return results with losses + if training: + loss = calc_mse(reconstruction, inputs["X_ori"], inputs["indicating_mask"]) + results["loss"] = loss + + return results diff --git a/pypots/imputation/moderntcn/data.py b/pypots/imputation/moderntcn/data.py new file mode 100644 index 00000000..c296728a --- /dev/null +++ b/pypots/imputation/moderntcn/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for ModernTCN. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForModernTCN(DatasetForSAITS): + """Actually ModernTCN uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_y: bool, + file_type: str = "hdf5", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_y, file_type, rate) diff --git a/pypots/imputation/moderntcn/model.py b/pypots/imputation/moderntcn/model.py new file mode 100644 index 00000000..2efb3fed --- /dev/null +++ b/pypots/imputation/moderntcn/model.py @@ -0,0 +1,342 @@ +""" +The implementation of ModernTCN for the partially-observed time-series imputation task. + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .core import _ModernTCN +from .data import DatasetForModernTCN +from ..base import BaseNNImputer +from ...data.checking import key_in_data_set +from ...data.dataset import BaseDataset +from ...optim.adam import Adam +from ...optim.base import Optimizer + + +class ModernTCN(BaseNNImputer): + """The PyTorch implementation of the ModernTCN model. + ModernTCN is originally proposed by Luo et al. in :cite:`luo2024moderntcn`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + patch_size : + The size of the patch for the patching mechanism. + + patch_stride : + The stride for the patching mechanism. + + downsampling_ratio : + The downsampling ratio for the downsampling mechanism. + + ffn_ratio : + The ratio for the feed-forward neural network in the model. + + num_blocks : + The number of blocks for the model. It should be a list of integers. + + large_size : + The size of the large kernel. It should be a list of odd integers. + + small_size : + The size of the small kernel. It should be a list of odd integers. + + dims : + The dimensions for the model. It should be a list of integers. + + small_kernel_merged : + Whether the small kernel is merged. + + backbone_dropout : + The dropout rate for the backbone of the model. + + head_dropout : + The dropout rate for the head of the model. + + use_multi_scale : + Whether to use multi-scale fusing. + + individual : + Whether to make a linear layer for each variate/channel/feature individually. + + apply_nonstationary_norm : + Whether to apply non-stationary normalization. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + verbose : + Whether to print out the training logs during the training process. + """ + + def __init__( + self, + n_steps: int, + n_features: int, + patch_size: int, + patch_stride: int, + downsampling_ratio: float, + ffn_ratio: float, + num_blocks: list, + large_size: list, + small_size: list, + dims: list, + small_kernel_merged: bool = False, + backbone_dropout: float = 0.1, + head_dropout: float = 0.1, + use_multi_scale: bool = True, + individual: bool = False, + apply_nonstationary_norm: bool = False, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + verbose: bool = True, + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + verbose, + ) + assert ( + len(num_blocks) == len(dims) == len(large_size) == len(small_size) + ), "The length of num_blocks, dims, large_size, and small_size should be the same." + + self.n_steps = n_steps + self.n_features = n_features + + # set up the model + self.model = _ModernTCN( + n_steps, + n_features, + patch_size, + patch_stride, + downsampling_ratio, + ffn_ratio, + num_blocks, + large_size, + small_size, + dims, + small_kernel_merged, + backbone_dropout, + head_dropout, + use_multi_scale, + individual, + apply_nonstationary_norm, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForModernTCN( + train_set, return_X_ori=False, return_y=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForModernTCN( + val_set, return_X_ori=True, return_y=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + Returns + ------- + file_type : + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, + return_X_ori=False, + return_X_pred=False, + return_y=False, + file_type=file_type, + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (n_steps), n_features], + Imputed data. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["imputation"] From 34b4f4e8dfe76cab150fabc383ed451d3b2e4b0f Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 4 Sep 2024 12:20:42 +0800 Subject: [PATCH 4/5] test: add ModernTCN tests; --- pypots/imputation/__init__.py | 2 + tests/imputation/moderntcn.py | 137 ++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+) create mode 100644 tests/imputation/moderntcn.py diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 4272f36e..8136b0f8 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -37,6 +37,7 @@ from .stemgnn import StemGNN from .imputeformer import ImputeFormer from .timemixer import TimeMixer +from .moderntcn import ModernTCN # naive imputation methods from .locf import LOCF @@ -77,6 +78,7 @@ "StemGNN", "ImputeFormer", "TimeMixer", + "ModernTCN", # naive imputation methods "LOCF", "Mean", diff --git a/tests/imputation/moderntcn.py b/tests/imputation/moderntcn.py new file mode 100644 index 00000000..33b41269 --- /dev/null +++ b/tests/imputation/moderntcn.py @@ -0,0 +1,137 @@ +""" +Test cases for ModernTCN imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import ModernTCN +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + GENERAL_H5_TRAIN_SET_PATH, + GENERAL_H5_VAL_SET_PATH, + GENERAL_H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestModernTCN(unittest.TestCase): + logger.info("Running tests for an imputation model ModernTCN...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "ModernTCN") + model_save_name = "saved_moderntcn_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a ModernTCN model + moderntcn = ModernTCN( + DATA["n_steps"], + DATA["n_features"], + patch_size=3, + patch_stride=2, + downsampling_ratio=2, + ffn_ratio=1, + num_blocks=[1], + large_size=[5], + small_size=[3], + dims=[32], + small_kernel_merged=False, + backbone_dropout=0.1, + head_dropout=0.1, + use_multi_scale=False, + individual=False, + apply_nonstationary_norm=False, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_0_fit(self): + self.moderntcn.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_1_impute(self): + imputation_results = self.moderntcn.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"ModernTCN test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_2_parameters(self): + assert hasattr(self.moderntcn, "model") and self.moderntcn.model is not None + + assert ( + hasattr(self.moderntcn, "optimizer") + and self.moderntcn.optimizer is not None + ) + + assert hasattr(self.moderntcn, "best_loss") + self.assertNotEqual(self.moderntcn.best_loss, float("inf")) + + assert ( + hasattr(self.moderntcn, "best_model_dict") + and self.moderntcn.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.moderntcn) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.moderntcn.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.moderntcn.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-moderntcn") + def test_4_lazy_loading(self): + self.moderntcn.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) + imputation_results = self.moderntcn.predict(GENERAL_H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading ModernTCN test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() From 66da59c96b718e454b1542c51c3d2acd10052038 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 4 Sep 2024 12:31:09 +0800 Subject: [PATCH 5/5] Add ModernTCN docs (#503) --- README.md | 2 ++ README_zh.md | 4 +++- docs/index.rst | 2 ++ docs/pypots.imputation.rst | 9 +++++++++ 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 95e74695..4bf1684a 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,7 @@ The paper references and links are all listed at the bottom of this file. | LLM | Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` | | Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` | | Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | +| Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` | | Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` | | Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` | | Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` | @@ -397,3 +398,4 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together Hard to perform multitask learning with your time series? Not problems no longer. We'll open application for public beta test recently ;-) Follow us, and stay tuned! Time-Series.AI [^37]: Wang, S., Wu, H., Shi, X., Hu, T., Luo, H., Ma, L., ... & ZHOU, J. (2024). [TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting](https://openreview.net/forum?id=7oLshfEIC2). *ICLR 2024* +[^38]: Luo, D., & Wang X. (2024). [ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis](https://openreview.net/forum?id=vpJMJerXHU). *ICLR 2024* diff --git a/README_zh.md b/README_zh.md index d3f760cc..4d55e68d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -104,6 +104,7 @@ PyPOTS当前支持多变量POTS数据的插补,预测,分类,聚类以及 | LLM | Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` | | Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` | | Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | +| Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` | | Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` | | Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` | | Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` | @@ -365,4 +366,5 @@ PyPOTS社区是一个开放、透明、友好的社区,让我们共同努力 [^35]: Bai, S., Kolter, J. Z., & Koltun, V. (2018). [An empirical evaluation of generic convolutional and recurrent networks for sequence modeling](https://arxiv.org/abs/1803.01271). *arXiv 2018*. [^36]: Gungnir项目,世界上第一个时间序列多任务大模型,将很快与大家见面。🚀 数据集存在缺少值且样本长短不一?多任务建模场景困难?都不再是问题,让我们的大模型来帮你解决。我们将在近期开放公测申请 ;-) 关注我们,敬请期待! Time-Series.AI -[^37]: Wang, S., Wu, H., Shi, X., Hu, T., Luo, H., Ma, L., ... & ZHOU, J. (2024). [TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting](https://openreview.net/forum?id=7oLshfEIC2). *ICLR 2024* +[^37]: Wang, S., Wu, H., Shi, X., Hu, T., Luo, H., Ma, L., ... & ZHOU, J. (2024). [TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting](https://openreview.net/forum?id=7oLshfEIC2). *ICLR 2024* +[^38]: Luo, D., & Wang X. (2024). [ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis](https://openreview.net/forum?id=vpJMJerXHU). *ICLR 2024* diff --git a/docs/index.rst b/docs/index.rst index c0204b42..113fbb65 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -137,6 +137,8 @@ The paper references are all listed at the bottom of this readme file. +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | iTransformer🧑‍🔧 :cite:`liu2024itransformer` | ✅ | | | | | ``2024 - ICLR`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ +| Neural Net | ModernTCN :cite:`luo2024moderntcn` | ✅ | | | | | ``2024 - ICLR`` | ++----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | ImputeFormer :cite:`nie2024imputeformer` | ✅ | | | | | ``2024 - KDD`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | SAITS :cite:`du2023SAITS` | ✅ | | | | | ``2023 - ESWA`` | diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst index 8af2d63f..b7a94b03 100644 --- a/docs/pypots.imputation.rst +++ b/docs/pypots.imputation.rst @@ -28,6 +28,15 @@ pypots.imputation.timemixer :show-inheritance: :inherited-members: +pypots.imputation.moderntcn +------------------------------------ + +.. automodule:: pypots.imputation.moderntcn + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.imputation.imputeformer ------------------------------------