Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(postgresql): dynamic schema #23401

Merged
merged 4 commits into from
Mar 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
supports_file_upload = True

# Is the DB engine spec able to change the default schema? This requires implementing
# a custom `adjust_database_uri` method.
# a custom `adjust_engine_params` method.
supports_dynamic_schema = False

@classmethod
Expand Down Expand Up @@ -472,7 +472,7 @@ def get_default_schema_for_query(
Determining the correct schema is crucial for managing access to data, so please
make sure you understand this logic when working on a new DB engine spec.
"""
# default schema varies on a per-query basis
# dynamic schema varies on a per-query basis
if cls.supports_dynamic_schema:
return query.schema

Expand Down Expand Up @@ -1057,30 +1057,33 @@ def extract_errors(
]

@classmethod
def adjust_database_uri( # pylint: disable=unused-argument
def adjust_engine_params( # pylint: disable=unused-argument
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@betodealmeida what do you think will be the impact of this change on any custom db engine specs that have been added to Superset instances?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, nm, I see how you covered this below.

cls,
uri: URL,
selected_schema: Optional[str],
) -> URL:
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
"""
Return a modified URL with a new database component.
Return a new URL and ``connect_args`` for a specific catalog/schema.

This is used in SQL Lab, allowing users to select a schema from the list of
schemas available in a given database, and have the query run with that schema as
the default one.

The URI here represents the URI as entered when saving the database,
``selected_schema`` is the schema currently active presumably in
the SQL Lab dropdown. Based on that, for some database engine,
we can return a new altered URI that connects straight to the
active schema, meaning the users won't have to prefix the object
names by the schema name.
For some databases (like MySQL, Presto, Snowflake) this requires modifying the
SQLAlchemy URI before creating the connection. For others (like Postgres), it
requires additional parameters in ``connect_args``.

Some databases engines have 2 level of namespacing: database and
schema (postgres, oracle, mssql, ...)
For those it's probably better to not alter the database
component of the URI with the schema name, it won't work.
When a DB engine spec implements this method it should also have the attribute
``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
given query is running in order to enforce permissions (see #23385 and #23401).

Some database drivers like Presto accept '{catalog}/{schema}' in
the database component of the URL, that can be handled here.
Currently, changing the catalog is not supported. The method acceps a catalog so
that when catalog support is added to Superse the interface remains the same. This
is important because DB engine specs can be installed from 3rd party packages.
"""
return uri
return uri, connect_args

@classmethod
def patch(cls) -> None:
Expand Down
18 changes: 11 additions & 7 deletions superset/db_engine_specs/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple
from urllib import parse

from sqlalchemy import types
Expand Down Expand Up @@ -71,13 +71,17 @@ def convert_dttm(
return None

@classmethod
def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL:
if selected_schema:
uri = uri.set(
database=parse.quote(selected_schema.replace(".", "/"), safe="")
)
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe=""))

return uri
return uri, connect_args

@classmethod
def get_schema_from_engine_params(
Expand Down
16 changes: 10 additions & 6 deletions superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,17 @@ def convert_dttm(
return None

@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema, safe=""))

return uri
return uri, connect_args

@classmethod
def get_schema_from_engine_params(
Expand Down
14 changes: 8 additions & 6 deletions superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,17 @@ def convert_dttm(
return None

@classmethod
def adjust_database_uri(
def adjust_engine_params(
cls,
uri: URL,
selected_schema: Optional[str] = None,
) -> URL:
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema, safe=""))

return uri
return uri, connect_args

@classmethod
def get_schema_from_engine_params(
Expand Down
60 changes: 47 additions & 13 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,30 @@
SYNTAX_ERROR_REGEX = re.compile('syntax error at or near "(?P<syntax_error>.*?)"')


def parse_options(connect_args: Dict[str, Any]) -> Dict[str, str]:
"""
Parse ``options`` from ``connect_args`` into a dictionary.
"""
if not isinstance(connect_args.get("options"), str):
return {}

tokens = (
tuple(token.strip() for token in option.strip().split("=", 1))
for option in re.split(r"-c\s?", connect_args["options"])
if "=" in option
)

return {token[0]: token[1] for token in tokens}


class PostgresBaseEngineSpec(BaseEngineSpec):
"""Abstract class for Postgres 'like' databases"""

engine = ""
engine_name = "PostgreSQL"

supports_dynamic_schema = True

_time_grain_expressions = {
None: "{col}",
"PT1S": "DATE_TRUNC('second', {col})",
Expand Down Expand Up @@ -147,6 +165,25 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
),
}

@classmethod
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
if not schema:
return uri, connect_args

options = parse_options(connect_args)
options["search_path"] = schema
connect_args["options"] = " ".join(
f"-c{key}={value}" for key, value in options.items()
)

return uri, connect_args

@classmethod
def get_schema_from_engine_params(
cls,
Expand All @@ -166,19 +203,16 @@ def get_schema_from_engine_params(
to determine the schema for a non-qualified table in a query. In cases like
that we raise an exception.
"""
options = re.split(r"-c\s?", connect_args.get("options", ""))
for option in options:
if "=" not in option:
continue
key, value = option.strip().split("=", 1)
if key.strip() == "search_path":
if "," in value:
raise Exception(
"Multiple schemas are configured in the search path, which means "
"Superset is unable to determine the schema of unqualified table "
"names and enforce permissions."
)
return value.strip()
options = parse_options(connect_args)
if search_path := options.get("search_path"):
schemas = search_path.split(",")
if len(schemas) > 1:
raise Exception(
"Multiple schemas are configured in the search path, which means "
"Superset is unable to determine the schema of unqualified table "
"names and enforce permissions."
)
return schemas[0]

return None

Expand Down
20 changes: 12 additions & 8 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,19 +301,23 @@ def epoch_to_dttm(cls) -> str:
return "from_unixtime({col})"

@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
database = uri.database
if selected_schema and database:
selected_schema = parse.quote(selected_schema, safe="")
if schema and database:
schema = parse.quote(schema, safe="")
if "/" in database:
database = database.split("/")[0] + "/" + selected_schema
database = database.split("/")[0] + "/" + schema
else:
database += "/" + selected_schema
database += "/" + schema
uri = uri.set(database=database)

return uri
return uri, connect_args

@classmethod
def get_schema_from_engine_params(
Expand Down
22 changes: 13 additions & 9 deletions superset/db_engine_specs/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,21 @@ def get_extra_params(database: "Database") -> Dict[str, Any]:
return extra

@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
def adjust_engine_params(
cls,
uri: URL,
connect_args: Dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
database = uri.database
if "/" in uri.database:
database = uri.database.split("/")[0]
if selected_schema:
selected_schema = parse.quote(selected_schema, safe="")
uri = uri.set(database=f"{database}/{selected_schema}")
if "/" in database:
database = database.split("/")[0]
if schema:
schema = parse.quote(schema, safe="")
uri = uri.set(database=f"{database}/{schema}")

return uri
return uri, connect_args

@classmethod
def get_schema_from_engine_params(
Expand Down
50 changes: 40 additions & 10 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,32 +421,58 @@ def _get_sqla_engine(
source: Optional[utils.QuerySource] = None,
sqlalchemy_uri: Optional[str] = None,
) -> Engine:
extra = self.get_extra()
sqlalchemy_url = make_url_safe(
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
)
self.db_engine_spec.validate_database_uri(sqlalchemy_url)

sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
extra = self.get_extra()
params = extra.get("engine_params", {})
if nullpool:
params["poolclass"] = NullPool
connect_args = params.get("connect_args", {})

# The ``adjust_database_uri`` method was renamed to ``adjust_engine_params`` and
# had its signature changed in order to support more DB engine specs. Since DB
# engine specs can be released as 3rd party modules we want to make sure the old
# method is still supported so we don't introduce a breaking change.
if hasattr(self.db_engine_spec, "adjust_database_uri"):
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(
sqlalchemy_url,
schema,
)
logger.warning(
"DB engine spec %s implements the method `adjust_database_uri`, which is "
"deprecated and will be removed in version 3.0. Please update it to "
"implement `adjust_engine_params` instead.",
self.db_engine_spec,
)

sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params(
uri=sqlalchemy_url,
connect_args=connect_args,
catalog=None,
schema=schema,
)

effective_username = self.get_effective_user(sqlalchemy_url)
# If using MySQL or Presto for example, will set url.username
# If using Hive, will not do anything yet since that relies on a
# configuration parameter instead.
sqlalchemy_url = self.db_engine_spec.get_url_for_impersonation(
sqlalchemy_url, self.impersonate_user, effective_username
sqlalchemy_url,
self.impersonate_user,
effective_username,
)

masked_url = self.get_password_masked_url(sqlalchemy_url)
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))

params = extra.get("engine_params", {})
if nullpool:
params["poolclass"] = NullPool

connect_args = params.get("connect_args", {})
if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args, str(sqlalchemy_url), effective_username
connect_args,
str(sqlalchemy_url),
effective_username,
)

if connect_args:
Expand All @@ -464,7 +490,11 @@ def _get_sqla_engine(
source = utils.QuerySource.SQL_LAB

sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
sqlalchemy_url, params, effective_username, security_manager, source
sqlalchemy_url,
params,
effective_username,
security_manager,
source,
)
try:
return create_engine(sqlalchemy_url, **params)
Expand Down
Loading