Skip to content

Commit

Permalink
feat: allow create/update OAuth2 DB
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Aug 31, 2024
1 parent 7855261 commit 33ba31c
Show file tree
Hide file tree
Showing 14 changed files with 618 additions and 29 deletions.
24 changes: 16 additions & 8 deletions superset/commands/database/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.db_engine_specs.base import GenericDBException
from superset.exceptions import SupersetErrorsException
from superset.exceptions import OAuth2RedirectError, SupersetErrorsException
from superset.extensions import event_logger, security_manager
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
Expand All @@ -55,13 +55,21 @@ class CreateDatabaseCommand(BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()

@transaction(on_error=partial(on_error, reraise=DatabaseCreateFailedError))
@transaction(
on_error=partial(on_error, reraise=DatabaseCreateFailedError),
allowed=(OAuth2RedirectError,),
)
def run(self) -> Model:
self.validate()

try:
# Test connection before starting create transaction
TestConnectionDatabaseCommand(self._properties).run()
except OAuth2RedirectError:
# If we can't connect to the database due to an OAuth2 error we can still
# save the database. Later, the user can sync permissions when setting up
# data access rules.
return self._create_database()
except (
SupersetErrorsException,
SSHTunnelingNotEnabledError,
Expand All @@ -80,12 +88,6 @@ def run(self) -> Model:
)
raise DatabaseConnectionFailedError() from ex

# when creating a new database we don't need to unmask encrypted extra
self._properties["encrypted_extra"] = self._properties.pop(
"masked_encrypted_extra",
"{}",
)

ssh_tunnel: Optional[SSHTunnel] = None

try:
Expand Down Expand Up @@ -195,6 +197,12 @@ def validate(self) -> None:
raise exception

def _create_database(self) -> Database:
# when creating a new database we don't need to unmask encrypted extra
self._properties["encrypted_extra"] = self._properties.pop(
"masked_encrypted_extra",
"{}",
)

database = DatabaseDAO.create(attributes=self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
return database
23 changes: 12 additions & 11 deletions superset/commands/database/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from superset.databases.utils import make_url_safe
from superset.errors import ErrorLevel, SupersetErrorType
from superset.exceptions import (
OAuth2RedirectError,
SupersetErrorsException,
SupersetSecurityException,
SupersetTimeoutException,
Expand Down Expand Up @@ -162,6 +163,13 @@ def ping(engine: Engine) -> bool:
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception as ex: # pylint: disable=broad-except
# If the connection failed because OAuth2 is needed, start the flow.
if (
database.is_oauth2_enabled()
and database.db_engine_spec.needs_oauth2(ex)
):
database.start_oauth2_dance()

alive = False
# So we stop losing the original message if any
ex_str = str(ex)
Expand Down Expand Up @@ -197,6 +205,8 @@ def ping(engine: Engine) -> bool:
# check for custom errors (wrong username, wrong password, etc)
errors = database.db_engine_spec.extract_errors(ex, self._context)
raise SupersetErrorsException(errors) from ex
except OAuth2RedirectError:
raise
except SupersetSecurityException as ex:
event_logger.log_with_context(
action=get_log_connection_action(
Expand All @@ -205,23 +215,14 @@ def ping(engine: Engine) -> bool:
engine=database.db_engine_spec.__name__,
)
raise DatabaseSecurityUnsafeError(message=str(ex)) from ex
except SupersetTimeoutException as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
),
engine=database.db_engine_spec.__name__,
)
# bubble up the exception to return a 408
raise
except SSHTunnelingNotEnabledError as ex:
except (SupersetTimeoutException, SSHTunnelingNotEnabledError) as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
),
engine=database.db_engine_spec.__name__,
)
# bubble up the exception to return a 400
# bubble up the exception to return proper status code
raise
except Exception as ex:
event_logger.log_with_context(
Expand Down
17 changes: 15 additions & 2 deletions superset/commands/database/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from superset.daos.dataset import DatasetDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.db_engine_specs.base import GenericDBException
from superset.exceptions import OAuth2RedirectError
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction

Expand All @@ -56,7 +57,10 @@ def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._model: Database | None = None

@transaction(on_error=partial(on_error, reraise=DatabaseUpdateFailedError))
@transaction(
on_error=partial(on_error, reraise=DatabaseUpdateFailedError),
allowed=(OAuth2RedirectError,),
)
def run(self) -> Model:
self._model = DatabaseDAO.find_by_id(self._model_id)

Expand All @@ -80,7 +84,10 @@ def run(self) -> Model:
database = DatabaseDAO.update(self._model, self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
try:
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
except OAuth2RedirectError:
pass

return database

Expand Down Expand Up @@ -123,6 +130,9 @@ def _get_catalog_names(
force=True,
ssh_tunnel=ssh_tunnel,
)
except OAuth2RedirectError:
# raise OAuth2 exceptions as-is
raise
except GenericDBException as ex:
raise DatabaseConnectionFailedError() from ex

Expand All @@ -141,6 +151,9 @@ def _get_schema_names(
catalog=catalog,
ssh_tunnel=ssh_tunnel,
)
except OAuth2RedirectError:
# raise OAuth2 exceptions as-is
raise
except GenericDBException as ex:
raise DatabaseConnectionFailedError() from ex

Expand Down
9 changes: 9 additions & 0 deletions superset/commands/database/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ def run(self) -> None:
with closing(engine.raw_connection()) as conn:
alive = engine.dialect.do_ping(conn)
except Exception as ex:
# If the connection failed because OAuth2 is needed, we can save the
# database and trigger the OAuth2 flow whenever a user tries to run a
# query.
if (
database.is_oauth2_enabled()
and database.db_engine_spec.needs_oauth2(ex)
):
return

url = make_url_safe(sqlalchemy_uri)
context = {
"hostname": url.host,
Expand Down
11 changes: 8 additions & 3 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
DatabaseNotFoundException,
InvalidPayloadSchemaError,
OAuth2Error,
OAuth2RedirectError,
SupersetErrorsException,
SupersetException,
SupersetSecurityException,
Expand Down Expand Up @@ -398,7 +399,6 @@ def get(self, pk: int, **kwargs: Any) -> Response:

@expose("/", methods=("POST",))
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post",
Expand Down Expand Up @@ -462,6 +462,8 @@ def post(self) -> FlaskResponse:
item["ssh_tunnel"] = mask_password_info(new_model.ssh_tunnel)

return self.response(201, id=new_model.id, result=item)
except OAuth2RedirectError:
raise
except DatabaseInvalidError as ex:
return self.response_422(message=ex.normalized_messages())
except DatabaseConnectionFailedError as ex:
Expand Down Expand Up @@ -621,7 +623,6 @@ def delete(self, pk: int) -> Response:

@expose("/<int:pk>/catalogs/")
@protect()
@safe
@rison(database_catalogs_query_schema)
@statsd_metrics
@event_logger.log_this_with_context(
Expand Down Expand Up @@ -680,12 +681,13 @@ def catalogs(self, pk: int, **kwargs: Any) -> FlaskResponse:
500,
message="There was an error connecting to the database",
)
except OAuth2RedirectError:
raise
except SupersetException as ex:
return self.response(ex.status, message=ex.message)

@expose("/<int:pk>/schemas/")
@protect()
@safe
@rison(database_schemas_query_schema)
@statsd_metrics
@event_logger.log_this_with_context(
Expand Down Expand Up @@ -746,6 +748,8 @@ def schemas(self, pk: int, **kwargs: Any) -> FlaskResponse:
return self.response(
500, message="There was an error connecting to the database"
)
except OAuth2RedirectError:
raise
except SupersetException as ex:
return self.response(ex.status, message=ex.message)

Expand Down Expand Up @@ -2069,6 +2073,7 @@ def available(self) -> Response:
"sqlalchemy_uri_placeholder": engine_spec.sqlalchemy_uri_placeholder,
"preferred": engine_spec.engine_name in preferred_databases,
"engine_information": engine_spec.get_public_information(),
"supports_oauth2": engine_spec.supports_oauth2,
}

if engine_spec.default_driver:
Expand Down
4 changes: 2 additions & 2 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# the user impersonation methods to handle personal tokens.
supports_oauth2 = False
oauth2_scope = ""
oauth2_authorization_request_uri = "" # pylint: disable=invalid-name
oauth2_token_request_uri = ""
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
oauth2_token_request_uri: str | None = None

# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
Expand Down
2 changes: 2 additions & 0 deletions superset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ class OAuth2RedirectError(SupersetErrorException):
See the `OAuth2RedirectMessage.tsx` component for more details of how this
information is handled.
TODO (betodealmeida): change status to 403.
"""

def __init__(self, url: str, tab_id: str, redirect_uri: str):
Expand Down
16 changes: 16 additions & 0 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,9 @@ def get_all_schema_names(
) as inspector:
return self.db_engine_spec.get_schema_names(inspector)
except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.start_oauth2_dance()

raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex

@cache_util.memoized_func(
Expand All @@ -865,6 +868,9 @@ def get_all_catalog_names(
with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector:
return self.db_engine_spec.get_catalog_names(self, inspector)
except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.start_oauth2_dance()

raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex

@property
Expand Down Expand Up @@ -1096,6 +1102,16 @@ def get_oauth2_config(self) -> OAuth2ClientConfig | None:

return self.db_engine_spec.get_oauth2_config()

def start_oauth2_dance(self) -> None:
"""
Start the OAuth2 dance.
This method is called when an OAuth2 error is encountered, and the database is
configured to use OAuth2 for authentication. It raises an exception that will
trigger the OAuth2 dance in the frontend.
"""
return self.db_engine_spec.start_oauth2_dance(self)


sqla.event.listen(Database, "after_insert", security_manager.database_after_insert)
sqla.event.listen(Database, "after_update", security_manager.database_after_update)
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/commands/databases/create_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytest_mock import MockerFixture

from superset.commands.database.create import CreateDatabaseCommand
from superset.exceptions import OAuth2RedirectError
from superset.extensions import security_manager


Expand Down Expand Up @@ -124,3 +125,33 @@ def test_create_permissions_without_catalog(
],
any_order=True,
)


def test_create_with_oauth2(
mocker: MockerFixture,
database_without_catalog: MockerFixture,
) -> None:
"""
Test that the database can be created even if OAuth2 is needed to connect.
"""
TestConnectionDatabaseCommand = mocker.patch(
"superset.commands.database.create.TestConnectionDatabaseCommand"
)
TestConnectionDatabaseCommand().run.side_effect = OAuth2RedirectError(
"url",
"tab_id",
"redirect_uri",
)
add_permission_view_menu = mocker.patch.object(
security_manager,
"add_permission_view_menu",
)

CreateDatabaseCommand(
{
"database_name": "test_database",
"sqlalchemy_uri": "sqlite://",
}
).run()

add_permission_view_menu.assert_not_called()
Loading

0 comments on commit 33ba31c

Please sign in to comment.