Skip to content

Commit

Permalink
revert default seasonality
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed Nov 13, 2023
1 parent a6c1fb7 commit b64abb2
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/gluonts/model/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

from gluonts.dataset import DataEntry
from gluonts.dataset.split import TestData
from gluonts.model import Forecast, Predictor
from gluonts.ev.ts_stats import seasonal_error
from gluonts.itertools import batcher, prod
from gluonts.model import Forecast, Predictor
from gluonts.time_feature.seasonality import get_seasonality

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,7 +61,7 @@ def _get_data_batch(
input_batch: List[DataEntry],
label_batch: List[DataEntry],
forecast_batch: List[Forecast],
seasonality: int = 1,
seasonality: Optional[int] = None,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
) -> ChainMap:
Expand All @@ -74,11 +75,18 @@ def _get_data_batch(

seasonal_error_values = []
for input_ in input_batch:
seasonality_entry = seasonality
if seasonality_entry is None:
seasonality_entry = get_seasonality(input_["start"].freqstr)
input_target = input_["target"]
if mask_invalid_label:
input_target = np.ma.masked_invalid(input_target)
seasonal_error_values.append(
seasonal_error(input_target, seasonality=seasonality, time_axis=-1)
seasonal_error(
input_target,
seasonality=seasonality_entry,
time_axis=-1,
)
)
other_data["seasonal_error"] = np.array(seasonal_error_values)

Expand All @@ -96,7 +104,7 @@ def evaluate_forecasts_raw(
batch_size: int = 100,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
seasonality: int = 1
seasonality: Optional[int] = None,
) -> dict:
"""
Evaluate ``forecasts`` by comparing them with ``test_data``, according
Expand Down Expand Up @@ -181,7 +189,7 @@ def evaluate_forecasts(
batch_size: int = 100,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
seasonality: int = 1
seasonality: Optional[int] = None,
) -> pd.DataFrame:
"""
Evaluate ``forecasts`` by comparing them with ``test_data``, according
Expand Down Expand Up @@ -235,7 +243,7 @@ def evaluate_model(
batch_size: int = 100,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
seasonality: int = 1
seasonality: Optional[int] = None,
) -> pd.DataFrame:
"""
Evaluate ``model`` when applied to ``test_data``, according
Expand Down

0 comments on commit b64abb2

Please sign in to comment.