Skip to content

Commit

Permalink
feat(postgresql): dynamic schema
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Mar 17, 2023
1 parent 42e8d1b commit c2b2644
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 64 deletions.
11 changes: 6 additions & 5 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
supports_file_upload = True

# Is the DB engine spec able to change the default schema? This requires implementing
# a custom `adjust_database_uri` method.
# a custom `adjust_engine_params` method.
supports_dynamic_schema = False

@classmethod
Expand Down Expand Up @@ -1057,11 +1057,12 @@ def extract_errors(
]

@classmethod
def adjust_database_uri( # pylint: disable=unused-argument
def adjust_engine_params( # pylint: disable=unused-argument
cls,
uri: URL,
selected_schema: Optional[str],
) -> URL:
connect_args: Dict[str, Any],
schema: Optional[str],
) -> Tuple[URL, Dict[str, Any]]:
"""
Return a modified URL with a new database component.
Expand All @@ -1080,7 +1081,7 @@ def adjust_database_uri( # pylint: disable=unused-argument
Some database drivers like Presto accept '{catalog}/{schema}' in
the database component of the URL, that can be handled here.
"""
return uri
return uri, connect_args

@classmethod
def patch(cls) -> None:
Expand Down
17 changes: 10 additions & 7 deletions superset/db_engine_specs/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple
from urllib import parse

from sqlalchemy import types
Expand Down Expand Up @@ -71,13 +71,16 @@ def convert_dttm(
return None

@classmethod
def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL:
if selected_schema:
uri = uri.set(
database=parse.quote(selected_schema.replace(".", "/"), safe="")
)
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
schema: Optional[str],
) -> Tuple[URL, Dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe=""))

return uri
return uri, connect_args

@classmethod
def get_schema_from_engine_params(
Expand Down
15 changes: 9 additions & 6 deletions superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,16 @@ def convert_dttm(
return None

@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema, safe=""))

return uri
return uri, connect_args

@classmethod
def get_schema_from_engine_params(
Expand Down
13 changes: 7 additions & 6 deletions superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,16 @@ def convert_dttm(
return None

@classmethod
def adjust_database_uri(
def adjust_engine_params(
cls,
uri: URL,
selected_schema: Optional[str] = None,
) -> URL:
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))
connect_args: Dict[str, Any],
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema, safe=""))

return uri
return uri, connect_args

@classmethod
def get_schema_from_engine_params(
Expand Down
56 changes: 43 additions & 13 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
engine = ""
engine_name = "PostgreSQL"

supports_dynamic_schema = True

_time_grain_expressions = {
None: "{col}",
"PT1S": "DATE_TRUNC('second', {col})",
Expand Down Expand Up @@ -147,6 +149,30 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
),
}

@classmethod
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if not schema:
return uri, connect_args

options = dict(
[
tuple(token.strip() for token in option.strip().split("=", 1))
for option in re.split(r"-c\s?", connect_args.get("options", ""))
if "=" in option
]
)
options["search_path"] = schema
connect_args["options"] = " ".join(
f"-c{key}={value}" for key, value in options.items()
)

return uri, connect_args

