From 33ba31c462ae342afd5e61a443887ebdc95f01eb Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 30 Aug 2024 10:44:52 -0400 Subject: [PATCH] feat: allow create/update OAuth2 DB --- superset/commands/database/create.py | 24 +- superset/commands/database/test_connection.py | 23 +- superset/commands/database/update.py | 17 +- superset/commands/database/validate.py | 9 + superset/databases/api.py | 11 +- superset/db_engine_specs/base.py | 4 +- superset/exceptions.py | 2 + superset/models/core.py | 16 ++ .../commands/databases/create_test.py | 31 +++ .../databases/test_connection_test.py | 91 ++++++++ .../commands/databases/update_test.py | 48 ++++ .../commands/databases/validate_test.py | 206 ++++++++++++++++++ tests/unit_tests/databases/api_test.py | 86 +++++++- tests/unit_tests/models/core_test.py | 79 ++++++- 14 files changed, 618 insertions(+), 29 deletions(-) create mode 100644 tests/unit_tests/commands/databases/test_connection_test.py create mode 100644 tests/unit_tests/commands/databases/validate_test.py diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index b8854010faa2e..6023f939a3704 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -42,7 +42,7 @@ from superset.daos.database import DatabaseDAO from superset.databases.ssh_tunnel.models import SSHTunnel from superset.db_engine_specs.base import GenericDBException -from superset.exceptions import SupersetErrorsException +from superset.exceptions import OAuth2RedirectError, SupersetErrorsException from superset.extensions import event_logger, security_manager from superset.models.core import Database from superset.utils.decorators import on_error, transaction @@ -55,13 +55,21 @@ class CreateDatabaseCommand(BaseCommand): def __init__(self, data: dict[str, Any]): self._properties = data.copy() - @transaction(on_error=partial(on_error, reraise=DatabaseCreateFailedError)) + @transaction( + on_error=partial(on_error, reraise=DatabaseCreateFailedError), + allowed=(OAuth2RedirectError,), + ) def run(self) -> Model: self.validate() try: # Test connection before starting create transaction TestConnectionDatabaseCommand(self._properties).run() + except OAuth2RedirectError: + # If we can't connect to the database due to an OAuth2 error we can still + # save the database. Later, the user can sync permissions when setting up + # data access rules. + return self._create_database() except ( SupersetErrorsException, SSHTunnelingNotEnabledError, @@ -80,12 +88,6 @@ def run(self) -> Model: ) raise DatabaseConnectionFailedError() from ex - # when creating a new database we don't need to unmask encrypted extra - self._properties["encrypted_extra"] = self._properties.pop( - "masked_encrypted_extra", - "{}", - ) - ssh_tunnel: Optional[SSHTunnel] = None try: @@ -195,6 +197,12 @@ def validate(self) -> None: raise exception def _create_database(self) -> Database: + # when creating a new database we don't need to unmask encrypted extra + self._properties["encrypted_extra"] = self._properties.pop( + "masked_encrypted_extra", + "{}", + ) + database = DatabaseDAO.create(attributes=self._properties) database.set_sqlalchemy_uri(database.sqlalchemy_uri) return database diff --git a/superset/commands/database/test_connection.py b/superset/commands/database/test_connection.py index 7979901bca0ac..8aef6c1359b5e 100644 --- a/superset/commands/database/test_connection.py +++ b/superset/commands/database/test_connection.py @@ -41,6 +41,7 @@ from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetErrorType from superset.exceptions import ( + OAuth2RedirectError, SupersetErrorsException, SupersetSecurityException, SupersetTimeoutException, @@ -162,6 +163,13 @@ def ping(engine: Engine) -> bool: extra={"sqlalchemy_uri": database.sqlalchemy_uri}, ) from ex except Exception as ex: # pylint: disable=broad-except + # If the connection failed because OAuth2 is needed, start the flow. + if ( + database.is_oauth2_enabled() + and database.db_engine_spec.needs_oauth2(ex) + ): + database.start_oauth2_dance() + alive = False # So we stop losing the original message if any ex_str = str(ex) @@ -197,6 +205,8 @@ def ping(engine: Engine) -> bool: # check for custom errors (wrong username, wrong password, etc) errors = database.db_engine_spec.extract_errors(ex, self._context) raise SupersetErrorsException(errors) from ex + except OAuth2RedirectError: + raise except SupersetSecurityException as ex: event_logger.log_with_context( action=get_log_connection_action( @@ -205,23 +215,14 @@ def ping(engine: Engine) -> bool: engine=database.db_engine_spec.__name__, ) raise DatabaseSecurityUnsafeError(message=str(ex)) from ex - except SupersetTimeoutException as ex: - event_logger.log_with_context( - action=get_log_connection_action( - "test_connection_error", ssh_tunnel, ex - ), - engine=database.db_engine_spec.__name__, - ) - # bubble up the exception to return a 408 - raise - except SSHTunnelingNotEnabledError as ex: + except (SupersetTimeoutException, SSHTunnelingNotEnabledError) as ex: event_logger.log_with_context( action=get_log_connection_action( "test_connection_error", ssh_tunnel, ex ), engine=database.db_engine_spec.__name__, ) - # bubble up the exception to return a 400 + # bubble up the exception to return proper status code raise except Exception as ex: event_logger.log_with_context( diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 0fc31c096a063..88abae7cc6e0e 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -42,6 +42,7 @@ from superset.daos.dataset import DatasetDAO from superset.databases.ssh_tunnel.models import SSHTunnel from superset.db_engine_specs.base import GenericDBException +from superset.exceptions import OAuth2RedirectError from superset.models.core import Database from superset.utils.decorators import on_error, transaction @@ -56,7 +57,10 @@ def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._model: Database | None = None - @transaction(on_error=partial(on_error, reraise=DatabaseUpdateFailedError)) + @transaction( + on_error=partial(on_error, reraise=DatabaseUpdateFailedError), + allowed=(OAuth2RedirectError,), + ) def run(self) -> Model: self._model = DatabaseDAO.find_by_id(self._model_id) @@ -80,7 +84,10 @@ def run(self) -> Model: database = DatabaseDAO.update(self._model, self._properties) database.set_sqlalchemy_uri(database.sqlalchemy_uri) ssh_tunnel = self._handle_ssh_tunnel(database) - self._refresh_catalogs(database, original_database_name, ssh_tunnel) + try: + self._refresh_catalogs(database, original_database_name, ssh_tunnel) + except OAuth2RedirectError: + pass return database @@ -123,6 +130,9 @@ def _get_catalog_names( force=True, ssh_tunnel=ssh_tunnel, ) + except OAuth2RedirectError: + # raise OAuth2 exceptions as-is + raise except GenericDBException as ex: raise DatabaseConnectionFailedError() from ex @@ -141,6 +151,9 @@ def _get_schema_names( catalog=catalog, ssh_tunnel=ssh_tunnel, ) + except OAuth2RedirectError: + # raise OAuth2 exceptions as-is + raise except GenericDBException as ex: raise DatabaseConnectionFailedError() from ex diff --git a/superset/commands/database/validate.py b/superset/commands/database/validate.py index 65df1ed1e149b..c44190f354ce1 100644 --- a/superset/commands/database/validate.py +++ b/superset/commands/database/validate.py @@ -107,6 +107,15 @@ def run(self) -> None: with closing(engine.raw_connection()) as conn: alive = engine.dialect.do_ping(conn) except Exception as ex: + # If the connection failed because OAuth2 is needed, we can save the + # database and trigger the OAuth2 flow whenever a user tries to run a + # query. + if ( + database.is_oauth2_enabled() + and database.db_engine_spec.needs_oauth2(ex) + ): + return + url = make_url_safe(sqlalchemy_uri) context = { "hostname": url.host, diff --git a/superset/databases/api.py b/superset/databases/api.py index eb611837bc9b1..b58e46bf3fcbe 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -110,6 +110,7 @@ DatabaseNotFoundException, InvalidPayloadSchemaError, OAuth2Error, + OAuth2RedirectError, SupersetErrorsException, SupersetException, SupersetSecurityException, @@ -398,7 +399,6 @@ def get(self, pk: int, **kwargs: Any) -> Response: @expose("/", methods=("POST",)) @protect() - @safe @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post", @@ -462,6 +462,8 @@ def post(self) -> FlaskResponse: item["ssh_tunnel"] = mask_password_info(new_model.ssh_tunnel) return self.response(201, id=new_model.id, result=item) + except OAuth2RedirectError: + raise except DatabaseInvalidError as ex: return self.response_422(message=ex.normalized_messages()) except DatabaseConnectionFailedError as ex: @@ -621,7 +623,6 @@ def delete(self, pk: int) -> Response: @expose("//catalogs/") @protect() - @safe @rison(database_catalogs_query_schema) @statsd_metrics @event_logger.log_this_with_context( @@ -680,12 +681,13 @@ def catalogs(self, pk: int, **kwargs: Any) -> FlaskResponse: 500, message="There was an error connecting to the database", ) + except OAuth2RedirectError: + raise except SupersetException as ex: return self.response(ex.status, message=ex.message) @expose("//schemas/") @protect() - @safe @rison(database_schemas_query_schema) @statsd_metrics @event_logger.log_this_with_context( @@ -746,6 +748,8 @@ def schemas(self, pk: int, **kwargs: Any) -> FlaskResponse: return self.response( 500, message="There was an error connecting to the database" ) + except OAuth2RedirectError: + raise except SupersetException as ex: return self.response(ex.status, message=ex.message) @@ -2069,6 +2073,7 @@ def available(self) -> Response: "sqlalchemy_uri_placeholder": engine_spec.sqlalchemy_uri_placeholder, "preferred": engine_spec.engine_name in preferred_databases, "engine_information": engine_spec.get_public_information(), + "supports_oauth2": engine_spec.supports_oauth2, } if engine_spec.default_driver: diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 7f4d591a77984..b5f001971d3d8 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -429,8 +429,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # the user impersonation methods to handle personal tokens. supports_oauth2 = False oauth2_scope = "" - oauth2_authorization_request_uri = "" # pylint: disable=invalid-name - oauth2_token_request_uri = "" + oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name + oauth2_token_request_uri: str | None = None # Driver-specific exception that should be mapped to OAuth2RedirectError oauth2_exception = OAuth2RedirectError diff --git a/superset/exceptions.py b/superset/exceptions.py index dd669f5b72ae3..a000e08165c47 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -329,6 +329,8 @@ class OAuth2RedirectError(SupersetErrorException): See the `OAuth2RedirectMessage.tsx` component for more details of how this information is handled. + + TODO (betodealmeida): change status to 403. """ def __init__(self, url: str, tab_id: str, redirect_uri: str): diff --git a/superset/models/core.py b/superset/models/core.py index 305c8c37813c8..c528d3580f7b7 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -844,6 +844,9 @@ def get_all_schema_names( ) as inspector: return self.db_engine_spec.get_schema_names(inspector) except Exception as ex: + if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): + self.start_oauth2_dance() + raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @cache_util.memoized_func( @@ -865,6 +868,9 @@ def get_all_catalog_names( with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector: return self.db_engine_spec.get_catalog_names(self, inspector) except Exception as ex: + if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex): + self.start_oauth2_dance() + raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @property @@ -1096,6 +1102,16 @@ def get_oauth2_config(self) -> OAuth2ClientConfig | None: return self.db_engine_spec.get_oauth2_config() + def start_oauth2_dance(self) -> None: + """ + Start the OAuth2 dance. + + This method is called when an OAuth2 error is encountered, and the database is + configured to use OAuth2 for authentication. It raises an exception that will + trigger the OAuth2 dance in the frontend. + """ + return self.db_engine_spec.start_oauth2_dance(self) + sqla.event.listen(Database, "after_insert", security_manager.database_after_insert) sqla.event.listen(Database, "after_update", security_manager.database_after_update) diff --git a/tests/unit_tests/commands/databases/create_test.py b/tests/unit_tests/commands/databases/create_test.py index 09d5744afd53b..61591274db633 100644 --- a/tests/unit_tests/commands/databases/create_test.py +++ b/tests/unit_tests/commands/databases/create_test.py @@ -21,6 +21,7 @@ from pytest_mock import MockerFixture from superset.commands.database.create import CreateDatabaseCommand +from superset.exceptions import OAuth2RedirectError from superset.extensions import security_manager @@ -124,3 +125,33 @@ def test_create_permissions_without_catalog( ], any_order=True, ) + + +def test_create_with_oauth2( + mocker: MockerFixture, + database_without_catalog: MockerFixture, +) -> None: + """ + Test that the database can be created even if OAuth2 is needed to connect. + """ + TestConnectionDatabaseCommand = mocker.patch( + "superset.commands.database.create.TestConnectionDatabaseCommand" + ) + TestConnectionDatabaseCommand().run.side_effect = OAuth2RedirectError( + "url", + "tab_id", + "redirect_uri", + ) + add_permission_view_menu = mocker.patch.object( + security_manager, + "add_permission_view_menu", + ) + + CreateDatabaseCommand( + { + "database_name": "test_database", + "sqlalchemy_uri": "sqlite://", + } + ).run() + + add_permission_view_menu.assert_not_called() diff --git a/tests/unit_tests/commands/databases/test_connection_test.py b/tests/unit_tests/commands/databases/test_connection_test.py new file mode 100644 index 0000000000000..eab2b466790a5 --- /dev/null +++ b/tests/unit_tests/commands/databases/test_connection_test.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from pytest_mock import MockerFixture + +from superset.commands.database.test_connection import TestConnectionDatabaseCommand +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import OAuth2RedirectError + + +def test_command(mocker: MockerFixture) -> None: + """ + Test the happy path of the command. + """ + user = mocker.MagicMock() + user.email = "alice@example.org" + mocker.patch("superset.db_engine_specs.gsheets.g", user=user) + mocker.patch("superset.db_engine_specs.gsheets.create_engine") + + database = mocker.MagicMock() + database.db_engine_spec.__name__ = "GSheetsEngineSpec" + with database.get_sqla_engine() as engine: + engine.dialect.do_ping.return_value = True + + DatabaseDAO = mocker.patch("superset.commands.database.test_connection.DatabaseDAO") + DatabaseDAO.build_db_for_connection_test.return_value = database + + properties = { + "sqlalchemy_uri": "gsheets://", + "engine": "gsheets", + "driver": "gsheets", + "catalog": {"test": "https://example.org/"}, + } + command = TestConnectionDatabaseCommand(properties) + command.run() + + +def test_command_with_oauth2(mocker: MockerFixture) -> None: + """ + Test the command when OAuth2 is needed. + """ + user = mocker.MagicMock() + user.email = "alice@example.org" + mocker.patch("superset.db_engine_specs.gsheets.g", user=user) + mocker.patch("superset.db_engine_specs.gsheets.create_engine") + + database = mocker.MagicMock() + database.is_oauth2_enabled.return_value = True + database.db_engine_spec.needs_oauth2.return_value = True + database.start_oauth2_dance.side_effect = OAuth2RedirectError( + "url", + "tab_id", + "redirect_uri", + ) + database.db_engine_spec.__name__ = "GSheetsEngineSpec" + with database.get_sqla_engine() as engine: + engine.dialect.do_ping.side_effect = Exception("OAuth2 needed") + + DatabaseDAO = mocker.patch("superset.commands.database.test_connection.DatabaseDAO") + DatabaseDAO.build_db_for_connection_test.return_value = database + + properties = { + "sqlalchemy_uri": "gsheets://", + "engine": "gsheets", + "driver": "gsheets", + "catalog": {"test": "https://example.org/"}, + } + command = TestConnectionDatabaseCommand(properties) + with pytest.raises(OAuth2RedirectError) as excinfo: + command.run() + assert excinfo.value.error == SupersetError( + message="You don't have permission to access the data.", + error_type=SupersetErrorType.OAUTH2_REDIRECT, + level=ErrorLevel.WARNING, + extra={"url": "url", "tab_id": "tab_id", "redirect_uri": "redirect_uri"}, + ) diff --git a/tests/unit_tests/commands/databases/update_test.py b/tests/unit_tests/commands/databases/update_test.py index b1b5e6843f0c2..d7b60f85d110f 100644 --- a/tests/unit_tests/commands/databases/update_test.py +++ b/tests/unit_tests/commands/databases/update_test.py @@ -21,6 +21,7 @@ from pytest_mock import MockerFixture from superset.commands.database.update import UpdateDatabaseCommand +from superset.exceptions import OAuth2RedirectError from superset.extensions import security_manager @@ -57,6 +58,24 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock: return database +@pytest.fixture() +def database_needs_oauth2(mocker: MockerFixture) -> MagicMock: + """ + Mock a database without catalogs that needs OAuth2. + """ + database = mocker.MagicMock() + database.database_name = "my_db" + database.db_engine_spec.__name__ = "test_engine" + database.db_engine_spec.supports_catalog = False + database.get_all_schema_names.side_effect = OAuth2RedirectError( + "url", + "tab_id", + "redirect_uri", + ) + + return database + + def test_update_with_catalog( mocker: MockerFixture, database_with_catalog: MockerFixture, @@ -276,3 +295,32 @@ def test_rename_without_catalog( ) assert schema2_pvm.view_menu.name == "[my_other_db].[schema2]" + + +def test_update_with_oauth2( + mocker: MockerFixture, + database_needs_oauth2: MockerFixture, +) -> None: + """ + Test that the database can be updated even if OAuth2 is needed to connect. + """ + DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") + DatabaseDAO.find_by_id.return_value = database_needs_oauth2 + DatabaseDAO.update.return_value = database_needs_oauth2 + + find_permission_view_menu = mocker.patch.object( + security_manager, + "find_permission_view_menu", + ) + find_permission_view_menu.side_effect = [ + None, # schema1 has no permissions + "[my_db].[schema2]", # second schema already exists + ] + add_permission_view_menu = mocker.patch.object( + security_manager, + "add_permission_view_menu", + ) + + UpdateDatabaseCommand(1, {}).run() + + add_permission_view_menu.assert_not_called() diff --git a/tests/unit_tests/commands/databases/validate_test.py b/tests/unit_tests/commands/databases/validate_test.py new file mode 100644 index 0000000000000..fde462536a75a --- /dev/null +++ b/tests/unit_tests/commands/databases/validate_test.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from pytest_mock import MockerFixture + +from superset.commands.database.exceptions import ( + DatabaseOfflineError, + DatabaseTestConnectionFailedError, + InvalidParametersError, +) +from superset.commands.database.validate import ValidateDatabaseParametersCommand +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType + + +def test_command(mocker: MockerFixture) -> None: + """ + Test the happy path of the command. + """ + user = mocker.MagicMock() + user.email = "alice@example.org" + mocker.patch("superset.db_engine_specs.gsheets.g", user=user) + mocker.patch("superset.db_engine_specs.gsheets.create_engine") + + database = mocker.MagicMock() + with database.get_sqla_engine() as engine: + engine.dialect.do_ping.return_value = True + + DatabaseDAO = mocker.patch("superset.commands.database.validate.DatabaseDAO") + DatabaseDAO.build_db_for_connection_test.return_value = database + + properties = { + "engine": "gsheets", + "driver": "gsheets", + "catalog": {"test": "https://example.org/"}, + } + command = ValidateDatabaseParametersCommand(properties) + command.run() + + +def test_command_invalid(mocker: MockerFixture) -> None: + """ + Test the command when the payload is invalid. + """ + user = mocker.MagicMock() + user.email = "alice@example.org" + mocker.patch("superset.db_engine_specs.gsheets.g", user=user) + mocker.patch("superset.db_engine_specs.gsheets.create_engine") + + database = mocker.MagicMock() + with database.get_sqla_engine() as engine: + engine.dialect.do_ping.return_value = True + + DatabaseDAO = mocker.patch("superset.commands.database.validate.DatabaseDAO") + DatabaseDAO.build_db_for_connection_test.return_value = database + + properties = { + "engine": "gsheets", + "driver": "gsheets", + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(InvalidParametersError) as excinfo: + command.run() + assert excinfo.value.errors == [ + SupersetError( + message="Sheet name is required", + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.WARNING, + extra={ + "catalog": {"idx": 0, "name": True}, + "issue_codes": [ + { + "code": 1018, + "message": ( + "Issue 1018 - One or more parameters needed to configure a " + "database are missing." + ), + } + ], + }, + ) + ] + + +def test_command_no_ping(mocker: MockerFixture) -> None: + """ + Test the command when it can't ping the database. + """ + user = mocker.MagicMock() + user.email = "alice@example.org" + mocker.patch("superset.db_engine_specs.gsheets.g", user=user) + mocker.patch("superset.db_engine_specs.gsheets.create_engine") + + database = mocker.MagicMock() + with database.get_sqla_engine() as engine: + engine.dialect.do_ping.return_value = False + + DatabaseDAO = mocker.patch("superset.commands.database.validate.DatabaseDAO") + DatabaseDAO.build_db_for_connection_test.return_value = database + + properties = { + "engine": "gsheets", + "driver": "gsheets", + "catalog": {"test": "https://example.org/"}, + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(DatabaseOfflineError) as excinfo: + command.run() + assert excinfo.value.error == SupersetError( + message="Database is offline.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + extra={ + "issue_codes": [ + { + "code": 1002, + "message": "Issue 1002 - The database returned an unexpected error.", + } + ] + }, + ) + + +def test_command_with_oauth2(mocker: MockerFixture) -> None: + """ + Test the command when OAuth2 is needed. + """ + user = mocker.MagicMock() + user.email = "alice@example.org" + mocker.patch("superset.db_engine_specs.gsheets.g", user=user) + mocker.patch("superset.db_engine_specs.gsheets.create_engine") + + database = mocker.MagicMock() + database.is_oauth2_enabled.return_value = True + database.db_engine_spec.needs_oauth2.return_value = True + with database.get_sqla_engine() as engine: + engine.dialect.do_ping.side_effect = Exception("OAuth2 needed") + + DatabaseDAO = mocker.patch("superset.commands.database.validate.DatabaseDAO") + DatabaseDAO.build_db_for_connection_test.return_value = database + + properties = { + "engine": "gsheets", + "driver": "gsheets", + "catalog": {"test": "https://example.org/"}, + } + command = ValidateDatabaseParametersCommand(properties) + command.run() + + +def test_command_with_oauth2_not_configured(mocker: MockerFixture) -> None: + """ + Test the command when OAuth2 is needed but not configured in the DB. + """ + user = mocker.MagicMock() + user.email = "alice@example.org" + mocker.patch("superset.db_engine_specs.gsheets.g", user=user) + mocker.patch("superset.db_engine_specs.gsheets.create_engine") + + database = mocker.MagicMock() + database.is_oauth2_enabled.return_value = False + database.db_engine_spec.needs_oauth2.return_value = True + database.db_engine_spec.extract_errors.return_value = [ + SupersetError( + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + message="OAuth2 is needed but not configured.", + level=ErrorLevel.ERROR, + extra={"engine_name": "gsheets"}, + ) + ] + with database.get_sqla_engine() as engine: + engine.dialect.do_ping.side_effect = Exception("OAuth2 needed") + + DatabaseDAO = mocker.patch("superset.commands.database.validate.DatabaseDAO") + DatabaseDAO.build_db_for_connection_test.return_value = database + + properties = { + "engine": "gsheets", + "driver": "gsheets", + "catalog": {"test": "https://example.org/"}, + } + command = ValidateDatabaseParametersCommand(properties) + with pytest.raises(DatabaseTestConnectionFailedError) as excinfo: + command.run() + assert excinfo.value.errors == [ + SupersetError( + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + message="OAuth2 is needed but not configured.", + level=ErrorLevel.ERROR, + extra={"engine_name": "gsheets"}, + ) + ] diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index f4534d216b9b7..746bf04cd9098 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -38,7 +38,7 @@ from superset.commands.database.uploaders.excel_reader import ExcelReader from superset.db_engine_specs.sqlite import SqliteEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import SupersetSecurityException +from superset.exceptions import OAuth2RedirectError, SupersetSecurityException from superset.sql_parse import Table from superset.utils import json from tests.unit_tests.fixtures.common import ( @@ -2112,6 +2112,47 @@ def test_catalogs( ) +def test_catalogs_with_oauth2( + mocker: MockerFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `catalogs` endpoint when OAuth2 is needed. + """ + database = mocker.MagicMock() + database.get_all_catalog_names.side_effect = OAuth2RedirectError( + "url", + "tab_id", + "redirect_uri", + ) + DatabaseDAO = mocker.patch("superset.databases.api.DatabaseDAO") + DatabaseDAO.find_by_id.return_value = database + + security_manager = mocker.patch( + "superset.databases.api.security_manager", + new=mocker.MagicMock(), + ) + security_manager.get_catalogs_accessible_by_user.return_value = {"db2"} + + response = client.get("/api/v1/database/1/catalogs/") + assert response.status_code == 500 + assert response.json == { + "errors": [ + { + "message": "You don't have permission to access the data.", + "error_type": "OAUTH2_REDIRECT", + "level": "warning", + "extra": { + "url": "url", + "tab_id": "tab_id", + "redirect_uri": "redirect_uri", + }, + } + ] + } + + def test_schemas( mocker: MockerFixture, client: Any, @@ -2168,3 +2209,46 @@ def test_schemas( "catalog2", {"schema1", "schema2"}, ) + + +def test_schemas_with_oauth2( + mocker: MockerFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `schemas` endpoint when OAuth2 is needed. + """ + from superset.databases.api import DatabaseRestApi + + database = mocker.MagicMock() + database.get_all_schema_names.side_effect = OAuth2RedirectError( + "url", + "tab_id", + "redirect_uri", + ) + datamodel = mocker.patch.object(DatabaseRestApi, "datamodel") + datamodel.get.return_value = database + + security_manager = mocker.patch( + "superset.databases.api.security_manager", + new=mocker.MagicMock(), + ) + security_manager.get_schemas_accessible_by_user.return_value = {"schema2"} + + response = client.get("/api/v1/database/1/schemas/") + assert response.status_code == 500 + assert response.json == { + "errors": [ + { + "message": "You don't have permission to access the data.", + "error_type": "OAUTH2_REDIRECT", + "level": "warning", + "extra": { + "url": "url", + "tab_id": "tab_id", + "redirect_uri": "redirect_uri", + }, + } + ] + } diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 6f588cde2408c..0346020c51f7c 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -25,6 +25,7 @@ from sqlalchemy.engine.url import make_url from superset.connectors.sqla.models import SqlaTable, TableColumn +from superset.errors import SupersetErrorType from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.models.core import Database from superset.sql_parse import Table @@ -38,7 +39,7 @@ "secret": "my_client_secret", "authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize", "token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request", - "scope": "refresh_token session:role:SYSADMIN", + "scope": "refresh_token session:role:USERADMIN", } } @@ -306,6 +307,80 @@ def test_get_all_catalog_names(mocker: MockerFixture) -> None: get_inspector.assert_called_with(ssh_tunnel=None) +def test_get_all_schema_names_needs_oauth2(mocker: MockerFixture) -> None: + """ + Test the `get_all_schema_names` method when OAuth2 is needed. + """ + database = Database( + database_name="db", + sqlalchemy_uri="snowflake://:@abcd1234.snowflakecomputing.com/db", + encrypted_extra=json.dumps(oauth2_client_info), + ) + + class DriverSpecificError(Exception): + """ + A custom exception that is raised by the Snowflake driver. + """ + + mocker.patch.object( + database.db_engine_spec, + "oauth2_exception", + DriverSpecificError, + ) + mocker.patch.object( + database.db_engine_spec, + "get_schema_names", + side_effect=DriverSpecificError("User needs to authenticate"), + ) + mocker.patch.object(database, "get_inspector") + user = mocker.MagicMock() + user.id = 42 + mocker.patch("superset.db_engine_specs.base.g", user=user) + + with pytest.raises(OAuth2RedirectError) as excinfo: + database.get_all_schema_names() + + assert excinfo.value.message == "You don't have permission to access the data." + assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT + + +def test_get_all_catalog_names_needs_oauth2(mocker: MockerFixture) -> None: + """ + Test the `get_all_catalog_names` method when OAuth2 is needed. + """ + database = Database( + database_name="db", + sqlalchemy_uri="snowflake://:@abcd1234.snowflakecomputing.com/db", + encrypted_extra=json.dumps(oauth2_client_info), + ) + + class DriverSpecificError(Exception): + """ + A custom exception that is raised by the Snowflake driver. + """ + + mocker.patch.object( + database.db_engine_spec, + "oauth2_exception", + DriverSpecificError, + ) + mocker.patch.object( + database.db_engine_spec, + "get_catalog_names", + side_effect=DriverSpecificError("User needs to authenticate"), + ) + mocker.patch.object(database, "get_inspector") + user = mocker.MagicMock() + user.id = 42 + mocker.patch("superset.db_engine_specs.base.g", user=user) + + with pytest.raises(OAuth2RedirectError) as excinfo: + database.get_all_catalog_names() + + assert excinfo.value.message == "You don't have permission to access the data." + assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT + + def test_get_sqla_engine(mocker: MockerFixture) -> None: """ Test `_get_sqla_engine`. @@ -425,7 +500,7 @@ def test_get_oauth2_config(app_context: None) -> None: "secret": "my_client_secret", "authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize", "token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request", - "scope": "refresh_token session:role:SYSADMIN", + "scope": "refresh_token session:role:USERADMIN", "redirect_uri": "http://example.com/api/v1/database/oauth2/", }