Skip to content

Fix saving/loading ProphetModel #1019

Merged
merged 4 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
- Remove documentation warning about using pickle in saving/loading catboost ([#1020](https://github.com/tinkoff-ai/etna/pull/1020))
- Fix saving/loading ProphetModel ([#1019](https://github.com/tinkoff-ai/etna/pull/1019))
-
-
## [1.13.0] - 2022-10-10
### Added
Expand Down
50 changes: 40 additions & 10 deletions etna/models/prophet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from datetime import datetime
from typing import Dict
from typing import Iterable
Expand All @@ -16,6 +17,8 @@

if SETTINGS.prophet_required:
from prophet import Prophet
from prophet.serialize import model_from_dict
from prophet.serialize import model_to_dict


class _ProphetAdapter(BaseAdapter):
Expand Down Expand Up @@ -62,11 +65,16 @@ def __init__(
self.stan_backend = stan_backend
self.additional_seasonality_params = additional_seasonality_params

self.model = Prophet(
self.model = self._create_model()

self.regressor_columns: Optional[List[str]] = None

def _create_model(self) -> "Prophet":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add description to the method

model = Prophet(
growth=self.growth,
changepoints=changepoints,
n_changepoints=n_changepoints,
changepoint_range=changepoint_range,
changepoints=self.changepoints,
n_changepoints=self.n_changepoints,
changepoint_range=self.changepoint_range,
yearly_seasonality=self.yearly_seasonality,
weekly_seasonality=self.weekly_seasonality,
daily_seasonality=self.daily_seasonality,
Expand All @@ -84,7 +92,7 @@ def __init__(
for seasonality_params in self.additional_seasonality_params:
self.model.add_seasonality(**seasonality_params)

self.regressor_columns: Optional[List[str]] = None
return model

def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_ProphetAdapter":
"""
Expand Down Expand Up @@ -154,6 +162,33 @@ def get_model(self) -> Prophet:
"""
return self.model

def __getstate__(self):
state = self.__dict__.copy()
try:
model_dict = model_to_dict(self.model)
is_fitted = True
except ValueError:
is_fitted = False
model_dict = {}
del state["model"]
state["_is_fitted"] = is_fitted
state["_model_dict"] = model_dict
return state

def __setstate__(self, state):
local_state = deepcopy(state)
is_fitted = local_state["_is_fitted"]
model_dict = local_state["_model_dict"]
del local_state["_is_fitted"]
del local_state["_model_dict"]

self.__dict__.update(local_state)

if is_fitted:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be we can move model creation logic to _create_model method, i.e. add optional parameter state

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really think that it is a good idea. Now _create_model doesn't know anything about serialization. Just handles creation of model using hyperparameters and processing additional_seasonalities (I added this method mostly because of them).

self.model = model_from_dict(model_dict)
else:
self.model = self._create_model()


class ProphetModel(
PerSegmentModelMixin, PredictionIntervalContextIgnorantModelMixin, PredictionIntervalContextIgnorantAbstractModel
Expand All @@ -165,11 +200,6 @@ class ProphetModel(
Original Prophet can use features 'cap' and 'floor',
they should be added to the known_future list on dataset initialization.

Warning
-------
Currently, pickle is used in ``save`` and ``load`` methods.
It can work unreliably according to `documentation <https://facebook.github.io/prophet/docs/additional_topics.html>`_.

Examples
--------
>>> from etna.datasets import generate_periodic_df
Expand Down
8 changes: 1 addition & 7 deletions etna/transforms/outliers/point_outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,7 @@ def detect_outliers(self, ts: TSDataset) -> Dict[str, List[pd.Timestamp]]:


class PredictionIntervalOutliersTransform(OutliersTransform):
"""Transform that uses :py:func:`~etna.analysis.outliers.prediction_interval_outliers.get_anomalies_prediction_interval` to find anomalies in data.

Warning
-------
Currently, pickle is used in ``save`` and ``load`` methods.
It can work unreliably according to `documentation <https://facebook.github.io/prophet/docs/additional_topics.html>`_.
"""
"""Transform that uses :py:func:`~etna.analysis.outliers.prediction_interval_outliers.get_anomalies_prediction_interval` to find anomalies in data."""

def __init__(
self,
Expand Down
74 changes: 74 additions & 0 deletions tests/test_models/test_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import pandas as pd
import pytest
from prophet import Prophet
from prophet.serialize import model_to_dict

from etna.datasets.tsdataset import TSDataset
from etna.models import ProphetModel
from etna.models.prophet import _ProphetAdapter
from etna.pipeline import Pipeline
from tests.test_models.utils import assert_model_equals_loaded_original

Expand Down Expand Up @@ -123,6 +125,78 @@ def test_get_model_after_training(example_tsds):
assert isinstance(models_dict[segment], Prophet)


@pytest.fixture
def prophet_default_params():
params = {
"growth": "linear",
"changepoints": None,
"n_changepoints": 25,
"changepoint_range": 0.8,
"yearly_seasonality": "auto",
"weekly_seasonality": "auto",
"daily_seasonality": "auto",
"holidays": None,
"seasonality_mode": "additive",
"seasonality_prior_scale": 10.0,
"holidays_prior_scale": 10.0,
"changepoint_prior_scale": 0.05,
"mcmc_samples": 0,
"interval_width": 0.8,
"uncertainty_samples": 1000,
"stan_backend": None,
"additional_seasonality_params": (),
}
return params


def test_getstate_not_fitted(prophet_default_params):
model = _ProphetAdapter()
state = model.__getstate__()
expected_state = {
"_is_fitted": False,
"_model_dict": {},
"regressor_columns": None,
**prophet_default_params,
}
assert state == expected_state


def test_getstate_fitted(example_tsds, prophet_default_params):
model = _ProphetAdapter()
df = example_tsds.to_pandas()["segment_1"].reset_index()
model.fit(df, regressors=[])
state = model.__getstate__()
expected_state = {
"_is_fitted": True,
"_model_dict": model_to_dict(model.model),
"regressor_columns": [],
**prophet_default_params,
}
assert state == expected_state


def test_setstate_not_fitted():
model_1 = _ProphetAdapter(n_changepoints=25)
initial_state = model_1.__getstate__()

model_2 = _ProphetAdapter(n_changepoints=20)
model_2.__setstate__(initial_state)
new_state = model_2.__getstate__()
assert new_state == initial_state


def test_setstate_fitted(example_tsds):
model_1 = _ProphetAdapter()
df = example_tsds.to_pandas()["segment_1"].reset_index()
model_1.fit(df, regressors=[])
initial_state = model_1.__getstate__()

model_2 = _ProphetAdapter()
model_2.__setstate__(initial_state)
new_state = model_2.__getstate__()
assert new_state == initial_state


@pytest.mark.xfail(reason="Non native serialization, should be fixed in inference-v2.0")
def test_save_load(example_tsds):
model = ProphetModel()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def test_save_load(transform, outliers_solid_tsds):
assert_transformation_equals_loaded_original(transform=transform, ts=outliers_solid_tsds)


@pytest.mark.xfail(reason="Non native serialization, should be fixed in inference-v2.0")
@pytest.mark.parametrize(
"transform",
(
Expand Down