Skip to content

Implement tests on inference scenarios for models #1082

Merged
merged 10 commits into from
Jan 31, 2023
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
-
-
-
- Add more scenarios into tests for models ([#1082](https://github.com/tinkoff-ai/etna/pull/1082))
-
-
-
Expand Down
2 changes: 2 additions & 0 deletions etna/models/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ def _make_predictions(self, ts: TSDataset, prediction_method: Callable, **kwargs
df = ts.to_pandas()
models = self._get_model()
for segment in ts.segments:
if segment not in models:
raise NotImplementedError("Per-segment models can't make predictions on new segments!")
segment_model = models[segment]
segment_predict = self._make_predictions_segment(
model=segment_model, segment=segment, df=df, prediction_method=prediction_method, **kwargs
Expand Down
18 changes: 17 additions & 1 deletion tests/test_models/test_inference/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import functools
from typing import List

import numpy as np
import pandas as pd
import pytest
from typing_extensions import get_args

Expand Down Expand Up @@ -29,6 +31,16 @@ def make_prediction(model, ts, prediction_size, method_name) -> TSDataset:
return ts


def select_segments_subset(ts: TSDataset, segments: List[str]) -> TSDataset:
df = ts.raw_df.loc[:, pd.IndexSlice[segments, :]]
df_exog = ts.df_exog
if df_exog is not None:
df_exog = df_exog.loc[:, pd.IndexSlice[segments, :]]
known_future = ts.known_future
freq = ts.freq
return TSDataset(df=df, df_exog=df_exog, known_future=known_future, freq=freq)


def _test_prediction_in_sample_full(ts, model, transforms, method_name):
df = ts.to_pandas()

Expand All @@ -45,6 +57,8 @@ def _test_prediction_in_sample_full(ts, model, transforms, method_name):
# checking
forecast_df = forecast_ts.to_pandas(flatten=True)
assert not np.any(forecast_df["target"].isna())
original_target = TSDataset.to_flatten(df)["target"]
assert not forecast_df["target"].equals(original_target)


def _test_prediction_in_sample_suffix(ts, model, transforms, method_name, num_skip_points):
Expand All @@ -57,10 +71,12 @@ def _test_prediction_in_sample_suffix(ts, model, transforms, method_name, num_sk
# forecasting
forecast_ts = TSDataset(df, freq="D")
forecast_ts.transform(ts.transforms)
forecast_ts.df = forecast_ts.df.iloc[(num_skip_points - model.context_size) :]
prediction_size = len(forecast_ts.index) - num_skip_points
forecast_ts.df = forecast_ts.df.iloc[(num_skip_points - model.context_size) :]
forecast_ts = make_prediction(model=model, ts=forecast_ts, prediction_size=prediction_size, method_name=method_name)

# checking
forecast_df = forecast_ts.to_pandas(flatten=True)
assert not np.any(forecast_df["target"].isna())
original_target = TSDataset.to_flatten(df.iloc[(num_skip_points - model.context_size) :])["target"]
assert not forecast_df["target"].equals(original_target)
Loading