diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 68328c215049e..ffcf497d6f231 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -30,7 +30,7 @@ Union, ) -from flask import current_app, g, request +from flask import current_app, g, has_request_context, request from flask_babel import gettext as _ from jinja2 import DebugUndefined from jinja2.sandbox import SandboxedEnvironment @@ -172,8 +172,9 @@ def url_param( # pylint: disable=import-outside-toplevel from superset.views.utils import get_form_data - if request.args.get(param): + if has_request_context() and request.args.get(param): # type: ignore return request.args.get(param, default) + form_data, _ = get_form_data() url_params = form_data.get("url_params") or {} result = url_params.get(param, default) diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index 19fbef297af70..926b39bd65861 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -46,6 +46,10 @@ def ensure_user_is_set(user_id: Optional[int]) -> None: g.user = security_manager.get_anonymous_user() +def set_form_data(form_data: Dict[str, Any]) -> None: + g.form_data = form_data + + @celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout) def load_chart_data_into_cache( job_metadata: Dict[str, Any], form_data: Dict[str, Any], @@ -55,6 +59,7 @@ def load_chart_data_into_cache( try: ensure_user_is_set(job_metadata.get("user_id")) + set_form_data(form_data) command = ChartDataCommand() command.set_query_context(form_data) result = command.run(cache=True) @@ -86,6 +91,7 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals cache_key_prefix = "ejr-" # ejr: explore_json request try: ensure_user_is_set(job_metadata.get("user_id")) + set_form_data(form_data) datasource_id, datasource_type = get_datasource_info(None, None, form_data) # Perform a deep copy here so that below we can cache the original diff --git a/superset/views/utils.py b/superset/views/utils.py index 6a61f66b0e708..37d83619fe115 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -24,7 +24,7 @@ import msgpack import pyarrow as pa import simplejson as json -from flask import g, request +from flask import g, has_request_context, request from flask_appbuilder.security.sqla import models as ab_models from flask_appbuilder.security.sqla.models import User from flask_babel import _ @@ -130,46 +130,52 @@ def get_form_data( # pylint: disable=too-many-locals slice_id: Optional[int] = None, use_slice_data: bool = False ) -> Tuple[Dict[str, Any], Optional[Slice]]: form_data: Dict[str, Any] = {} - # chart data API requests are JSON - request_json_data = ( - request.json["queries"][0] - if request.is_json and "queries" in request.json - else None - ) - add_sqllab_custom_filters(form_data) - - request_form_data = request.form.get("form_data") - request_args_data = request.args.get("form_data") - if request_json_data: - form_data.update(request_json_data) - if request_form_data: - parsed_form_data = loads_request_json(request_form_data) - # some chart data api requests are form_data - queries = parsed_form_data.get("queries") - if isinstance(queries, list): - form_data.update(queries[0]) - else: - form_data.update(parsed_form_data) - # request params can overwrite the body - if request_args_data: - form_data.update(loads_request_json(request_args_data)) - - # Fallback to using the Flask globals (used for cache warmup) if defined. + if has_request_context(): # type: ignore + # chart data API requests are JSON + request_json_data = ( + request.json["queries"][0] + if request.is_json and "queries" in request.json + else None + ) + + add_sqllab_custom_filters(form_data) + + request_form_data = request.form.get("form_data") + request_args_data = request.args.get("form_data") + if request_json_data: + form_data.update(request_json_data) + if request_form_data: + parsed_form_data = loads_request_json(request_form_data) + # some chart data api requests are form_data + queries = parsed_form_data.get("queries") + if isinstance(queries, list): + form_data.update(queries[0]) + else: + form_data.update(parsed_form_data) + # request params can overwrite the body + if request_args_data: + form_data.update(loads_request_json(request_args_data)) + + # Fallback to using the Flask globals (used for cache warmup and async queries) if not form_data and hasattr(g, "form_data"): form_data = getattr(g, "form_data") - - url_id = request.args.get("r") - if url_id: - saved_url = db.session.query(models.Url).filter_by(id=url_id).first() - if saved_url: - url_str = parse.unquote_plus( - saved_url.url.split("?")[1][10:], encoding="utf-8" - ) - url_form_data = loads_request_json(url_str) - # allow form_date in request override saved url - url_form_data.update(form_data) - form_data = url_form_data + # chart data API requests are JSON + json_data = form_data["queries"][0] if "queries" in form_data else {} + form_data.update(json_data) + + if has_request_context(): # type: ignore + url_id = request.args.get("r") + if url_id: + saved_url = db.session.query(models.Url).filter_by(id=url_id).first() + if saved_url: + url_str = parse.unquote_plus( + saved_url.url.split("?")[1][10:], encoding="utf-8" + ) + url_form_data = loads_request_json(url_str) + # allow form_date in request override saved url + url_form_data.update(form_data) + form_data = url_form_data form_data = {k: v for k, v in form_data.items() if k not in REJECTED_FORM_DATA_KEYS} diff --git a/tests/integration_tests/tasks/async_queries_tests.py b/tests/integration_tests/tasks/async_queries_tests.py index 57b0df5ad2273..3ea1c6f0ce6de 100644 --- a/tests/integration_tests/tasks/async_queries_tests.py +++ b/tests/integration_tests/tasks/async_queries_tests.py @@ -45,7 +45,8 @@ class TestAsyncQueries(SupersetTestCase): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.object(async_query_manager, "update_job") - def test_load_chart_data_into_cache(self, mock_update_job): + @mock.patch.object(async_queries, "set_form_data") + def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job): async_query_manager.init_app(app) query_context = get_query_context("birth_names") user = security_manager.find_user("gamma") @@ -63,6 +64,7 @@ def test_load_chart_data_into_cache(self, mock_update_job): load_chart_data_into_cache(job_metadata, query_context) ensure_user_is_set.assert_called_once_with(user.id) + mock_set_form_data.assert_called_once_with(query_context) mock_update_job.assert_called_once_with( job_metadata, "done", result_url=mock.ANY ) @@ -154,7 +156,10 @@ def test_load_explore_json_into_cache(self, mock_update_job): ) @mock.patch.object(async_query_manager, "update_job") - def test_load_explore_json_into_cache_error(self, mock_update_job): + @mock.patch.object(async_queries, "set_form_data") + def test_load_explore_json_into_cache_error( + self, mock_set_form_data, mock_update_job + ): async_query_manager.init_app(app) user = security_manager.find_user("gamma") form_data = {} @@ -173,6 +178,7 @@ def test_load_explore_json_into_cache_error(self, mock_update_job): load_explore_json_into_cache(job_metadata, form_data) ensure_user_is_set.assert_called_once_with(user.id) + mock_set_form_data.assert_called_once_with(form_data) errors = ["The dataset associated with this chart no longer exists"] mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)