Skip to content

Commit

Permalink
fix: Support Jinja template functions in global async queries (apache…
Browse files Browse the repository at this point in the history
…#16412)

* Support Jinja template functions in async queries

* Pylint

* Add tests for async tasks

* Remove redundant has_request_context check
  • Loading branch information
robdiciuccio authored and villebro committed Sep 6, 2021
1 parent 793c027 commit e7262bf
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 42 deletions.
5 changes: 3 additions & 2 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
Union,
)

from flask import current_app, g, request
from flask import current_app, g, has_request_context, request
from flask_babel import gettext as _
from jinja2 import DebugUndefined
from jinja2.sandbox import SandboxedEnvironment
Expand Down Expand Up @@ -172,8 +172,9 @@ def url_param(
# pylint: disable=import-outside-toplevel
from superset.views.utils import get_form_data

if request.args.get(param):
if has_request_context() and request.args.get(param): # type: ignore
return request.args.get(param, default)

form_data, _ = get_form_data()
url_params = form_data.get("url_params") or {}
result = url_params.get(param, default)
Expand Down
6 changes: 6 additions & 0 deletions superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def ensure_user_is_set(user_id: Optional[int]) -> None:
g.user = security_manager.get_anonymous_user()


def set_form_data(form_data: Dict[str, Any]) -> None:
g.form_data = form_data


@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
def load_chart_data_into_cache(
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
Expand All @@ -55,6 +59,7 @@ def load_chart_data_into_cache(

try:
ensure_user_is_set(job_metadata.get("user_id"))
set_form_data(form_data)
command = ChartDataCommand()
command.set_query_context(form_data)
result = command.run(cache=True)
Expand Down Expand Up @@ -86,6 +91,7 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
cache_key_prefix = "ejr-" # ejr: explore_json request
try:
ensure_user_is_set(job_metadata.get("user_id"))
set_form_data(form_data)
datasource_id, datasource_type = get_datasource_info(None, None, form_data)

# Perform a deep copy here so that below we can cache the original
Expand Down
82 changes: 44 additions & 38 deletions superset/views/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import msgpack
import pyarrow as pa
import simplejson as json
from flask import g, request
from flask import g, has_request_context, request
from flask_appbuilder.security.sqla import models as ab_models
from flask_appbuilder.security.sqla.models import User
from flask_babel import _
Expand Down Expand Up @@ -130,46 +130,52 @@ def get_form_data( # pylint: disable=too-many-locals
slice_id: Optional[int] = None, use_slice_data: bool = False
) -> Tuple[Dict[str, Any], Optional[Slice]]:
form_data: Dict[str, Any] = {}
# chart data API requests are JSON
request_json_data = (
request.json["queries"][0]
if request.is_json and "queries" in request.json
else None
)

add_sqllab_custom_filters(form_data)

request_form_data = request.form.get("form_data")
request_args_data = request.args.get("form_data")
if request_json_data:
form_data.update(request_json_data)
if request_form_data:
parsed_form_data = loads_request_json(request_form_data)
# some chart data api requests are form_data
queries = parsed_form_data.get("queries")
if isinstance(queries, list):
form_data.update(queries[0])
else:
form_data.update(parsed_form_data)
# request params can overwrite the body
if request_args_data:
form_data.update(loads_request_json(request_args_data))

# Fallback to using the Flask globals (used for cache warmup) if defined.
if has_request_context(): # type: ignore
# chart data API requests are JSON
request_json_data = (
request.json["queries"][0]
if request.is_json and "queries" in request.json
else None
)

add_sqllab_custom_filters(form_data)

request_form_data = request.form.get("form_data")
request_args_data = request.args.get("form_data")
if request_json_data:
form_data.update(request_json_data)
if request_form_data:
parsed_form_data = loads_request_json(request_form_data)
# some chart data api requests are form_data
queries = parsed_form_data.get("queries")
if isinstance(queries, list):
form_data.update(queries[0])
else:
form_data.update(parsed_form_data)
# request params can overwrite the body
if request_args_data:
form_data.update(loads_request_json(request_args_data))

# Fallback to using the Flask globals (used for cache warmup and async queries)
if not form_data and hasattr(g, "form_data"):
form_data = getattr(g, "form_data")

url_id = request.args.get("r")
if url_id:
saved_url = db.session.query(models.Url).filter_by(id=url_id).first()
if saved_url:
url_str = parse.unquote_plus(
saved_url.url.split("?")[1][10:], encoding="utf-8"
)
url_form_data = loads_request_json(url_str)
# allow form_date in request override saved url
url_form_data.update(form_data)
form_data = url_form_data
# chart data API requests are JSON
json_data = form_data["queries"][0] if "queries" in form_data else {}
form_data.update(json_data)

if has_request_context(): # type: ignore
url_id = request.args.get("r")
if url_id:
saved_url = db.session.query(models.Url).filter_by(id=url_id).first()
if saved_url:
url_str = parse.unquote_plus(
saved_url.url.split("?")[1][10:], encoding="utf-8"
)
url_form_data = loads_request_json(url_str)
# allow form_date in request override saved url
url_form_data.update(form_data)
form_data = url_form_data

form_data = {k: v for k, v in form_data.items() if k not in REJECTED_FORM_DATA_KEYS}

Expand Down
10 changes: 8 additions & 2 deletions tests/integration_tests/tasks/async_queries_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
class TestAsyncQueries(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job")
def test_load_chart_data_into_cache(self, mock_update_job):
@mock.patch.object(async_queries, "set_form_data")
def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job):
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
user = security_manager.find_user("gamma")
Expand All @@ -63,6 +64,7 @@ def test_load_chart_data_into_cache(self, mock_update_job):
load_chart_data_into_cache(job_metadata, query_context)

ensure_user_is_set.assert_called_once_with(user.id)
mock_set_form_data.assert_called_once_with(query_context)
mock_update_job.assert_called_once_with(
job_metadata, "done", result_url=mock.ANY
)
Expand Down Expand Up @@ -154,7 +156,10 @@ def test_load_explore_json_into_cache(self, mock_update_job):
)

@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache_error(self, mock_update_job):
@mock.patch.object(async_queries, "set_form_data")
def test_load_explore_json_into_cache_error(
self, mock_set_form_data, mock_update_job
):
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
Expand All @@ -173,6 +178,7 @@ def test_load_explore_json_into_cache_error(self, mock_update_job):
load_explore_json_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id)

mock_set_form_data.assert_called_once_with(form_data)
errors = ["The dataset associated with this chart no longer exists"]
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)

Expand Down

0 comments on commit e7262bf

Please sign in to comment.