diff --git a/superset/app.py b/superset/app.py index c3731a4a16567..1ef5b30531aaf 100644 --- a/superset/app.py +++ b/superset/app.py @@ -261,8 +261,8 @@ def init_views(self) -> None: if self.config["ENABLE_ROW_LEVEL_SECURITY"]: appbuilder.add_view( RowLevelSecurityFiltersModelView, - "Row Level Security Filters", - label=__("Row level security filters"), + "Row Level Security", + label=__("Row level security"), category="Security", category_label=__("Security"), icon="fa-lock", diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 667c9d4a3ee7b..49663d21421ab 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -16,7 +16,7 @@ # under the License. import json import logging -from collections import OrderedDict +from collections import defaultdict, OrderedDict from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union @@ -35,6 +35,7 @@ Column, DateTime, desc, + Enum, ForeignKey, Integer, or_, @@ -92,8 +93,8 @@ class MetadataResult: class AnnotationDatasource(BaseDatasource): - """ Dummy object so we can query annotations using 'Viz' objects just like - regular datasources. + """Dummy object so we can query annotations using 'Viz' objects just like + regular datasources. """ cache_timeout = 0 @@ -798,11 +799,14 @@ def _get_sqla_row_level_filters( :returns: A list of SQL clauses to be ANDed together. :rtype: List[str] """ + filters_grouped: Dict[Union[int, str], List[str]] = defaultdict(list) try: - return [ - text("({})".format(template_processor.process_template(f.clause))) - for f in security_manager.get_rls_filters(self) - ] + for filter_ in security_manager.get_rls_filters(self): + clause = text( + f"({template_processor.process_template(filter_.clause)})" + ) + filters_grouped[filter_.group_key or filter_.id].append(clause) + return [or_(*clauses) for clauses in filters_grouped.values()] except TemplateError as ex: raise QueryObjectValidationError( _("Error in jinja expression in RLS filters: %(msg)s", msg=ex.message,) @@ -1371,9 +1375,9 @@ def import_obj( ) -> int: """Imports the datasource from the object to the database. - Metrics and columns and datasource will be overrided if exists. - This function can be used to import/export dashboards between multiple - superset instances. Audit metadata isn't copies over. + Metrics and columns and datasource will be overrided if exists. + This function can be used to import/export dashboards between multiple + superset instances. Audit metadata isn't copies over. """ def lookup_sqlatable(table_: "SqlaTable") -> "SqlaTable": @@ -1506,6 +1510,10 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable): __tablename__ = "row_level_security_filters" id = Column(Integer, primary_key=True) + filter_type = Column( + Enum(*[filter_type.value for filter_type in utils.RowLevelSecurityFilterType]) + ) + group_key = Column(String(255), nullable=True) roles = relationship( security_manager.role_model, secondary=RLSFilterRoles, diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index a1e28e492b5c9..4483018fdf7a2 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -18,9 +18,9 @@ import logging import re from dataclasses import dataclass, field -from typing import Dict, List, Union +from typing import Any, cast, Dict, List, Union -from flask import flash, Markup, redirect +from flask import current_app, flash, Markup, redirect from flask_appbuilder import CompactCRUDMixin, expose from flask_appbuilder.actions import action from flask_appbuilder.fieldwidgets import Select2Widget @@ -41,6 +41,7 @@ DatasourceFilter, DeleteMixin, ListWidgetWithCheckboxes, + SupersetListWidget, SupersetModelView, validate_sqlatable, YamlExportMixin, @@ -241,30 +242,73 @@ class SqlMetricInlineView( # pylint: disable=too-many-ancestors edit_form_extra_fields = add_form_extra_fields +class RowLevelSecurityListWidget( + SupersetListWidget +): # pylint: disable=too-few-public-methods + template = "superset/models/rls/list.html" + + def __init__(self, **kwargs: Any): + kwargs["appbuilder"] = current_app.appbuilder + super().__init__(**kwargs) + + class RowLevelSecurityFiltersModelView( # pylint: disable=too-many-ancestors SupersetModelView, DeleteMixin ): datamodel = SQLAInterface(models.RowLevelSecurityFilter) + list_widget = cast(SupersetListWidget, RowLevelSecurityListWidget) + list_title = _("Row level security filter") show_title = _("Show Row level security filter") add_title = _("Add Row level security filter") edit_title = _("Edit Row level security filter") - list_columns = ["tables", "roles", "clause", "creator", "modified"] - order_columns = ["tables", "clause", "modified"] - edit_columns = ["tables", "roles", "clause"] + list_columns = [ + "filter_type", + "tables", + "roles", + "group_key", + "clause", + "creator", + "modified", + ] + order_columns = ["filter_type", "group_key", "clause", "modified"] + edit_columns = ["filter_type", "tables", "roles", "group_key", "clause"] show_columns = edit_columns - search_columns = ("tables", "roles", "clause") + search_columns = ("filter_type", "tables", "roles", "group_key", "clause") add_columns = edit_columns base_order = ("changed_on", "desc") description_columns = { + "filter_type": _( + "Regular filters add where clauses to queries if a user belongs to a " + "role referenced in the filter. Base filters apply filters to all queries " + "except the roles defined in the filter, and can be used to define what " + "users can see if no RLS filters within a filter group apply to them." + ), "tables": _("These are the tables this filter will be applied to."), - "roles": _("These are the roles this filter will be applied to."), + "roles": _( + "For regular filters, these are the roles this filter will be " + "applied to. For base filters, these are the roles that the " + "filter DOES NOT apply to, e.g. Admin if admin should see all " + "data." + ), + "group_key": _( + "Filters with the same group key will be ORed together within the group, " + "while different filter groups will be ANDed together. Undefined group " + "keys are treated as unique groups, i.e. are not grouped together. " + "For example, if a table has three filters, of which two are for " + "departments Finance and Marketing (group key = 'department'), and one " + "refers to the region Europe (group key = 'region'), the filter clause " + "would apply the filter (department = 'Finance' OR department = " + "'Marketing') AND (region = 'Europe')." + ), "clause": _( "This is the condition that will be added to the WHERE clause. " "For example, to only return rows for a particular client, " - "you might put in: client_id = 9" + "you might define a regular filter with the clause `client_id = 9`. To " + "display no rows unless a user belongs to a RLS filter role, a base " + "filter can be created with the clause `1 = 0` (always false)." ), } label_columns = { diff --git a/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py b/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py new file mode 100644 index 0000000000000..01fcf60e93357 --- /dev/null +++ b/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add rls filter type and grouping key + +Revision ID: e5ef6828ac4e +Revises: ae19b4ee3692 +Create Date: 2020-09-15 18:22:40.130985 + +""" + +# revision identifiers, used by Alembic. +revision = "e5ef6828ac4e" +down_revision = "ae19b4ee3692" + +import sqlalchemy as sa +from alembic import op + +from superset.utils import core as utils + + +def upgrade(): + with op.batch_alter_table("row_level_security_filters") as batch_op: + batch_op.add_column(sa.Column("filter_type", sa.VARCHAR(255), nullable=True)), + batch_op.add_column(sa.Column("group_key", sa.VARCHAR(255), nullable=True)), + batch_op.create_index( + op.f("ix_row_level_security_filters_filter_type"), + ["filter_type"], + unique=False, + ) + + bind = op.get_bind() + metadata = sa.MetaData(bind=bind) + filters = sa.Table("row_level_security_filters", metadata, autoload=True) + statement = filters.update().values( + filter_type=utils.RowLevelSecurityFilterType.REGULAR.value + ) + bind.execute(statement) + + +def downgrade(): + with op.batch_alter_table("row_level_security_filters") as batch_op: + batch_op.drop_index(op.f("ix_row_level_security_filters_filter_type"),) + batch_op.drop_column("filter_type") + batch_op.drop_column("group_key") diff --git a/superset/security/manager.py b/superset/security/manager.py index f5376d8f6d215..858ecdc157e9e 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -36,7 +36,7 @@ ViewMenuModelView, ) from flask_appbuilder.widgets import ListWidget -from sqlalchemy import or_ +from sqlalchemy import and_, or_ from sqlalchemy.engine.base import Connection from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import Query as SqlaQuery @@ -46,7 +46,7 @@ from superset.constants import RouteMethod from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetSecurityException -from superset.utils.core import DatasourceName +from superset.utils.core import DatasourceName, RowLevelSecurityFilterType if TYPE_CHECKING: from superset.common.query_context import QueryContext @@ -62,7 +62,7 @@ class SupersetSecurityListWidget(ListWidget): """ - Redeclaring to avoid circular imports + Redeclaring to avoid circular imports """ template = "superset/fab_overrides/list.html" @@ -70,8 +70,8 @@ class SupersetSecurityListWidget(ListWidget): class SupersetRoleListWidget(ListWidget): """ - Role model view from FAB already uses a custom list widget override - So we override the override + Role model view from FAB already uses a custom list widget override + So we override the override """ template = "superset/fab_overrides/list_role.html" @@ -1012,8 +1012,23 @@ def get_rls_filters( # pylint: disable=no-self-use .filter(assoc_user_role.c.user_id == g.user.id) .subquery() ) - filter_roles = ( + regular_filter_roles = ( self.get_session.query(RLSFilterRoles.c.rls_filter_id) + .join(RowLevelSecurityFilter) + .filter( + RowLevelSecurityFilter.filter_type + == RowLevelSecurityFilterType.REGULAR + ) + .filter(RLSFilterRoles.c.role_id.in_(user_roles)) + .subquery() + ) + base_filter_roles = ( + self.get_session.query(RLSFilterRoles.c.rls_filter_id) + .join(RowLevelSecurityFilter) + .filter( + RowLevelSecurityFilter.filter_type + == RowLevelSecurityFilterType.BASE + ) .filter(RLSFilterRoles.c.role_id.in_(user_roles)) .subquery() ) @@ -1024,10 +1039,25 @@ def get_rls_filters( # pylint: disable=no-self-use ) query = ( self.get_session.query( - RowLevelSecurityFilter.id, RowLevelSecurityFilter.clause + RowLevelSecurityFilter.id, + RowLevelSecurityFilter.group_key, + RowLevelSecurityFilter.clause, ) .filter(RowLevelSecurityFilter.id.in_(filter_tables)) - .filter(RowLevelSecurityFilter.id.in_(filter_roles)) + .filter( + or_( + and_( + RowLevelSecurityFilter.filter_type + == RowLevelSecurityFilterType.REGULAR, + RowLevelSecurityFilter.id.in_(regular_filter_roles), + ), + and_( + RowLevelSecurityFilter.filter_type + == RowLevelSecurityFilterType.BASE, + RowLevelSecurityFilter.id.notin_(base_filter_roles), + ), + ) + ) ) return query.all() return [] diff --git a/superset/templates/superset/models/rls/list.html b/superset/templates/superset/models/rls/list.html new file mode 100644 index 0000000000000..905ee6d305b1a --- /dev/null +++ b/superset/templates/superset/models/rls/list.html @@ -0,0 +1,96 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +#} +{% extends 'appbuilder/general/widgets/base_list.html' %} +{% import 'appbuilder/general/lib.html' as lib %} + + {% block begin_content scoped %} +
+ + {% endblock %} + + {% block begin_loop_header scoped %} + + + {% if actions %} + + {% endif %} + + {% if can_show or can_edit or can_delete %} + + {% endif %} + + {% for item in include_columns %} + {% if item in order_columns %} + {% set res = item | get_link_order(modelview_name) %} + {% if res == 2 %} + + {% elif res == 1 %} + + {% else %} + + {% endif %} + {% else %} + + {% endif %} + {% endfor %} + + + {% endblock %} + + {% block begin_loop_values %} + {% for item in value_columns %} + {% set pk = pks[loop.index-1] %} + + {% if actions %} + + {% endif %} + {% if can_show or can_edit or can_delete %} + + {% endif %} + {% for value in include_columns %} + + {% endfor %} + + {% endfor %} + {% endblock %} + + {% block end_content scoped %} +
+ + {{label_columns.get(item)}} + {{label_columns.get(item)}} + {{label_columns.get(item)}} + {{label_columns.get(item)}}
+ +
+ {{ lib.btn_crud(can_show, can_edit, can_delete, pk, modelview_name, filters) }} +
+ {% if value == "roles" and item["filter_type"] == "Base" and not item[value] %} + All + {% elif value == "roles" and item["filter_type"] == 'Base' %} + Not {{ item[value] }} + {% elif value == "roles" and item["filter_type"] == 'Regular' and not item[value] %} + None + {% elif value == "group_key" and item[value] == None %} + {% else %} + {{ item[value] }} + {% endif %} +
+
+ {% endblock %} diff --git a/superset/utils/core.py b/superset/utils/core.py index 86d7e9f7af6ca..bfd9742c846ed 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1547,3 +1547,8 @@ class PostProcessingContributionOrientation(str, Enum): class AdhocMetricExpressionType(str, Enum): SIMPLE = "SIMPLE" SQL = "SQL" + + +class RowLevelSecurityFilterType(str, Enum): + REGULAR = "Regular" + BASE = "Base" diff --git a/tests/security_tests.py b/tests/security_tests.py index 50548fdeb4071..33136e8226787 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -16,6 +16,7 @@ # under the License. # isort:skip_file import inspect +import re import unittest from unittest.mock import Mock, patch @@ -1009,70 +1010,116 @@ class TestRowLevelSecurity(SupersetTestCase): """ rls_entry = None + query_obj = dict( + groupby=[], + metrics=[], + filter=[], + is_timeseries=False, + columns=["value"], + granularity=None, + from_dttm=None, + to_dttm=None, + extras={}, + ) + NAME_AB_ROLE = "NameAB" + NAME_Q_ROLE = "NameQ" + NAMES_A_REGEX = re.compile(r"name like 'A%'") + NAMES_B_REGEX = re.compile(r"name like 'B%'") + NAMES_Q_REGEX = re.compile(r"name like 'Q%'") + BASE_FILTER_REGEX = re.compile(r"gender = 'boy'") def setUp(self): session = db.session - # Create the RowLevelSecurityFilter - self.rls_entry = RowLevelSecurityFilter() - self.rls_entry.tables.extend( + # Create roles + security_manager.add_role(self.NAME_AB_ROLE) + security_manager.add_role(self.NAME_Q_ROLE) + gamma_user = security_manager.find_user(username="gamma") + gamma_user.roles.append(security_manager.find_role(self.NAME_AB_ROLE)) + gamma_user.roles.append(security_manager.find_role(self.NAME_Q_ROLE)) + self.create_user_with_roles("NoRlsRoleUser", ["Gamma"]) + session.commit() + + # Create regular RowLevelSecurityFilter (energy_usage, unicode_test) + self.rls_entry1 = RowLevelSecurityFilter() + self.rls_entry1.tables.extend( session.query(SqlaTable) .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"])) .all() ) - self.rls_entry.clause = "value > {{ cache_key_wrapper(1) }}" - self.rls_entry.roles.append( - security_manager.find_role("Gamma") - ) # db.session.query(Role).filter_by(name="Gamma").first()) - self.rls_entry.roles.append(security_manager.find_role("Alpha")) - db.session.add(self.rls_entry) + self.rls_entry1.filter_type = "Regular" + self.rls_entry1.clause = "value > {{ cache_key_wrapper(1) }}" + self.rls_entry1.group_key = None + self.rls_entry1.roles.append(security_manager.find_role("Gamma")) + self.rls_entry1.roles.append(security_manager.find_role("Alpha")) + db.session.add(self.rls_entry1) + + # Create regular RowLevelSecurityFilter (birth_names name starts with A or B) + self.rls_entry2 = RowLevelSecurityFilter() + self.rls_entry2.tables.extend( + session.query(SqlaTable) + .filter(SqlaTable.table_name.in_(["birth_names"])) + .all() + ) + self.rls_entry2.filter_type = "Regular" + self.rls_entry2.clause = "name like 'A%' or name like 'B%'" + self.rls_entry2.group_key = "name" + self.rls_entry2.roles.append(security_manager.find_role("NameAB")) + db.session.add(self.rls_entry2) + + # Create Regular RowLevelSecurityFilter (birth_names name starts with Q) + self.rls_entry3 = RowLevelSecurityFilter() + self.rls_entry3.tables.extend( + session.query(SqlaTable) + .filter(SqlaTable.table_name.in_(["birth_names"])) + .all() + ) + self.rls_entry3.filter_type = "Regular" + self.rls_entry3.clause = "name like 'Q%'" + self.rls_entry3.group_key = "name" + self.rls_entry3.roles.append(security_manager.find_role("NameQ")) + db.session.add(self.rls_entry3) + + # Create Base RowLevelSecurityFilter (birth_names boys) + self.rls_entry4 = RowLevelSecurityFilter() + self.rls_entry4.tables.extend( + session.query(SqlaTable) + .filter(SqlaTable.table_name.in_(["birth_names"])) + .all() + ) + self.rls_entry4.filter_type = "Base" + self.rls_entry4.clause = "gender = 'boy'" + self.rls_entry4.group_key = "gender" + self.rls_entry4.roles.append(security_manager.find_role("Admin")) + db.session.add(self.rls_entry4) db.session.commit() def tearDown(self): session = db.session - session.delete(self.rls_entry) + session.delete(self.rls_entry1) + session.delete(self.rls_entry2) + session.delete(self.rls_entry3) + session.delete(self.rls_entry4) + session.delete(security_manager.find_role("NameAB")) + session.delete(security_manager.find_role("NameQ")) + session.delete(self.get_user("NoRlsRoleUser")) session.commit() - # Do another test to make sure it doesn't alter another query - def test_rls_filter_alters_query(self): - g.user = self.get_user( - username="alpha" - ) # self.login() doesn't actually set the user + def test_rls_filter_alters_energy_query(self): + g.user = self.get_user(username="alpha") tbl = self.get_table_by_name("energy_usage") - query_obj = dict( - groupby=[], - metrics=[], - filter=[], - is_timeseries=False, - columns=["value"], - granularity=None, - from_dttm=None, - to_dttm=None, - extras={}, - ) - sql = tbl.get_query_str(query_obj) - assert tbl.get_extra_cache_keys(query_obj) == [1] + sql = tbl.get_query_str(self.query_obj) + assert tbl.get_extra_cache_keys(self.query_obj) == [1] assert "value > 1" in sql - def test_rls_filter_doesnt_alter_query(self): + def test_rls_filter_doesnt_alter_energy_query(self): g.user = self.get_user( username="admin" ) # self.login() doesn't actually set the user tbl = self.get_table_by_name("energy_usage") - query_obj = dict( - groupby=[], - metrics=[], - filter=[], - is_timeseries=False, - columns=["value"], - granularity=None, - from_dttm=None, - to_dttm=None, - extras={}, - ) - sql = tbl.get_query_str(query_obj) - assert tbl.get_extra_cache_keys(query_obj) == [] + sql = tbl.get_query_str(self.query_obj) + assert tbl.get_extra_cache_keys(self.query_obj) == [] assert "value > 1" not in sql def test_multiple_table_filter_alters_another_tables_query(self): @@ -1080,17 +1127,41 @@ def test_multiple_table_filter_alters_another_tables_query(self): username="alpha" ) # self.login() doesn't actually set the user tbl = self.get_table_by_name("unicode_test") - query_obj = dict( - groupby=[], - metrics=[], - filter=[], - is_timeseries=False, - columns=["value"], - granularity=None, - from_dttm=None, - to_dttm=None, - extras={}, - ) - sql = tbl.get_query_str(query_obj) - assert tbl.get_extra_cache_keys(query_obj) == [1] + sql = tbl.get_query_str(self.query_obj) + assert tbl.get_extra_cache_keys(self.query_obj) == [1] assert "value > 1" in sql + + def test_rls_filter_alters_gamma_birth_names_query(self): + g.user = self.get_user(username="gamma") + tbl = self.get_table_by_name("birth_names") + sql = tbl.get_query_str(self.query_obj) + + # establish that the filters are grouped together correctly with + # ANDs, ORs and parens in the correct place + assert ( + "WHERE ((name like 'A%'\n or name like 'B%')\n OR (name like 'Q%'))\n AND (gender = 'boy');" + in sql + ) + + def test_rls_filter_alters_no_role_user_birth_names_query(self): + g.user = self.get_user(username="NoRlsRoleUser") + tbl = self.get_table_by_name("birth_names") + sql = tbl.get_query_str(self.query_obj) + + # gamma's filters should not be present query + assert not self.NAMES_A_REGEX.search(sql) + assert not self.NAMES_B_REGEX.search(sql) + assert not self.NAMES_Q_REGEX.search(sql) + # base query should be present + assert self.BASE_FILTER_REGEX.search(sql) + + def test_rls_filter_doesnt_alter_admin_birth_names_query(self): + g.user = self.get_user(username="admin") + tbl = self.get_table_by_name("birth_names") + sql = tbl.get_query_str(self.query_obj) + + # no filters are applied for admin user + assert not self.NAMES_A_REGEX.search(sql) + assert not self.NAMES_B_REGEX.search(sql) + assert not self.NAMES_Q_REGEX.search(sql) + assert not self.BASE_FILTER_REGEX.search(sql)