Skip to content

Commit

Permalink
chore(sqla): assert query is single read-only statement (apache#11236)
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro authored and dpgaspar committed Oct 12, 2020
1 parent 9aba607 commit 76f6e85
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
9 changes: 9 additions & 0 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from superset.models.annotations import Annotation
from superset.models.core import Database
from superset.models.helpers import AuditMixinNullable, QueryResult
from superset.sql_parse import ParsedQuery
from superset.typing import Metric, QueryObjectDict
from superset.utils import core as utils, import_datasource

Expand Down Expand Up @@ -755,6 +756,14 @@ def get_from_clause(
)

from_sql = sqlparse.format(from_sql, strip_comments=True)
if len(sqlparse.split(from_sql)) > 1:
raise QueryObjectValidationError(
_("Virtual dataset query cannot consist of multiple statements")
)
if not ParsedQuery(from_sql).is_readonly():
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)
return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
return self.get_sqla_table()

Expand Down
42 changes: 42 additions & 0 deletions tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,45 @@ def test_incorrect_jinja_syntax_raises_correct_exception(self):
if get_example_database().backend != "presto":
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**query_obj)

def test_multiple_sql_statements_raises_exception(self):
base_query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["grp"],
"metrics": [],
"is_timeseries": False,
"filter": [],
}

table = SqlaTable(
table_name="test_has_extra_cache_keys_table",
sql="SELECT 'foo' as grp, 1 as num; SELECT 'bar' as grp, 2 as num",
database=get_example_database(),
)

query_obj = dict(**base_query_obj, extras={})
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**query_obj)

def test_dml_statement_raises_exception(self):
base_query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["grp"],
"metrics": [],
"is_timeseries": False,
"filter": [],
}

table = SqlaTable(
table_name="test_has_extra_cache_keys_table",
sql="DELETE FROM foo",
database=get_example_database(),
)

query_obj = dict(**base_query_obj, extras={})
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**query_obj)

0 comments on commit 76f6e85

Please sign in to comment.