Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Use RLS clause instead of ID for cache key #25229

Merged
26 changes: 14 additions & 12 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
if TYPE_CHECKING:
from superset.common.query_context import QueryContext
from superset.connectors.base.models import BaseDatasource
from superset.connectors.sqla.models import SqlaTable
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.sql_lab import Query
Expand Down Expand Up @@ -2083,28 +2083,30 @@ def get_rls_filters(self, table: "BaseDatasource") -> list[SqlaQuery]:
)
return query.all()

def get_rls_ids(self, table: "BaseDatasource") -> list[int]:
def get_rls_sorted(self, table: "BaseDatasource") -> list["RowLevelSecurityFilter"]:
"""
Retrieves the appropriate row level security filters IDs for the current user
and the passed table.
Retrieves a list RLS filters sorted by ID for
the current user and the passed table.

:param table: The table to check against
:returns: A list of IDs
:returns: A list RLS filters
"""
ids = [f.id for f in self.get_rls_filters(table)]
ids.sort() # Combinations rather than permutations
return ids
filters = self.get_rls_filters(table)
filters.sort(key=lambda f: f.id)
return filters

def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]:
return [f.get("clause", "") for f in self.get_guest_rls_filters(table)]

def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]:
rls_ids = []
rls_clauses_with_group_key = []
if datasource.is_rls_supported:
rls_ids = self.get_rls_ids(datasource)
rls_str = [str(rls_id) for rls_id in rls_ids]
rls_clauses_with_group_key = [
f"{f.clause}-{f.group_key or ''}"
for f in self.get_rls_sorted(datasource)
]
guest_rls = self.get_guest_rls_filters_str(datasource)
return guest_rls + rls_str
return guest_rls + rls_clauses_with_group_key

@staticmethod
def _get_current_epoch_time() -> float:
Expand Down
15 changes: 15 additions & 0 deletions tests/integration_tests/security/row_level_security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,21 @@ def test_rls_filter_doesnt_alter_admin_birth_names_query(self):
assert not self.NAMES_Q_REGEX.search(sql)
assert not self.BASE_FILTER_REGEX.search(sql)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_rls_cache_key(self):
g.user = self.get_user(username="admin")
tbl = self.get_table(name="birth_names")
clauses = security_manager.get_rls_cache_key(tbl)
assert clauses == []

g.user = self.get_user(username="gamma")
clauses = security_manager.get_rls_cache_key(tbl)
assert clauses == [
"name like 'A%' or name like 'B%'-name",
"name like 'Q%'-name",
"gender = 'boy'-gender",
]


class TestRowLevelSecurityCreateAPI(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
Expand Down
Loading