From 8e04ad8a22dce50645199d5fad148dc9d15cbb3e Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 15 Mar 2023 11:34:11 -0700 Subject: [PATCH] fix: improve schema security --- superset/db_engine_specs/base.py | 64 ++++++++++++++++++- superset/db_engine_specs/drill.py | 17 ++++- superset/db_engine_specs/hive.py | 13 +++- superset/db_engine_specs/mysql.py | 19 +++++- superset/db_engine_specs/postgres.py | 36 +++++++++++ superset/db_engine_specs/presto.py | 23 ++++++- superset/db_engine_specs/snowflake.py | 18 +++++- superset/models/core.py | 17 +++++ superset/security/manager.py | 15 +---- .../unit_tests/db_engine_specs/test_drill.py | 16 +++++ tests/unit_tests/db_engine_specs/test_hive.py | 15 +++++ .../unit_tests/db_engine_specs/test_mysql.py | 14 ++++ .../db_engine_specs/test_postgres.py | 42 ++++++++++++ .../unit_tests/db_engine_specs/test_presto.py | 24 +++++++ .../db_engine_specs/test_snowflake.py | 32 ++++++++++ tests/unit_tests/explore/utils_test.py | 4 +- tests/unit_tests/security/manager_test.py | 3 +- 17 files changed, 346 insertions(+), 26 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 9d64ad8fb37ea..26dd169dc06ad 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -372,7 +372,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # Is the DB engine spec able to change the default schema? This requires implementing # a custom `adjust_database_uri` method. - dynamic_schema = False + supports_dynamic_schema = False @classmethod def supports_url(cls, url: URL) -> bool: @@ -426,6 +426,68 @@ def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool: return driver in cls.drivers + @classmethod + def get_default_schema(cls, database: Database) -> Optional[str]: + """ + Return the default schema in a given database. + """ + with database.get_inspector_with_context() as inspector: + return inspector.default_schema_name + + @classmethod + def get_schema_from_engine_params( # pylint: disable=unused-argument + cls, + sqlalchemy_uri: URL, + connect_args: Dict[str, Any], + ) -> Optional[str]: + """ + Return the schema configured in a SQLALchemy URI and connection argments, if any. + """ + return None + + @classmethod + def get_default_schema_for_query( + cls, + database: Database, + query: Query, + ) -> Optional[str]: + """ + Return the default schema for a given query. + + This is used to determine the schema of tables that aren't fully qualified, eg: + + SELECT * FROM foo; + + In the example above, the schema where the `foo` table lives depends on a few + factors: + + 1. For DB engine specs that allow dynamically changing the schema based on the + query we should use the query schema. + 2. For DB engine specs that don't support dynamically changing the schema and + have the schema hardcoded in the SQLAlchemy URI we should use the schema + from the URI. + 3. For DB engine specs that don't connect to a specific schema and can't + change it dynamically we need to probe the database for the default schema. + + Determining the correct schema is crucial for managing access to data, so please + make sure you understand this logic when working on a new DB engine spec. + """ + # default schema varies on a per-query basis + if cls.supports_dynamic_schema: + return query.schema + + # check if the schema is stored in the SQLAlchemy URI or connection arguments + try: + connect_args = database.get_extra()["engine_params"]["connect_args"] + except KeyError: + connect_args = {} + sqlalchemy_uri = make_url_safe(database.sqlalchemy_uri) + if schema := cls.get_schema_from_engine_params(sqlalchemy_uri, connect_args): + return schema + + # return the default schema of the database + return cls.get_default_schema(database) + @classmethod def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: """ diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py index f14bdd79f0b74..4ae5ae59b301e 100644 --- a/superset/db_engine_specs/drill.py +++ b/superset/db_engine_specs/drill.py @@ -32,7 +32,7 @@ class DrillEngineSpec(BaseEngineSpec): engine_name = "Apache Drill" default_driver = "sadrill" - dynamic_schema = True + supports_dynamic_schema = True _time_grain_expressions = { None: "{col}", @@ -73,10 +73,23 @@ def convert_dttm( @classmethod def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL: if selected_schema: - uri = uri.set(database=parse.quote(selected_schema, safe="")) + uri = uri.set( + database=parse.quote(selected_schema.replace(".", "/"), safe="") + ) return uri + @classmethod + def get_schema_from_engine_params( + cls, + sqlalchemy_uri: URL, + connect_args: Dict[str, Any], + ) -> Optional[str]: + """ + Return the configured schema. + """ + return parse.unquote(sqlalchemy_uri.database).replace("/", ".") + @classmethod def get_url_for_impersonation( cls, url: URL, impersonate_user: bool, username: Optional[str] diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 8c69ab3fc7643..f90d889f8cc5a 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -98,7 +98,7 @@ class HiveEngineSpec(PrestoEngineSpec): allows_alias_to_source_column = True allows_hidden_orderby_agg = False - dynamic_schema = True + supports_dynamic_schema = True # When running `SHOW FUNCTIONS`, what is the name of the column with the # function names? @@ -268,6 +268,17 @@ def adjust_database_uri( return uri + @classmethod + def get_schema_from_engine_params( + cls, + sqlalchemy_uri: URL, + connect_args: Dict[str, Any], + ) -> Optional[str]: + """ + Return the configured schema. + """ + return parse.unquote(sqlalchemy_uri.database) + @classmethod def _extract_error_message(cls, ex: Exception) -> str: msg = str(ex) diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 75c1c697892f2..04b8c68dd7503 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -69,7 +69,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): ) encryption_parameters = {"ssl": "1"} - dynamic_schema = True + supports_dynamic_schema = True column_type_mappings = ( ( @@ -192,13 +192,28 @@ def convert_dttm( @classmethod def adjust_database_uri( - cls, uri: URL, selected_schema: Optional[str] = None + cls, + uri: URL, + selected_schema: Optional[str] = None, ) -> URL: if selected_schema: uri = uri.set(database=parse.quote(selected_schema, safe="")) return uri + @classmethod + def get_schema_from_engine_params( + cls, + sqlalchemy_uri: URL, + connect_args: Dict[str, Any], + ) -> Optional[str]: + """ + Return the configured schema. + + A MySQL database is a SQLAlchemy schema. + """ + return parse.unquote(sqlalchemy_uri.database) + @classmethod def get_datatype(cls, type_code: Any) -> Optional[str]: if not cls.type_code_map: diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index cbe00ea58dfc6..84ddf56e10f00 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -23,6 +23,7 @@ from flask_babel import gettext as __ from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON from sqlalchemy.dialects.postgresql.base import PGInspector +from sqlalchemy.engine.url import URL from sqlalchemy.types import Date, DateTime, String from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin @@ -146,6 +147,41 @@ class PostgresBaseEngineSpec(BaseEngineSpec): ), } + @classmethod + def get_schema_from_engine_params( + cls, + sqlalchemy_uri: URL, + connect_args: Dict[str, Any], + ) -> Optional[str]: + """ + Return the configured schema. + + While Postgres doesn't support connecting directly to a given schema, it allows + users to specify a "search path" that is used to resolve non-qualified table + names; this can be specified in the database ``connect_args``. + + One important detail is that the search path can be a comma separated list of + schemas. While this is supported by the SQLAlchemy dialect, it shouldn't be used + in Superset because it breaks schema-level permissions, since it's impossible + to determine the schema for a non-qualified table in a query. In cases like + that we raise an exception. + """ + options = re.split(r"-c\s?", connect_args.get("options", "")) + for option in options: + if "=" not in option: + continue + key, value = option.strip().split("=", 1) + if key.strip() == "search_path": + if "," in value: + raise Exception( + "Multiple schemas are configured in the search path, which means " + "Superset is unable to determine the schema of unqualified table " + "names and enforce permissions." + ) + return value.strip() + + return None + @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index cda946ec4db63..dd7bd88cdb478 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -165,7 +165,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): A base class that share common functions between Presto and Trino """ - dynamic_schema = True + supports_dynamic_schema = True column_type_mappings = ( ( @@ -315,6 +315,27 @@ def adjust_database_uri( return uri + @classmethod + def get_schema_from_engine_params( + cls, + sqlalchemy_uri: URL, + connect_args: Dict[str, Any], + ) -> Optional[str]: + """ + Return the configured schema. + + For Presto the SQLAlchemy URI looks like this: + + presto://localhost:8080/hive[/default] + + """ + database = sqlalchemy_uri.database.strip("/") + + if "/" not in database: + return None + + return parse.unquote(database.split("/")[1]) + @classmethod def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: """ diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 38addb6e350e4..ba15eea7fb508 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -83,7 +83,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): default_driver = "snowflake" sqlalchemy_uri_placeholder = "snowflake://" - dynamic_schema = True + supports_dynamic_schema = True _time_grain_expressions = { None: "{col}", @@ -147,6 +147,22 @@ def adjust_database_uri( return uri + @classmethod + def get_schema_from_engine_params( + cls, + sqlalchemy_uri: URL, + connect_args: Dict[str, Any], + ) -> Optional[str]: + """ + Return the configured schema. + """ + database = sqlalchemy_uri.database.strip("/") + + if "/" not in database: + return None + + return parse.unquote(database.split("/")[1]) + @classmethod def epoch_to_dttm(cls) -> str: return "DATEADD(S, {col}, '1970-01-01')" diff --git a/superset/models/core.py b/superset/models/core.py index 9c67a2efa6d2b..5717726edca96 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -78,6 +78,7 @@ if TYPE_CHECKING: from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.sql_lab import Query DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"] @@ -483,6 +484,22 @@ def get_raw_connection( with closing(engine.raw_connection()) as conn: yield conn + def get_default_schema_for_query(self, query: "Query") -> Optional[str]: + """ + Return the default schema for a given query. + + This is used to determine if the user has access to a query that reads from table + names without a specific schema, eg: + + SELECT * FROM `foo` + + The schema of the `foo` table depends on the DB engine spec. Some DB engine specs + can change the default schema on a per-query basis; in other DB engine specs the + default schema is defined in the SQLAlchemy URI; and in others the default schema + might be determined by the database itself (like `public` for Postgres). + """ + return self.db_engine_spec.get_default_schema_for_query(self, query) + @property def quote_identifier(self) -> Callable[[str], str]: """Add quotes to potential identifiter expressions if needed""" diff --git a/superset/security/manager.py b/superset/security/manager.py index 9c154c6498706..5aa5080294f73 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1787,7 +1787,7 @@ def get_exclude_users_from_lists() -> List[str]: return [] def raise_for_access( - # pylint: disable=too-many-arguments, too-many-locals, too-many-branches + # pylint: disable=too-many-arguments, too-many-locals self, database: Optional["Database"] = None, datasource: Optional["BaseDatasource"] = None, @@ -1823,18 +1823,7 @@ def raise_for_access( return if query: - # Some databases can change the default schema in which the query wil run, - # respecting the selection in SQL Lab. If that's the case, the query - # schema becomes the default one. - if database.db_engine_spec.dynamic_schema: - default_schema = query.schema - # For other databases, the selected schema in SQL Lab is used only for - # table discovery and autocomplete. In this case we need to use the - # database default schema for tables that don't have an explicit schema. - else: - with database.get_inspector_with_context() as inspector: - default_schema = inspector.default_schema_name - + default_schema = database.get_default_schema_for_query(query) tables = { Table(table_.table, table_.schema or default_schema) for table_ in sql_parse.ParsedQuery(query.sql).tables diff --git a/tests/unit_tests/db_engine_specs/test_drill.py b/tests/unit_tests/db_engine_specs/test_drill.py index e56df5d47cc4b..2930e0e81ea42 100644 --- a/tests/unit_tests/db_engine_specs/test_drill.py +++ b/tests/unit_tests/db_engine_specs/test_drill.py @@ -20,6 +20,7 @@ from typing import Optional import pytest +from sqlalchemy.engine.url import make_url from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm @@ -106,3 +107,18 @@ def test_convert_dttm( from superset.db_engine_specs.drill import DrillEngineSpec as spec assert_convert_dttm(spec, target_type, expected_result, dttm) + + +def test_get_schema_from_engine_params() -> None: + """ + Test ``get_schema_from_engine_params``. + """ + from superset.db_engine_specs.drill import DrillEngineSpec + + assert ( + DrillEngineSpec.get_schema_from_engine_params( + make_url("drill+sadrill://localhost:8047/dfs/test?use_ssl=False"), + {}, + ) + == "dfs.test" + ) diff --git a/tests/unit_tests/db_engine_specs/test_hive.py b/tests/unit_tests/db_engine_specs/test_hive.py index 3a5cb91405bd4..ba6471b893a35 100644 --- a/tests/unit_tests/db_engine_specs/test_hive.py +++ b/tests/unit_tests/db_engine_specs/test_hive.py @@ -20,6 +20,7 @@ from typing import Optional import pytest +from sqlalchemy.engine.url import make_url from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm @@ -42,3 +43,17 @@ def test_convert_dttm( from superset.db_engine_specs.hive import HiveEngineSpec as spec assert_convert_dttm(spec, target_type, expected_result, dttm) + + +def test_get_schema_from_engine_params() -> None: + """ + Test the ``get_schema_from_engine_params`` method. + """ + from superset.db_engine_specs.hive import HiveEngineSpec + + assert ( + HiveEngineSpec.get_schema_from_engine_params( + make_url("hive://localhost:10000/default"), {} + ) + == "default" + ) diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index a512e71a97f67..091cdb3b46305 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -148,3 +148,17 @@ def test_cancel_query_failed(engine_mock: Mock) -> None: query = Query() cursor_mock = engine_mock.raiseError.side_effect = Exception() assert MySQLEngineSpec.cancel_query(cursor_mock, query, "123") is False + + +def test_get_schema_from_engine_params() -> None: + """ + Test the ``get_schema_from_engine_params`` method. + """ + from superset.db_engine_specs.mysql import MySQLEngineSpec + + assert ( + MySQLEngineSpec.get_schema_from_engine_params( + make_url("mysql://user:password@host/db1"), {} + ) + == "db1" + ) diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py index 088ce2747834d..e57e6a6f8e23e 100644 --- a/tests/unit_tests/db_engine_specs/test_postgres.py +++ b/tests/unit_tests/db_engine_specs/test_postgres.py @@ -21,6 +21,7 @@ import pytest from sqlalchemy import types from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON +from sqlalchemy.engine.url import make_url from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( @@ -89,3 +90,44 @@ def test_get_column_spec( from superset.db_engine_specs.postgres import PostgresEngineSpec as spec assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm) + + +def test_get_schema_from_engine_params() -> None: + """ + Test the ``get_schema_from_engine_params`` method. + """ + from superset.db_engine_specs.postgres import PostgresEngineSpec + + assert ( + PostgresEngineSpec.get_schema_from_engine_params( + make_url("postgresql://user:password@host/db1"), {} + ) + is None + ) + + assert ( + PostgresEngineSpec.get_schema_from_engine_params( + make_url("postgresql://user:password@host/db1"), + {"options": "-csearch_path=secret"}, + ) + == "secret" + ) + + assert ( + PostgresEngineSpec.get_schema_from_engine_params( + make_url("postgresql://user:password@host/db1"), + {"options": "-c search_path = secret -cfoo=bar -c debug"}, + ) + == "secret" + ) + + with pytest.raises(Exception) as excinfo: + PostgresEngineSpec.get_schema_from_engine_params( + make_url("postgresql://user:password@host/db1"), + {"options": "-csearch_path=secret,public"}, + ) + assert str(excinfo.value) == ( + "Multiple schemas are configured in the search path, which means " + "Superset is unable to determine the schema of unqualified table " + "names and enforce permissions." + ) diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index a30fab94c9157..2684db0555523 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -20,6 +20,7 @@ import pytest import pytz from sqlalchemy import types +from sqlalchemy.engine.url import make_url from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( @@ -82,3 +83,26 @@ def test_get_column_spec( from superset.db_engine_specs.presto import PrestoEngineSpec as spec assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm) + + +def test_get_schema_from_engine_params() -> None: + """ + Test the ``get_schema_from_engine_params`` method. + """ + from superset.db_engine_specs.presto import PrestoEngineSpec + + assert ( + PrestoEngineSpec.get_schema_from_engine_params( + make_url("presto://localhost:8080/hive/default"), + {}, + ) + == "default" + ) + + assert ( + PrestoEngineSpec.get_schema_from_engine_params( + make_url("presto://localhost:8080/hive"), + {}, + ) + is None + ) diff --git a/tests/unit_tests/db_engine_specs/test_snowflake.py b/tests/unit_tests/db_engine_specs/test_snowflake.py index 9689428d25653..5d560dd89ddd5 100644 --- a/tests/unit_tests/db_engine_specs/test_snowflake.py +++ b/tests/unit_tests/db_engine_specs/test_snowflake.py @@ -24,6 +24,7 @@ import pytest from pytest_mock import MockerFixture +from sqlalchemy.engine.url import make_url from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm @@ -169,3 +170,34 @@ def test_get_extra_params(mocker: MockerFixture) -> None: "connect_args": {"application": "Custom user agent", "foo": "bar"} } } + + +def test_get_schema_from_engine_params() -> None: + """ + Test the ``get_schema_from_engine_params`` method. + """ + from superset.db_engine_specs.snowflake import SnowflakeEngineSpec + + assert ( + SnowflakeEngineSpec.get_schema_from_engine_params( + make_url("snowflake://user:pass@account/database_name/default"), + {}, + ) + == "default" + ) + + assert ( + SnowflakeEngineSpec.get_schema_from_engine_params( + make_url("snowflake://user:pass@account/database_name"), + {}, + ) + is None + ) + + assert ( + SnowflakeEngineSpec.get_schema_from_engine_params( + make_url("snowflake://user:pass@account/"), + {}, + ) + is None + ) diff --git a/tests/unit_tests/explore/utils_test.py b/tests/unit_tests/explore/utils_test.py index 9b75a13fbac2a..b2989b1244de3 100644 --- a/tests/unit_tests/explore/utils_test.py +++ b/tests/unit_tests/explore/utils_test.py @@ -274,10 +274,8 @@ def test_query_no_access(mocker: MockFixture, client) -> None: from superset.models.core import Database from superset.models.sql_lab import Query - inspect = mocker.patch("superset.security.manager.inspect") - inspect().default_schema_name = "public" - database = mocker.MagicMock() + database.get_default_schema_for_query.return_value = "public" mocker.patch( query_find_by_id, return_value=Query(database=database, sql="select * from foo"), diff --git a/tests/unit_tests/security/manager_test.py b/tests/unit_tests/security/manager_test.py index 6d0468c75c78d..1843e7261cc9a 100644 --- a/tests/unit_tests/security/manager_test.py +++ b/tests/unit_tests/security/manager_test.py @@ -52,8 +52,7 @@ def test_raise_for_access_query_default_schema( SqlaTable.query_datasources_by_name.return_value = [] database = mocker.MagicMock() - database.db_engine_spec.dynamic_schema = False - database.get_inspector_with_context().__enter__().default_schema_name = "public" + database.get_default_schema_for_query.return_value = "public" query = mocker.MagicMock() query.database = database query.sql = "SELECT * FROM ab_user"