Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sql lab] a better approach at limiting queries #4947

Merged
merged 5 commits into from
May 14, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.sql import text
from sqlalchemy.sql.expression import TextAsFrom
import sqlparse
import unicodecsv
from werkzeug.utils import secure_filename
Expand All @@ -55,6 +56,7 @@ class LimitMethod(object):
"""Enum the ways that limits can be applied"""
FETCH_MANY = 'fetch_many'
WRAP_SQL = 'wrap_sql'
FORCE_LIMIT = 'force_limit'


class BaseEngineSpec(object):
Expand All @@ -65,7 +67,7 @@ class BaseEngineSpec(object):
cursor_execute_kwargs = {}
time_grains = tuple()
time_groupby_inline = False
limit_method = LimitMethod.FETCH_MANY
limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False
inner_joins = True

Expand All @@ -88,6 +90,23 @@ def extra_table_metadata(cls, database, table_name, schema_name):
"""Returns engine-specific table metadata"""
return {}

@classmethod
def apply_limit_to_sql(cls, sql, limit, database):
"""Alters the SQL statement to apply a LIMIT clause"""
if cls.limit_method == LimitMethod.WRAP_SQL:
qry = (
select('*')
.select_from(
TextAsFrom(text(sql), ['*']).alias('inner_qry'),
)
.limit(limit)
)
return database.compile_sqla_query(qry)
elif LimitMethod.FORCE_LIMIT:
no_limit = re.sub(r'(?i)\s+LIMIT\s+\d+;?(\s|;)*$', '', sql)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scary! :-O

BTW, you can use backreferences here to replace the LIMIT number in one pass (untested):

sql = re.sub(r'(?i)(.*\s+LIMIT\s+)\d+\s*;?\s*$', r'\1 {0}'.format(limit), sql)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I barely understand what I wrote here :/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some people, when confronted with a problem, think "I know, I'll use regular expressions." Now they have two problems.

Yeah, I just realized my example does not work when there's no LIMIT clause. :-/

You know about verbose regexes? Eg:

no_limit = re.sub(r"""
    (?ix)        # case insensitive, verbose
    \s+          # whitespace
    LIMIT\s+\d+  # LIMIT $ROWS
    ;?           # optional semi-colon
    (\s|;)*$     # any number of trailing whitespace or semi-colons, til the end
""", '', sql)

Though in this case I think it's pretty straightforward. You might wanna rstrip whitespace and semicolons from the end of sql, that would simplify the regex, no?

return '{no_limit} LIMIT {limit}'.format(**locals())
return sql

@staticmethod
def csv_to_df(**kwargs):
kwargs['filepath_or_buffer'] = \
Expand Down Expand Up @@ -337,7 +356,6 @@ def get_table_names(cls, schema, inspector):

class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake'

time_grains = (
Grain('Time Column', _('Time Column'), '{col}', None),
Grain('second', _('second'), "DATE_TRUNC('SECOND', {col})", 'PT1S'),
Expand All @@ -361,6 +379,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec):

class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle'
limit_method = LimitMethod.WRAP_SQL

