diff --git a/src/gluonts/model/evaluation.py b/src/gluonts/model/evaluation.py index ae551a19eb..15638dea84 100644 --- a/src/gluonts/model/evaluation.py +++ b/src/gluonts/model/evaluation.py @@ -13,7 +13,7 @@ import logging from collections import ChainMap -from typing import Iterable, Optional, Union +from typing import Iterable, List, Optional, Union from dataclasses import dataclass from toolz import first, valmap @@ -23,10 +23,10 @@ from gluonts.dataset import DataEntry from gluonts.dataset.split import TestData -from gluonts.time_feature.seasonality import get_seasonality -from gluonts.model import Forecast, Predictor from gluonts.ev.ts_stats import seasonal_error -from gluonts.itertools import prod +from gluonts.itertools import batcher, prod +from gluonts.model import Forecast, Predictor +from gluonts.time_feature.seasonality import get_seasonality logger = logging.getLogger(__name__) @@ -39,12 +39,14 @@ class BatchForecast: ``gluonts.ev``. """ - forecast: Forecast + forecasts: List[Forecast] allow_nan: bool = False def __getitem__(self, name): - value = self.forecast[name] - if np.isnan(value).any(): + values = [forecast[name].T for forecast in self.forecasts] + res = np.stack(values, axis=0) + + if np.isnan(res).any(): if not self.allow_nan: raise ValueError("Forecast contains NaN values") @@ -52,40 +54,45 @@ def __getitem__(self, name): "Forecast contains NaN values. Metrics may be incorrect." ) - return np.expand_dims(value.T, axis=0) + return res def _get_data_batch( - input_: DataEntry, - label: DataEntry, - forecast: Forecast, + input_batch: List[DataEntry], + label_batch: List[DataEntry], + forecast_batch: List[Forecast], seasonality: Optional[int] = None, mask_invalid_label: bool = True, allow_nan_forecast: bool = False, ) -> ChainMap: - forecast_dict = BatchForecast(forecast, allow_nan=allow_nan_forecast) - - freq = forecast.start_date.freqstr - if seasonality is None: - seasonality = get_seasonality(freq=freq) - - label_target = label["target"] - input_target = input_["target"] + label_target = np.stack([label["target"] for label in label_batch], axis=0) if mask_invalid_label: label_target = np.ma.masked_invalid(label_target) - input_target = np.ma.masked_invalid(input_target) other_data = { - "label": np.expand_dims(label_target, axis=0), - "seasonal_error": np.expand_dims( - seasonal_error( - input_target, seasonality=seasonality, time_axis=-1 - ), - axis=0, - ), + "label": label_target, } - return ChainMap(other_data, forecast_dict) # type: ignore + seasonal_error_values = [] + for input_ in input_batch: + seasonality_entry = seasonality + if seasonality_entry is None: + seasonality_entry = get_seasonality(input_["start"].freqstr) + input_target = input_["target"] + if mask_invalid_label: + input_target = np.ma.masked_invalid(input_target) + seasonal_error_values.append( + seasonal_error( + input_target, + seasonality=seasonality_entry, + time_axis=-1, + ) + ) + other_data["seasonal_error"] = np.array(seasonal_error_values) + + return ChainMap( + other_data, BatchForecast(forecast_batch, allow_nan=allow_nan_forecast) # type: ignore + ) def evaluate_forecasts_raw( @@ -94,6 +101,7 @@ def evaluate_forecasts_raw( test_data: TestData, metrics, axis: Optional[Union[int, tuple]] = None, + batch_size: int = 100, mask_invalid_label: bool = True, allow_nan_forecast: bool = False, seasonality: Optional[int] = None @@ -130,16 +138,26 @@ def evaluate_forecasts_raw( index_data = [] - for input_, label, forecast in tqdm( - zip(test_data.input, test_data.label, forecasts) + input_batches = batcher(test_data.input, batch_size=batch_size) + label_batches = batcher(test_data.label, batch_size=batch_size) + forecast_batches = batcher(forecasts, batch_size=batch_size) + + pbar = tqdm() + for input_batch, label_batch, forecast_batch in zip( + input_batches, label_batches, forecast_batches ): if 0 not in axis: - index_data.append((forecast.item_id, forecast.start_date)) + index_data.extend( + [ + (forecast.item_id, forecast.start_date) + for forecast in forecast_batch + ] + ) data_batch = _get_data_batch( - input_, - label, - forecast, + input_batch, + label_batch, + forecast_batch, seasonality=seasonality, mask_invalid_label=mask_invalid_label, allow_nan_forecast=allow_nan_forecast, @@ -148,6 +166,9 @@ def evaluate_forecasts_raw( for evaluator in evaluators.values(): evaluator.update(data_batch) + pbar.update(len(forecast_batch)) + pbar.close() + metrics_values = { metric_name: evaluator.get() for metric_name, evaluator in evaluators.items() @@ -165,6 +186,7 @@ def evaluate_forecasts( test_data: TestData, metrics, axis: Optional[Union[int, tuple]] = None, + batch_size: int = 100, mask_invalid_label: bool = True, allow_nan_forecast: bool = False, seasonality: Optional[int] = None @@ -188,6 +210,7 @@ def evaluate_forecasts( test_data=test_data, metrics=metrics, axis=axis, + batch_size=batch_size, mask_invalid_label=mask_invalid_label, allow_nan_forecast=allow_nan_forecast, seasonality=seasonality, @@ -217,6 +240,7 @@ def evaluate_model( test_data: TestData, metrics, axis: Optional[Union[int, tuple]] = None, + batch_size: int = 100, mask_invalid_label: bool = True, allow_nan_forecast: bool = False, seasonality: Optional[int] = None @@ -242,6 +266,7 @@ def evaluate_model( test_data=test_data, metrics=metrics, axis=axis, + batch_size=batch_size, mask_invalid_label=mask_invalid_label, allow_nan_forecast=allow_nan_forecast, seasonality=seasonality,