diff --git a/nbs/common.base_recurrent.ipynb b/nbs/common.base_recurrent.ipynb index f84087969..da17c378b 100644 --- a/nbs/common.base_recurrent.ipynb +++ b/nbs/common.base_recurrent.ipynb @@ -151,14 +151,15 @@ " self.early_stop_patience_steps = early_stop_patience_steps\n", " self.val_check_steps = val_check_steps\n", "\n", - " # Scaler\n", - " self.scaler = TemporalNorm(scaler_type=scaler_type, dim=-1) # Time dimension is -1.\n", - "\n", " # Variables\n", " self.futr_exog_list = futr_exog_list if futr_exog_list is not None else []\n", " self.hist_exog_list = hist_exog_list if hist_exog_list is not None else []\n", " self.stat_exog_list = stat_exog_list if stat_exog_list is not None else []\n", "\n", + " # Scaler\n", + " self.scaler = TemporalNorm(scaler_type=scaler_type, dim=-1, # Time dimension is -1.\n", + " num_features=1+len(self.hist_exog_list)+len(self.futr_exog_list)) \n", + "\n", " # Fit arguments\n", " self.val_size = 0\n", " self.test_size = 0\n", @@ -219,13 +220,17 @@ " 'interval': 'step'}\n", " return {'optimizer': optimizer, 'lr_scheduler': scheduler}\n", "\n", - " def _normalization(self, batch, val_size=0, test_size=0):\n", + " def _get_temporal_data_cols(self, temporal_cols):\n", + " temporal_data_cols = ['y'] + list(set(temporal_cols.tolist()) &\\\n", + " set(self.hist_exog_list + self.futr_exog_list))\n", + " return temporal_data_cols\n", "\n", + " def _normalization(self, batch, val_size=0, test_size=0):\n", " temporal = batch['temporal'] # B, C, T\n", " temporal_cols = batch['temporal_cols'].copy()\n", "\n", " # Separate data and mask\n", - " temporal_data_cols = temporal_cols.drop('available_mask').tolist()\n", + " temporal_data_cols = self._get_temporal_data_cols(temporal_cols=temporal_cols)\n", " temporal_data = temporal[:, temporal_cols.get_indexer(temporal_data_cols), :]\n", " temporal_mask = temporal[:, temporal_cols.get_loc('available_mask'), :].clone()\n", "\n", @@ -679,6 +684,52 @@ "show_doc(BaseRecurrent.predict, title_level=3)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# add h=0,1 unit test for _parse_windows \n", + "from neuralforecast.losses.pytorch import MAE\n", + "from neuralforecast.utils import AirPassengersDF\n", + "from neuralforecast.tsdataset import TimeSeriesDataset, TimeSeriesDataModule\n", + "\n", + "# Declare batch\n", + "AirPassengersDF['x'] = np.array(len(AirPassengersDF))\n", + "AirPassengersDF['x2'] = np.array(len(AirPassengersDF)) * 2\n", + "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=AirPassengersDF)\n", + "data = TimeSeriesDataModule(dataset=dataset, batch_size=1, drop_last=True)\n", + "\n", + "train_loader = data.train_dataloader()\n", + "batch = next(iter(train_loader))\n", + "\n", + "# Test that hist_exog_list and futr_exog_list correctly filter data that is sent to scaler.\n", + "baserecurrent = BaseRecurrent(h=12,\n", + " input_size=117,\n", + " hist_exog_list=['x', 'x2'],\n", + " futr_exog_list=['x'],\n", + " loss=MAE(),\n", + " valid_loss=MAE(),\n", + " learning_rate=0.001,\n", + " max_steps=1,\n", + " val_check_steps=0,\n", + " batch_size=1,\n", + " valid_batch_size=1,\n", + " windows_batch_size=10,\n", + " inference_input_size=2,\n", + " start_padding_enabled=True)\n", + "\n", + "windows = baserecurrent._create_windows(batch, step='train')\n", + "\n", + "temporal_cols = windows['temporal_cols'].copy() # B, L+H, C\n", + "temporal_data_cols = baserecurrent._get_temporal_data_cols(temporal_cols=temporal_cols)\n", + "\n", + "test_eq(set(temporal_data_cols), set(['y', 'x', 'x2']))\n", + "test_eq(windows['temporal'].shape, torch.Size([1,len(['y', 'x', 'x2', 'available_mask']),117,12+1]))" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/common.base_windows.ipynb b/nbs/common.base_windows.ipynb index 6cb63be19..cab27b4df 100644 --- a/nbs/common.base_windows.ipynb +++ b/nbs/common.base_windows.ipynb @@ -158,15 +158,16 @@ " self.windows_batch_size = windows_batch_size\n", " self.step_size = step_size\n", "\n", - " # Scaler\n", - " self.scaler = TemporalNorm(scaler_type=scaler_type, dim=1) # Time dimension is 1.\n", - "\n", " # Variables\n", " self.futr_exog_list = futr_exog_list if futr_exog_list is not None else []\n", " self.hist_exog_list = hist_exog_list if hist_exog_list is not None else []\n", " self.stat_exog_list = stat_exog_list if stat_exog_list is not None else []\n", " self.exclude_insample_y = exclude_insample_y\n", "\n", + " # Scaler\n", + " self.scaler = TemporalNorm(scaler_type=scaler_type, dim=1, # Time dimension is 1.\n", + " num_features=1+len(self.hist_exog_list)+len(self.futr_exog_list))\n", + "\n", " # Fit arguments\n", " self.val_size = 0\n", " self.test_size = 0\n", @@ -353,6 +354,11 @@ " return windows_batch\n", " else:\n", " raise ValueError(f'Unknown step {step}')\n", + "\n", + " def _get_temporal_data_cols(self, temporal_cols):\n", + " temporal_data_cols = ['y'] + list(set(temporal_cols.tolist()) &\\\n", + " set(self.hist_exog_list + self.futr_exog_list))\n", + " return temporal_data_cols\n", " \n", " def _normalization(self, windows):\n", " # windows are already filtered by train/validation/test\n", @@ -361,7 +367,8 @@ " temporal_cols = windows['temporal_cols'].copy() # B, L+H, C\n", "\n", " # To avoid leakage uses only the lags\n", - " temporal_data_cols = temporal_cols.drop('available_mask').tolist()\n", + " #temporal_data_cols = temporal_cols.drop('available_mask').tolist()\n", + " temporal_data_cols = self._get_temporal_data_cols(temporal_cols=temporal_cols)\n", " temporal_data = temporal[:, :, temporal_cols.get_indexer(temporal_data_cols)]\n", " temporal_mask = temporal[:, :, temporal_cols.get_loc('available_mask')].clone()\n", " if self.h > 0:\n", @@ -822,6 +829,7 @@ "\n", "# Declare batch\n", "AirPassengersDF['x'] = np.array(len(AirPassengersDF))\n", + "AirPassengersDF['x2'] = np.array(len(AirPassengersDF)) * 2\n", "dataset, indices, dates, ds = TimeSeriesDataset.from_df(df=AirPassengersDF)\n", "data = TimeSeriesDataModule(dataset=dataset, batch_size=1, drop_last=True)\n", "\n", @@ -902,8 +910,51 @@ "windows = basewindows._create_windows(batch, step='predict')\n", "windows = basewindows._normalization(windows=windows)\n", "insample_y, insample_mask, outsample_y, outsample_mask, \\\n", - " hist_exog, futr_exog, stat_exog = basewindows._parse_windows(batch, windows)\n" + " hist_exog, futr_exog, stat_exog = basewindows._parse_windows(batch, windows)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54d2e850", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "\n", + "# Test that hist_exog_list and futr_exog_list correctly filter data.\n", + "# that is sent to scaler.\n", + "basewindows = BaseWindows(h=12,\n", + " input_size=500,\n", + " hist_exog_list=['x', 'x2'],\n", + " futr_exog_list=['x'],\n", + " loss=MAE(),\n", + " valid_loss=MAE(),\n", + " learning_rate=0.001,\n", + " max_steps=1,\n", + " val_check_steps=0,\n", + " batch_size=1,\n", + " valid_batch_size=1,\n", + " windows_batch_size=10,\n", + " inference_windows_batch_size=2,\n", + " start_padding_enabled=True)\n", + "\n", + "windows = basewindows._create_windows(batch, step='train')\n", + "\n", + "temporal_cols = windows['temporal_cols'].copy() # B, L+H, C\n", + "temporal_data_cols = basewindows._get_temporal_data_cols(temporal_cols=temporal_cols)\n", + "\n", + "test_eq(set(temporal_data_cols), set(['y', 'x', 'x2']))\n", + "test_eq(windows['temporal'].shape, torch.Size([10,500+12,len(['y', 'x', 'x2', 'available_mask'])]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf493ff9", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/nbs/common.scalers.ipynb b/nbs/common.scalers.ipynb index d6a72953a..50caf9a2f 100644 --- a/nbs/common.scalers.ipynb +++ b/nbs/common.scalers.ipynb @@ -30,7 +30,7 @@ "source": [ "# TemporalNorm\n", "\n", - "> Temporal normalization has proven to be essential in neural forecasting tasks, as it enables network's non-linearities to express themselves. Forecasting scaling methods take particular interest in the temporal dimension where most of the variance dwells, contrary to other deep learning techniques like `BatchNorm` that normalizes across batch and temporal dimensions, and `LayerNorm` that normalizes across the feature dimension. Currently we support the following techniques: `std`, `median`, `norm`, `norm1`, `invariant`.

" + "> Temporal normalization has proven to be essential in neural forecasting tasks, as it enables network's non-linearities to express themselves. Forecasting scaling methods take particular interest in the temporal dimension where most of the variance dwells, contrary to other deep learning techniques like `BatchNorm` that normalizes across batch and temporal dimensions, and `LayerNorm` that normalizes across the feature dimension. Currently we support the following techniques: `std`, `median`, `norm`, `norm1`, `invariant`, `revin`.

**References**
- [Kin G. Olivares, David Luo, Cristian Challu, Stefania La Vattiata, Max Mergenthaler, Artur Dubrawski (2023). \"HINT: Hierarchical Mixture Networks For Coherent Probabilistic Forecasting\". Neural Information Processing Systems, submitted. Working Paper version available at arxiv.](https://arxiv.org/abs/2305.07089)
- [Taesung Kim and Jinhee Kim and Yunwon Tae and Cheonbok Park and Jang-Ho Choi and Jaegul Choo. \"Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift\". ICLR 2022.](https://openreview.net/pdf?id=cGDAkQo1C0p).
- [David Salinas, Valentin Flunkert, Jan Gasthaus, Tim Januschowski (2020). \"DeepAR: Probabilistic forecasting with autoregressive recurrent networks\". International Journal of Forecasting.](https://www.sciencedirect.com/science/article/pii/S0169207019301888)
" ] }, { @@ -42,6 +42,19 @@ "![Figure 1. Illustration of temporal normalization (left), layer normalization (center) and batch normalization (right). The entries in green show the components used to compute the normalizing statistics.](imgs_models/temporal_norm.png)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5400f41", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "import os\n", + "os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"" + ] + }, { "cell_type": "code", "execution_count": null, @@ -202,8 +215,17 @@ " x_range = x_max - x_min\n", " x_range[x_range==0] = 1.0\n", " x_range = x_range + eps\n", - " return x_min, x_range\n", - "\n", + " return x_min, x_range" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39fa429b", + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", "def minmax_scaler(x, x_min, x_range):\n", " return (x - x_min) / x_range\n", "\n", @@ -263,8 +285,17 @@ " x_range = x_max - x_min\n", " x_range[x_range==0] = 1.0\n", " x_range = x_range + eps\n", - " return x_min, x_range\n", - "\n", + " return x_min, x_range" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a19ed5a8", + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", "def minmax1_scaler(x, x_min, x_range):\n", " x = (x - x_min) / x_range\n", " z = x * (2) - 1\n", @@ -316,12 +347,21 @@ " \"\"\"\n", " x_means = masked_mean(x=x, mask=mask, dim=dim)\n", " x_stds = torch.sqrt(masked_mean(x=(x-x_means)**2, mask=mask, dim=dim))\n", - " \n", + "\n", " # Protect against division by zero\n", " x_stds[x_stds==0] = 1.0\n", " x_stds = x_stds + eps\n", - " return x_means, x_stds\n", - "\n", + " return x_means, x_stds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17f90821", + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", "def std_scaler(x, x_means, x_stds):\n", " return (x - x_means) / x_stds\n", "\n", @@ -386,8 +426,17 @@ " # Protect against division by zero\n", " x_mad[x_mad==0] = 1.0\n", " x_mad = x_mad + eps\n", - " return x_median, x_mad\n", - "\n", + " return x_median, x_mad" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33f3cf28", + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", "def robust_scaler(x, x_median, x_mad):\n", " return (x - x_median) / x_mad\n", "\n", @@ -450,8 +499,17 @@ " # Protect against division by zero\n", " x_mad[x_mad==0] = 1.0\n", " x_mad = x_mad + eps\n", - " return x_median, x_mad\n", - "\n", + " return x_median, x_mad" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24cca2bf", + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", "def invariant_scaler(x, x_median, x_mad):\n", " return torch.arcsinh((x - x_median) / x_mad)\n", "\n", @@ -500,8 +558,17 @@ " x_shift = torch.zeros(shape)\n", " x_scale = torch.ones(shape)\n", "\n", - " return x_shift, x_scale\n", - "\n", + " return x_shift, x_scale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d7b313e", + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", "def identity_scaler(x, x_shift, x_scale):\n", " return x\n", "\n", @@ -546,18 +613,27 @@ "\n", " $$\\mathbf{z}_{[B,T,C]} = \\\\textrm{Scaler}(\\mathbf{x}_{[B,T,C]})$$\n", "\n", + " If `scaler_type` is `revin` learnable normalization parameters are added on top of\n", + " the usual normalization technique, the parameters are learned through scale decouple\n", + " global skip connections. The technique is available for point and probabilistic outputs.\n", + "\n", + " $$\\mathbf{\\hat{z}}_{[B,T,C]} = \\\\boldsymbol{\\hat{\\\\gamma}}_{[1,1,C]} \\mathbf{z}_{[B,T,C]} +\\\\boldsymbol{\\hat{\\\\beta}}_{[1,1,C]}$$\n", + "\n", " **Parameters:**
\n", - " `scaler_type`: str, defines the type of scaler used by TemporalNorm.\n", - " available [`identity`, `standard`, `robust`, `minmax`, `minmax1`, `invariant`].
\n", + " `scaler_type`: str, defines the type of scaler used by TemporalNorm. Available [`identity`, `standard`, `robust`, `minmax`, `minmax1`, `invariant`, `revin`].
\n", " `dim` (int, optional): Dimension over to compute scale and shift. Defaults to -1.
\n", " `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
\n", - " \n", - " \"\"\" \n", - " def __init__(self, scaler_type='robust', dim=-1, eps=1e-6):\n", + " `num_features`: int=None, for RevIN-like learnable affine parameters initialization.
\n", + "\n", + " **References**
\n", + " - [Kin G. Olivares, David Luo, Cristian Challu, Stefania La Vattiata, Max Mergenthaler, Artur Dubrawski (2023). \"HINT: Hierarchical Mixture Networks For Coherent Probabilistic Forecasting\". Neural Information Processing Systems, submitted. Working Paper version available at arxiv.](https://arxiv.org/abs/2305.07089)
\n", + " \"\"\"\n", + " def __init__(self, scaler_type='robust', dim=-1, eps=1e-6, num_features=None):\n", " super().__init__()\n", " compute_statistics = {None: identity_statistics,\n", " 'identity': identity_statistics,\n", " 'standard': std_statistics,\n", + " 'revin': std_statistics,\n", " 'robust': robust_statistics,\n", " 'minmax': minmax_statistics,\n", " 'minmax1': minmax1_statistics,\n", @@ -565,6 +641,7 @@ " scalers = {None: identity_scaler,\n", " 'identity': identity_scaler,\n", " 'standard': std_scaler,\n", + " 'revin': std_scaler,\n", " 'robust': robust_scaler,\n", " 'minmax': minmax_scaler,\n", " 'minmax1': minmax1_scaler,\n", @@ -572,11 +649,14 @@ " inverse_scalers = {None: inv_identity_scaler,\n", " 'identity': inv_identity_scaler,\n", " 'standard': inv_std_scaler,\n", + " 'revin': inv_std_scaler,\n", " 'robust': inv_robust_scaler,\n", " 'minmax': inv_minmax_scaler,\n", " 'minmax1': inv_minmax1_scaler,\n", " 'invariant': inv_invariant_scaler,}\n", " assert (scaler_type in scalers.keys()), f'{scaler_type} not defined'\n", + " if (scaler_type=='revin') and (num_features is None):\n", + " raise Exception('You must pass num_features for ReVIN scaler.')\n", "\n", " self.compute_statistics = compute_statistics[scaler_type]\n", " self.scaler = scalers[scaler_type]\n", @@ -585,6 +665,18 @@ " self.dim = dim\n", " self.eps = eps\n", "\n", + " if (scaler_type=='revin'):\n", + " self._init_params(num_features=num_features)\n", + "\n", + " def _init_params(self, num_features):\n", + " # Initialize RevIN scaler params to broadcast:\n", + " if self.dim==1: # [B,T,C] [1,1,C]\n", + " self.revin_bias = nn.Parameter(torch.zeros(1,1,num_features))\n", + " self.revin_weight = nn.Parameter(torch.ones(1,1,num_features))\n", + " elif self.dim==-1: # [B,C,T] [1,C,1]\n", + " self.revin_bias = nn.Parameter(torch.zeros(1,num_features,1))\n", + " self.revin_weight = nn.Parameter(torch.ones(1,num_features,1))\n", + "\n", " #@torch.no_grad()\n", " def transform(self, x, mask):\n", " \"\"\" Center and scale the data.\n", @@ -594,13 +686,23 @@ " `mask`: torch Tensor bool, shape [batch, time] where `x` is valid and False\n", " where `x` should be masked. Mask should not be all False in any column of\n", " dimension dim to avoid NaNs from zero division.
\n", - " \n", + "\n", " **Returns:**
\n", - " `z`: torch.Tensor same shape as `x`, except scaled. \n", + " `z`: torch.Tensor same shape as `x`, except scaled.\n", " \"\"\"\n", " x_shift, x_scale = self.compute_statistics(x=x, mask=mask, dim=self.dim, eps=self.eps)\n", " self.x_shift = x_shift\n", " self.x_scale = x_scale\n", + "\n", + " # Original Revin performs this operation\n", + " # z = self.revin_weight * z\n", + " # z = z + self.revin_bias\n", + " # However this is only valid for point forecast not for\n", + " # distribution's scale decouple technique.\n", + " if self.scaler_type=='revin':\n", + " self.x_shift = self.x_shift + self.revin_bias\n", + " self.x_scale = self.x_scale * torch.relu(self.revin_weight+self.eps)\n", + "\n", " z = self.scaler(x, x_shift, x_scale)\n", " return z\n", "\n", @@ -614,13 +716,24 @@ " **Returns:**
\n", " `x`: torch.Tensor original data.\n", " \"\"\"\n", + "\n", " if x_shift is None:\n", " x_shift = self.x_shift\n", " if x_scale is None:\n", " x_scale = self.x_scale\n", "\n", + " # Original Revin performs this operation\n", + " # z = z - self.revin_bias\n", + " # z = (z / (self.revin_weight + self.eps))\n", + " # However this is only valid for point forecast not for\n", + " # distribution's scale decouple technique.\n", + "\n", " x = self.inverse_scaler(z, x_shift, x_scale)\n", - " return x" + " return x\n", + "\n", + " def forward(self, x):\n", + " # The gradients are optained from BaseWindows/BaseRecurrent forwards.\n", + " pass" ] }, { @@ -739,22 +852,13 @@ "source": [ "#| hide\n", "# Validate scalers\n", - "for scaler_type in [None, 'identity', 'standard', 'robust', 'minmax', 'minmax1', 'invariant']:\n", + "for scaler_type in [None, 'identity', 'standard', 'robust', 'minmax', 'minmax1', 'invariant', 'revin']:\n", " x = 1.0*torch.tensor(np_x)\n", " mask = torch.tensor(np_mask)\n", - " scaler = TemporalNorm(scaler_type=scaler_type, dim=1)\n", + " scaler = TemporalNorm(scaler_type=scaler_type, dim=1, num_features=np_x.shape[-1])\n", " x_scaled = scaler.transform(x=x, mask=mask)\n", " x_recovered = scaler.inverse_transform(x_scaled)\n", - " assert torch.allclose(x, x_recovered, atol=1e-5), f'Recovered data is not the same as original with {scaler_type}'" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "fb1207bd", - "metadata": {}, - "source": [ - "# Test Predict (masked)" + " assert torch.allclose(x, x_recovered, atol=1e-3), f'Recovered data is not the same as original with {scaler_type}'" ] }, { @@ -765,6 +869,8 @@ "outputs": [], "source": [ "#| hide\n", + "\n", + "# Unit test for masked predict filtering\n", "import pandas as pd\n", "\n", "from neuralforecast import NeuralForecast\n", @@ -783,6 +889,86 @@ "Y_hat = nf.predict(df=Y_df)\n", "assert pd.isnull(Y_hat).sum().sum() == 0, 'Predictions should not have NaNs'" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa6e6a40", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "\n", + "# Unit test for ReVIN, and its compatibility with distribution's scale decouple\n", + "from neuralforecast import NeuralForecast\n", + "from neuralforecast.models import NHITS, RNN\n", + "from neuralforecast.losses.pytorch import DistributionLoss, HuberLoss, GMM, MAE\n", + "from neuralforecast.tsdataset import TimeSeriesDataset\n", + "from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic\n", + "\n", + "Y_df = AirPassengersPanel\n", + "# del Y_df['trend']\n", + "\n", + "# Instantiate BaseWindow model and test revin dynamic dimensionality with hist_exog_list\n", + "model = NHITS(h=12,\n", + " input_size=24,\n", + " loss=GMM(n_components=10, level=[90]),\n", + " hist_exog_list=['y_[lag12]'],\n", + " max_steps=1,\n", + " early_stop_patience_steps=10,\n", + " val_check_steps=50,\n", + " scaler_type='revin',\n", + " learning_rate=1e-3)\n", + "nf = NeuralForecast(models=[model], freq='MS')\n", + "Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)\n", + "\n", + "# Instantiate BaseWindow model and test revin dynamic dimensionality with hist_exog_list\n", + "model = NHITS(h=12,\n", + " input_size=24,\n", + " loss=HuberLoss(),\n", + " hist_exog_list=['trend', 'y_[lag12]'],\n", + " max_steps=1,\n", + " early_stop_patience_steps=10,\n", + " val_check_steps=50,\n", + " scaler_type='revin',\n", + " learning_rate=1e-3)\n", + "nf = NeuralForecast(models=[model], freq='MS')\n", + "Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)\n", + "\n", + "# Instantiate BaseRecurrent model and test revin dynamic dimensionality with hist_exog_list\n", + "model = RNN(h=12,\n", + " input_size=24,\n", + " loss=GMM(n_components=10, level=[90]),\n", + " hist_exog_list=['trend', 'y_[lag12]'],\n", + " max_steps=1,\n", + " early_stop_patience_steps=10,\n", + " val_check_steps=50,\n", + " scaler_type='revin',\n", + " learning_rate=1e-3)\n", + "nf = NeuralForecast(models=[model], freq='MS')\n", + "Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)\n", + "\n", + "# Instantiate BaseRecurrent model and test revin dynamic dimensionality with hist_exog_list\n", + "model = RNN(h=12,\n", + " input_size=24,\n", + " loss=HuberLoss(),\n", + " hist_exog_list=['trend'],\n", + " max_steps=1,\n", + " early_stop_patience_steps=10,\n", + " val_check_steps=50,\n", + " scaler_type='revin',\n", + " learning_rate=1e-3)\n", + "nf = NeuralForecast(models=[model], freq='MS')\n", + "Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2f50bd8", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/neuralforecast/common/_base_recurrent.py b/neuralforecast/common/_base_recurrent.py index a206ea51b..e6ea16b63 100644 --- a/neuralforecast/common/_base_recurrent.py +++ b/neuralforecast/common/_base_recurrent.py @@ -102,16 +102,18 @@ def __init__( self.early_stop_patience_steps = early_stop_patience_steps self.val_check_steps = val_check_steps - # Scaler - self.scaler = TemporalNorm( - scaler_type=scaler_type, dim=-1 - ) # Time dimension is -1. - # Variables self.futr_exog_list = futr_exog_list if futr_exog_list is not None else [] self.hist_exog_list = hist_exog_list if hist_exog_list is not None else [] self.stat_exog_list = stat_exog_list if stat_exog_list is not None else [] + # Scaler + self.scaler = TemporalNorm( + scaler_type=scaler_type, + dim=-1, # Time dimension is -1. + num_features=1 + len(self.hist_exog_list) + len(self.futr_exog_list), + ) + # Fit arguments self.val_size = 0 self.test_size = 0 @@ -176,12 +178,18 @@ def configure_optimizers(self): } return {"optimizer": optimizer, "lr_scheduler": scheduler} + def _get_temporal_data_cols(self, temporal_cols): + temporal_data_cols = ["y"] + list( + set(temporal_cols.tolist()) & set(self.hist_exog_list + self.futr_exog_list) + ) + return temporal_data_cols + def _normalization(self, batch, val_size=0, test_size=0): temporal = batch["temporal"] # B, C, T temporal_cols = batch["temporal_cols"].copy() # Separate data and mask - temporal_data_cols = temporal_cols.drop("available_mask").tolist() + temporal_data_cols = self._get_temporal_data_cols(temporal_cols=temporal_cols) temporal_data = temporal[:, temporal_cols.get_indexer(temporal_data_cols), :] temporal_mask = temporal[:, temporal_cols.get_loc("available_mask"), :].clone() diff --git a/neuralforecast/common/_base_windows.py b/neuralforecast/common/_base_windows.py index f4516b36a..3ccd58dbc 100644 --- a/neuralforecast/common/_base_windows.py +++ b/neuralforecast/common/_base_windows.py @@ -108,17 +108,19 @@ def __init__( self.windows_batch_size = windows_batch_size self.step_size = step_size - # Scaler - self.scaler = TemporalNorm( - scaler_type=scaler_type, dim=1 - ) # Time dimension is 1. - # Variables self.futr_exog_list = futr_exog_list if futr_exog_list is not None else [] self.hist_exog_list = hist_exog_list if hist_exog_list is not None else [] self.stat_exog_list = stat_exog_list if stat_exog_list is not None else [] self.exclude_insample_y = exclude_insample_y + # Scaler + self.scaler = TemporalNorm( + scaler_type=scaler_type, + dim=1, # Time dimension is 1. + num_features=1 + len(self.hist_exog_list) + len(self.futr_exog_list), + ) + # Fit arguments self.val_size = 0 self.test_size = 0 @@ -329,6 +331,12 @@ def _create_windows(self, batch, step, w_idxs=None): else: raise ValueError(f"Unknown step {step}") + def _get_temporal_data_cols(self, temporal_cols): + temporal_data_cols = ["y"] + list( + set(temporal_cols.tolist()) & set(self.hist_exog_list + self.futr_exog_list) + ) + return temporal_data_cols + def _normalization(self, windows): # windows are already filtered by train/validation/test # from the `create_windows_method` nor leakage risk @@ -336,7 +344,8 @@ def _normalization(self, windows): temporal_cols = windows["temporal_cols"].copy() # B, L+H, C # To avoid leakage uses only the lags - temporal_data_cols = temporal_cols.drop("available_mask").tolist() + # temporal_data_cols = temporal_cols.drop('available_mask').tolist() + temporal_data_cols = self._get_temporal_data_cols(temporal_cols=temporal_cols) temporal_data = temporal[:, :, temporal_cols.get_indexer(temporal_data_cols)] temporal_mask = temporal[:, :, temporal_cols.get_loc("available_mask")].clone() if self.h > 0: diff --git a/neuralforecast/common/_scalers.py b/neuralforecast/common/_scalers.py index 551c05a8a..a388dd9a8 100644 --- a/neuralforecast/common/_scalers.py +++ b/neuralforecast/common/_scalers.py @@ -1,16 +1,14 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/common.scalers.ipynb. # %% auto 0 -__all__ = ['masked_median', 'masked_mean', 'minmax_statistics', 'minmax_scaler', 'inv_minmax_scaler', 'minmax1_statistics', - 'minmax1_scaler', 'inv_minmax1_scaler', 'std_statistics', 'std_scaler', 'inv_std_scaler', - 'robust_statistics', 'robust_scaler', 'inv_robust_scaler', 'invariant_statistics', 'invariant_scaler', - 'inv_invariant_scaler', 'identity_statistics', 'identity_scaler', 'inv_identity_scaler', 'TemporalNorm'] +__all__ = ['masked_median', 'masked_mean', 'minmax_statistics', 'minmax1_statistics', 'std_statistics', 'robust_statistics', + 'invariant_statistics', 'identity_statistics', 'TemporalNorm'] -# %% ../../nbs/common.scalers.ipynb 4 +# %% ../../nbs/common.scalers.ipynb 5 import torch import torch.nn as nn -# %% ../../nbs/common.scalers.ipynb 7 +# %% ../../nbs/common.scalers.ipynb 8 def masked_median(x, mask, dim=-1, keepdim=True): """Masked Median @@ -56,7 +54,7 @@ def masked_mean(x, mask, dim=-1, keepdim=True): x_mean = torch.nan_to_num(x_mean, nan=0.0) return x_mean -# %% ../../nbs/common.scalers.ipynb 11 +# %% ../../nbs/common.scalers.ipynb 12 def minmax_statistics(x, mask, eps=1e-6, dim=-1): """MinMax Scaler @@ -96,7 +94,7 @@ def minmax_statistics(x, mask, eps=1e-6, dim=-1): x_range = x_range + eps return x_min, x_range - +# %% ../../nbs/common.scalers.ipynb 13 def minmax_scaler(x, x_min, x_range): return (x - x_min) / x_range @@ -104,7 +102,7 @@ def minmax_scaler(x, x_min, x_range): def inv_minmax_scaler(z, x_min, x_range): return z * x_range + x_min -# %% ../../nbs/common.scalers.ipynb 13 +# %% ../../nbs/common.scalers.ipynb 15 def minmax1_statistics(x, mask, eps=1e-6, dim=-1): """MinMax1 Scaler @@ -145,7 +143,7 @@ def minmax1_statistics(x, mask, eps=1e-6, dim=-1): x_range = x_range + eps return x_min, x_range - +# %% ../../nbs/common.scalers.ipynb 16 def minmax1_scaler(x, x_min, x_range): x = (x - x_min) / x_range z = x * (2) - 1 @@ -156,7 +154,7 @@ def inv_minmax1_scaler(z, x_min, x_range): z = (z + 1) / 2 return z * x_range + x_min -# %% ../../nbs/common.scalers.ipynb 15 +# %% ../../nbs/common.scalers.ipynb 18 def std_statistics(x, mask, dim=-1, eps=1e-6): """Standard Scaler @@ -186,7 +184,7 @@ def std_statistics(x, mask, dim=-1, eps=1e-6): x_stds = x_stds + eps return x_means, x_stds - +# %% ../../nbs/common.scalers.ipynb 19 def std_scaler(x, x_means, x_stds): return (x - x_means) / x_stds @@ -194,7 +192,7 @@ def std_scaler(x, x_means, x_stds): def inv_std_scaler(z, x_mean, x_std): return (z * x_std) + x_mean -# %% ../../nbs/common.scalers.ipynb 17 +# %% ../../nbs/common.scalers.ipynb 21 def robust_statistics(x, mask, dim=-1, eps=1e-6): """Robust Median Scaler @@ -236,7 +234,7 @@ def robust_statistics(x, mask, dim=-1, eps=1e-6): x_mad = x_mad + eps return x_median, x_mad - +# %% ../../nbs/common.scalers.ipynb 22 def robust_scaler(x, x_median, x_mad): return (x - x_median) / x_mad @@ -244,7 +242,7 @@ def robust_scaler(x, x_median, x_mad): def inv_robust_scaler(z, x_median, x_mad): return z * x_mad + x_median -# %% ../../nbs/common.scalers.ipynb 19 +# %% ../../nbs/common.scalers.ipynb 24 def invariant_statistics(x, mask, dim=-1, eps=1e-6): """Invariant Median Scaler @@ -284,7 +282,7 @@ def invariant_statistics(x, mask, dim=-1, eps=1e-6): x_mad = x_mad + eps return x_median, x_mad - +# %% ../../nbs/common.scalers.ipynb 25 def invariant_scaler(x, x_median, x_mad): return torch.arcsinh((x - x_median) / x_mad) @@ -292,7 +290,7 @@ def invariant_scaler(x, x_median, x_mad): def inv_invariant_scaler(z, x_median, x_mad): return torch.sinh(z) * x_mad + x_median -# %% ../../nbs/common.scalers.ipynb 21 +# %% ../../nbs/common.scalers.ipynb 27 def identity_statistics(x, mask, dim=-1, eps=1e-6): """Identity Scaler @@ -318,7 +316,7 @@ def identity_statistics(x, mask, dim=-1, eps=1e-6): return x_shift, x_scale - +# %% ../../nbs/common.scalers.ipynb 28 def identity_scaler(x, x_shift, x_scale): return x @@ -326,7 +324,7 @@ def identity_scaler(x, x_shift, x_scale): def inv_identity_scaler(z, x_shift, x_scale): return z -# %% ../../nbs/common.scalers.ipynb 24 +# %% ../../nbs/common.scalers.ipynb 31 class TemporalNorm(nn.Module): """Temporal Normalization @@ -337,20 +335,29 @@ class TemporalNorm(nn.Module): $$\mathbf{z}_{[B,T,C]} = \\textrm{Scaler}(\mathbf{x}_{[B,T,C]})$$ + If `scaler_type` is `revin` learnable normalization parameters are added on top of + the usual normalization technique, the parameters are learned through scale decouple + global skip connections. The technique is available for point and probabilistic outputs. + + $$\mathbf{\hat{z}}_{[B,T,C]} = \\boldsymbol{\hat{\\gamma}}_{[1,1,C]} \mathbf{z}_{[B,T,C]} +\\boldsymbol{\hat{\\beta}}_{[1,1,C]}$$ + **Parameters:**
- `scaler_type`: str, defines the type of scaler used by TemporalNorm. - available [`identity`, `standard`, `robust`, `minmax`, `minmax1`, `invariant`].
+ `scaler_type`: str, defines the type of scaler used by TemporalNorm. Available [`identity`, `standard`, `robust`, `minmax`, `minmax1`, `invariant`, `revin`].
`dim` (int, optional): Dimension over to compute scale and shift. Defaults to -1.
`eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
+ `num_features`: int=None, for RevIN-like learnable affine parameters initialization.
+ **References**
+ - [Kin G. Olivares, David Luo, Cristian Challu, Stefania La Vattiata, Max Mergenthaler, Artur Dubrawski (2023). "HINT: Hierarchical Mixture Networks For Coherent Probabilistic Forecasting". Neural Information Processing Systems, submitted. Working Paper version available at arxiv.](https://arxiv.org/abs/2305.07089)
""" - def __init__(self, scaler_type="robust", dim=-1, eps=1e-6): + def __init__(self, scaler_type="robust", dim=-1, eps=1e-6, num_features=None): super().__init__() compute_statistics = { None: identity_statistics, "identity": identity_statistics, "standard": std_statistics, + "revin": std_statistics, "robust": robust_statistics, "minmax": minmax_statistics, "minmax1": minmax1_statistics, @@ -360,6 +367,7 @@ def __init__(self, scaler_type="robust", dim=-1, eps=1e-6): None: identity_scaler, "identity": identity_scaler, "standard": std_scaler, + "revin": std_scaler, "robust": robust_scaler, "minmax": minmax_scaler, "minmax1": minmax1_scaler, @@ -369,12 +377,15 @@ def __init__(self, scaler_type="robust", dim=-1, eps=1e-6): None: inv_identity_scaler, "identity": inv_identity_scaler, "standard": inv_std_scaler, + "revin": inv_std_scaler, "robust": inv_robust_scaler, "minmax": inv_minmax_scaler, "minmax1": inv_minmax1_scaler, "invariant": inv_invariant_scaler, } assert scaler_type in scalers.keys(), f"{scaler_type} not defined" + if (scaler_type == "revin") and (num_features is None): + raise Exception("You must pass num_features for ReVIN scaler.") self.compute_statistics = compute_statistics[scaler_type] self.scaler = scalers[scaler_type] @@ -383,6 +394,18 @@ def __init__(self, scaler_type="robust", dim=-1, eps=1e-6): self.dim = dim self.eps = eps + if scaler_type == "revin": + self._init_params(num_features=num_features) + + def _init_params(self, num_features): + # Initialize RevIN scaler params to broadcast: + if self.dim == 1: # [B,T,C] [1,1,C] + self.revin_bias = nn.Parameter(torch.zeros(1, 1, num_features)) + self.revin_weight = nn.Parameter(torch.ones(1, 1, num_features)) + elif self.dim == -1: # [B,C,T] [1,C,1] + self.revin_bias = nn.Parameter(torch.zeros(1, num_features, 1)) + self.revin_weight = nn.Parameter(torch.ones(1, num_features, 1)) + # @torch.no_grad() def transform(self, x, mask): """Center and scale the data. @@ -401,6 +424,16 @@ def transform(self, x, mask): ) self.x_shift = x_shift self.x_scale = x_scale + + # Original Revin performs this operation + # z = self.revin_weight * z + # z = z + self.revin_bias + # However this is only valid for point forecast not for + # distribution's scale decouple technique. + if self.scaler_type == "revin": + self.x_shift = self.x_shift + self.revin_bias + self.x_scale = self.x_scale * torch.relu(self.revin_weight + self.eps) + z = self.scaler(x, x_shift, x_scale) return z @@ -414,10 +447,21 @@ def inverse_transform(self, z, x_shift=None, x_scale=None): **Returns:**
`x`: torch.Tensor original data. """ + if x_shift is None: x_shift = self.x_shift if x_scale is None: x_scale = self.x_scale + # Original Revin performs this operation + # z = z - self.revin_bias + # z = (z / (self.revin_weight + self.eps)) + # However this is only valid for point forecast not for + # distribution's scale decouple technique. + x = self.inverse_scaler(z, x_shift, x_scale) return x + + def forward(self, x): + # The gradients are optained from BaseWindows/BaseRecurrent forwards. + pass