From cd88b8e81ea3721f93d0cfdd0329f380a2c3d061 Mon Sep 17 00:00:00 2001 From: Shiva Raisinghani Date: Tue, 7 Dec 2021 23:56:18 -0800 Subject: [PATCH] feat(prophet): enable confidence intervals and y_hat without forecast (#17658) * enable confidence intervals and y_hat without forecast * fix if statement Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com> --- superset/charts/schemas.py | 2 +- superset/utils/pandas_postprocessing.py | 4 ++-- .../pandas_postprocessing_tests.py | 24 ++++++++++++++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 8293065d9ee1b..f68c4ff4b8c95 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -542,7 +542,7 @@ class ChartDataProphetOptionsSchema(ChartDataPostProcessingOperationOptionsSchem ) periods = fields.Integer( descrption="Time periods (in units of `time_grain`) to predict into the future", - min=1, + min=0, example=7, required=True, ) diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index 8ad5099552ff6..ea4e40986bf81 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -820,8 +820,8 @@ def prophet( # pylint: disable=too-many-arguments freq = PROPHET_TIME_GRAIN_MAP[time_grain] # check type at runtime due to marhsmallow schema not being able to handle # union types - if not periods or periods < 0 or not isinstance(periods, int): - raise QueryObjectValidationError(_("Periods must be a positive integer value")) + if not isinstance(periods, int) or periods < 0: + raise QueryObjectValidationError(_("Periods must be a whole number")) if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1: raise QueryObjectValidationError( _("Confidence interval must be between 0 and 1 (exclusive)") diff --git a/tests/integration_tests/pandas_postprocessing_tests.py b/tests/integration_tests/pandas_postprocessing_tests.py index 7221130be85bf..feabef6e2fbaf 100644 --- a/tests/integration_tests/pandas_postprocessing_tests.py +++ b/tests/integration_tests/pandas_postprocessing_tests.py @@ -830,6 +830,28 @@ def test_prophet_valid(self): assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31) assert len(df) == 9 + def test_prophet_valid_zero_periods(self): + pytest.importorskip("prophet") + + df = proc.prophet( + df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9 + ) + columns = {column for column in df.columns} + assert columns == { + DTTM_ALIAS, + "a__yhat", + "a__yhat_upper", + "a__yhat_lower", + "a", + "b__yhat", + "b__yhat_upper", + "b__yhat_lower", + "b", + } + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) + assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31) + assert len(df) == 4 + def test_prophet_import(self): prophet = find_spec("prophet") if prophet is None: @@ -875,7 +897,7 @@ def test_prophet_incorrect_periods(self): proc.prophet, df=prophet_df, time_grain="P1M", - periods=0, + periods=-1, confidence_interval=0.8, )