Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sqllab] force limit queries only when there is no existing limit #5023

Merged
merged 3 commits into from
May 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,32 @@ def apply_limit_to_sql(cls, sql, limit, database):
)
return database.compile_sqla_query(qry)
elif LimitMethod.FORCE_LIMIT:
no_limit = re.sub(r"""
sql_without_limit = cls.get_query_without_limit(sql)
return '{sql_without_limit} LIMIT {limit}'.format(**locals())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic can generate sql like

SELECT id,
       name
FROM main.ab_permission
LIMIT 100
OFFSET 0 LIMIT 5000

which is wrong

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which engine is this on?

return sql

@classmethod
def get_limit_from_sql(cls, sql):
limit_pattern = re.compile(r"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize that this logic existed previously but shouldn't we use something like sqlparse to obtain the limit rather than using a regex?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I explored the sqlparse option but there were no nice way to just get the limit without recursively parsing through the query.

(?ix) # case insensitive, verbose
\s+ # whitespace
LIMIT\s+(\d+) # LIMIT $ROWS
;? # optional semi-colon
(\s|;)*$ # remove trailing spaces tabs or semicolons
""")
matches = limit_pattern.findall(sql)
if matches:
return int(matches[0][0])

@classmethod
def get_query_without_limit(cls, sql):
return re.sub(r"""
(?ix) # case insensitive, verbose
\s+ # whitespace
LIMIT\s+\d+ # LIMIT $ROWS
;? # optional semi-colon
(\s|;)*$ # remove trailing spaces tabs or semicolons
""", '', sql)
return '{no_limit} LIMIT {limit}'.format(**locals())
return sql

@staticmethod
def csv_to_df(**kwargs):
Expand Down
6 changes: 4 additions & 2 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def handle_error(msg):
# Limit enforced only for retrieving the data, not for the CTA queries.
superset_query = SupersetQuery(rendered_query)
executed_sql = superset_query.stripped()
SQL_MAX_ROWS = app.config.get('SQL_MAX_ROW')
if not superset_query.is_select() and not database.allow_dml:
return handle_error(
'Only `SELECT` statements are allowed against this database')
Expand All @@ -185,9 +186,10 @@ def handle_error(msg):
query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S'))
executed_sql = superset_query.as_create_table(query.tmp_table_name)
query.select_as_cta_used = True
elif (query.limit and superset_query.is_select()):
if (superset_query.is_select() and SQL_MAX_ROWS and
(not query.limit or query.limit > SQL_MAX_ROWS)):
query.limit = SQL_MAX_ROWS
executed_sql = database.apply_limit_to_sql(executed_sql, query.limit)
query.limit_used = True

# Hook to allow environment-specific mutation (usually comments) to the SQL
SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
Expand Down
2 changes: 1 addition & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2394,7 +2394,7 @@ def sql_json(self):

query = Query(
database_id=int(database_id),
limit=int(app.config.get('SQL_MAX_ROW', None)),
limit=mydb.db_engine_spec.get_limit_from_sql(sql),
sql=sql,
schema=schema,
select_as_cta=request.form.get('select_as_cta') == 'true',
Expand Down
30 changes: 28 additions & 2 deletions tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,43 @@ def test_run_async_query(self):
self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records'))
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue('FROM tmp_async_1' in query.select_sql)
self.assertTrue('LIMIT 666' in query.select_sql)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were these checks removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

self.assertEqual(
'CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role '
"WHERE name='Admin'", query.executed_sql)
"WHERE name='Admin' LIMIT 666", query.executed_sql)
self.assertEqual(sql_where, query.sql)
self.assertEqual(0, query.rows)
self.assertEqual(666, query.limit)
self.assertEqual(False, query.limit_used)
self.assertEqual(True, query.select_as_cta)
self.assertEqual(True, query.select_as_cta_used)

def test_run_async_query_with_lower_limit(self):
main_db = self.get_main_database(db.session)
eng = main_db.get_sqla_engine()
sql_where = "SELECT name FROM ab_role WHERE name='Alpha' LIMIT 1"
result = self.run_sql(
main_db.id, sql_where, '5', async='true', tmp_table='tmp_async_2',
cta='true')
assert result['query']['state'] in (
QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS)

time.sleep(1)

query = self.get_query_by_id(result['query']['serverId'])
df = pd.read_sql_query(query.select_sql, con=eng)
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertEqual([{'name': 'Alpha'}], df.to_dict(orient='records'))
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue('FROM tmp_async_2' in query.select_sql)
self.assertEqual(
'CREATE TABLE tmp_async_2 AS \nSELECT name FROM ab_role '
"WHERE name='Alpha' LIMIT 1", query.executed_sql)
self.assertEqual(sql_where, query.sql)
self.assertEqual(0, query.rows)
self.assertEqual(1, query.limit)
self.assertEqual(True, query.select_as_cta)
self.assertEqual(True, query.select_as_cta_used)

@staticmethod
def de_unicode_dict(d):
def str_if_basestring(o):
Expand Down
13 changes: 13 additions & 0 deletions tests/db_engine_specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,19 @@ def sql_limit_regex(
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main)
self.assertEquals(expected_sql, limited)

def test_extract_limit_from_query(self, engine_spec_class=MySQLEngineSpec):
q0 = 'select * from table'
q1 = 'select * from mytable limit 10'
q2 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20'
q3 = 'select * from (select * from my_subquery limit 10);'
q4 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20;'

self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)

def test_wrapped_query(self):
self.sql_limit_regex(
'SELECT * FROM a',
Expand Down