Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow function as season_length argument in SeasonalNaivePredictor #3033

Merged
merged 2 commits into from
Oct 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/gluonts/model/seasonal_naive/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Callable, Union

import numpy as np

from gluonts.core.component import validated
Expand Down Expand Up @@ -43,7 +45,10 @@ class SeasonalNaivePredictor(RepresentablePredictor):
prediction_length
Number of time points to predict.
season_length
Seasonality used to make predictions.
Seasonality used to make predictions. If this is an integer, then a fixed
sesasonlity is applied; if this is a function, then it will be called on each
given entry's ``freq`` attribute of the ``"start"`` field, and will return
the seasonality to use.
imputation_method
The imputation method to use in case of missing values.
Defaults to :py:class:`LastValueImputation` which replaces each missing
Expand All @@ -54,18 +59,25 @@ class SeasonalNaivePredictor(RepresentablePredictor):
def __init__(
self,
prediction_length: int,
season_length: int,
season_length: Union[int, Callable],
imputation_method: MissingValueImputation = LastValueImputation(),
) -> None:
super().__init__(prediction_length=prediction_length)

assert season_length > 0, "The value of `season_length` should be > 0"
assert (
not isinstance(season_length, int) or season_length > 0
), "The value of `season_length` should be > 0"

self.prediction_length = prediction_length
self.season_length = season_length
self.imputation_method = imputation_method

def predict_item(self, item: DataEntry) -> Forecast:
if isinstance(self.season_length, int):
season_length = self.season_length
else:
season_length = self.season_length(item["start"].freq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires that there is a freq for the time series, which might not be the case in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this mechanism is really only useful if you have frequency information attached to the series. If that's not the case, there's no point in using this option, and the model should be configured with an int value.

Also: I believe the assumption that item["start"] is a pd.Period (hence has frequency information) is already baked into this method since it uses forecast_start?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment we require that all time series have frequency information attached -- but that's an arbitrary requirement.


target = np.asarray(item[FieldName.TARGET], np.float32)
len_ts = len(target)
forecast_start_time = forecast_start(item)
Expand All @@ -78,9 +90,9 @@ def predict_item(self, item: DataEntry) -> Forecast:
target = target.copy()
target = self.imputation_method(target)

if len_ts >= self.season_length:
if len_ts >= season_length:
indices = [
len_ts - self.season_length + k % self.season_length
len_ts - season_length + k % season_length
for k in range(self.prediction_length)
]
samples = target[indices].reshape((1, self.prediction_length))
Expand Down
Loading