Skip to content

Commit

Permalink
Fix saving/loading ProphetModel (#1019)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Geekman authored Nov 30, 2022
1 parent 3a40f75 commit 6ca6f82
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
- Remove documentation warning about using pickle in saving/loading catboost ([#1020](https://github.com/tinkoff-ai/etna/pull/1020))
- Fix saving/loading ProphetModel ([#1019](https://github.com/tinkoff-ai/etna/pull/1019))
-
-
## [1.13.0] - 2022-10-10
### Added
Expand Down
50 changes: 40 additions & 10 deletions etna/models/prophet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from datetime import datetime
from typing import Dict
from typing import Iterable
Expand All @@ -16,6 +17,8 @@

if SETTINGS.prophet_required:
from prophet import Prophet
from prophet.serialize import model_from_dict
from prophet.serialize import model_to_dict


class _ProphetAdapter(BaseAdapter):
Expand Down Expand Up @@ -62,11 +65,16 @@ def __init__(
self.stan_backend = stan_backend
self.additional_seasonality_params = additional_seasonality_params

self.model = Prophet(
self.model = self._create_model()

self.regressor_columns: Optional[List[str]] = None

def _create_model(self) -> "Prophet":
model = Prophet(
growth=self.growth,
changepoints=changepoints,
n_changepoints=n_changepoints,
changepoint_range=changepoint_range,
changepoints=self.changepoints,
n_changepoints=self.n_changepoints,
changepoint_range=self.changepoint_range,
yearly_seasonality=self.yearly_seasonality,
weekly_seasonality=self.weekly_seasonality,
daily_seasonality=self.daily_seasonality,
Expand All @@ -84,7 +92,7 @@ def __init__(
for seasonality_params in self.additional_seasonality_params:
self.model.add_seasonality(**seasonality_params)

self.regressor_columns: Optional[List[str]] = None
return model

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_ProphetAdapter":
"""
Expand Down Expand Up @@ -154,6 +162,33 @@ def get_model(self) -> Prophet:
"""
return self.model

def __getstate__(self):
state = self.__dict__.copy()
try:
model_dict = model_to_dict(self.model)
is_fitted = True
except ValueError:
is_fitted = False
model_dict = {}
del state["model"]
state["_is_fitted"] = is_fitted
state["_model_dict"] = model_dict
return state

def __setstate__(self, state):
local_state = deepcopy(state)
is_fitted = local_state["_is_fitted"]
model_dict = local_state["_model_dict"]
del local_state["_is_fitted"]
del local_state["_model_dict"]

self.__dict__.update(local_state)

if is_fitted:
self.model = model_from_dict(model_dict)
else:
self.model = self._create_model()


class ProphetModel(
PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel
Expand All @@ -165,11 +200,6 @@ class ProphetModel(
Original Prophet can use features 'cap' and 'floor',
they should be added to the known_future list on dataset initialization.
Warning
-------
Currently, pickle is used in ``save`` and ``load`` methods.
It can work unreliably according to `documentation <https://facebook.github.io/prophet/docs/additional_topics.html>`_.
Examples
--------
>>> from etna.datasets import generate_periodic_df
Expand Down
8 changes: 1 addition & 7 deletions etna/transforms/outliers/point_outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,7 @@ def detect_outliers(self, ts: TSDataset) -> Dict[str, List[pd.Timestamp]]:


class PredictionIntervalOutliersTransform(OutliersTransform):
"""Transform that uses :py:func:`~etna.analysis.outliers.prediction_interval_outliers.get_anomalies_prediction_interval` to find anomalies in data.
Warning
-------
Currently, pickle is used in ``save`` and ``load`` methods.
It can work unreliably according to `documentation <https://facebook.github.io/prophet/docs/additional_topics.html>`_.
"""
"""Transform that uses :py:func:`~etna.analysis.outliers.prediction_interval_outliers.get_anomalies_prediction_interval` to find anomalies in data."""

def __init__(
self,
Expand Down
74 changes: 74 additions & 0 deletions tests/test_models/test_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import pandas as pd
import pytest
from prophet import Prophet
from prophet.serialize import model_to_dict

from etna.datasets.tsdataset import TSDataset
from etna.models import ProphetModel
from etna.models.prophet import _ProphetAdapter
from etna.pipeline import Pipeline
from tests.test_models.utils import assert_model_equals_loaded_original

Expand Down Expand Up @@ -123,6 +125,78 @@ def test_get_model_after_training(example_tsds):
assert isinstance(models_dict[segment], Prophet)


@pytest.fixture
def prophet_default_params():
params = {
"growth": "linear",
"changepoints": None,
"n_changepoints": 25,
"changepoint_range": 0.8,
"yearly_seasonality": "auto",
"weekly_seasonality": "auto",
"daily_seasonality": "auto",
"holidays": None,
"seasonality_mode": "additive",
"seasonality_prior_scale": 10.0,
"holidays_prior_scale": 10.0,
"changepoint_prior_scale": 0.05,
"mcmc_samples": 0,
"interval_width": 0.8,
"uncertainty_samples": 1000,
"stan_backend": None,
"additional_seasonality_params": (),
}
return params


def test_getstate_not_fitted(prophet_default_params):
model = _ProphetAdapter()
state = model.__getstate__()
expected_state = {
"_is_fitted": False,
"_model_dict": {},
"regressor_columns": None,
**prophet_default_params,
}
assert state == expected_state


def test_getstate_fitted(example_tsds, prophet_default_params):
model = _ProphetAdapter()
df = example_tsds.to_pandas()["segment_1"].reset_index()
model.fit(df, regressors=[])
state = model.__getstate__()
expected_state = {
"_is_fitted": True,
"_model_dict": model_to_dict(model.model),
"regressor_columns": [],
**prophet_default_params,
}
assert state == expected_state


def test_setstate_not_fitted():
model_1 = _ProphetAdapter(n_changepoints=25)
initial_state = model_1.__getstate__()

model_2 = _ProphetAdapter(n_changepoints=20)
model_2.__setstate__(initial_state)
new_state = model_2.__getstate__()
assert new_state == initial_state


def test_setstate_fitted(example_tsds):
model_1 = _ProphetAdapter()
df = example_tsds.to_pandas()["segment_1"].reset_index()
model_1.fit(df, regressors=[])
initial_state = model_1.__getstate__()

model_2 = _ProphetAdapter()
model_2.__setstate__(initial_state)
new_state = model_2.__getstate__()
assert new_state == initial_state


@pytest.mark.xfail(reason="Non native serialization, should be fixed in inference-v2.0")
def test_save_load(example_tsds):
model = ProphetModel()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def test_save_load(transform, outliers_solid_tsds):
assert_transformation_equals_loaded_original(transform=transform, ts=outliers_solid_tsds)


@pytest.mark.xfail(reason="Non native serialization, should be fixed in inference-v2.0")
@pytest.mark.parametrize(
"transform",
(
Expand Down

0 comments on commit 6ca6f82

Please sign in to comment.