diff --git a/superset/common/request_contexed_based.py b/superset/common/request_contexed_based.py index 0b06a0ccbe1d5..5d8405e36cf05 100644 --- a/superset/common/request_contexed_based.py +++ b/superset/common/request_contexed_based.py @@ -16,24 +16,10 @@ # under the License. from __future__ import annotations -from typing import List, TYPE_CHECKING - -from flask import g - from superset import conf, security_manager -if TYPE_CHECKING: - from flask_appbuilder.security.sqla.models import Role - - -def get_user_roles() -> List[Role]: - if g.user.is_anonymous: - public_role = conf.get("AUTH_ROLE_PUBLIC") - return [security_manager.get_public_role()] if public_role else [] - return g.user.roles - def is_user_admin() -> bool: - user_roles = [role.name.lower() for role in get_user_roles()] + user_roles = [role.name.lower() for role in security_manager.get_user_roles()] admin_role = conf.get("AUTH_ROLE_ADMIN").lower() return admin_role in user_roles diff --git a/superset/config.py b/superset/config.py index e84f366d268db..4d00d7153f8b7 100644 --- a/superset/config.py +++ b/superset/config.py @@ -199,7 +199,11 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: WTF_CSRF_ENABLED = True # Add endpoints that need to be exempt from CSRF protection -WTF_CSRF_EXEMPT_LIST = ["superset.views.core.log", "superset.charts.data.api.data"] +WTF_CSRF_EXEMPT_LIST = [ + "superset.views.core.log", + "superset.views.core.explore_json", + "superset.charts.data.api.data", +] # Whether to run the web server in debug mode or not DEBUG = os.environ.get("FLASK_ENV") == "development" @@ -401,7 +405,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: # a custom security config could potentially give access to setting filters on # tables that users do not have access to. "ROW_LEVEL_SECURITY": True, - "EMBEDDED_SUPERSET": False, + "EMBEDDED_SUPERSET": False, # This requires that the public role be available # Enables Alerts and reports new implementation "ALERT_REPORTS": False, # Enable experimental feature to search for other dashboards diff --git a/superset/dashboards/filters.py b/superset/dashboards/filters.py index 96584784f449b..e398af97b744a 100644 --- a/superset/dashboards/filters.py +++ b/superset/dashboards/filters.py @@ -16,6 +16,7 @@ # under the License. from typing import Any, Optional +from flask import g from flask_appbuilder.security.sqla.models import Role from flask_babel import lazy_gettext as _ from sqlalchemy import and_, or_ @@ -25,7 +26,8 @@ from superset.models.core import FavStar from superset.models.dashboard import Dashboard from superset.models.slice import Slice -from superset.views.base import BaseFilter, get_user_roles, is_user_admin +from superset.security.guest_token import GuestTokenResourceType, GuestUser +from superset.views.base import BaseFilter, is_user_admin from superset.views.base_api import BaseFavoriteFilter @@ -112,7 +114,7 @@ def apply(self, query: Query, value: Any) -> Query: ) ) - dashboard_rbac_or_filters = [] + feature_flagged_filters = [] if is_feature_enabled("DASHBOARD_RBAC"): roles_based_query = ( db.session.query(Dashboard.id) @@ -121,19 +123,31 @@ def apply(self, query: Query, value: Any) -> Query: and_( Dashboard.published.is_(True), dashboard_has_roles, - Role.id.in_([x.id for x in get_user_roles()]), + Role.id.in_([x.id for x in security_manager.get_user_roles()]), ), ) ) - dashboard_rbac_or_filters.append(Dashboard.id.in_(roles_based_query)) + feature_flagged_filters.append(Dashboard.id.in_(roles_based_query)) + + if is_feature_enabled("EMBEDDED_SUPERSET") and security_manager.is_guest_user( + g.user + ): + guest_user: GuestUser = g.user + embedded_dashboard_ids = [ + r["id"] + for r in guest_user.resources + if r["type"] == GuestTokenResourceType.DASHBOARD.value + ] + if len(embedded_dashboard_ids) != 0: + feature_flagged_filters.append(Dashboard.id.in_(embedded_dashboard_ids)) query = query.filter( or_( Dashboard.id.in_(owner_ids_query), Dashboard.id.in_(datasource_perm_query), Dashboard.id.in_(users_favorite_dash_query), - *dashboard_rbac_or_filters, + *feature_flagged_filters, ) ) diff --git a/superset/security/api.py b/superset/security/api.py index c0a4a77b24225..54efcd07e0dbd 100644 --- a/superset/security/api.py +++ b/superset/security/api.py @@ -35,7 +35,7 @@ class UserSchema(Schema): class ResourceSchema(Schema): - type = fields.String(required=True) + type = fields.String(required=True) # todo figure out how to make this an enum id = fields.String(required=True) rls = fields.String() diff --git a/superset/security/guest_token.py b/superset/security/guest_token.py index cbef52f008e34..60add8175400d 100644 --- a/superset/security/guest_token.py +++ b/superset/security/guest_token.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from enum import Enum from typing import List, Optional, TypedDict, Union from flask_appbuilder.security.sqla.models import Role @@ -26,17 +27,24 @@ class GuestTokenUser(TypedDict, total=False): last_name: str +class GuestTokenResourceType(Enum): + DASHBOARD = "dashboard" + + class GuestTokenResource(TypedDict): - type: str + type: GuestTokenResourceType id: Union[str, int] rls: Optional[str] +GuestTokenResources = List[GuestTokenResource] + + class GuestToken(TypedDict): iat: float exp: float user: GuestTokenUser - resources: List[GuestTokenResource] + resources: GuestTokenResources class GuestUser(AnonymousUserMixin): @@ -50,7 +58,7 @@ class GuestUser(AnonymousUserMixin): def is_authenticated(self) -> bool: """ This is set to true because guest users should be considered authenticated, - at least in most places. The treatment of this flag is pretty inconsistent. + at least in most places. The treatment of this flag is kind of inconsistent. """ return True diff --git a/superset/security/manager.py b/superset/security/manager.py index 1a0a4532f5cd2..5993f952f3473 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -67,7 +67,8 @@ from superset.exceptions import SupersetSecurityException from superset.security.guest_token import ( GuestToken, - GuestTokenResource, + GuestTokenResources, + GuestTokenResourceType, GuestTokenUser, GuestUser, ) @@ -1067,11 +1068,16 @@ def raise_for_access( assert datasource + should_check_dashboard_access = ( + feature_flag_manager.is_feature_enabled("DASHBOARD_RBAC") + or self.is_guest_user() + ) + if not ( self.can_access_schema(datasource) or self.can_access("datasource_access", datasource.perm or "") or ( - feature_flag_manager.is_feature_enabled("DASHBOARD_RBAC") + should_check_dashboard_access and self.can_access_based_on_dashboard(datasource) ) ): @@ -1097,6 +1103,14 @@ def get_user_by_username( def get_anonymous_user(self) -> User: # pylint: disable=no-self-use return AnonymousUserMixin() + def get_user_roles(self, user: Optional[User] = None) -> List[Role]: + if not user: + user = g.user + if user.is_anonymous: + public_role = current_app.config.get("AUTH_ROLE_PUBLIC") + return [self.get_public_role()] if public_role else [] + return user.roles + def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]: """ Retrieves the appropriate row level security filters for the current user and @@ -1195,10 +1209,11 @@ def raise_for_user_activity_access(user_id: int) -> None: ) ) - @staticmethod - def raise_for_dashboard_access(dashboard: "Dashboard") -> None: + def raise_for_dashboard_access(self, dashboard: "Dashboard") -> None: """ Raise an exception if the user cannot access the dashboard. + This does not check for the required role/permission pairs, + it only concerns itself with entity relationships. :param dashboard: Dashboard the user wants access to :raises DashboardAccessDeniedError: If the user cannot access the resource @@ -1206,23 +1221,27 @@ def raise_for_dashboard_access(dashboard: "Dashboard") -> None: # pylint: disable=import-outside-toplevel from superset import is_feature_enabled from superset.dashboards.commands.exceptions import DashboardAccessDeniedError - from superset.views.base import get_user_roles, is_user_admin + from superset.views.base import is_user_admin from superset.views.utils import is_owner - has_rbac_access = True - - if is_feature_enabled("DASHBOARD_RBAC"): - has_rbac_access = any( - dashboard_role.id in [user_role.id for user_role in get_user_roles()] + def has_rbac_access() -> bool: + return (not is_feature_enabled("DASHBOARD_RBAC")) or any( + dashboard_role.id + in [user_role.id for user_role in self.get_user_roles()] for dashboard_role in dashboard.roles ) - can_access = ( - is_user_admin() - or is_owner(dashboard, g.user) - or (dashboard.published and has_rbac_access) - or (not dashboard.published and not dashboard.roles) - ) + if self.is_guest_user(): + can_access = self.has_guest_access( + GuestTokenResourceType.DASHBOARD, dashboard.id + ) + else: + can_access = ( + is_user_admin() + or is_owner(dashboard, g.user) + or (dashboard.published and has_rbac_access()) + or (not dashboard.published and not dashboard.roles) + ) if not can_access: raise DashboardAccessDeniedError() @@ -1255,7 +1274,7 @@ def _get_current_epoch_time() -> float: return time.time() def create_guest_access_token( - self, user: GuestTokenUser, resources: List[GuestTokenResource] + self, user: GuestTokenUser, resources: GuestTokenResources ) -> bytes: secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] @@ -1289,33 +1308,60 @@ def get_guest_user_from_request(self, req: Request) -> Optional[GuestUser]: try: token = self.parse_jwt_guest_token(raw_token) + if token.get("user") is None: + raise ValueError("Guest token does not contain a user claim") + if token.get("resources") is None: + raise ValueError("Guest token does not contain a resources claim") except Exception: # pylint: disable=broad-except # The login manager will handle sending 401s. # We don't need to send a special error message. logger.warning("Invalid guest token", exc_info=True) return None else: - return self.guest_user_cls( - token=token, - roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])], - ) + return self.get_guest_user_from_token(cast(GuestToken, token)) + + def get_guest_user_from_token(self, token: GuestToken) -> GuestUser: + return self.guest_user_cls( + token=token, roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])], + ) @staticmethod - def parse_jwt_guest_token(raw_token: str) -> GuestToken: + def parse_jwt_guest_token(raw_token: str) -> Dict[str, Any]: """ - Parses and validates a guest token. - Raises an error if the jwt is invalid: - if it is not signed with our secret, - or if required claims are not present. + Parses a guest token. Raises an error if the jwt fails standard claims checks. :param raw_token: the token gotten from the request :return: the same token that was passed in, tested but unchanged """ secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] + return jwt.decode(raw_token, secret, algorithms=[algo]) + + @staticmethod + def is_guest_user(user: Optional[Any] = None) -> bool: + # pylint: disable=import-outside-toplevel + from superset import is_feature_enabled + + if not is_feature_enabled("EMBEDDED_SUPERSET"): + return False + if not user: + user = g.user + return hasattr(user, "is_guest_user") and user.is_guest_user + + def get_current_guest_user_if_guest(self) -> Optional[GuestUser]: + + if self.is_guest_user(): + return g.user + return None + + def has_guest_access( + self, resource_type: GuestTokenResourceType, resource_id: Union[str, int] + ) -> bool: + user = self.get_current_guest_user_if_guest() + if not user: + return False - token = jwt.decode(raw_token, secret, algorithms=[algo]) - if token.get("user") is None: - raise ValueError("Guest token does not contain a user claim") - if token.get("resources") is None: - raise ValueError("Guest token does not contain a resources claim") - return cast(GuestToken, token) + strid = str(resource_id) + for resource in user.resources: + if resource["type"] == resource_type.value and str(resource["id"]) == strid: + return True + return False diff --git a/superset/views/base.py b/superset/views/base.py index d3cb477b9fd9a..4244c66131f28 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -38,7 +38,7 @@ from flask_appbuilder.actions import action from flask_appbuilder.forms import DynamicForm from flask_appbuilder.models.sqla.filters import BaseFilter -from flask_appbuilder.security.sqla.models import Role, User +from flask_appbuilder.security.sqla.models import User from flask_appbuilder.widgets import ListWidget from flask_babel import get_locale, gettext as __, lazy_gettext as _ from flask_jwt_extended.exceptions import NoAuthorizationError @@ -264,15 +264,8 @@ def create_table_permissions(table: models.SqlaTable) -> None: security_manager.add_permission_view_menu("schema_access", table.schema_perm) -def get_user_roles() -> List[Role]: - if g.user.is_anonymous: - public_role = conf.get("AUTH_ROLE_PUBLIC") - return [security_manager.find_role(public_role)] if public_role else [] - return g.user.roles - - def is_user_admin() -> bool: - user_roles = [role.name.lower() for role in list(get_user_roles())] + user_roles = [role.name.lower() for role in list(security_manager.get_user_roles())] return "admin" in user_roles diff --git a/superset/views/core.py b/superset/views/core.py index ca7dae3ee51e7..e557314cfbe4d 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -135,7 +135,6 @@ data_payload_response, generate_download_headers, get_error_msg, - get_user_roles, handle_api_exception, json_error_response, json_errors_response, @@ -1888,7 +1887,9 @@ def publish( # pylint: disable=no-self-use f"ERROR: cannot find dashboard {dashboard_id}", status=404 ) - edit_perm = is_owner(dash, g.user) or admin_role in get_user_roles() + edit_perm = ( + is_owner(dash, g.user) or admin_role in security_manager.get_user_roles() + ) if not edit_perm: username = g.user.username if hasattr(g.user, "username") else "user" return json_error_response( diff --git a/tests/integration_tests/security/guest_token_security_tests.py b/tests/integration_tests/security/guest_token_security_tests.py new file mode 100644 index 0000000000000..9ca34198dbdf2 --- /dev/null +++ b/tests/integration_tests/security/guest_token_security_tests.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for Superset""" +from unittest import mock + +import pytest +from flask import g + +from superset import db, security_manager +from superset.dashboards.commands.exceptions import DashboardAccessDeniedError +from superset.exceptions import SupersetSecurityException +from superset.models.dashboard import Dashboard +from superset.security.guest_token import GuestTokenResourceType +from superset.sql_parse import Table +from tests.integration_tests.base_tests import SupersetTestCase +from tests.integration_tests.fixtures.birth_names_dashboard import ( + load_birth_names_dashboard_with_slices, + load_birth_names_data, +) + + +@mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", EMBEDDED_SUPERSET=True, +) +class TestGuestUserSecurity(SupersetTestCase): + # This test doesn't use a dashboard fixture, the next test does. + # That way tests are faster. + + resource_id = 42 + + def authorized_guest(self): + return security_manager.get_guest_user_from_token( + {"user": {}, "resources": [{"type": "dashboard", "id": self.resource_id}]} + ) + + def test_is_guest_user__regular_user(self): + is_guest = security_manager.is_guest_user(security_manager.find_user("admin")) + self.assertFalse(is_guest) + + def test_is_guest_user__anonymous(self): + is_guest = security_manager.is_guest_user(security_manager.get_anonymous_user()) + self.assertFalse(is_guest) + + def test_is_guest_user__guest_user(self): + is_guest = security_manager.is_guest_user(self.authorized_guest()) + self.assertTrue(is_guest) + + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + EMBEDDED_SUPERSET=False, + ) + def test_is_guest_user__flag_off(self): + is_guest = security_manager.is_guest_user(self.authorized_guest()) + self.assertFalse(is_guest) + + def test_get_guest_user__regular_user(self): + g.user = security_manager.find_user("admin") + guest_user = security_manager.get_current_guest_user_if_guest() + self.assertIsNone(guest_user) + + def test_get_guest_user__anonymous_user(self): + g.user = security_manager.get_anonymous_user() + guest_user = security_manager.get_current_guest_user_if_guest() + self.assertIsNone(guest_user) + + def test_get_guest_user__guest_user(self): + g.user = self.authorized_guest() + guest_user = security_manager.get_current_guest_user_if_guest() + self.assertEqual(guest_user, g.user) + + def test_has_guest_access__regular_user(self): + g.user = security_manager.find_user("admin") + has_guest_access = security_manager.has_guest_access( + GuestTokenResourceType.DASHBOARD, self.resource_id + ) + self.assertFalse(has_guest_access) + + def test_has_guest_access__anonymous_user(self): + g.user = security_manager.get_anonymous_user() + has_guest_access = security_manager.has_guest_access( + GuestTokenResourceType.DASHBOARD, self.resource_id + ) + self.assertFalse(has_guest_access) + + def test_has_guest_access__authorized_guest_user(self): + g.user = self.authorized_guest() + has_guest_access = security_manager.has_guest_access( + GuestTokenResourceType.DASHBOARD, self.resource_id + ) + self.assertTrue(has_guest_access) + + def test_has_guest_access__authorized_guest_user__non_zero_resource_index(self): + guest = self.authorized_guest() + guest.resources = [ + {"type": "dashboard", "id": self.resource_id - 1} + ] + guest.resources + g.user = guest + + has_guest_access = security_manager.has_guest_access( + GuestTokenResourceType.DASHBOARD, self.resource_id + ) + self.assertTrue(has_guest_access) + + def test_has_guest_access__unauthorized_guest_user__different_resource_id(self): + g.user = security_manager.get_guest_user_from_token( + { + "user": {}, + "resources": [{"type": "dashboard", "id": self.resource_id - 1}], + } + ) + has_guest_access = security_manager.has_guest_access( + GuestTokenResourceType.DASHBOARD, self.resource_id + ) + self.assertFalse(has_guest_access) + + def test_has_guest_access__unauthorized_guest_user__different_resource_type(self): + g.user = security_manager.get_guest_user_from_token( + {"user": {}, "resources": [{"type": "dirt", "id": self.resource_id}]} + ) + has_guest_access = security_manager.has_guest_access( + GuestTokenResourceType.DASHBOARD, self.resource_id + ) + self.assertFalse(has_guest_access) + + def test_get_guest_user_roles_explicit(self): + guest = self.authorized_guest() + roles = security_manager.get_user_roles(guest) + self.assertEqual(guest.roles, roles) + + def test_get_guest_user_roles_implicit(self): + guest = self.authorized_guest() + g.user = guest + + roles = security_manager.get_user_roles() + self.assertEqual(guest.roles, roles) + + +@mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", EMBEDDED_SUPERSET=True, +) +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") +class TestGuestUserDashboardAccess(SupersetTestCase): + def setUp(self) -> None: + self.dash = db.session.query(Dashboard).filter_by(slug="births").first() + self.authorized_guest = security_manager.get_guest_user_from_token( + {"user": {}, "resources": [{"type": "dashboard", "id": self.dash.id}]} + ) + self.unauthorized_guest = security_manager.get_guest_user_from_token( + {"user": {}, "resources": [{"type": "dashboard", "id": self.dash.id + 1}]} + ) + + def test_chart_raise_for_access_as_guest(self): + chart = self.dash.slices[0] + g.user = self.authorized_guest + + security_manager.raise_for_access(viz=chart) + + def test_chart_raise_for_access_as_unauthorized_guest(self): + chart = self.dash.slices[0] + g.user = self.unauthorized_guest + + with self.assertRaises(SupersetSecurityException): + security_manager.raise_for_access(viz=chart) + + def test_dataset_raise_for_access_as_guest(self): + dataset = self.dash.slices[0].datasource + g.user = self.authorized_guest + + security_manager.raise_for_access(datasource=dataset) + + def test_dataset_raise_for_access_as_unauthorized_guest(self): + dataset = self.dash.slices[0].datasource + g.user = self.unauthorized_guest + + with self.assertRaises(SupersetSecurityException): + security_manager.raise_for_access(datasource=dataset) + + def test_guest_token_does_not_grant_access_to_underlying_table(self): + sqla_table = self.dash.slices[0].table + table = Table(table=sqla_table.table_name) + + g.user = self.authorized_guest + + with self.assertRaises(Exception): + security_manager.raise_for_access(table=table, database=sqla_table.database) + + def test_raise_for_dashboard_access_as_guest(self): + g.user = self.authorized_guest + + security_manager.raise_for_dashboard_access(self.dash) + + def test_raise_for_dashboard_access_as_unauthorized_guest(self): + g.user = self.unauthorized_guest + + with self.assertRaises(DashboardAccessDeniedError): + security_manager.raise_for_dashboard_access(self.dash) diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py new file mode 100644 index 0000000000000..665666cb61f5b --- /dev/null +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file +import re +from typing import Any, Dict + +import pytest +from flask import g + +from superset import db, security_manager +from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable +from ..base_tests import SupersetTestCase +from tests.integration_tests.fixtures.birth_names_dashboard import ( + load_birth_names_dashboard_with_slices, + load_birth_names_data, +) +from tests.integration_tests.fixtures.energy_dashboard import ( + load_energy_table_with_slice, + load_energy_table_data, +) +from tests.integration_tests.fixtures.unicode_dashboard import ( + load_unicode_dashboard_with_slice, + load_unicode_data, +) + + +class TestRowLevelSecurity(SupersetTestCase): + """ + Testing Row Level Security + """ + + rls_entry = None + query_obj: Dict[str, Any] = dict( + groupby=[], + metrics=None, + filter=[], + is_timeseries=False, + columns=["value"], + granularity=None, + from_dttm=None, + to_dttm=None, + extras={}, + ) + NAME_AB_ROLE = "NameAB" + NAME_Q_ROLE = "NameQ" + NAMES_A_REGEX = re.compile(r"name like 'A%'") + NAMES_B_REGEX = re.compile(r"name like 'B%'") + NAMES_Q_REGEX = re.compile(r"name like 'Q%'") + BASE_FILTER_REGEX = re.compile(r"gender = 'boy'") + + def setUp(self): + session = db.session + + # Create roles + security_manager.add_role(self.NAME_AB_ROLE) + security_manager.add_role(self.NAME_Q_ROLE) + gamma_user = security_manager.find_user(username="gamma") + gamma_user.roles.append(security_manager.find_role(self.NAME_AB_ROLE)) + gamma_user.roles.append(security_manager.find_role(self.NAME_Q_ROLE)) + self.create_user_with_roles("NoRlsRoleUser", ["Gamma"]) + session.commit() + + # Create regular RowLevelSecurityFilter (energy_usage, unicode_test) + self.rls_entry1 = RowLevelSecurityFilter() + self.rls_entry1.tables.extend( + session.query(SqlaTable) + .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"])) + .all() + ) + self.rls_entry1.filter_type = "Regular" + self.rls_entry1.clause = "value > {{ cache_key_wrapper(1) }}" + self.rls_entry1.group_key = None + self.rls_entry1.roles.append(security_manager.find_role("Gamma")) + self.rls_entry1.roles.append(security_manager.find_role("Alpha")) + db.session.add(self.rls_entry1) + + # Create regular RowLevelSecurityFilter (birth_names name starts with A or B) + self.rls_entry2 = RowLevelSecurityFilter() + self.rls_entry2.tables.extend( + session.query(SqlaTable) + .filter(SqlaTable.table_name.in_(["birth_names"])) + .all() + ) + self.rls_entry2.filter_type = "Regular" + self.rls_entry2.clause = "name like 'A%' or name like 'B%'" + self.rls_entry2.group_key = "name" + self.rls_entry2.roles.append(security_manager.find_role("NameAB")) + db.session.add(self.rls_entry2) + + # Create Regular RowLevelSecurityFilter (birth_names name starts with Q) + self.rls_entry3 = RowLevelSecurityFilter() + self.rls_entry3.tables.extend( + session.query(SqlaTable) + .filter(SqlaTable.table_name.in_(["birth_names"])) + .all() + ) + self.rls_entry3.filter_type = "Regular" + self.rls_entry3.clause = "name like 'Q%'" + self.rls_entry3.group_key = "name" + self.rls_entry3.roles.append(security_manager.find_role("NameQ")) + db.session.add(self.rls_entry3) + + # Create Base RowLevelSecurityFilter (birth_names boys) + self.rls_entry4 = RowLevelSecurityFilter() + self.rls_entry4.tables.extend( + session.query(SqlaTable) + .filter(SqlaTable.table_name.in_(["birth_names"])) + .all() + ) + self.rls_entry4.filter_type = "Base" + self.rls_entry4.clause = "gender = 'boy'" + self.rls_entry4.group_key = "gender" + self.rls_entry4.roles.append(security_manager.find_role("Admin")) + db.session.add(self.rls_entry4) + + db.session.commit() + + def tearDown(self): + session = db.session + session.delete(self.rls_entry1) + session.delete(self.rls_entry2) + session.delete(self.rls_entry3) + session.delete(self.rls_entry4) + session.delete(security_manager.find_role("NameAB")) + session.delete(security_manager.find_role("NameQ")) + session.delete(self.get_user("NoRlsRoleUser")) + session.commit() + + @pytest.mark.usefixtures("load_energy_table_with_slice") + def test_rls_filter_alters_energy_query(self): + g.user = self.get_user(username="alpha") + tbl = self.get_table(name="energy_usage") + sql = tbl.get_query_str(self.query_obj) + assert tbl.get_extra_cache_keys(self.query_obj) == [1] + assert "value > 1" in sql + + @pytest.mark.usefixtures("load_energy_table_with_slice") + def test_rls_filter_doesnt_alter_energy_query(self): + g.user = self.get_user( + username="admin" + ) # self.login() doesn't actually set the user + tbl = self.get_table(name="energy_usage") + sql = tbl.get_query_str(self.query_obj) + assert tbl.get_extra_cache_keys(self.query_obj) == [] + assert "value > 1" not in sql + + @pytest.mark.usefixtures("load_unicode_dashboard_with_slice") + def test_multiple_table_filter_alters_another_tables_query(self): + g.user = self.get_user( + username="alpha" + ) # self.login() doesn't actually set the user + tbl = self.get_table(name="unicode_test") + sql = tbl.get_query_str(self.query_obj) + assert tbl.get_extra_cache_keys(self.query_obj) == [1] + assert "value > 1" in sql + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_rls_filter_alters_gamma_birth_names_query(self): + g.user = self.get_user(username="gamma") + tbl = self.get_table(name="birth_names") + sql = tbl.get_query_str(self.query_obj) + + # establish that the filters are grouped together correctly with + # ANDs, ORs and parens in the correct place + assert ( + "WHERE ((name like 'A%'\n or name like 'B%')\n OR (name like 'Q%'))\n AND (gender = 'boy');" + in sql + ) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_rls_filter_alters_no_role_user_birth_names_query(self): + g.user = self.get_user(username="NoRlsRoleUser") + tbl = self.get_table(name="birth_names") + sql = tbl.get_query_str(self.query_obj) + + # gamma's filters should not be present query + assert not self.NAMES_A_REGEX.search(sql) + assert not self.NAMES_B_REGEX.search(sql) + assert not self.NAMES_Q_REGEX.search(sql) + # base query should be present + assert self.BASE_FILTER_REGEX.search(sql) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_rls_filter_doesnt_alter_admin_birth_names_query(self): + g.user = self.get_user(username="admin") + tbl = self.get_table(name="birth_names") + sql = tbl.get_query_str(self.query_obj) + + # no filters are applied for admin user + assert not self.NAMES_A_REGEX.search(sql) + assert not self.NAMES_B_REGEX.search(sql) + assert not self.NAMES_Q_REGEX.search(sql) + assert not self.BASE_FILTER_REGEX.search(sql) diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 82bedb3da46d0..2f5ad65aaea6a 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -15,27 +15,25 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -import dataclasses import inspect -import re import time import unittest from collections import namedtuple from unittest import mock from unittest.mock import Mock, patch -from typing import Any, Dict +from typing import Any import jwt import prison import pytest -from flask import current_app, g +from flask import current_app from superset.models.dashboard import Dashboard from superset import app, appbuilder, db, security_manager, viz, ConnectorRegistry from superset.connectors.druid.models import DruidCluster, DruidDatasource -from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable +from superset.connectors.sqla.models import SqlaTable from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetSecurityException from superset.models.core import Database @@ -49,22 +47,10 @@ from superset.views.access_requests import AccessRequestsModelView from .base_tests import SupersetTestCase -from tests.integration_tests.fixtures.birth_names_dashboard import ( - load_birth_names_dashboard_with_slices, - load_birth_names_data, -) -from tests.integration_tests.fixtures.energy_dashboard import ( - load_energy_table_with_slice, - load_energy_table_data, -) from tests.integration_tests.fixtures.public_role import ( public_role_like_gamma, public_role_like_test_role, ) -from tests.integration_tests.fixtures.unicode_dashboard import ( - load_unicode_dashboard_with_slice, - load_unicode_data, -) from tests.integration_tests.fixtures.world_bank_dashboard import ( load_world_bank_dashboard_with_slices, load_world_bank_data, @@ -1056,174 +1042,18 @@ def test_raise_for_access_viz(self, mock_can_access_schema, mock_can_access): with self.assertRaises(SupersetSecurityException): security_manager.raise_for_access(viz=test_viz) + @patch("superset.security.manager.g") + def test_get_user_roles(self, mock_g): + admin = security_manager.find_user("admin") + mock_g.user = admin + roles = security_manager.get_user_roles() + self.assertEqual(admin.roles, roles) -class TestRowLevelSecurity(SupersetTestCase): - """ - Testing Row Level Security - """ - - rls_entry = None - query_obj: Dict[str, Any] = dict( - groupby=[], - metrics=None, - filter=[], - is_timeseries=False, - columns=["value"], - granularity=None, - from_dttm=None, - to_dttm=None, - extras={}, - ) - NAME_AB_ROLE = "NameAB" - NAME_Q_ROLE = "NameQ" - NAMES_A_REGEX = re.compile(r"name like 'A%'") - NAMES_B_REGEX = re.compile(r"name like 'B%'") - NAMES_Q_REGEX = re.compile(r"name like 'Q%'") - BASE_FILTER_REGEX = re.compile(r"gender = 'boy'") - - def setUp(self): - session = db.session - - # Create roles - security_manager.add_role(self.NAME_AB_ROLE) - security_manager.add_role(self.NAME_Q_ROLE) - gamma_user = security_manager.find_user(username="gamma") - gamma_user.roles.append(security_manager.find_role(self.NAME_AB_ROLE)) - gamma_user.roles.append(security_manager.find_role(self.NAME_Q_ROLE)) - self.create_user_with_roles("NoRlsRoleUser", ["Gamma"]) - session.commit() - - # Create regular RowLevelSecurityFilter (energy_usage, unicode_test) - self.rls_entry1 = RowLevelSecurityFilter() - self.rls_entry1.tables.extend( - session.query(SqlaTable) - .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"])) - .all() - ) - self.rls_entry1.filter_type = "Regular" - self.rls_entry1.clause = "value > {{ cache_key_wrapper(1) }}" - self.rls_entry1.group_key = None - self.rls_entry1.roles.append(security_manager.find_role("Gamma")) - self.rls_entry1.roles.append(security_manager.find_role("Alpha")) - db.session.add(self.rls_entry1) - - # Create regular RowLevelSecurityFilter (birth_names name starts with A or B) - self.rls_entry2 = RowLevelSecurityFilter() - self.rls_entry2.tables.extend( - session.query(SqlaTable) - .filter(SqlaTable.table_name.in_(["birth_names"])) - .all() - ) - self.rls_entry2.filter_type = "Regular" - self.rls_entry2.clause = "name like 'A%' or name like 'B%'" - self.rls_entry2.group_key = "name" - self.rls_entry2.roles.append(security_manager.find_role("NameAB")) - db.session.add(self.rls_entry2) - - # Create Regular RowLevelSecurityFilter (birth_names name starts with Q) - self.rls_entry3 = RowLevelSecurityFilter() - self.rls_entry3.tables.extend( - session.query(SqlaTable) - .filter(SqlaTable.table_name.in_(["birth_names"])) - .all() - ) - self.rls_entry3.filter_type = "Regular" - self.rls_entry3.clause = "name like 'Q%'" - self.rls_entry3.group_key = "name" - self.rls_entry3.roles.append(security_manager.find_role("NameQ")) - db.session.add(self.rls_entry3) - - # Create Base RowLevelSecurityFilter (birth_names boys) - self.rls_entry4 = RowLevelSecurityFilter() - self.rls_entry4.tables.extend( - session.query(SqlaTable) - .filter(SqlaTable.table_name.in_(["birth_names"])) - .all() - ) - self.rls_entry4.filter_type = "Base" - self.rls_entry4.clause = "gender = 'boy'" - self.rls_entry4.group_key = "gender" - self.rls_entry4.roles.append(security_manager.find_role("Admin")) - db.session.add(self.rls_entry4) - - db.session.commit() - - def tearDown(self): - session = db.session - session.delete(self.rls_entry1) - session.delete(self.rls_entry2) - session.delete(self.rls_entry3) - session.delete(self.rls_entry4) - session.delete(security_manager.find_role("NameAB")) - session.delete(security_manager.find_role("NameQ")) - session.delete(self.get_user("NoRlsRoleUser")) - session.commit() - - @pytest.mark.usefixtures("load_energy_table_with_slice") - def test_rls_filter_alters_energy_query(self): - g.user = self.get_user(username="alpha") - tbl = self.get_table(name="energy_usage") - sql = tbl.get_query_str(self.query_obj) - assert tbl.get_extra_cache_keys(self.query_obj) == [1] - assert "value > 1" in sql - - @pytest.mark.usefixtures("load_energy_table_with_slice") - def test_rls_filter_doesnt_alter_energy_query(self): - g.user = self.get_user( - username="admin" - ) # self.login() doesn't actually set the user - tbl = self.get_table(name="energy_usage") - sql = tbl.get_query_str(self.query_obj) - assert tbl.get_extra_cache_keys(self.query_obj) == [] - assert "value > 1" not in sql - - @pytest.mark.usefixtures("load_unicode_dashboard_with_slice") - def test_multiple_table_filter_alters_another_tables_query(self): - g.user = self.get_user( - username="alpha" - ) # self.login() doesn't actually set the user - tbl = self.get_table(name="unicode_test") - sql = tbl.get_query_str(self.query_obj) - assert tbl.get_extra_cache_keys(self.query_obj) == [1] - assert "value > 1" in sql - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_rls_filter_alters_gamma_birth_names_query(self): - g.user = self.get_user(username="gamma") - tbl = self.get_table(name="birth_names") - sql = tbl.get_query_str(self.query_obj) - - # establish that the filters are grouped together correctly with - # ANDs, ORs and parens in the correct place - assert ( - "WHERE ((name like 'A%'\n or name like 'B%')\n OR (name like 'Q%'))\n AND (gender = 'boy');" - in sql - ) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_rls_filter_alters_no_role_user_birth_names_query(self): - g.user = self.get_user(username="NoRlsRoleUser") - tbl = self.get_table(name="birth_names") - sql = tbl.get_query_str(self.query_obj) - - # gamma's filters should not be present query - assert not self.NAMES_A_REGEX.search(sql) - assert not self.NAMES_B_REGEX.search(sql) - assert not self.NAMES_Q_REGEX.search(sql) - # base query should be present - assert self.BASE_FILTER_REGEX.search(sql) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_rls_filter_doesnt_alter_admin_birth_names_query(self): - g.user = self.get_user(username="admin") - tbl = self.get_table(name="birth_names") - sql = tbl.get_query_str(self.query_obj) - - # no filters are applied for admin user - assert not self.NAMES_A_REGEX.search(sql) - assert not self.NAMES_B_REGEX.search(sql) - assert not self.NAMES_Q_REGEX.search(sql) - assert not self.BASE_FILTER_REGEX.search(sql) + @patch("superset.security.manager.g") + def test_get_anonymous_roles(self, mock_g): + mock_g.user = security_manager.get_anonymous_user() + roles = security_manager.get_user_roles() + self.assertEqual([security_manager.get_public_role()], roles) class TestAccessRequestEndpoints(SupersetTestCase): @@ -1341,7 +1171,11 @@ def test_create_guest_access_token(self, get_time_mock): token = security_manager.create_guest_access_token(user, resources) # unfortunately we cannot mock time in the jwt lib - decoded_token = jwt.decode(token, self.app.config["GUEST_TOKEN_JWT_SECRET"]) + decoded_token = jwt.decode( + token, + self.app.config["GUEST_TOKEN_JWT_SECRET"], + algorithms=[self.app.config["GUEST_TOKEN_JWT_ALGO"]], + ) self.assertEqual(user, decoded_token["user"]) self.assertEqual(resources, decoded_token["resources"]) @@ -1378,3 +1212,26 @@ def test_get_guest_user_expired_token(self, get_time_mock): guest_user = security_manager.get_guest_user_from_request(fake_request) self.assertIsNone(guest_user) + + def test_get_guest_user_no_user(self): + user = None + resources = [{"type": "dashboard", "id": 1}] + token = security_manager.create_guest_access_token(user, resources) + fake_request = FakeRequest() + fake_request.headers[current_app.config["GUEST_TOKEN_HEADER_NAME"]] = token + guest_user = security_manager.get_guest_user_from_request(fake_request) + + self.assertIsNone(guest_user) + self.assertRaisesRegex(ValueError, "Guest token does not contain a user claim") + + def test_get_guest_user_no_resource(self): + user = {"username": "test_guest"} + resources = [] + token = security_manager.create_guest_access_token(user, resources) + fake_request = FakeRequest() + fake_request.headers[current_app.config["GUEST_TOKEN_HEADER_NAME"]] = token + guest_user = security_manager.get_guest_user_from_request(fake_request) + + self.assertRaisesRegex( + ValueError, "Guest token does not contain a resources claim" + )