From 095fb8c8520c50ff217e2fc8d221e08c2dee9c63 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 15 Sep 2021 10:10:07 +0300 Subject: [PATCH] fix test --- superset/common/query_context.py | 5 +---- superset/common/query_object.py | 10 +++++++--- superset/utils/sqllab_execution_context.py | 2 +- superset/viz.py | 2 +- tests/integration_tests/charts/schema_tests.py | 6 ++++-- tests/integration_tests/query_context_tests.py | 1 + 6 files changed, 15 insertions(+), 11 deletions(-) diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 5c71b2917c938..83070f70c34a2 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -100,14 +100,11 @@ def __init__( self.datasource = ConnectorRegistry.get_datasource( str(datasource["type"]), int(datasource["id"]), db.session ) - self.queries = [ - QueryObject(**query_obj, result_type_qc=result_type) - for query_obj in queries - ] self.force = force self.custom_cache_timeout = custom_cache_timeout self.result_type = result_type or ChartDataResultType.FULL self.result_format = result_format or ChartDataResultFormat.JSON + self.queries = [QueryObject(self, **query_obj) for query_obj in queries] self.cache_values = { "datasource": datasource, "queries": queries, diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 5332de56fe257..72995c0668059 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime, timedelta -from typing import Any, Dict, List, NamedTuple, Optional +from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING from flask_babel import gettext as _ from pandas import DataFrame @@ -42,6 +42,10 @@ from superset.utils.hashing import md5_sha_from_dict from superset.views.utils import get_time_range_endpoints +if TYPE_CHECKING: + from superset.common.query_context import QueryContext + + config = app.config logger = logging.getLogger(__name__) @@ -101,6 +105,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes def __init__( # pylint: disable=too-many-arguments,too-many-locals self, + query_context: "QueryContext", datasource: Optional[DatasourceDict] = None, result_type: Optional[ChartDataResultType] = None, annotation_layers: Optional[List[Dict[str, Any]]] = None, @@ -123,7 +128,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals orderby: Optional[List[OrderBy]] = None, post_processing: Optional[List[Optional[Dict[str, Any]]]] = None, is_rowcount: bool = False, - result_type_qc: Optional[ChartDataResultType] = None, **kwargs: Any, ): columns = columns or [] @@ -140,7 +144,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self.datasource = ConnectorRegistry.get_datasource( str(datasource["type"]), int(datasource["id"]), db.session ) - self.result_type = result_type or result_type_qc + self.result_type = result_type or query_context.result_type self.apply_fetch_values_predicate = apply_fetch_values_predicate or False self.annotation_layers = [ layer diff --git a/superset/utils/sqllab_execution_context.py b/superset/utils/sqllab_execution_context.py index 38feeb225ebc9..58d7f3104fa80 100644 --- a/superset/utils/sqllab_execution_context.py +++ b/superset/utils/sqllab_execution_context.py @@ -104,7 +104,7 @@ def _get_template_params(query_params: Dict[str, Any]) -> Dict[str, Any]: @staticmethod def _get_limit_param(query_params: Dict[str, Any]) -> int: limit = apply_max_row_limit( - app.config["SQL_MAX_ROW"], query_params.get("queryLimit", 0) + app.config["SQL_MAX_ROW"], query_params.get("queryLimit") or 0 ) if limit < 0: logger.warning( diff --git a/superset/viz.py b/superset/viz.py index 3277f0c5ea89f..4ac3d665098cc 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -333,7 +333,7 @@ def query_obj(self) -> QueryObjectDict: timeseries_limit_metric = form_data.get("timeseries_limit_metric") # apply row limit to query - row_limit = form_data.get("row_limit") or config["ROW_LIMIT"] + row_limit = int(form_data.get("row_limit") or config["ROW_LIMIT"]) row_limit = apply_max_row_limit(config["SQL_MAX_ROW"], row_limit) # default order direction diff --git a/tests/integration_tests/charts/schema_tests.py b/tests/integration_tests/charts/schema_tests.py index e34b7d71fb418..5cdf22ac338e1 100644 --- a/tests/integration_tests/charts/schema_tests.py +++ b/tests/integration_tests/charts/schema_tests.py @@ -16,17 +16,19 @@ # under the License. # isort:skip_file """Unit tests for Superset""" -from typing import Any, Dict, Tuple +from unittest import mock from marshmallow import ValidationError from tests.integration_tests.test_app import app from superset.charts.schemas import ChartDataQueryContextSchema -from superset.common.query_context import QueryContext from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.query_context import get_query_context class TestSchema(SupersetTestCase): + @mock.patch( + "superset.common.query_object.config", {**app.config, "SQL_MAX_ROW": 100000}, + ) def test_query_context_limit_and_offset(self): self.login(username="admin") payload = get_query_context("birth_names") diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 87af4c931b988..e895e7e2f6f3e 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -90,6 +90,7 @@ def test_schema_deserialization(self): self.assertEqual(post_proc["operation"], payload_post_proc["operation"]) self.assertEqual(post_proc["options"], payload_post_proc["options"]) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_cache(self): table_name = "birth_names" table = self.get_table(name=table_name)