Skip to content

Make in-sample predictions of SARIMAXModel non-dynamic in all cases #812

Merged
merged 3 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Make in-sample predictions of SARIMAXModel non-dynamic in all cases ([#812](https://github.com/tinkoff-ai/etna/pull/812))
- Teach BATS/TBATS to work with in-sample, out-sample predictions correctly ([#806](https://github.com/tinkoff-ai/etna/pull/806))
-
- Github actions cache issue with poetry update ([#778](https://github.com/tinkoff-ai/etna/pull/778))
Expand Down
2 changes: 1 addition & 1 deletion etna/models/sarimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequen
y_pred[f"mean_{quantile:.4g}"] = series
else:
forecast = self._result.get_prediction(
start=df["timestamp"].min(), end=df["timestamp"].max(), dynamic=True, exog=exog_future
start=df["timestamp"].min(), end=df["timestamp"].max(), dynamic=False, exog=exog_future
)
y_pred = forecast.predicted_mean
y_pred.name = "mean"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def test_forecast_out_sample_suffix_failed(model, transforms, example_tsds):
(ElasticPerSegmentModel(), [LagTransform(in_column="target", lags=[5, 6])]),
(ElasticMultiSegmentModel(), [LagTransform(in_column="target", lags=[5, 6])]),
(ProphetModel(), []),
(SARIMAXModel(), []),
(HoltModel(), []),
(HoltWintersModel(), []),
(SimpleExpSmoothingModel(), []),
Expand All @@ -431,7 +432,6 @@ def test_forecast_mixed_in_out_sample(model, transforms, example_tsds):
@pytest.mark.parametrize(
"model, transforms",
[
(SARIMAXModel(), []),
(AutoARIMAModel(), []),
(
DeepARModel(max_epochs=5, learning_rate=[0.01]),
Expand Down