diff --git a/requirements/testing.in b/requirements/testing.in index 9a40c90753da1..856c5272dc0b2 100644 --- a/requirements/testing.in +++ b/requirements/testing.in @@ -16,7 +16,7 @@ # -r development.in -r integration.in --e file:.[bigquery,hive,presto,trino] +-e file:.[bigquery,hive,presto,prophet,trino] docker flask-testing freezegun diff --git a/requirements/testing.txt b/requirements/testing.txt index efa48fcad9398..a97aaff8524ce 100644 --- a/requirements/testing.txt +++ b/requirements/testing.txt @@ -1,4 +1,4 @@ -# SHA1:623feb0dd2b6bd376238ecf75069bc82136c2d70 +# SHA1:78fe89f88adf34ac75513d363d7d9d0b5cc8cd1c # # This file is autogenerated by pip-compile-multi # To update, run: @@ -12,16 +12,26 @@ # -r requirements/base.in # -r requirements/development.in # -r requirements/testing.in +cmdstanpy==1.1.0 + # via prophet +contourpy==1.0.7 + # via matplotlib coverage[toml]==7.2.5 # via pytest-cov +cycler==0.11.0 + # via matplotlib db-dtypes==1.1.1 # via pandas-gbq docker==6.1.1 # via -r requirements/testing.in +ephem==4.1.4 + # via lunarcalendar exceptiongroup==1.1.1 # via pytest flask-testing==0.8.1 # via -r requirements/testing.in +fonttools==4.39.4 + # via matplotlib freezegun==1.2.2 # via -r requirements/testing.in google-api-core[grpc]==2.11.0 @@ -73,6 +83,12 @@ iniconfig==2.0.0 # via pytest jsonschema-spec==0.1.4 # via openapi-spec-validator +kiwisolver==1.4.4 + # via matplotlib +lunarcalendar==0.0.9 + # via prophet +matplotlib==3.7.1 + # via prophet oauthlib==3.2.2 # via requests-oauthlib openapi-schema-validator==0.4.4 @@ -85,6 +101,8 @@ parameterized==0.9.0 # via -r requirements/testing.in pathable==0.4.3 # via jsonschema-spec +prophet==1.1.3 + # via apache-superset proto-plus==1.22.2 # via # google-cloud-bigquery @@ -107,8 +125,6 @@ pydata-google-auth==1.7.0 # via pandas-gbq pyfakefs==5.2.2 # via -r requirements/testing.in -pyhive[presto]==0.6.5 - # via apache-superset pytest==7.3.1 # via # -r requirements/testing.in @@ -130,6 +146,10 @@ sqlalchemy-bigquery==1.6.1 # via apache-superset statsd==4.0.1 # via -r requirements/testing.in +tqdm==4.65.0 + # via + # cmdstanpy + # prophet trino==0.324.0 # via apache-superset tzdata==2023.3 diff --git a/setup.py b/setup.py index 5bd67ca0c2a3e..d07a1ea0a28f5 100644 --- a/setup.py +++ b/setup.py @@ -176,7 +176,7 @@ def get_git_sha() -> str: "postgres": ["psycopg2-binary==2.9.6"], "presto": ["pyhive[presto]>=0.6.5"], "trino": ["trino>=0.324.0"], - "prophet": ["prophet>=1.0.1, <1.1", "pystan<3.0"], + "prophet": ["prophet>=1.1.0, <2.0.0"], "redshift": ["sqlalchemy-redshift>=0.8.1, < 0.9"], "rockset": ["rockset>=0.8.10, <0.9"], "shillelagh": [ diff --git a/superset/utils/pandas_postprocessing/prophet.py b/superset/utils/pandas_postprocessing/prophet.py index 6d733296adf54..a23f7838bf838 100644 --- a/superset/utils/pandas_postprocessing/prophet.py +++ b/superset/utils/pandas_postprocessing/prophet.py @@ -17,6 +17,7 @@ import logging from typing import Optional, Union +import pandas as pd from flask_babel import gettext as _ from pandas import DataFrame @@ -134,7 +135,13 @@ def prophet( # pylint: disable=too-many-arguments raise InvalidPostProcessingError(_("DataFrame include at least one series")) target_df = DataFrame() - for column in [column for column in df.columns if column != index]: + + for column in [ + column + for column in df.columns + if column != index + and pd.to_numeric(df[column], errors="coerce").notnull().all() + ]: fit_df = _prophet_fit_and_predict( df=df[[index, column]].rename(columns={index: "ds", column: "y"}), confidence_interval=confidence_interval, diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 4c8d1996c958b..da3a28f1ba81b 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -444,11 +444,11 @@ def test_chart_data_dttm_filter(self): else: raise Exception("ds column not found") + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_prophet(self): """ Chart data API: Ensure prophet post transformation works """ - pytest.importorskip("prophet") time_grain = "P1Y" self.query_context_payload["queries"][0]["is_timeseries"] = True self.query_context_payload["queries"][0]["groupby"] = [] @@ -476,7 +476,7 @@ def test_chart_data_prophet(self): self.assertIn("sum__num__yhat", row) self.assertIn("sum__num__yhat_upper", row) self.assertIn("sum__num__yhat_lower", row) - self.assertEqual(result["rowcount"], 47) + self.assertEqual(result["rowcount"], 103) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_invalid_post_processing(self): diff --git a/tests/unit_tests/pandas_postprocessing/test_prophet.py b/tests/unit_tests/pandas_postprocessing/test_prophet.py index 6da3a7a591a3d..4d9acdb066998 100644 --- a/tests/unit_tests/pandas_postprocessing/test_prophet.py +++ b/tests/unit_tests/pandas_postprocessing/test_prophet.py @@ -27,8 +27,6 @@ def test_prophet_valid(): - pytest.importorskip("prophet") - df = prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9) columns = {column for column in df.columns} assert columns == { @@ -113,8 +111,6 @@ def test_prophet_valid(): def test_prophet_valid_zero_periods(): - pytest.importorskip("prophet") - df = prophet(df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9) columns = {column for column in df.columns} assert columns == {