Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FreTS as an imputation model #370

Merged
merged 3 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/pypots.imputation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ pypots.imputation.transformer
:show-inheritance:
:inherited-members:

pypots.imputation.frets
------------------------------

.. automodule:: pypots.imputation.frets
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.imputation.crossformer
------------------------------

Expand Down
12 changes: 12 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -554,4 +554,16 @@ @inproceedings{zhou2022film
url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/524ef58c2bd075775861234266e5e020-Paper-Conference.pdf},
volume = {35},
year = {2022}
}

@inproceedings{yi2023frets,
author = {Yi, Kun and Zhang, Qi and Fan, Wei and Wang, Shoujin and Wang, Pengyang and He, Hui and An, Ning and Lian, Defu and Cao, Longbing and Niu, Zhendong},
booktitle = {Advances in Neural Information Processing Systems},
editor = {A. Oh and T. Neumann and A. Globerson and K. Saenko and M. Hardt and S. Levine},
pages = {76656--76679},
publisher = {Curran Associates, Inc.},
title = {Frequency-domain MLPs are More Effective Learners in Time Series Forecasting},
url = {https://proceedings.neurips.cc/paper_files/paper/2023/file/f1d16af76939f476b5f040fd1398c0a3-Paper-Conference.pdf},
volume = {36},
year = {2023}
}
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .etsformer import ETSformer
from .fedformer import FEDformer
from .film import FiLM
from .frets import FreTS
from .crossformer import Crossformer
from .informer import Informer
from .autoformer import Autoformer
Expand All @@ -35,6 +36,7 @@
"ETSformer",
"FEDformer",
"FiLM",
"FreTS",
"Crossformer",
"TimesNet",
"PatchTST",
Expand Down
24 changes: 24 additions & 0 deletions pypots/imputation/frets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package of the partially-observed time-series imputation model FreTS.

Refer to the paper
`Kun Yi, Qi Zhang, Wei Fan, Shoujin Wang, Pengyang Wang, Hui He, Ning An, Defu Lian, Longbing Cao, and Zhendong Niu.
"Frequency-domain MLPs are More Effective Learners in Time Series Forecasting."
Advances in Neural Information Processing Systems 36 (2024).
<https://proceedings.neurips.cc/paper_files/paper/2023/file/f1d16af76939f476b5f040fd1398c0a3-Paper-Conference.pdf>`_

Notes
-----
Partial implementation uses code from https://github.com/aikunyi/FreTS

"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import FreTS

__all__ = [
"FreTS",
]
83 changes: 83 additions & 0 deletions pypots/imputation/frets/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""

"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from ...nn.modules.frets import BackboneFreTS
from ...nn.modules.saits import SaitsLoss
from ...nn.modules.transformer.embedding import DataEmbedding


class _FreTS(nn.Module):
def __init__(
self,
n_steps,
n_features,
embed_size: int = 128, # the default value is the same as the fixed one in the original implementation
hidden_size: int = 256, # the default value is the same as the fixed one in the original implementation
channel_independence: bool = False,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()

self.n_steps = n_steps

self.enc_embedding = DataEmbedding(
n_features * 2,
embed_size,
dropout=0,
with_pos=False,
)
self.backbone = BackboneFreTS(
n_steps,
n_features,
embed_size,
n_steps,
hidden_size,
channel_independence,
)

# for the imputation task, the output dim is the same as input dim
self.output_projection = nn.Linear(embed_size, n_features)
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

# WDU: the original FreTS paper isn't proposed for imputation task. Hence the model doesn't take
# the missing mask into account, which means, in the process, the model doesn't know which part of
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the
# embedding layers to project the concatenation of features and masks into a hidden space, as well as
# the output layers to project back from the hidden space to the original space.

# the same as SAITS, concatenate the time series data and the missing mask for embedding
input_X = torch.cat([X, missing_mask], dim=2)
enc_out = self.enc_embedding(input_X)

# FreTS processing
backbone_output = self.backbone(enc_out)
reconstruction = self.output_projection(backbone_output)

imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(
reconstruction, X_ori, missing_mask, indicating_mask
)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss

return results
24 changes: 24 additions & 0 deletions pypots/imputation/frets/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for FreTS.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForFreTS(DatasetForSAITS):
"""Actually FreTS uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading
Loading