Skip to content

Commit

Permalink
fix: adds the ability to disallow SQL functions per engine (apache#28639
Browse files Browse the repository at this point in the history
)
  • Loading branch information
dpgaspar authored May 29, 2024
1 parent 6575cac commit 5dfbab5
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 15 deletions.
9 changes: 9 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,15 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
#
DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None

# A set of disallowed SQL functions per engine. This is used to restrict the use of
# unsafe SQL functions in SQL Lab and Charts. The keys of the dictionary are the engine
# names, and the values are sets of disallowed functions.
DISALLOWED_SQL_FUNCTIONS: dict[str, set[str]] = {
"postgresql": {"version", "query_to_xml", "inet_server_addr", "inet_client_addr"},
"clickhouse": {"url"},
"mysql": {"version"},
}


# A function that intercepts the SQL to be executed and can alter it.
# A common use case for this is around adding some sort of comment header to the SQL
Expand Down
7 changes: 6 additions & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from superset.constants import TimeGrain as TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
from superset.sql_parse import ParsedQuery, SQLScript, Table
from superset.superset_typing import (
OAuth2ClientConfig,
Expand Down Expand Up @@ -1818,6 +1818,11 @@ def execute( # pylint: disable=unused-argument
"""
if not cls.allows_sql_comments:
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
disallowed_functions = current_app.config["DISALLOWED_SQL_FUNCTIONS"].get(
cls.engine, set()
)
if sql_parse.check_sql_functions_exist(query, disallowed_functions, cls.engine):
raise DisallowedSQLFunction(disallowed_functions)

if cls.arraysize:
cursor.arraysize = cls.arraysize
Expand Down
11 changes: 7 additions & 4 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import time
from typing import Any, TYPE_CHECKING

from flask import current_app
from flask import current_app, Flask
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
Expand Down Expand Up @@ -218,19 +218,22 @@ def execute_with_cursor(
execute_result: dict[str, Any] = {}
execute_event = threading.Event()

def _execute(results: dict[str, Any], event: threading.Event) -> None:
def _execute(
results: dict[str, Any], event: threading.Event, app: Flask
) -> None:
logger.debug("Query %d: Running query: %s", query_id, sql)

try:
cls.execute(cursor, sql, query.database)
with app.app_context():
cls.execute(cursor, sql, query.database)
except Exception as ex: # pylint: disable=broad-except
results["error"] = ex
finally:
event.set()

execute_thread = threading.Thread(
target=_execute,
args=(execute_result, execute_event),
args=(execute_result, execute_event, current_app._get_current_object()), # pylint: disable=protected-access
)
execute_thread.start()

Expand Down
15 changes: 15 additions & 0 deletions superset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,21 @@ def __init__(self, error: str):
)


class DisallowedSQLFunction(SupersetErrorException):
"""
Disallowed function found on SQL statement
"""

def __init__(self, functions: set[str]):
super().__init__(
SupersetError(
message=f"SQL statement contains disallowed function(s): {functions}",
error_type=SupersetErrorType.SYNTAX_ERROR,
level=ErrorLevel.ERROR,
)
)


class CreateKeyValueDistributedLockFailedException(Exception):
"""
Exception to signalize failure to acquire lock.
Expand Down
42 changes: 42 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
Function,
Identifier,
IdentifierList,
Parenthesis,
Expand Down Expand Up @@ -223,6 +224,19 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
return cte, remainder


def check_sql_functions_exist(
sql: str, function_list: set[str], engine: str | None = None
) -> bool:
"""
Check if the SQL statement contains any of the specified functions.
:param sql: The SQL statement
:param function_list: The list of functions to search for
:param engine: The engine to use for parsing the SQL statement
"""
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)


def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
"""
Strips comments from a SQL statement, does a simple test first
Expand Down Expand Up @@ -743,6 +757,34 @@ def tables(self) -> set[Table]:
self._tables = self._extract_tables_from_sql()
return self._tables

def _check_functions_exist_in_token(
self, token: Token, functions: set[str]
) -> bool:
if (
isinstance(token, Function)
and token.get_name() is not None
and token.get_name().lower() in functions
):
return True
if hasattr(token, "tokens"):
for inner_token in token.tokens:
if self._check_functions_exist_in_token(inner_token, functions):
return True
return False

def check_functions_exist(self, functions: set[str]) -> bool:
"""
Check if the SQL statement contains any of the specified functions.
:param functions: A set of functions to search for
:return: True if the statement contains any of the specified functions
"""
for statement in self._parsed:
for token in statement.tokens:
if self._check_functions_exist_in_token(token, functions):
return True
return False

def _extract_tables_from_sql(self) -> set[Table]:
"""
Extract all table references in a query.
Expand Down
24 changes: 14 additions & 10 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def test_handle_cursor_early_cancel(
assert cancel_query_mock.call_args is None


def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
from superset.db_engine_specs.trino import TrinoEngineSpec

Expand All @@ -416,16 +416,20 @@ def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id

mock_cursor.execute.side_effect = _mock_execute
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)

TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)

mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)
mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)


def test_get_columns(mocker: MockerFixture):
Expand Down
26 changes: 26 additions & 0 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from superset.sql_parse import (
add_table_name,
check_sql_functions_exist,
extract_table_references,
extract_tables_from_jinja_sql,
get_rls_for_table,
Expand Down Expand Up @@ -1215,6 +1216,31 @@ def test_strip_comments_from_sql() -> None:
)


def test_check_sql_functions_exist() -> None:
"""
Test that comments are stripped out correctly.
"""
assert not (
check_sql_functions_exist("select a, b from version", {"version"}, "postgresql")
)

assert check_sql_functions_exist("select version()", {"version"}, "postgresql")

assert check_sql_functions_exist(
"select version from version()", {"version"}, "postgresql"
)

assert check_sql_functions_exist(
"select 1, a.version from (select version from version()) as a",
{"version"},
"postgresql",
)

assert check_sql_functions_exist(
"select 1, a.version from (select version()) as a", {"version"}, "postgresql"
)


def test_sanitize_clause_valid():
# regular clauses
assert sanitize_clause("col = 1") == "col = 1"
Expand Down

0 comments on commit 5dfbab5

Please sign in to comment.