Skip to content

Commit

Permalink
fix: import database engine validation (#24697)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar authored Jul 20, 2023
1 parent 1a97245 commit cb9b865
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 10 deletions.
12 changes: 10 additions & 2 deletions superset/databases/commands/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@

from sqlalchemy.orm import Session

from superset import security_manager
from superset import app, security_manager
from superset.commands.exceptions import ImportFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.exceptions import SupersetSecurityException
from superset.models.core import Database
from superset.security.analytics_db_safety import check_sqlalchemy_uri


def import_database(
Expand All @@ -45,7 +48,12 @@ def import_database(
raise ImportFailedError(
"Database doesn't exist and user doesn't have permission to create databases"
)

# Check if this URI is allowed
if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"]:
try:
check_sqlalchemy_uri(make_url_safe(config["sqlalchemy_uri"]))
except SupersetSecurityException as exc:
raise ImportFailedError(exc.message) from exc
# https://github.com/apache/superset/pull/16756 renamed ``csv`` to ``file``.
config["allow_file_upload"] = config.pop("allow_csv_upload")
if "schemas_allowed_for_csv_upload" in config["extra"]:
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/databases/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def test_import_v1_database(self, mock_g):
assert database.database_name == "imported_database"
assert database.expose_in_sqllab
assert database.extra == "{}"
assert database.sqlalchemy_uri == "sqlite:///test.db"
assert database.sqlalchemy_uri == "someengine://user:pass@host1"

db.session.delete(database)
db.session.commit()
Expand Down Expand Up @@ -460,7 +460,7 @@ def test_import_v1_database_broken_csv_fields(self, mock_g):
assert database.database_name == "imported_database"
assert database.expose_in_sqllab
assert database.extra == '{"schemas_allowed_for_file_upload": ["upload"]}'
assert database.sqlalchemy_uri == "sqlite:///test.db"
assert database.sqlalchemy_uri == "someengine://user:pass@host1"

db.session.delete(database)
db.session.commit()
Expand Down Expand Up @@ -716,7 +716,7 @@ def test_import_v1_database_with_ssh_tunnel_password(
assert database.database_name == "imported_database"
assert database.expose_in_sqllab
assert database.extra == "{}"
assert database.sqlalchemy_uri == "sqlite:///test.db"
assert database.sqlalchemy_uri == "someengine://user:pass@host1"

model_ssh_tunnel = (
db.session.query(SSHTunnel)
Expand Down Expand Up @@ -761,7 +761,7 @@ def test_import_v1_database_with_ssh_tunnel_private_key_and_password(
assert database.database_name == "imported_database"
assert database.expose_in_sqllab
assert database.extra == "{}"
assert database.sqlalchemy_uri == "sqlite:///test.db"
assert database.sqlalchemy_uri == "someengine://user:pass@host1"

model_ssh_tunnel = (
db.session.query(SSHTunnel)
Expand Down
21 changes: 18 additions & 3 deletions tests/integration_tests/fixtures/importexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@
"type": "SavedQuery",
"timestamp": "2021-03-30T20:37:54.791187+00:00",
}
database_config: dict[str, Any] = {
database_config_sqlite: dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
Expand All @@ -361,6 +361,21 @@
"version": "1.0.0",
}

database_config: dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
"allow_dml": True,
"allow_run_async": False,
"cache_timeout": None,
"database_name": "imported_database",
"expose_in_sqllab": True,
"extra": {},
"sqlalchemy_uri": "someengine://user:pass@host1",
"uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
"version": "1.0.0",
}

database_with_ssh_tunnel_config_private_key: dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
Expand All @@ -371,7 +386,7 @@
"database_name": "imported_database",
"expose_in_sqllab": True,
"extra": {},
"sqlalchemy_uri": "sqlite:///test.db",
"sqlalchemy_uri": "someengine://user:pass@host1",
"uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
"ssh_tunnel": {
"server_address": "localhost",
Expand All @@ -393,7 +408,7 @@
"database_name": "imported_database",
"expose_in_sqllab": True,
"extra": {},
"sqlalchemy_uri": "sqlite:///test.db",
"sqlalchemy_uri": "someengine://user:pass@host1",
"uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
"ssh_tunnel": {
"server_address": "localhost",
Expand Down
28 changes: 27 additions & 1 deletion tests/unit_tests/databases/commands/importers/v1/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
config = copy.deepcopy(database_config)
database = import_database(session, config)
assert database.database_name == "imported_database"
assert database.sqlalchemy_uri == "sqlite:///test.db"
assert database.sqlalchemy_uri == "someengine://user:pass@host1"
assert database.cache_timeout is None
assert database.expose_in_sqllab is True
assert database.allow_run_async is False
Expand All @@ -65,6 +65,32 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
assert database.allow_dml is False


def test_import_database_sqlite_invalid(mocker: MockFixture, session: Session) -> None:
"""
Test importing a database.
"""
from superset import app, security_manager
from superset.databases.commands.importers.v1.utils import import_database
from superset.models.core import Database
from tests.integration_tests.fixtures.importexport import database_config_sqlite

app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
mocker.patch.object(security_manager, "can_access", return_value=True)

engine = session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member

config = copy.deepcopy(database_config_sqlite)
with pytest.raises(ImportFailedError) as excinfo:
_ = import_database(session, config)
assert (
str(excinfo.value)
== "SQLiteDialect_pysqlite cannot be used as a data source for security reasons."
)
# restore app config
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True


def test_import_database_managed_externally(
mocker: MockFixture,
session: Session,
Expand Down

0 comments on commit cb9b865

Please sign in to comment.