diff --git a/superset/config.py b/superset/config.py index 3511d67d192d6..6b33333546a28 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1061,6 +1061,18 @@ class CeleryConfig: # pylint: disable=too-few-public-methods "postgresql": "PostgreSQLValidator", } +# A list of preferred databases, in order. These databases will be +# displayed prominently in the "Add Database" dialog. You should +# use the "engine" attribute of the corresponding DB engine spec in +# `superset/db_engine_specs/`. +PREFERRED_DATABASES: List[str] = [ + # "postgresql", + # "presto", + # "mysql", + # "sqlite", + # etc. +] + # Do you want Talisman enabled? TALISMAN_ENABLED = False # If you want Talisman, how do you want it configured?? diff --git a/superset/databases/api.py b/superset/databases/api.py index 970f61accf0c8..cfc3697ce351f 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -18,7 +18,7 @@ import logging from datetime import datetime from io import BytesIO -from typing import Any, Optional +from typing import Any, Dict, List, Optional from zipfile import ZipFile from flask import g, request, Response, send_file @@ -27,7 +27,7 @@ from marshmallow import ValidationError from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError -from superset import event_logger +from superset import app, event_logger from superset.commands.exceptions import CommandInvalidError from superset.commands.importers.v1.utils import get_contents_from_bundle from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod @@ -63,6 +63,10 @@ TableMetadataResponseSchema, ) from superset.databases.utils import get_table_metadata +from superset.db_engine_specs import get_available_engine_specs +from superset.db_engine_specs.base import BaseParametersMixin +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetErrorException from superset.extensions import security_manager from superset.models.core import Database from superset.typing import FlaskResponse @@ -84,6 +88,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "test_connection", "related_objects", "function_names", + "available", } resource_name = "database" class_permission_name = "Database" @@ -821,7 +826,6 @@ def function_names(self, pk: int) -> Response: schema: type: integer responses: - 200: 200: description: Query result content: @@ -839,3 +843,67 @@ def function_names(self, pk: int) -> Response: if not database: return self.response_404() return self.response(200, function_names=database.function_names,) + + @expose("/available/", methods=["GET"]) + @protect() + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".available", + log_to_statsd=False, + ) + def available(self) -> Response: + """Return names of databases currently available + --- + get: + description: + Get names of databases currently available + responses: + 200: + description: Database names + content: + application/json: + schema: + type: array + items: + type: object + properties: + name: + description: Name of the database + type: string + preferred: + description: Is the database preferred? + type: bool + sqlalchemy_uri_placeholder: + description: Example placeholder for the SQLAlchemy URI + type: string + parameters: + description: JSON schema defining the needed parameters + 400: + $ref: '#/components/responses/400' + 500: + $ref: '#/components/responses/500' + """ + preferred_databases: List[str] = app.config.get("PREFERRED_DATABASES", []) + available_databases = [] + for engine_spec in get_available_engine_specs(): + payload: Dict[str, Any] = { + "name": engine_spec.engine_name, + "engine": engine_spec.engine, + "preferred": engine_spec.engine in preferred_databases, + } + + if issubclass(engine_spec, BaseParametersMixin): + payload["parameters"] = engine_spec.parameters_json_schema() + payload[ + "sqlalchemy_uri_placeholder" + ] = engine_spec.sqlalchemy_uri_placeholder + + available_databases.append(payload) + + available_databases.sort( + key=lambda payload: preferred_databases.index(payload["engine"]) + if payload["engine"] in preferred_databases + else len(preferred_databases) + ) + + return self.response(200, databases=available_databases) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 6c6acc73fb403..dcb257988e536 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -21,12 +21,14 @@ from flask import current_app from flask_babel import lazy_gettext as _ -from marshmallow import fields, Schema, validates_schema +from marshmallow import fields, pre_load, Schema, validates_schema from marshmallow.validate import Length, ValidationError from sqlalchemy import MetaData from sqlalchemy.engine.url import make_url from sqlalchemy.exc import ArgumentError +from superset.db_engine_specs import get_engine_specs +from superset.db_engine_specs.base import BaseParametersMixin from superset.exceptions import CertificateException, SupersetSecurityException from superset.models.core import PASSWORD_MASK from superset.security.analytics_db_safety import check_sqlalchemy_uri @@ -207,7 +209,72 @@ def extra_validator(value: str) -> str: return value -class DatabasePostSchema(Schema): +class DatabaseParametersSchemaMixin: + """ + Allow SQLAlchemy URI to be passed as separate parameters. + + This mixing is a first step in allowing the users to test, create and + edit databases without having to know how to write a SQLAlchemy URI. + Instead, each databases defines the parameters that it takes (eg, + username, password, host, etc.) and the SQLAlchemy URI is built from + these parameters. + + When using this mixin make sure that `sqlalchemy_uri` is not required. + """ + + parameters = fields.Dict( + keys=fields.Str(), + values=fields.Raw(), + description="DB-specific parameters for configuration", + ) + + # pylint: disable=no-self-use, unused-argument + @pre_load + def build_sqlalchemy_uri( + self, data: Dict[str, Any], **kwargs: Any + ) -> Dict[str, Any]: + """ + Build SQLAlchemy URI from separate parameters. + + This is used for databases that support being configured by individual + parameters (eg, username, password, host, etc.), instead of requiring + the constructed SQLAlchemy URI to be passed. + """ + parameters = data.pop("parameters", None) + if parameters: + if "engine" not in parameters: + raise ValidationError( + [ + _( + "An engine must be specified when passing " + "individual parameters to a database." + ) + ] + ) + engine = parameters["engine"] + + engine_specs = get_engine_specs() + if engine not in engine_specs: + raise ValidationError( + [_('Engine "%(engine)s" is not a valid engine.', engine=engine,)] + ) + engine_spec = engine_specs[engine] + if not issubclass(engine_spec, BaseParametersMixin): + raise ValidationError( + [ + _( + 'Engine spec "%(engine_spec)s" does not support ' + "being configured via individual parameters.", + engine_spec=engine_spec.__name__, + ) + ] + ) + + data["sqlalchemy_uri"] = engine_spec.build_sqlalchemy_url(parameters) + return data + + +class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin): database_name = fields.String( description=database_name_description, required=True, validate=Length(1, 250), ) @@ -242,12 +309,11 @@ class DatabasePostSchema(Schema): ) sqlalchemy_uri = fields.String( description=sqlalchemy_uri_description, - required=True, validate=[Length(1, 1024), sqlalchemy_uri_validator], ) -class DatabasePutSchema(Schema): +class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin): database_name = fields.String( description=database_name_description, allow_none=True, validate=Length(1, 250), ) @@ -282,12 +348,11 @@ class DatabasePutSchema(Schema): ) sqlalchemy_uri = fields.String( description=sqlalchemy_uri_description, - allow_none=True, validate=[Length(0, 1024), sqlalchemy_uri_validator], ) -class DatabaseTestConnectionSchema(Schema): +class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): database_name = fields.String( description=database_name_description, allow_none=True, validate=Length(1, 250), ) @@ -305,7 +370,6 @@ class DatabaseTestConnectionSchema(Schema): ) sqlalchemy_uri = fields.String( description=sqlalchemy_uri_description, - required=True, validate=[Length(1, 1024), sqlalchemy_uri_validator], ) diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py index 747543f820e63..a4e083cf6ed00 100644 --- a/superset/db_engine_specs/__init__.py +++ b/superset/db_engine_specs/__init__.py @@ -32,8 +32,9 @@ import pkgutil from importlib import import_module from pathlib import Path -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Set, Type +import sqlalchemy.databases from pkg_resources import iter_entry_points from superset.db_engine_specs.base import BaseEngineSpec @@ -67,7 +68,7 @@ def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]: try: engine_spec = ep.load() except Exception: # pylint: disable=broad-except - logger.warning("Unable to load engine spec: %s", engine_spec) + logger.warning("Unable to load Superset DB engine spec: %s", engine_spec) continue engine_specs.append(engine_spec) @@ -82,3 +83,23 @@ def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]: engine_specs_map[name] = engine_spec return engine_specs_map + + +def get_available_engine_specs() -> List[Type[BaseEngineSpec]]: + # native SQLAlchemy dialects + backends: Set[str] = { + getattr(sqlalchemy.databases, attr).dialect.name + for attr in sqlalchemy.databases.__all__ + } + + # installed 3rd-party dialects + for ep in iter_entry_points("sqlalchemy.dialects"): + try: + dialect = ep.load() + except Exception: # pylint: disable=broad-except + logger.warning("Unable to load SQLAlchemy dialect: %s", dialect) + else: + backends.add(dialect.name) + + engine_specs = get_engine_specs() + return [engine_specs[backend] for backend in backends if backend in engine_specs] diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 5cdb2a0307059..259dba14f17f3 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -30,6 +30,7 @@ NamedTuple, Optional, Pattern, + Set, Tuple, Type, TYPE_CHECKING, @@ -38,18 +39,22 @@ import pandas as pd import sqlparse +from apispec import APISpec +from apispec.ext.marshmallow import MarshmallowPlugin from flask import g from flask_babel import gettext as __, lazy_gettext as _ +from marshmallow import fields, Schema from sqlalchemy import column, DateTime, select, types from sqlalchemy.engine.base import Engine from sqlalchemy.engine.interfaces import Compiled, Dialect from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.engine.url import URL +from sqlalchemy.engine.url import make_url, URL from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import Session from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom from sqlalchemy.types import String, TypeEngine, UnicodeText +from typing_extensions import TypedDict from superset import app, security_manager, sql_parse from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -150,7 +155,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ engine = "base" # str as defined in sqlalchemy.engine.engine - engine_aliases: Optional[Tuple[str]] = None + engine_aliases: Set[str] = set() engine_name: Optional[ str ] = None # used for user messages, overridden in child classes @@ -1293,3 +1298,90 @@ def get_column_spec( sqla_type=column_type, generic_type=generic_type, is_dttm=is_dttm ) return None + + +# schema for adding a database by providing parameters instead of the +# full SQLAlchemy URI +class BaseParametersSchema(Schema): + username = fields.String(allow_none=True, description=__("Username")) + password = fields.String(allow_none=True, description=__("Password")) + host = fields.String(required=True, description=__("Hostname or IP address")) + port = fields.Integer(required=True, description=__("Database port")) + database = fields.String(required=True, description=__("Database name")) + query = fields.Dict( + keys=fields.Str(), values=fields.Raw(), description=__("Additinal parameters") + ) + + +class BaseParametersType(TypedDict, total=False): + username: Optional[str] + password: Optional[str] + host: str + port: int + database: str + query: Dict[str, Any] + + +class BaseParametersMixin: + + """ + Mixin for configuring DB engine specs via a dictionary. + + With this mixin the SQLAlchemy engine can be configured through + individual parameters, instead of the full SQLAlchemy URI. This + mixin is for the most common pattern of URI: + + drivername://user:password@host:port/dbname[?key=value&key=value...] + + """ + + # schema describing the parameters used to configure the DB + parameters_schema = BaseParametersSchema() + + # recommended driver name for the DB engine spec + drivername = "" + + # placeholder with the SQLAlchemy URI template + sqlalchemy_uri_placeholder = ( + "drivername://user:password@host:port/dbname[?key=value&key=value...]" + ) + + @classmethod + def build_sqlalchemy_url(cls, parameters: BaseParametersType) -> str: + return str( + URL( + cls.drivername, + username=parameters.get("username"), + password=parameters.get("password"), + host=parameters["host"], + port=parameters["port"], + database=parameters["database"], + query=parameters.get("query", {}), + ) + ) + + @classmethod + def get_parameters_from_uri(cls, uri: str) -> BaseParametersType: + url = make_url(uri) + return { + "username": url.username, + "password": url.password, + "host": url.host, + "port": url.port, + "database": url.database, + "query": url.query, + } + + @classmethod + def parameters_json_schema(cls) -> Any: + """ + Return configuration parameters as OpenAPI. + """ + spec = APISpec( + title="Database Parameters", + version="1.0.0", + openapi_version="3.0.2", + plugins=[MarshmallowPlugin()], + ) + spec.components.schema(cls.__name__, schema=cls.parameters_schema) + return spec.to_dict()["components"]["schemas"][cls.__name__] diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 92c00013f2a47..c2e6776a54c11 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -37,7 +37,7 @@ from sqlalchemy.dialects.postgresql.base import PGInspector from sqlalchemy.types import String, TypeEngine -from superset.db_engine_specs.base import BaseEngineSpec +from superset.db_engine_specs.base import BaseEngineSpec, BaseParametersMixin from superset.errors import SupersetErrorType from superset.exceptions import SupersetException from superset.utils import core as utils @@ -143,9 +143,15 @@ def epoch_to_dttm(cls) -> str: return "(timestamp 'epoch' + {col} * interval '1 second')" -class PostgresEngineSpec(PostgresBaseEngineSpec): +class PostgresEngineSpec(PostgresBaseEngineSpec, BaseParametersMixin): engine = "postgresql" - engine_aliases = ("postgres",) + engine_aliases = {"postgres"} + + drivername = "postgresql+psycopg2" + sqlalchemy_uri_placeholder = ( + "postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]" + ) + max_column_name_length = 63 try_remove_schema_from_table_name = False diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py index 39a73e6e61d14..fc12fe0a4acd7 100644 --- a/tests/databases/api_tests.py +++ b/tests/databases/api_tests.py @@ -33,6 +33,8 @@ from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable +from superset.db_engine_specs.mysql import MySQLEngineSpec +from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.errors import SupersetError from superset.models.core import Database from superset.models.reports import ReportSchedule, ReportScheduleType @@ -613,7 +615,8 @@ def test_info_security_database(self): assert "can_read" in data["permissions"] assert "can_write" in data["permissions"] assert "can_function_names" in data["permissions"] - assert len(data["permissions"]) == 3 + assert "can_available" in data["permissions"] + assert len(data["permissions"]) == 4 def test_get_invalid_database_table_metadata(self): """ @@ -1245,3 +1248,65 @@ def test_function_names(self, mock_get_function_names): assert rv.status_code == 200 assert response == {"function_names": ["AVG", "MAX", "SUM"]} + + @mock.patch("superset.databases.api.get_available_engine_specs") + @mock.patch("superset.databases.api.app") + def test_available(self, app, get_available_engine_specs): + app.config = {"PREFERRED_DATABASES": ["postgresql"]} + get_available_engine_specs.return_value = [ + MySQLEngineSpec, + PostgresEngineSpec, + ] + + self.login(username="admin") + uri = "api/v1/database/available/" + + rv = self.client.get(uri) + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == { + "databases": [ + { + "engine": "postgresql", + "name": "PostgreSQL", + "parameters": { + "properties": { + "database": { + "description": "Database name", + "type": "string", + }, + "host": { + "description": "Hostname or IP address", + "type": "string", + }, + "password": { + "description": "Password", + "nullable": True, + "type": "string", + }, + "port": { + "description": "Database port", + "format": "int32", + "type": "integer", + }, + "query": { + "additionalProperties": {}, + "description": "Additinal parameters", + "type": "object", + }, + "username": { + "description": "Username", + "nullable": True, + "type": "string", + }, + }, + "required": ["database", "host", "port"], + "type": "object", + }, + "preferred": True, + "sqlalchemy_uri_placeholder": "postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]", + }, + {"engine": "mysql", "name": "MySQL", "preferred": False}, + ] + } diff --git a/tests/databases/schema_tests.py b/tests/databases/schema_tests.py new file mode 100644 index 0000000000000..6d173cc515c74 --- /dev/null +++ b/tests/databases/schema_tests.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import mock + +from marshmallow import fields, Schema, ValidationError + +from superset.databases.schemas import DatabaseParametersSchemaMixin +from superset.db_engine_specs.base import BaseParametersMixin + + +class DummySchema(Schema, DatabaseParametersSchemaMixin): + sqlalchemy_uri = fields.String() + + +class DummyEngine(BaseParametersMixin): + drivername = "dummy" + + +class InvalidEngine: + pass + + +@mock.patch("superset.databases.schemas.get_engine_specs") +def test_database_parameters_schema_mixin(get_engine_specs): + get_engine_specs.return_value = {"dummy_engine": DummyEngine} + payload = { + "parameters": { + "engine": "dummy_engine", + "username": "username", + "password": "password", + "host": "localhost", + "port": 12345, + "database": "dbname", + } + } + schema = DummySchema() + result = schema.load(payload) + assert result == { + "sqlalchemy_uri": "dummy://username:password@localhost:12345/dbname" + } + + +def test_database_parameters_schema_mixin_no_engine(): + payload = { + "parameters": { + "username": "username", + "password": "password", + "host": "localhost", + "port": 12345, + "dbname": "dbname", + } + } + schema = DummySchema() + try: + schema.load(payload) + except ValidationError as err: + assert err.messages == { + "_schema": [ + "An engine must be specified when passing individual parameters to a database." + ] + } + + +@mock.patch("superset.databases.schemas.get_engine_specs") +def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs): + get_engine_specs.return_value = {} + payload = { + "parameters": { + "engine": "dummy_engine", + "username": "username", + "password": "password", + "host": "localhost", + "port": 12345, + "dbname": "dbname", + } + } + schema = DummySchema() + try: + schema.load(payload) + except ValidationError as err: + assert err.messages == { + "_schema": ['Engine "dummy_engine" is not a valid engine.'] + } + + +@mock.patch("superset.databases.schemas.get_engine_specs") +def test_database_parameters_schema_no_mixin(get_engine_specs): + get_engine_specs.return_value = {"invalid_engine": InvalidEngine} + payload = { + "parameters": { + "engine": "invalid_engine", + "username": "username", + "password": "password", + "host": "localhost", + "port": 12345, + "database": "dbname", + } + } + schema = DummySchema() + try: + schema.load(payload) + except ValidationError as err: + assert err.messages == { + "_schema": [ + ( + 'Engine spec "InvalidEngine" does not support ' + "being configured via individual parameters." + ) + ] + } diff --git a/tests/db_engine_specs/postgres_tests.py b/tests/db_engine_specs/postgres_tests.py index b9fe0cadcc785..6c8b8e5840158 100644 --- a/tests/db_engine_specs/postgres_tests.py +++ b/tests/db_engine_specs/postgres_tests.py @@ -388,3 +388,44 @@ def test_extract_errors(self): }, ) ] + + +def test_base_parameters_mixin(): + parameters = { + "username": "username", + "password": "password", + "host": "localhost", + "port": 5432, + "database": "dbname", + "query": {"foo": "bar"}, + } + sqlalchemy_uri = PostgresEngineSpec.build_sqlalchemy_url(parameters) + assert ( + sqlalchemy_uri + == "postgresql+psycopg2://username:password@localhost:5432/dbname?foo=bar" + ) + + parameters_from_uri = PostgresEngineSpec.get_parameters_from_uri(sqlalchemy_uri) + assert parameters_from_uri == parameters + + json_schema = PostgresEngineSpec.parameters_json_schema() + assert json_schema == { + "type": "object", + "properties": { + "host": {"type": "string", "description": "Hostname or IP address"}, + "username": {"type": "string", "nullable": True, "description": "Username"}, + "password": {"type": "string", "nullable": True, "description": "Password"}, + "database": {"type": "string", "description": "Database name"}, + "query": { + "type": "object", + "description": "Additinal parameters", + "additionalProperties": {}, + }, + "port": { + "type": "integer", + "format": "int32", + "description": "Database port", + }, + }, + "required": ["database", "host", "port"], + }