time_grains = (
Grain('Time Column', _('Time Column'), '{col}', None),
Expand All @@ -382,6 +401,7 @@ def convert_dttm(cls, target_type, dttm):

class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa'
limit_method = LimitMethod.WRAP_SQL
time_grains = (
Grain('Time Column', _('Time Column'), '{col}', None),
Grain('second', _('second'),
Expand Down Expand Up @@ -1106,6 +1126,7 @@ def get_configuration_for_impersonation(cls, uri, impersonate_user, username):
class MssqlEngineSpec(BaseEngineSpec):
engine = 'mssql'
epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')"
limit_method = LimitMethod.WRAP_SQL

time_grains = (
Grain('Time Column', _('Time Column'), '{col}', None),
Expand Down Expand Up @@ -1310,7 +1331,6 @@ def get_schema_names(cls, inspector):
class DruidEngineSpec(BaseEngineSpec):
"""Engine spec for Druid.io"""
engine = 'druid'
limit_method = LimitMethod.FETCH_MANY
inner_joins = False


Expand Down
13 changes: 2 additions & 11 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,14 @@
import sqlalchemy as sqla
from sqlalchemy import (
Boolean, Column, create_engine, DateTime, ForeignKey, Integer,
MetaData, select, String, Table, Text,
MetaData, String, Table, Text,
)
from sqlalchemy.engine import url
from sqlalchemy.engine.url import make_url
from sqlalchemy.orm import relationship, subqueryload
from sqlalchemy.orm.session import make_transient
from sqlalchemy.pool import NullPool
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import text
from sqlalchemy.sql.expression import TextAsFrom
from sqlalchemy_utils import EncryptedType

from superset import app, db, db_engine_specs, security_manager, utils
Expand Down Expand Up @@ -722,14 +720,7 @@ def select_star(
indent=indent, latest_partition=latest_partition, cols=cols)

def wrap_sql_limit(self, sql, limit=1000):
qry = (
select('*')
.select_from(
TextAsFrom(text(sql), ['*'])
.alias('inner_qry'),
).limit(limit)
)
return self.compile_sqla_query(qry)
return self.db_engine_spec.apply_limit_to_sql(sql, limit, self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better update the method name here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, rename wrap_sql_limit to something else.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

went with apply_limit_to_sql


def safe_sqlalchemy_uri(self):
return self.sqlalchemy_uri
Expand Down
4 changes: 1 addition & 3 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from sqlalchemy.pool import NullPool

from superset import app, dataframe, db, results_backend, security_manager, utils
from superset.db_engine_specs import LimitMethod
from superset.models.sql_lab import Query
from superset.sql_parse import SupersetQuery
from superset.utils import get_celery_app, QueryStatus
Expand Down Expand Up @@ -188,8 +187,7 @@ def handle_error(msg):
query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S'))
executed_sql = superset_query.as_create_table(query.tmp_table_name)
query.select_as_cta_used = True
elif (query.limit and superset_query.is_select() and
db_engine_spec.limit_method == LimitMethod.WRAP_SQL):
elif (query.limit and superset_query.is_select()):
executed_sql = database.wrap_sql_limit(executed_sql, query.limit)
query.limit_used = True

Expand Down
24 changes: 7 additions & 17 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}


# TODO: some sql_lab logic here.
class SupersetQuery(object):
def __init__(self, sql_statement):
self.sql = sql_statement
self._table_names = set()
self._alias_names = set()
# TODO: multistatement support

logging.info('Parsing with sqlparse statement {}'.format(self.sql))
self._parsed = sqlparse.parse(self.sql)
for statement in self._parsed:
Expand All @@ -36,11 +36,7 @@ def is_select(self):
return self._parsed[0].get_type() == 'SELECT'

def stripped(self):
sql = self.sql
if sql:
while sql[-1] in (' ', ';', '\n', '\t'):
sql = sql[:-1]
return sql
return self.sql.strip(' \t\n;')

@staticmethod
def __precedes_table_name(token_value):
Expand All @@ -65,13 +61,12 @@ def __is_result_operation(keyword):

@staticmethod
def __is_identifier(token):
return (
isinstance(token, IdentifierList) or isinstance(token, Identifier))
return isinstance(token, (IdentifierList, Identifier))

def __process_identifier(self, identifier):
# exclude subselects
if '(' not in '{}'.format(identifier):
self._table_names.add(SupersetQuery.__get_full_name(identifier))
self._table_names.add(self.__get_full_name(identifier))
return

# store aliases
Expand All @@ -94,11 +89,6 @@ def as_create_table(self, table_name, overwrite=False):
:param overwrite, boolean, table table_name will be dropped if true
:return: string, create table as query
"""
# TODO(bkyryliuk): enforce that all the columns have names.
# Presto requires it for the CTA operation.
# TODO(bkyryliuk): drop table if allowed, check the namespace and
# the permissions.
# TODO raise if multi-statement
exec_sql = ''
sql = self.stripped()
if overwrite:
Expand All @@ -117,15 +107,15 @@ def __extract_from_token(self, token):
self.__extract_from_token(item)

if item.ttype in Keyword:
if SupersetQuery.__precedes_table_name(item.value.upper()):
if self.__precedes_table_name(item.value.upper()):
table_name_preceding_token = True
continue

if not table_name_preceding_token:
continue

if item.ttype in Keyword:
if SupersetQuery.__is_result_operation(item.value):
if self.__is_result_operation(item.value):
table_name_preceding_token = False
continue
# FROM clause is over
Expand All @@ -136,5 +126,5 @@ def __extract_from_token(self, token):

if isinstance(item, IdentifierList):
for token in item.tokens:
if SupersetQuery.__is_identifier(token):
if self.__is_identifier(token):
self.__process_identifier(token)
32 changes: 0 additions & 32 deletions tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,38 +139,6 @@ def run_sql(self, db_id, sql, client_id, cta='false', tmp_table='tmp',
self.logout()
return json.loads(resp.data.decode('utf-8'))

def test_add_limit_to_the_query(self):
main_db = self.get_main_database(db.session)

select_query = 'SELECT * FROM outer_space;'
updated_select_query = main_db.wrap_sql_limit(select_query, 100)
# Different DB engines have their own spacing while compiling
# the queries, that's why ' '.join(query.split()) is used.
# In addition some of the engines do not include OFFSET 0.
self.assertTrue(
'SELECT * FROM (SELECT * FROM outer_space;) AS inner_qry '
'LIMIT 100' in ' '.join(updated_select_query.split()),
)

select_query_no_semicolon = 'SELECT * FROM outer_space'
updated_select_query_no_semicolon = main_db.wrap_sql_limit(
select_query_no_semicolon, 100)
self.assertTrue(
'SELECT * FROM (SELECT * FROM outer_space) AS inner_qry '
'LIMIT 100' in
' '.join(updated_select_query_no_semicolon.split()),
)

multi_line_query = (
"SELECT * FROM planets WHERE\n Luke_Father = 'Darth Vader';"
)
updated_multi_line_query = main_db.wrap_sql_limit(multi_line_query, 100)
self.assertTrue(
'SELECT * FROM (SELECT * FROM planets WHERE '
"Luke_Father = 'Darth Vader';) AS inner_qry LIMIT 100" in
' '.join(updated_multi_line_query.split()),
)

def test_run_sync_query_dont_exist(self):
main_db = self.get_main_database(db.session)
db_id = main_db.id
Expand Down
51 changes: 47 additions & 4 deletions tests/db_engine_specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from __future__ import print_function
from __future__ import unicode_literals

import unittest
from superset.db_engine_specs import (
HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec)
from superset.models.core import Database
from .base_tests import SupersetTestCase

from superset.db_engine_specs import HiveEngineSpec


class DbEngineSpecsTestCase(unittest.TestCase):
class DbEngineSpecsTestCase(SupersetTestCase):
def test_0_progress(self):
log = """
17/02/07 18:26:27 INFO log.PerfLogger: <PERFLOG method=compile from=org.apache.hadoop.hive.ql.Driver>
Expand Down Expand Up @@ -80,3 +81,45 @@ def test_job_2_launched_stage_2_stages_progress(self):
17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0%
""".split('\n') # noqa ignore: E501
self.assertEquals(60, HiveEngineSpec.progress(log))

def test_wrapped_query(self):
sql = 'SELECT * FROM a'
db = Database(sqlalchemy_uri='mysql://localhost')
limited = MssqlEngineSpec.apply_limit_to_sql(sql, 1000, db)
expected = 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000'
self.assertEquals(expected, limited)

def test_simple_limit_query(self):
sql = 'SELECT * FROM a'
db = Database(sqlalchemy_uri='mysql://localhost')
limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db)
expected = 'SELECT * FROM a LIMIT 1000'
self.assertEquals(expected, limited)

def test_modify_limit_query(self):
sql = 'SELECT * FROM a LIMIT 9999'
db = Database(sqlalchemy_uri='mysql://localhost')
limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db)
expected = 'SELECT * FROM a LIMIT 1000'
self.assertEquals(expected, limited)

def test_modify_newline_query(self):
sql = 'SELECT * FROM a\nLIMIT 9999'
db = Database(sqlalchemy_uri='mysql://localhost')
limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db)
expected = 'SELECT * FROM a LIMIT 1000'
self.assertEquals(expected, limited)

def test_modify_lcase_limit_query(self):
sql = 'SELECT * FROM a\tlimit 9999'
db = Database(sqlalchemy_uri='mysql://localhost')
limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db)
expected = 'SELECT * FROM a LIMIT 1000'
self.assertEquals(expected, limited)

def test_limit_query_with_limit_subquery(self):
sql = 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999'
db = Database(sqlalchemy_uri='mysql://localhost')
limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db)
expected = 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000'
self.assertEquals(expected, limited)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm nervous about the regular expression, I'd add a few more unit tests covering some weird cases, eg:

SELECT
    'LIMIT 1000' AS a
  , b
FROM
    table
LIMIT
    1000
<EOF>

And maybe some stuff with extra space at the end:

...
LIMIT      1000  ;    
  <EOF>