From c2b2644948d6766c5d94d6070a096139a88766af Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 16 Mar 2023 15:04:46 -0700 Subject: [PATCH 1/4] feat(postgresql): dynamic schema --- superset/db_engine_specs/base.py | 11 ++-- superset/db_engine_specs/drill.py | 17 +++--- superset/db_engine_specs/hive.py | 15 +++-- superset/db_engine_specs/mysql.py | 13 +++-- superset/db_engine_specs/postgres.py | 56 ++++++++++++++----- superset/db_engine_specs/presto.py | 19 ++++--- superset/db_engine_specs/snowflake.py | 21 ++++--- superset/models/core.py | 32 +++++++---- .../db_engine_specs/test_postgres.py | 23 ++++++++ 9 files changed, 143 insertions(+), 64 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 26dd169dc06ad..efe7d68da2182 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -371,7 +371,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods supports_file_upload = True # Is the DB engine spec able to change the default schema? This requires implementing - # a custom `adjust_database_uri` method. + # a custom `adjust_engine_params` method. supports_dynamic_schema = False @classmethod @@ -1057,11 +1057,12 @@ def extract_errors( ] @classmethod - def adjust_database_uri( # pylint: disable=unused-argument + def adjust_engine_params( # pylint: disable=unused-argument cls, uri: URL, - selected_schema: Optional[str], - ) -> URL: + connect_args: Dict[str, Any], + schema: Optional[str], + ) -> Tuple[URL, Dict[str, Any]]: """ Return a modified URL with a new database component. @@ -1080,7 +1081,7 @@ def adjust_database_uri( # pylint: disable=unused-argument Some database drivers like Presto accept '{catalog}/{schema}' in the database component of the URL, that can be handled here. """ - return uri + return uri, connect_args @classmethod def patch(cls) -> None: diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py index 4ae5ae59b301e..5f7e9baf51170 100644 --- a/superset/db_engine_specs/drill.py +++ b/superset/db_engine_specs/drill.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from urllib import parse from sqlalchemy import types @@ -71,13 +71,16 @@ def convert_dttm( return None @classmethod - def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL: - if selected_schema: - uri = uri.set( - database=parse.quote(selected_schema.replace(".", "/"), safe="") - ) + def adjust_engine_params( + cls, + uri: URL, + connect_args: Dict[str, Any], + schema: Optional[str], + ) -> Tuple[URL, Dict[str, Any]]: + if schema: + uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe="")) - return uri + return uri, connect_args @classmethod def get_schema_from_engine_params( diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index f90d889f8cc5a..f70ff2db03dfd 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -260,13 +260,16 @@ def convert_dttm( return None @classmethod - def adjust_database_uri( - cls, uri: URL, selected_schema: Optional[str] = None - ) -> URL: - if selected_schema: - uri = uri.set(database=parse.quote(selected_schema, safe="")) + def adjust_engine_params( + cls, + uri: URL, + connect_args: Dict[str, Any], + schema: Optional[str] = None, + ) -> Tuple[URL, Dict[str, Any]]: + if schema: + uri = uri.set(database=parse.quote(schema, safe="")) - return uri + return uri, connect_args @classmethod def get_schema_from_engine_params( diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 04b8c68dd7503..572db3aab14a4 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -191,15 +191,16 @@ def convert_dttm( return None @classmethod - def adjust_database_uri( + def adjust_engine_params( cls, uri: URL, - selected_schema: Optional[str] = None, - ) -> URL: - if selected_schema: - uri = uri.set(database=parse.quote(selected_schema, safe="")) + connect_args: Dict[str, Any], + schema: Optional[str] = None, + ) -> Tuple[URL, Dict[str, Any]]: + if schema: + uri = uri.set(database=parse.quote(schema, safe="")) - return uri + return uri, connect_args @classmethod def get_schema_from_engine_params( diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 84ddf56e10f00..3999c98e64967 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -78,6 +78,8 @@ class PostgresBaseEngineSpec(BaseEngineSpec): engine = "" engine_name = "PostgreSQL" + supports_dynamic_schema = True + _time_grain_expressions = { None: "{col}", "PT1S": "DATE_TRUNC('second', {col})", @@ -147,6 +149,30 @@ class PostgresBaseEngineSpec(BaseEngineSpec): ), } + @classmethod + def adjust_engine_params( + cls, + uri: URL, + connect_args: Dict[str, Any], + schema: Optional[str] = None, + ) -> Tuple[URL, Dict[str, Any]]: + if not schema: + return uri, connect_args + + options = dict( + [ + tuple(token.strip() for token in option.strip().split("=", 1)) + for option in re.split(r"-c\s?", connect_args.get("options", "")) + if "=" in option + ] + ) + options["search_path"] = schema + connect_args["options"] = " ".join( + f"-c{key}={value}" for key, value in options.items() + ) + + return uri, connect_args + @classmethod def get_schema_from_engine_params( cls, @@ -166,19 +192,23 @@ def get_schema_from_engine_params( to determine the schema for a non-qualified table in a query. In cases like that we raise an exception. """ - options = re.split(r"-c\s?", connect_args.get("options", "")) - for option in options: - if "=" not in option: - continue - key, value = option.strip().split("=", 1) - if key.strip() == "search_path": - if "," in value: - raise Exception( - "Multiple schemas are configured in the search path, which means " - "Superset is unable to determine the schema of unqualified table " - "names and enforce permissions." - ) - return value.strip() + options = dict( + [ + tuple(token.strip() for token in option.strip().split("=", 1)) + for option in re.split(r"-c\s?", connect_args.get("options", "")) + if "=" in option + ] + ) + + if search_path := options.get("search_path"): + schemas = search_path.split(",") + if len(schemas) > 1: + raise Exception( + "Multiple schemas are configured in the search path, which means " + "Superset is unable to determine the schema of unqualified table " + "names and enforce permissions." + ) + return schemas[0] return None diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index dd7bd88cdb478..4004483abf52f 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -301,19 +301,22 @@ def epoch_to_dttm(cls) -> str: return "from_unixtime({col})" @classmethod - def adjust_database_uri( - cls, uri: URL, selected_schema: Optional[str] = None - ) -> URL: + def adjust_engine_params( + cls, + uri: URL, + connect_args: Dict[str, Any], + schema: Optional[str] = None, + ) -> Tuple[URL, Dict[str, Any]]: database = uri.database - if selected_schema and database: - selected_schema = parse.quote(selected_schema, safe="") + if schema and database: + schema = parse.quote(schema, safe="") if "/" in database: - database = database.split("/")[0] + "/" + selected_schema + database = database.split("/")[0] + "/" + schema else: - database += "/" + selected_schema + database += "/" + schema uri = uri.set(database=database) - return uri + return uri, connect_args @classmethod def get_schema_from_engine_params( diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index ba15eea7fb508..f3bff7021cf6b 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -135,17 +135,20 @@ def get_extra_params(database: "Database") -> Dict[str, Any]: return extra @classmethod - def adjust_database_uri( - cls, uri: URL, selected_schema: Optional[str] = None - ) -> URL: + def adjust_engine_params( + cls, + uri: URL, + connect_args: Dict[str, Any], + schema: Optional[str] = None, + ) -> Tuple[URL, Dict[str, Any]]: database = uri.database - if "/" in uri.database: - database = uri.database.split("/")[0] - if selected_schema: - selected_schema = parse.quote(selected_schema, safe="") - uri = uri.set(database=f"{database}/{selected_schema}") + if "/" in database: + database = database.split("/")[0] + if schema: + schema = parse.quote(schema, safe="") + uri = uri.set(database=f"{database}/{schema}") - return uri + return uri, connect_args @classmethod def get_schema_from_engine_params( diff --git a/superset/models/core.py b/superset/models/core.py index 5717726edca96..45973e44d9061 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -421,32 +421,40 @@ def _get_sqla_engine( source: Optional[utils.QuerySource] = None, sqlalchemy_uri: Optional[str] = None, ) -> Engine: - extra = self.get_extra() sqlalchemy_url = make_url_safe( sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted ) self.db_engine_spec.validate_database_uri(sqlalchemy_url) - sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema) + extra = self.get_extra() + params = extra.get("engine_params", {}) + if nullpool: + params["poolclass"] = NullPool + connect_args = params.get("connect_args", {}) + + sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params( + sqlalchemy_url, + connect_args, + schema, + ) effective_username = self.get_effective_user(sqlalchemy_url) # If using MySQL or Presto for example, will set url.username # If using Hive, will not do anything yet since that relies on a # configuration parameter instead. sqlalchemy_url = self.db_engine_spec.get_url_for_impersonation( - sqlalchemy_url, self.impersonate_user, effective_username + sqlalchemy_url, + self.impersonate_user, + effective_username, ) masked_url = self.get_password_masked_url(sqlalchemy_url) logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url)) - params = extra.get("engine_params", {}) - if nullpool: - params["poolclass"] = NullPool - - connect_args = params.get("connect_args", {}) if self.impersonate_user: self.db_engine_spec.update_impersonation_config( - connect_args, str(sqlalchemy_url), effective_username + connect_args, + str(sqlalchemy_url), + effective_username, ) if connect_args: @@ -464,7 +472,11 @@ def _get_sqla_engine( source = utils.QuerySource.SQL_LAB sqlalchemy_url, params = DB_CONNECTION_MUTATOR( - sqlalchemy_url, params, effective_username, security_manager, source + sqlalchemy_url, + params, + effective_username, + security_manager, + source, ) try: return create_engine(sqlalchemy_url, **params) diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py index e57e6a6f8e23e..828f155cc9208 100644 --- a/tests/unit_tests/db_engine_specs/test_postgres.py +++ b/tests/unit_tests/db_engine_specs/test_postgres.py @@ -131,3 +131,26 @@ def test_get_schema_from_engine_params() -> None: "Superset is unable to determine the schema of unqualified table " "names and enforce permissions." ) + + +def test_adjust_engine_params() -> None: + """ + Test the ``adjust_engine_params`` method. + """ + from superset.db_engine_specs.postgres import PostgresEngineSpec + + uri = make_url("postgres://user:password@host/catalog") + + assert PostgresEngineSpec.adjust_engine_params(uri, {}, "secret") == ( + uri, + {"options": "-csearch_path=secret"}, + ) + + assert PostgresEngineSpec.adjust_engine_params( + uri, + {"foo": "bar", "options": "-csearch_path=default -c debug=1"}, + "secret", + ) == ( + uri, + {"foo": "bar", "options": "-csearch_path=secret -cdebug=1"}, + ) From c4c78202122fb5fe89eafb2ab74a8753e726540f Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 17 Mar 2023 08:39:30 -0700 Subject: [PATCH 2/4] Add catalog --- superset/db_engine_specs/base.py | 30 ++++++++++--------- superset/db_engine_specs/drill.py | 1 + superset/db_engine_specs/hive.py | 1 + superset/db_engine_specs/mysql.py | 1 + superset/db_engine_specs/postgres.py | 1 + superset/db_engine_specs/presto.py | 1 + superset/db_engine_specs/snowflake.py | 1 + superset/models/core.py | 24 +++++++++++++-- .../db_engine_specs/test_postgres.py | 3 +- 9 files changed, 45 insertions(+), 18 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index efe7d68da2182..1e57996998183 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -472,7 +472,7 @@ def get_default_schema_for_query( Determining the correct schema is crucial for managing access to data, so please make sure you understand this logic when working on a new DB engine spec. """ - # default schema varies on a per-query basis + # dynamic schema varies on a per-query basis if cls.supports_dynamic_schema: return query.schema @@ -1061,25 +1061,27 @@ def adjust_engine_params( # pylint: disable=unused-argument cls, uri: URL, connect_args: Dict[str, Any], + catalog: Optional[str], schema: Optional[str], ) -> Tuple[URL, Dict[str, Any]]: """ - Return a modified URL with a new database component. + Return a new URL and ``connect_args`` for a specific catalog/schema. - The URI here represents the URI as entered when saving the database, - ``selected_schema`` is the schema currently active presumably in - the SQL Lab dropdown. Based on that, for some database engine, - we can return a new altered URI that connects straight to the - active schema, meaning the users won't have to prefix the object - names by the schema name. + This is used in SQL Lab, allowing users to select a schema from the list of + schemas available in a given database, and have the query run with that schema as + the default one. - Some databases engines have 2 level of namespacing: database and - schema (postgres, oracle, mssql, ...) - For those it's probably better to not alter the database - component of the URI with the schema name, it won't work. + For some databases (like MySQL, Presto, Snowflake) this requires modifying the + SQLAlchemy URI before creating the connection. For others (like Postgres), it + requires additional parameters in ``connect_args``. - Some database drivers like Presto accept '{catalog}/{schema}' in - the database component of the URL, that can be handled here. + When a DB engine spec implements this method it should also have the attribute + ``supports_dynamic_schema`` set to true, so that Superset knows in which schema a + given query is running in order to enforce permissions (see #23385 and #23401). + + Currently, changing the catalog is not supported. The method acceps a catalog so + that when catalog support is added to Superse the interface remains the same. This + is important because DB engine specs can be installed from 3rd party packages. """ return uri, connect_args diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py index 5f7e9baf51170..ec7788ead2d07 100644 --- a/superset/db_engine_specs/drill.py +++ b/superset/db_engine_specs/drill.py @@ -75,6 +75,7 @@ def adjust_engine_params( cls, uri: URL, connect_args: Dict[str, Any], + catalog: Optional[str], schema: Optional[str], ) -> Tuple[URL, Dict[str, Any]]: if schema: diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index f70ff2db03dfd..792ef947350a1 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -264,6 +264,7 @@ def adjust_engine_params( cls, uri: URL, connect_args: Dict[str, Any], + catalog: Optional[str] = None, schema: Optional[str] = None, ) -> Tuple[URL, Dict[str, Any]]: if schema: diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 572db3aab14a4..e5ff964f868da 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -195,6 +195,7 @@ def adjust_engine_params( cls, uri: URL, connect_args: Dict[str, Any], + catalog: Optional[str] = None, schema: Optional[str] = None, ) -> Tuple[URL, Dict[str, Any]]: if schema: diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 3999c98e64967..7555b660116b5 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -154,6 +154,7 @@ def adjust_engine_params( cls, uri: URL, connect_args: Dict[str, Any], + catalog: Optional[str] = None, schema: Optional[str] = None, ) -> Tuple[URL, Dict[str, Any]]: if not schema: diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 4004483abf52f..c0b4f2c6dd881 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -305,6 +305,7 @@ def adjust_engine_params( cls, uri: URL, connect_args: Dict[str, Any], + catalog: Optional[str] = None, schema: Optional[str] = None, ) -> Tuple[URL, Dict[str, Any]]: database = uri.database diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index f3bff7021cf6b..033b637e48d1d 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -139,6 +139,7 @@ def adjust_engine_params( cls, uri: URL, connect_args: Dict[str, Any], + catalog: Optional[str] = None, schema: Optional[str] = None, ) -> Tuple[URL, Dict[str, Any]]: database = uri.database diff --git a/superset/models/core.py b/superset/models/core.py index 45973e44d9061..d7a38cdc033a5 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -432,11 +432,29 @@ def _get_sqla_engine( params["poolclass"] = NullPool connect_args = params.get("connect_args", {}) + # The ``adjust_database_uri`` method was renamed to ``adjust_engine_params`` and + # had its signature changed in order to support more DB engine specs. Since DB + # engine specs can be released as 3rd party modules we want to make sure the old + # method is still supported so we don't introduce a breaking change. + if hasattr(self.db_engine_spec, "adjust_database_uri"): + sqlalchemy_url = self.db_engine_spec.adjust_database_uri( + sqlalchemy_url, + schema, + ) + logger.warning( + "DB engine spec %s implements the method `adjust_database_uri`, which is " + "deprecated and will be removed in version 3.0. Please update it to " + "implement `adjust_engine_params` instead.", + self.db_engine_spec, + ) + sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params( - sqlalchemy_url, - connect_args, - schema, + uri=sqlalchemy_url, + connect_args=connect_args, + catalog=None, + schema=schema, ) + effective_username = self.get_effective_user(sqlalchemy_url) # If using MySQL or Presto for example, will set url.username # If using Hive, will not do anything yet since that relies on a diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py index 828f155cc9208..fef864795962e 100644 --- a/tests/unit_tests/db_engine_specs/test_postgres.py +++ b/tests/unit_tests/db_engine_specs/test_postgres.py @@ -141,7 +141,7 @@ def test_adjust_engine_params() -> None: uri = make_url("postgres://user:password@host/catalog") - assert PostgresEngineSpec.adjust_engine_params(uri, {}, "secret") == ( + assert PostgresEngineSpec.adjust_engine_params(uri, {}, None, "secret") == ( uri, {"options": "-csearch_path=secret"}, ) @@ -149,6 +149,7 @@ def test_adjust_engine_params() -> None: assert PostgresEngineSpec.adjust_engine_params( uri, {"foo": "bar", "options": "-csearch_path=default -c debug=1"}, + None, "secret", ) == ( uri, From 0835da1db5cc7787fa6027fdbeaaec9e14f6efe9 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 17 Mar 2023 09:27:11 -0700 Subject: [PATCH 3/4] DRY options parsing --- superset/db_engine_specs/postgres.py | 33 +++++++++++++++------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 7555b660116b5..fac0b1b1d0483 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -72,6 +72,22 @@ SYNTAX_ERROR_REGEX = re.compile('syntax error at or near "(?P.*?)"') +def parse_options(connect_args: Dict[str, Any]) -> Dict[str, str]: + """ + Parse ``options`` from ``connect_args`` into a dictionary. + """ + if not isinstance(connect_args.get("options"), str): + return {} + + tokens = ( + tuple(token.strip() for token in option.strip().split("=", 1)) + for option in re.split(r"-c\s?", connect_args["options"]) + if "=" in option + ) + + return {token[0]: token[1] for token in tokens} + + class PostgresBaseEngineSpec(BaseEngineSpec): """Abstract class for Postgres 'like' databases""" @@ -160,13 +176,7 @@ def adjust_engine_params( if not schema: return uri, connect_args - options = dict( - [ - tuple(token.strip() for token in option.strip().split("=", 1)) - for option in re.split(r"-c\s?", connect_args.get("options", "")) - if "=" in option - ] - ) + options = parse_options(connect_args) options["search_path"] = schema connect_args["options"] = " ".join( f"-c{key}={value}" for key, value in options.items() @@ -193,14 +203,7 @@ def get_schema_from_engine_params( to determine the schema for a non-qualified table in a query. In cases like that we raise an exception. """ - options = dict( - [ - tuple(token.strip() for token in option.strip().split("=", 1)) - for option in re.split(r"-c\s?", connect_args.get("options", "")) - if "=" in option - ] - ) - + options = parse_options(connect_args) if search_path := options.get("search_path"): schemas = search_path.split(",") if len(schemas) > 1: From 2db69f2694bc487afc3e251f2f6284d7250e79c2 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 17 Mar 2023 16:52:29 -0700 Subject: [PATCH 4/4] Add default values to schema/catalog --- superset/db_engine_specs/base.py | 4 ++-- superset/db_engine_specs/drill.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 1e57996998183..b8b1662057e3f 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1061,8 +1061,8 @@ def adjust_engine_params( # pylint: disable=unused-argument cls, uri: URL, connect_args: Dict[str, Any], - catalog: Optional[str], - schema: Optional[str], + catalog: Optional[str] = None, + schema: Optional[str] = None, ) -> Tuple[URL, Dict[str, Any]]: """ Return a new URL and ``connect_args`` for a specific catalog/schema. diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py index ec7788ead2d07..16ac89212ad13 100644 --- a/superset/db_engine_specs/drill.py +++ b/superset/db_engine_specs/drill.py @@ -75,8 +75,8 @@ def adjust_engine_params( cls, uri: URL, connect_args: Dict[str, Any], - catalog: Optional[str], - schema: Optional[str], + catalog: Optional[str] = None, + schema: Optional[str] = None, ) -> Tuple[URL, Dict[str, Any]]: if schema: uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe=""))