Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ev: Batch data for faster evaluation. #3051

Merged
merged 4 commits into from
Nov 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 59 additions & 34 deletions src/gluonts/model/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import logging
from collections import ChainMap
from typing import Iterable, Optional, Union
from typing import Iterable, List, Optional, Union
from dataclasses import dataclass
from toolz import first, valmap

Expand All @@ -23,10 +23,10 @@

from gluonts.dataset import DataEntry
from gluonts.dataset.split import TestData
from gluonts.time_feature.seasonality import get_seasonality
from gluonts.model import Forecast, Predictor
from gluonts.ev.ts_stats import seasonal_error
from gluonts.itertools import prod
from gluonts.itertools import batcher, prod
from gluonts.model import Forecast, Predictor
from gluonts.time_feature.seasonality import get_seasonality

logger = logging.getLogger(__name__)

Expand All @@ -39,53 +39,60 @@ class BatchForecast:
``gluonts.ev``.
"""

forecast: Forecast
forecasts: List[Forecast]
allow_nan: bool = False

def __getitem__(self, name):
value = self.forecast[name]
if np.isnan(value).any():
values = [forecast[name].T for forecast in self.forecasts]
res = np.stack(values, axis=0)

if np.isnan(res).any():
if not self.allow_nan:
raise ValueError("Forecast contains NaN values")

logger.warning(
"Forecast contains NaN values. Metrics may be incorrect."
)

return np.expand_dims(value.T, axis=0)
return res


def _get_data_batch(
input_: DataEntry,
label: DataEntry,
forecast: Forecast,
input_batch: List[DataEntry],
label_batch: List[DataEntry],
forecast_batch: List[Forecast],
seasonality: Optional[int] = None,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
) -> ChainMap:
forecast_dict = BatchForecast(forecast, allow_nan=allow_nan_forecast)

freq = forecast.start_date.freqstr
if seasonality is None:
seasonality = get_seasonality(freq=freq)

label_target = label["target"]
input_target = input_["target"]
label_target = np.stack([label["target"] for label in label_batch], axis=0)
if mask_invalid_label:
label_target = np.ma.masked_invalid(label_target)
input_target = np.ma.masked_invalid(input_target)

other_data = {
"label": np.expand_dims(label_target, axis=0),
"seasonal_error": np.expand_dims(
seasonal_error(
input_target, seasonality=seasonality, time_axis=-1
),
axis=0,
),
"label": label_target,
}

return ChainMap(other_data, forecast_dict) # type: ignore
seasonal_error_values = []
for input_ in input_batch:
seasonality_entry = seasonality
if seasonality_entry is None:
seasonality_entry = get_seasonality(input_["start"].freqstr)
input_target = input_["target"]
if mask_invalid_label:
input_target = np.ma.masked_invalid(input_target)
seasonal_error_values.append(
seasonal_error(
input_target,
seasonality=seasonality_entry,
time_axis=-1,
)
)
other_data["seasonal_error"] = np.array(seasonal_error_values)

return ChainMap(
other_data, BatchForecast(forecast_batch, allow_nan=allow_nan_forecast) # type: ignore
)


def evaluate_forecasts_raw(
Expand All @@ -94,6 +101,7 @@ def evaluate_forecasts_raw(
test_data: TestData,
metrics,
axis: Optional[Union[int, tuple]] = None,
batch_size: int = 100,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
seasonality: Optional[int] = None
Expand Down Expand Up @@ -130,16 +138,26 @@ def evaluate_forecasts_raw(

index_data = []

for input_, label, forecast in tqdm(
zip(test_data.input, test_data.label, forecasts)
input_batches = batcher(test_data.input, batch_size=batch_size)
label_batches = batcher(test_data.label, batch_size=batch_size)
forecast_batches = batcher(forecasts, batch_size=batch_size)

pbar = tqdm()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also set the total length here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but depending on what dataset type one is using, this may end up iterating the dataset to get its length. I don’t think the current code does it, so maybe let’s do it in a separate change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you could try to call len on it as long as it doesn't consume it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the idea. However this may still take time (say the dataset is huge and len needs to iterate it), which I would like to avoid at least in this PR

for input_batch, label_batch, forecast_batch in zip(
input_batches, label_batches, forecast_batches
):
if 0 not in axis:
index_data.append((forecast.item_id, forecast.start_date))
index_data.extend(
[
(forecast.item_id, forecast.start_date)
for forecast in forecast_batch
]
)

data_batch = _get_data_batch(
input_,
label,
forecast,
input_batch,
label_batch,
forecast_batch,
seasonality=seasonality,
mask_invalid_label=mask_invalid_label,
allow_nan_forecast=allow_nan_forecast,
Expand All @@ -148,6 +166,9 @@ def evaluate_forecasts_raw(
for evaluator in evaluators.values():
evaluator.update(data_batch)

pbar.update(len(forecast_batch))
pbar.close()

metrics_values = {
metric_name: evaluator.get()
for metric_name, evaluator in evaluators.items()
Expand All @@ -165,6 +186,7 @@ def evaluate_forecasts(
test_data: TestData,
metrics,
axis: Optional[Union[int, tuple]] = None,
batch_size: int = 100,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
seasonality: Optional[int] = None
Expand All @@ -188,6 +210,7 @@ def evaluate_forecasts(
test_data=test_data,
metrics=metrics,
axis=axis,
batch_size=batch_size,
mask_invalid_label=mask_invalid_label,
allow_nan_forecast=allow_nan_forecast,
seasonality=seasonality,
Expand Down Expand Up @@ -217,6 +240,7 @@ def evaluate_model(
test_data: TestData,
metrics,
axis: Optional[Union[int, tuple]] = None,
batch_size: int = 100,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
seasonality: Optional[int] = None
Expand All @@ -242,6 +266,7 @@ def evaluate_model(
test_data=test_data,
metrics=metrics,
axis=axis,
batch_size=batch_size,
mask_invalid_label=mask_invalid_label,
allow_nan_forecast=allow_nan_forecast,
seasonality=seasonality,
Expand Down
Loading