Skip to content

Commit

Permalink
feat: safer insert RLS
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Jun 9, 2022
1 parent 07b4a71 commit 52597ad
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 11 deletions.
4 changes: 2 additions & 2 deletions superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery
from superset.sql_parse import has_table_query, insert_rls_in_predicate, ParsedQuery
from superset.superset_typing import ResultSetColumnType
from superset.utils.memoized import memoized

Expand Down Expand Up @@ -174,7 +174,7 @@ def validate_adhoc_subquery(
level=ErrorLevel.ERROR,
)
)
statement = insert_rls(statement, database_id, default_schema)
statement = insert_rls_in_predicate(statement, database_id, default_schema)
statements.append(statement)

return ";\n".join(str(statement) for statement in statements)
Expand Down
17 changes: 16 additions & 1 deletion superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sql_parse import CtasMethod, insert_rls, ParsedQuery
from superset.sql_parse import (
CtasMethod,
insert_rls_as_subquery,
insert_rls_in_predicate,
ParsedQuery,
)
from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils.celery import session_scope
from superset.utils.core import (
Expand Down Expand Up @@ -201,6 +206,16 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statem

parsed_query = ParsedQuery(sql_statement)
if is_feature_enabled("RLS_IN_SQLLAB"):
# There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
# safer, but not supported in all databases.
insert_rls = (
insert_rls_as_subquery
if database.db_engine_spec.allows_subqueries
and database.db_engine_spec.allows_alias_in_select
else insert_rls_in_predicate
)

# Insert any applicable RLS predicates
parsed_query = ParsedQuery(
str(
Expand Down
99 changes: 97 additions & 2 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
Punctuation,
String,
Whitespace,
Wildcard,
)
from sqlparse.utils import imt

Expand Down Expand Up @@ -597,14 +598,106 @@ def get_rls_for_table(
return rls


def insert_rls(
def insert_rls_as_subquery(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
username: Optional[str] = None,
) -> TokenList:
"""
Update a statement inplace applying any associated RLS predicates.
The RLS predicate is applied as subquery replacing the original table:
before: SELECT * FROM some_table WHERE 1=1
after: SELECT * FROM (
SELECT * FROM some_table WHERE some_table.id=42
) AS some_table
WHERE 1=1
This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
databases.
"""
rls: Optional[TokenList] = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:

# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
token_list.tokens[i] = insert_rls_as_subquery(
token, database_id, default_schema
)

# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
state = InsertRLSState.SEEN_SOURCE

# Found identifier/keyword after FROM/JOIN, test for table
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
rls = get_rls_for_table(token, database_id, default_schema, username)
if rls:
# replace table with subquery
subquery_alias = (
token.tokens[-1].value
if isinstance(token, Identifier)
else token.value
)
i = token_list.tokens.index(token)
token_list.tokens[i] = Identifier(
[
Parenthesis(
[
Token(Punctuation, "("),
Token(DML, "SELECT"),
Token(Whitespace, " "),
Token(Wildcard, "*"),
Token(Whitespace, " "),
Token(Keyword, "FROM"),
Token(Whitespace, " "),
token,
Token(Whitespace, " "),
Where(
[
Token(Keyword, "WHERE"),
Token(Whitespace, " "),
rls,
]
),
Token(Punctuation, ")"),
]
),
Token(Whitespace, " "),
Token(Keyword, "AS"),
Token(Whitespace, " "),
Identifier([Token(Name, subquery_alias)]),
]
)
state = InsertRLSState.SCANNING

# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING

return token_list


def insert_rls_in_predicate(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
username: Optional[str] = None,
) -> TokenList:
"""
Update a statement inplace applying any associated RLS predicates.
The RLS predicate is ``AND``ed to any existing predicates:
before: SELECT * FROM some_table WHERE 1=1
after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
"""
rls: Optional[TokenList] = None
state = InsertRLSState.SCANNING
Expand All @@ -613,7 +706,9 @@ def insert_rls(
# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
token_list.tokens[i] = insert_rls(token, database_id, default_schema)
token_list.tokens[i] = insert_rls_in_predicate(
token, database_id, default_schema
)

# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/sql_lab_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_execute_sql_statement_with_rls(
cursor = mocker.MagicMock()
SupersetResultSet = mocker.patch("superset.sql_lab.SupersetResultSet")
mocker.patch(
"superset.sql_lab.insert_rls",
"superset.sql_lab.insert_rls_in_predicate",
return_value=sqlparse.parse("SELECT * FROM sales WHERE organization_id=42")[0],
)
mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)
Expand All @@ -115,13 +115,13 @@ def test_execute_sql_statement_with_rls(
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)


def test_sql_lab_insert_rls(
def test_sql_lab_insert_rls_in_predicate(
mocker: MockerFixture,
session: Session,
app_context: None,
) -> None:
"""
Integration test for `insert_rls`.
Integration test for `insert_rls_in_predicate`.
"""
from flask_appbuilder.security.sqla.models import Role, User

Expand Down
154 changes: 151 additions & 3 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
extract_table_references,
get_rls_for_table,
has_table_query,
insert_rls,
insert_rls_as_subquery,
insert_rls_in_predicate,
ParsedQuery,
sanitize_clause,
strip_comments_from_sql,
Expand Down Expand Up @@ -1221,6 +1222,151 @@ def test_has_table_query(sql: str, expected: bool) -> None:
assert has_table_query(statement) == expected


@pytest.mark.parametrize(
"sql,table,rls,expected",
[
# Basic test
(
"SELECT * FROM some_table WHERE 1=1",
"some_table",
"id=42",
(
"SELECT * FROM (SELECT * FROM some_table WHERE some_table.id=42) "
"AS some_table WHERE 1=1"
),
),
# Here "table" is a reserved word; since sqlparse is too aggressive when
# characterizing reserved words we need to support them even when not quoted.
(
"SELECT * FROM table WHERE 1=1",
"table",
"id=42",
"SELECT * FROM (SELECT * FROM table WHERE table.id=42) AS table WHERE 1=1",
),
# RLS is only applied to queries reading from the associated table
(
"SELECT * FROM table WHERE 1=1",
"other_table",
"id=42",
"SELECT * FROM table WHERE 1=1",
),
(
"SELECT * FROM other_table WHERE 1=1",
"table",
"id=42",
"SELECT * FROM other_table WHERE 1=1",
),
# JOINs are supported
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
"other_table",
"id=42",
(
"SELECT * FROM table JOIN "
"(SELECT * FROM other_table WHERE other_table.id=42) AS other_table "
"ON table.id = other_table.id"
),
),
# Subqueries
(
"SELECT * FROM (SELECT * FROM other_table)",
"other_table",
"id=42",
(
"SELECT * FROM (SELECT * FROM ("
"SELECT * FROM other_table WHERE other_table.id=42"
") AS other_table)"
),
),
# UNION
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
"table",
"id=42",
(
"SELECT * FROM (SELECT * FROM table WHERE table.id=42) AS table "
"UNION ALL SELECT * FROM other_table"
),
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
"other_table",
"id=42",
(
"SELECT * FROM table UNION ALL SELECT * FROM ("
"SELECT * FROM other_table WHERE other_table.id=42) AS other_table"
),
),
# When comparing fully qualified table names (eg, schema.table) to simple names
# (eg, table) we are also conservative, assuming the schema is the same, since
# we don't have information on the default schema.
(
"SELECT * FROM schema.table_name",
"table_name",
"id=42",
(
"SELECT * FROM (SELECT * FROM schema.table_name "
"WHERE table_name.id=42) AS table_name"
),
),
(
"SELECT * FROM schema.table_name",
"schema.table_name",
"id=42",
(
"SELECT * FROM (SELECT * FROM schema.table_name "
"WHERE schema.table_name.id=42) AS table_name"
),
),
(
"SELECT * FROM table_name",
"schema.table_name",
"id=42",
(
"SELECT * FROM (SELECT * FROM table_name WHERE "
"schema.table_name.id=42) AS table_name"
),
),
],
)
def test_insert_rls_as_subquery(
mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
) -> None:
"""
Insert into a statement a given RLS condition associated with a table.
"""
condition = sqlparse.parse(rls)[0]
add_table_name(condition, table)

# pylint: disable=unused-argument
def get_rls_for_table(
candidate: Token,
database_id: int,
default_schema: str,
username: Optional[str] = None,
) -> Optional[TokenList]:
"""
Return the RLS ``condition`` if ``candidate`` matches ``table``.
"""
# compare ignoring schema
for left, right in zip(str(candidate).split(".")[::-1], table.split(".")[::-1]):
if left != right:
return None
return condition

mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table)

statement = sqlparse.parse(sql)[0]
assert (
str(
insert_rls_as_subquery(
token_list=statement, database_id=1, default_schema="my_schema"
)
).strip()
== expected.strip()
)


@pytest.mark.parametrize(
"sql,table,rls,expected",
[
Expand Down Expand Up @@ -1395,7 +1541,7 @@ def test_has_table_query(sql: str, expected: bool) -> None:
),
],
)
def test_insert_rls(
def test_insert_rls_in_predicate(
mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
) -> None:
"""
Expand Down Expand Up @@ -1425,7 +1571,9 @@ def get_rls_for_table(
statement = sqlparse.parse(sql)[0]
assert (
str(
insert_rls(token_list=statement, database_id=1, default_schema="my_schema")
insert_rls_in_predicate(
token_list=statement, database_id=1, default_schema="my_schema"
)
).strip()
== expected.strip()
)
Expand Down

0 comments on commit 52597ad

Please sign in to comment.