Skip to content

Commit

Permalink
[sqla] Adding check for invalid filter columns (#7888)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored Jul 18, 2019
1 parent 174a48a commit 2b3e7fe
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 38 deletions.
79 changes: 41 additions & 38 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def get_sqla_query( # sqla
elif m in metrics_dict:
metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
raise Exception(_("Metric '%(metric)s' does not exist", metric=m))
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
Expand Down Expand Up @@ -721,43 +721,46 @@ def get_sqla_query( # sqla
if not all([flt.get(s) for s in ["col", "op"]]):
continue
col = flt["col"]

if col not in cols:
raise Exception(_("Column '%(column)s' does not exist", column=col))

op = flt["op"]
col_obj = cols.get(col)
if col_obj:
is_list_target = op in ("in", "not in")
eq = self.filter_values_handler(
flt.get("val"),
target_column_is_numeric=col_obj.is_num,
is_list_target=is_list_target,
)
if op in ("in", "not in"):
cond = col_obj.get_sqla_col().in_(eq)
if "<NULL>" in eq:
cond = or_(cond, col_obj.get_sqla_col() == None) # noqa
if op == "not in":
cond = ~cond
where_clause_and.append(cond)
else:
if col_obj.is_num:
eq = utils.string_to_num(flt["val"])
if op == "==":
where_clause_and.append(col_obj.get_sqla_col() == eq)
elif op == "!=":
where_clause_and.append(col_obj.get_sqla_col() != eq)
elif op == ">":
where_clause_and.append(col_obj.get_sqla_col() > eq)
elif op == "<":
where_clause_and.append(col_obj.get_sqla_col() < eq)
elif op == ">=":
where_clause_and.append(col_obj.get_sqla_col() >= eq)
elif op == "<=":
where_clause_and.append(col_obj.get_sqla_col() <= eq)
elif op == "LIKE":
where_clause_and.append(col_obj.get_sqla_col().like(eq))
elif op == "IS NULL":
where_clause_and.append(col_obj.get_sqla_col() == None) # noqa
elif op == "IS NOT NULL":
where_clause_and.append(col_obj.get_sqla_col() != None) # noqa
col_obj = cols[col]
is_list_target = op in ("in", "not in")
eq = self.filter_values_handler(
flt.get("val"),
target_column_is_numeric=col_obj.is_num,
is_list_target=is_list_target,
)
if op in ("in", "not in"):
cond = col_obj.get_sqla_col().in_(eq)
if "<NULL>" in eq:
cond = or_(cond, col_obj.get_sqla_col() == None) # noqa
if op == "not in":
cond = ~cond
where_clause_and.append(cond)
else:
if col_obj.is_num:
eq = utils.string_to_num(flt["val"])
if op == "==":
where_clause_and.append(col_obj.get_sqla_col() == eq)
elif op == "!=":
where_clause_and.append(col_obj.get_sqla_col() != eq)
elif op == ">":
where_clause_and.append(col_obj.get_sqla_col() > eq)
elif op == "<":
where_clause_and.append(col_obj.get_sqla_col() < eq)
elif op == ">=":
where_clause_and.append(col_obj.get_sqla_col() >= eq)
elif op == "<=":
where_clause_and.append(col_obj.get_sqla_col() <= eq)
elif op == "LIKE":
where_clause_and.append(col_obj.get_sqla_col().like(eq))
elif op == "IS NULL":
where_clause_and.append(col_obj.get_sqla_col() == None) # noqa
elif op == "IS NOT NULL":
where_clause_and.append(col_obj.get_sqla_col() != None) # noqa
if extras:
where = extras.get("where")
if where:
Expand Down Expand Up @@ -877,7 +880,7 @@ def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols):
ob = timeseries_limit_metric.get_sqla_col()
else:
raise Exception(
_("Metric '{}' is not valid".format(timeseries_limit_metric))
_("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric)
)

return ob
Expand Down
42 changes: 42 additions & 0 deletions tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,45 @@ def mutator(*args):
self.assertIn("--COMMENT", sql)

app.config["SQL_QUERY_MUTATOR"] = None

def test_query_with_non_existent_metrics(self):
tbl = self.get_table_by_name("birth_names")

query_obj = dict(
groupby=[],
metrics=["invalid"],
filter=[],
is_timeseries=False,
columns=["name"],
granularity=None,
from_dttm=None,
to_dttm=None,
is_prequery=False,
extras={},
)

with self.assertRaises(Exception) as context:
tbl.get_query_str(query_obj)

self.assertTrue("Metric 'invalid' does not exist", context.exception)

def test_query_with_non_existent_filter_columns(self):
tbl = self.get_table_by_name("birth_names")

query_obj = dict(
groupby=[],
metrics=["count"],
filter=[{"col": "invalid", "op": "==", "val": "male"}],
is_timeseries=False,
columns=["name"],
granularity=None,
from_dttm=None,
to_dttm=None,
is_prequery=False,
extras={},
)

with self.assertRaises(Exception) as context:
tbl.get_query_str(query_obj)

self.assertTrue("Column 'invalid' does not exist", context.exception)

0 comments on commit 2b3e7fe

Please sign in to comment.