forked from awslabs/gluonts
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
756 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
from .module import TimeMixerModel | ||
from .lightning_module import TimeMixerLightningModule | ||
from .estimator import TimeMixerEstimator | ||
|
||
__all__ = ["TimeMixerModel", "TimeMixerLightningModule", "TimeMixerEstimator"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
from typing import Optional, Iterable, Dict, Any | ||
|
||
import torch | ||
import lightning.pytorch as pl | ||
|
||
from gluonts.core.component import validated | ||
from gluonts.dataset.common import Dataset | ||
from gluonts.dataset.field_names import FieldName | ||
from gluonts.dataset.loader import as_stacked_batches | ||
from gluonts.itertools import Cyclic | ||
from gluonts.transform import ( | ||
Transformation, | ||
AddObservedValuesIndicator, | ||
InstanceSampler, | ||
InstanceSplitter, | ||
ValidationSplitSampler, | ||
TestSplitSampler, | ||
ExpectedNumInstanceSampler, | ||
SelectFields, | ||
) | ||
from gluonts.torch.model.estimator import PyTorchLightningEstimator | ||
from gluonts.torch.model.predictor import PyTorchPredictor | ||
from gluonts.torch.distributions import Output, StudentTOutput | ||
|
||
from .lightning_module import TimeMixerLightningModule | ||
|
||
PREDICTION_INPUT_NAMES = ["past_target", "past_observed_values"] | ||
|
||
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [ | ||
"future_target", | ||
"future_observed_values", | ||
] | ||
|
||
|
||
class TimeMixerEstimator(PyTorchLightningEstimator): | ||
""" | ||
An estimator training the Time Mixer model form the paper | ||
https://openreview.net/pdf?id=7oLshfEIC2 extended for probabilistic | ||
forecasting. | ||
This class is uses the model defined in ``TimeMixerModel``, | ||
and wraps it into a ``TimeMixerLightningModule`` for training | ||
purposes: training is performed using PyTorch Lightning's ``pl.Trainer`` | ||
class. | ||
Parameters | ||
---------- | ||
prediction_length | ||
Length of the prediction horizon. | ||
context_length | ||
Number of time steps prior to prediction time that the model | ||
takes as inputs (default: ``10 * prediction_length``). | ||
hidden_dimension | ||
Size of representation. | ||
lr | ||
Learning rate (default: ``1e-3``). | ||
weight_decay | ||
Weight decay regularization parameter (default: ``1e-8``). | ||
distr_output | ||
Distribution to use to evaluate observations and sample predictions | ||
(default: StudentTOutput()). | ||
kernel_size | ||
batch_size | ||
The size of the batches to be used for training (default: 32). | ||
num_batches_per_epoch | ||
Number of batches to be processed in each training epoch | ||
(default: 50). | ||
trainer_kwargs | ||
Additional arguments to provide to ``pl.Trainer`` for construction. | ||
train_sampler | ||
Controls the sampling of windows during training. | ||
validation_sampler | ||
Controls the sampling of windows during validation. | ||
""" | ||
|
||
@validated() | ||
def __init__( | ||
self, | ||
prediction_length: int, | ||
context_length: Optional[int] = None, | ||
hidden_dimension: Optional[int] = None, | ||
lr: float = 1e-3, | ||
weight_decay: float = 1e-8, | ||
scaling: Optional[str] = "mean", | ||
distr_output: Output = StudentTOutput(), | ||
kernel_size: int = 25, | ||
batch_size: int = 32, | ||
num_batches_per_epoch: int = 50, | ||
trainer_kwargs: Optional[Dict[str, Any]] = None, | ||
train_sampler: Optional[InstanceSampler] = None, | ||
validation_sampler: Optional[InstanceSampler] = None, | ||
) -> None: | ||
default_trainer_kwargs = { | ||
"max_epochs": 100, | ||
"gradient_clip_val": 10.0, | ||
} | ||
if trainer_kwargs is not None: | ||
default_trainer_kwargs.update(trainer_kwargs) | ||
super().__init__(trainer_kwargs=default_trainer_kwargs) | ||
|
||
self.prediction_length = prediction_length | ||
self.context_length = context_length or 10 * prediction_length | ||
# TODO find way to enforce same defaults to network and estimator | ||
# somehow | ||
self.hidden_dimension = hidden_dimension or 20 | ||
self.lr = lr | ||
self.weight_decay = weight_decay | ||
self.distr_output = distr_output | ||
self.scaling = scaling | ||
self.kernel_size = kernel_size | ||
self.batch_size = batch_size | ||
self.num_batches_per_epoch = num_batches_per_epoch | ||
|
||
self.train_sampler = train_sampler or ExpectedNumInstanceSampler( | ||
num_instances=1.0, min_future=prediction_length | ||
) | ||
self.validation_sampler = validation_sampler or ValidationSplitSampler( | ||
min_future=prediction_length | ||
) | ||
|
||
def create_transformation(self) -> Transformation: | ||
return SelectFields( | ||
[ | ||
FieldName.ITEM_ID, | ||
FieldName.INFO, | ||
FieldName.START, | ||
FieldName.TARGET, | ||
], | ||
allow_missing=True, | ||
) + AddObservedValuesIndicator( | ||
target_field=FieldName.TARGET, | ||
output_field=FieldName.OBSERVED_VALUES, | ||
) | ||
|
||
def create_lightning_module(self) -> pl.LightningModule: | ||
return TimeMixerLightningModule( | ||
lr=self.lr, | ||
weight_decay=self.weight_decay, | ||
model_kwargs={ | ||
"prediction_length": self.prediction_length, | ||
"context_length": self.context_length, | ||
"hidden_dimension": self.hidden_dimension, | ||
"distr_output": self.distr_output, | ||
"kernel_size": self.kernel_size, | ||
"scaling": self.scaling, | ||
}, | ||
) | ||
|
||
def _create_instance_splitter( | ||
self, module: TimeMixerLightningModule, mode: str | ||
): | ||
assert mode in ["training", "validation", "test"] | ||
|
||
instance_sampler = { | ||
"training": self.train_sampler, | ||
"validation": self.validation_sampler, | ||
"test": TestSplitSampler(), | ||
}[mode] | ||
|
||
return InstanceSplitter( | ||
target_field=FieldName.TARGET, | ||
is_pad_field=FieldName.IS_PAD, | ||
start_field=FieldName.START, | ||
forecast_start_field=FieldName.FORECAST_START, | ||
instance_sampler=instance_sampler, | ||
past_length=self.context_length, | ||
future_length=self.prediction_length, | ||
time_series_fields=[ | ||
FieldName.OBSERVED_VALUES, | ||
], | ||
dummy_value=self.distr_output.value_in_support, | ||
) | ||
|
||
def create_training_data_loader( | ||
self, | ||
data: Dataset, | ||
module: TimeMixerLightningModule, | ||
shuffle_buffer_length: Optional[int] = None, | ||
**kwargs, | ||
) -> Iterable: | ||
data = Cyclic(data).stream() | ||
instances = self._create_instance_splitter(module, "training").apply( | ||
data, is_train=True | ||
) | ||
return as_stacked_batches( | ||
instances, | ||
batch_size=self.batch_size, | ||
shuffle_buffer_length=shuffle_buffer_length, | ||
field_names=TRAINING_INPUT_NAMES, | ||
output_type=torch.tensor, | ||
num_batches_per_epoch=self.num_batches_per_epoch, | ||
) | ||
|
||
def create_validation_data_loader( | ||
self, | ||
data: Dataset, | ||
module: TimeMixerLightningModule, | ||
**kwargs, | ||
) -> Iterable: | ||
instances = self._create_instance_splitter(module, "validation").apply( | ||
data, is_train=True | ||
) | ||
return as_stacked_batches( | ||
instances, | ||
batch_size=self.batch_size, | ||
field_names=TRAINING_INPUT_NAMES, | ||
output_type=torch.tensor, | ||
) | ||
|
||
def create_predictor( | ||
self, | ||
transformation: Transformation, | ||
module, | ||
) -> PyTorchPredictor: | ||
prediction_splitter = self._create_instance_splitter(module, "test") | ||
|
||
return PyTorchPredictor( | ||
input_transform=transformation + prediction_splitter, | ||
input_names=PREDICTION_INPUT_NAMES, | ||
prediction_net=module, | ||
forecast_generator=self.distr_output.forecast_generator, | ||
batch_size=self.batch_size, | ||
prediction_length=self.prediction_length, | ||
device="auto", | ||
) |
100 changes: 100 additions & 0 deletions
100
src/gluonts/torch/model/times_mixer/lightning_module.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
import lightning.pytorch as pl | ||
import torch | ||
|
||
from gluonts.core.component import validated | ||
from gluonts.itertools import select | ||
|
||
from .module import TimeMixerModel | ||
|
||
|
||
class TimeMixerLightningModule(pl.LightningModule): | ||
""" | ||
A ``pl.LightningModule`` class that can be used to train a ``TimeMixerModel`` | ||
with PyTorch Lightning. | ||
This is a thin layer around a (wrapped) ``TimeMixerModel`` object, | ||
that exposes the methods to evaluate training and validation loss. | ||
Parameters | ||
---------- | ||
model_kwargs | ||
Keyword arguments to construct the ``TimeMixerModel`` to be trained. | ||
loss | ||
Loss function to be used for training. | ||
lr | ||
Learning rate. | ||
weight_decay | ||
Weight decay regularization parameter. | ||
""" | ||
|
||
@validated() | ||
def __init__( | ||
self, | ||
model_kwargs: dict, | ||
lr: float = 1e-3, | ||
weight_decay: float = 1e-8, | ||
): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.model = TimeMixerModel(**model_kwargs) | ||
self.lr = lr | ||
self.weight_decay = weight_decay | ||
self.inputs = self.model.describe_inputs() | ||
|
||
def forward(self, *args, **kwargs): | ||
return self.model.forward(*args, **kwargs) | ||
|
||
def training_step(self, batch, batch_idx: int): # type: ignore | ||
""" | ||
Execute training step. | ||
""" | ||
train_loss = self.model.loss( | ||
**select(self.inputs, batch), | ||
future_target=batch["future_target"], | ||
future_observed_values=batch["future_observed_values"], | ||
).mean() | ||
self.log( | ||
"train_loss", | ||
train_loss, | ||
on_epoch=True, | ||
on_step=False, | ||
prog_bar=True, | ||
) | ||
return train_loss | ||
|
||
def validation_step(self, batch, batch_idx: int): # type: ignore | ||
""" | ||
Execute validation step. | ||
""" | ||
val_loss = self.model.loss( | ||
**select(self.inputs, batch), | ||
future_target=batch["future_target"], | ||
future_observed_values=batch["future_observed_values"], | ||
).mean() | ||
self.log( | ||
"val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True | ||
) | ||
return val_loss | ||
|
||
def configure_optimizers(self): | ||
""" | ||
Returns the optimizer to use. | ||
""" | ||
return torch.optim.Adam( | ||
self.model.parameters(), | ||
lr=self.lr, | ||
weight_decay=self.weight_decay, | ||
) |
Oops, something went wrong.