Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ev: Batch data for faster evaluation. #3051

Merged
merged 4 commits into from
Nov 13, 2023
Merged

Conversation

lostella
Copy link
Contributor

@lostella lostella commented Nov 10, 2023

Description of changes: Add batching to gluonts.model.evaluation, reaching speedups of ~6x in some cases (speedup depends also on how many metrics are being evaluated, more metrics => more speedup). The following example is using m4_daily data:

4227it [00:46, 91.83it/s]
evaluation (batch_size=1): 46.04444639701978
4227it [00:27, 152.03it/s]
evaluation (batch_size=2): 27.80498409701977
4227it [00:15, 277.03it/s]
evaluation (batch_size=5): 15.259674712986453
4227it [00:11, 371.23it/s]
evaluation (batch_size=10): 11.387762713013217
4227it [00:09, 451.05it/s]
evaluation (batch_size=20): 9.372666671988554
4227it [00:08, 512.34it/s]
evaluation (batch_size=50): 8.252159972995287
4227it [00:07, 546.00it/s]
evaluation (batch_size=100): 7.743695029988885
4227it [00:07, 561.83it/s]
evaluation (batch_size=200): 7.524872001988115
4227it [00:07, 576.23it/s]
evaluation (batch_size=500): 7.336895904998528

Code:

import timeit
import numpy as np
import pandas as pd

from gluonts.ev.metrics import (
    SMAPE,
    MASE,
    NRMSE,
    ND,
    MeanWeightedSumQuantileLoss,
    AverageMeanScaledQuantileLoss,
    MAECoverage,
)
from gluonts.dataset.repository import get_dataset
from gluonts.dataset.split import split
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model import SampleForecast


quantile_levels = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

METRICS = [
    SMAPE(),
    MASE(),
    NRMSE(),
    ND(),
    MeanWeightedSumQuantileLoss(quantile_levels=quantile_levels),
    AverageMeanScaledQuantileLoss(quantile_levels=quantile_levels),
    MAECoverage(quantile_levels=quantile_levels),
]


if __name__ == "__main__":
    data = list(get_dataset("m4_daily").test)
    test_data = split(data, offset=-14)[1].generate_instances(
        prediction_length=14
    )
    inputs = list(test_data.input)
    labels = list(test_data.label)

    forecasts = [
        SampleForecast(
            start_date=entry["start"],
            samples=np.random.normal(size=(100, 14)),
        )
        for entry in labels
    ]

    batch_sizes = [1, 2, 5, 10, 20, 50, 100, 200, 500]

    ref = None

    for batch_size in batch_sizes:
        t0 = timeit.default_timer()
        res = evaluate_forecasts(
            forecasts,
            test_data=test_data,
            metrics=METRICS,
            batch_size=batch_size,
            seasonality=7,
        )
        t1 = timeit.default_timer()
        if ref is None:
            ref = res
        else:
            pd.testing.assert_frame_equal(ref, res)
        print(f"evaluation ({batch_size=}): {t1 - t0}")

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

@lostella lostella added enhancement New feature or request BREAKING This is a breaking change (one of pr required labels) labels Nov 10, 2023
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
seasonality: Optional[int] = None
seasonality: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the default being changed to 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point: I had initially done this because I was also computing the seasonal error in batches, that didn’t play well with potentially different seasonalities in a batch; turns out batching there is anyway problematic because of different length, so I should indeed be able to maintain the current behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

label_batches = batcher(test_data.label, batch_size=batch_size)
forecast_batches = batcher(forecasts, batch_size=batch_size)

pbar = tqdm()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also set the total length here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but depending on what dataset type one is using, this may end up iterating the dataset to get its length. I don’t think the current code does it, so maybe let’s do it in a separate change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you could try to call len on it as long as it doesn't consume it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the idea. However this may still take time (say the dataset is huge and len needs to iterate it), which I would like to avoid at least in this PR

@jaheba jaheba changed the title Speed up evaluation 6x Speed up evaluation by batching data. Nov 13, 2023
@jaheba jaheba changed the title Speed up evaluation by batching data. ev: Batch data for faster evaluation. Nov 13, 2023
@lostella lostella merged commit 56eec11 into awslabs:dev Nov 13, 2023
19 checks passed
@lostella lostella deleted the faster-eval branch November 13, 2023 10:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
BREAKING This is a breaking change (one of pr required labels) enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants