Skip to content

Commit

Permalink
fix: search_path in RDS (apache#24739)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Jul 20, 2023
1 parent 24f6360 commit ff5373d
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 85 deletions.
33 changes: 28 additions & 5 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,22 +1082,45 @@ def adjust_engine_params( # pylint: disable=unused-argument
For some databases (like MySQL, Presto, Snowflake) this requires modifying the
SQLAlchemy URI before creating the connection. For others (like Postgres), it
requires additional parameters in ``connect_args``.
requires additional parameters in ``connect_args`` or running pre-session
queries with ``set`` parameters.
When a DB engine spec implements this method it should also have the attribute
``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
given query is running in order to enforce permissions (see #23385 and #23401).
When a DB engine spec implements this method or ``get_prequeries`` (see below) it
should also have the attribute ``supports_dynamic_schema`` set to true, so that
Superset knows in which schema a given query is running in order to enforce
permissions (see #23385 and #23401).
Currently, changing the catalog is not supported. The method accepts a catalog so
that when catalog support is added to Superset the interface remains the same.
This is important because DB engine specs can be installed from 3rd party
packages.
packages, so we want to keep these methods as stable as possible.
"""
return uri, {
**connect_args,
**cls.enforce_uri_query_params.get(uri.get_driver_name(), {}),
}

@classmethod
def get_prequeries(
cls,
catalog: str | None = None, # pylint: disable=unused-argument
schema: str | None = None, # pylint: disable=unused-argument
) -> list[str]:
"""
Return pre-session queries.
These are currently used as an alternative to ``adjust_engine_params`` for
databases where the selected schema cannot be specified in the SQLAlchemy URI or
connection arguments.
For example, in order to specify a default schema in RDS we need to run a query
at the beggining of the session:
sql> set search_path = my_schema;
"""
return []

@classmethod
def patch(cls) -> None:
"""
Expand Down
88 changes: 60 additions & 28 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import json
import logging
import re
from datetime import datetime
from re import Pattern
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, TYPE_CHECKING

import sqlparse
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
Expand All @@ -30,8 +34,8 @@

from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.errors import SupersetErrorType
from superset.exceptions import SupersetException
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.models.sql_lab import Query
from superset.utils import core as utils
from superset.utils.core import GenericDataType
Expand Down Expand Up @@ -169,9 +173,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
}

@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
) -> list[tuple[Any, ...]]:
def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]:
if not cursor.description:
return []
return super().fetch_data(cursor, limit)
Expand Down Expand Up @@ -224,7 +226,7 @@ def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
connect_args: dict[str, Any],
) -> Optional[str]:
) -> str | None:
"""
Return the configured schema.
Expand All @@ -237,6 +239,9 @@ def get_schema_from_engine_params(
in Superset because it breaks schema-level permissions, since it's impossible
to determine the schema for a non-qualified table in a query. In cases like
that we raise an exception.
Note that because the DB engine supports dynamic schema this method is never
called. It's left here as an implementation reference.
"""
options = parse_options(connect_args)
if search_path := options.get("search_path"):
Expand All @@ -252,23 +257,50 @@ def get_schema_from_engine_params(
return None

@classmethod
def adjust_engine_params(
def get_default_schema_for_query(
cls,
uri: URL,
connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> tuple[URL, dict[str, Any]]:
if not schema:
return uri, connect_args
database: Database,
query: Query,
) -> str | None:
"""
Return the default schema for a given query.
options = parse_options(connect_args)
options["search_path"] = schema
connect_args["options"] = " ".join(
f"-c{key}={value}" for key, value in options.items()
)
This method simply uses the parent method after checking that there are no
malicious path setting in the query.
"""
sql = sqlparse.format(query.sql, strip_comments=True)
if re.search(r"set\s+search_path\s*=", sql, re.IGNORECASE):
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
message=__(
"Users are not allowed to set a search path for security reasons."
),
level=ErrorLevel.ERROR,
)
)

return super().get_default_schema_for_query(database, query)

return uri, connect_args
@classmethod
def get_prequeries(
cls,
catalog: str | None = None,
schema: str | None = None,
) -> list[str]:
"""
Set the search path to the specified schema.
This is important for two reasons: in SQL Lab it will allow queries to run in
the schema selected in the dropdown, resolving unqualified table names to the
expected schema.
But more importantly, in SQL Lab this is used to check if the user has access to
any tables with unqualified names. If the schema is not set by SQL Lab it could
be anything, and we would have to block users from running any queries
referencing tables without an explicit schema.
"""
return [f'set search_path = "{schema}"'] if schema else []

@classmethod
def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
Expand Down Expand Up @@ -298,7 +330,7 @@ def query_cost_formatter(
@classmethod
def get_catalog_names(
cls,
database: "Database",
database: Database,
inspector: Inspector,
) -> list[str]:
"""
Expand All @@ -318,7 +350,7 @@ def get_catalog_names(

@classmethod
def get_table_names(
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
cls, database: Database, inspector: PGInspector, schema: str | None
) -> set[str]:
"""Need to consider foreign tables for PostgreSQL"""
return set(inspector.get_table_names(schema)) | set(
Expand All @@ -327,8 +359,8 @@ def get_table_names(

@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, Date):
Expand All @@ -339,7 +371,7 @@ def convert_dttm(
return None

@staticmethod
def get_extra_params(database: "Database") -> dict[str, Any]:
def get_extra_params(database: Database) -> dict[str, Any]:
"""
For Postgres, the path to a SSL certificate is placed in `connect_args`.
Expand All @@ -363,7 +395,7 @@ def get_extra_params(database: "Database") -> dict[str, Any]:
return extra

@classmethod
def get_datatype(cls, type_code: Any) -> Optional[str]:
def get_datatype(cls, type_code: Any) -> str | None:
# pylint: disable=import-outside-toplevel
from psycopg2.extensions import binary_types, string_types

Expand All @@ -374,7 +406,7 @@ def get_datatype(cls, type_code: Any) -> Optional[str]:
return None

@classmethod
def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
def get_cancel_query_id(cls, cursor: Any, query: Query) -> str | None:
"""
Get Postgres PID that will be used to cancel all other running
queries in the same session.
Expand Down
Loading

0 comments on commit ff5373d

Please sign in to comment.