Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Use RLS clause instead of ID for cache key #25229

Merged
20 changes: 10 additions & 10 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2083,26 +2083,26 @@ def get_rls_filters(self, table: "BaseDatasource") -> list[SqlaQuery]:
)
return query.all()

def get_rls_ids(self, table: "BaseDatasource") -> list[int]:
def get_rls_filters_str(self, table: "BaseDatasource") -> list[str]:
"""
Retrieves the appropriate row level security filters IDs for the current user
and the passed table.
Retrieves the appropriate row level security filters string representation
(id concat'ed with clause) for the current user and the passed table.

:param table: The table to check against
:returns: A list of IDs
:returns: A list of string representations of the user's RLS filters
"""
ids = [f.id for f in self.get_rls_filters(table)]
ids.sort() # Combinations rather than permutations
return ids
filters = self.get_rls_filters(table)
filters.sort(key=lambda f: f.id) # Combinations rather than permutations
str_reps = [f"{f.id}-{f.clause}" for f in filters]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a totally unnecessary perf optimization, but do we need the ids here?I think just using the clause and sorting them should be sufficient, and would avoid a cache miss if two filters are swapped (= id stays the same, but the clauses are interchanged).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I included the ID here since I was thinking about an edge case where there could be two different RLS with the same clause but different type (regular/base) and we'd want to avoid a false positive cache hit.

However I just realized I was thinking about this wrong because the filter type only affects who the filter applies to, not how the clause is applied. I'll remove the id and just sort by clause

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, but now I've realized there's the group key which actually does affect how the clause is applied, so I think that does actually need to be considered for the cache key

return str_reps

def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]:
return [f.get("clause", "") for f in self.get_guest_rls_filters(table)]

def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]:
rls_ids = []
rls_str = []
if datasource.is_rls_supported:
rls_ids = self.get_rls_ids(datasource)
rls_str = [str(rls_id) for rls_id in rls_ids]
rls_str = self.get_rls_filters_str(datasource)
guest_rls = self.get_guest_rls_filters_str(datasource)
return guest_rls + rls_str

Expand Down
11 changes: 11 additions & 0 deletions tests/integration_tests/security/row_level_security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,17 @@ def test_rls_filter_doesnt_alter_admin_birth_names_query(self):
assert not self.NAMES_Q_REGEX.search(sql)
assert not self.BASE_FILTER_REGEX.search(sql)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_rls_filter_clauses(self):
g.user = self.get_user(username="admin")
tbl = self.get_table(name="birth_names")
clauses = security_manager.get_rls_filters_str(tbl)
assert clauses == []

g.user = self.get_user(username="gamma")
clauses = security_manager.get_rls_filters_str(tbl)
assert clauses == [f"{self.rls_entry2.id}-name like 'A%' or name like 'B%'", f"{self.rls_entry3.id}-name like 'Q%'", f"{self.rls_entry4.id}-gender = 'boy'"]


class TestRowLevelSecurityCreateAPI(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
Expand Down
Loading