Skip to content

Commit

Permalink
initial time mixer
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jun 15, 2024
1 parent 36164f6 commit e08c75a
Show file tree
Hide file tree
Showing 4 changed files with 756 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/gluonts/torch/model/times_mixer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from .module import TimeMixerModel
from .lightning_module import TimeMixerLightningModule
from .estimator import TimeMixerEstimator

__all__ = ["TimeMixerModel", "TimeMixerLightningModule", "TimeMixerEstimator"]
239 changes: 239 additions & 0 deletions src/gluonts/torch/model/times_mixer/estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Optional, Iterable, Dict, Any

import torch
import lightning.pytorch as pl

from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import as_stacked_batches
from gluonts.itertools import Cyclic
from gluonts.transform import (
Transformation,
AddObservedValuesIndicator,
InstanceSampler,
InstanceSplitter,
ValidationSplitSampler,
TestSplitSampler,
ExpectedNumInstanceSampler,
SelectFields,
)
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.distributions import Output, StudentTOutput

from .lightning_module import TimeMixerLightningModule

PREDICTION_INPUT_NAMES = ["past_target", "past_observed_values"]

TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
"future_target",
"future_observed_values",
]


class TimeMixerEstimator(PyTorchLightningEstimator):
"""
An estimator training the Time Mixer model form the paper
https://openreview.net/pdf?id=7oLshfEIC2 extended for probabilistic
forecasting.
This class is uses the model defined in ``TimeMixerModel``,
and wraps it into a ``TimeMixerLightningModule`` for training
purposes: training is performed using PyTorch Lightning's ``pl.Trainer``
class.
Parameters
----------
prediction_length
Length of the prediction horizon.
context_length
Number of time steps prior to prediction time that the model
takes as inputs (default: ``10 * prediction_length``).
hidden_dimension
Size of representation.
lr
Learning rate (default: ``1e-3``).
weight_decay
Weight decay regularization parameter (default: ``1e-8``).
distr_output
Distribution to use to evaluate observations and sample predictions
(default: StudentTOutput()).
kernel_size
batch_size
The size of the batches to be used for training (default: 32).
num_batches_per_epoch
Number of batches to be processed in each training epoch
(default: 50).
trainer_kwargs
Additional arguments to provide to ``pl.Trainer`` for construction.
train_sampler
Controls the sampling of windows during training.
validation_sampler
Controls the sampling of windows during validation.
"""

@validated()
def __init__(
self,
prediction_length: int,
context_length: Optional[int] = None,
hidden_dimension: Optional[int] = None,
lr: float = 1e-3,
weight_decay: float = 1e-8,
scaling: Optional[str] = "mean",
distr_output: Output = StudentTOutput(),
kernel_size: int = 25,
batch_size: int = 32,
num_batches_per_epoch: int = 50,
trainer_kwargs: Optional[Dict[str, Any]] = None,
train_sampler: Optional[InstanceSampler] = None,
validation_sampler: Optional[InstanceSampler] = None,
) -> None:
default_trainer_kwargs = {
"max_epochs": 100,
"gradient_clip_val": 10.0,
}
if trainer_kwargs is not None:
default_trainer_kwargs.update(trainer_kwargs)
super().__init__(trainer_kwargs=default_trainer_kwargs)

self.prediction_length = prediction_length
self.context_length = context_length or 10 * prediction_length
# TODO find way to enforce same defaults to network and estimator
# somehow
self.hidden_dimension = hidden_dimension or 20
self.lr = lr
self.weight_decay = weight_decay
self.distr_output = distr_output
self.scaling = scaling
self.kernel_size = kernel_size
self.batch_size = batch_size
self.num_batches_per_epoch = num_batches_per_epoch

self.train_sampler = train_sampler or ExpectedNumInstanceSampler(
num_instances=1.0, min_future=prediction_length
)
self.validation_sampler = validation_sampler or ValidationSplitSampler(
min_future=prediction_length
)

def create_transformation(self) -> Transformation:
return SelectFields(
[
FieldName.ITEM_ID,
FieldName.INFO,
FieldName.START,
FieldName.TARGET,
],
allow_missing=True,
) + AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
)

