diff --git a/superset/config.py b/superset/config.py index 0f9fafb68e238..85ba10167f793 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1583,6 +1583,26 @@ class ExtraRelatedQueryFilters(TypedDict, total=False): EXTRA_RELATED_QUERY_FILTERS: ExtraRelatedQueryFilters = {} +# Extra dynamic query filters make it possible to limit which objects are shown +# in the UI before any other filtering is applied. Useful for example when +# considering to filter using Feature Flags along with regular role filters +# that get applied by default in our base_filters. +# For example, to only show a database starting with the letter "b" +# in the "Database Connections" list, you could add the following in your config: +# def initial_database_filter(query: Query, *args, *kwargs): +# from superset.models.core import Database +# +# filter = Database.database_name.startswith('b') +# return query.filter(filter) +# +# EXTRA_DYNAMIC_QUERY_FILTERS = {"database": initial_database_filter} +class ExtraDynamicQueryFilters(TypedDict, total=False): + databases: Callable[[Query], Query] + + +EXTRA_DYNAMIC_QUERY_FILTERS: ExtraDynamicQueryFilters = {} + + # ------------------------------------------------------------------- # * WARNING: STOP EDITING HERE * # ------------------------------------------------------------------- diff --git a/superset/databases/filters.py b/superset/databases/filters.py index 2ca77b77d1c40..384a62c9d3b6f 100644 --- a/superset/databases/filters.py +++ b/superset/databases/filters.py @@ -16,7 +16,7 @@ # under the License. from typing import Any -from flask import g +from flask import current_app, g from flask_babel import lazy_gettext as _ from sqlalchemy import or_ from sqlalchemy.orm import Query @@ -41,6 +41,19 @@ class DatabaseFilter(BaseFilter): # pylint: disable=too-few-public-methods # TODO(bogdan): consider caching. def apply(self, query: Query, value: Any) -> Query: + """ + Dynamic Filters need to be applied to the Query before we filter + databases with anything else. This way you can show/hide databases using + Feature Flags for example in conjuction with the regular role filtering. + If not, if an user has access to all Databases it would skip this dynamic + filtering. + """ + + if dynamic_filters := current_app.config["EXTRA_DYNAMIC_QUERY_FILTERS"]: + if dynamic_databases_filter := dynamic_filters.get("databases"): + query = dynamic_databases_filter(query) + + # We can proceed with default filtering now if security_manager.can_access_all_databases(): return query database_perms = security_manager.user_view_menu_names("database_access") diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 568ba0593443d..ebf94219c3b0a 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -28,6 +28,8 @@ import pytest import yaml +from unittest.mock import Mock + from sqlalchemy.engine.url import make_url from sqlalchemy.exc import DBAPIError from sqlalchemy.sql import func @@ -3632,3 +3634,94 @@ def test_validate_sql_endpoint_failure(self, get_validator_by_name): return self.assertEqual(rv.status_code, 422) self.assertIn("Kaboom!", response["errors"][0]["message"]) + + def test_get_databases_with_extra_filters(self): + """ + API: Test get database with extra query filter. + Here we are testing our default where all databases + must be returned if nothing is being set in the config. + Then, we're adding the patch for the config to add the filter function + and testing it's being applied. + """ + self.login(username="admin") + extra = { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_file_upload": [], + } + example_db = get_example_database() + + if example_db.backend == "sqlite": + return + # Create our two databases + database_data = { + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, + "server_cert": None, + "extra": json.dumps(extra), + } + + uri = "api/v1/database/" + rv = self.client.post( + uri, json={**database_data, "database_name": "dyntest-create-database-1"} + ) + first_response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + + uri = "api/v1/database/" + rv = self.client.post( + uri, json={**database_data, "database_name": "create-database-2"} + ) + second_response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + + # The filter function + def _base_filter(query): + from superset.models.core import Database + + return query.filter(Database.database_name.startswith("dyntest")) + + # Create the Mock + base_filter_mock = Mock(side_effect=_base_filter) + dbs = db.session.query(Database).all() + expected_names = [db.database_name for db in dbs] + expected_names.sort() + + uri = f"api/v1/database/" + # Get the list of databases without filter in the config + rv = self.client.get(uri) + data = json.loads(rv.data.decode("utf-8")) + # All databases must be returned if no filter is present + self.assertEqual(data["count"], len(dbs)) + database_names = [item["database_name"] for item in data["result"]] + database_names.sort() + # All Databases because we are an admin + self.assertEqual(database_names, expected_names) + assert rv.status_code == 200 + # Our filter function wasn't get called + base_filter_mock.assert_not_called() + + # Now we patch the config to include our filter function + with patch.dict( + "superset.views.filters.current_app.config", + {"EXTRA_DYNAMIC_QUERY_FILTERS": {"databases": base_filter_mock}}, + ): + uri = f"api/v1/database/" + rv = self.client.get(uri) + data = json.loads(rv.data.decode("utf-8")) + # Only one database start with dyntest + self.assertEqual(data["count"], 1) + database_names = [item["database_name"] for item in data["result"]] + # Only the database that starts with tests, even if we are an admin + self.assertEqual(database_names, ["dyntest-create-database-1"]) + assert rv.status_code == 200 + # The filter function is called now that it's defined in our config + base_filter_mock.assert_called() + + # Cleanup + first_model = db.session.query(Database).get(first_response.get("id")) + second_model = db.session.query(Database).get(second_response.get("id")) + db.session.delete(first_model) + db.session.delete(second_model) + db.session.commit() diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 24fde88369594..899e2b0234571 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -20,9 +20,11 @@ import json from io import BytesIO from typing import Any +from unittest.mock import Mock from uuid import UUID import pytest +from flask import current_app from pytest_mock import MockFixture from sqlalchemy.orm.session import Session @@ -495,3 +497,100 @@ def test_delete_ssh_tunnel_not_found( response_tunnel = DatabaseDAO.get_ssh_tunnel(2) assert response_tunnel is None + + +def test_apply_dynamic_database_filter( + mocker: MockFixture, + app: Any, + session: Session, + client: Any, + full_api_access: None, +) -> None: + """ + Test that we can filter the list of databases. + First test the default behavior without a filter and then + defining a filter function and patching the config to get + the filtered results. + """ + with app.app_context(): + from superset.daos.database import DatabaseDAO + from superset.databases.api import DatabaseRestApi + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + DatabaseRestApi.datamodel.session = session + + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + + # Create our First Database + database = Database( + database_name="first-database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_file_upload": [], + } + ), + ) + session.add(database) + session.commit() + + # Create our Second Database + database = Database( + database_name="second-database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_file_upload": [], + } + ), + ) + session.add(database) + session.commit() + + # mock the lookup so that we don't need to include the driver + mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") + mocker.patch("superset.utils.log.DBEventLogger.log") + mocker.patch( + "superset.databases.ssh_tunnel.commands.delete.is_feature_enabled", + return_value=False, + ) + + def _base_filter(query): + from superset.models.core import Database + + return query.filter(Database.database_name.startswith("second")) + + # Create a mock object + base_filter_mock = Mock(side_effect=_base_filter) + + # Get our recently created Databases + response_databases = DatabaseDAO.find_all() + assert response_databases + expected_db_names = ["first-database", "second-database"] + actual_db_names = [db.database_name for db in response_databases] + assert actual_db_names == expected_db_names + + # Ensure that the filter has not been called because it's not in our config + assert base_filter_mock.call_count == 0 + + original_config = current_app.config.copy() + original_config["EXTRA_DYNAMIC_QUERY_FILTERS"] = {"databases": base_filter_mock} + + mocker.patch("superset.views.filters.current_app.config", new=original_config) + # Get filtered list + response_databases = DatabaseDAO.find_all() + assert response_databases + expected_db_names = ["second-database"] + actual_db_names = [db.database_name for db in response_databases] + assert actual_db_names == expected_db_names + + # Ensure that the filter has been called once + assert base_filter_mock.call_count == 1