From be9084ae3cf7084bbc530cbf7e50c1ff075db8a7 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Mon, 7 May 2018 18:18:31 -0700 Subject: [PATCH 1/5] [sql lab] a better approach at limiting queries Currently there are two mechanisms that we use to enforce the row limiting constraints, depending on the database engine: 1. use dbapi's `cursor.fetchmany()` 2. wrap the SQL into a limiting subquery Method 1 isn't great as it can result in the database server storing larger than required result sets in memory expecting another fetch command while we know we don't need that. Method 2 has a positive side of working with all database engines, whether they use LIMIT, ROWNUM, TOP or whatever else since sqlalchemy does the work as specified for the dialect. On the downside though the query optimizer might not be able to optimize this as much as an approach that doesn't use a subquery. Since most modern DBs use the LIMIT syntax, this adds a regex approach to modify the query and force a LIMIT clause without using a subquery for the database that support this syntax and uses method 2 for all others. --- superset/db_engine_specs.py | 26 +++++++++++++++--- superset/models/core.py | 9 +------ superset/sql_lab.py | 3 +-- superset/sql_parse.py | 24 +++++------------ tests/db_engine_specs_test.py | 50 ++++++++++++++++++++++++++++++++--- 5 files changed, 78 insertions(+), 34 deletions(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index a718a0d62ca85..189fb713db36d 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -35,6 +35,7 @@ from sqlalchemy.engine import create_engine from sqlalchemy.engine.url import make_url from sqlalchemy.sql import text +from sqlalchemy.sql.expression import TextAsFrom import sqlparse import unicodecsv from werkzeug.utils import secure_filename @@ -55,6 +56,7 @@ class LimitMethod(object): """Enum the ways that limits can be applied""" FETCH_MANY = 'fetch_many' WRAP_SQL = 'wrap_sql' + FORCE_LIMIT = 'force_limit' class BaseEngineSpec(object): @@ -65,7 +67,7 @@ class BaseEngineSpec(object): cursor_execute_kwargs = {} time_grains = tuple() time_groupby_inline = False - limit_method = LimitMethod.FETCH_MANY + limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False inner_joins = True @@ -88,6 +90,23 @@ def extra_table_metadata(cls, database, table_name, schema_name): """Returns engine-specific table metadata""" return {} + @classmethod + def apply_limit_to_sql(cls, sql, limit, database): + """Alters the SQL statement to apply a LIMIT clause""" + if cls.limit_method == LimitMethod.WRAP_SQL: + qry = ( + select('*') + .select_from( + TextAsFrom(text(sql), ['*']).alias('inner_qry'), + ) + .limit(limit) + ) + return database.compile_sqla_query(qry) + elif LimitMethod.FORCE_LIMIT: + no_limit = re.sub(r"(?i)\s+LIMIT\s+\d+;?(\s|;)*$", '', sql) + return "{no_limit} LIMIT {limit}".format(**locals()) + return sql + @staticmethod def csv_to_df(**kwargs): kwargs['filepath_or_buffer'] = \ @@ -337,7 +356,6 @@ def get_table_names(cls, schema, inspector): class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = 'snowflake' - time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), Grain('second', _('second'), "DATE_TRUNC('SECOND', {col})", 'PT1S'), @@ -361,6 +379,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec): class OracleEngineSpec(PostgresBaseEngineSpec): engine = 'oracle' + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), @@ -382,6 +401,7 @@ def convert_dttm(cls, target_type, dttm): class Db2EngineSpec(BaseEngineSpec): engine = 'ibm_db_sa' + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), Grain('second', _('second'), @@ -1106,6 +1126,7 @@ def get_configuration_for_impersonation(cls, uri, impersonate_user, username): class MssqlEngineSpec(BaseEngineSpec): engine = 'mssql' epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')" + limit_method = LimitMethod.WRAP_SQL time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), @@ -1310,7 +1331,6 @@ def get_schema_names(cls, inspector): class DruidEngineSpec(BaseEngineSpec): """Engine spec for Druid.io""" engine = 'druid' - limit_method = LimitMethod.FETCH_MANY inner_joins = False diff --git a/superset/models/core.py b/superset/models/core.py index 2ad20faca85c6..2187b0540bd6b 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -722,14 +722,7 @@ def select_star( indent=indent, latest_partition=latest_partition, cols=cols) def wrap_sql_limit(self, sql, limit=1000): - qry = ( - select('*') - .select_from( - TextAsFrom(text(sql), ['*']) - .alias('inner_qry'), - ).limit(limit) - ) - return self.compile_sqla_query(qry) + return self.db_engine_spec.apply_limit_to_sql(sql, limit, self) def safe_sqlalchemy_uri(self): return self.sqlalchemy_uri diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 856ea4880fb71..99ec957d80d1c 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -188,8 +188,7 @@ def handle_error(msg): query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S')) executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True - elif (query.limit and superset_query.is_select() and - db_engine_spec.limit_method == LimitMethod.WRAP_SQL): + elif (query.limit and superset_query.is_select()): executed_sql = database.wrap_sql_limit(executed_sql, query.limit) query.limit_used = True diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 790371ae35706..ea1c9c38851c1 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -15,13 +15,13 @@ PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'} -# TODO: some sql_lab logic here. class SupersetQuery(object): def __init__(self, sql_statement): self.sql = sql_statement self._table_names = set() self._alias_names = set() # TODO: multistatement support + logging.info('Parsing with sqlparse statement {}'.format(self.sql)) self._parsed = sqlparse.parse(self.sql) for statement in self._parsed: @@ -36,11 +36,7 @@ def is_select(self): return self._parsed[0].get_type() == 'SELECT' def stripped(self): - sql = self.sql - if sql: - while sql[-1] in (' ', ';', '\n', '\t'): - sql = sql[:-1] - return sql + return self.sql.strip(' \t\n;') @staticmethod def __precedes_table_name(token_value): @@ -65,13 +61,12 @@ def __is_result_operation(keyword): @staticmethod def __is_identifier(token): - return ( - isinstance(token, IdentifierList) or isinstance(token, Identifier)) + return isinstance(token, (IdentifierList, Identifier)) def __process_identifier(self, identifier): # exclude subselects if '(' not in '{}'.format(identifier): - self._table_names.add(SupersetQuery.__get_full_name(identifier)) + self._table_names.add(self.__get_full_name(identifier)) return # store aliases @@ -94,11 +89,6 @@ def as_create_table(self, table_name, overwrite=False): :param overwrite, boolean, table table_name will be dropped if true :return: string, create table as query """ - # TODO(bkyryliuk): enforce that all the columns have names. - # Presto requires it for the CTA operation. - # TODO(bkyryliuk): drop table if allowed, check the namespace and - # the permissions. - # TODO raise if multi-statement exec_sql = '' sql = self.stripped() if overwrite: @@ -117,7 +107,7 @@ def __extract_from_token(self, token): self.__extract_from_token(item) if item.ttype in Keyword: - if SupersetQuery.__precedes_table_name(item.value.upper()): + if self.__precedes_table_name(item.value.upper()): table_name_preceding_token = True continue @@ -125,7 +115,7 @@ def __extract_from_token(self, token): continue if item.ttype in Keyword: - if SupersetQuery.__is_result_operation(item.value): + if self.__is_result_operation(item.value): table_name_preceding_token = False continue # FROM clause is over @@ -136,5 +126,5 @@ def __extract_from_token(self, token): if isinstance(item, IdentifierList): for token in item.tokens: - if SupersetQuery.__is_identifier(token): + if self.__is_identifier(token): self.__process_identifier(token) diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 1a1282ad1a47f..71e6bc493bb8b 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -4,12 +4,12 @@ from __future__ import print_function from __future__ import unicode_literals -import unittest +from superset.db_engine_specs import MssqlEngineSpec, HiveEngineSpec, MySQLEngineSpec +from superset.models.core import Database +from .base_tests import SupersetTestCase -from superset.db_engine_specs import HiveEngineSpec - -class DbEngineSpecsTestCase(unittest.TestCase): +class DbEngineSpecsTestCase(SupersetTestCase): def test_0_progress(self): log = """ 17/02/07 18:26:27 INFO log.PerfLogger: @@ -80,3 +80,45 @@ def test_job_2_launched_stage_2_stages_progress(self): 17/02/07 19:16:09 INFO exec.Task: 2017-02-07 19:16:09,173 Stage-1 map = 40%, reduce = 0% """.split('\n') # noqa ignore: E501 self.assertEquals(60, HiveEngineSpec.progress(log)) + + def test_wrapped_query(self): + sql = "SELECT * FROM a" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MssqlEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_simple_limit_query(self): + sql = "SELECT * FROM a" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * FROM a LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_modify_limit_query(self): + sql = "SELECT * FROM a LIMIT 9999" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * FROM a LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_modify_newline_query(self): + sql = "SELECT * FROM a\nLIMIT 9999" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * FROM a LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_modify_lcase_limit_query(self): + sql = "SELECT * FROM a\tlimit 9999" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = """SELECT * FROM a LIMIT 1000""" + self.assertEquals(expected, limited) + + def test_limit_query_with_limit_subquery(self): + sql = "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999" + db = Database(sqlalchemy_uri="mysql://localhost") + limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) + expected = "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000" + self.assertEquals(expected, limited) From 08ff644539139b625149a72b19ef12a982fff88b Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Thu, 10 May 2018 13:58:31 -0700 Subject: [PATCH 2/5] Fixing build --- superset/db_engine_specs.py | 2 +- superset/models/core.py | 4 +--- superset/sql_lab.py | 1 - tests/celery_tests.py | 32 ---------------------------- tests/db_engine_specs_test.py | 39 ++++++++++++++++++----------------- 5 files changed, 22 insertions(+), 56 deletions(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 189fb713db36d..2cf4891d73cf8 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -103,7 +103,7 @@ def apply_limit_to_sql(cls, sql, limit, database): ) return database.compile_sqla_query(qry) elif LimitMethod.FORCE_LIMIT: - no_limit = re.sub(r"(?i)\s+LIMIT\s+\d+;?(\s|;)*$", '', sql) + no_limit = re.sub(r'(?i)\s+LIMIT\s+\d+;?(\s|;)*$', '', sql) return "{no_limit} LIMIT {limit}".format(**locals()) return sql diff --git a/superset/models/core.py b/superset/models/core.py index 2187b0540bd6b..ed33461f23c5e 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -22,7 +22,7 @@ import sqlalchemy as sqla from sqlalchemy import ( Boolean, Column, create_engine, DateTime, ForeignKey, Integer, - MetaData, select, String, Table, Text, + MetaData, String, Table, Text, ) from sqlalchemy.engine import url from sqlalchemy.engine.url import make_url @@ -30,8 +30,6 @@ from sqlalchemy.orm.session import make_transient from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint -from sqlalchemy.sql import text -from sqlalchemy.sql.expression import TextAsFrom from sqlalchemy_utils import EncryptedType from superset import app, db, db_engine_specs, security_manager, utils diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 99ec957d80d1c..881271a2281f9 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -17,7 +17,6 @@ from sqlalchemy.pool import NullPool from superset import app, dataframe, db, results_backend, security_manager, utils -from superset.db_engine_specs import LimitMethod from superset.models.sql_lab import Query from superset.sql_parse import SupersetQuery from superset.utils import get_celery_app, QueryStatus diff --git a/tests/celery_tests.py b/tests/celery_tests.py index 79b71e986e2aa..f6d1a2958fd1c 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -139,38 +139,6 @@ def run_sql(self, db_id, sql, client_id, cta='false', tmp_table='tmp', self.logout() return json.loads(resp.data.decode('utf-8')) - def test_add_limit_to_the_query(self): - main_db = self.get_main_database(db.session) - - select_query = 'SELECT * FROM outer_space;' - updated_select_query = main_db.wrap_sql_limit(select_query, 100) - # Different DB engines have their own spacing while compiling - # the queries, that's why ' '.join(query.split()) is used. - # In addition some of the engines do not include OFFSET 0. - self.assertTrue( - 'SELECT * FROM (SELECT * FROM outer_space;) AS inner_qry ' - 'LIMIT 100' in ' '.join(updated_select_query.split()), - ) - - select_query_no_semicolon = 'SELECT * FROM outer_space' - updated_select_query_no_semicolon = main_db.wrap_sql_limit( - select_query_no_semicolon, 100) - self.assertTrue( - 'SELECT * FROM (SELECT * FROM outer_space) AS inner_qry ' - 'LIMIT 100' in - ' '.join(updated_select_query_no_semicolon.split()), - ) - - multi_line_query = ( - "SELECT * FROM planets WHERE\n Luke_Father = 'Darth Vader';" - ) - updated_multi_line_query = main_db.wrap_sql_limit(multi_line_query, 100) - self.assertTrue( - 'SELECT * FROM (SELECT * FROM planets WHERE ' - "Luke_Father = 'Darth Vader';) AS inner_qry LIMIT 100" in - ' '.join(updated_multi_line_query.split()), - ) - def test_run_sync_query_dont_exist(self): main_db = self.get_main_database(db.session) db_id = main_db.id diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 71e6bc493bb8b..0726a68fbff0e 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -4,7 +4,8 @@ from __future__ import print_function from __future__ import unicode_literals -from superset.db_engine_specs import MssqlEngineSpec, HiveEngineSpec, MySQLEngineSpec +from superset.db_engine_specs import ( + HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec) from superset.models.core import Database from .base_tests import SupersetTestCase @@ -82,43 +83,43 @@ def test_job_2_launched_stage_2_stages_progress(self): self.assertEquals(60, HiveEngineSpec.progress(log)) def test_wrapped_query(self): - sql = "SELECT * FROM a" - db = Database(sqlalchemy_uri="mysql://localhost") + sql = 'SELECT * FROM a' + db = Database(sqlalchemy_uri='mysql://localhost') limited = MssqlEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = """SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000""" + expected = 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000' self.assertEquals(expected, limited) def test_simple_limit_query(self): - sql = "SELECT * FROM a" - db = Database(sqlalchemy_uri="mysql://localhost") + sql = 'SELECT * FROM a' + db = Database(sqlalchemy_uri='mysql://localhost') limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = """SELECT * FROM a LIMIT 1000""" + expected = 'SELECT * FROM a LIMIT 1000' self.assertEquals(expected, limited) def test_modify_limit_query(self): - sql = "SELECT * FROM a LIMIT 9999" - db = Database(sqlalchemy_uri="mysql://localhost") + sql = 'SELECT * FROM a LIMIT 9999' + db = Database(sqlalchemy_uri='mysql://localhost') limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = """SELECT * FROM a LIMIT 1000""" + expected = 'SELECT * FROM a LIMIT 1000' self.assertEquals(expected, limited) def test_modify_newline_query(self): - sql = "SELECT * FROM a\nLIMIT 9999" - db = Database(sqlalchemy_uri="mysql://localhost") + sql = 'SELECT * FROM a\nLIMIT 9999' + db = Database(sqlalchemy_uri='mysql://localhost') limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = """SELECT * FROM a LIMIT 1000""" + expected = 'SELECT * FROM a LIMIT 1000' self.assertEquals(expected, limited) def test_modify_lcase_limit_query(self): - sql = "SELECT * FROM a\tlimit 9999" - db = Database(sqlalchemy_uri="mysql://localhost") + sql = 'SELECT * FROM a\tlimit 9999' + db = Database(sqlalchemy_uri='mysql://localhost') limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = """SELECT * FROM a LIMIT 1000""" + expected = 'SELECT * FROM a LIMIT 1000' self.assertEquals(expected, limited) def test_limit_query_with_limit_subquery(self): - sql = "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999" - db = Database(sqlalchemy_uri="mysql://localhost") + sql = 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999' + db = Database(sqlalchemy_uri='mysql://localhost') limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000" + expected = 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000' self.assertEquals(expected, limited) From 94b939d56c248be7e614ac1a2d3c1f8e6e387280 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Thu, 10 May 2018 14:02:21 -0700 Subject: [PATCH 3/5] Fix lint --- superset/db_engine_specs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 2cf4891d73cf8..957da21194fee 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -104,7 +104,7 @@ def apply_limit_to_sql(cls, sql, limit, database): return database.compile_sqla_query(qry) elif LimitMethod.FORCE_LIMIT: no_limit = re.sub(r'(?i)\s+LIMIT\s+\d+;?(\s|;)*$', '', sql) - return "{no_limit} LIMIT {limit}".format(**locals()) + return '{no_limit} LIMIT {limit}'.format(**locals()) return sql @staticmethod From 7e819135de285886ee345b72d9e4da753142427b Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Sun, 13 May 2018 13:31:40 -0500 Subject: [PATCH 4/5] Added more tests --- superset/db_engine_specs.py | 8 +++ superset/models/core.py | 2 +- superset/sql_lab.py | 2 +- tests/db_engine_specs_test.py | 117 +++++++++++++++++++++++++--------- 4 files changed, 97 insertions(+), 32 deletions(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 957da21194fee..88c51eadbcd38 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -94,6 +94,7 @@ def extra_table_metadata(cls, database, table_name, schema_name): def apply_limit_to_sql(cls, sql, limit, database): """Alters the SQL statement to apply a LIMIT clause""" if cls.limit_method == LimitMethod.WRAP_SQL: + sql = sql.strip('\t\n ;') qry = ( select('*') .select_from( @@ -104,6 +105,13 @@ def apply_limit_to_sql(cls, sql, limit, database): return database.compile_sqla_query(qry) elif LimitMethod.FORCE_LIMIT: no_limit = re.sub(r'(?i)\s+LIMIT\s+\d+;?(\s|;)*$', '', sql) + no_limit = re.sub(r""" + (?ix) # case insensitive, verbose + \s+ # whitespace + LIMIT\s+\d+ # LIMIT $ROWS + ;? # optional semi-colon + (\s|;)*$ # remove trailing spaces tabs or semicolons + """, '', sql) return '{no_limit} LIMIT {limit}'.format(**locals()) return sql diff --git a/superset/models/core.py b/superset/models/core.py index ed33461f23c5e..8448c7ba54e49 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -719,7 +719,7 @@ def select_star( self, table_name, schema=schema, limit=limit, show_cols=show_cols, indent=indent, latest_partition=latest_partition, cols=cols) - def wrap_sql_limit(self, sql, limit=1000): + def apply_limit_to_sql(self, sql, limit=1000): return self.db_engine_spec.apply_limit_to_sql(sql, limit, self) def safe_sqlalchemy_uri(self): diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 881271a2281f9..75e88146b9eee 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -188,7 +188,7 @@ def handle_error(msg): executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True elif (query.limit and superset_query.is_select()): - executed_sql = database.wrap_sql_limit(executed_sql, query.limit) + executed_sql = database.apply_limit_to_sql(executed_sql, query.limit) query.limit_used = True # Hook to allow environment-specific mutation (usually comments) to the SQL diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 0726a68fbff0e..71c93ee328f7f 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -4,6 +4,9 @@ from __future__ import print_function from __future__ import unicode_literals +import textwrap + +from superset import db from superset.db_engine_specs import ( HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec) from superset.models.core import Database @@ -82,44 +85,98 @@ def test_job_2_launched_stage_2_stages_progress(self): """.split('\n') # noqa ignore: E501 self.assertEquals(60, HiveEngineSpec.progress(log)) + def sql_limit_regex( + self, sql, expected_sql, + engine_spec_class=MySQLEngineSpec, + limit=1000): + main = self.get_main_database(db.session) + limited = engine_spec_class.apply_limit_to_sql(sql, limit, main) + self.assertEquals(expected_sql, limited) + def test_wrapped_query(self): - sql = 'SELECT * FROM a' - db = Database(sqlalchemy_uri='mysql://localhost') - limited = MssqlEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000' - self.assertEquals(expected, limited) + self.sql_limit_regex( + 'SELECT * FROM a', + 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', + MssqlEngineSpec, + ) + + def test_wrapped_semi(self): + self.sql_limit_regex( + 'SELECT * FROM a;', + 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', + MssqlEngineSpec, + ) + + def test_wrapped_semi_tabs(self): + self.sql_limit_regex( + 'SELECT * FROM a \t \n ; \t \n ', + 'SELECT * \nFROM (SELECT * FROM a) AS inner_qry \n LIMIT 1000', + MssqlEngineSpec, + ) def test_simple_limit_query(self): - sql = 'SELECT * FROM a' - db = Database(sqlalchemy_uri='mysql://localhost') - limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = 'SELECT * FROM a LIMIT 1000' - self.assertEquals(expected, limited) + self.sql_limit_regex( + 'SELECT * FROM a', + 'SELECT * FROM a LIMIT 1000', + ) def test_modify_limit_query(self): - sql = 'SELECT * FROM a LIMIT 9999' - db = Database(sqlalchemy_uri='mysql://localhost') - limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = 'SELECT * FROM a LIMIT 1000' - self.assertEquals(expected, limited) + self.sql_limit_regex( + 'SELECT * FROM a LIMIT 9999', + 'SELECT * FROM a LIMIT 1000', + ) def test_modify_newline_query(self): - sql = 'SELECT * FROM a\nLIMIT 9999' - db = Database(sqlalchemy_uri='mysql://localhost') - limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = 'SELECT * FROM a LIMIT 1000' - self.assertEquals(expected, limited) + self.sql_limit_regex( + 'SELECT * FROM a\nLIMIT 9999', + 'SELECT * FROM a LIMIT 1000', + ) def test_modify_lcase_limit_query(self): - sql = 'SELECT * FROM a\tlimit 9999' - db = Database(sqlalchemy_uri='mysql://localhost') - limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = 'SELECT * FROM a LIMIT 1000' - self.assertEquals(expected, limited) + self.sql_limit_regex( + 'SELECT * FROM a\tlimit 9999', + 'SELECT * FROM a LIMIT 1000', + ) def test_limit_query_with_limit_subquery(self): - sql = 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999' - db = Database(sqlalchemy_uri='mysql://localhost') - limited = MySQLEngineSpec.apply_limit_to_sql(sql, 1000, db) - expected = 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000' - self.assertEquals(expected, limited) + self.sql_limit_regex( + 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999', + 'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000', + ) + + def test_limit_with_expr(self): + self.sql_limit_regex( + textwrap.dedent( + """\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT + 99990"""), + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table LIMIT 1000"""), + ) + + def test_limit_expr_and_semicolon(self): + self.sql_limit_regex( + textwrap.dedent( + """\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT 99990 ;"""), + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table LIMIT 1000"""), + ) From 706d8608f6e28f282f029fb1c6c3fd5f87a84343 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Mon, 14 May 2018 13:36:45 -0500 Subject: [PATCH 5/5] Fix tests --- superset/db_engine_specs.py | 1 - tests/db_engine_specs_test.py | 48 +++++++++++++++++------------------ 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 88c51eadbcd38..b91d93f393aca 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -104,7 +104,6 @@ def apply_limit_to_sql(cls, sql, limit, database): ) return database.compile_sqla_query(qry) elif LimitMethod.FORCE_LIMIT: - no_limit = re.sub(r'(?i)\s+LIMIT\s+\d+;?(\s|;)*$', '', sql) no_limit = re.sub(r""" (?ix) # case insensitive, verbose \s+ # whitespace diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index 71c93ee328f7f..c38e4f569023a 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -6,7 +6,6 @@ import textwrap -from superset import db from superset.db_engine_specs import ( HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec) from superset.models.core import Database @@ -85,11 +84,14 @@ def test_job_2_launched_stage_2_stages_progress(self): """.split('\n') # noqa ignore: E501 self.assertEquals(60, HiveEngineSpec.progress(log)) + def get_generic_database(self): + return Database(sqlalchemy_uri='mysql://localhost') + def sql_limit_regex( self, sql, expected_sql, engine_spec_class=MySQLEngineSpec, limit=1000): - main = self.get_main_database(db.session) + main = self.get_generic_database() limited = engine_spec_class.apply_limit_to_sql(sql, limit, main) self.assertEquals(expected_sql, limited) @@ -146,15 +148,14 @@ def test_limit_query_with_limit_subquery(self): def test_limit_with_expr(self): self.sql_limit_regex( - textwrap.dedent( - """\ - SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT - 99990"""), + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT + 99990"""), textwrap.dedent("""\ SELECT 'LIMIT 777' AS a @@ -165,18 +166,17 @@ def test_limit_with_expr(self): def test_limit_expr_and_semicolon(self): self.sql_limit_regex( - textwrap.dedent( - """\ - SELECT - 'LIMIT 777' AS a - , b - FROM - table - LIMIT 99990 ;"""), textwrap.dedent("""\ - SELECT - 'LIMIT 777' AS a - , b - FROM - table LIMIT 1000"""), + SELECT + 'LIMIT 777' AS a + , b + FROM + table + LIMIT 99990 ;"""), + textwrap.dedent("""\ + SELECT + 'LIMIT 777' AS a + , b + FROM + table LIMIT 1000"""), )