def create_lightning_module(self) -> pl.LightningModule:
return TimeMixerLightningModule(
lr=self.lr,
weight_decay=self.weight_decay,
model_kwargs={
"prediction_length": self.prediction_length,
"context_length": self.context_length,
"hidden_dimension": self.hidden_dimension,
"distr_output": self.distr_output,
"kernel_size": self.kernel_size,
"scaling": self.scaling,
},
)

def _create_instance_splitter(
self, module: TimeMixerLightningModule, mode: str
):
assert mode in ["training", "validation", "test"]

instance_sampler = {
"training": self.train_sampler,
"validation": self.validation_sampler,
"test": TestSplitSampler(),
}[mode]

return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=self.context_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.OBSERVED_VALUES,
],
dummy_value=self.distr_output.value_in_support,
)

def create_training_data_loader(
self,
data: Dataset,
module: TimeMixerLightningModule,
shuffle_buffer_length: Optional[int] = None,
**kwargs,
) -> Iterable:
data = Cyclic(data).stream()
instances = self._create_instance_splitter(module, "training").apply(
data, is_train=True
)
return as_stacked_batches(
instances,
batch_size=self.batch_size,
shuffle_buffer_length=shuffle_buffer_length,
field_names=TRAINING_INPUT_NAMES,
output_type=torch.tensor,
num_batches_per_epoch=self.num_batches_per_epoch,
)

def create_validation_data_loader(
self,
data: Dataset,
module: TimeMixerLightningModule,
**kwargs,
) -> Iterable:
instances = self._create_instance_splitter(module, "validation").apply(
data, is_train=True
)
return as_stacked_batches(
instances,
batch_size=self.batch_size,
field_names=TRAINING_INPUT_NAMES,
output_type=torch.tensor,
)

def create_predictor(
self,
transformation: Transformation,
module,
) -> PyTorchPredictor:
prediction_splitter = self._create_instance_splitter(module, "test")

return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module,
forecast_generator=self.distr_output.forecast_generator,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device="auto",
)
100 changes: 100 additions & 0 deletions src/gluonts/torch/model/times_mixer/lightning_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import lightning.pytorch as pl
import torch

from gluonts.core.component import validated
from gluonts.itertools import select

from .module import TimeMixerModel


class TimeMixerLightningModule(pl.LightningModule):
"""
A ``pl.LightningModule`` class that can be used to train a ``TimeMixerModel``
with PyTorch Lightning.
This is a thin layer around a (wrapped) ``TimeMixerModel`` object,
that exposes the methods to evaluate training and validation loss.
Parameters
----------
model_kwargs
Keyword arguments to construct the ``TimeMixerModel`` to be trained.
loss
Loss function to be used for training.
lr
Learning rate.
weight_decay
Weight decay regularization parameter.
"""

@validated()
def __init__(
self,
model_kwargs: dict,
lr: float = 1e-3,
weight_decay: float = 1e-8,
):
super().__init__()
self.save_hyperparameters()
self.model = TimeMixerModel(**model_kwargs)
self.lr = lr
self.weight_decay = weight_decay
self.inputs = self.model.describe_inputs()

def forward(self, *args, **kwargs):
return self.model.forward(*args, **kwargs)

def training_step(self, batch, batch_idx: int): # type: ignore
"""
Execute training step.
"""
train_loss = self.model.loss(
**select(self.inputs, batch),
future_target=batch["future_target"],
future_observed_values=batch["future_observed_values"],
).mean()
self.log(
"train_loss",
train_loss,
on_epoch=True,
on_step=False,
prog_bar=True,
)
return train_loss

def validation_step(self, batch, batch_idx: int): # type: ignore
"""
Execute validation step.
"""
val_loss = self.model.loss(
**select(self.inputs, batch),
future_target=batch["future_target"],
future_observed_values=batch["future_observed_values"],
).mean()
self.log(
"val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True
)
return val_loss

def configure_optimizers(self):
"""
Returns the optimizer to use.
"""
return torch.optim.Adam(
self.model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
Loading

0 comments on commit e08c75a

Please sign in to comment.