diff --git a/superset/databases/commands/importers/v1/utils.py b/superset/databases/commands/importers/v1/utils.py index 8881f78a9c39c..c8c2847b9f673 100644 --- a/superset/databases/commands/importers/v1/utils.py +++ b/superset/databases/commands/importers/v1/utils.py @@ -20,10 +20,13 @@ from sqlalchemy.orm import Session -from superset import security_manager +from superset import app, security_manager from superset.commands.exceptions import ImportFailedError from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.databases.utils import make_url_safe +from superset.exceptions import SupersetSecurityException from superset.models.core import Database +from superset.security.analytics_db_safety import check_sqlalchemy_uri def import_database( @@ -45,7 +48,12 @@ def import_database( raise ImportFailedError( "Database doesn't exist and user doesn't have permission to create databases" ) - + # Check if this URI is allowed + if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"]: + try: + check_sqlalchemy_uri(make_url_safe(config["sqlalchemy_uri"])) + except SupersetSecurityException as exc: + raise ImportFailedError(exc.message) from exc # https://github.com/apache/superset/pull/16756 renamed ``csv`` to ``file``. config["allow_file_upload"] = config.pop("allow_csv_upload") if "schemas_allowed_for_csv_upload" in config["extra"]: diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 8ffb31b78215f..d5946d8b6d105 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -421,7 +421,7 @@ def test_import_v1_database(self, mock_g): assert database.database_name == "imported_database" assert database.expose_in_sqllab assert database.extra == "{}" - assert database.sqlalchemy_uri == "sqlite:///test.db" + assert database.sqlalchemy_uri == "someengine://user:pass@host1" db.session.delete(database) db.session.commit() @@ -460,7 +460,7 @@ def test_import_v1_database_broken_csv_fields(self, mock_g): assert database.database_name == "imported_database" assert database.expose_in_sqllab assert database.extra == '{"schemas_allowed_for_file_upload": ["upload"]}' - assert database.sqlalchemy_uri == "sqlite:///test.db" + assert database.sqlalchemy_uri == "someengine://user:pass@host1" db.session.delete(database) db.session.commit() @@ -716,7 +716,7 @@ def test_import_v1_database_with_ssh_tunnel_password( assert database.database_name == "imported_database" assert database.expose_in_sqllab assert database.extra == "{}" - assert database.sqlalchemy_uri == "sqlite:///test.db" + assert database.sqlalchemy_uri == "someengine://user:pass@host1" model_ssh_tunnel = ( db.session.query(SSHTunnel) @@ -761,7 +761,7 @@ def test_import_v1_database_with_ssh_tunnel_private_key_and_password( assert database.database_name == "imported_database" assert database.expose_in_sqllab assert database.extra == "{}" - assert database.sqlalchemy_uri == "sqlite:///test.db" + assert database.sqlalchemy_uri == "someengine://user:pass@host1" model_ssh_tunnel = ( db.session.query(SSHTunnel) diff --git a/tests/integration_tests/fixtures/importexport.py b/tests/integration_tests/fixtures/importexport.py index 74237f4a82cfb..d0fa04e97dfc7 100644 --- a/tests/integration_tests/fixtures/importexport.py +++ b/tests/integration_tests/fixtures/importexport.py @@ -346,7 +346,7 @@ "type": "SavedQuery", "timestamp": "2021-03-30T20:37:54.791187+00:00", } -database_config: dict[str, Any] = { +database_config_sqlite: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -361,6 +361,21 @@ "version": "1.0.0", } +database_config: dict[str, Any] = { + "allow_csv_upload": True, + "allow_ctas": True, + "allow_cvas": True, + "allow_dml": True, + "allow_run_async": False, + "cache_timeout": None, + "database_name": "imported_database", + "expose_in_sqllab": True, + "extra": {}, + "sqlalchemy_uri": "someengine://user:pass@host1", + "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", + "version": "1.0.0", +} + database_with_ssh_tunnel_config_private_key: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, @@ -371,7 +386,7 @@ "database_name": "imported_database", "expose_in_sqllab": True, "extra": {}, - "sqlalchemy_uri": "sqlite:///test.db", + "sqlalchemy_uri": "someengine://user:pass@host1", "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", "ssh_tunnel": { "server_address": "localhost", @@ -393,7 +408,7 @@ "database_name": "imported_database", "expose_in_sqllab": True, "extra": {}, - "sqlalchemy_uri": "sqlite:///test.db", + "sqlalchemy_uri": "someengine://user:pass@host1", "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", "ssh_tunnel": { "server_address": "localhost", diff --git a/tests/unit_tests/databases/commands/importers/v1/import_test.py b/tests/unit_tests/databases/commands/importers/v1/import_test.py index f9d2695f26072..b8bd24d94d187 100644 --- a/tests/unit_tests/databases/commands/importers/v1/import_test.py +++ b/tests/unit_tests/databases/commands/importers/v1/import_test.py @@ -42,7 +42,7 @@ def test_import_database(mocker: MockFixture, session: Session) -> None: config = copy.deepcopy(database_config) database = import_database(session, config) assert database.database_name == "imported_database" - assert database.sqlalchemy_uri == "sqlite:///test.db" + assert database.sqlalchemy_uri == "someengine://user:pass@host1" assert database.cache_timeout is None assert database.expose_in_sqllab is True assert database.allow_run_async is False @@ -65,6 +65,32 @@ def test_import_database(mocker: MockFixture, session: Session) -> None: assert database.allow_dml is False +def test_import_database_sqlite_invalid(mocker: MockFixture, session: Session) -> None: + """ + Test importing a database. + """ + from superset import app, security_manager + from superset.databases.commands.importers.v1.utils import import_database + from superset.models.core import Database + from tests.integration_tests.fixtures.importexport import database_config_sqlite + + app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True + mocker.patch.object(security_manager, "can_access", return_value=True) + + engine = session.get_bind() + Database.metadata.create_all(engine) # pylint: disable=no-member + + config = copy.deepcopy(database_config_sqlite) + with pytest.raises(ImportFailedError) as excinfo: + _ = import_database(session, config) + assert ( + str(excinfo.value) + == "SQLiteDialect_pysqlite cannot be used as a data source for security reasons." + ) + # restore app config + app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True + + def test_import_database_managed_externally( mocker: MockFixture, session: Session,