From a01c4c95db9bd099758b5cf49119e4ad536613e8 Mon Sep 17 00:00:00 2001 From: Grace Guo Date: Thu, 16 Dec 2021 04:12:43 -0800 Subject: [PATCH] fix: [alert] should run alert query from report account (#17499) * fix: [alert] should run alert query from report account * add solution2: override username for get_df * add integration test --- superset/models/core.py | 5 +++-- superset/reports/commands/alert.py | 7 +++++-- tests/integration_tests/model_tests.py | 12 ++++++++++++ 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/superset/models/core.py b/superset/models/core.py index 3aaa35769bd27..fbccb38f60e10 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -397,11 +397,12 @@ def get_df( # pylint: disable=too-many-locals sql: str, schema: Optional[str] = None, mutator: Optional[Callable[[pd.DataFrame], None]] = None, + username: Optional[str] = None, ) -> pd.DataFrame: sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)] - engine = self.get_sqla_engine(schema=schema) - username = utils.get_username() + engine = self.get_sqla_engine(schema=schema, user_name=username) + username = utils.get_username() or username def needs_conversion(df_series: pd.Series) -> bool: return ( diff --git a/superset/reports/commands/alert.py b/superset/reports/commands/alert.py index a13916f59ceb9..e00ac9f2df5c1 100644 --- a/superset/reports/commands/alert.py +++ b/superset/reports/commands/alert.py @@ -25,7 +25,7 @@ from celery.exceptions import SoftTimeLimitExceeded from flask_babel import lazy_gettext as _ -from superset import jinja_context +from superset import app, jinja_context from superset.commands.base import BaseCommand from superset.models.reports import ReportSchedule, ReportScheduleValidatorType from superset.reports.commands.exceptions import ( @@ -146,8 +146,11 @@ def _execute_query(self) -> pd.DataFrame: limited_rendered_sql = self._report_schedule.database.apply_limit_to_sql( rendered_sql, ALERT_SQL_LIMIT ) + query_username = app.config["THUMBNAIL_SELENIUM_USER"] start = default_timer() - df = self._report_schedule.database.get_df(limited_rendered_sql) + df = self._report_schedule.database.get_df( + sql=limited_rendered_sql, username=query_username + ) stop = default_timer() logger.info( "Query for %s took %.2f ms", diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index c7f7b0ce1b222..7ffb173cd35fd 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -340,6 +340,18 @@ def test_multi_statement(self): df = main_db.get_df("USE superset; SELECT ';';", None) self.assertEqual(df.iat[0, 0], ";") + @mock.patch("superset.models.core.Database.get_sqla_engine") + def test_username_param(self, mocked_get_sqla_engine): + main_db = get_example_database() + main_db.impersonate_user = True + test_username = "test_username_param" + + if main_db.backend == "mysql": + main_db.get_df("USE superset; SELECT 1", username=test_username) + mocked_get_sqla_engine.assert_called_with( + schema=None, user_name="test_username_param", + ) + @mock.patch("superset.models.core.create_engine") def test_get_sqla_engine(self, mocked_create_engine): model = Database(