From 73e51e489b36326e02a7531021dce2f6a2cb0fcd Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 30 Oct 2023 09:22:02 +0100 Subject: [PATCH] Add option to provide function for season_length in SeasonalNaivePredictor --- .../model/seasonal_naive/_predictor.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/gluonts/model/seasonal_naive/_predictor.py b/src/gluonts/model/seasonal_naive/_predictor.py index deaf3978b8..c972b38e62 100644 --- a/src/gluonts/model/seasonal_naive/_predictor.py +++ b/src/gluonts/model/seasonal_naive/_predictor.py @@ -11,6 +11,8 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. +from typing import Callable, Union + import numpy as np from gluonts.core.component import validated @@ -43,7 +45,10 @@ class SeasonalNaivePredictor(RepresentablePredictor): prediction_length Number of time points to predict. season_length - Seasonality used to make predictions. + Seasonality used to make predictions. If this is an integer, then a fixed + sesasonlity is applied; if this is a function, then it will be called on each + given entry's ``freq`` attribute of the ``"start"`` field, and will return + the seasonality to use. imputation_method The imputation method to use in case of missing values. Defaults to :py:class:`LastValueImputation` which replaces each missing @@ -54,18 +59,25 @@ class SeasonalNaivePredictor(RepresentablePredictor): def __init__( self, prediction_length: int, - season_length: int, + season_length: Union[int, Callable], imputation_method: MissingValueImputation = LastValueImputation(), ) -> None: super().__init__(prediction_length=prediction_length) - assert season_length > 0, "The value of `season_length` should be > 0" + assert ( + not isinstance(season_length, int) or season_length > 0 + ), "The value of `season_length` should be > 0" self.prediction_length = prediction_length self.season_length = season_length self.imputation_method = imputation_method def predict_item(self, item: DataEntry) -> Forecast: + if isinstance(self.season_length, int): + season_length = self.season_length + else: + season_length = self.season_length(item["start"].freq) + target = np.asarray(item[FieldName.TARGET], np.float32) len_ts = len(target) forecast_start_time = forecast_start(item) @@ -78,9 +90,9 @@ def predict_item(self, item: DataEntry) -> Forecast: target = target.copy() target = self.imputation_method(target) - if len_ts >= self.season_length: + if len_ts >= season_length: indices = [ - len_ts - self.season_length + k % self.season_length + len_ts - season_length + k % season_length for k in range(self.prediction_length) ] samples = target[indices].reshape((1, self.prediction_length))