@classmethod
def get_schema_from_engine_params(
cls,
Expand All @@ -166,19 +192,23 @@ def get_schema_from_engine_params(
to determine the schema for a non-qualified table in a query. In cases like
that we raise an exception.
"""
options = re.split(r"-c\s?", connect_args.get("options", ""))
for option in options:
if "=" not in option:
continue
key, value = option.strip().split("=", 1)
if key.strip() == "search_path":
if "," in value:
raise Exception(
"Multiple schemas are configured in the search path, which means "
"Superset is unable to determine the schema of unqualified table "
"names and enforce permissions."
)
return value.strip()
options = dict(
[
tuple(token.strip() for token in option.strip().split("=", 1))
for option in re.split(r"-c\s?", connect_args.get("options", ""))
if "=" in option
]
)

if search_path := options.get("search_path"):
schemas = search_path.split(",")
if len(schemas) > 1:
raise Exception(
"Multiple schemas are configured in the search path, which means "
"Superset is unable to determine the schema of unqualified table "
"names and enforce permissions."
)
return schemas[0]

return None

Expand Down
19 changes: 11 additions & 8 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,19 +301,22 @@ def epoch_to_dttm(cls) -> str:
return "from_unixtime({col})"

@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
database = uri.database
if selected_schema and database:
selected_schema = parse.quote(selected_schema, safe="")
if schema and database:
schema = parse.quote(schema, safe="")
if "/" in database:
database = database.split("/")[0] + "/" + selected_schema
database = database.split("/")[0] + "/" + schema
else:
database += "/" + selected_schema
database += "/" + schema
uri = uri.set(database=database)

return uri
return uri, connect_args

@classmethod
def get_schema_from_engine_params(
Expand Down
21 changes: 12 additions & 9 deletions superset/db_engine_specs/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,20 @@ def get_extra_params(database: "Database") -> Dict[str, Any]:
return extra

@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
database = uri.database
if "/" in uri.database:
database = uri.database.split("/")[0]
if selected_schema:
selected_schema = parse.quote(selected_schema, safe="")
uri = uri.set(database=f"{database}/{selected_schema}")
if "/" in database:
database = database.split("/")[0]
if schema:
schema = parse.quote(schema, safe="")
uri = uri.set(database=f"{database}/{schema}")

return uri
return uri, connect_args

@classmethod
def get_schema_from_engine_params(
Expand Down
32 changes: 22 additions & 10 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,32 +421,40 @@ def _get_sqla_engine(
source: Optional[utils.QuerySource] = None,
sqlalchemy_uri: Optional[str] = None,
) -> Engine:
extra = self.get_extra()
sqlalchemy_url = make_url_safe(
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
)
self.db_engine_spec.validate_database_uri(sqlalchemy_url)

sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
extra = self.get_extra()
params = extra.get("engine_params", {})
if nullpool:
params["poolclass"] = NullPool
connect_args = params.get("connect_args", {})

sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params(
sqlalchemy_url,
connect_args,
schema,
)
effective_username = self.get_effective_user(sqlalchemy_url)
# If using MySQL or Presto for example, will set url.username
# If using Hive, will not do anything yet since that relies on a
# configuration parameter instead.
sqlalchemy_url = self.db_engine_spec.get_url_for_impersonation(
sqlalchemy_url, self.impersonate_user, effective_username
sqlalchemy_url,
self.impersonate_user,
effective_username,
)

masked_url = self.get_password_masked_url(sqlalchemy_url)
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))

params = extra.get("engine_params", {})
if nullpool:
params["poolclass"] = NullPool

connect_args = params.get("connect_args", {})
if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args, str(sqlalchemy_url), effective_username
connect_args,
str(sqlalchemy_url),
effective_username,
)

if connect_args:
Expand All @@ -464,7 +472,11 @@ def _get_sqla_engine(
source = utils.QuerySource.SQL_LAB

sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
sqlalchemy_url, params, effective_username, security_manager, source
sqlalchemy_url,
params,
effective_username,
security_manager,
source,
)
try:
return create_engine(sqlalchemy_url, **params)
Expand Down
23 changes: 23 additions & 0 deletions tests/unit_tests/db_engine_specs/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,26 @@ def test_get_schema_from_engine_params() -> None:
"Superset is unable to determine the schema of unqualified table "
"names and enforce permissions."
)


def test_adjust_engine_params() -> None:
"""
Test the ``adjust_engine_params`` method.
"""
from superset.db_engine_specs.postgres import PostgresEngineSpec

uri = make_url("postgres://user:password@host/catalog")

assert PostgresEngineSpec.adjust_engine_params(uri, {}, "secret") == (
uri,
{"options": "-csearch_path=secret"},
)

assert PostgresEngineSpec.adjust_engine_params(
uri,
{"foo": "bar", "options": "-csearch_path=default -c debug=1"},
"secret",
) == (
uri,
{"foo": "bar", "options": "-csearch_path=secret -cdebug=1"},
)

0 comments on commit c2b2644

Please sign in to comment.