diff --git a/superset/security/manager.py b/superset/security/manager.py index ef0f9c975a18a..2935e1eb98e22 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -79,7 +79,7 @@ if TYPE_CHECKING: from superset.common.query_context import QueryContext from superset.connectors.base.models import BaseDatasource - from superset.connectors.sqla.models import SqlaTable + from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.sql_lab import Query @@ -2083,28 +2083,30 @@ def get_rls_filters(self, table: "BaseDatasource") -> list[SqlaQuery]: ) return query.all() - def get_rls_ids(self, table: "BaseDatasource") -> list[int]: + def get_rls_sorted(self, table: "BaseDatasource") -> list["RowLevelSecurityFilter"]: """ - Retrieves the appropriate row level security filters IDs for the current user - and the passed table. + Retrieves a list RLS filters sorted by ID for + the current user and the passed table. :param table: The table to check against - :returns: A list of IDs + :returns: A list RLS filters """ - ids = [f.id for f in self.get_rls_filters(table)] - ids.sort() # Combinations rather than permutations - return ids + filters = self.get_rls_filters(table) + filters.sort(key=lambda f: f.id) + return filters def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]: return [f.get("clause", "") for f in self.get_guest_rls_filters(table)] def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]: - rls_ids = [] + rls_clauses_with_group_key = [] if datasource.is_rls_supported: - rls_ids = self.get_rls_ids(datasource) - rls_str = [str(rls_id) for rls_id in rls_ids] + rls_clauses_with_group_key = [ + f"{f.clause}-{f.group_key or ''}" + for f in self.get_rls_sorted(datasource) + ] guest_rls = self.get_guest_rls_filters_str(datasource) - return guest_rls + rls_str + return guest_rls + rls_clauses_with_group_key @staticmethod def _get_current_epoch_time() -> float: diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index c29ebe9afef03..41ca0d5e798e9 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -305,6 +305,21 @@ def test_rls_filter_doesnt_alter_admin_birth_names_query(self): assert not self.NAMES_Q_REGEX.search(sql) assert not self.BASE_FILTER_REGEX.search(sql) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_get_rls_cache_key(self): + g.user = self.get_user(username="admin") + tbl = self.get_table(name="birth_names") + clauses = security_manager.get_rls_cache_key(tbl) + assert clauses == [] + + g.user = self.get_user(username="gamma") + clauses = security_manager.get_rls_cache_key(tbl) + assert clauses == [ + "name like 'A%' or name like 'B%'-name", + "name like 'Q%'-name", + "gender = 'boy'-gender", + ] + class TestRowLevelSecurityCreateAPI(SupersetTestCase): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")