diff --git a/superset/config.py b/superset/config.py index 0f9fafb68e238..af2816fc06d40 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1582,6 +1582,21 @@ class ExtraRelatedQueryFilters(TypedDict, total=False): EXTRA_RELATED_QUERY_FILTERS: ExtraRelatedQueryFilters = {} +# Extra dynamic database filter make it possible to limit which databases are shown +# in the UI before any other filtering is applied. Useful for example when +# considering to filter using Feature Flags along with regular role filters +# that get applied by default in our base_filters. +# For example, to only show a database starting with the letter "b" +# in the "Database Connections" list, you could add the following in your config: +# def initial_database_filter(query: Query, *args, *kwargs): +# from superset.models.core import Database +# +# filter = Database.database_name.startswith('b') +# return query.filter(filter) +# +# EXTRA_DYNAMIC_DATABASE_FILTER = initial_database_filter +EXTRA_DYNAMIC_DATABASE_FILTER: Callable[[Query], Query] | None = None + # ------------------------------------------------------------------- # * WARNING: STOP EDITING HERE * diff --git a/superset/databases/filters.py b/superset/databases/filters.py index 2ca77b77d1c40..d5ce6aa50b135 100644 --- a/superset/databases/filters.py +++ b/superset/databases/filters.py @@ -16,7 +16,7 @@ # under the License. from typing import Any -from flask import g +from flask import current_app, g from flask_babel import lazy_gettext as _ from sqlalchemy import or_ from sqlalchemy.orm import Query @@ -41,6 +41,16 @@ class DatabaseFilter(BaseFilter): # pylint: disable=too-few-public-methods # TODO(bogdan): consider caching. def apply(self, query: Query, value: Any) -> Query: + # Dynamic Filters need to be applied to the Query before we filter + # databases with anything else. This way you can show/hide databases using + # Feature Flags for example in conjuction with the regular role filtering. + # If not, if an user has access to all Databases it would skip this dynamic + # filtering. + + if dynamic_filter := current_app.config["EXTRA_DYNAMIC_DATABASE_FILTER"]: + query = dynamic_filter(query) + + # We can proceed with default filtering now if security_manager.can_access_all_databases(): return query database_perms = security_manager.user_view_menu_names("database_access") diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 568ba0593443d..bc07ae13166bb 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -3632,3 +3632,77 @@ def test_validate_sql_endpoint_failure(self, get_validator_by_name): return self.assertEqual(rv.status_code, 422) self.assertIn("Kaboom!", response["errors"][0]["message"]) + + def test_get_databases_with_extra_filters(self): + """ + API: Test get database with extra query filter + """ + self.login(username="admin") + extra = { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_file_upload": [], + } + example_db = get_example_database() + + if example_db.backend == "sqlite": + return + database_data = { + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, + "server_cert": None, + "extra": json.dumps(extra), + } + + uri = "api/v1/database/" + rv = self.client.post( + uri, json={**database_data, "database_name": "dyntest-create-database-1"} + ) + first_response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + + uri = "api/v1/database/" + rv = self.client.post( + uri, json={**database_data, "database_name": "create-database-2"} + ) + second_response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + dbs = db.session.query(Database).all() + expected_names = [db.database_name for db in dbs] + expected_names.sort() + + uri = f"api/v1/database/" + rv = self.client.get(uri) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data["count"], len(dbs)) + database_names = [item["database_name"] for item in data["result"]] + database_names.sort() + # All Databases because we are an admin + self.assertEqual(database_names, expected_names) + assert rv.status_code == 200 + + def _base_filter(query): + from superset.models.core import Database + + return query.filter(Database.database_name.startswith("dyntest")) + + with patch.dict( + "superset.views.filters.current_app.config", + {"EXTRA_DYNAMIC_DATABASE_FILTER": _base_filter}, + ): + uri = f"api/v1/database/" + rv = self.client.get(uri) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data["count"], 1) + database_names = [item["database_name"] for item in data["result"]] + # Only the database that starts with tests, even if we are an admin + self.assertEqual(database_names, ["dyntest-create-database-1"]) + assert rv.status_code == 200 + + # Cleanup + first_model = db.session.query(Database).get(first_response.get("id")) + second_model = db.session.query(Database).get(second_response.get("id")) + db.session.delete(first_model) + db.session.delete(second_model) + db.session.commit() diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 24fde88369594..6f34a3ff9eadb 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -23,6 +23,7 @@ from uuid import UUID import pytest +from flask import current_app from pytest_mock import MockFixture from sqlalchemy.orm.session import Session @@ -495,3 +496,88 @@ def test_delete_ssh_tunnel_not_found( response_tunnel = DatabaseDAO.get_ssh_tunnel(2) assert response_tunnel is None + + +def test_apply_dynamic_database_filter( + mocker: MockFixture, + app: Any, + session: Session, + client: Any, + full_api_access: None, +) -> None: + """ + Test that we canfilter the list of databases + """ + with app.app_context(): + from superset.daos.database import DatabaseDAO + from superset.databases.api import DatabaseRestApi + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + DatabaseRestApi.datamodel.session = session + + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + + # Create our First Database + database = Database( + database_name="first-database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_file_upload": [], + } + ), + ) + session.add(database) + session.commit() + + # Create our Second Database + database = Database( + database_name="second-database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_file_upload": [], + } + ), + ) + session.add(database) + session.commit() + + # mock the lookup so that we don't need to include the driver + mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") + mocker.patch("superset.utils.log.DBEventLogger.log") + mocker.patch( + "superset.databases.ssh_tunnel.commands.delete.is_feature_enabled", + return_value=False, + ) + + # Get our recently created Databases + response_databases = DatabaseDAO.find_all() + assert response_databases + expected_db_names = ["first-database", "second-database"] + actual_db_names = [db.database_name for db in response_databases] + assert actual_db_names == expected_db_names + + def _base_filter(query): + from superset.models.core import Database + + return query.filter(Database.database_name.startswith("second")) + + original_config = current_app.config.copy() + original_config["EXTRA_DYNAMIC_DATABASE_FILTER"] = _base_filter + + mocker.patch("superset.views.filters.current_app.config", new=original_config) + # Get filtered list + response_databases = DatabaseDAO.find_all() + assert response_databases + expected_db_names = ["second-database"] + actual_db_names = [db.database_name for db in response_databases] + assert actual_db_names == expected_db_names