From 20dd9e5109aa90336abaa5378e2107d3d22ffa4a Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 16 Sep 2020 10:17:12 +0300 Subject: [PATCH 1/4] feat(row-level-security): add filter type and group key --- superset/app.py | 4 +- superset/connectors/sqla/models.py | 16 +- superset/connectors/sqla/views.py | 43 ++++- ...4e_add_rls_filter_type_and_grouping_key.py | 48 ++++++ superset/security/manager.py | 36 +++- tests/security_tests.py | 158 ++++++++++++------ 6 files changed, 230 insertions(+), 75 deletions(-) create mode 100644 superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py diff --git a/superset/app.py b/superset/app.py index c3731a4a16567..1ef5b30531aaf 100644 --- a/superset/app.py +++ b/superset/app.py @@ -261,8 +261,8 @@ def init_views(self) -> None: if self.config["ENABLE_ROW_LEVEL_SECURITY"]: appbuilder.add_view( RowLevelSecurityFiltersModelView, - "Row Level Security Filters", - label=__("Row level security filters"), + "Row Level Security", + label=__("Row level security"), category="Security", category_label=__("Security"), icon="fa-lock", diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 667c9d4a3ee7b..e70d96377bf84 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -16,7 +16,7 @@ # under the License. import json import logging -from collections import OrderedDict +from collections import defaultdict, OrderedDict from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union @@ -35,6 +35,7 @@ Column, DateTime, desc, + Enum, ForeignKey, Integer, or_, @@ -798,11 +799,14 @@ def _get_sqla_row_level_filters( :returns: A list of SQL clauses to be ANDed together. :rtype: List[str] """ + filters_grouped: Dict[Union[int, str], List[str]] = defaultdict(list) try: - return [ - text("({})".format(template_processor.process_template(f.clause))) - for f in security_manager.get_rls_filters(self) - ] + for filter_ in security_manager.get_rls_filters(self): + clause = text( + f"({template_processor.process_template(filter_.clause)})" + ) + filters_grouped[filter_.group_key or filter_.id].append(clause) + return [or_(*clauses) for clauses in filters_grouped.values()] except TemplateError as ex: raise QueryObjectValidationError( _("Error in jinja expression in RLS filters: %(msg)s", msg=ex.message,) @@ -1506,6 +1510,8 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable): __tablename__ = "row_level_security_filters" id = Column(Integer, primary_key=True) + filter_type = Column(Enum("Regular", "Base")) + group_key = Column(String(255), nullable=True) roles = relationship( security_manager.role_model, secondary=RLSFilterRoles, diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index a1e28e492b5c9..e07878d793882 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -251,20 +251,51 @@ class RowLevelSecurityFiltersModelView( # pylint: disable=too-many-ancestors add_title = _("Add Row level security filter") edit_title = _("Edit Row level security filter") - list_columns = ["tables", "roles", "clause", "creator", "modified"] - order_columns = ["tables", "clause", "modified"] - edit_columns = ["tables", "roles", "clause"] + list_columns = [ + "filter_type", + "tables", + "roles", + "group_key", + "clause", + "creator", + "modified", + ] + order_columns = ["filter_type", "tables", "group_key", "clause", "modified"] + edit_columns = ["filter_type", "tables", "roles", "group_key", "clause"] show_columns = edit_columns - search_columns = ("tables", "roles", "clause") + search_columns = ("filter_type", "tables", "roles", "group_key", "clause") add_columns = edit_columns base_order = ("changed_on", "desc") description_columns = { + "filter_type": _( + "Regular filters add where clauses to queries if a user belongs to a " + "role referenced in the filter. Base filters apply filters to all queries " + "except the roles defined in the filter, and can be used to define what " + "users can see if no RLS filters within a filter group apply to them." + ), "tables": _("These are the tables this filter will be applied to."), - "roles": _("These are the roles this filter will be applied to."), + "roles": _( + "For regular filters, these are the roles this filter will be " + "applied to. For base filters, these are the roles that the " + "filter DOES NOT apply to, e.g. Admin if admin should see all " + "data." + ), + "group_key": _( + "Filters with the same group key will be ORed together within the group, " + "while different filter groups will be ANDed together. Undefined group " + "keys are treated as unique groups, i.e. are not grouped together. " + "For example, if a table has three filters, of which two are for " + "departments Finance and Marketing (group key = 'department'), and one " + "refers to the region Europe (group key = 'region'), the filter clause " + "would apply the filter (department = 'Finance' OR department = " + "'Marketing') AND (region = 'Europe')." + ), "clause": _( "This is the condition that will be added to the WHERE clause. " "For example, to only return rows for a particular client, " - "you might put in: client_id = 9" + "you might define a regular filter with the clause `client_id = 9`. To " + "display no rows unless a user belongs to a RLS filter role, a base " + "filter can be created with the clause `1 = 0` (always false)." ), } label_columns = { diff --git a/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py b/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py new file mode 100644 index 0000000000000..f66d5b9e6ad69 --- /dev/null +++ b/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add rls filter type and grouping key + +Revision ID: e5ef6828ac4e +Revises: ae19b4ee3692 +Create Date: 2020-09-15 18:22:40.130985 + +""" + +# revision identifiers, used by Alembic. +revision = "e5ef6828ac4e" +down_revision = "ae19b4ee3692" + +import sqlalchemy as sa +from alembic import op + + +def upgrade(): + with op.batch_alter_table("row_level_security_filters") as batch_op: + batch_op.add_column(sa.Column("filter_type", sa.VARCHAR(255), nullable=True)), + batch_op.add_column(sa.Column("group_key", sa.VARCHAR(255), nullable=True)), + + bind = op.get_bind() + metadata = sa.MetaData(bind=bind) + filters = sa.Table("row_level_security_filters", metadata, autoload=True) + statement = filters.update().values(filter_type="Regular") + bind.execute(statement) + + +def downgrade(): + with op.batch_alter_table("row_level_security_filters") as batch_op: + batch_op.drop_column("filter_type") + batch_op.drop_column("group_key") diff --git a/superset/security/manager.py b/superset/security/manager.py index f5376d8f6d215..a30ec0ae5bddd 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -36,7 +36,7 @@ ViewMenuModelView, ) from flask_appbuilder.widgets import ListWidget -from sqlalchemy import or_ +from sqlalchemy import and_, or_ from sqlalchemy.engine.base import Connection from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import Query as SqlaQuery @@ -62,7 +62,7 @@ class SupersetSecurityListWidget(ListWidget): """ - Redeclaring to avoid circular imports + Redeclaring to avoid circular imports """ template = "superset/fab_overrides/list.html" @@ -70,8 +70,8 @@ class SupersetSecurityListWidget(ListWidget): class SupersetRoleListWidget(ListWidget): """ - Role model view from FAB already uses a custom list widget override - So we override the override + Role model view from FAB already uses a custom list widget override + So we override the override """ template = "superset/fab_overrides/list_role.html" @@ -1012,8 +1012,17 @@ def get_rls_filters( # pylint: disable=no-self-use .filter(assoc_user_role.c.user_id == g.user.id) .subquery() ) - filter_roles = ( + regular_filter_roles = ( self.get_session.query(RLSFilterRoles.c.rls_filter_id) + .join(RowLevelSecurityFilter) + .filter(RowLevelSecurityFilter.filter_type == "Regular") + .filter(RLSFilterRoles.c.role_id.in_(user_roles)) + .subquery() + ) + base_filter_roles = ( + self.get_session.query(RLSFilterRoles.c.rls_filter_id) + .join(RowLevelSecurityFilter) + .filter(RowLevelSecurityFilter.filter_type == "Base") .filter(RLSFilterRoles.c.role_id.in_(user_roles)) .subquery() ) @@ -1024,10 +1033,23 @@ def get_rls_filters( # pylint: disable=no-self-use ) query = ( self.get_session.query( - RowLevelSecurityFilter.id, RowLevelSecurityFilter.clause + RowLevelSecurityFilter.id, + RowLevelSecurityFilter.group_key, + RowLevelSecurityFilter.clause, ) .filter(RowLevelSecurityFilter.id.in_(filter_tables)) - .filter(RowLevelSecurityFilter.id.in_(filter_roles)) + .filter( + or_( + and_( + RowLevelSecurityFilter.filter_type == "Regular", + RowLevelSecurityFilter.id.in_(regular_filter_roles), + ), + and_( + RowLevelSecurityFilter.filter_type == "Base", + RowLevelSecurityFilter.id.notin_(base_filter_roles), + ), + ) + ) ) return query.all() return [] diff --git a/tests/security_tests.py b/tests/security_tests.py index 50548fdeb4071..11a6b215d43f3 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -16,6 +16,7 @@ # under the License. # isort:skip_file import inspect +import re import unittest from unittest.mock import Mock, patch @@ -1009,70 +1010,98 @@ class TestRowLevelSecurity(SupersetTestCase): """ rls_entry = None + query_obj = dict( + groupby=[], + metrics=[], + filter=[], + is_timeseries=False, + columns=["value"], + granularity=None, + from_dttm=None, + to_dttm=None, + extras={}, + ) def setUp(self): session = db.session - # Create the RowLevelSecurityFilter - self.rls_entry = RowLevelSecurityFilter() - self.rls_entry.tables.extend( + # Create regular RowLevelSecurityFilter (energy_usage, unicode_test) + self.rls_entry1 = RowLevelSecurityFilter() + self.rls_entry1.tables.extend( session.query(SqlaTable) .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"])) .all() ) - self.rls_entry.clause = "value > {{ cache_key_wrapper(1) }}" - self.rls_entry.roles.append( - security_manager.find_role("Gamma") - ) # db.session.query(Role).filter_by(name="Gamma").first()) - self.rls_entry.roles.append(security_manager.find_role("Alpha")) - db.session.add(self.rls_entry) + self.rls_entry1.filter_type = "Regular" + self.rls_entry1.clause = "value > {{ cache_key_wrapper(1) }}" + self.rls_entry1.group_key = None + self.rls_entry1.roles.append(security_manager.find_role("Gamma")) + self.rls_entry1.roles.append(security_manager.find_role("Alpha")) + db.session.add(self.rls_entry1) + + # Create regular RowLevelSecurityFilter (birth_names name starts with A or B) + self.rls_entry2 = RowLevelSecurityFilter() + self.rls_entry2.tables.extend( + session.query(SqlaTable) + .filter(SqlaTable.table_name.in_(["birth_names"])) + .all() + ) + self.rls_entry2.filter_type = "Regular" + self.rls_entry2.clause = "name like 'A%' or name like 'B%'" + self.rls_entry2.group_key = "name" + self.rls_entry2.roles.append(security_manager.find_role("Gamma")) + db.session.add(self.rls_entry2) + + # Create Regular RowLevelSecurityFilter (birth_names name starts with Q) + self.rls_entry3 = RowLevelSecurityFilter() + self.rls_entry3.tables.extend( + session.query(SqlaTable) + .filter(SqlaTable.table_name.in_(["birth_names"])) + .all() + ) + self.rls_entry3.filter_type = "Regular" + self.rls_entry3.clause = "name like 'Q%'" + self.rls_entry3.group_key = "name" + self.rls_entry3.roles.append(security_manager.find_role("Gamma")) + db.session.add(self.rls_entry3) + + # Create Base RowLevelSecurityFilter (birth_names boys) + self.rls_entry4 = RowLevelSecurityFilter() + self.rls_entry4.tables.extend( + session.query(SqlaTable) + .filter(SqlaTable.table_name.in_(["birth_names"])) + .all() + ) + self.rls_entry4.filter_type = "Base" + self.rls_entry4.clause = "gender = 'boy'" + self.rls_entry4.group_key = "gender" + self.rls_entry4.roles.append(security_manager.find_role("Admin")) + db.session.add(self.rls_entry4) db.session.commit() def tearDown(self): session = db.session - session.delete(self.rls_entry) + session.delete(self.rls_entry1) + session.delete(self.rls_entry2) + session.delete(self.rls_entry3) + session.delete(self.rls_entry4) session.commit() - # Do another test to make sure it doesn't alter another query - def test_rls_filter_alters_query(self): - g.user = self.get_user( - username="alpha" - ) # self.login() doesn't actually set the user + def test_rls_filter_alters_energy_query(self): + g.user = self.get_user(username="alpha") tbl = self.get_table_by_name("energy_usage") - query_obj = dict( - groupby=[], - metrics=[], - filter=[], - is_timeseries=False, - columns=["value"], - granularity=None, - from_dttm=None, - to_dttm=None, - extras={}, - ) - sql = tbl.get_query_str(query_obj) - assert tbl.get_extra_cache_keys(query_obj) == [1] + sql = tbl.get_query_str(self.query_obj) + assert tbl.get_extra_cache_keys(self.query_obj) == [1] assert "value > 1" in sql - def test_rls_filter_doesnt_alter_query(self): + def test_rls_filter_doesnt_alter_energy_query(self): g.user = self.get_user( username="admin" ) # self.login() doesn't actually set the user tbl = self.get_table_by_name("energy_usage") - query_obj = dict( - groupby=[], - metrics=[], - filter=[], - is_timeseries=False, - columns=["value"], - granularity=None, - from_dttm=None, - to_dttm=None, - extras={}, - ) - sql = tbl.get_query_str(query_obj) - assert tbl.get_extra_cache_keys(query_obj) == [] + sql = tbl.get_query_str(self.query_obj) + assert tbl.get_extra_cache_keys(self.query_obj) == [] assert "value > 1" not in sql def test_multiple_table_filter_alters_another_tables_query(self): @@ -1080,17 +1109,36 @@ def test_multiple_table_filter_alters_another_tables_query(self): username="alpha" ) # self.login() doesn't actually set the user tbl = self.get_table_by_name("unicode_test") - query_obj = dict( - groupby=[], - metrics=[], - filter=[], - is_timeseries=False, - columns=["value"], - granularity=None, - from_dttm=None, - to_dttm=None, - extras={}, - ) - sql = tbl.get_query_str(query_obj) - assert tbl.get_extra_cache_keys(query_obj) == [1] + sql = tbl.get_query_str(self.query_obj) + assert tbl.get_extra_cache_keys(self.query_obj) == [1] assert "value > 1" in sql + + def test_rls_filter_alters_gamma_birth_names_query(self): + g.user = self.get_user(username="gamma") + tbl = self.get_table_by_name("birth_names") + sql = tbl.get_query_str(self.query_obj) + + # establish that groupings are properly applied + assert re.search( + r"\(\s*\(\s*name\s+like\s+'A%'\s+or\s+name\s+like\s+'B%'\s*\)\s+" + r"OR\s+\(\s*name\s+like\s+'Q%'\)\s*\)\s+AND\s+\(gender\s+=\s+'boy'\);", + sql, + re.IGNORECASE, + ) + + def test_rls_filter_alters_alpha_birth_names_query(self): + g.user = self.get_user(username="alpha") + tbl = self.get_table_by_name("birth_names") + sql = tbl.get_query_str(self.query_obj) + assert "name like 'A%' or name like 'B%'" not in sql + assert "name like 'Q%'" not in sql + # base query should be present + assert "gender = 'boy'" in sql + + def test_rls_filter_doesnt_alter_admin_birth_names_query(self): + g.user = self.get_user(username="admin") + tbl = self.get_table_by_name("birth_names") + sql = tbl.get_query_str(self.query_obj) + assert "name like 'A%' or name like 'B%'" not in sql + assert "name like 'Q%'" not in sql + assert "gender = 'boy'" not in sql From f99744e747d5e30e1030b12fadad25c08aa33a74 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Fri, 18 Sep 2020 09:31:21 +0300 Subject: [PATCH 2/4] simplify tests and add custom list widget --- superset/connectors/sqla/views.py | 16 +++- .../templates/superset/models/rls/list.html | 96 +++++++++++++++++++ tests/security_tests.py | 30 +++--- 3 files changed, 128 insertions(+), 14 deletions(-) create mode 100644 superset/templates/superset/models/rls/list.html diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index e07878d793882..1783658da1eb6 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -18,14 +18,15 @@ import logging import re from dataclasses import dataclass, field -from typing import Dict, List, Union +from typing import Any, Dict, List, Union -from flask import flash, Markup, redirect +from flask import current_app, flash, Markup, redirect from flask_appbuilder import CompactCRUDMixin, expose from flask_appbuilder.actions import action from flask_appbuilder.fieldwidgets import Select2Widget from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.security.decorators import has_access +from flask_appbuilder.widgets import ListWidget from flask_babel import gettext as __, lazy_gettext as _ from wtforms.ext.sqlalchemy.fields import QuerySelectField from wtforms.validators import Regexp @@ -241,10 +242,19 @@ class SqlMetricInlineView( # pylint: disable=too-many-ancestors edit_form_extra_fields = add_form_extra_fields +class RowLevelSecurityListWidget(ListWidget): + template = "superset/models/rls/list.html" + + def __init__(self, **kwargs: Any): + kwargs["appbuilder"] = current_app.appbuilder + super().__init__(**kwargs) + + class RowLevelSecurityFiltersModelView( # pylint: disable=too-many-ancestors SupersetModelView, DeleteMixin ): datamodel = SQLAInterface(models.RowLevelSecurityFilter) + list_widget = RowLevelSecurityListWidget list_title = _("Row level security filter") show_title = _("Show Row level security filter") @@ -260,7 +270,7 @@ class RowLevelSecurityFiltersModelView( # pylint: disable=too-many-ancestors "creator", "modified", ] - order_columns = ["filter_type", "tables", "group_key", "clause", "modified"] + order_columns = ["filter_type", "group_key", "clause", "modified"] edit_columns = ["filter_type", "tables", "roles", "group_key", "clause"] show_columns = edit_columns search_columns = ("filter_type", "tables", "roles", "group_key", "clause") diff --git a/superset/templates/superset/models/rls/list.html b/superset/templates/superset/models/rls/list.html new file mode 100644 index 0000000000000..905ee6d305b1a --- /dev/null +++ b/superset/templates/superset/models/rls/list.html @@ -0,0 +1,96 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +#} +{% extends 'appbuilder/general/widgets/base_list.html' %} +{% import 'appbuilder/general/lib.html' as lib %} + + {% block begin_content scoped %} +
+ + {% endblock %} + + {% block begin_loop_header scoped %} + + + {% if actions %} + + {% endif %} + + {% if can_show or can_edit or can_delete %} + + {% endif %} + + {% for item in include_columns %} + {% if item in order_columns %} + {% set res = item | get_link_order(modelview_name) %} + {% if res == 2 %} + + {% elif res == 1 %} + + {% else %} + + {% endif %} + {% else %} + + {% endif %} + {% endfor %} + + + {% endblock %} + + {% block begin_loop_values %} + {% for item in value_columns %} + {% set pk = pks[loop.index-1] %} + + {% if actions %} + + {% endif %} + {% if can_show or can_edit or can_delete %} + + {% endif %} + {% for value in include_columns %} + + {% endfor %} + + {% endfor %} + {% endblock %} + + {% block end_content scoped %} +
+ + {{label_columns.get(item)}} + {{label_columns.get(item)}} + {{label_columns.get(item)}} + {{label_columns.get(item)}}
+ +
+ {{ lib.btn_crud(can_show, can_edit, can_delete, pk, modelview_name, filters) }} +
+ {% if value == "roles" and item["filter_type"] == "Base" and not item[value] %} + All + {% elif value == "roles" and item["filter_type"] == 'Base' %} + Not {{ item[value] }} + {% elif value == "roles" and item["filter_type"] == 'Regular' and not item[value] %} + None + {% elif value == "group_key" and item[value] == None %} + {% else %} + {{ item[value] }} + {% endif %} +
+
+ {% endblock %} diff --git a/tests/security_tests.py b/tests/security_tests.py index 11a6b215d43f3..c9abd4e09e942 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -1021,6 +1021,8 @@ class TestRowLevelSecurity(SupersetTestCase): to_dttm=None, extras={}, ) + GAMMA_FILTER_REGEX = re.compile(r"'[A,B,Q]%'") + BASE_FILTER_REGEX = re.compile(r"'boy'") def setUp(self): session = db.session @@ -1118,27 +1120,33 @@ def test_rls_filter_alters_gamma_birth_names_query(self): tbl = self.get_table_by_name("birth_names") sql = tbl.get_query_str(self.query_obj) - # establish that groupings are properly applied + # establish that both regular and base filters are present + assert self.GAMMA_FILTER_REGEX.search(sql) + assert self.BASE_FILTER_REGEX.search(sql) + + # establish that they are grouped together correctly with ANDs, ORs + # and parens in the correct place (only look for unique bits in the + # filters to make the regex simpler) assert re.search( - r"\(\s*\(\s*name\s+like\s+'A%'\s+or\s+name\s+like\s+'B%'\s*\)\s+" - r"OR\s+\(\s*name\s+like\s+'Q%'\)\s*\)\s+AND\s+\(gender\s+=\s+'boy'\);", - sql, - re.IGNORECASE, + r"\(\s*\(.*'A%'.*\).*OR.*'Q%'\s*\)\s*\)\s+AND\s+\(.*'boy'\s*\)", + sql.replace("\n", " "), # remove newlines to make simpler regex ) def test_rls_filter_alters_alpha_birth_names_query(self): g.user = self.get_user(username="alpha") tbl = self.get_table_by_name("birth_names") sql = tbl.get_query_str(self.query_obj) - assert "name like 'A%' or name like 'B%'" not in sql - assert "name like 'Q%'" not in sql + + # gamma's filters should not be present query + assert not self.GAMMA_FILTER_REGEX.search(sql) # base query should be present - assert "gender = 'boy'" in sql + assert self.BASE_FILTER_REGEX.search(sql) def test_rls_filter_doesnt_alter_admin_birth_names_query(self): g.user = self.get_user(username="admin") tbl = self.get_table_by_name("birth_names") sql = tbl.get_query_str(self.query_obj) - assert "name like 'A%' or name like 'B%'" not in sql - assert "name like 'Q%'" not in sql - assert "gender = 'boy'" not in sql + + # no filters are applied for admin user + assert not self.GAMMA_FILTER_REGEX.search(sql) + assert not self.BASE_FILTER_REGEX.search(sql) From 9526ef289542b8e448a3351c74818041216b88ac Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 21 Sep 2020 10:53:59 +0300 Subject: [PATCH 3/4] address comments --- superset/connectors/sqla/models.py | 2 +- superset/connectors/sqla/views.py | 11 ++-- ...4e_add_rls_filter_type_and_grouping_key.py | 12 ++++- superset/security/manager.py | 18 +++++-- superset/utils/core.py | 5 ++ tests/security_tests.py | 51 ++++++++++++------- 6 files changed, 70 insertions(+), 29 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index e70d96377bf84..293d07110e2d4 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1510,7 +1510,7 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable): __tablename__ = "row_level_security_filters" id = Column(Integer, primary_key=True) - filter_type = Column(Enum("Regular", "Base")) + filter_type = Column(Enum(utils.RowLevelSecurityFilterType)) group_key = Column(String(255), nullable=True) roles = relationship( security_manager.role_model, diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index 1783658da1eb6..4483018fdf7a2 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -18,7 +18,7 @@ import logging import re from dataclasses import dataclass, field -from typing import Any, Dict, List, Union +from typing import Any, cast, Dict, List, Union from flask import current_app, flash, Markup, redirect from flask_appbuilder import CompactCRUDMixin, expose @@ -26,7 +26,6 @@ from flask_appbuilder.fieldwidgets import Select2Widget from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.security.decorators import has_access -from flask_appbuilder.widgets import ListWidget from flask_babel import gettext as __, lazy_gettext as _ from wtforms.ext.sqlalchemy.fields import QuerySelectField from wtforms.validators import Regexp @@ -42,6 +41,7 @@ DatasourceFilter, DeleteMixin, ListWidgetWithCheckboxes, + SupersetListWidget, SupersetModelView, validate_sqlatable, YamlExportMixin, @@ -242,7 +242,9 @@ class SqlMetricInlineView( # pylint: disable=too-many-ancestors edit_form_extra_fields = add_form_extra_fields -class RowLevelSecurityListWidget(ListWidget): +class RowLevelSecurityListWidget( + SupersetListWidget +): # pylint: disable=too-few-public-methods template = "superset/models/rls/list.html" def __init__(self, **kwargs: Any): @@ -254,7 +256,8 @@ class RowLevelSecurityFiltersModelView( # pylint: disable=too-many-ancestors SupersetModelView, DeleteMixin ): datamodel = SQLAInterface(models.RowLevelSecurityFilter) - list_widget = RowLevelSecurityListWidget + + list_widget = cast(SupersetListWidget, RowLevelSecurityListWidget) list_title = _("Row level security filter") show_title = _("Show Row level security filter") diff --git a/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py b/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py index f66d5b9e6ad69..01fcf60e93357 100644 --- a/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py +++ b/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py @@ -29,20 +29,30 @@ import sqlalchemy as sa from alembic import op +from superset.utils import core as utils + def upgrade(): with op.batch_alter_table("row_level_security_filters") as batch_op: batch_op.add_column(sa.Column("filter_type", sa.VARCHAR(255), nullable=True)), batch_op.add_column(sa.Column("group_key", sa.VARCHAR(255), nullable=True)), + batch_op.create_index( + op.f("ix_row_level_security_filters_filter_type"), + ["filter_type"], + unique=False, + ) bind = op.get_bind() metadata = sa.MetaData(bind=bind) filters = sa.Table("row_level_security_filters", metadata, autoload=True) - statement = filters.update().values(filter_type="Regular") + statement = filters.update().values( + filter_type=utils.RowLevelSecurityFilterType.REGULAR.value + ) bind.execute(statement) def downgrade(): with op.batch_alter_table("row_level_security_filters") as batch_op: + batch_op.drop_index(op.f("ix_row_level_security_filters_filter_type"),) batch_op.drop_column("filter_type") batch_op.drop_column("group_key") diff --git a/superset/security/manager.py b/superset/security/manager.py index a30ec0ae5bddd..858ecdc157e9e 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -46,7 +46,7 @@ from superset.constants import RouteMethod from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetSecurityException -from superset.utils.core import DatasourceName +from superset.utils.core import DatasourceName, RowLevelSecurityFilterType if TYPE_CHECKING: from superset.common.query_context import QueryContext @@ -1015,14 +1015,20 @@ def get_rls_filters( # pylint: disable=no-self-use regular_filter_roles = ( self.get_session.query(RLSFilterRoles.c.rls_filter_id) .join(RowLevelSecurityFilter) - .filter(RowLevelSecurityFilter.filter_type == "Regular") + .filter( + RowLevelSecurityFilter.filter_type + == RowLevelSecurityFilterType.REGULAR + ) .filter(RLSFilterRoles.c.role_id.in_(user_roles)) .subquery() ) base_filter_roles = ( self.get_session.query(RLSFilterRoles.c.rls_filter_id) .join(RowLevelSecurityFilter) - .filter(RowLevelSecurityFilter.filter_type == "Base") + .filter( + RowLevelSecurityFilter.filter_type + == RowLevelSecurityFilterType.BASE + ) .filter(RLSFilterRoles.c.role_id.in_(user_roles)) .subquery() ) @@ -1041,11 +1047,13 @@ def get_rls_filters( # pylint: disable=no-self-use .filter( or_( and_( - RowLevelSecurityFilter.filter_type == "Regular", + RowLevelSecurityFilter.filter_type + == RowLevelSecurityFilterType.REGULAR, RowLevelSecurityFilter.id.in_(regular_filter_roles), ), and_( - RowLevelSecurityFilter.filter_type == "Base", + RowLevelSecurityFilter.filter_type + == RowLevelSecurityFilterType.BASE, RowLevelSecurityFilter.id.notin_(base_filter_roles), ), ) diff --git a/superset/utils/core.py b/superset/utils/core.py index 86d7e9f7af6ca..bfd9742c846ed 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1547,3 +1547,8 @@ class PostProcessingContributionOrientation(str, Enum): class AdhocMetricExpressionType(str, Enum): SIMPLE = "SIMPLE" SQL = "SQL" + + +class RowLevelSecurityFilterType(str, Enum): + REGULAR = "Regular" + BASE = "Base" diff --git a/tests/security_tests.py b/tests/security_tests.py index c9abd4e09e942..33136e8226787 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -1021,12 +1021,25 @@ class TestRowLevelSecurity(SupersetTestCase): to_dttm=None, extras={}, ) - GAMMA_FILTER_REGEX = re.compile(r"'[A,B,Q]%'") - BASE_FILTER_REGEX = re.compile(r"'boy'") + NAME_AB_ROLE = "NameAB" + NAME_Q_ROLE = "NameQ" + NAMES_A_REGEX = re.compile(r"name like 'A%'") + NAMES_B_REGEX = re.compile(r"name like 'B%'") + NAMES_Q_REGEX = re.compile(r"name like 'Q%'") + BASE_FILTER_REGEX = re.compile(r"gender = 'boy'") def setUp(self): session = db.session + # Create roles + security_manager.add_role(self.NAME_AB_ROLE) + security_manager.add_role(self.NAME_Q_ROLE) + gamma_user = security_manager.find_user(username="gamma") + gamma_user.roles.append(security_manager.find_role(self.NAME_AB_ROLE)) + gamma_user.roles.append(security_manager.find_role(self.NAME_Q_ROLE)) + self.create_user_with_roles("NoRlsRoleUser", ["Gamma"]) + session.commit() + # Create regular RowLevelSecurityFilter (energy_usage, unicode_test) self.rls_entry1 = RowLevelSecurityFilter() self.rls_entry1.tables.extend( @@ -1051,7 +1064,7 @@ def setUp(self): self.rls_entry2.filter_type = "Regular" self.rls_entry2.clause = "name like 'A%' or name like 'B%'" self.rls_entry2.group_key = "name" - self.rls_entry2.roles.append(security_manager.find_role("Gamma")) + self.rls_entry2.roles.append(security_manager.find_role("NameAB")) db.session.add(self.rls_entry2) # Create Regular RowLevelSecurityFilter (birth_names name starts with Q) @@ -1064,7 +1077,7 @@ def setUp(self): self.rls_entry3.filter_type = "Regular" self.rls_entry3.clause = "name like 'Q%'" self.rls_entry3.group_key = "name" - self.rls_entry3.roles.append(security_manager.find_role("Gamma")) + self.rls_entry3.roles.append(security_manager.find_role("NameQ")) db.session.add(self.rls_entry3) # Create Base RowLevelSecurityFilter (birth_names boys) @@ -1088,6 +1101,9 @@ def tearDown(self): session.delete(self.rls_entry2) session.delete(self.rls_entry3) session.delete(self.rls_entry4) + session.delete(security_manager.find_role("NameAB")) + session.delete(security_manager.find_role("NameQ")) + session.delete(self.get_user("NoRlsRoleUser")) session.commit() def test_rls_filter_alters_energy_query(self): @@ -1120,25 +1136,22 @@ def test_rls_filter_alters_gamma_birth_names_query(self): tbl = self.get_table_by_name("birth_names") sql = tbl.get_query_str(self.query_obj) - # establish that both regular and base filters are present - assert self.GAMMA_FILTER_REGEX.search(sql) - assert self.BASE_FILTER_REGEX.search(sql) - - # establish that they are grouped together correctly with ANDs, ORs - # and parens in the correct place (only look for unique bits in the - # filters to make the regex simpler) - assert re.search( - r"\(\s*\(.*'A%'.*\).*OR.*'Q%'\s*\)\s*\)\s+AND\s+\(.*'boy'\s*\)", - sql.replace("\n", " "), # remove newlines to make simpler regex + # establish that the filters are grouped together correctly with + # ANDs, ORs and parens in the correct place + assert ( + "WHERE ((name like 'A%'\n or name like 'B%')\n OR (name like 'Q%'))\n AND (gender = 'boy');" + in sql ) - def test_rls_filter_alters_alpha_birth_names_query(self): - g.user = self.get_user(username="alpha") + def test_rls_filter_alters_no_role_user_birth_names_query(self): + g.user = self.get_user(username="NoRlsRoleUser") tbl = self.get_table_by_name("birth_names") sql = tbl.get_query_str(self.query_obj) # gamma's filters should not be present query - assert not self.GAMMA_FILTER_REGEX.search(sql) + assert not self.NAMES_A_REGEX.search(sql) + assert not self.NAMES_B_REGEX.search(sql) + assert not self.NAMES_Q_REGEX.search(sql) # base query should be present assert self.BASE_FILTER_REGEX.search(sql) @@ -1148,5 +1161,7 @@ def test_rls_filter_doesnt_alter_admin_birth_names_query(self): sql = tbl.get_query_str(self.query_obj) # no filters are applied for admin user - assert not self.GAMMA_FILTER_REGEX.search(sql) + assert not self.NAMES_A_REGEX.search(sql) + assert not self.NAMES_B_REGEX.search(sql) + assert not self.NAMES_Q_REGEX.search(sql) assert not self.BASE_FILTER_REGEX.search(sql) From 8aced2a86aae4f3a1628e90b317c5c78eb64c1ea Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 21 Sep 2020 19:49:41 +0300 Subject: [PATCH 4/4] use enum value to ensure case sensitive value is used --- superset/connectors/sqla/models.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 293d07110e2d4..49663d21421ab 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -93,8 +93,8 @@ class MetadataResult: class AnnotationDatasource(BaseDatasource): - """ Dummy object so we can query annotations using 'Viz' objects just like - regular datasources. + """Dummy object so we can query annotations using 'Viz' objects just like + regular datasources. """ cache_timeout = 0 @@ -1375,9 +1375,9 @@ def import_obj( ) -> int: """Imports the datasource from the object to the database. - Metrics and columns and datasource will be overrided if exists. - This function can be used to import/export dashboards between multiple - superset instances. Audit metadata isn't copies over. + Metrics and columns and datasource will be overrided if exists. + This function can be used to import/export dashboards between multiple + superset instances. Audit metadata isn't copies over. """ def lookup_sqlatable(table_: "SqlaTable") -> "SqlaTable": @@ -1510,7 +1510,9 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable): __tablename__ = "row_level_security_filters" id = Column(Integer, primary_key=True) - filter_type = Column(Enum(utils.RowLevelSecurityFilterType)) + filter_type = Column( + Enum(*[filter_type.value for filter_type in utils.RowLevelSecurityFilterType]) + ) group_key = Column(String(255), nullable=True) roles = relationship( security_manager.role_model,