Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(database): Database Filtering via custom configuration #24580

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,21 @@ class ExtraRelatedQueryFilters(TypedDict, total=False):

EXTRA_RELATED_QUERY_FILTERS: ExtraRelatedQueryFilters = {}

# Extra dynamic database filter make it possible to limit which databases are shown
# in the UI before any other filtering is applied. Useful for example when
# considering to filter using Feature Flags along with regular role filters
# that get applied by default in our base_filters.
# For example, to only show a database starting with the letter "b"
# in the "Database Connections" list, you could add the following in your config:
# def initial_database_filter(query: Query, *args, *kwargs):
# from superset.models.core import Database
#
# filter = Database.database_name.startswith('b')
# return query.filter(filter)
#
# EXTRA_DYNAMIC_DATABASE_FILTER = initial_database_filter
EXTRA_DYNAMIC_DATABASE_FILTER: Callable[[Query], Query] | None = None


# -------------------------------------------------------------------
# * WARNING: STOP EDITING HERE *
Expand Down
12 changes: 11 additions & 1 deletion superset/databases/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from typing import Any

from flask import g
from flask import current_app, g
from flask_babel import lazy_gettext as _
from sqlalchemy import or_
from sqlalchemy.orm import Query
Expand All @@ -41,6 +41,16 @@ class DatabaseFilter(BaseFilter): # pylint: disable=too-few-public-methods
# TODO(bogdan): consider caching.

def apply(self, query: Query, value: Any) -> Query:
# Dynamic Filters need to be applied to the Query before we filter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like a good place for a docstring comment.

# databases with anything else. This way you can show/hide databases using
# Feature Flags for example in conjuction with the regular role filtering.
# If not, if an user has access to all Databases it would skip this dynamic
# filtering.

if dynamic_filter := current_app.config["EXTRA_DYNAMIC_DATABASE_FILTER"]:
query = dynamic_filter(query)
Antonio-RiveroMartnez marked this conversation as resolved.
Show resolved Hide resolved

# We can proceed with default filtering now
if security_manager.can_access_all_databases():
return query
database_perms = security_manager.user_view_menu_names("database_access")
Expand Down
74 changes: 74 additions & 0 deletions tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3632,3 +3632,77 @@ def test_validate_sql_endpoint_failure(self, get_validator_by_name):
return
self.assertEqual(rv.status_code, 422)
self.assertIn("Kaboom!", response["errors"][0]["message"])

def test_get_databases_with_extra_filters(self):
"""
API: Test get database with extra query filter
"""
self.login(username="admin")
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
example_db = get_example_database()

if example_db.backend == "sqlite":
return
database_data = {
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"server_cert": None,
"extra": json.dumps(extra),
}

uri = "api/v1/database/"
rv = self.client.post(
uri, json={**database_data, "database_name": "dyntest-create-database-1"}
)
first_response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)

uri = "api/v1/database/"
rv = self.client.post(
uri, json={**database_data, "database_name": "create-database-2"}
)
second_response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
dbs = db.session.query(Database).all()
expected_names = [db.database_name for db in dbs]
expected_names.sort()

uri = f"api/v1/database/"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], len(dbs))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hughhhh here I am testing our current behavior (default) where all databases must be returned if nothing is being set in the config, so dynamic_filter is not defined. Then, I'm adding the patch for the config to add the filter function and testing it's being applied because dynamic_filter is defined.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added explicit assertions to check whether the filter method has been called when defined.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here I am testing our current behavior (default) where all databases must be returned if nothing is being set in the config, so dynamic_filter is not defined. Then, I'm adding the patch for the config to add the filter function and testing it's being applied because dynamic_filter is defined.

Can you write that down in the function docstring? :)

database_names = [item["database_name"] for item in data["result"]]
database_names.sort()
# All Databases because we are an admin
self.assertEqual(database_names, expected_names)
assert rv.status_code == 200

def _base_filter(query):
from superset.models.core import Database

return query.filter(Database.database_name.startswith("dyntest"))

with patch.dict(
"superset.views.filters.current_app.config",
{"EXTRA_DYNAMIC_DATABASE_FILTER": _base_filter},
):
uri = f"api/v1/database/"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 1)
database_names = [item["database_name"] for item in data["result"]]
# Only the database that starts with tests, even if we are an admin
self.assertEqual(database_names, ["dyntest-create-database-1"])
assert rv.status_code == 200

# Cleanup
first_model = db.session.query(Database).get(first_response.get("id"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to figure out how to do this via fixture or after every test so we can always land back in a normal state

second_model = db.session.query(Database).get(second_response.get("id"))
db.session.delete(first_model)
db.session.delete(second_model)
db.session.commit()
86 changes: 86 additions & 0 deletions tests/unit_tests/databases/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from uuid import UUID

import pytest
from flask import current_app
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session

Expand Down Expand Up @@ -495,3 +496,88 @@ def test_delete_ssh_tunnel_not_found(

response_tunnel = DatabaseDAO.get_ssh_tunnel(2)
assert response_tunnel is None


def test_apply_dynamic_database_filter(
mocker: MockFixture,
app: Any,
session: Session,
client: Any,
full_api_access: None,
) -> None:
"""
Test that we canfilter the list of databases
"""
with app.app_context():
from superset.daos.database import DatabaseDAO
from superset.databases.api import DatabaseRestApi
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database

DatabaseRestApi.datamodel.session = session

# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member

# Create our First Database
database = Database(
database_name="first-database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
),
)
session.add(database)
session.commit()

# Create our Second Database
database = Database(
database_name="second-database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
),
)
session.add(database)
session.commit()

# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch(
"superset.databases.ssh_tunnel.commands.delete.is_feature_enabled",
return_value=False,
)

# Get our recently created Databases
response_databases = DatabaseDAO.find_all()
assert response_databases
expected_db_names = ["first-database", "second-database"]
actual_db_names = [db.database_name for db in response_databases]
assert actual_db_names == expected_db_names

def _base_filter(query):
from superset.models.core import Database

return query.filter(Database.database_name.startswith("second"))

original_config = current_app.config.copy()
original_config["EXTRA_DYNAMIC_DATABASE_FILTER"] = _base_filter

mocker.patch("superset.views.filters.current_app.config", new=original_config)
# Get filtered list
response_databases = DatabaseDAO.find_all()
assert response_databases
expected_db_names = ["second-database"]
actual_db_names = [db.database_name for db in response_databases]
assert actual_db_names == expected_db_names
Loading