From b2e8c6640529f7ad6d460828fd5aa5391eb7ff0c Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 23 Aug 2022 19:29:12 -0300 Subject: [PATCH] fix: improve get_db_engine_spec_for_backend --- superset/databases/commands/validate.py | 26 ++-------- superset/databases/schemas.py | 39 ++++++--------- superset/db_engine_specs/__init__.py | 35 +++++++------ superset/db_engine_specs/base.py | 58 +++++++++++++++++++++- superset/db_engine_specs/databricks.py | 19 +++++--- superset/db_engine_specs/shillelagh.py | 6 ++- superset/models/core.py | 13 ++--- tests/unit_tests/models/core_test.py | 65 ++++++++++++++++++++++++- 8 files changed, 183 insertions(+), 78 deletions(-) diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index a9f1633a18144..caddf48c4c129 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -29,7 +29,7 @@ ) from superset.databases.dao import DatabaseDAO from superset.databases.utils import make_url_safe -from superset.db_engine_specs import get_engine_specs +from superset.db_engine_specs import get_engine_spec from superset.db_engine_specs.base import BasicParametersMixin from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.extensions import event_logger @@ -45,25 +45,13 @@ def __init__(self, parameters: Dict[str, Any]): def run(self) -> None: engine = self._properties["engine"] - engine_specs = get_engine_specs() + driver = self._properties["driver"] if engine in BYPASS_VALIDATION_ENGINES: # Skip engines that are only validated onCreate return - if engine not in engine_specs: - raise InvalidEngineError( - SupersetError( - message=__( - 'Engine "%(engine)s" is not a valid engine.', - engine=engine, - ), - error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, - level=ErrorLevel.ERROR, - extra={"allowed": list(engine_specs), "provided": engine}, - ), - ) - engine_spec = engine_specs[engine] + engine_spec = get_engine_spec(engine, driver) if not hasattr(engine_spec, "parameters_schema"): raise InvalidEngineError( SupersetError( @@ -73,14 +61,6 @@ def run(self) -> None: ), error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, level=ErrorLevel.ERROR, - extra={ - "allowed": [ - name - for name, engine_spec in engine_specs.items() - if issubclass(engine_spec, BasicParametersMixin) - ], - "provided": engine, - }, ), ) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index aa88822a854df..3cad1c6f238ce 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -28,7 +28,7 @@ from superset import db from superset.databases.commands.exceptions import DatabaseInvalidError from superset.databases.utils import make_url_safe -from superset.db_engine_specs import BaseEngineSpec, get_engine_specs +from superset.db_engine_specs import BaseEngineSpec, get_engine_spec from superset.exceptions import CertificateException, SupersetSecurityException from superset.models.core import ConfigurationMethod, Database, PASSWORD_MASK from superset.security.analytics_db_safety import check_sqlalchemy_uri @@ -231,6 +231,7 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods """ engine = fields.String(allow_none=True, description="SQLAlchemy engine to use") + driver = fields.String(allow_none=True, description="SQLAlchemy driver to use") parameters = fields.Dict( keys=fields.String(), values=fields.Raw(), @@ -262,10 +263,20 @@ def build_sqlalchemy_uri( or parameters.pop("engine", None) or data.pop("backend", None) ) + if not engine: + raise ValidationError( + [ + _( + "An engine must be specified when passing " + "individual parameters to a database." + ) + ] + ) + driver = data.pop("driver", None) configuration_method = data.get("configuration_method") if configuration_method == ConfigurationMethod.DYNAMIC_FORM: - engine_spec = get_engine_spec(engine) + engine_spec = get_engine_spec(engine, driver) if not hasattr(engine_spec, "build_sqlalchemy_uri") or not hasattr( engine_spec, "parameters_schema" @@ -295,34 +306,12 @@ def build_sqlalchemy_uri( return data -def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]: - if not engine: - raise ValidationError( - [ - _( - "An engine must be specified when passing " - "individual parameters to a database." - ) - ] - ) - engine_specs = get_engine_specs() - if engine not in engine_specs: - raise ValidationError( - [ - _( - 'Engine "%(engine)s" is not a valid engine.', - engine=engine, - ) - ] - ) - return engine_specs[engine] - - class DatabaseValidateParametersSchema(Schema): class Meta: # pylint: disable=too-few-public-methods unknown = EXCLUDE engine = fields.String(required=True, description="SQLAlchemy engine to use") + driver = fields.String(allow_none=True, description="SQLAlchemy driver to use") parameters = fields.Dict( keys=fields.String(), values=fields.Raw(allow_none=True), diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py index dac700199557c..257f6481bccba 100644 --- a/superset/db_engine_specs/__init__.py +++ b/superset/db_engine_specs/__init__.py @@ -33,27 +33,34 @@ from collections import defaultdict from importlib import import_module from pathlib import Path -from typing import Any, Dict, List, Set, Type +from typing import Any, Dict, List, Optional, Set, Type import sqlalchemy.databases import sqlalchemy.dialects from pkg_resources import iter_entry_points from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.engine.url import URL from superset.db_engine_specs.base import BaseEngineSpec logger = logging.getLogger(__name__) -def is_engine_spec(attr: Any) -> bool: +def is_engine_spec(obj: Any) -> bool: + """ + Return true if a given object is a DB engine spec. + """ return ( - inspect.isclass(attr) - and issubclass(attr, BaseEngineSpec) - and attr != BaseEngineSpec + inspect.isclass(obj) + and issubclass(obj, BaseEngineSpec) + and obj != BaseEngineSpec ) def load_engine_specs() -> List[Type[BaseEngineSpec]]: + """ + Load all engine specs, native and 3rd party. + """ engine_specs: List[Type[BaseEngineSpec]] = [] # load standard engines @@ -78,20 +85,18 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]: return engine_specs -def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]: +def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]: + """ + Return the DB engine spec associated with a given SQLAlchemy URL. + """ engine_specs = load_engine_specs() - # build map from name/alias -> spec - engine_specs_map: Dict[str, Type[BaseEngineSpec]] = {} for engine_spec in engine_specs: - names = [engine_spec.engine] - if engine_spec.engine_aliases: - names.extend(engine_spec.engine_aliases) - - for name in names: - engine_specs_map[name] = engine_spec + if engine_spec.supports_backend(backend, driver): + return engine_spec - return engine_specs_map + # default to the generic DB engine spec + return BaseEngineSpec # there's a mismatch between the dialect name reported by the driver in these diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 368770e2612f5..1ce802781999a 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -183,9 +183,39 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods having to add the same aggregation in SELECT. """ + engine_name: Optional[str] = None # for user messages, overridden in child classes + + # Associate the DB engine spec to one or more SQLAlchemy dialects/drivers. For + # example, if a given DB engine spec has: + # + # class PostgresDBEngineSpec: + # engine = 'postgresql' + # engine_aliases = 'postgres' + # drivers = {'psycopg2', 'asyncpg'} + # + # It would be used for all the following SQLALchemy URIs: + # + # - postgres://user:password@host/db + # - postgresql://user:password@host/db + # - postgres+asyncpg://user:password@host/db + # - postgres+psycopg2://user:password@host/db + # - postgresql+asyncpg://user:password@host/db + # - postgresql+psycopg2://user:password@host/db + # + # Note that SQLAlchemy has a default driver when one is not specified: + # + # >>> from sqlalchemy.engine.url import make_url + # >>> make_url('postgres://').get_driver_name() + # 'psycopg2' + # + # The ``default_driver`` should point to the recomended driver, and is used by + # database creation modals where the user provides parameters to connect to the + # database, instead of providing the SQLAlchemy URI. engine = "base" # str as defined in sqlalchemy.engine.engine engine_aliases: Set[str] = set() - engine_name: Optional[str] = None # for user messages, overridden in child classes + drivers: Dict[str, str] = {} + default_driver: Optional[str] = None + _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} column_type_mappings: Tuple[ColumnTypeMapping, ...] = ( @@ -355,6 +385,32 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]] ] = {} + @classmethod + def supports_url(cls, url: URL) -> bool: + """ + Returns true if the DB engine spec supports a given SQLAlchemy URL. + """ + backend = url.get_backend_name() + driver = url.get_driver_name() + return cls.supports_backend(backend, driver) + + @classmethod + def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool: + """ + Returns true if the DB engine spec supports a given SQLAlchemy backend/driver. + """ + # check the backend first + if backend != cls.engine and backend not in cls.engine_aliases: + return False + + # originally DB engine specs didn't declare any drivers and the check was made + # only on the engine; if that's the case, ignore the driver for backwards + # compatibility + if not cls.drivers or driver is None: + return True + + return driver in cls.drivers + @classmethod def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: """ diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 79718c93f664c..90d90b9448fa7 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -47,18 +47,23 @@ class DatabricksHiveEngineSpec(HiveEngineSpec): - engine = "databricks" engine_name = "Databricks Interactive Cluster" - driver = "pyhive" + + engine = "databricks" + drivers = {"pyhive": "Hive driver for Interactive Cluster"} + default_driver = "pyhive" + _show_functions_column = "function" _time_grain_expressions = time_grain_expressions class DatabricksODBCEngineSpec(BaseEngineSpec): - engine = "databricks" engine_name = "Databricks SQL Endpoint" - driver = "pyodbc" + + engine = "databricks" + drivers = {"pyodbc": "ODBC driver for SQL endpoint"} + default_driver = "pyodbc" _time_grain_expressions = time_grain_expressions @@ -74,9 +79,11 @@ def epoch_to_dttm(cls) -> str: class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec): - engine = "databricks" engine_name = "Databricks Native Connector" - driver = "connector" + + engine = "databricks" + drivers = {"connector": "Native all-purpose driver"} + default_driver = "connector" @staticmethod def get_extra_params(database: "Database") -> Dict[str, Any]: diff --git a/superset/db_engine_specs/shillelagh.py b/superset/db_engine_specs/shillelagh.py index c6e6f618c7251..37301224484b7 100644 --- a/superset/db_engine_specs/shillelagh.py +++ b/superset/db_engine_specs/shillelagh.py @@ -20,7 +20,11 @@ class ShillelaghEngineSpec(SqliteEngineSpec): """Engine for shillelagh""" - engine = "shillelagh" engine_name = "Shillelagh" + engine = "shillelagh" + drivers = {"apsw": "SQLite driver"} + default_driver = "apsw" + sqlalchemy_uri_placeholder = "shillelagh://" + allows_joins = True allows_subqueries = True diff --git a/superset/models/core.py b/superset/models/core.py index b5a4aa6537da2..2df8b14113c7a 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -635,15 +635,16 @@ def get_all_schema_names( # pylint: disable=unused-argument @property def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]: - return self.get_db_engine_spec_for_backend(self.backend) + url = make_url_safe(self.sqlalchemy_uri_decrypted) + return self.get_db_engine_spec(url) @classmethod @memoized - def get_db_engine_spec_for_backend( - cls, backend: str - ) -> Type[db_engine_specs.BaseEngineSpec]: - engines = db_engine_specs.get_engine_specs() - return engines.get(backend, db_engine_specs.BaseEngineSpec) + def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]: + backend = url.get_backend_name() + driver = url.get_driver_name() + + return db_engine_specs.get_engine_spec(backend, driver) def grains(self) -> Tuple[TimeGrain, ...]: """Defines time granularity database-specific expressions. diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 3338ddcb61441..4ae429c677aab 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -59,7 +59,7 @@ def get_metrics( }, ] - database.get_db_engine_spec_for_backend = mocker.MagicMock( # type: ignore + database.get_db_engine_spec = mocker.MagicMock( # type: ignore return_value=CustomSqliteEngineSpec ) assert database.get_metrics("table") == [ @@ -70,3 +70,66 @@ def get_metrics( "verbose_name": "COUNT(DISTINCT user_id)", }, ] + + +def test_get_db_engine_spec(mocker: MockFixture) -> None: + """ + Tests for ``get_db_engine_spec``. + """ + from superset.db_engine_specs import BaseEngineSpec + from superset.models.core import Database + + # pylint: disable=abstract-method + class PostgresDBEngineSpec(BaseEngineSpec): + """ + A DB engine spec with drivers and a default driver. + """ + + engine = "postgresql" + engine_aliases = {"postgres"} + drivers = { + "psycopg2": "The default Postgres driver", + "asyncpg": "An async Postgres driver", + } + default_driver = "psycopg2" + + # pylint: disable=abstract-method + class OldDBEngineSpec(BaseEngineSpec): + """ + And old DB engine spec without drivers nor a default driver. + """ + + engine = "mysql" + + load_engine_specs = mocker.patch("superset.db_engine_specs.load_engine_specs") + load_engine_specs.return_value = [ + PostgresDBEngineSpec, + OldDBEngineSpec, + ] + + assert ( + Database(database_name="db", sqlalchemy_uri="postgresql://").db_engine_spec + == PostgresDBEngineSpec + ) + assert ( + Database( + database_name="db", sqlalchemy_uri="postgresql+psycopg2://" + ).db_engine_spec + == PostgresDBEngineSpec + ) + assert ( + Database( + database_name="db", sqlalchemy_uri="postgresql+asyncpg://" + ).db_engine_spec + == PostgresDBEngineSpec + ) + assert ( + Database(database_name="db", sqlalchemy_uri="mysql://").db_engine_spec + == OldDBEngineSpec + ) + assert ( + Database( + database_name="db", sqlalchemy_uri="mysql+mysqlconnector://" + ).db_engine_spec + == OldDBEngineSpec + )