Skip to content

Commit

Permalink
fix(explore): strip semicolons in virtual table SQL (#13801)
Browse files Browse the repository at this point in the history
* add method to strip semicolon

* address comments

* test the test

* Update tests/sqla_models_tests.py

Co-authored-by: Jesse Yang <jesse.yang@airbnb.com>

* Update tests/sqla_models_tests.py

Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>

* fix test

* add suggestion

* fix trailing space

* remove logger

* fix unit test

Co-authored-by: Jesse Yang <jesse.yang@airbnb.com>
Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 6, 2021
1 parent c0888dc commit 34991f5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
4 changes: 2 additions & 2 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,6 @@ def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
logger.info(sql)
sql = sqlparse.format(sql, reindent=True)
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
Expand Down Expand Up @@ -818,6 +817,7 @@ def get_rendered_sql(
"""
Render sql with template engine (Jinja).
"""

sql = self.sql
if template_processor:
try:
Expand All @@ -829,7 +829,7 @@ def get_rendered_sql(
msg=ex.message,
)
)
sql = sqlparse.format(sql, strip_comments=True)
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:
Expand Down
22 changes: 22 additions & 0 deletions tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,28 @@ def test_incorrect_jinja_syntax_raises_correct_exception(self):
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**query_obj)

def test_query_format_strip_trailing_semicolon(self):
query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["user"],
"metrics": [],
"is_timeseries": False,
"filter": [],
"extras": {},
}

# Table with Jinja callable.
table = SqlaTable(
table_name="test_table",
sql="SELECT * from test_table;",
database=get_example_database(),
)
sqlaq = table.get_sqla_query(**query_obj)
sql = table.database.compile_sqla_query(sqlaq.sqla_query)
assert sql[-1] != ";"

def test_multiple_sql_statements_raises_exception(self):
base_query_obj = {
"granularity": None,
Expand Down

0 comments on commit 34991f5

Please sign in to comment.