diff --git a/superset/jinja_context.py b/superset/jinja_context.py index e4a83422315d8..604e26b1dbb36 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -25,7 +25,7 @@ from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Union import dateutil -from flask import current_app, has_request_context, request +from flask import current_app, g, has_request_context, request from flask_babel import gettext as _ from jinja2 import DebugUndefined, Environment from jinja2.sandbox import SandboxedEnvironment @@ -847,35 +847,45 @@ def dataset_macro( def get_dataset_id_from_context(metric_key: str) -> int: """ - Retrives the Dataset ID from the request context. + Retrieves the Dataset ID from the request context. :param metric_key: the metric key. :returns: the dataset ID. """ # pylint: disable=import-outside-toplevel from superset.daos.chart import ChartDAO - from superset.views.utils import get_form_data + from superset.views.utils import loads_request_json + form_data: dict[str, Any] = {} exc_message = _( "Please specify the Dataset ID for the ``%(name)s`` metric in the Jinja macro.", name=metric_key, ) - form_data, chart = get_form_data() - if not (form_data or chart): - raise SupersetTemplateException(exc_message) + if has_request_context(): + if payload := request.get_json(cache=True) if request.is_json else None: + if dataset_id := payload.get("datasource", {}).get("id"): + return dataset_id + form_data.update(payload.get("form_data", {})) + request_form = loads_request_json(request.form.get("form_data")) + form_data.update(request_form) + request_args = loads_request_json(request.args.get("form_data")) + form_data.update(request_args) + + if form_data := (form_data or getattr(g, "form_data", {})): + if datasource_info := form_data.get("datasource"): + if isinstance(datasource_info, dict): + return datasource_info["id"] + return datasource_info.split("__")[0] + url_params = form_data.get("queries", [{}])[0].get("url_params", {}) + if dataset_id := url_params.get("datasource_id"): + return dataset_id + if chart_id := (form_data.get("slice_id") or url_params.get("slice_id")): + chart_data = ChartDAO.find_by_id(chart_id) + if not chart_data: + raise SupersetTemplateException(exc_message) + return chart_data.datasource_id - if chart and chart.datasource_id: - return chart.datasource_id - if dataset_id := form_data.get("url_params", {}).get("datasource_id"): - return dataset_id - if chart_id := ( - form_data.get("slice_id") or form_data.get("url_params", {}).get("slice_id") - ): - chart_data = ChartDAO.find_by_id(chart_id) - if not chart_data: - raise SupersetTemplateException(exc_message) - return chart_data.datasource_id raise SupersetTemplateException(exc_message) diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 922cbf67fd65e..2d7f6bf041bdd 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -200,8 +200,8 @@ def test_jinja_metrics_and_calc_columns(self, mock_username): db.session.delete(table) db.session.commit() - @patch("superset.views.utils.get_form_data") - def test_jinja_metric_macro(self, mock_form_data_context): + @patch("superset.jinja_context.get_dataset_id_from_context") + def test_jinja_metric_macro(self, mock_dataset_id_from_context): self.login(username="admin") table = self.get_table(name="birth_names") metric = SqlMetric( @@ -234,14 +234,8 @@ def test_jinja_metric_macro(self, mock_form_data_context): "filter": [], "extras": {"time_grain_sqla": "P1D"}, } - mock_form_data_context.return_value = [ - { - "url_params": { - "datasource_id": table.id, - } - }, - None, - ] + mock_dataset_id_from_context.return_value = table.id + sqla_query = table.get_sqla_query(**base_query_obj) query = table.database.compile_sqla_query(sqla_query.sqla_query) diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py index ced40c8119dea..391ead3f46277 100644 --- a/tests/unit_tests/jinja_context_test.py +++ b/tests/unit_tests/jinja_context_test.py @@ -584,15 +584,15 @@ def test_metric_macro_no_dataset_id_no_context(mocker: MockerFixture) -> None: not available in the context. """ DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") - mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") - mock_get_form_data.return_value = [None, None] - with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") - assert str(excinfo.value) == ( - "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." - ) - mock_get_form_data.assert_called_once() - DatasetDAO.find_by_id.assert_not_called() + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {} + with app.test_request_context(): + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) + DatasetDAO.find_by_id.assert_not_called() def test_metric_macro_no_dataset_id_with_context_missing_info( @@ -603,20 +603,31 @@ def test_metric_macro_no_dataset_id_with_context_missing_info( has context but no dataset/chart ID. """ DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") - mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") - mock_get_form_data.return_value = [ - { - "url_params": {}, - }, - None, - ] - with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") - assert str(excinfo.value) == ( - "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." - ) - mock_get_form_data.assert_called_once() - DatasetDAO.find_by_id.assert_not_called() + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"queries": []} + with app.test_request_context( + data={ + "form_data": json.dumps( + { + "adhoc_filters": [ + { + "clause": "WHERE", + "comparator": "foo", + "expressionType": "SIMPLE", + "operator": "in", + "subject": "name", + } + ], + } + ), + } + ): + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) + DatasetDAO.find_by_id.assert_not_called() def test_metric_macro_no_dataset_id_with_context_datasource_id( @@ -636,18 +647,39 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id( schema="my_schema", sql=None, ) - mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") - mock_get_form_data.return_value = [ - { - "url_params": { - "datasource_id": 1, + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {} + + # Getting the data from the request context + with app.test_request_context( + data={ + "form_data": json.dumps( + { + "queries": [ + { + "url_params": { + "datasource_id": 1, + } + } + ], + } + ) + } + ): + assert metric_macro("macro_key") == "COUNT(*)" + + # Getting data from g's form_data + mock_g.form_data = { + "queries": [ + { + "url_params": { + "datasource_id": 1, + } } - }, - None, - ] - assert metric_macro("macro_key") == "COUNT(*)" - mock_get_form_data.assert_called_once() - DatasetDAO.find_by_id.assert_called_once_with(1) + ], + } + with app.test_request_context(): + assert metric_macro("macro_key") == "COUNT(*)" def test_metric_macro_no_dataset_id_with_context_datasource_id_none( @@ -657,26 +689,47 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id_none( Test the ``metric_macro`` when not specifying a dataset ID and it's set to None in the context (url_params.datasource_id). """ - ChartDAO = mocker.patch("superset.daos.chart.ChartDAO") - ChartDAO.find_by_id.return_value = None - DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") - mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") - mock_get_form_data.return_value = [ - { - "url_params": { - "datasource_id": None, - } - }, - None, - ] + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {} - with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") - assert str(excinfo.value) == ( - "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." - ) - mock_get_form_data.assert_called_once() - DatasetDAO.find_by_id.assert_not_called() + # Getting the data from the request context + with app.test_request_context( + data={ + "form_data": json.dumps( + { + "queries": [ + { + "url_params": { + "datasource_id": None, + } + } + ], + } + ) + } + ): + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) + + # Getting data from g's form_data + mock_g.form_data = { + "queries": [ + { + "url_params": { + "datasource_id": None, + } + } + ], + } + with app.test_request_context(): + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) def test_metric_macro_no_dataset_id_with_context_chart_id( @@ -700,16 +753,40 @@ def test_metric_macro_no_dataset_id_with_context_chart_id( schema="my_schema", sql=None, ) - mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") - mock_get_form_data.return_value = [ - { - "slice_id": 1, - }, - None, - ] - assert metric_macro("macro_key") == "COUNT(*)" - mock_get_form_data.assert_called_once() - DatasetDAO.find_by_id.assert_called_once_with(1) + + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {} + + # Getting the data from the request context + with app.test_request_context( + data={ + "form_data": json.dumps( + { + "queries": [ + { + "url_params": { + "slice_id": 1, + } + } + ], + } + ) + } + ): + assert metric_macro("macro_key") == "COUNT(*)" + + # Getting data from g's form_data + mock_g.form_data = { + "queries": [ + { + "url_params": { + "slice_id": 1, + } + } + ], + } + with app.test_request_context(): + assert metric_macro("macro_key") == "COUNT(*)" def test_metric_macro_no_dataset_id_with_context_slice_id_none( @@ -719,53 +796,47 @@ def test_metric_macro_no_dataset_id_with_context_slice_id_none( Test the ``metric_macro`` when not specifying a dataset ID and context includes slice_id set to None (url_params.slice_id). """ - ChartDAO = mocker.patch("superset.daos.chart.ChartDAO") - ChartDAO.find_by_id.return_value = None - DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") - mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") - mock_get_form_data.return_value = [ - { - "slice_id": None, - }, - None, - ] - - with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") - assert str(excinfo.value) == ( - "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." - ) - mock_get_form_data.assert_called_once() - DatasetDAO.find_by_id.assert_not_called() + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {} + # Getting the data from the request context + with app.test_request_context( + data={ + "form_data": json.dumps( + { + "queries": [ + { + "url_params": { + "slice_id": None, + } + } + ], + } + ) + } + ): + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) -def test_metric_macro_no_dataset_id_with_context_chart(mocker: MockerFixture) -> None: - """ - Test the ``metric_macro`` when not specifying a dataset ID and context - includes an existing chart (get_form_data()[1]). - """ - ChartDAO = mocker.patch("superset.daos.chart.ChartDAO") - DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") - DatasetDAO.find_by_id.return_value = SqlaTable( - table_name="test_dataset", - metrics=[ - SqlMetric(metric_name="macro_key", expression="COUNT(*)"), + # Getting data from g's form_data + mock_g.form_data = { + "queries": [ + { + "url_params": { + "slice_id": None, + } + } ], - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), - schema="my_schema", - sql=None, - ) - mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") - mock_get_form_data.return_value = [ - { - "slice_id": 1, - }, - Slice(datasource_id=1), - ] - assert metric_macro("macro_key") == "COUNT(*)" - mock_get_form_data.assert_called_once() - DatasetDAO.find_by_id.assert_called_once_with(1) - ChartDAO.find_by_id.assert_not_called() + } + with app.test_request_context(): + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) def test_metric_macro_no_dataset_id_with_context_deleted_chart( @@ -777,49 +848,91 @@ def test_metric_macro_no_dataset_id_with_context_deleted_chart( """ ChartDAO = mocker.patch("superset.daos.chart.ChartDAO") ChartDAO.find_by_id.return_value = None - DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") - mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") - mock_get_form_data.return_value = [ - { - "slice_id": 1, - }, - None, - ] + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {} - with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") - assert str(excinfo.value) == ( - "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." - ) - mock_get_form_data.assert_called_once() - DatasetDAO.find_by_id.assert_not_called() + # Getting the data from the request context + with app.test_request_context( + data={ + "form_data": json.dumps( + { + "queries": [ + { + "url_params": { + "slice_id": 1, + } + } + ], + } + ) + } + ): + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) + + # Getting data from g's form_data + mock_g.form_data = { + "queries": [ + { + "url_params": { + "slice_id": 1, + } + } + ], + } + with app.test_request_context(): + with pytest.raises(SupersetTemplateException) as excinfo: + metric_macro("macro_key") + assert str(excinfo.value) == ( + "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + ) -def test_metric_macro_no_dataset_id_with_context_chart_no_datasource_id( +def test_metric_macro_no_dataset_id_available_in_request_form_data( mocker: MockerFixture, ) -> None: """ Test the ``metric_macro`` when not specifying a dataset ID and context - includes an existing chart (get_form_data()[1]) with no dataset ID. + includes an existing dataset ID (datasource.id). """ - ChartDAO = mocker.patch("superset.daos.chart.ChartDAO") - ChartDAO.find_by_id.return_value = None DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") - mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") - mock_get_form_data.return_value = [ - {}, - Slice( - datasource_id=None, - ), - ] - - with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") - assert str(excinfo.value) == ( - "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." + DatasetDAO.find_by_id.return_value = SqlaTable( + table_name="test_dataset", + metrics=[ + SqlMetric(metric_name="macro_key", expression="COUNT(*)"), + ], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + schema="my_schema", + sql=None, ) - mock_get_form_data.assert_called_once() - DatasetDAO.find_by_id.assert_not_called() + + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {} + + # Getting the data from the request context + with app.test_request_context( + data={ + "form_data": json.dumps( + { + "datasource": { + "id": 1, + }, + } + ) + } + ): + assert metric_macro("macro_key") == "COUNT(*)" + + # Getting data from g's form_data + mock_g.form_data = { + "datasource": "1__table", + } + + with app.test_request_context(): + assert metric_macro("macro_key") == "COUNT(*)" @pytest.mark.parametrize(