Skip to content

Commit

Permalink
fix(filter-indicator): show filters handled by jinja as applied (apac…
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro authored Oct 18, 2021
1 parent 565ee23 commit d7834f1
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 20 deletions.
7 changes: 5 additions & 2 deletions superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _get_full(
datasource = _get_datasource(query_context, query_obj)
result_type = query_obj.result_type or query_context.result_type
payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
applied_template_filters = payload.get("applied_template_filters", [])
df = payload["df"]
status = payload["status"]
if status != QueryStatus.FAILED:
Expand All @@ -113,12 +114,14 @@ def _get_full(
datasource, query_obj.applied_time_extras
)
payload["applied_filters"] = [
{"column": col} for col in filter_columns if col in columns
{"column": col}
for col in filter_columns
if col in columns or col in applied_template_filters
] + applied_time_columns
payload["rejected_filters"] = [
{"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col}
for col in filter_columns
if col not in columns
if col not in columns and col not in applied_template_filters
] + rejected_time_columns

if result_type == ChartDataResultType.RESULTS and status != QueryStatus.FAILED:
Expand Down
1 change: 1 addition & 0 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def get_df_payload(
"cached_dttm": cache.cache_dttm,
"cache_timeout": self.cache_timeout,
"df": cache.df,
"applied_template_filters": cache.applied_template_filters,
"annotation_data": cache.annotation_data,
"error": cache.error_message,
"is_cached": cache.is_cached,
Expand Down
9 changes: 8 additions & 1 deletion superset/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from flask_caching import Cache
from pandas import DataFrame
Expand Down Expand Up @@ -51,6 +51,7 @@ def __init__(
df: DataFrame = DataFrame(),
query: str = "",
annotation_data: Optional[Dict[str, Any]] = None,
applied_template_filters: Optional[List[str]] = None,
status: Optional[str] = None,
error_message: Optional[str] = None,
is_loaded: bool = False,
Expand All @@ -62,6 +63,7 @@ def __init__(
self.df = df
self.query = query
self.annotation_data = {} if annotation_data is None else annotation_data
self.applied_template_filters = applied_template_filters or []
self.status = status
self.error_message = error_message

Expand All @@ -88,6 +90,7 @@ def set_query_result(
try:
self.status = query_result.status
self.query = query_result.query
self.applied_template_filters = query_result.applied_template_filters
self.error_message = query_result.error_message
self.df = query_result.df
self.annotation_data = {} if annotation_data is None else annotation_data
Expand All @@ -101,6 +104,7 @@ def set_query_result(
value = {
"df": self.df,
"query": self.query,
"applied_template_filters": self.applied_template_filters,
"annotation_data": self.annotation_data,
}
if self.is_loaded and key and self.status != QueryStatus.FAILED:
Expand Down Expand Up @@ -141,6 +145,9 @@ def get(
query_cache.df = cache_value["df"]
query_cache.query = cache_value["query"]
query_cache.annotation_data = cache_value.get("annotation_data", {})
query_cache.applied_template_filters = cache_value.get(
"applied_template_filters", []
)
query_cache.status = QueryStatus.SUCCESS
query_cache.is_loaded = True
query_cache.is_cached = cache_value is not None
Expand Down
11 changes: 10 additions & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,15 @@


class SqlaQuery(NamedTuple):
applied_template_filters: List[str]
extra_cache_keys: List[Any]
labels_expected: List[str]
prequeries: List[str]
sqla_query: Select


class QueryStringExtended(NamedTuple):
applied_template_filters: Optional[List[str]]
labels_expected: List[str]
prequeries: List[str]
sql: str
Expand Down Expand Up @@ -755,7 +757,10 @@ def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExten
sql = sqlparse.format(sql, reindent=True)
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries
applied_template_filters=sqlaq.applied_template_filters,
labels_expected=sqlaq.labels_expected,
prequeries=sqlaq.prequeries,
sql=sql,
)

def get_query_str(self, query_obj: QueryObjectDict) -> str:
Expand Down Expand Up @@ -978,7 +983,9 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
extra_cache_keys: List[Any] = []
template_kwargs["extra_cache_keys"] = extra_cache_keys
removed_filters: List[str] = []
applied_template_filters: List[str] = []
template_kwargs["removed_filters"] = removed_filters
template_kwargs["applied_filters"] = applied_template_filters
template_processor = self.get_template_processor(**template_kwargs)
db_engine_spec = self.db_engine_spec
prequeries: List[str] = []
Expand Down Expand Up @@ -1394,6 +1401,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
labels_expected = [label]

return SqlaQuery(
applied_template_filters=applied_template_filters,
extra_cache_keys=extra_cache_keys,
labels_expected=labels_expected,
sqla_query=qry,
Expand Down Expand Up @@ -1491,6 +1499,7 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
error_message = utils.error_msg_from_exception(ex)

return QueryResult(
applied_template_filters=query_str_ext.applied_template_filters,
status=status,
df=df,
duration=datetime.now() - qry_start_dttm,
Expand Down
8 changes: 8 additions & 0 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ class ExtraCache:
def __init__(
self,
extra_cache_keys: Optional[List[Any]] = None,
applied_filters: Optional[List[str]] = None,
removed_filters: Optional[List[str]] = None,
dialect: Optional[Dialect] = None,
):
self.extra_cache_keys = extra_cache_keys
self.applied_filters = applied_filters if applied_filters is not None else []
self.removed_filters = removed_filters if removed_filters is not None else []
self.dialect = dialect

Expand Down Expand Up @@ -323,6 +325,9 @@ def get_filters(self, column: str, remove_filter: bool = False) -> List[Filter]:
if remove_filter:
if column not in self.removed_filters:
self.removed_filters.append(column)
if column not in self.applied_filters:
self.applied_filters.append(column)

if op in (
FilterOperator.IN.value,
FilterOperator.NOT_IN.value,
Expand Down Expand Up @@ -408,6 +413,7 @@ def __init__(
table: Optional["SqlaTable"] = None,
extra_cache_keys: Optional[List[Any]] = None,
removed_filters: Optional[List[str]] = None,
applied_filters: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
self._database = database
Expand All @@ -418,6 +424,7 @@ def __init__(
elif table:
self._schema = table.schema
self._extra_cache_keys = extra_cache_keys
self._applied_filters = applied_filters
self._removed_filters = removed_filters
self._context: Dict[str, Any] = {}
self._env = SandboxedEnvironment(undefined=DebugUndefined)
Expand Down Expand Up @@ -446,6 +453,7 @@ def set_context(self, **kwargs: Any) -> None:
super().set_context(**kwargs)
extra_cache = ExtraCache(
extra_cache_keys=self._extra_cache_keys,
applied_filters=self._applied_filters,
removed_filters=self._removed_filters,
dialect=self._database.get_dialect(),
)
Expand Down
2 changes: 2 additions & 0 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,15 @@ def __init__( # pylint: disable=too-many-arguments
df: pd.DataFrame,
query: str,
duration: timedelta,
applied_template_filters: Optional[List[str]] = None,
status: str = QueryStatus.SUCCESS,
error_message: Optional[str] = None,
errors: Optional[List[Dict[str, Any]]] = None,
) -> None:
self.df = df
self.query = query
self.duration = duration
self.applied_template_filters = applied_template_filters or []
self.status = status
self.error_message = error_message
self.errors = errors or []
Expand Down
25 changes: 9 additions & 16 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,6 @@
"size",
]

# This regex is to get user defined filter column name, which is the first param in the
# filter_values function. See the definition of filter_values template:
# https://github.com/apache/superset/blob/24ad6063d736c1f38ad6f962e586b9b1a21946af/superset/jinja_context.py#L63
FILTER_VALUES_REGEX = re.compile(r"filter_values\(['\"](\w+)['\"]\,")


class BaseViz: # pylint: disable=too-many-public-methods

Expand Down Expand Up @@ -143,6 +138,7 @@ def __init__(
self.status: Optional[str] = None
self.error_msg = ""
self.results: Optional[QueryResult] = None
self.applied_template_filters: List[str] = []
self.errors: List[Dict[str, Any]] = []
self.force = force
self._force_cached = force_cached
Expand Down Expand Up @@ -270,6 +266,7 @@ def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:

# The datasource here can be different backend but the interface is common
self.results = self.datasource.query(query_obj)
self.applied_template_filters = self.results.applied_template_filters or []
self.query = self.results.query
self.status = self.results.status
self.errors = self.results.errors
Expand Down Expand Up @@ -459,33 +456,26 @@ def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload
filters = self.form_data.get("filters", [])
filter_columns = [flt.get("col") for flt in filters]
columns = set(self.datasource.column_names)
filter_values_columns = []

# if using virtual datasource, check filter_values
if self.datasource.sql:
filter_values_columns = (
re.findall(FILTER_VALUES_REGEX, self.datasource.sql)
) or []

applied_template_filters = self.applied_template_filters or []
applied_time_extras = self.form_data.get("applied_time_extras", {})
applied_time_columns, rejected_time_columns = utils.get_time_filter_status(
self.datasource, applied_time_extras
)
payload["applied_filters"] = [
{"column": col}
for col in filter_columns
if col in columns or col in filter_values_columns
if col in columns or col in applied_template_filters
] + applied_time_columns
payload["rejected_filters"] = [
{"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col}
for col in filter_columns
if col not in columns and col not in filter_values_columns
if col not in columns and col not in applied_template_filters
] + rejected_time_columns
if df is not None:
payload["colnames"] = list(df.columns)
return payload

def get_df_payload(
def get_df_payload( # pylint: disable=too-many-statements
self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Handles caching around the df payload retrieval"""
Expand All @@ -504,6 +494,9 @@ def get_df_payload(
try:
df = cache_value["df"]
self.query = cache_value["query"]
self.applied_template_filters = cache_value.get(
"applied_template_filters", []
)
self.status = QueryStatus.SUCCESS
is_loaded = True
stats_logger.incr("loaded_from_cache")
Expand Down
5 changes: 5 additions & 0 deletions tests/integration_tests/jinja_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_filter_values_adhoc_filters(self) -> None:
):
cache = ExtraCache()
self.assertEqual(cache.filter_values("name"), ["foo"])
self.assertEqual(cache.applied_filters, ["name"])

with app.test_request_context(
data={
Expand All @@ -94,6 +95,7 @@ def test_filter_values_adhoc_filters(self) -> None:
):
cache = ExtraCache()
self.assertEqual(cache.filter_values("name"), ["foo", "bar"])
self.assertEqual(cache.applied_filters, ["name"])

def test_get_filters_adhoc_filters(self) -> None:
with app.test_request_context(
Expand All @@ -118,6 +120,7 @@ def test_get_filters_adhoc_filters(self) -> None:
cache.get_filters("name"), [{"op": "IN", "col": "name", "val": ["foo"]}]
)
self.assertEqual(cache.removed_filters, list())
self.assertEqual(cache.applied_filters, ["name"])

with app.test_request_context(
data={
Expand Down Expand Up @@ -166,6 +169,7 @@ def test_get_filters_adhoc_filters(self) -> None:
[{"op": "IN", "col": "name", "val": ["foo", "bar"]}],
)
self.assertEqual(cache.removed_filters, ["name"])
self.assertEqual(cache.applied_filters, ["name"])

def test_filter_values_extra_filters(self) -> None:
with app.test_request_context(
Expand All @@ -177,6 +181,7 @@ def test_filter_values_extra_filters(self) -> None:
):
cache = ExtraCache()
self.assertEqual(cache.filter_values("name"), ["foo"])
self.assertEqual(cache.applied_filters, ["name"])

def test_url_param_default(self) -> None:
with app.test_request_context():
Expand Down

0 comments on commit d7834f1

Please sign in to comment.