From e08c75a753dbda5925294014afe29fe293996dab Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 15 Jun 2024 13:06:53 +0200 Subject: [PATCH] initial time mixer --- .../torch/model/times_mixer/__init__.py | 18 + .../torch/model/times_mixer/estimator.py | 239 +++++++++++ .../model/times_mixer/lightning_module.py | 100 +++++ src/gluonts/torch/model/times_mixer/module.py | 399 ++++++++++++++++++ 4 files changed, 756 insertions(+) create mode 100644 src/gluonts/torch/model/times_mixer/__init__.py create mode 100644 src/gluonts/torch/model/times_mixer/estimator.py create mode 100644 src/gluonts/torch/model/times_mixer/lightning_module.py create mode 100644 src/gluonts/torch/model/times_mixer/module.py diff --git a/src/gluonts/torch/model/times_mixer/__init__.py b/src/gluonts/torch/model/times_mixer/__init__.py new file mode 100644 index 0000000000..71e833b655 --- /dev/null +++ b/src/gluonts/torch/model/times_mixer/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from .module import TimeMixerModel +from .lightning_module import TimeMixerLightningModule +from .estimator import TimeMixerEstimator + +__all__ = ["TimeMixerModel", "TimeMixerLightningModule", "TimeMixerEstimator"] diff --git a/src/gluonts/torch/model/times_mixer/estimator.py b/src/gluonts/torch/model/times_mixer/estimator.py new file mode 100644 index 0000000000..c9116f3a43 --- /dev/null +++ b/src/gluonts/torch/model/times_mixer/estimator.py @@ -0,0 +1,239 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import Optional, Iterable, Dict, Any + +import torch +import lightning.pytorch as pl + +from gluonts.core.component import validated +from gluonts.dataset.common import Dataset +from gluonts.dataset.field_names import FieldName +from gluonts.dataset.loader import as_stacked_batches +from gluonts.itertools import Cyclic +from gluonts.transform import ( + Transformation, + AddObservedValuesIndicator, + InstanceSampler, + InstanceSplitter, + ValidationSplitSampler, + TestSplitSampler, + ExpectedNumInstanceSampler, + SelectFields, +) +from gluonts.torch.model.estimator import PyTorchLightningEstimator +from gluonts.torch.model.predictor import PyTorchPredictor +from gluonts.torch.distributions import Output, StudentTOutput + +from .lightning_module import TimeMixerLightningModule + +PREDICTION_INPUT_NAMES = ["past_target", "past_observed_values"] + +TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [ + "future_target", + "future_observed_values", +] + + +class TimeMixerEstimator(PyTorchLightningEstimator): + """ + An estimator training the Time Mixer model form the paper + https://openreview.net/pdf?id=7oLshfEIC2 extended for probabilistic + forecasting. + + This class is uses the model defined in ``TimeMixerModel``, + and wraps it into a ``TimeMixerLightningModule`` for training + purposes: training is performed using PyTorch Lightning's ``pl.Trainer`` + class. + + Parameters + ---------- + prediction_length + Length of the prediction horizon. + context_length + Number of time steps prior to prediction time that the model + takes as inputs (default: ``10 * prediction_length``). + hidden_dimension + Size of representation. + lr + Learning rate (default: ``1e-3``). + weight_decay + Weight decay regularization parameter (default: ``1e-8``). + + distr_output + Distribution to use to evaluate observations and sample predictions + (default: StudentTOutput()). + kernel_size + batch_size + The size of the batches to be used for training (default: 32). + num_batches_per_epoch + Number of batches to be processed in each training epoch + (default: 50). + trainer_kwargs + Additional arguments to provide to ``pl.Trainer`` for construction. + train_sampler + Controls the sampling of windows during training. + validation_sampler + Controls the sampling of windows during validation. + """ + + @validated() + def __init__( + self, + prediction_length: int, + context_length: Optional[int] = None, + hidden_dimension: Optional[int] = None, + lr: float = 1e-3, + weight_decay: float = 1e-8, + scaling: Optional[str] = "mean", + distr_output: Output = StudentTOutput(), + kernel_size: int = 25, + batch_size: int = 32, + num_batches_per_epoch: int = 50, + trainer_kwargs: Optional[Dict[str, Any]] = None, + train_sampler: Optional[InstanceSampler] = None, + validation_sampler: Optional[InstanceSampler] = None, + ) -> None: + default_trainer_kwargs = { + "max_epochs": 100, + "gradient_clip_val": 10.0, + } + if trainer_kwargs is not None: + default_trainer_kwargs.update(trainer_kwargs) + super().__init__(trainer_kwargs=default_trainer_kwargs) + + self.prediction_length = prediction_length + self.context_length = context_length or 10 * prediction_length + # TODO find way to enforce same defaults to network and estimator + # somehow + self.hidden_dimension = hidden_dimension or 20 + self.lr = lr + self.weight_decay = weight_decay + self.distr_output = distr_output + self.scaling = scaling + self.kernel_size = kernel_size + self.batch_size = batch_size + self.num_batches_per_epoch = num_batches_per_epoch + + self.train_sampler = train_sampler or ExpectedNumInstanceSampler( + num_instances=1.0, min_future=prediction_length + ) + self.validation_sampler = validation_sampler or ValidationSplitSampler( + min_future=prediction_length + ) + + def create_transformation(self) -> Transformation: + return SelectFields( + [ + FieldName.ITEM_ID, + FieldName.INFO, + FieldName.START, + FieldName.TARGET, + ], + allow_missing=True, + ) + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ) + + def create_lightning_module(self) -> pl.LightningModule: + return TimeMixerLightningModule( + lr=self.lr, + weight_decay=self.weight_decay, + model_kwargs={ + "prediction_length": self.prediction_length, + "context_length": self.context_length, + "hidden_dimension": self.hidden_dimension, + "distr_output": self.distr_output, + "kernel_size": self.kernel_size, + "scaling": self.scaling, + }, + ) + + def _create_instance_splitter( + self, module: TimeMixerLightningModule, mode: str + ): + assert mode in ["training", "validation", "test"] + + instance_sampler = { + "training": self.train_sampler, + "validation": self.validation_sampler, + "test": TestSplitSampler(), + }[mode] + + return InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + instance_sampler=instance_sampler, + past_length=self.context_length, + future_length=self.prediction_length, + time_series_fields=[ + FieldName.OBSERVED_VALUES, + ], + dummy_value=self.distr_output.value_in_support, + ) + + def create_training_data_loader( + self, + data: Dataset, + module: TimeMixerLightningModule, + shuffle_buffer_length: Optional[int] = None, + **kwargs, + ) -> Iterable: + data = Cyclic(data).stream() + instances = self._create_instance_splitter(module, "training").apply( + data, is_train=True + ) + return as_stacked_batches( + instances, + batch_size=self.batch_size, + shuffle_buffer_length=shuffle_buffer_length, + field_names=TRAINING_INPUT_NAMES, + output_type=torch.tensor, + num_batches_per_epoch=self.num_batches_per_epoch, + ) + + def create_validation_data_loader( + self, + data: Dataset, + module: TimeMixerLightningModule, + **kwargs, + ) -> Iterable: + instances = self._create_instance_splitter(module, "validation").apply( + data, is_train=True + ) + return as_stacked_batches( + instances, + batch_size=self.batch_size, + field_names=TRAINING_INPUT_NAMES, + output_type=torch.tensor, + ) + + def create_predictor( + self, + transformation: Transformation, + module, + ) -> PyTorchPredictor: + prediction_splitter = self._create_instance_splitter(module, "test") + + return PyTorchPredictor( + input_transform=transformation + prediction_splitter, + input_names=PREDICTION_INPUT_NAMES, + prediction_net=module, + forecast_generator=self.distr_output.forecast_generator, + batch_size=self.batch_size, + prediction_length=self.prediction_length, + device="auto", + ) diff --git a/src/gluonts/torch/model/times_mixer/lightning_module.py b/src/gluonts/torch/model/times_mixer/lightning_module.py new file mode 100644 index 0000000000..e39a31b1fa --- /dev/null +++ b/src/gluonts/torch/model/times_mixer/lightning_module.py @@ -0,0 +1,100 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import lightning.pytorch as pl +import torch + +from gluonts.core.component import validated +from gluonts.itertools import select + +from .module import TimeMixerModel + + +class TimeMixerLightningModule(pl.LightningModule): + """ + A ``pl.LightningModule`` class that can be used to train a ``TimeMixerModel`` + with PyTorch Lightning. + + This is a thin layer around a (wrapped) ``TimeMixerModel`` object, + that exposes the methods to evaluate training and validation loss. + + Parameters + ---------- + model_kwargs + Keyword arguments to construct the ``TimeMixerModel`` to be trained. + loss + Loss function to be used for training. + lr + Learning rate. + weight_decay + Weight decay regularization parameter. + """ + + @validated() + def __init__( + self, + model_kwargs: dict, + lr: float = 1e-3, + weight_decay: float = 1e-8, + ): + super().__init__() + self.save_hyperparameters() + self.model = TimeMixerModel(**model_kwargs) + self.lr = lr + self.weight_decay = weight_decay + self.inputs = self.model.describe_inputs() + + def forward(self, *args, **kwargs): + return self.model.forward(*args, **kwargs) + + def training_step(self, batch, batch_idx: int): # type: ignore + """ + Execute training step. + """ + train_loss = self.model.loss( + **select(self.inputs, batch), + future_target=batch["future_target"], + future_observed_values=batch["future_observed_values"], + ).mean() + self.log( + "train_loss", + train_loss, + on_epoch=True, + on_step=False, + prog_bar=True, + ) + return train_loss + + def validation_step(self, batch, batch_idx: int): # type: ignore + """ + Execute validation step. + """ + val_loss = self.model.loss( + **select(self.inputs, batch), + future_target=batch["future_target"], + future_observed_values=batch["future_observed_values"], + ).mean() + self.log( + "val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True + ) + return val_loss + + def configure_optimizers(self): + """ + Returns the optimizer to use. + """ + return torch.optim.Adam( + self.model.parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + ) diff --git a/src/gluonts/torch/model/times_mixer/module.py b/src/gluonts/torch/model/times_mixer/module.py new file mode 100644 index 0000000000..be74e9da72 --- /dev/null +++ b/src/gluonts/torch/model/times_mixer/module.py @@ -0,0 +1,399 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import Tuple + +import torch +from torch import nn + +from gluonts.core.component import validated +from gluonts.model import Input, InputSpec +from gluonts.torch.distributions import StudentTOutput +from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler +from gluonts.torch.model.simple_feedforward import make_linear_layer +from gluonts.torch.util import weighted_average + + +class MovingAvg(nn.Module): + """ + Moving average block to highlight the trend of time series. + + Parameters: + - kernel_size (int): The size of the kernel for the average pooling operation. + - stride (int): The stride of the average pooling operation. + + Attributes: + - kernel_size (int): The size of the kernel for the average pooling operation. + + Methods: + - forward(x): Performs the forward pass of the moving average block. + + """ + + def __init__(self, kernel_size, stride): + super().__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d( + kernel_size=kernel_size, stride=stride, padding=0 + ) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, ...].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, ...].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class SeriesDecomp(nn.Module): + """ + Series decomposition block. + + This class represents a series decomposition block that decomposes a time series into its trend and residual components. + It takes a kernel size as input, which determines the size of the moving average window used for trend estimation. + + Parameters: + ----------- + kernel_size : int + The size of the moving average window used for trend estimation. + + Methods: + -------- + forward(x): + Performs the forward pass of the series decomposition block. + + Attributes: + ----------- + moving_avg : MovingAvg + The moving average module used for trend estimation. + """ + + def __init__(self, kernel_size): + super().__init__() + self.moving_avg = MovingAvg(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean + + +class DFT_series_decomp(nn.Module): + """ + Series decomposition block + + This class represents a series decomposition block that performs decomposition of a time series into its seasonal and trend components using Discrete Fourier Transform (DFT). + + Parameters: + - top_k (int): The number of top frequencies to keep during decomposition. Default is 5. + + Methods: + - forward(x): Performs the forward pass of the series decomposition block. + + Returns: + - x_season (torch.Tensor): The seasonal component of the input time series. + - x_trend (torch.Tensor): The trend component of the input time series. + """ + + def __init__(self, top_k=5): + super().__init__() + self.top_k = top_k + + def forward(self, x): + xf = torch.fft.rfft(x) + freq = abs(xf) + freq[0] = 0 + top_k_freq, top_list = torch.topk(freq, 5) + xf[freq <= top_k_freq.min()] = 0 + x_season = torch.fft.irfft(xf) + x_trend = x - x_season + return x_season, x_trend + + +class MultiScaleSeasonMixing(nn.Module): + """ + Bottom-up mixing season pattern + """ + + def __init__( + self, context_length, down_sampling_window, down_sampling_layers + ): + super().__init__() + + self.down_sampling_layers = torch.nn.ModuleList( + [ + nn.Sequential( + torch.nn.Linear( + context_length // (down_sampling_window**i), + context_length // (down_sampling_window ** (i + 1)), + ), + nn.GELU(), + torch.nn.Linear( + context_length // (down_sampling_window ** (i + 1)), + context_length // (down_sampling_window ** (i + 1)), + ), + ) + for i in range(down_sampling_layers) + ] + ) + + def forward(self, season_list): + # mixing high->low + out_high = season_list[0] + out_low = season_list[1] + out_season_list = [out_high.permute(0, 2, 1)] + + for i in range(len(season_list) - 1): + out_low_res = self.down_sampling_layers[i](out_high) + out_low = out_low + out_low_res + out_high = out_low + if i + 2 <= len(season_list) - 1: + out_low = season_list[i + 2] + out_season_list.append(out_high.permute(0, 2, 1)) + + return out_season_list + + +class MultiScaleTrendMixing(nn.Module): + """ + Top-down mixing trend pattern + """ + + def __init__( + self, context_length, down_sampling_window, down_sampling_layers + ): + super().__init__() + + self.up_sampling_layers = torch.nn.ModuleList( + [ + nn.Sequential( + torch.nn.Linear( + context_length // (down_sampling_window ** (i + 1)), + context_length // (down_sampling_window**i), + ), + nn.GELU(), + torch.nn.Linear( + context_length // (down_sampling_window**i), + context_length // (down_sampling_window**i), + ), + ) + for i in reversed(range(down_sampling_layers)) + ] + ) + + def forward(self, trend_list): + # mixing low->high + trend_list_reverse = trend_list.copy() + trend_list_reverse.reverse() + out_low = trend_list_reverse[0] + out_high = trend_list_reverse[1] + out_trend_list = [out_low.permute(0, 2, 1)] + + for i in range(len(trend_list_reverse) - 1): + out_high_res = self.up_sampling_layers[i](out_low) + out_high = out_high + out_high_res + out_low = out_high + if i + 2 <= len(trend_list_reverse) - 1: + out_high = trend_list_reverse[i + 2] + out_trend_list.append(out_low.permute(0, 2, 1)) + + out_trend_list.reverse() + return out_trend_list + + +class PastDecomposableMixing(nn.Module): + def __init__( + self, + context_length, + down_sampling_window, + down_sampling_layers, + d_model, + d_ff, + decomp_method, + kernel_size, + top_k, + ): + super().__init__() + + if decomp_method == "moving_avg": + self.decompsition = SeriesDecomp(kernel_size) + elif decomp_method == "dft_decomp": + self.decompsition = DFT_series_decomp(top_k) + else: + raise ValueError("unknown decomp_method") + + # Mixing season + self.mixing_multi_scale_season = MultiScaleSeasonMixing( + context_length, down_sampling_window, down_sampling_layers + ) + + # Mixing trend + self.mixing_multi_scale_trend = MultiScaleTrendMixing( + context_length, down_sampling_window, down_sampling_layers + ) + + self.out_cross_layer = nn.Sequential( + nn.Linear(in_features=d_model, out_features=d_ff), + nn.GELU(), + nn.Linear(in_features=d_ff, out_features=d_model), + ) + + def forward(self, x_list): + length_list = [] + for x in x_list: + _, T, _ = x.size() + length_list.append(T) + + # Decompose to obtain the season and trend + season_list = [] + trend_list = [] + for x in x_list: + season, trend = self.decompsition(x) + season_list.append(season.permute(0, 2, 1)) + trend_list.append(trend.permute(0, 2, 1)) + + # bottom-up season mixing + out_season_list = self.mixing_multi_scale_season(season_list) + # top-down trend mixing + out_trend_list = self.mixing_multi_scale_trend(trend_list) + + out_list = [] + for ori, out_season, out_trend, length in zip( + x_list, out_season_list, out_trend_list, length_list + ): + out = out_season + out_trend + out = ori + self.out_cross_layer(out) + out_list.append(out[:, :length, :]) + return out_list + + +class TimeMixerModel(nn.Module): + """ + Module implementing a feed-forward model form the paper + https://arxiv.org/pdf/2205.13504.pdf extended for probabilistic + forecasting. + + Parameters + ---------- + prediction_length + Number of time points to predict. + context_length + Number of time steps prior to prediction time that the model. + hidden_dimension + Size of last hidden layers in the feed-forward network. + distr_output + Distribution to use to evaluate observations and sample predictions. + """ + + @validated() + def __init__( + self, + prediction_length: int, + context_length: int, + down_sampling_window: int, + down_sampling_layers: int, + e_layers: int, + d_model: int, + d_ff: int, + down_sampling_method: str = "max", + decomp_method: str = "dft_decomp", + distr_output=StudentTOutput(), + kernel_size: int = 25, + top_k: int = 5, + scaling: str = "mean", + ) -> None: + super().__init__() + + assert prediction_length > 0 + assert context_length > 0 + + self.prediction_length = prediction_length + self.context_length = context_length + + self.down_sampling_method = down_sampling_method + + self.pdm_blocks = nn.ModuleList( + [ + PastDecomposableMixing( + context_length=context_length, + down_sampling_window=down_sampling_window, + d_model=d_model, + d_ff=d_ff, + decomp_method=decomp_method, + kernel_size=kernel_size, + top_k=top_k, + ) + for _ in range(e_layers) + ] + ) + + self.distr_output = distr_output + if scaling == "mean": + self.scaler = MeanScaler(keepdim=True) + elif scaling == "std": + self.scaler = StdScaler(keepdim=True) + else: + self.scaler = NOPScaler(keepdim=True) + + self.predict_layers = torch.nn.ModuleList( + [ + torch.nn.Linear( + prediction_length // (down_sampling_window**i), + prediction_length, + ) + for i in range(down_sampling_layers + 1) + ] + ) + self.args_proj = self.distr_output.get_args_proj(d_model) + + def describe_inputs(self, batch_size=1) -> InputSpec: + return InputSpec( + { + "past_target": Input( + shape=(batch_size, self.context_length), dtype=torch.float + ), + "past_observed_values": Input( + shape=(batch_size, self.context_length), dtype=torch.float + ), + }, + torch.zeros, + ) + + def forward( + self, + past_target: torch.Tensor, + past_observed_values: torch.Tensor, + ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: + # scale the input + past_target_scaled, loc, scale = self.scaler( + past_target, past_observed_values + ) + + return distr_args, loc, scale + + def loss( + self, + past_target: torch.Tensor, + past_observed_values: torch.Tensor, + future_target: torch.Tensor, + future_observed_values: torch.Tensor, + ) -> torch.Tensor: + distr_args, loc, scale = self( + past_target=past_target, past_observed_values=past_observed_values + ) + loss = self.distr_output.loss( + target=future_target, distr_args=distr_args, loc=loc, scale=scale + ) + return weighted_average(loss, weights=future_observed_values, dim=-1)