From 305c70468c6e873e5fc6f9af955851b59454e131 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Mon, 18 Oct 2021 19:28:05 +0200 Subject: [PATCH] fix(filter-indicator): show filters handled by jinja as applied (#17140) --- superset/common/query_actions.py | 7 ++++-- superset/common/query_context.py | 1 + superset/common/utils.py | 9 ++++++- superset/connectors/sqla/models.py | 11 +++++++- superset/jinja_context.py | 8 ++++++ superset/models/helpers.py | 2 ++ superset/viz.py | 25 +++++++------------ .../integration_tests/jinja_context_tests.py | 5 ++++ 8 files changed, 48 insertions(+), 20 deletions(-) diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index 86a687f08716a..925d4a19516ce 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -96,6 +96,7 @@ def _get_full( datasource = _get_datasource(query_context, query_obj) result_type = query_obj.result_type or query_context.result_type payload = query_context.get_df_payload(query_obj, force_cached=force_cached) + applied_template_filters = payload.get("applied_template_filters", []) df = payload["df"] status = payload["status"] if status != QueryStatus.FAILED: @@ -113,12 +114,14 @@ def _get_full( datasource, query_obj.applied_time_extras ) payload["applied_filters"] = [ - {"column": col} for col in filter_columns if col in columns + {"column": col} + for col in filter_columns + if col in columns or col in applied_template_filters ] + applied_time_columns payload["rejected_filters"] = [ {"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col} for col in filter_columns - if col not in columns + if col not in columns and col not in applied_template_filters ] + rejected_time_columns if result_type == ChartDataResultType.RESULTS and status != QueryStatus.FAILED: diff --git a/superset/common/query_context.py b/superset/common/query_context.py index c1162b9671ae2..eee2bbee42531 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -485,6 +485,7 @@ def get_df_payload( "cached_dttm": cache.cache_dttm, "cache_timeout": self.cache_timeout, "df": cache.df, + "applied_template_filters": cache.applied_template_filters, "annotation_data": cache.annotation_data, "error": cache.error_message, "is_cached": cache.is_cached, diff --git a/superset/common/utils.py b/superset/common/utils.py index 77a6baba7fba6..d5cad68eb26d8 100644 --- a/superset/common/utils.py +++ b/superset/common/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from flask_caching import Cache from pandas import DataFrame @@ -51,6 +51,7 @@ def __init__( df: DataFrame = DataFrame(), query: str = "", annotation_data: Optional[Dict[str, Any]] = None, + applied_template_filters: Optional[List[str]] = None, status: Optional[str] = None, error_message: Optional[str] = None, is_loaded: bool = False, @@ -62,6 +63,7 @@ def __init__( self.df = df self.query = query self.annotation_data = {} if annotation_data is None else annotation_data + self.applied_template_filters = applied_template_filters or [] self.status = status self.error_message = error_message @@ -88,6 +90,7 @@ def set_query_result( try: self.status = query_result.status self.query = query_result.query + self.applied_template_filters = query_result.applied_template_filters self.error_message = query_result.error_message self.df = query_result.df self.annotation_data = {} if annotation_data is None else annotation_data @@ -101,6 +104,7 @@ def set_query_result( value = { "df": self.df, "query": self.query, + "applied_template_filters": self.applied_template_filters, "annotation_data": self.annotation_data, } if self.is_loaded and key and self.status != QueryStatus.FAILED: @@ -141,6 +145,9 @@ def get( query_cache.df = cache_value["df"] query_cache.query = cache_value["query"] query_cache.annotation_data = cache_value.get("annotation_data", {}) + query_cache.applied_template_filters = cache_value.get( + "applied_template_filters", [] + ) query_cache.status = QueryStatus.SUCCESS query_cache.is_loaded = True query_cache.is_cached = cache_value is not None diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 0f6b7052781c6..fcb40f2199b5e 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -103,6 +103,7 @@ class SqlaQuery(NamedTuple): + applied_template_filters: List[str] extra_cache_keys: List[Any] labels_expected: List[str] prequeries: List[str] @@ -110,6 +111,7 @@ class SqlaQuery(NamedTuple): class QueryStringExtended(NamedTuple): + applied_template_filters: Optional[List[str]] labels_expected: List[str] prequeries: List[str] sql: str @@ -755,7 +757,10 @@ def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExten sql = sqlparse.format(sql, reindent=True) sql = self.mutate_query_from_config(sql) return QueryStringExtended( - labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries + applied_template_filters=sqlaq.applied_template_filters, + labels_expected=sqlaq.labels_expected, + prequeries=sqlaq.prequeries, + sql=sql, ) def get_query_str(self, query_obj: QueryObjectDict) -> str: @@ -978,7 +983,9 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma extra_cache_keys: List[Any] = [] template_kwargs["extra_cache_keys"] = extra_cache_keys removed_filters: List[str] = [] + applied_template_filters: List[str] = [] template_kwargs["removed_filters"] = removed_filters + template_kwargs["applied_filters"] = applied_template_filters template_processor = self.get_template_processor(**template_kwargs) db_engine_spec = self.db_engine_spec prequeries: List[str] = [] @@ -1394,6 +1401,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma labels_expected = [label] return SqlaQuery( + applied_template_filters=applied_template_filters, extra_cache_keys=extra_cache_keys, labels_expected=labels_expected, sqla_query=qry, @@ -1491,6 +1499,7 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: error_message = utils.error_msg_from_exception(ex) return QueryResult( + applied_template_filters=query_str_ext.applied_template_filters, status=status, df=df, duration=datetime.now() - qry_start_dttm, diff --git a/superset/jinja_context.py b/superset/jinja_context.py index e6a4cab963fb7..f21fbbb1b745a 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -96,10 +96,12 @@ class ExtraCache: def __init__( self, extra_cache_keys: Optional[List[Any]] = None, + applied_filters: Optional[List[str]] = None, removed_filters: Optional[List[str]] = None, dialect: Optional[Dialect] = None, ): self.extra_cache_keys = extra_cache_keys + self.applied_filters = applied_filters if applied_filters is not None else [] self.removed_filters = removed_filters if removed_filters is not None else [] self.dialect = dialect @@ -323,6 +325,9 @@ def get_filters(self, column: str, remove_filter: bool = False) -> List[Filter]: if remove_filter: if column not in self.removed_filters: self.removed_filters.append(column) + if column not in self.applied_filters: + self.applied_filters.append(column) + if op in ( FilterOperator.IN.value, FilterOperator.NOT_IN.value, @@ -408,6 +413,7 @@ def __init__( table: Optional["SqlaTable"] = None, extra_cache_keys: Optional[List[Any]] = None, removed_filters: Optional[List[str]] = None, + applied_filters: Optional[List[str]] = None, **kwargs: Any, ) -> None: self._database = database @@ -418,6 +424,7 @@ def __init__( elif table: self._schema = table.schema self._extra_cache_keys = extra_cache_keys + self._applied_filters = applied_filters self._removed_filters = removed_filters self._context: Dict[str, Any] = {} self._env = SandboxedEnvironment(undefined=DebugUndefined) @@ -446,6 +453,7 @@ def set_context(self, **kwargs: Any) -> None: super().set_context(**kwargs) extra_cache = ExtraCache( extra_cache_keys=self._extra_cache_keys, + applied_filters=self._applied_filters, removed_filters=self._removed_filters, dialect=self._database.get_dialect(), ) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 30d5ab9696475..1875d247dc0a3 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -442,6 +442,7 @@ def __init__( # pylint: disable=too-many-arguments df: pd.DataFrame, query: str, duration: timedelta, + applied_template_filters: Optional[List[str]] = None, status: str = QueryStatus.SUCCESS, error_message: Optional[str] = None, errors: Optional[List[Dict[str, Any]]] = None, @@ -449,6 +450,7 @@ def __init__( # pylint: disable=too-many-arguments self.df = df self.query = query self.duration = duration + self.applied_template_filters = applied_template_filters or [] self.status = status self.error_message = error_message self.errors = errors or [] diff --git a/superset/viz.py b/superset/viz.py index 5e22114765108..ecf4f63c6a20e 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -102,11 +102,6 @@ "size", ] -# This regex is to get user defined filter column name, which is the first param in the -# filter_values function. See the definition of filter_values template: -# https://github.com/apache/superset/blob/24ad6063d736c1f38ad6f962e586b9b1a21946af/superset/jinja_context.py#L63 -FILTER_VALUES_REGEX = re.compile(r"filter_values\(['\"](\w+)['\"]\,") - class BaseViz: # pylint: disable=too-many-public-methods @@ -143,6 +138,7 @@ def __init__( self.status: Optional[str] = None self.error_msg = "" self.results: Optional[QueryResult] = None + self.applied_template_filters: List[str] = [] self.errors: List[Dict[str, Any]] = [] self.force = force self._force_cached = force_cached @@ -270,6 +266,7 @@ def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame: # The datasource here can be different backend but the interface is common self.results = self.datasource.query(query_obj) + self.applied_template_filters = self.results.applied_template_filters or [] self.query = self.results.query self.status = self.results.status self.errors = self.results.errors @@ -459,14 +456,7 @@ def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload filters = self.form_data.get("filters", []) filter_columns = [flt.get("col") for flt in filters] columns = set(self.datasource.column_names) - filter_values_columns = [] - - # if using virtual datasource, check filter_values - if self.datasource.sql: - filter_values_columns = ( - re.findall(FILTER_VALUES_REGEX, self.datasource.sql) - ) or [] - + applied_template_filters = self.applied_template_filters or [] applied_time_extras = self.form_data.get("applied_time_extras", {}) applied_time_columns, rejected_time_columns = utils.get_time_filter_status( self.datasource, applied_time_extras @@ -474,18 +464,18 @@ def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload payload["applied_filters"] = [ {"column": col} for col in filter_columns - if col in columns or col in filter_values_columns + if col in columns or col in applied_template_filters ] + applied_time_columns payload["rejected_filters"] = [ {"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col} for col in filter_columns - if col not in columns and col not in filter_values_columns + if col not in columns and col not in applied_template_filters ] + rejected_time_columns if df is not None: payload["colnames"] = list(df.columns) return payload - def get_df_payload( + def get_df_payload( # pylint: disable=too-many-statements self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any ) -> Dict[str, Any]: """Handles caching around the df payload retrieval""" @@ -504,6 +494,9 @@ def get_df_payload( try: df = cache_value["df"] self.query = cache_value["query"] + self.applied_template_filters = cache_value.get( + "applied_template_filters", [] + ) self.status = QueryStatus.SUCCESS is_loaded = True stats_logger.incr("loaded_from_cache") diff --git a/tests/integration_tests/jinja_context_tests.py b/tests/integration_tests/jinja_context_tests.py index b82adfa05f9dc..31f877740bb0c 100644 --- a/tests/integration_tests/jinja_context_tests.py +++ b/tests/integration_tests/jinja_context_tests.py @@ -74,6 +74,7 @@ def test_filter_values_adhoc_filters(self) -> None: ): cache = ExtraCache() self.assertEqual(cache.filter_values("name"), ["foo"]) + self.assertEqual(cache.applied_filters, ["name"]) with app.test_request_context( data={ @@ -94,6 +95,7 @@ def test_filter_values_adhoc_filters(self) -> None: ): cache = ExtraCache() self.assertEqual(cache.filter_values("name"), ["foo", "bar"]) + self.assertEqual(cache.applied_filters, ["name"]) def test_get_filters_adhoc_filters(self) -> None: with app.test_request_context( @@ -118,6 +120,7 @@ def test_get_filters_adhoc_filters(self) -> None: cache.get_filters("name"), [{"op": "IN", "col": "name", "val": ["foo"]}] ) self.assertEqual(cache.removed_filters, list()) + self.assertEqual(cache.applied_filters, ["name"]) with app.test_request_context( data={ @@ -166,6 +169,7 @@ def test_get_filters_adhoc_filters(self) -> None: [{"op": "IN", "col": "name", "val": ["foo", "bar"]}], ) self.assertEqual(cache.removed_filters, ["name"]) + self.assertEqual(cache.applied_filters, ["name"]) def test_filter_values_extra_filters(self) -> None: with app.test_request_context( @@ -177,6 +181,7 @@ def test_filter_values_extra_filters(self) -> None: ): cache = ExtraCache() self.assertEqual(cache.filter_values("name"), ["foo"]) + self.assertEqual(cache.applied_filters, ["name"]) def test_url_param_default(self) -> None: with app.test_request_context():