diff --git a/superset/config.py b/superset/config.py index 59a54fee3cfef..d7c8d5ef86802 100644 --- a/superset/config.py +++ b/superset/config.py @@ -343,6 +343,7 @@ def _try_json_readsha( # pylint: disable=unused-argument "ALERT_REPORTS": False, # Enable experimental feature to search for other dashboards "OMNIBAR": False, + "DASHBOARD_RBAC": False, } # Set the default view to card/grid view if thumbnail support is enabled. diff --git a/superset/dashboards/filters.py b/superset/dashboards/filters.py index a2421ab007ba6..7ca466e917186 100644 --- a/superset/dashboards/filters.py +++ b/superset/dashboards/filters.py @@ -16,15 +16,16 @@ # under the License. from typing import Any +from flask_appbuilder.security.sqla.models import Role from flask_babel import lazy_gettext as _ from sqlalchemy import and_, or_ from sqlalchemy.orm.query import Query -from superset import db, security_manager +from superset import db, is_feature_enabled, security_manager from superset.models.core import FavStar from superset.models.dashboard import Dashboard from superset.models.slice import Slice -from superset.views.base import BaseFilter, is_user_admin +from superset.views.base import BaseFilter, get_user_roles, is_user_admin from superset.views.base_api import BaseFavoriteFilter @@ -74,12 +75,19 @@ def apply(self, query: Query, value: Any) -> Query: datasource_perms = security_manager.user_view_menu_names("datasource_access") schema_perms = security_manager.user_view_menu_names("schema_access") - published_dash_query = ( + + is_rbac_disabled_filter = [] + dashboard_has_roles = Dashboard.roles.any() + if is_feature_enabled("DASHBOARD_RBAC"): + is_rbac_disabled_filter.append(~dashboard_has_roles) + + datasource_perm_query = ( db.session.query(Dashboard.id) .join(Dashboard.slices) .filter( and_( - Dashboard.published == True, # pylint: disable=singleton-comparison + Dashboard.published.is_(True), + *is_rbac_disabled_filter, or_( Slice.perm.in_(datasource_perms), Slice.schema_perm.in_(schema_perms), @@ -104,11 +112,28 @@ def apply(self, query: Query, value: Any) -> Query: ) ) + dashboard_rbac_or_filters = [] + if is_feature_enabled("DASHBOARD_RBAC"): + roles_based_query = ( + db.session.query(Dashboard.id) + .join(Dashboard.roles) + .filter( + and_( + Dashboard.published.is_(True), + dashboard_has_roles, + Role.id.in_([x.id for x in get_user_roles()]), + ), + ) + ) + + dashboard_rbac_or_filters.append(Dashboard.id.in_(roles_based_query)) + query = query.filter( or_( Dashboard.id.in_(owner_ids_query), - Dashboard.id.in_(published_dash_query), + Dashboard.id.in_(datasource_perm_query), Dashboard.id.in_(users_favorite_dash_query), + *dashboard_rbac_or_filters, ) ) diff --git a/superset/migrations/versions/11ccdd12658_add_roles_relationship_to_dashboard.py b/superset/migrations/versions/11ccdd12658_add_roles_relationship_to_dashboard.py new file mode 100644 index 0000000000000..b5576caa5f18b --- /dev/null +++ b/superset/migrations/versions/11ccdd12658_add_roles_relationship_to_dashboard.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add roles relationship to dashboard +Revision ID: e11ccdd12658 +Revises: 260bf0649a77 +Create Date: 2021-01-14 19:12:43.406230 +""" +# revision identifiers, used by Alembic. +revision = "e11ccdd12658" +down_revision = "260bf0649a77" +import sqlalchemy as sa +from alembic import op + + +def upgrade(): + op.create_table( + "dashboard_roles", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("role_id", sa.Integer(), nullable=False), + sa.Column("dashboard_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["dashboard_id"], ["dashboards.id"]), + sa.ForeignKeyConstraint(["role_id"], ["ab_role.id"]), + sa.PrimaryKeyConstraint("id"), + ) + + +def downgrade(): + op.drop_table("dashboard_roles") diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index d7c2a6978a480..3606da8c5cb80 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -115,6 +115,15 @@ def copy_dashboard( ) +DashboardRoles = Table( + "dashboard_roles", + metadata, + Column("id", Integer, primary_key=True), + Column("dashboard_id", Integer, ForeignKey("dashboards.id"), nullable=False), + Column("role_id", Integer, ForeignKey("ab_role.id"), nullable=False), +) + + class Dashboard( # pylint: disable=too-many-instance-attributes Model, AuditMixinNullable, ImportExportMixin ): @@ -132,7 +141,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes slices = relationship(Slice, secondary=dashboard_slices, backref="dashboards") owners = relationship(security_manager.user_model, secondary=dashboard_user) published = Column(Boolean, default=False) - + roles = relationship(security_manager.role_model, secondary=DashboardRoles) export_fields = [ "dashboard_title", "position_json", diff --git a/superset/views/dashboard/mixin.py b/superset/views/dashboard/mixin.py index 89472acf74a00..273cfdf0b3c2f 100644 --- a/superset/views/dashboard/mixin.py +++ b/superset/views/dashboard/mixin.py @@ -33,6 +33,7 @@ class DashboardMixin: # pylint: disable=too-few-public-methods "dashboard_title", "slug", "owners", + "roles", "position_json", "css", "json_metadata", @@ -62,6 +63,12 @@ class DashboardMixin: # pylint: disable=too-few-public-methods "want to alter specific parameters." ), "owners": _("Owners is a list of users who can alter the dashboard."), + "roles": _( + "Roles is a list which defines access to the dashboard. " + "These roles are always applied in addition to restrictions on dataset " + "level access. " + "If no roles defined then the dashboard is available to all roles." + ), "published": _( "Determines whether or not this dashboard is " "visible in the list of all dashboards" @@ -74,6 +81,7 @@ class DashboardMixin: # pylint: disable=too-few-public-methods "slug": _("Slug"), "charts": _("Charts"), "owners": _("Owners"), + "roles": _("Roles"), "published": _("Published"), "creator": _("Creator"), "modified": _("Modified"), diff --git a/tests/base_tests.py b/tests/base_tests.py index 81e218b5c69e7..a5fd9d5382da3 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -111,7 +111,6 @@ def logged_in_admin(): class SupersetTestCase(TestCase): - default_schema_backend_map = { "sqlite": "main", "mysql": "superset", @@ -135,7 +134,9 @@ def get_birth_names_dataset(): ) @staticmethod - def create_user_with_roles(username: str, roles: List[str]): + def create_user_with_roles( + username: str, roles: List[str], should_create_roles: bool = False + ): user_to_create = security_manager.find_user(username) if not user_to_create: security_manager.add_user( @@ -149,7 +150,12 @@ def create_user_with_roles(username: str, roles: List[str]): db.session.commit() user_to_create = security_manager.find_user(username) assert user_to_create - user_to_create.roles = [security_manager.find_role(r) for r in roles] + user_to_create.roles = [] + for chosen_user_role in roles: + if should_create_roles: + ## copy role from gamma but without data permissions + security_manager.copy_role("Gamma", chosen_user_role, merge=False) + user_to_create.roles.append(security_manager.find_role(chosen_user_role)) db.session.commit() return user_to_create @@ -290,7 +296,11 @@ def logout(self): self.client.get("/logout/", follow_redirects=True) def grant_public_access_to_table(self, table): - public_role = security_manager.find_role("Public") + role_name = "Public" + self.grant_role_access_to_table(table, role_name) + + def grant_role_access_to_table(self, table, role_name): + role = security_manager.find_role(role_name) perms = db.session.query(ab_models.PermissionView).all() for perm in perms: if ( @@ -298,10 +308,14 @@ def grant_public_access_to_table(self, table): and perm.view_menu and table.perm in perm.view_menu.name ): - security_manager.add_permission_role(public_role, perm) + security_manager.add_permission_role(role, perm) def revoke_public_access_to_table(self, table): - public_role = security_manager.find_role("Public") + role_name = "Public" + self.revoke_role_access_to_table(role_name, table) + + def revoke_role_access_to_table(self, role_name, table): + public_role = security_manager.find_role(role_name) perms = db.session.query(ab_models.PermissionView).all() for perm in perms: if ( diff --git a/tests/dashboard_tests.py b/tests/dashboard_tests.py index be3e0a9be55e0..9780b1608c917 100644 --- a/tests/dashboard_tests.py +++ b/tests/dashboard_tests.py @@ -453,24 +453,6 @@ def test_only_owners_can_save(self): db.session.commit() self.test_save_dash("alpha") - def test_owners_can_view_empty_dashboard(self): - dash = db.session.query(Dashboard).filter_by(slug="empty_dashboard").first() - if not dash: - dash = Dashboard() - dash.dashboard_title = "Empty Dashboard" - dash.slug = "empty_dashboard" - else: - dash.slices = [] - dash.owners = [] - db.session.merge(dash) - db.session.commit() - - gamma_user = security_manager.find_user("gamma") - self.login(gamma_user.username) - - resp = self.get_resp("/api/v1/dashboard/") - self.assertNotIn("/superset/dashboard/empty_dashboard/", resp) - @pytest.mark.usefixtures("load_energy_table_with_slice", "load_dashboard") def test_users_can_view_published_dashboard(self): resp = self.get_resp("/api/v1/dashboard/") diff --git a/tests/dashboards/base_case.py b/tests/dashboards/base_case.py new file mode 100644 index 0000000000000..42cd87bfd93e1 --- /dev/null +++ b/tests/dashboards/base_case.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +from typing import Any, Dict, Union + +import prison +from flask import Response + +from superset import app, security_manager +from tests.base_tests import SupersetTestCase +from tests.dashboards.consts import * +from tests.dashboards.dashboard_test_utils import build_save_dash_parts +from tests.dashboards.superset_factory_util import delete_all_inserted_objects + + +class DashboardTestCase(SupersetTestCase): + def get_dashboard_via_api_by_id(self, dashboard_id: int) -> Response: + uri = DASHBOARD_API_URL_FORMAT.format(dashboard_id) + return self.get_assert_metric(uri, "get") + + def get_dashboard_view_response(self, dashboard_to_access) -> Response: + return self.client.get(dashboard_to_access.url) + + def get_dashboard_api_response(self, dashboard_to_access) -> Response: + return self.client.get(DASHBOARD_API_URL_FORMAT.format(dashboard_to_access.id)) + + def get_dashboards_list_response(self) -> Response: + return self.client.get(GET_DASHBOARDS_LIST_VIEW) + + def get_dashboards_api_response(self) -> Response: + return self.client.get(DASHBOARDS_API_URL) + + def save_dashboard_via_view( + self, dashboard_id: Union[str, int], dashboard_data: Dict[str, Any] + ) -> Response: + save_dash_url = SAVE_DASHBOARD_URL_FORMAT.format(dashboard_id) + return self.get_resp(save_dash_url, data=dict(data=json.dumps(dashboard_data))) + + def save_dashboard( + self, dashboard_id: Union[str, int], dashboard_data: Dict[str, Any] + ) -> Response: + return self.save_dashboard_via_view(dashboard_id, dashboard_data) + + def delete_dashboard_via_view(self, dashboard_id: int) -> Response: + delete_dashboard_url = DELETE_DASHBOARD_VIEW_URL_FORMAT.format(dashboard_id) + return self.get_resp(delete_dashboard_url, {}) + + def delete_dashboard_via_api(self, dashboard_id): + uri = DASHBOARD_API_URL_FORMAT.format(dashboard_id) + return self.delete_assert_metric(uri, "delete") + + def bulk_delete_dashboard_via_api(self, dashboard_ids): + uri = DASHBOARDS_API_URL_WITH_QUERY_FORMAT.format(prison.dumps(dashboard_ids)) + return self.delete_assert_metric(uri, "bulk_delete") + + def delete_dashboard(self, dashboard_id: int) -> Response: + return self.delete_dashboard_via_view(dashboard_id) + + def assert_permission_was_created(self, dashboard): + view_menu = security_manager.find_view_menu(dashboard.view_name) + self.assertIsNotNone(view_menu) + self.assertEqual(len(security_manager.find_permissions_view_menu(view_menu)), 1) + + def assert_permission_kept_and_changed(self, updated_dashboard, excepted_view_id): + view_menu_after_title_changed = security_manager.find_view_menu( + updated_dashboard.view_name + ) + self.assertIsNotNone(view_menu_after_title_changed) + self.assertEqual(view_menu_after_title_changed.id, excepted_view_id) + + def assert_permissions_were_deleted(self, deleted_dashboard): + view_menu = security_manager.find_view_menu(deleted_dashboard.view_name) + self.assertIsNone(view_menu) + + def save_dash_basic_case(self, username=ADMIN_USERNAME): + # arrange + self.login(username=username) + ( + dashboard_to_save, + data_before_change, + data_after_change, + ) = build_save_dash_parts() + + # act + save_dash_response = self.save_dashboard_via_view( + dashboard_to_save.id, data_after_change + ) + + # assert + self.assertIn("SUCCESS", save_dash_response) + + # post test + self.save_dashboard(dashboard_to_save.id, data_before_change) + + def clean_created_objects(self): + with app.test_request_context(): + self.logout() + self.login("admin") + delete_all_inserted_objects() + self.logout() diff --git a/tests/dashboards/consts.py b/tests/dashboards/consts.py new file mode 100644 index 0000000000000..a6e36839be9ed --- /dev/null +++ b/tests/dashboards/consts.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +QUERY_FORMAT = "?q={}" + +DASHBOARDS_API_URL = "api/v1/dashboard/" +DASHBOARDS_API_URL_WITH_QUERY_FORMAT = DASHBOARDS_API_URL + QUERY_FORMAT +DASHBOARD_API_URL_FORMAT = DASHBOARDS_API_URL + "{}" +EXPORT_DASHBOARDS_API_URL = DASHBOARDS_API_URL + "export/" +EXPORT_DASHBOARDS_API_URL_WITH_QUERY_FORMAT = EXPORT_DASHBOARDS_API_URL + QUERY_FORMAT + +GET_DASHBOARD_VIEW_URL_FORMAT = "/superset/dashboard/{}/" +SAVE_DASHBOARD_URL_FORMAT = "/superset/save_dash/{}/" +COPY_DASHBOARD_URL_FORMAT = "/superset/copy_dash/{}/" +ADD_SLICES_URL_FORMAT = "/superset/add_slices/{}/" + +DELETE_DASHBOARD_VIEW_URL_FORMAT = "/dashboard/delete/{}" +GET_DASHBOARDS_LIST_VIEW = "/dashboard/list/" +NEW_DASHBOARD_URL = "/dashboard/new/" +GET_CHARTS_API_URL = "/api/v1/chart/" + +GAMMA_ROLE_NAME = "Gamma" + +ADMIN_USERNAME = "admin" +GAMMA_USERNAME = "gamma" + +DASHBOARD_SLUG_OF_ACCESSIBLE_TABLE = "births" +DEFAULT_DASHBOARD_SLUG_TO_TEST = "births" +WORLD_HEALTH_SLUG = "world_health" diff --git a/tests/dashboards/dashboard_test_utils.py b/tests/dashboards/dashboard_test_utils.py new file mode 100644 index 0000000000000..c7032f87e6481 --- /dev/null +++ b/tests/dashboards/dashboard_test_utils.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +import random +import string +from typing import Any, Dict, List, Optional, Tuple + +from sqlalchemy import func + +from superset import appbuilder, db, security_manager +from superset.connectors.sqla.models import SqlaTable +from superset.models.dashboard import Dashboard +from superset.models.slice import Slice +from tests.dashboards.consts import DEFAULT_DASHBOARD_SLUG_TO_TEST + +logger = logging.getLogger(__name__) + +session = appbuilder.get_session + + +def get_mock_positions(dashboard: Dashboard) -> Dict[str, Any]: + positions = {"DASHBOARD_VERSION_KEY": "v2"} + for i, slc in enumerate(dashboard.slices): + id_ = "DASHBOARD_CHART_TYPE-{}".format(i) + position_data: Any = { + "type": "CHART", + "id": id_, + "children": [], + "meta": {"width": 4, "height": 50, "chartId": slc.id}, + } + positions[id_] = position_data + return positions + + +def build_save_dash_parts( + dashboard_slug: Optional[str] = None, dashboard_to_edit: Optional[Dashboard] = None +) -> Tuple[Dashboard, Dict[str, Any], Dict[str, Any]]: + if not dashboard_to_edit: + dashboard_slug = ( + dashboard_slug if dashboard_slug else DEFAULT_DASHBOARD_SLUG_TO_TEST + ) + dashboard_to_edit = get_dashboard_by_slug(dashboard_slug) + + data_before_change = { + "positions": dashboard_to_edit.position, + "dashboard_title": dashboard_to_edit.dashboard_title, + } + data_after_change = { + "css": "", + "expanded_slices": {}, + "positions": get_mock_positions(dashboard_to_edit), + "dashboard_title": dashboard_to_edit.dashboard_title, + } + return dashboard_to_edit, data_before_change, data_after_change + + +def get_all_dashboards() -> List[Dashboard]: + return db.session.query(Dashboard).all() + + +def get_dashboard_by_slug(dashboard_slug: str) -> Dashboard: + return db.session.query(Dashboard).filter_by(slug=dashboard_slug).first() + + +def get_slice_by_name(slice_name: str) -> Slice: + return db.session.query(Slice).filter_by(slice_name=slice_name).first() + + +def get_sql_table_by_name(table_name: str): + return db.session.query(SqlaTable).filter_by(table_name=table_name).one() + + +def count_dashboards() -> int: + return db.session.query(func.count(Dashboard.id)).first()[0] + + +def random_title(): + return f"title{random_str()}" + + +def random_slug(): + return f"slug{random_str()}" + + +def get_random_string(length): + letters = string.ascii_lowercase + result_str = "".join(random.choice(letters) for i in range(length)) + print("Random string of length", length, "is:", result_str) + return result_str + + +def random_str(): + return get_random_string(8) + + +def grant_access_to_dashboard(dashboard, role_name): + role = security_manager.find_role(role_name) + dashboard.roles.append(role) + db.session.merge(dashboard) + db.session.commit() + + +def revoke_access_to_dashboard(dashboard, role_name): + role = security_manager.find_role(role_name) + dashboard.roles.remove(role) + db.session.merge(dashboard) + db.session.commit() diff --git a/tests/dashboards/security/__init__.py b/tests/dashboards/security/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/dashboards/security/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/dashboards/security/base_case.py b/tests/dashboards/security/base_case.py new file mode 100644 index 0000000000000..ab24734ce7d25 --- /dev/null +++ b/tests/dashboards/security/base_case.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional + +from flask import escape, Response + +from superset.models.dashboard import Dashboard +from tests.dashboards.base_case import DashboardTestCase + + +class BaseTestDashboardSecurity(DashboardTestCase): + def tearDown(self) -> None: + self.clean_created_objects() + + def assert_dashboard_view_response( + self, response: Response, dashboard_to_access: Dashboard + ) -> None: + self.assert200(response) + assert escape(dashboard_to_access.dashboard_title) in response.data.decode( + "utf-8" + ) + + def assert_dashboard_api_response( + self, response: Response, dashboard_to_access: Dashboard + ) -> None: + self.assert200(response) + assert response.json["id"] == dashboard_to_access.id + + def assert_dashboards_list_view_response( + self, + response: Response, + expected_counts: int, + expected_dashboards: Optional[List[Dashboard]] = None, + not_expected_dashboards: Optional[List[Dashboard]] = None, + ) -> None: + self.assert200(response) + response_html = response.data.decode("utf-8") + if expected_counts == 0: + assert "No records found" in response_html + else: + # # a way to parse number of dashboards returns + # in the list view as an html response + assert ( + "Record Count: {count}".format(count=str(expected_counts)) + in response_html + ) + expected_dashboards = expected_dashboards or [] + for dashboard in expected_dashboards: + assert dashboard.url in response_html + not_expected_dashboards = not_expected_dashboards or [] + for dashboard in not_expected_dashboards: + assert dashboard.url not in response_html + + def assert_dashboards_api_response( + self, + response: Response, + expected_counts: int, + expected_dashboards: Optional[List[Dashboard]] = None, + not_expected_dashboards: Optional[List[Dashboard]] = None, + ) -> None: + self.assert200(response) + response_data = response.json + assert response_data["count"] == expected_counts + response_dashboards_url = set( + map(lambda dash: dash["url"], response_data["result"]) + ) + expected_dashboards = expected_dashboards or [] + for dashboard in expected_dashboards: + assert dashboard.url in response_dashboards_url + not_expected_dashboards = not_expected_dashboards or [] + for dashboard in not_expected_dashboards: + assert dashboard.url not in response_dashboards_url diff --git a/tests/dashboards/security/security_dataset_tests.py b/tests/dashboards/security/security_dataset_tests.py new file mode 100644 index 0000000000000..842f5c09f3968 --- /dev/null +++ b/tests/dashboards/security/security_dataset_tests.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for Superset""" +import json + +import prison +import pytest +from flask import escape + +from superset import app +from superset.models import core as models +from tests.dashboards.base_case import DashboardTestCase +from tests.dashboards.consts import * +from tests.dashboards.dashboard_test_utils import * +from tests.dashboards.superset_factory_util import * +from tests.fixtures.energy_dashboard import load_energy_table_with_slice + + +class TestDashboardDatasetSecurity(DashboardTestCase): + @pytest.fixture + def load_dashboard(self): + with app.app_context(): + table = ( + db.session.query(SqlaTable).filter_by(table_name="energy_usage").one() + ) + # get a slice from the allowed table + slice = db.session.query(Slice).filter_by(slice_name="Energy Sankey").one() + + self.grant_public_access_to_table(table) + + pytest.hidden_dash_slug = f"hidden_dash_{random_slug()}" + pytest.published_dash_slug = f"published_dash_{random_slug()}" + + # Create a published and hidden dashboard and add them to the database + published_dash = Dashboard() + published_dash.dashboard_title = "Published Dashboard" + published_dash.slug = pytest.published_dash_slug + published_dash.slices = [slice] + published_dash.published = True + + hidden_dash = Dashboard() + hidden_dash.dashboard_title = "Hidden Dashboard" + hidden_dash.slug = pytest.hidden_dash_slug + hidden_dash.slices = [slice] + hidden_dash.published = False + + db.session.merge(published_dash) + db.session.merge(hidden_dash) + yield db.session.commit() + + self.revoke_public_access_to_table(table) + db.session.delete(published_dash) + db.session.delete(hidden_dash) + db.session.commit() + + def test_dashboard_access__admin_can_access_all(self): + # arrange + self.login(username=ADMIN_USERNAME) + dashboard_title_by_url = { + dash.url: dash.dashboard_title for dash in get_all_dashboards() + } + + # act + responses_by_url = { + url: self.client.get(url).data.decode("utf-8") + for url in dashboard_title_by_url.keys() + } + + # assert + for dashboard_url, get_dashboard_response in responses_by_url.items(): + assert ( + escape(dashboard_title_by_url[dashboard_url]) in get_dashboard_response + ) + + def test_get_dashboards__users_are_dashboards_owners(self): + # arrange + username = "gamma" + user = security_manager.find_user(username) + my_owned_dashboard = create_dashboard_to_db( + dashboard_title="My Dashboard", published=False, owners=[user], + ) + + not_my_owned_dashboard = create_dashboard_to_db( + dashboard_title="Not My Dashboard", published=False, + ) + + self.login(user.username) + + # act + get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) + + # assert + self.assertIn(my_owned_dashboard.url, get_dashboards_response) + self.assertNotIn(not_my_owned_dashboard.url, get_dashboards_response) + + def test_get_dashboards__owners_can_view_empty_dashboard(self): + # arrange + dash = create_dashboard_to_db("Empty Dashboard", slug="empty_dashboard") + dashboard_url = dash.url + gamma_user = security_manager.find_user("gamma") + self.login(gamma_user.username) + + # act + get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) + + # assert + self.assertNotIn(dashboard_url, get_dashboards_response) + + def test_get_dashboards__users_can_view_favorites_dashboards(self): + # arrange + user = security_manager.find_user("gamma") + fav_dash_slug = f"my_favorite_dash_{random_slug()}" + regular_dash_slug = f"regular_dash_{random_slug()}" + + favorite_dash = Dashboard() + favorite_dash.dashboard_title = "My Favorite Dashboard" + favorite_dash.slug = fav_dash_slug + + regular_dash = Dashboard() + regular_dash.dashboard_title = "A Plain Ol Dashboard" + regular_dash.slug = regular_dash_slug + + db.session.merge(favorite_dash) + db.session.merge(regular_dash) + db.session.commit() + + dash = db.session.query(Dashboard).filter_by(slug=fav_dash_slug).first() + + favorites = models.FavStar() + favorites.obj_id = dash.id + favorites.class_name = "Dashboard" + favorites.user_id = user.id + + db.session.merge(favorites) + db.session.commit() + + self.login(user.username) + + # act + get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) + + # assert + self.assertIn(f"/superset/dashboard/{fav_dash_slug}/", get_dashboards_response) + + def test_get_dashboards__user_can_not_view_unpublished_dash(self): + # arrange + admin_user = security_manager.find_user(ADMIN_USERNAME) + gamma_user = security_manager.find_user(GAMMA_USERNAME) + admin_and_not_published_dashboard = create_dashboard_to_db( + dashboard_title="admin_owned_unpublished_dash", owners=[admin_user] + ) + + self.login(gamma_user.username) + + # act - list dashboards as a gamma user + get_dashboards_response_as_gamma = self.get_resp(DASHBOARDS_API_URL) + + # assert + self.assertNotIn( + admin_and_not_published_dashboard.url, get_dashboards_response_as_gamma + ) + + @pytest.mark.usefixtures("load_energy_table_with_slice", "load_dashboard") + def test_get_dashboards__users_can_view_permitted_dashboard(self): + # arrange + username = random_str() + new_role = f"role_{random_str()}" + self.create_user_with_roles(username, [new_role], should_create_roles=True) + accessed_table = get_sql_table_by_name("energy_usage") + self.grant_role_access_to_table(accessed_table, new_role) + # get a slice from the allowed table + slice_to_add_to_dashboards = get_slice_by_name("Energy Sankey") + # Create a published and hidden dashboard and add them to the database + first_dash = create_dashboard_to_db( + dashboard_title="Published Dashboard", + published=True, + slices=[slice_to_add_to_dashboards], + ) + + second_dash = create_dashboard_to_db( + dashboard_title="Hidden Dashboard", + published=True, + slices=[slice_to_add_to_dashboards], + ) + + try: + self.login(username) + # act + get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) + + # assert + self.assertIn(second_dash.url, get_dashboards_response) + self.assertIn(first_dash.url, get_dashboards_response) + finally: + self.revoke_public_access_to_table(accessed_table) + + def test_get_dashboard_api_no_data_access(self): + """ + Dashboard API: Test get dashboard without data access + """ + admin = self.get_user("admin") + dashboard = create_dashboard_to_db( + random_title(), random_slug(), owners=[admin] + ) + + self.login(username="gamma") + uri = DASHBOARD_API_URL_FORMAT.format(dashboard.id) + rv = self.client.get(uri) + self.assert404(rv) + + def test_get_dashboards_api_no_data_access(self): + """ + Dashboard API: Test get dashboards no data access + """ + admin = self.get_user("admin") + title = f"title{random_str()}" + create_dashboard_to_db(title, "slug1", owners=[admin]) + + self.login(username="gamma") + arguments = { + "filters": [{"col": "dashboard_title", "opr": "sw", "value": title[0:8]}] + } + uri = DASHBOARDS_API_URL_WITH_QUERY_FORMAT.format(prison.dumps(arguments)) + rv = self.client.get(uri) + self.assert200(rv) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(0, data["count"]) diff --git a/tests/dashboards/security/security_rbac_tests.py b/tests/dashboards/security/security_rbac_tests.py new file mode 100644 index 0000000000000..efd2c990f314c --- /dev/null +++ b/tests/dashboards/security/security_rbac_tests.py @@ -0,0 +1,308 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for Superset""" +from unittest import mock + +from tests.dashboards.dashboard_test_utils import * +from tests.dashboards.security.base_case import BaseTestDashboardSecurity +from tests.dashboards.superset_factory_util import ( + create_dashboard_to_db, + create_database_to_db, + create_datasource_table_to_db, + create_slice_to_db, +) + + +@mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", DASHBOARD_RBAC=True, +) +class TestDashboardRoleBasedSecurity(BaseTestDashboardSecurity): + def test_get_dashboards_list__admin_get_all_dashboards(self): + # arrange + create_dashboard_to_db( + owners=[], slices=[create_slice_to_db()], published=False + ) + dashboard_counts = count_dashboards() + + self.login("admin") + + # act + response = self.get_dashboards_list_response() + + # assert + self.assert_dashboards_list_view_response(response, dashboard_counts) + + def test_get_dashboards_list__owner_get_all_owned_dashboards(self): + # arrange + username = random_str() + new_role = f"role_{random_str()}" + owner = self.create_user_with_roles( + username, [new_role], should_create_roles=True + ) + database = create_database_to_db() + table = create_datasource_table_to_db(db_id=database.id, owners=[owner]) + first_dash = create_dashboard_to_db( + owners=[owner], slices=[create_slice_to_db(datasource_id=table.id)] + ) + second_dash = create_dashboard_to_db( + owners=[owner], slices=[create_slice_to_db(datasource_id=table.id)] + ) + owned_dashboards = [first_dash, second_dash] + not_owned_dashboards = [ + create_dashboard_to_db( + slices=[create_slice_to_db(datasource_id=table.id)], published=True + ) + ] + + self.login(username) + + # act + response = self.get_dashboards_list_response() + + # assert + self.assert_dashboards_list_view_response( + response, 2, owned_dashboards, not_owned_dashboards + ) + + def test_get_dashboards_list__user_without_any_permissions_get_empty_list(self): + + # arrange + username = random_str() + new_role = f"role_{random_str()}" + self.create_user_with_roles(username, [new_role], should_create_roles=True) + + create_dashboard_to_db(published=True) + self.login(username) + + # act + response = self.get_dashboards_list_response() + + # assert + self.assert_dashboards_list_view_response(response, 0) + + def test_get_dashboards_list__user_get_only_published_permitted_dashboards(self): + # arrange + username = random_str() + new_role = f"role_{random_str()}" + self.create_user_with_roles(username, [new_role], should_create_roles=True) + + published_dashboards = [ + create_dashboard_to_db(published=True), + create_dashboard_to_db(published=True), + ] + not_published_dashboards = [ + create_dashboard_to_db(published=False), + create_dashboard_to_db(published=False), + ] + + for dash in published_dashboards + not_published_dashboards: + grant_access_to_dashboard(dash, new_role) + + self.login(username) + + # act + response = self.get_dashboards_list_response() + + # assert + self.assert_dashboards_list_view_response( + response, + len(published_dashboards), + published_dashboards, + not_published_dashboards, + ) + + # post + for dash in published_dashboards + not_published_dashboards: + revoke_access_to_dashboard(dash, new_role) + + def test_get_dashboards_list__public_user_without_any_permissions_get_empty_list( + self, + ): + create_dashboard_to_db(published=True) + + # act + response = self.get_dashboards_list_response() + + # assert + self.assert_dashboards_list_view_response(response, 0) + + def test_get_dashboards_list__public_user_get_only_published_permitted_dashboards( + self, + ): + # arrange + published_dashboards = [ + create_dashboard_to_db(published=True), + create_dashboard_to_db(published=True), + ] + not_published_dashboards = [ + create_dashboard_to_db(published=False), + create_dashboard_to_db(published=False), + ] + + for dash in published_dashboards + not_published_dashboards: + grant_access_to_dashboard(dash, "Public") + + # act + response = self.get_dashboards_list_response() + + # assert + self.assert_dashboards_list_view_response( + response, + len(published_dashboards), + published_dashboards, + not_published_dashboards, + ) + + # post + for dash in published_dashboards + not_published_dashboards: + revoke_access_to_dashboard(dash, "Public") + + def test_get_dashboards_api__admin_get_all_dashboards(self): + # arrange + create_dashboard_to_db( + owners=[], slices=[create_slice_to_db()], published=False + ) + dashboard_counts = count_dashboards() + + self.login("admin") + + # act + response = self.get_dashboards_api_response() + + # assert + self.assert_dashboards_api_response(response, dashboard_counts) + + def test_get_dashboards_api__owner_get_all_owned_dashboards(self): + # arrange + username = random_str() + new_role = f"role_{random_str()}" + owner = self.create_user_with_roles( + username, [new_role], should_create_roles=True + ) + database = create_database_to_db() + table = create_datasource_table_to_db(db_id=database.id, owners=[owner]) + first_dash = create_dashboard_to_db( + owners=[owner], slices=[create_slice_to_db(datasource_id=table.id)] + ) + second_dash = create_dashboard_to_db( + owners=[owner], slices=[create_slice_to_db(datasource_id=table.id)] + ) + owned_dashboards = [first_dash, second_dash] + not_owned_dashboards = [ + create_dashboard_to_db( + slices=[create_slice_to_db(datasource_id=table.id)], published=True + ) + ] + + self.login(username) + + # act + response = self.get_dashboards_api_response() + + # assert + self.assert_dashboards_api_response( + response, 2, owned_dashboards, not_owned_dashboards + ) + + def test_get_dashboards_api__user_without_any_permissions_get_empty_list(self): + username = random_str() + new_role = f"role_{random_str()}" + self.create_user_with_roles(username, [new_role], should_create_roles=True) + create_dashboard_to_db(published=True) + self.login(username) + + # act + response = self.get_dashboards_api_response() + + # assert + self.assert_dashboards_api_response(response, 0) + + def test_get_dashboards_api__user_get_only_published_permitted_dashboards(self): + username = random_str() + new_role = f"role_{random_str()}" + self.create_user_with_roles(username, [new_role], should_create_roles=True) + # arrange + published_dashboards = [ + create_dashboard_to_db(published=True), + create_dashboard_to_db(published=True), + ] + not_published_dashboards = [ + create_dashboard_to_db(published=False), + create_dashboard_to_db(published=False), + ] + + for dash in published_dashboards + not_published_dashboards: + grant_access_to_dashboard(dash, new_role) + + self.login(username) + + # act + response = self.get_dashboards_api_response() + + # assert + self.assert_dashboards_api_response( + response, + len(published_dashboards), + published_dashboards, + not_published_dashboards, + ) + + # post + for dash in published_dashboards + not_published_dashboards: + revoke_access_to_dashboard(dash, new_role) + + def test_get_dashboards_api__public_user_without_any_permissions_get_empty_list( + self, + ): + create_dashboard_to_db(published=True) + + # act + response = self.get_dashboards_api_response() + + # assert + self.assert_dashboards_api_response(response, 0) + + def test_get_dashboards_api__public_user_get_only_published_permitted_dashboards( + self, + ): + # arrange + published_dashboards = [ + create_dashboard_to_db(published=True), + create_dashboard_to_db(published=True), + ] + not_published_dashboards = [ + create_dashboard_to_db(published=False), + create_dashboard_to_db(published=False), + ] + + for dash in published_dashboards + not_published_dashboards: + grant_access_to_dashboard(dash, "Public") + + # act + response = self.get_dashboards_api_response() + + # assert + self.assert_dashboards_api_response( + response, + len(published_dashboards), + published_dashboards, + not_published_dashboards, + ) + + # post + for dash in published_dashboards + not_published_dashboards: + revoke_access_to_dashboard(dash, "Public") diff --git a/tests/dashboards/superset_factory_util.py b/tests/dashboards/superset_factory_util.py new file mode 100644 index 0000000000000..f62f4ac2cddf7 --- /dev/null +++ b/tests/dashboards/superset_factory_util.py @@ -0,0 +1,305 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import List, Optional + +from flask_appbuilder import Model +from flask_appbuilder.security.sqla.models import User + +from superset import appbuilder +from superset.connectors.sqla.models import SqlaTable, sqlatable_user +from superset.models.core import Database +from superset.models.dashboard import ( + Dashboard, + dashboard_slices, + dashboard_user, + DashboardRoles, +) +from superset.models.slice import Slice, slice_user +from tests.dashboards.dashboard_test_utils import random_slug, random_str, random_title + +logger = logging.getLogger(__name__) + +session = appbuilder.get_session + +inserted_dashboards_ids = [] +inserted_databases_ids = [] +inserted_sqltables_ids = [] +inserted_slices_ids = [] + + +def create_dashboard_to_db( + dashboard_title: Optional[str] = None, + slug: Optional[str] = None, + published: bool = False, + owners: Optional[List[User]] = None, + slices: Optional[List[Slice]] = None, + css: str = "", + json_metadata: str = "", + position_json: str = "", +) -> Dashboard: + dashboard = create_dashboard( + dashboard_title, + slug, + published, + owners, + slices, + css, + json_metadata, + position_json, + ) + + insert_model(dashboard) + inserted_dashboards_ids.append(dashboard.id) + return dashboard + + +def create_dashboard( + dashboard_title: Optional[str] = None, + slug: Optional[str] = None, + published: bool = False, + owners: Optional[List[User]] = None, + slices: Optional[List[Slice]] = None, + css: str = "", + json_metadata: str = "", + position_json: str = "", +) -> Dashboard: + dashboard_title = dashboard_title or random_title() + slug = slug or random_slug() + owners = owners or [] + slices = slices or [] + return Dashboard( + dashboard_title=dashboard_title, + slug=slug, + published=published, + owners=owners, + css=css, + position_json=position_json, + json_metadata=json_metadata, + slices=slices, + ) + + +def insert_model(dashboard: Model) -> None: + session.add(dashboard) + session.commit() + session.refresh(dashboard) + + +def create_slice_to_db( + name: Optional[str] = None, + datasource_id: Optional[int] = None, + owners: Optional[List[User]] = None, +) -> Slice: + slice_ = create_slice(datasource_id, name, owners) + insert_model(slice_) + inserted_slices_ids.append(slice_.id) + return slice_ + + +def create_slice( + datasource_id: Optional[int], name: Optional[str], owners: Optional[List[User]] +) -> Slice: + name = name or random_str() + owners = owners or [] + datasource_id = ( + datasource_id or create_datasource_table_to_db(name=name + "_table").id + ) + return Slice( + slice_name=name, + datasource_id=datasource_id, + owners=owners, + datasource_type="table", + ) + + +def create_datasource_table_to_db( + name: Optional[str] = None, + db_id: Optional[int] = None, + owners: Optional[List[User]] = None, +) -> SqlaTable: + sqltable = create_datasource_table(name, db_id, owners) + insert_model(sqltable) + inserted_sqltables_ids.append(sqltable.id) + return sqltable + + +def create_datasource_table( + name: Optional[str] = None, + db_id: Optional[int] = None, + owners: Optional[List[User]] = None, +) -> SqlaTable: + name = name or random_str() + owners = owners or [] + db_id = db_id or create_database_to_db(name=name + "_db").id + return SqlaTable(table_name=name, database_id=db_id, owners=owners) + + +def create_database_to_db(name: Optional[str] = None) -> Database: + database = create_database(name) + insert_model(database) + inserted_databases_ids.append(database.id) + return database + + +def create_database(name: Optional[str] = None) -> Database: + name = name or random_str() + return Database(database_name=name, sqlalchemy_uri="sqlite:///:memory:") + + +def delete_all_inserted_objects() -> None: + delete_all_inserted_dashboards() + delete_all_inserted_slices() + delete_all_inserted_tables() + delete_all_inserted_dbs() + + +def delete_all_inserted_dashboards(): + try: + dashboards_to_delete: List[Dashboard] = session.query(Dashboard).filter( + Dashboard.id.in_(inserted_dashboards_ids) + ).all() + for dashboard in dashboards_to_delete: + try: + delete_dashboard(dashboard, False) + except Exception as ex: + logger.error(f"failed to delete {dashboard.id}", exc_info=True) + raise ex + if len(inserted_dashboards_ids) > 0: + session.commit() + inserted_dashboards_ids.clear() + except Exception as ex2: + logger.error("delete_all_inserted_dashboards failed", exc_info=True) + raise ex2 + + +def delete_dashboard(dashboard: Dashboard, do_commit: bool = False) -> None: + logger.info(f"deleting dashboard{dashboard.id}") + delete_dashboard_roles_associations(dashboard) + delete_dashboard_users_associations(dashboard) + delete_dashboard_slices_associations(dashboard) + session.delete(dashboard) + if do_commit: + session.commit() + + +def delete_dashboard_users_associations(dashboard: Dashboard) -> None: + session.execute( + dashboard_user.delete().where(dashboard_user.c.dashboard_id == dashboard.id) + ) + + +def delete_dashboard_roles_associations(dashboard: Dashboard) -> None: + session.execute( + DashboardRoles.delete().where(DashboardRoles.c.dashboard_id == dashboard.id) + ) + + +def delete_dashboard_slices_associations(dashboard: Dashboard) -> None: + session.execute( + dashboard_slices.delete().where(dashboard_slices.c.dashboard_id == dashboard.id) + ) + + +def delete_all_inserted_slices(): + try: + slices_to_delete: List[Slice] = session.query(Slice).filter( + Slice.id.in_(inserted_slices_ids) + ).all() + for slice in slices_to_delete: + try: + delete_slice(slice, False) + except Exception as ex: + logger.error(f"failed to delete {slice.id}", exc_info=True) + raise ex + if len(inserted_slices_ids) > 0: + session.commit() + inserted_slices_ids.clear() + except Exception as ex2: + logger.error("delete_all_inserted_slices failed", exc_info=True) + raise ex2 + + +def delete_slice(slice_: Slice, do_commit: bool = False) -> None: + logger.info(f"deleting slice{slice_.id}") + delete_slice_users_associations(slice_) + session.delete(slice_) + if do_commit: + session.commit() + + +def delete_slice_users_associations(slice_: Slice) -> None: + session.execute(slice_user.delete().where(slice_user.c.slice_id == slice_.id)) + + +def delete_all_inserted_tables(): + try: + tables_to_delete: List[SqlaTable] = session.query(SqlaTable).filter( + SqlaTable.id.in_(inserted_sqltables_ids) + ).all() + for table in tables_to_delete: + try: + delete_sqltable(table, False) + except Exception as ex: + logger.error(f"failed to delete {table.id}", exc_info=True) + raise ex + if len(inserted_sqltables_ids) > 0: + session.commit() + inserted_sqltables_ids.clear() + except Exception as ex2: + logger.error("delete_all_inserted_tables failed", exc_info=True) + raise ex2 + + +def delete_sqltable(table: SqlaTable, do_commit: bool = False) -> None: + logger.info(f"deleting table{table.id}") + delete_table_users_associations(table) + session.delete(table) + if do_commit: + session.commit() + + +def delete_table_users_associations(table: SqlaTable) -> None: + session.execute( + sqlatable_user.delete().where(sqlatable_user.c.table_id == table.id) + ) + + +def delete_all_inserted_dbs(): + try: + dbs_to_delete: List[Database] = session.query(Database).filter( + Database.id.in_(inserted_databases_ids) + ).all() + for db in dbs_to_delete: + try: + delete_database(db, False) + except Exception as ex: + logger.error(f"failed to delete {db.id}", exc_info=True) + raise ex + if len(inserted_databases_ids) > 0: + session.commit() + inserted_databases_ids.clear() + except Exception as ex2: + logger.error("delete_all_inserted_databases failed", exc_info=True) + raise ex2 + + +def delete_database(database: Database, do_commit: bool = False) -> None: + logger.info(f"deleting database{database.id}") + session.delete(database) + if do_commit: + session.commit()