diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 0d778de43987a..af24e54790b9c 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1082,22 +1082,45 @@ def adjust_engine_params( # pylint: disable=unused-argument For some databases (like MySQL, Presto, Snowflake) this requires modifying the SQLAlchemy URI before creating the connection. For others (like Postgres), it - requires additional parameters in ``connect_args``. + requires additional parameters in ``connect_args`` or running pre-session + queries with ``set`` parameters. - When a DB engine spec implements this method it should also have the attribute - ``supports_dynamic_schema`` set to true, so that Superset knows in which schema a - given query is running in order to enforce permissions (see #23385 and #23401). + When a DB engine spec implements this method or ``get_prequeries`` (see below) it + should also have the attribute ``supports_dynamic_schema`` set to true, so that + Superset knows in which schema a given query is running in order to enforce + permissions (see #23385 and #23401). Currently, changing the catalog is not supported. The method accepts a catalog so that when catalog support is added to Superset the interface remains the same. This is important because DB engine specs can be installed from 3rd party - packages. + packages, so we want to keep these methods as stable as possible. """ return uri, { **connect_args, **cls.enforce_uri_query_params.get(uri.get_driver_name(), {}), } + @classmethod + def get_prequeries( + cls, + catalog: str | None = None, # pylint: disable=unused-argument + schema: str | None = None, # pylint: disable=unused-argument + ) -> list[str]: + """ + Return pre-session queries. + + These are currently used as an alternative to ``adjust_engine_params`` for + databases where the selected schema cannot be specified in the SQLAlchemy URI or + connection arguments. + + For example, in order to specify a default schema in RDS we need to run a query + at the beggining of the session: + + sql> set search_path = my_schema; + + """ + return [] + @classmethod def patch(cls) -> None: """ diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index cdd71fdfccbcc..642f84f58cd61 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -14,13 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + import json import logging import re from datetime import datetime from re import Pattern -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING +import sqlparse from flask_babel import gettext as __ from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON from sqlalchemy.dialects.postgresql.base import PGInspector @@ -30,8 +34,8 @@ from superset.constants import TimeGrain from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin -from superset.errors import SupersetErrorType -from superset.exceptions import SupersetException +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetException, SupersetSecurityException from superset.models.sql_lab import Query from superset.utils import core as utils from superset.utils.core import GenericDataType @@ -169,9 +173,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec): } @classmethod - def fetch_data( - cls, cursor: Any, limit: Optional[int] = None - ) -> list[tuple[Any, ...]]: + def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]: if not cursor.description: return [] return super().fetch_data(cursor, limit) @@ -224,7 +226,7 @@ def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, connect_args: dict[str, Any], - ) -> Optional[str]: + ) -> str | None: """ Return the configured schema. @@ -237,6 +239,9 @@ def get_schema_from_engine_params( in Superset because it breaks schema-level permissions, since it's impossible to determine the schema for a non-qualified table in a query. In cases like that we raise an exception. + + Note that because the DB engine supports dynamic schema this method is never + called. It's left here as an implementation reference. """ options = parse_options(connect_args) if search_path := options.get("search_path"): @@ -252,23 +257,50 @@ def get_schema_from_engine_params( return None @classmethod - def adjust_engine_params( + def get_default_schema_for_query( cls, - uri: URL, - connect_args: dict[str, Any], - catalog: Optional[str] = None, - schema: Optional[str] = None, - ) -> tuple[URL, dict[str, Any]]: - if not schema: - return uri, connect_args + database: Database, + query: Query, + ) -> str | None: + """ + Return the default schema for a given query. - options = parse_options(connect_args) - options["search_path"] = schema - connect_args["options"] = " ".join( - f"-c{key}={value}" for key, value in options.items() - ) + This method simply uses the parent method after checking that there are no + malicious path setting in the query. + """ + sql = sqlparse.format(query.sql, strip_comments=True) + if re.search(r"set\s+search_path\s*=", sql, re.IGNORECASE): + raise SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR, + message=__( + "Users are not allowed to set a search path for security reasons." + ), + level=ErrorLevel.ERROR, + ) + ) + + return super().get_default_schema_for_query(database, query) - return uri, connect_args + @classmethod + def get_prequeries( + cls, + catalog: str | None = None, + schema: str | None = None, + ) -> list[str]: + """ + Set the search path to the specified schema. + + This is important for two reasons: in SQL Lab it will allow queries to run in + the schema selected in the dropdown, resolving unqualified table names to the + expected schema. + + But more importantly, in SQL Lab this is used to check if the user has access to + any tables with unqualified names. If the schema is not set by SQL Lab it could + be anything, and we would have to block users from running any queries + referencing tables without an explicit schema. + """ + return [f'set search_path = "{schema}"'] if schema else [] @classmethod def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: @@ -298,7 +330,7 @@ def query_cost_formatter( @classmethod def get_catalog_names( cls, - database: "Database", + database: Database, inspector: Inspector, ) -> list[str]: """ @@ -318,7 +350,7 @@ def get_catalog_names( @classmethod def get_table_names( - cls, database: "Database", inspector: PGInspector, schema: Optional[str] + cls, database: Database, inspector: PGInspector, schema: str | None ) -> set[str]: """Need to consider foreign tables for PostgreSQL""" return set(inspector.get_table_names(schema)) | set( @@ -327,8 +359,8 @@ def get_table_names( @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, Date): @@ -339,7 +371,7 @@ def convert_dttm( return None @staticmethod - def get_extra_params(database: "Database") -> dict[str, Any]: + def get_extra_params(database: Database) -> dict[str, Any]: """ For Postgres, the path to a SSL certificate is placed in `connect_args`. @@ -363,7 +395,7 @@ def get_extra_params(database: "Database") -> dict[str, Any]: return extra @classmethod - def get_datatype(cls, type_code: Any) -> Optional[str]: + def get_datatype(cls, type_code: Any) -> str | None: # pylint: disable=import-outside-toplevel from psycopg2.extensions import binary_types, string_types @@ -374,7 +406,7 @@ def get_datatype(cls, type_code: Any) -> Optional[str]: return None @classmethod - def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: + def get_cancel_query_id(cls, cursor: Any, query: Query) -> str | None: """ Get Postgres PID that will be used to cancel all other running queries in the same session. diff --git a/superset/models/core.py b/superset/models/core.py index 4ff56145e1bd5..a8fe8b541156c 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -16,6 +16,9 @@ # under the License. # pylint: disable=line-too-long,too-many-lines """A collection of ORM sqlalchemy models for Superset""" + +from __future__ import annotations + import builtins import enum import json @@ -26,7 +29,7 @@ from copy import deepcopy from datetime import datetime from functools import lru_cache -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING import numpy import pandas as pd @@ -270,7 +273,7 @@ def driver(self) -> str: return self.url_object.get_driver_name() @property - def masked_encrypted_extra(self) -> Optional[str]: + def masked_encrypted_extra(self) -> str | None: return self.db_engine_spec.mask_encrypted_extra(self.encrypted_extra) @property @@ -315,7 +318,7 @@ def schema_cache_enabled(self) -> bool: return "schema_cache_timeout" in self.metadata_cache_timeout @property - def schema_cache_timeout(self) -> Optional[int]: + def schema_cache_timeout(self) -> int | None: return self.metadata_cache_timeout.get("schema_cache_timeout") @property @@ -323,7 +326,7 @@ def table_cache_enabled(self) -> bool: return "table_cache_timeout" in self.metadata_cache_timeout @property - def table_cache_timeout(self) -> Optional[int]: + def table_cache_timeout(self) -> int | None: return self.metadata_cache_timeout.get("table_cache_timeout") @property @@ -364,7 +367,7 @@ def set_sqlalchemy_uri(self, uri: str) -> None: conn = conn.set(password=PASSWORD_MASK if conn.password else None) self.sqlalchemy_uri = str(conn) # hides the password - def get_effective_user(self, object_url: URL) -> Optional[str]: + def get_effective_user(self, object_url: URL) -> str | None: """ Get the effective user, especially during impersonation. @@ -383,10 +386,10 @@ def get_effective_user(self, object_url: URL) -> Optional[str]: @contextmanager def get_sqla_engine_with_context( self, - schema: Optional[str] = None, + schema: str | None = None, nullpool: bool = True, - source: Optional[utils.QuerySource] = None, - override_ssh_tunnel: Optional["SSHTunnel"] = None, + source: utils.QuerySource | None = None, + override_ssh_tunnel: SSHTunnel | None = None, ) -> Engine: from superset.daos.database import ( # pylint: disable=import-outside-toplevel DatabaseDAO, @@ -425,10 +428,10 @@ def get_sqla_engine_with_context( def _get_sqla_engine( self, - schema: Optional[str] = None, + schema: str | None = None, nullpool: bool = True, - source: Optional[utils.QuerySource] = None, - sqlalchemy_uri: Optional[str] = None, + source: utils.QuerySource | None = None, + sqlalchemy_uri: str | None = None, ) -> Engine: sqlalchemy_url = make_url_safe( sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted @@ -513,17 +516,23 @@ def _get_sqla_engine( @contextmanager def get_raw_connection( self, - schema: Optional[str] = None, + schema: str | None = None, nullpool: bool = True, - source: Optional[utils.QuerySource] = None, + source: utils.QuerySource | None = None, ) -> Connection: with self.get_sqla_engine_with_context( schema=schema, nullpool=nullpool, source=source ) as engine: with closing(engine.raw_connection()) as conn: + # pre-session queries are used to set the selected schema and, in the + # future, the selected catalog + for prequery in self.db_engine_spec.get_prequeries(schema=schema): + cursor = conn.cursor() + cursor.execute(prequery) + yield conn - def get_default_schema_for_query(self, query: "Query") -> Optional[str]: + def get_default_schema_for_query(self, query: Query) -> str | None: """ Return the default schema for a given query. @@ -550,8 +559,8 @@ def get_reserved_words(self) -> set[str]: def get_df( # pylint: disable=too-many-locals self, sql: str, - schema: Optional[str] = None, - mutator: Optional[Callable[[pd.DataFrame], None]] = None, + schema: str | None = None, + mutator: Callable[[pd.DataFrame], None] | None = None, ) -> pd.DataFrame: sqls = self.db_engine_spec.parse_sql(sql) engine = self._get_sqla_engine(schema) @@ -614,7 +623,7 @@ def _log_query(sql: str) -> None: return df - def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str: + def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str: engine = self._get_sqla_engine(schema=schema) sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) @@ -628,12 +637,12 @@ def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str: def select_star( # pylint: disable=too-many-arguments self, table_name: str, - schema: Optional[str] = None, + schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, latest_partition: bool = False, - cols: Optional[list[ResultSetColumnType]] = None, + cols: list[ResultSetColumnType] | None = None, ) -> str: """Generates a ``select *`` statement in the proper dialect""" eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB) @@ -672,7 +681,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument self, schema: str, cache: bool = False, - cache_timeout: Optional[int] = None, + cache_timeout: int | None = None, force: bool = False, ) -> set[tuple[str, str]]: """Parameters need to be passed as keyword arguments. @@ -708,7 +717,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument self, schema: str, cache: bool = False, - cache_timeout: Optional[int] = None, + cache_timeout: int | None = None, force: bool = False, ) -> set[tuple[str, str]]: """Parameters need to be passed as keyword arguments. @@ -737,7 +746,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument @contextmanager def get_inspector_with_context( - self, ssh_tunnel: Optional["SSHTunnel"] = None + self, ssh_tunnel: SSHTunnel | None = None ) -> Inspector: with self.get_sqla_engine_with_context( override_ssh_tunnel=ssh_tunnel @@ -751,9 +760,9 @@ def get_inspector_with_context( def get_all_schema_names( # pylint: disable=unused-argument self, cache: bool = False, - cache_timeout: Optional[int] = None, + cache_timeout: int | None = None, force: bool = False, - ssh_tunnel: Optional["SSHTunnel"] = None, + ssh_tunnel: SSHTunnel | None = None, ) -> list[str]: """Parameters need to be passed as keyword arguments. @@ -818,7 +827,7 @@ def get_encrypted_extra(self) -> dict[str, Any]: def update_params_from_encrypted_extra(self, params: dict[str, Any]) -> None: self.db_engine_spec.update_params_from_encrypted_extra(self, params) - def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: + def get_table(self, table_name: str, schema: str | None = None) -> Table: extra = self.get_extra() meta = MetaData(**extra.get("metadata_params", {})) with self.get_sqla_engine_with_context() as engine: @@ -831,13 +840,13 @@ def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: ) def get_table_comment( - self, table_name: str, schema: Optional[str] = None - ) -> Optional[str]: + self, table_name: str, schema: str | None = None + ) -> str | None: with self.get_inspector_with_context() as inspector: return self.db_engine_spec.get_table_comment(inspector, table_name, schema) def get_columns( - self, table_name: str, schema: Optional[str] = None + self, table_name: str, schema: str | None = None ) -> list[ResultSetColumnType]: with self.get_inspector_with_context() as inspector: return self.db_engine_spec.get_columns(inspector, table_name, schema) @@ -845,19 +854,19 @@ def get_columns( def get_metrics( self, table_name: str, - schema: Optional[str] = None, + schema: str | None = None, ) -> list[MetricType]: with self.get_inspector_with_context() as inspector: return self.db_engine_spec.get_metrics(self, inspector, table_name, schema) def get_indexes( - self, table_name: str, schema: Optional[str] = None + self, table_name: str, schema: str | None = None ) -> list[dict[str, Any]]: with self.get_inspector_with_context() as inspector: return self.db_engine_spec.get_indexes(self, inspector, table_name, schema) def get_pk_constraint( - self, table_name: str, schema: Optional[str] = None + self, table_name: str, schema: str | None = None ) -> dict[str, Any]: with self.get_inspector_with_context() as inspector: pk_constraint = inspector.get_pk_constraint(table_name, schema) or {} @@ -871,7 +880,7 @@ def _convert(value: Any) -> Any: return {key: _convert(value) for key, value in pk_constraint.items()} def get_foreign_keys( - self, table_name: str, schema: Optional[str] = None + self, table_name: str, schema: str | None = None ) -> list[dict[str, Any]]: with self.get_inspector_with_context() as inspector: return inspector.get_foreign_keys(table_name, schema) @@ -926,7 +935,7 @@ def has_table(self, table: Table) -> bool: with self.get_sqla_engine_with_context() as engine: return engine.has_table(table.table_name, table.schema or None) - def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool: + def has_table_by_name(self, table_name: str, schema: str | None = None) -> bool: with self.get_sqla_engine_with_context() as engine: return engine.has_table(table_name, schema) @@ -936,7 +945,7 @@ def _has_view( conn: Connection, dialect: Dialect, view_name: str, - schema: Optional[str] = None, + schema: str | None = None, ) -> bool: view_names: list[str] = [] try: @@ -945,11 +954,11 @@ def _has_view( logger.warning("Has view failed", exc_info=True) return view_name in view_names - def has_view(self, view_name: str, schema: Optional[str] = None) -> bool: + def has_view(self, view_name: str, schema: str | None = None) -> bool: engine = self._get_sqla_engine() return engine.run_callable(self._has_view, engine.dialect, view_name, schema) - def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool: + def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool: return self.has_view(view_name=view_name, schema=schema) def get_dialect(self) -> Dialect: @@ -957,7 +966,7 @@ def get_dialect(self) -> Dialect: return sqla_url.get_dialect()() def make_sqla_column_compatible( - self, sqla_col: ColumnElement, label: Optional[str] = None + self, sqla_col: ColumnElement, label: str | None = None ) -> ColumnElement: """Takes a sqlalchemy column object and adds label info if supported by engine. :param sqla_col: sqlalchemy column instance diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py index 145d398898d13..59d1829f142d1 100644 --- a/tests/unit_tests/db_engine_specs/test_postgres.py +++ b/tests/unit_tests/db_engine_specs/test_postgres.py @@ -19,10 +19,12 @@ from typing import Any, Optional import pytest +from pytest_mock import MockFixture from sqlalchemy import types from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON from sqlalchemy.engine.url import make_url +from superset.exceptions import SupersetSecurityException from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( assert_column_spec, @@ -133,25 +135,41 @@ def test_get_schema_from_engine_params() -> None: ) -def test_adjust_engine_params() -> None: +def test_get_prequeries() -> None: """ - Test the ``adjust_engine_params`` method. + Test the ``get_prequeries`` method. """ from superset.db_engine_specs.postgres import PostgresEngineSpec - uri = make_url("postgres://user:password@host/catalog") + assert PostgresEngineSpec.get_prequeries() == [] + assert PostgresEngineSpec.get_prequeries(schema="test") == [ + 'set search_path = "test"' + ] - assert PostgresEngineSpec.adjust_engine_params(uri, {}, None, "secret") == ( - uri, - {"options": "-csearch_path=secret"}, - ) - assert PostgresEngineSpec.adjust_engine_params( - uri, - {"foo": "bar", "options": "-csearch_path=default -c debug=1"}, - None, - "secret", - ) == ( - uri, - {"foo": "bar", "options": "-csearch_path=secret -cdebug=1"}, +def test_get_default_schema_for_query(mocker: MockFixture) -> None: + """ + Test the ``get_default_schema_for_query`` method. + """ + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = mocker.MagicMock() + query = mocker.MagicMock() + + query.sql = "SELECT * FROM some_table" + query.schema = "foo" + assert PostgresEngineSpec.get_default_schema_for_query(database, query) == "foo" + + query.sql = """ +set +-- this is a tricky comment +search_path -- another one += bar; +SELECT * FROM some_table; + """ + with pytest.raises(SupersetSecurityException) as excinfo: + PostgresEngineSpec.get_default_schema_for_query(database, query) + assert ( + str(excinfo.value) + == "Users are not allowed to set a search path for security reasons." ) diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 267b7c024aae5..5d6c1fcbccbc5 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -212,3 +212,21 @@ def test_dttm_sql_literal( def test_table_column_database() -> None: database = Database(database_name="db") assert TableColumn(database=database).database is database + + +def test_get_prequeries(mocker: MockFixture) -> None: + """ + Tests for ``get_prequeries``. + """ + mocker.patch.object( + Database, + "get_sqla_engine_with_context", + ) + db_engine_spec = mocker.patch.object(Database, "db_engine_spec") + db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"] + + database = Database(database_name="db") + with database.get_raw_connection() as conn: + conn.cursor().execute.assert_has_calls( + [mocker.call("set a=1"), mocker.call("set b=2")] + )