From 6cf681df6808c9b612cff1e53ddb6925a9b28ebf Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 25 Apr 2024 12:23:49 -0400 Subject: [PATCH] feat(SIP-95): new endpoint for table metadata (#28122) --- .../src/SqlLab/actions/sqlLab.test.js | 5 +- .../SqlEditorLeftBar.test.tsx | 4 +- .../TableElement/TableElement.test.tsx | 5 +- .../AddDataset/DatasetPanel/index.tsx | 4 +- .../src/hooks/apiResources/tables.ts | 6 +- superset/commands/database/tables.py | 2 + superset/commands/database/validate_sql.py | 3 +- superset/commands/dataset/create.py | 13 +- .../commands/dataset/importers/v1/utils.py | 10 +- superset/commands/dataset/update.py | 4 +- superset/commands/sql_lab/estimate.py | 25 +- superset/connectors/sqla/models.py | 51 ++-- superset/connectors/sqla/utils.py | 28 +-- superset/constants.py | 1 + superset/daos/dataset.py | 24 +- superset/databases/api.py | 97 +++++++- superset/databases/schemas.py | 47 +++- superset/databases/utils.py | 53 ++-- superset/db_engine_specs/README.md | 6 +- superset/db_engine_specs/base.py | 124 ++++++---- superset/db_engine_specs/bigquery.py | 103 ++++---- superset/db_engine_specs/db2.py | 10 +- superset/db_engine_specs/gsheets.py | 11 +- superset/db_engine_specs/hive.py | 44 ++-- superset/db_engine_specs/presto.py | 97 ++++---- superset/db_engine_specs/trino.py | 30 ++- superset/examples/bart_lines.py | 3 +- superset/examples/birth_names.py | 3 +- superset/examples/country_map.py | 3 +- superset/examples/energy.py | 3 +- superset/examples/flights.py | 3 +- superset/examples/long_lat.py | 3 +- superset/examples/multiformat_time_series.py | 3 +- superset/examples/paris.py | 3 +- superset/examples/random_time_series.py | 3 +- superset/examples/sf_population_polygons.py | 3 +- .../examples/supported_charts_dashboard.py | 3 +- superset/examples/world_bank.py | 3 +- superset/extensions/metadb.py | 5 +- ...1_15-41_5f57af97bc3f_add_catalog_column.py | 55 +++++ superset/models/core.py | 232 +++++++++--------- superset/models/dashboard.py | 8 - superset/models/sql_lab.py | 4 + superset/security/manager.py | 2 + superset/sql_lab.py | 14 +- superset/sql_validators/base.py | 17 +- superset/sql_validators/postgres.py | 9 +- superset/sql_validators/presto_db.py | 18 +- superset/utils/mock_data.py | 7 +- superset/views/datasource/views.py | 4 +- tests/integration_tests/celery_tests.py | 2 +- .../charts/data/api_tests.py | 4 +- tests/integration_tests/core_tests.py | 8 +- .../integration_tests/databases/api_tests.py | 2 +- .../databases/commands/upload_test.py | 4 +- tests/integration_tests/datasets/api_tests.py | 13 +- .../db_engine_specs/base_engine_spec_tests.py | 7 +- .../db_engine_specs/bigquery_tests.py | 9 +- .../db_engine_specs/hive_tests.py | 16 +- .../db_engine_specs/postgres_tests.py | 2 +- .../db_engine_specs/presto_tests.py | 33 ++- tests/integration_tests/model_tests.py | 9 +- .../integration_tests/sql_validator_tests.py | 18 +- tests/unit_tests/dao/dataset_test.py | 10 +- tests/unit_tests/databases/api_test.py | 164 +++++++++++++ tests/unit_tests/db_engine_specs/test_base.py | 6 +- .../db_engine_specs/test_bigquery.py | 4 +- tests/unit_tests/db_engine_specs/test_db2.py | 7 +- .../unit_tests/db_engine_specs/test_presto.py | 6 +- .../unit_tests/db_engine_specs/test_trino.py | 16 +- tests/unit_tests/models/core_test.py | 9 +- 71 files changed, 1051 insertions(+), 516 deletions(-) create mode 100644 superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.test.js b/superset-frontend/src/SqlLab/actions/sqlLab.test.js index ecf2c4d7e299c..871b3ff6f6b4f 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.test.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.test.js @@ -508,10 +508,11 @@ describe('async actions', () => { fetchMock.delete(updateTableSchemaEndpoint, {}); fetchMock.post(updateTableSchemaEndpoint, JSON.stringify({ id: 1 })); - const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table/*/*/'; + const getTableMetadataEndpoint = + 'glob:**/api/v1/database/*/table_metadata/*'; fetchMock.get(getTableMetadataEndpoint, {}); const getExtraTableMetadataEndpoint = - 'glob:**/api/v1/database/*/table_metadata/extra/'; + 'glob:**/api/v1/database/*/table_metadata/extra/*'; fetchMock.get(getExtraTableMetadataEndpoint, {}); let isFeatureEnabledMock; diff --git a/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx b/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx index f8c94468bf7f3..b5003b16f7b47 100644 --- a/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx +++ b/superset-frontend/src/SqlLab/components/SqlEditorLeftBar/SqlEditorLeftBar.test.tsx @@ -61,13 +61,13 @@ beforeEach(() => { }, ], }); - fetchMock.get('glob:*/api/v1/database/*/table/*/*', { + fetchMock.get('glob:*/api/v1/database/*/table_metadata/*', { status: 200, body: { columns: table.columns, }, }); - fetchMock.get('glob:*/api/v1/database/*/table_metadata/extra/', { + fetchMock.get('glob:*/api/v1/database/*/table_metadata/extra/*', { status: 200, body: {}, }); diff --git a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx index a2fe88020aa45..1489f23a13a06 100644 --- a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx +++ b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx @@ -47,9 +47,10 @@ jest.mock(
{column.name}
), ); -const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table/*/*/'; +const getTableMetadataEndpoint = + /\/api\/v1\/database\/\d+\/table_metadata\/(?:\?.*)?$/; const getExtraTableMetadataEndpoint = - 'glob:**/api/v1/database/*/table_metadata/extra/*'; + /\/api\/v1\/database\/\d+\/table_metadata\/extra\/(?:\?.*)?$/; const updateTableSchemaEndpoint = 'glob:*/tableschemaview/*/expanded'; beforeEach(() => { diff --git a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx index ef5797fb309c2..b3f8aec8f99a0 100644 --- a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx +++ b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx @@ -74,7 +74,9 @@ const DatasetPanelWrapper = ({ const { dbId, tableName, schema } = props; setLoading(true); setHasColumns?.(false); - const path = `/api/v1/database/${dbId}/table/${tableName}/${schema}/`; + const path = schema + ? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}` + : `/api/v1/database/${dbId}/table_metadata/?name=${tableName}`; try { const response = await SupersetClient.get({ endpoint: path, diff --git a/superset-frontend/src/hooks/apiResources/tables.ts b/superset-frontend/src/hooks/apiResources/tables.ts index 164fe0f0ab19c..41be4c167c9c8 100644 --- a/superset-frontend/src/hooks/apiResources/tables.ts +++ b/superset-frontend/src/hooks/apiResources/tables.ts @@ -114,9 +114,9 @@ const tableApi = api.injectEndpoints({ }), tableMetadata: builder.query({ query: ({ dbId, schema, table }) => ({ - endpoint: `/api/v1/database/${dbId}/table/${encodeURIComponent( - table, - )}/${encodeURIComponent(schema)}/`, + endpoint: schema + ? `/api/v1/database/${dbId}/table_metadata/?name=${table}&schema=${schema}` + : `/api/v1/database/${dbId}/table_metadata/?name=${table}`, transformResponse: ({ json }: TableMetadataReponse) => json, }), }), diff --git a/superset/commands/database/tables.py b/superset/commands/database/tables.py index fa98bcbc7ec5a..055c0be9aea91 100644 --- a/superset/commands/database/tables.py +++ b/superset/commands/database/tables.py @@ -51,6 +51,7 @@ def run(self) -> dict[str, Any]: datasource_names=sorted( DatasourceName(*datasource_name) for datasource_name in self._model.get_all_table_names_in_schema( + catalog=None, schema=self._schema_name, force=self._force, cache=self._model.table_cache_enabled, @@ -65,6 +66,7 @@ def run(self) -> dict[str, Any]: datasource_names=sorted( DatasourceName(*datasource_name) for datasource_name in self._model.get_all_view_names_in_schema( + catalog=None, schema=self._schema_name, force=self._force, cache=self._model.table_cache_enabled, diff --git a/superset/commands/database/validate_sql.py b/superset/commands/database/validate_sql.py index 6ecc4f1626edb..6a93a01473acb 100644 --- a/superset/commands/database/validate_sql.py +++ b/superset/commands/database/validate_sql.py @@ -61,11 +61,12 @@ def run(self) -> list[dict[str, Any]]: raise ValidatorSQLUnexpectedError() sql = self._properties["sql"] schema = self._properties.get("schema") + catalog = self._properties.get("catalog") try: timeout = current_app.config["SQLLAB_VALIDATION_TIMEOUT"] timeout_msg = f"The query exceeded the {timeout} seconds timeout." with utils.timeout(seconds=timeout, error_message=timeout_msg): - errors = self._validator.validate(sql, schema, self._model) + errors = self._validator.validate(sql, catalog, schema, self._model) return [err.to_dict() for err in errors] except Exception as ex: logger.exception(ex) diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py index 16b87a567a5f0..dace92f911bcf 100644 --- a/superset/commands/dataset/create.py +++ b/superset/commands/dataset/create.py @@ -34,6 +34,7 @@ from superset.daos.exceptions import DAOCreateFailedError from superset.exceptions import SupersetSecurityException from superset.extensions import db, security_manager +from superset.sql_parse import Table logger = logging.getLogger(__name__) @@ -61,12 +62,15 @@ def validate(self) -> None: exceptions: list[ValidationError] = [] database_id = self._properties["database"] table_name = self._properties["table_name"] - schema = self._properties.get("schema", None) - sql = self._properties.get("sql", None) + schema = self._properties.get("schema") + catalog = self._properties.get("catalog") + sql = self._properties.get("sql") owner_ids: Optional[list[int]] = self._properties.get("owners") + table = Table(table_name, schema, catalog) + # Validate uniqueness - if not DatasetDAO.validate_uniqueness(database_id, schema, table_name): + if not DatasetDAO.validate_uniqueness(database_id, table): exceptions.append(DatasetExistsValidationError(table_name)) # Validate/Populate database @@ -80,7 +84,7 @@ def validate(self) -> None: if ( database and not sql - and not DatasetDAO.validate_table_exists(database, table_name, schema) + and not DatasetDAO.validate_table_exists(database, table) ): exceptions.append(TableNotFoundValidationError(table_name)) @@ -89,6 +93,7 @@ def validate(self) -> None: security_manager.raise_for_access( database=database, sql=sql, + catalog=catalog, schema=schema, ) except SupersetSecurityException as ex: diff --git a/superset/commands/dataset/importers/v1/utils.py b/superset/commands/dataset/importers/v1/utils.py index 50bb916b07801..0d2226f724277 100644 --- a/superset/commands/dataset/importers/v1/utils.py +++ b/superset/commands/dataset/importers/v1/utils.py @@ -32,6 +32,7 @@ from superset.commands.exceptions import ImportFailedError from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database +from superset.sql_parse import Table from superset.utils.core import get_user logger = logging.getLogger(__name__) @@ -164,7 +165,9 @@ def import_dataset( db.session.flush() try: - table_exists = dataset.database.has_table_by_name(dataset.table_name) + table_exists = dataset.database.has_table( + Table(dataset.table_name, dataset.schema), + ) except Exception: # pylint: disable=broad-except # MySQL doesn't play nice with GSheets table names logger.warning( @@ -217,7 +220,10 @@ def load_data(data_uri: str, dataset: SqlaTable, database: Database) -> None: ) else: logger.warning("Loading data outside the import transaction") - with database.get_sqla_engine() as engine: + with database.get_sqla_engine( + catalog=dataset.catalog, + schema=dataset.schema, + ) as engine: df.to_sql( dataset.table_name, con=engine, diff --git a/superset/commands/dataset/update.py b/superset/commands/dataset/update.py index 5c0d87b230eff..282c778eb432e 100644 --- a/superset/commands/dataset/update.py +++ b/superset/commands/dataset/update.py @@ -41,6 +41,7 @@ from superset.daos.dataset import DatasetDAO from superset.daos.exceptions import DAOUpdateFailedError from superset.exceptions import SupersetSecurityException +from superset.sql_parse import Table logger = logging.getLogger(__name__) @@ -90,9 +91,8 @@ def validate(self) -> None: # Validate uniqueness if not DatasetDAO.validate_update_uniqueness( self._model.database_id, - self._model.schema, + Table(table_name, self._model.schema, self._model.catalog), self._model_id, - table_name, ): exceptions.append(DatasetExistsValidationError(table_name)) # Validate/Populate database not allowed to change diff --git a/superset/commands/sql_lab/estimate.py b/superset/commands/sql_lab/estimate.py index bf1d6c4fa57d0..d3198815662ad 100644 --- a/superset/commands/sql_lab/estimate.py +++ b/superset/commands/sql_lab/estimate.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any +from typing import Any, TypedDict from flask_babel import gettext as __ @@ -27,7 +27,6 @@ from superset.exceptions import SupersetErrorException, SupersetTimeoutException from superset.jinja_context import get_template_processor from superset.models.core import Database -from superset.sqllab.schemas import EstimateQueryCostSchema from superset.utils import core as utils config = app.config @@ -37,18 +36,28 @@ logger = logging.getLogger(__name__) +class EstimateQueryCostType(TypedDict): + database_id: int + sql: str + template_params: dict[str, Any] + catalog: str | None + schema: str | None + + class QueryEstimationCommand(BaseCommand): _database_id: int _sql: str _template_params: dict[str, Any] _schema: str _database: Database + _catalog: str | None - def __init__(self, params: EstimateQueryCostSchema) -> None: - self._database_id = params.get("database_id") + def __init__(self, params: EstimateQueryCostType) -> None: + self._database_id = params["database_id"] self._sql = params.get("sql", "") self._template_params = params.get("template_params", {}) - self._schema = params.get("schema", "") + self._schema = params.get("schema") or "" + self._catalog = params.get("catalog") def validate(self) -> None: self._database = db.session.query(Database).get(self._database_id) @@ -77,7 +86,11 @@ def run( try: with utils.timeout(seconds=timeout, error_message=timeout_msg): cost = self._database.db_engine_spec.estimate_query_cost( - self._database, self._schema, sql, utils.QuerySource.SQL_LAB + self._database, + self._catalog, + self._schema, + sql, + utils.QuerySource.SQL_LAB, ) except SupersetTimeoutException as ex: logger.exception(ex) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 339be9d177e6b..719d5af588852 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -49,7 +49,7 @@ Integer, or_, String, - Table, + Table as DBTable, Text, update, ) @@ -108,7 +108,7 @@ validate_adhoc_subquery, ) from superset.models.slice import Slice -from superset.sql_parse import ParsedQuery, sanitize_clause +from superset.sql_parse import ParsedQuery, sanitize_clause, Table from superset.superset_typing import ( AdhocColumn, AdhocMetric, @@ -329,7 +329,7 @@ def short_data(self) -> dict[str, Any]: "edit_url": self.url, "id": self.id, "uid": self.uid, - "schema": self.schema, + "schema": self.schema or None, "name": self.name, "type": self.type, "connection": self.connection, @@ -383,7 +383,7 @@ def data(self) -> dict[str, Any]: "datasource_name": self.datasource_name, "table_name": self.datasource_name, "type": self.type, - "schema": self.schema, + "schema": self.schema or None, "offset": self.offset, "cache_timeout": self.cache_timeout, "params": self.params, @@ -1065,7 +1065,7 @@ def data(self) -> dict[str, Any]: return {s: getattr(self, s) for s in attrs} -sqlatable_user = Table( +sqlatable_user = DBTable( "sqlatable_user", metadata, Column("id", Integer, primary_key=True), @@ -1143,6 +1143,7 @@ class SqlaTable( foreign_keys=[database_id], ) schema = Column(String(255)) + catalog = Column(String(256), nullable=True, default=None) sql = Column(MediumText()) is_sqllab_view = Column(Boolean, default=False) template_params = Column(Text) @@ -1262,7 +1263,7 @@ def link(self) -> Markup: def get_schema_perm(self) -> str | None: """Returns schema permission if present, database one otherwise.""" - return security_manager.get_schema_perm(self.database, self.schema) + return security_manager.get_schema_perm(self.database, self.schema or None) def get_perm(self) -> str: """ @@ -1319,8 +1320,7 @@ def external_metadata(self) -> list[ResultSetColumnType]: return get_virtual_table_metadata(dataset=self) return get_physical_table_metadata( database=self.database, - table_name=self.table_name, - schema_name=self.schema, + table=Table(self.table_name, self.schema or None, self.catalog), normalize_columns=self.normalize_columns, ) @@ -1336,7 +1336,9 @@ def select_star(self) -> str | None: # show_cols and latest_partition set to false to avoid # the expensive cost of inspecting the DB return self.database.select_star( - self.table_name, schema=self.schema, show_cols=False, latest_partition=False + Table(self.table_name, self.schema or None, self.catalog), + show_cols=False, + latest_partition=False, ) @property @@ -1523,7 +1525,12 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals tbl, _ = self.get_from_clause(template_processor) qry = sa.select([sqla_column]).limit(1).select_from(tbl) sql = self.database.compile_sqla_query(qry) - col_desc = get_columns_description(self.database, self.schema, sql) + col_desc = get_columns_description( + self.database, + self.catalog, + self.schema or None, + sql, + ) if not col_desc: raise SupersetGenericDBErrorException("Column not found") is_dttm = col_desc[0]["is_dttm"] # type: ignore @@ -1728,7 +1735,9 @@ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None: return df try: - df = self.database.get_df(sql, self.schema, mutator=assign_column_label) + df = self.database.get_df( + sql, self.schema or None, mutator=assign_column_label + ) except (SupersetErrorException, SupersetErrorsException) as ex: # SupersetError(s) exception should not be captured; instead, they should # bubble up to the Flask error handler so they are returned as proper SIP-40 @@ -1762,7 +1771,13 @@ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None: ) def get_sqla_table_object(self) -> Table: - return self.database.get_table(self.table_name, schema=self.schema) + return self.database.get_table( + Table( + self.table_name, + self.schema or None, + self.catalog, + ) + ) def fetch_metadata(self, commit: bool = True) -> MetadataResult: """ @@ -1774,7 +1789,13 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult: new_columns = self.external_metadata() metrics = [ SqlMetric(**metric) - for metric in self.database.get_metrics(self.table_name, self.schema) + for metric in self.database.get_metrics( + Table( + self.table_name, + self.schema or None, + self.catalog, + ) + ) ] any_date_col = None db_engine_spec = self.db_engine_spec @@ -2021,7 +2042,7 @@ def load_database(self: SqlaTable) -> None: sa.event.listen(SqlMetric, "after_update", SqlaTable.update_column) sa.event.listen(TableColumn, "after_update", SqlaTable.update_column) -RLSFilterRoles = Table( +RLSFilterRoles = DBTable( "rls_filter_roles", metadata, Column("id", Integer, primary_key=True), @@ -2029,7 +2050,7 @@ def load_database(self: SqlaTable) -> None: Column("rls_filter_id", Integer, ForeignKey("row_level_security_filters.id")), ) -RLSFilterTables = Table( +RLSFilterTables = DBTable( "rls_filter_tables", metadata, Column("id", Integer, primary_key=True), diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 4bc11aee42d80..87b3d5dd3a28a 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -38,7 +38,7 @@ ) from superset.models.core import Database from superset.result_set import SupersetResultSet -from superset.sql_parse import ParsedQuery +from superset.sql_parse import ParsedQuery, Table from superset.superset_typing import ResultSetColumnType if TYPE_CHECKING: @@ -47,24 +47,18 @@ def get_physical_table_metadata( database: Database, - table_name: str, + table: Table, normalize_columns: bool, - schema_name: str | None = None, ) -> list[ResultSetColumnType]: """Use SQLAlchemy inspector to get table metadata""" db_engine_spec = database.db_engine_spec db_dialect = database.get_dialect() - # ensure empty schema - _schema_name = schema_name if schema_name else None - # Table does not exist or is not visible to a connection. - if not ( - database.has_table_by_name(table_name=table_name, schema=_schema_name) - or database.has_view_by_name(view_name=table_name, schema=_schema_name) - ): - raise NoSuchTableError + # Table does not exist or is not visible to a connection. + if not (database.has_table(table) or database.has_view(table)): + raise NoSuchTableError(table) - cols = database.get_columns(table_name, schema=_schema_name) + cols = database.get_columns(table) for col in cols: try: if isinstance(col["type"], TypeEngine): @@ -129,11 +123,17 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]: level=ErrorLevel.ERROR, ) ) - return get_columns_description(dataset.database, dataset.schema, statements[0]) + return get_columns_description( + dataset.database, + dataset.catalog, + dataset.schema, + statements[0], + ) def get_columns_description( database: Database, + catalog: str | None, schema: str | None, query: str, ) -> list[ResultSetColumnType]: @@ -141,7 +141,7 @@ def get_columns_description( # sql_lab.py:execute_sql_statements db_engine_spec = database.db_engine_spec try: - with database.get_raw_connection(schema=schema) as conn: + with database.get_raw_connection(catalog=catalog, schema=schema) as conn: cursor = conn.cursor() query = database.apply_limit_to_sql(query, limit=1) cursor.execute(query) diff --git a/superset/constants.py b/superset/constants.py index e4d467bdd834e..28902ded6cf0d 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -134,6 +134,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods "schemas": "read", "select_star": "read", "table_metadata": "read", + "table_metadata_deprecated": "read", "table_extra_metadata": "read", "table_extra_metadata_deprecated": "read", "test_connection": "write", diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py index 23b46e3329d99..21c5ae1d0faf0 100644 --- a/superset/daos/dataset.py +++ b/superset/daos/dataset.py @@ -30,6 +30,7 @@ from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.sql_parse import Table from superset.utils.core import DatasourceType from superset.views.base import DatasourceFilter @@ -72,25 +73,26 @@ def get_related_objects(database_id: int) -> dict[str, Any]: @staticmethod def validate_table_exists( - database: Database, table_name: str, schema: str | None + database: Database, + table: Table, ) -> bool: try: - database.get_table(table_name, schema=schema) + database.get_table(table) return True except SQLAlchemyError as ex: # pragma: no cover - logger.warning("Got an error %s validating table: %s", str(ex), table_name) + logger.warning("Got an error %s validating table: %s", str(ex), table) return False @staticmethod def validate_uniqueness( database_id: int, - schema: str | None, - name: str, + table: Table, dataset_id: int | None = None, ) -> bool: dataset_query = db.session.query(SqlaTable).filter( - SqlaTable.table_name == name, - SqlaTable.schema == schema, + SqlaTable.table_name == table.table, + SqlaTable.schema == table.schema, + SqlaTable.catalog == table.catalog, SqlaTable.database_id == database_id, ) @@ -103,14 +105,14 @@ def validate_uniqueness( @staticmethod def validate_update_uniqueness( database_id: int, - schema: str | None, + table: Table, dataset_id: int, - name: str, ) -> bool: dataset_query = db.session.query(SqlaTable).filter( - SqlaTable.table_name == name, + SqlaTable.table_name == table.table, SqlaTable.database_id == database_id, - SqlaTable.schema == schema, + SqlaTable.schema == table.schema, + SqlaTable.catalog == table.catalog, SqlaTable.id != dataset_id, ) return not db.session.query(dataset_query.exists()).scalar() diff --git a/superset/databases/api.py b/superset/databases/api.py index 635a2da790b3d..a77019123b976 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -136,6 +136,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): RouteMethod.RELATED, "tables", "table_metadata", + "table_metadata_deprecated", "table_extra_metadata", "table_extra_metadata_deprecated", "select_star", @@ -722,10 +723,10 @@ def tables(self, pk: int, **kwargs: Any) -> FlaskResponse: @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" - f".table_metadata", + f".table_metadata_deprecated", log_to_statsd=False, ) - def table_metadata( + def table_metadata_deprecated( self, database: Database, table_name: str, schema_name: str ) -> FlaskResponse: """Get database table metadata. @@ -766,16 +767,16 @@ def table_metadata( 500: $ref: '#/components/responses/500' """ - self.incr_stats("init", self.table_metadata.__name__) + self.incr_stats("init", self.table_metadata_deprecated.__name__) try: - table_info = get_table_metadata(database, table_name, schema_name) + table_info = get_table_metadata(database, Table(table_name, schema_name)) except SQLAlchemyError as ex: - self.incr_stats("error", self.table_metadata.__name__) + self.incr_stats("error", self.table_metadata_deprecated.__name__) return self.response_422(error_msg_from_exception(ex)) except SupersetException as ex: return self.response(ex.status, message=ex.message) - self.incr_stats("success", self.table_metadata.__name__) + self.incr_stats("success", self.table_metadata_deprecated.__name__) return self.response(200, **table_info) @expose("//table_extra///", methods=("GET",)) @@ -844,7 +845,86 @@ def table_extra_metadata_deprecated( payload = database.db_engine_spec.get_extra_table_metadata(database, table) return self.response(200, **payload) - @expose("//table_metadata/extra/", methods=("GET",)) + @expose("//table_metadata/", methods=["GET"]) + @protect() + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".table_metadata", + log_to_statsd=False, + ) + def table_metadata(self, pk: int) -> FlaskResponse: + """ + Get metadata for a given table. + + Optionally, a schema and a catalog can be passed, if different from the default + ones. + --- + get: + summary: Get table metadata + description: >- + Metadata associated with the table (columns, indexes, etc.) + parameters: + - in: path + schema: + type: integer + name: pk + description: The database id + - in: query + schema: + type: string + name: table + required: true + description: Table name + - in: query + schema: + type: string + name: schema + description: >- + Optional table schema, if not passed default schema will be used + - in: query + schema: + type: string + name: catalog + description: >- + Optional table catalog, if not passed default catalog will be used + responses: + 200: + description: Table metadata information + content: + application/json: + schema: + $ref: "#/components/schemas/TableExtraMetadataResponseSchema" + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + self.incr_stats("init", self.table_metadata.__name__) + + database = DatabaseDAO.find_by_id(pk) + if database is None: + raise DatabaseNotFoundException("No such database") + + try: + parameters = QualifiedTableSchema().load(request.args) + except ValidationError as ex: + raise InvalidPayloadSchemaError(ex) from ex + + table = Table(parameters["name"], parameters["schema"], parameters["catalog"]) + try: + security_manager.raise_for_access(database=database, table=table) + except SupersetSecurityException as ex: + # instead of raising 403, raise 404 to hide table existence + raise TableNotFoundException("No such table") from ex + + payload = database.db_engine_spec.get_table_metadata(database, table) + + return self.response(200, **payload) + + @expose("//table_metadata/extra/", methods=["GET"]) @protect() @statsd_metrics @event_logger.log_this_with_context( @@ -978,7 +1058,8 @@ def select_star( self.incr_stats("init", self.select_star.__name__) try: result = database.select_star( - table_name, schema_name, latest_partition=True + Table(table_name, schema_name), + latest_partition=True, ) except NoSuchTableError: self.incr_stats("error", self.select_star.__name__) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 9a1fc9d6c1511..1bc0af7472c83 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -17,11 +17,13 @@ # pylint: disable=unused-argument, too-many-lines +from __future__ import annotations + import inspect import json import os import re -from typing import Any +from typing import Any, TypedDict from flask import current_app from flask_babel import lazy_gettext as _ @@ -581,6 +583,49 @@ class DatabaseTestConnectionSchema(DatabaseParametersSchemaMixin, Schema): ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) +class TableMetadataOptionsResponse(TypedDict): + deferrable: bool + initially: bool + match: bool + ondelete: bool + onupdate: bool + + +class TableMetadataColumnsResponse(TypedDict, total=False): + keys: list[str] + longType: str + name: str + type: str + duplicates_constraint: str | None + comment: str | None + + +class TableMetadataForeignKeysIndexesResponse(TypedDict): + column_names: list[str] + name: str + options: TableMetadataOptionsResponse + referred_columns: list[str] + referred_schema: str + referred_table: str + type: str + + +class TableMetadataPrimaryKeyResponse(TypedDict): + column_names: list[str] + name: str + type: str + + +class TableMetadataResponse(TypedDict): + name: str + columns: list[TableMetadataColumnsResponse] + foreignKeys: list[TableMetadataForeignKeysIndexesResponse] + indexes: list[TableMetadataForeignKeysIndexesResponse] + primaryKey: TableMetadataPrimaryKeyResponse + selectStar: str + comment: str | None + + class TableMetadataOptionsResponseSchema(Schema): deferrable = fields.Bool() initially = fields.Bool() diff --git a/superset/databases/utils.py b/superset/databases/utils.py index 8de4bb6f2353d..dfd75eb2233f4 100644 --- a/superset/databases/utils.py +++ b/superset/databases/utils.py @@ -14,19 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Optional, Union + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING from sqlalchemy.engine.url import make_url, URL from superset.commands.database.exceptions import DatabaseInvalidError +from superset.sql_parse import Table + +if TYPE_CHECKING: + from superset.databases.schemas import ( + TableMetadataColumnsResponse, + TableMetadataForeignKeysIndexesResponse, + TableMetadataResponse, + ) def get_foreign_keys_metadata( database: Any, - table_name: str, - schema_name: Optional[str], -) -> list[dict[str, Any]]: - foreign_keys = database.get_foreign_keys(table_name, schema_name) + table: Table, +) -> list[TableMetadataForeignKeysIndexesResponse]: + foreign_keys = database.get_foreign_keys(table) for fk in foreign_keys: fk["column_names"] = fk.pop("constrained_columns") fk["type"] = "fk" @@ -34,9 +44,10 @@ def get_foreign_keys_metadata( def get_indexes_metadata( - database: Any, table_name: str, schema_name: Optional[str] -) -> list[dict[str, Any]]: - indexes = database.get_indexes(table_name, schema_name) + database: Any, + table: Table, +) -> list[TableMetadataForeignKeysIndexesResponse]: + indexes = database.get_indexes(table) for idx in indexes: idx["type"] = "index" return indexes @@ -51,30 +62,27 @@ def get_col_type(col: dict[Any, Any]) -> str: return dtype -def get_table_metadata( - database: Any, table_name: str, schema_name: Optional[str] -) -> dict[str, Any]: +def get_table_metadata(database: Any, table: Table) -> TableMetadataResponse: """ Get table metadata information, including type, pk, fks. This function raises SQLAlchemyError when a schema is not found. :param database: The database model - :param table_name: Table name - :param schema_name: schema name + :param table: Table instance :return: Dict table metadata ready for API response """ keys = [] - columns = database.get_columns(table_name, schema_name) - primary_key = database.get_pk_constraint(table_name, schema_name) + columns = database.get_columns(table) + primary_key = database.get_pk_constraint(table) if primary_key and primary_key.get("constrained_columns"): primary_key["column_names"] = primary_key.pop("constrained_columns") primary_key["type"] = "pk" keys += [primary_key] - foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name) - indexes = get_indexes_metadata(database, table_name, schema_name) + foreign_keys = get_foreign_keys_metadata(database, table) + indexes = get_indexes_metadata(database, table) keys += foreign_keys + indexes - payload_columns: list[dict[str, Any]] = [] - table_comment = database.get_table_comment(table_name, schema_name) + payload_columns: list[TableMetadataColumnsResponse] = [] + table_comment = database.get_table_comment(table) for col in columns: dtype = get_col_type(col) payload_columns.append( @@ -87,11 +95,10 @@ def get_table_metadata( } ) return { - "name": table_name, + "name": table.table, "columns": payload_columns, "selectStar": database.select_star( - table_name, - schema=schema_name, + table, indent=True, cols=columns, latest_partition=True, @@ -103,7 +110,7 @@ def get_table_metadata( } -def make_url_safe(raw_url: Union[str, URL]) -> URL: +def make_url_safe(raw_url: str | URL) -> URL: """ Wrapper for SQLAlchemy's make_url(), which tends to raise too detailed of errors, which inevitably find their way into server logs. ArgumentErrors diff --git a/superset/db_engine_specs/README.md b/superset/db_engine_specs/README.md index 11f0f90ab7d81..4a108be6587bd 100644 --- a/superset/db_engine_specs/README.md +++ b/superset/db_engine_specs/README.md @@ -660,7 +660,7 @@ This way, when a user selects a column that doesn't exist Superset can return a ### Dynamic schema -In SQL Lab it's possible to select a database, and then a schema in that database. Ideally, when running a query in SQL Lab, any unqualified table names (eg, `table`, instead of `schema.table`) should be in the selected schema. For example, if the user select `dev` as the schema and then runs the following query: +In SQL Lab it's possible to select a database, and then a schema in that database. Ideally, when running a query in SQL Lab, any unqualified table names (eg, `table`, instead of `schema.table`) should be in the selected schema. For example, if the user selects `dev` as the schema and then runs the following query: ```sql SELECT * FROM my_table @@ -674,7 +674,7 @@ Implementing this method is also important for usability. When the method is not ### Catalog -In general, databases support a hierarchy of concepts of one-to-many concepts: +In general, databases support a hierarchy of one-to-many concepts: 1. Database 2. Catalog @@ -692,7 +692,7 @@ These concepts have different names depending on the database. For example, Post BigQuery, on the other hand: -1. Bigquery (database) +1. BigQuery (database) 2. Project (catalog) 3. Schema (namespace) 4. Table diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 1fc6a40a3a29e..3cc1315129571 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -61,7 +61,7 @@ from superset import sql_parse from superset.constants import TimeGrain as TimeGrainConstants -from superset.databases.utils import make_url_safe +from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.sql_parse import ParsedQuery, SQLScript, Table @@ -80,6 +80,7 @@ if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn + from superset.databases.schemas import TableMetadataResponse from superset.models.core import Database from superset.models.sql_lab import Query @@ -638,11 +639,11 @@ def supports_backend(cls, backend: str, driver: str | None = None) -> bool: return driver in cls.drivers @classmethod - def get_default_schema(cls, database: Database) -> str | None: + def get_default_schema(cls, database: Database, catalog: str | None) -> str | None: """ Return the default schema in a given database. """ - with database.get_inspector_with_context() as inspector: + with database.get_inspector(catalog=catalog) as inspector: return inspector.default_schema_name @classmethod @@ -697,7 +698,7 @@ def get_default_schema_for_query( return schema # return the default schema of the database - return cls.get_default_schema(database) + return cls.get_default_schema(database, query.catalog) @classmethod def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: @@ -760,18 +761,19 @@ def get_text_clause(cls, clause: str) -> TextClause: def get_engine( cls, database: Database, + catalog: str | None = None, schema: str | None = None, source: utils.QuerySource | None = None, ) -> ContextManager[Engine]: """ Return an engine context manager. - >>> with DBEngineSpec.get_engine(database, schema, source) as engine: + >>> with DBEngineSpec.get_engine(database, catalog, schema, source) as engine: ... connection = engine.connect() ... connection.execute(sql) """ - return database.get_sqla_engine(schema=schema, source=source) + return database.get_sqla_engine(catalog=catalog, schema=schema, source=source) @classmethod def get_timestamp_expr( @@ -1033,6 +1035,21 @@ def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any] """ return indexes + @classmethod + def get_table_metadata( + cls, + database: Database, + table: Table, + ) -> TableMetadataResponse: + """ + Returns basic table metadata + + :param database: Database instance + :param table: A Table instance + :return: Basic table metadata + """ + return get_table_metadata(database, table) + @classmethod def get_extra_table_metadata( cls, @@ -1236,7 +1253,11 @@ def df_to_sql( # Only add schema when it is preset and non-empty. to_sql_kwargs["schema"] = table.schema - with cls.get_engine(database) as engine: + with cls.get_engine( + database, + catalog=table.catalog, + schema=table.schema, + ) as engine: if engine.dialect.supports_multivalues_insert: to_sql_kwargs["method"] = "multi" @@ -1471,36 +1492,34 @@ def get_indexes( cls, database: Database, # pylint: disable=unused-argument inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, ) -> list[dict[str, Any]]: """ Get the indexes associated with the specified schema/table. :param database: The database to inspect :param inspector: The SQLAlchemy inspector - :param table_name: The table to inspect - :param schema: The schema to inspect + :param table: The table instance to inspect :returns: The indexes """ - return inspector.get_indexes(table_name, schema) + return inspector.get_indexes(table.table, table.schema) @classmethod def get_table_comment( - cls, inspector: Inspector, table_name: str, schema: str | None + cls, + inspector: Inspector, + table: Table, ) -> str | None: """ Get comment of table from a given schema and table - :param inspector: SqlAlchemy Inspector instance - :param table_name: Table name - :param schema: Schema name. If omitted, uses default schema for database + :param table: Table instance :return: comment of table """ comment = None try: - comment = inspector.get_table_comment(table_name, schema) + comment = inspector.get_table_comment(table.table, table.schema) comment = comment.get("text") if isinstance(comment, dict) else None except NotImplementedError: # It's expected that some dialects don't implement the comment method @@ -1514,22 +1533,25 @@ def get_table_comment( def get_columns( # pylint: disable=unused-argument cls, inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: """ - Get all columns from a given schema and table + Get all columns from a given schema and table. + + The inspector will be bound to a catalog, if one was specified. :param inspector: SqlAlchemy Inspector instance - :param table_name: Table name - :param schema: Schema name. If omitted, uses default schema for database + :param table: Table instance :param options: Extra options to customise the display of columns in some databases :return: All columns in table """ return convert_inspector_columns( - cast(list[SQLAColumnType], inspector.get_columns(table_name, schema)) + cast( + list[SQLAColumnType], + inspector.get_columns(table.table, table.schema), + ) ) @classmethod @@ -1537,8 +1559,7 @@ def get_metrics( # pylint: disable=unused-argument cls, database: Database, inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, ) -> list[MetricType]: """ Get all metrics from a given schema and table. @@ -1553,19 +1574,17 @@ def get_metrics( # pylint: disable=unused-argument ] @classmethod - def where_latest_partition( # pylint: disable=too-many-arguments,unused-argument + def where_latest_partition( # pylint: disable=unused-argument cls, - table_name: str, - schema: str | None, database: Database, + table: Table, query: Select, columns: list[ResultSetColumnType] | None = None, ) -> Select | None: """ Add a where clause to a query to reference only the most recent partition - :param table_name: Table name - :param schema: Schema name + :param table: Table instance :param database: Database instance :param query: SqlAlchemy query :param columns: List of TableColumns @@ -1588,9 +1607,8 @@ def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]: def select_star( # pylint: disable=too-many-arguments,too-many-locals cls, database: Database, - table_name: str, + table: Table, engine: Engine, - schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, @@ -1603,9 +1621,8 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals WARNING: expects only unquoted table and schema names. :param database: Database instance - :param table_name: Table name, unquoted + :param table: Table instance :param engine: SqlAlchemy Engine instance - :param schema: Schema, unquoted :param limit: limit to impose on query :param show_cols: Show columns in query; otherwise use "*" :param indent: Add indentation to query @@ -1617,16 +1634,18 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals fields: str | list[Any] = "*" cols = cols or [] if (show_cols or latest_partition) and not cols: - cols = database.get_columns(table_name, schema) + cols = database.get_columns(table) if show_cols: fields = cls._get_fields(cols) + quote = engine.dialect.identifier_preparer.quote quote_schema = engine.dialect.identifier_preparer.quote_schema - if schema: - full_table_name = quote_schema(schema) + "." + quote(table_name) - else: - full_table_name = quote(table_name) + full_table_name = ( + quote_schema(table.schema) + "." + quote(table.table) + if table.schema + else quote(table.table) + ) qry = select(fields).select_from(text(full_table_name)) @@ -1634,7 +1653,10 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals qry = qry.limit(limit) if latest_partition: partition_query = cls.where_latest_partition( - table_name, schema, database, qry, columns=cols + database, + table, + qry, + columns=cols, ) if partition_query is not None: qry = partition_query @@ -1685,9 +1707,10 @@ def process_statement(cls, statement: str, database: Database) -> str: return database.mutate_sql_based_on_config(sql, is_split=True) @classmethod - def estimate_query_cost( + def estimate_query_cost( # pylint: disable=too-many-arguments cls, database: Database, + catalog: str | None, schema: str, sql: str, source: utils.QuerySource | None = None, @@ -1709,14 +1732,19 @@ def estimate_query_cost( parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) statements = parsed_query.get_statements() - costs = [] - with database.get_raw_connection(schema=schema, source=source) as conn: + with database.get_raw_connection( + catalog=catalog, + schema=schema, + source=source, + ) as conn: cursor = conn.cursor() - for statement in statements: - processed_statement = cls.process_statement(statement, database) - costs.append(cls.estimate_statement_cost(processed_statement, cursor)) - - return costs + return [ + cls.estimate_statement_cost( + cls.process_statement(statement, database), + cursor, + ) + for statement in statements + ] @classmethod def get_url_for_impersonation( diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 78d845450fc01..8a2612f5b0983 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -14,13 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + import contextlib import json import re import urllib from datetime import datetime from re import Pattern -from typing import Any, Optional, TYPE_CHECKING, TypedDict +from typing import Any, TYPE_CHECKING, TypedDict import pandas as pd from apispec import APISpec @@ -220,8 +223,8 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.Date): return f"CAST('{dttm.date().isoformat()}' AS DATE)" @@ -234,9 +237,7 @@ def convert_dttm( return None @classmethod - def fetch_data( - cls, cursor: Any, limit: Optional[int] = None - ) -> list[tuple[Any, ...]]: + def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Support type BigQuery Row, introduced here PR #4071 # google.cloud.bigquery.table.Row @@ -302,30 +303,28 @@ def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any] @classmethod def get_indexes( cls, - database: "Database", + database: Database, inspector: Inspector, - table_name: str, - schema: Optional[str], + table: Table, ) -> list[dict[str, Any]]: """ Get the indexes associated with the specified schema/table. :param database: The database to inspect :param inspector: The SQLAlchemy inspector - :param table_name: The table to inspect - :param schema: The schema to inspect + :param table: The table instance to inspect :returns: The indexes """ - return cls.normalize_indexes(inspector.get_indexes(table_name, schema)) + return cls.normalize_indexes(inspector.get_indexes(table.table, table.schema)) @classmethod def get_extra_table_metadata( cls, - database: "Database", + database: Database, table: Table, ) -> dict[str, Any]: - indexes = database.get_indexes(table.table, table.schema) + indexes = database.get_indexes(table) if not indexes: return {} partitions_columns = [ @@ -354,7 +353,7 @@ def epoch_ms_to_dttm(cls) -> str: @classmethod def df_to_sql( cls, - database: "Database", + database: Database, table: Table, df: pd.DataFrame, to_sql_kwargs: dict[str, Any], @@ -380,7 +379,11 @@ def df_to_sql( raise SupersetException("The table schema must be defined") to_gbq_kwargs = {} - with cls.get_engine(database) as engine: + with cls.get_engine( + database, + catalog=table.catalog, + schema=table.schema, + ) as engine: to_gbq_kwargs = { "destination_table": str(table), "project_id": engine.url.host, @@ -403,7 +406,7 @@ def df_to_sql( pandas_gbq.to_gbq(df, **to_gbq_kwargs) @classmethod - def _get_client(cls, engine: Engine) -> Any: + def _get_client(cls, engine: Engine) -> bigquery.Client: """ Return the BigQuery client associated with an engine. """ @@ -418,17 +421,19 @@ def _get_client(cls, engine: Engine) -> Any: return bigquery.Client(credentials=credentials) @classmethod - def estimate_query_cost( + def estimate_query_cost( # pylint: disable=too-many-arguments cls, - database: "Database", + database: Database, + catalog: str | None, schema: str, sql: str, - source: Optional[utils.QuerySource] = None, + source: utils.QuerySource | None = None, ) -> list[dict[str, Any]]: """ Estimate the cost of a multiple statement SQL query. :param database: Database instance + :param catalog: Database project :param schema: Database schema :param sql: SQL query with possibly multiple statements :param source: Source of the query (eg, "sql_lab") @@ -439,17 +444,25 @@ def estimate_query_cost( parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) statements = parsed_query.get_statements() - costs = [] - for statement in statements: - processed_statement = cls.process_statement(statement, database) - costs.append(cls.estimate_statement_cost(processed_statement, database)) - return costs + with cls.get_engine( + database, + catalog=catalog, + schema=schema, + ) as engine: + client = cls._get_client(engine) + return [ + cls.custom_estimate_statement_cost( + cls.process_statement(statement, database), + client, + ) + for statement in statements + ] @classmethod def get_catalog_names( cls, - database: "Database", + database: Database, inspector: Inspector, ) -> list[str]: """ @@ -469,14 +482,16 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return True @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: - with cls.get_engine(cursor) as engine: - client = cls._get_client(engine) - job_config = bigquery.QueryJobConfig(dry_run=True) - query_job = client.query( - statement, - job_config=job_config, - ) # Make an API request. + def custom_estimate_statement_cost( + cls, + statement: str, + client: bigquery.Client, + ) -> dict[str, Any]: + """ + Custom version that receives a client instead of a cursor. + """ + job_config = bigquery.QueryJobConfig(dry_run=True) + query_job = client.query(statement, job_config=job_config) # Format Bytes. # TODO: Humanize in case more db engine specs need to be added, @@ -514,7 +529,7 @@ def query_cost_formatter( def build_sqlalchemy_uri( cls, parameters: BigQueryParametersType, - encrypted_extra: Optional[dict[str, Any]] = None, + encrypted_extra: dict[str, Any] | None = None, ) -> str: query = parameters.get("query", {}) query_params = urllib.parse.urlencode(query) @@ -536,7 +551,7 @@ def build_sqlalchemy_uri( def get_parameters_from_uri( cls, uri: str, - encrypted_extra: Optional[dict[str, Any]] = None, + encrypted_extra: dict[str, Any] | None = None, ) -> Any: value = make_url_safe(uri) @@ -549,7 +564,7 @@ def get_parameters_from_uri( raise ValidationError("Invalid service credentials") @classmethod - def mask_encrypted_extra(cls, encrypted_extra: Optional[str]) -> Optional[str]: + def mask_encrypted_extra(cls, encrypted_extra: str | None) -> str | None: if encrypted_extra is None: return encrypted_extra @@ -563,9 +578,7 @@ def mask_encrypted_extra(cls, encrypted_extra: Optional[str]) -> Optional[str]: return json.dumps(config) @classmethod - def unmask_encrypted_extra( - cls, old: Optional[str], new: Optional[str] - ) -> Optional[str]: + def unmask_encrypted_extra(cls, old: str | None, new: str | None) -> str | None: """ Reuse ``private_key`` if available and unchanged. """ @@ -628,15 +641,14 @@ def parameters_json_schema(cls) -> Any: @classmethod def select_star( # pylint: disable=too-many-arguments cls, - database: "Database", - table_name: str, + database: Database, + table: Table, engine: Engine, - schema: Optional[str] = None, limit: int = 100, show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: Optional[list[ResultSetColumnType]] = None, + cols: list[ResultSetColumnType] | None = None, ) -> str: """ Remove array structures from `SELECT *`. @@ -690,9 +702,8 @@ def select_star( # pylint: disable=too-many-arguments return super().select_star( database, - table_name, + table, engine, - schema, limit, show_cols, indent, diff --git a/superset/db_engine_specs/db2.py b/superset/db_engine_specs/db2.py index db2e500b53d8f..b2151767d2d72 100644 --- a/superset/db_engine_specs/db2.py +++ b/superset/db_engine_specs/db2.py @@ -21,6 +21,7 @@ from superset.constants import TimeGrain from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.sql_parse import Table logger = logging.getLogger(__name__) @@ -64,7 +65,9 @@ def epoch_to_dttm(cls) -> str: @classmethod def get_table_comment( - cls, inspector: Inspector, table_name: str, schema: Union[str, None] + cls, + inspector: Inspector, + table: Table, ) -> Optional[str]: """ Get comment of table from a given schema @@ -72,13 +75,12 @@ def get_table_comment( Ibm Db2 return comments as tuples, so we need to get the first element :param inspector: SqlAlchemy Inspector instance - :param table_name: Table name - :param schema: Schema name. If omitted, uses default schema for database + :param table: Table instance :return: comment of table """ comment = None try: - table_comment = inspector.get_table_comment(table_name, schema) + table_comment = inspector.get_table_comment(table.table, table.schema) comment = table_comment.get("text") return comment[0] except IndexError: diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 4aed2693f4676..7606e93b5009f 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -142,7 +142,10 @@ def get_extra_table_metadata( database: Database, table: Table, ) -> dict[str, Any]: - with database.get_raw_connection(schema=table.schema) as conn: + with database.get_raw_connection( + catalog=table.catalog, + schema=table.schema, + ) as conn: cursor = conn.cursor() cursor.execute(f'SELECT GET_METADATA("{table.table}")') results = cursor.fetchone()[0] @@ -395,7 +398,11 @@ def df_to_sql( # pylint: disable=too-many-locals pass # get the Google session from the Shillelagh adapter - with cls.get_engine(database) as engine: + with cls.get_engine( + database, + catalog=table.catalog, + schema=table.schema, + ) as engine: with engine.connect() as conn: # any GSheets URL will work to get a working session adapter = get_adapter_for_table_name( diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index a10f5f66bc9f4..80892b59877de 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -205,7 +205,11 @@ def df_to_sql( if table_exists: raise SupersetException("Table already exists") elif to_sql_kwargs["if_exists"] == "replace": - with cls.get_engine(database) as engine: + with cls.get_engine( + database, + catalog=table.catalog, + schema=table.schema, + ) as engine: engine.execute(f"DROP TABLE IF EXISTS {str(table)}") def _get_hive_type(dtype: np.dtype[Any]) -> str: @@ -227,7 +231,11 @@ def _get_hive_type(dtype: np.dtype[Any]) -> str: ) as file: pq.write_table(pa.Table.from_pandas(df), where=file.name) - with cls.get_engine(database) as engine: + with cls.get_engine( + database, + catalog=table.catalog, + schema=table.schema, + ) as engine: engine.execute( text( f""" @@ -410,24 +418,24 @@ def handle_cursor( # pylint: disable=too-many-locals def get_columns( cls, inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: - return BaseEngineSpec.get_columns(inspector, table_name, schema, options) + return BaseEngineSpec.get_columns(inspector, table, options) @classmethod - def where_latest_partition( # pylint: disable=too-many-arguments + def where_latest_partition( cls, - table_name: str, - schema: str | None, database: Database, + table: Table, query: Select, columns: list[ResultSetColumnType] | None = None, ) -> Select | None: try: col_names, values = cls.latest_partition( - table_name, schema, database, show_first=True + database, + table, + show_first=True, ) except Exception: # pylint: disable=broad-except # table is not partitioned @@ -447,7 +455,10 @@ def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]: @classmethod def latest_sub_partition( # type: ignore - cls, table_name: str, schema: str | None, database: Database, **kwargs: Any + cls, + database: Database, + table: Table, + **kwargs: Any, ) -> str: # TODO(bogdan): implement` pass @@ -465,24 +476,24 @@ def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None: @classmethod def _partition_query( # pylint: disable=too-many-arguments cls, - table_name: str, - schema: str | None, + table: Table, indexes: list[dict[str, Any]], database: Database, limit: int = 0, order_by: list[tuple[str, bool]] | None = None, filters: dict[Any, Any] | None = None, ) -> str: - full_table_name = f"{schema}.{table_name}" if schema else table_name + full_table_name = ( + f"{table.schema}.{table.table}" if table.schema else table.table + ) return f"SHOW PARTITIONS {full_table_name}" @classmethod def select_star( # pylint: disable=too-many-arguments cls, database: Database, - table_name: str, + table: Table, engine: Engine, - schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, @@ -491,9 +502,8 @@ def select_star( # pylint: disable=too-many-arguments ) -> str: return super(PrestoEngineSpec, cls).select_star( database, - table_name, + table, engine, - schema, limit, show_cols, indent, diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 8a803d3f14018..34c47eb522c00 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -420,8 +420,7 @@ def get_function_names(cls, database: Database) -> list[str]: @classmethod def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unused-argument cls, - table_name: str, - schema: str | None, + table: Table, indexes: list[dict[str, Any]], database: Database, limit: int = 0, @@ -434,8 +433,7 @@ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unus Note the unused arguments are exposed for sub-classing purposes where custom integrations may require the schema, indexes, etc. to build the partition query. - :param table_name: the name of the table to get partitions from - :param schema: the schema name + :param table: the table instance :param indexes: the indexes associated with the table :param database: the database the query will be run against :param limit: the number of partitions to be returned @@ -464,12 +462,16 @@ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unus presto_version = database.get_extra().get("version") if presto_version and Version(presto_version) < Version("0.199"): - full_table_name = f"{schema}.{table_name}" if schema else table_name + full_table_name = ( + f"{table.schema}.{table.table}" if table.schema else table.table + ) partition_select_clause = f"SHOW PARTITIONS FROM {full_table_name}" else: - system_table_name = f'"{table_name}$partitions"' + system_table_name = f'"{table.table}$partitions"' full_table_name = ( - f"{schema}.{system_table_name}" if schema else system_table_name + f"{table.schema}.{system_table_name}" + if table.schema + else system_table_name ) partition_select_clause = f"SELECT * FROM {full_table_name}" @@ -484,18 +486,15 @@ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unus return sql @classmethod - def where_latest_partition( # pylint: disable=too-many-arguments + def where_latest_partition( cls, - table_name: str, - schema: str | None, database: Database, + table: Table, query: Select, columns: list[ResultSetColumnType] | None = None, ) -> Select | None: try: - col_names, values = cls.latest_partition( - table_name, schema, database, show_first=True - ) + col_names, values = cls.latest_partition(database, table, show_first=True) except Exception: # pylint: disable=broad-except # table is not partitioned return None @@ -527,18 +526,16 @@ def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None: @classmethod @cache_manager.data_cache.memoize(timeout=60) - def latest_partition( # pylint: disable=too-many-arguments + def latest_partition( cls, - table_name: str, - schema: str | None, database: Database, + table: Table, show_first: bool = False, indexes: list[dict[str, Any]] | None = None, ) -> tuple[list[str], list[str] | None]: """Returns col name and the latest (max) partition value for a table - :param table_name: the name of the table - :param schema: schema / database / namespace + :param table: the table instance :param database: database query will be run against :type database: models.Database :param show_first: displays the value for the first partitioning key @@ -550,11 +547,11 @@ def latest_partition( # pylint: disable=too-many-arguments (['ds'], ('2018-01-01',)) """ if indexes is None: - indexes = database.get_indexes(table_name, schema) + indexes = database.get_indexes(table) if not indexes: raise SupersetTemplateException( - f"Error getting partition for {schema}.{table_name}. " + f"Error getting partition for {table}. " "Verify that this table has a partition." ) @@ -575,20 +572,23 @@ def latest_partition( # pylint: disable=too-many-arguments return column_names, cls._latest_partition_from_df( df=database.get_df( sql=cls._partition_query( - table_name, - schema, + table, indexes, database, limit=1, order_by=[(column_name, True) for column_name in column_names], ), - schema=schema, + catalog=table.catalog, + schema=table.schema, ) ) @classmethod def latest_sub_partition( - cls, table_name: str, schema: str | None, database: Database, **kwargs: Any + cls, + database: Database, + table: Table, + **kwargs: Any, ) -> Any: """Returns the latest (max) partition value for a table @@ -601,12 +601,9 @@ def latest_sub_partition( ``latest_sub_partition('my_table', event_category='page', event_type='click')`` - :param table_name: the name of the table, can be just the table - name or a fully qualified table name as ``schema_name.table_name`` - :type table_name: str - :param schema: schema / database / namespace - :type schema: str :param database: database query will be run against + :param table: the table instance + :type table: Table :type database: models.Database :param kwargs: keyword arguments define the filtering criteria @@ -615,7 +612,7 @@ def latest_sub_partition( >>> latest_sub_partition('sub_partition_table', event_type='click') '2018-01-01' """ - indexes = database.get_indexes(table_name, schema) + indexes = database.get_indexes(table) part_fields = indexes[0]["column_names"] for k in kwargs.keys(): # pylint: disable=consider-iterating-dictionary if k not in k in part_fields: # pylint: disable=comparison-with-itself @@ -633,15 +630,14 @@ def latest_sub_partition( field_to_return = field sql = cls._partition_query( - table_name, - schema, + table, indexes, database, limit=1, order_by=[(field_to_return, True)], filters=kwargs, ) - df = database.get_df(sql, schema) + df = database.get_df(sql, table.catalog, table.schema) if df.empty: return "" return df.to_dict()[field_to_return][0] @@ -966,40 +962,39 @@ def _parse_structural_column( # pylint: disable=too-many-locals @classmethod def _show_columns( - cls, inspector: Inspector, table_name: str, schema: str | None + cls, + inspector: Inspector, + table: Table, ) -> list[ResultRow]: """ Show presto column names :param inspector: object that performs database schema inspection - :param table_name: table name - :param schema: schema name + :param table: table instance :return: list of column objects """ quote = inspector.engine.dialect.identifier_preparer.quote_identifier - full_table = quote(table_name) - if schema: - full_table = f"{quote(schema)}.{full_table}" + full_table = quote(table.table) + if table.schema: + full_table = f"{quote(table.schema)}.{full_table}" return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall() @classmethod def get_columns( cls, inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: """ Get columns from a Presto data source. This includes handling row and array data types :param inspector: object that performs database schema inspection - :param table_name: table name - :param schema: schema name + :param table: table instance :param options: Extra configuration options, not used by this backend :return: a list of results that contain column info (i.e. column name and data type) """ - columns = cls._show_columns(inspector, table_name, schema) + columns = cls._show_columns(inspector, table) result: list[ResultSetColumnType] = [] for column in columns: # parse column if it is a row or array @@ -1077,9 +1072,8 @@ def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]: def select_star( # pylint: disable=too-many-arguments cls, database: Database, - table_name: str, + table: Table, engine: Engine, - schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, @@ -1102,9 +1096,8 @@ def select_star( # pylint: disable=too-many-arguments ] return super().select_star( database, - table_name, + table, engine, - schema, limit, show_cols, indent, @@ -1232,11 +1225,10 @@ def get_extra_table_metadata( ) -> dict[str, Any]: metadata = {} - if indexes := database.get_indexes(table.table, table.schema): + if indexes := database.get_indexes(table): col_names, latest_parts = cls.latest_partition( - table.table, - table.schema, database, + table, show_first=True, indexes=indexes, ) @@ -1248,8 +1240,7 @@ def get_extra_table_metadata( "cols": sorted(indexes[0].get("column_names", [])), "latest": dict(zip(col_names, latest_parts)), "partitionQuery": cls._partition_query( - table_name=table.table, - schema=table.schema, + table=table, indexes=indexes, database=database, ), diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 32c119ec1f645..08a38894e6645 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -40,12 +40,12 @@ ) from superset.db_engine_specs.presto import PrestoBaseEngineSpec from superset.models.sql_lab import Query +from superset.sql_parse import Table from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils if TYPE_CHECKING: from superset.models.core import Database - from superset.sql_parse import Table with contextlib.suppress(ImportError): # trino may not be installed from trino.dbapi import Cursor @@ -66,11 +66,10 @@ def get_extra_table_metadata( ) -> dict[str, Any]: metadata = {} - if indexes := database.get_indexes(table.table, table.schema): + if indexes := database.get_indexes(table): col_names, latest_parts = cls.latest_partition( - table.table, - table.schema, database, + table, show_first=True, indexes=indexes, ) @@ -91,15 +90,17 @@ def get_extra_table_metadata( ), "latest": dict(zip(col_names, latest_parts)), "partitionQuery": cls._partition_query( - table_name=table.table, - schema=table.schema, + table=table, indexes=indexes, database=database, ), } - if database.has_view_by_name(table.table, table.schema): - with database.get_inspector_with_context() as inspector: + if database.has_view(Table(table.table, table.schema)): + with database.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: metadata["view"] = inspector.get_view_definition( table.table, table.schema, @@ -414,8 +415,7 @@ def _expand_columns(cls, col: ResultSetColumnType) -> list[ResultSetColumnType]: def get_columns( cls, inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: """ @@ -423,7 +423,7 @@ def get_columns( "schema_options", expand the schema definition out to show all subfields of nested ROWs as their appropriate dotted paths. """ - base_cols = super().get_columns(inspector, table_name, schema, options) + base_cols = super().get_columns(inspector, table, options) if not (options or {}).get("expand_rows"): return base_cols @@ -434,8 +434,7 @@ def get_indexes( cls, database: Database, inspector: Inspector, - table_name: str, - schema: str | None, + table: Table, ) -> list[dict[str, Any]]: """ Get the indexes associated with the specified schema/table. @@ -444,11 +443,10 @@ def get_indexes( :param database: The database to inspect :param inspector: The SQLAlchemy inspector - :param table_name: The table to inspect - :param schema: The schema to inspect + :param table: The table instance to inspect :returns: The indexes """ try: - return super().get_indexes(database, inspector, table_name, schema) + return super().get_indexes(database, inspector, table) except NoSuchTableError: return [] diff --git a/superset/examples/bart_lines.py b/superset/examples/bart_lines.py index 9ce27d4952304..efbb83020156b 100644 --- a/superset/examples/bart_lines.py +++ b/superset/examples/bart_lines.py @@ -21,6 +21,7 @@ from sqlalchemy import inspect, String, Text from superset import db +from superset.sql_parse import Table from ..utils.database import get_example_database from .helpers import get_example_url, get_table_connector_registry @@ -31,7 +32,7 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None: database = get_example_database() with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + table_exists = database.has_table(Table(tbl_name, schema)) if not only_metadata and (not table_exists or force): url = get_example_url("bart-lines.json.gz") diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index 2e711bef290fc..7b7928c53210e 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -27,6 +27,7 @@ from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.sql_parse import Table from superset.utils.core import DatasourceType from ..utils.database import get_example_database @@ -95,7 +96,7 @@ def load_birth_names( schema = inspect(engine).default_schema_name tbl_name = "birth_names" - table_exists = database.has_table_by_name(tbl_name, schema=schema) + table_exists = database.has_table(Table(tbl_name, schema)) if not only_metadata and (not table_exists or force): load_data(tbl_name, database, sample=sample) diff --git a/superset/examples/country_map.py b/superset/examples/country_map.py index 59c257bc80b77..1741219470ac3 100644 --- a/superset/examples/country_map.py +++ b/superset/examples/country_map.py @@ -24,6 +24,7 @@ from superset import db from superset.connectors.sqla.models import SqlMetric from superset.models.slice import Slice +from superset.sql_parse import Table from superset.utils.core import DatasourceType from .helpers import ( @@ -42,7 +43,7 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + table_exists = database.has_table(Table(tbl_name, schema)) if not only_metadata and (not table_exists or force): url = get_example_url("birth_france_data_for_country_map.csv") diff --git a/superset/examples/energy.py b/superset/examples/energy.py index 16d4eea3741c9..98b444f9db2f6 100644 --- a/superset/examples/energy.py +++ b/superset/examples/energy.py @@ -26,6 +26,7 @@ from superset import db from superset.connectors.sqla.models import SqlMetric from superset.models.slice import Slice +from superset.sql_parse import Table from superset.utils.core import DatasourceType from .helpers import ( @@ -45,7 +46,7 @@ def load_energy( with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + table_exists = database.has_table(Table(tbl_name, schema)) if not only_metadata and (not table_exists or force): url = get_example_url("energy.json.gz") diff --git a/superset/examples/flights.py b/superset/examples/flights.py index 1e22fed468828..4db029519fd8b 100644 --- a/superset/examples/flights.py +++ b/superset/examples/flights.py @@ -19,6 +19,7 @@ import superset.utils.database as database_utils from superset import db +from superset.sql_parse import Table from .helpers import get_example_url, get_table_connector_registry @@ -29,7 +30,7 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None: database = database_utils.get_example_database() with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + table_exists = database.has_table(Table(tbl_name, schema)) if not only_metadata and (not table_exists or force): flight_data_url = get_example_url("flight_data.csv.gz") diff --git a/superset/examples/long_lat.py b/superset/examples/long_lat.py index 95cccadc24089..4f8de31453c18 100644 --- a/superset/examples/long_lat.py +++ b/superset/examples/long_lat.py @@ -24,6 +24,7 @@ import superset.utils.database as database_utils from superset import db from superset.models.slice import Slice +from superset.sql_parse import Table from superset.utils.core import DatasourceType from .helpers import ( @@ -41,7 +42,7 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None database = database_utils.get_example_database() with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + table_exists = database.has_table(Table(tbl_name, schema)) if not only_metadata and (not table_exists or force): url = get_example_url("san_francisco.csv.gz") diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py index 91799b2c2cf61..979be10686f5a 100644 --- a/superset/examples/multiformat_time_series.py +++ b/superset/examples/multiformat_time_series.py @@ -21,6 +21,7 @@ from superset import app, db from superset.models.slice import Slice +from superset.sql_parse import Table from superset.utils.core import DatasourceType from ..utils.database import get_example_database @@ -41,7 +42,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals database = get_example_database() with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + table_exists = database.has_table(Table(tbl_name, schema)) if not only_metadata and (not table_exists or force): url = get_example_url("multiformat_time_series.json.gz") diff --git a/superset/examples/paris.py b/superset/examples/paris.py index cea784be7754d..1cd6c84d92d1d 100644 --- a/superset/examples/paris.py +++ b/superset/examples/paris.py @@ -21,6 +21,7 @@ import superset.utils.database as database_utils from superset import db +from superset.sql_parse import Table from .helpers import get_example_url, get_table_connector_registry @@ -30,7 +31,7 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> database = database_utils.get_example_database() with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + table_exists = database.has_table(Table(tbl_name, schema)) if not only_metadata and (not table_exists or force): url = get_example_url("paris_iris.json.gz") diff --git a/superset/examples/random_time_series.py b/superset/examples/random_time_series.py index 9b5306781dfc2..ec232995fa2e7 100644 --- a/superset/examples/random_time_series.py +++ b/superset/examples/random_time_series.py @@ -21,6 +21,7 @@ import superset.utils.database as database_utils from superset import app, db from superset.models.slice import Slice +from superset.sql_parse import Table from superset.utils.core import DatasourceType from .helpers import ( @@ -39,7 +40,7 @@ def load_random_time_series_data( database = database_utils.get_example_database() with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + table_exists = database.has_table(Table(tbl_name, schema)) if not only_metadata and (not table_exists or force): url = get_example_url("random_time_series.json.gz") diff --git a/superset/examples/sf_population_polygons.py b/superset/examples/sf_population_polygons.py index d97ffd3ae51ed..d4754887c72fa 100644 --- a/superset/examples/sf_population_polygons.py +++ b/superset/examples/sf_population_polygons.py @@ -21,6 +21,7 @@ import superset.utils.database as database_utils from superset import db +from superset.sql_parse import Table from .helpers import get_example_url, get_table_connector_registry @@ -32,7 +33,7 @@ def load_sf_population_polygons( database = database_utils.get_example_database() with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + table_exists = database.has_table(Table(tbl_name, schema)) if not only_metadata and (not table_exists or force): url = get_example_url("sf_population.json.gz") diff --git a/superset/examples/supported_charts_dashboard.py b/superset/examples/supported_charts_dashboard.py index 371f03d18b13a..ae0962fc17245 100644 --- a/superset/examples/supported_charts_dashboard.py +++ b/superset/examples/supported_charts_dashboard.py @@ -26,6 +26,7 @@ from superset.connectors.sqla.models import SqlaTable from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.sql_parse import Table from superset.utils.core import DatasourceType from ..utils.database import get_example_database @@ -443,7 +444,7 @@ def load_supported_charts_dashboard() -> None: schema = inspect(engine).default_schema_name tbl_name = "birth_names" - table_exists = database.has_table_by_name(tbl_name, schema=schema) + table_exists = database.has_table(Table(tbl_name, schema)) if table_exists: table = get_table_connector_registry() diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index c98c1fc11ceed..7b6b3749213ca 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -37,6 +37,7 @@ ) from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.sql_parse import Table from superset.utils import core as utils from superset.utils.core import DatasourceType @@ -51,7 +52,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s database = superset.utils.database.get_example_database() with database.get_sqla_engine() as engine: schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + table_exists = database.has_table(Table(tbl_name, schema)) if not only_metadata and (not table_exists or force): url = get_example_url("countries.json.gz") diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py index 0d33ac97e2be1..2d8444cc9930a 100644 --- a/superset/extensions/metadb.py +++ b/superset/extensions/metadb.py @@ -271,6 +271,8 @@ def __init__( self.catalog = parts.pop(-1) if parts else None if self.catalog: + # TODO (betodealmeida): when SIP-95 is implemented we should check to see if + # the database has multi-catalog enabled, and if so, give access. raise NotImplementedError("Catalogs are not currently supported") # If the table has a single integer primary key we use that as the row ID in order @@ -314,7 +316,8 @@ def _set_columns(self) -> None: # store this callable for later whenever we need an engine self.engine_context = partial( database.get_sqla_engine, - self.schema, + catalog=self.catalog, + schema=self.schema, ) # fetch column names and types diff --git a/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py b/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py new file mode 100644 index 0000000000000..ec5733e151044 --- /dev/null +++ b/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add catalog column + +Revision ID: 5f57af97bc3f +Revises: d60591c5515f +Create Date: 2024-04-11 15:41:34.663989 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "5f57af97bc3f" +down_revision = "d60591c5515f" + + +def upgrade(): + op.add_column("tables", sa.Column("catalog", sa.String(length=256), nullable=True)) + op.add_column("query", sa.Column("catalog", sa.String(length=256), nullable=True)) + op.add_column( + "saved_query", + sa.Column("catalog", sa.String(length=256), nullable=True), + ) + op.add_column( + "tab_state", + sa.Column("catalog", sa.String(length=256), nullable=True), + ) + op.add_column( + "table_schema", + sa.Column("catalog", sa.String(length=256), nullable=True), + ) + + +def downgrade(): + op.drop_column("table_schema", "catalog") + op.drop_column("tab_state", "catalog") + op.drop_column("saved_query", "catalog") + op.drop_column("query", "catalog") + op.drop_column("tables", "catalog") diff --git a/superset/models/core.py b/superset/models/core.py index 42f6a78244f1d..9a4a1de40376c 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines, too-many-arguments + """A collection of ORM sqlalchemy models for Superset""" from __future__ import annotations @@ -46,7 +47,7 @@ Integer, MetaData, String, - Table, + Table as SqlaTable, Text, ) from sqlalchemy.engine import Connection, Dialect, Engine @@ -73,6 +74,7 @@ ) from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.result_set import SupersetResultSet +from superset.sql_parse import Table from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType from superset.utils import cache as cache_util, core as utils from superset.utils.backports import StrEnum @@ -382,13 +384,22 @@ def get_effective_user(self, object_url: URL) -> str | None: ) @contextmanager - def get_sqla_engine( + def get_sqla_engine( # pylint: disable=too-many-arguments self, + catalog: str | None = None, schema: str | None = None, nullpool: bool = True, source: utils.QuerySource | None = None, override_ssh_tunnel: SSHTunnel | None = None, ) -> Engine: + """ + Context manager for a SQLAlchemy engine. + + This method will return a context manager for a SQLAlchemy engine. Using the + context manager (as opposed to the engine directly) is important because we need + to potentially establish SSH tunnels before the connection is created, and clean + them up once the engine is no longer used. + """ from superset.daos.database import ( # pylint: disable=import-outside-toplevel DatabaseDAO, ) @@ -403,7 +414,7 @@ def get_sqla_engine( # if ssh_tunnel is available build engine with information engine_context = ssh_manager_factory.instance.create_tunnel( ssh_tunnel=ssh_tunnel, - sqlalchemy_database_uri=self.sqlalchemy_uri_decrypted, + sqlalchemy_database_uri=sqlalchemy_uri, ) with engine_context as server_context: @@ -415,22 +426,21 @@ def get_sqla_engine( server_context.local_bind_address, ) sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url( - sqlalchemy_uri, server_context + sqlalchemy_uri, + server_context, ) + yield self._get_sqla_engine( + catalog=catalog, schema=schema, nullpool=nullpool, source=source, sqlalchemy_uri=sqlalchemy_uri, ) - # The `get_sqla_engine_with_context` was renamed to `get_sqla_engine`, but we kept a - # reference to the old method to prevent breaking third-party applications. - # TODO (betodealmeida): Remove in 5.0 - get_sqla_engine_with_context = get_sqla_engine - def _get_sqla_engine( self, + catalog: str | None = None, schema: str | None = None, nullpool: bool = True, source: utils.QuerySource | None = None, @@ -447,26 +457,10 @@ def _get_sqla_engine( params["poolclass"] = NullPool connect_args = params.get("connect_args", {}) - # The ``adjust_database_uri`` method was renamed to ``adjust_engine_params`` and - # had its signature changed in order to support more DB engine specs. Since DB - # engine specs can be released as 3rd party modules we want to make sure the old - # method is still supported so we don't introduce a breaking change. - if hasattr(self.db_engine_spec, "adjust_database_uri"): - sqlalchemy_url = self.db_engine_spec.adjust_database_uri( - sqlalchemy_url, - schema, - ) - logger.warning( - "DB engine spec %s implements the method `adjust_database_uri`, which is " - "deprecated and will be removed in version 3.0. Please update it to " - "implement `adjust_engine_params` instead.", - self.db_engine_spec, - ) - sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params( uri=sqlalchemy_url, connect_args=connect_args, - catalog=None, + catalog=catalog, schema=schema, ) @@ -532,17 +526,24 @@ def _get_sqla_engine( @contextmanager def get_raw_connection( self, + catalog: str | None = None, schema: str | None = None, nullpool: bool = True, source: utils.QuerySource | None = None, ) -> Connection: with self.get_sqla_engine( - schema=schema, nullpool=nullpool, source=source + catalog=catalog, + schema=schema, + nullpool=nullpool, + source=source, ) as engine: with closing(engine.raw_connection()) as conn: # pre-session queries are used to set the selected schema and, in the # future, the selected catalog - for prequery in self.db_engine_spec.get_prequeries(schema=schema): + for prequery in self.db_engine_spec.get_prequeries( + catalog=catalog, + schema=schema, + ): cursor = conn.cursor() cursor.execute(prequery) @@ -606,14 +607,15 @@ def mutate_sql_based_on_config(self, sql_: str, is_split: bool = False) -> str: ) return sql_ - def get_df( + def get_df( # pylint: disable=too-many-locals self, sql: str, + catalog: str | None = None, schema: str | None = None, mutator: Callable[[pd.DataFrame], None] | None = None, ) -> pd.DataFrame: sqls = self.db_engine_spec.parse_sql(sql) - with self.get_sqla_engine(schema) as engine: + with self.get_sqla_engine(catalog=catalog, schema=schema) as engine: engine_url = engine.url def _log_query(sql: str) -> None: @@ -626,7 +628,7 @@ def _log_query(sql: str) -> None: security_manager, ) - with self.get_raw_connection(schema=schema) as conn: + with self.get_raw_connection(catalog=catalog, schema=schema) as conn: cursor = conn.cursor() df = None for i, sql_ in enumerate(sqls): @@ -653,8 +655,13 @@ def _log_query(sql: str) -> None: return self.post_process_df(df) - def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str: - with self.get_sqla_engine(schema) as engine: + def compile_sqla_query( + self, + qry: Select, + catalog: str | None = None, + schema: str | None = None, + ) -> str: + with self.get_sqla_engine(catalog=catalog, schema=schema) as engine: sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) # pylint: disable=protected-access @@ -665,8 +672,7 @@ def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str: def select_star( # pylint: disable=too-many-arguments self, - table_name: str, - schema: str | None = None, + table: Table, limit: int = 100, show_cols: bool = False, indent: bool = True, @@ -674,11 +680,10 @@ def select_star( # pylint: disable=too-many-arguments cols: list[ResultSetColumnType] | None = None, ) -> str: """Generates a ``select *`` statement in the proper dialect""" - with self.get_sqla_engine(schema) as engine: + with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine: return self.db_engine_spec.select_star( self, - table_name, - schema=schema, + table, engine=engine, limit=limit, show_cols=show_cols, @@ -703,6 +708,7 @@ def safe_sqlalchemy_uri(self) -> str: ) def get_all_table_names_in_schema( # pylint: disable=unused-argument self, + catalog: str | None, schema: str, cache: bool = False, cache_timeout: int | None = None, @@ -720,7 +726,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument :return: The table/schema pairs """ try: - with self.get_inspector_with_context() as inspector: + with self.get_inspector(catalog=catalog, schema=schema) as inspector: return { (table, schema) for table in self.db_engine_spec.get_table_names( @@ -738,6 +744,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument ) def get_all_view_names_in_schema( # pylint: disable=unused-argument self, + catalog: str | None, schema: str, cache: bool = False, cache_timeout: int | None = None, @@ -755,7 +762,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument :return: set of views """ try: - with self.get_inspector_with_context() as inspector: + with self.get_inspector(catalog=catalog, schema=schema) as inspector: return { (view, schema) for view in self.db_engine_spec.get_view_names( @@ -768,10 +775,17 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument raise self.db_engine_spec.get_dbapi_mapped_exception(ex) @contextmanager - def get_inspector_with_context( - self, ssh_tunnel: SSHTunnel | None = None + def get_inspector( + self, + catalog: str | None = None, + schema: str | None = None, + ssh_tunnel: SSHTunnel | None = None, ) -> Inspector: - with self.get_sqla_engine(override_ssh_tunnel=ssh_tunnel) as engine: + with self.get_sqla_engine( + catalog=catalog, + schema=schema, + override_ssh_tunnel=ssh_tunnel, + ) as engine: yield sqla.inspect(engine) @cache_util.memoized_func( @@ -780,6 +794,7 @@ def get_inspector_with_context( ) def get_all_schema_names( # pylint: disable=unused-argument self, + catalog: str | None = None, cache: bool = False, cache_timeout: int | None = None, force: bool = False, @@ -796,7 +811,10 @@ def get_all_schema_names( # pylint: disable=unused-argument :return: schema list """ try: - with self.get_inspector_with_context(ssh_tunnel=ssh_tunnel) as inspector: + with self.get_inspector( + catalog=catalog, + ssh_tunnel=ssh_tunnel, + ) as inspector: return self.db_engine_spec.get_schema_names(inspector) except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @@ -848,51 +866,57 @@ def get_encrypted_extra(self) -> dict[str, Any]: def update_params_from_encrypted_extra(self, params: dict[str, Any]) -> None: self.db_engine_spec.update_params_from_encrypted_extra(self, params) - def get_table(self, table_name: str, schema: str | None = None) -> Table: + def get_table(self, table: Table) -> SqlaTable: extra = self.get_extra() meta = MetaData(**extra.get("metadata_params", {})) - with self.get_sqla_engine() as engine: - return Table( - table_name, + with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine: + return SqlaTable( + table.table, meta, - schema=schema or None, + schema=table.schema or None, autoload=True, autoload_with=engine, ) - def get_table_comment( - self, table_name: str, schema: str | None = None - ) -> str | None: - with self.get_inspector_with_context() as inspector: - return self.db_engine_spec.get_table_comment(inspector, table_name, schema) - - def get_columns( - self, table_name: str, schema: str | None = None - ) -> list[ResultSetColumnType]: - with self.get_inspector_with_context() as inspector: + def get_table_comment(self, table: Table) -> str | None: + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + return self.db_engine_spec.get_table_comment(inspector, table) + + def get_columns(self, table: Table) -> list[ResultSetColumnType]: + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: return self.db_engine_spec.get_columns( - inspector, table_name, schema, self.schema_options + inspector, table, self.schema_options ) def get_metrics( self, - table_name: str, - schema: str | None = None, + table: Table, ) -> list[MetricType]: - with self.get_inspector_with_context() as inspector: - return self.db_engine_spec.get_metrics(self, inspector, table_name, schema) - - def get_indexes( - self, table_name: str, schema: str | None = None - ) -> list[dict[str, Any]]: - with self.get_inspector_with_context() as inspector: - return self.db_engine_spec.get_indexes(self, inspector, table_name, schema) - - def get_pk_constraint( - self, table_name: str, schema: str | None = None - ) -> dict[str, Any]: - with self.get_inspector_with_context() as inspector: - pk_constraint = inspector.get_pk_constraint(table_name, schema) or {} + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + return self.db_engine_spec.get_metrics(self, inspector, table) + + def get_indexes(self, table: Table) -> list[dict[str, Any]]: + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + return self.db_engine_spec.get_indexes(self, inspector, table) + + def get_pk_constraint(self, table: Table) -> dict[str, Any]: + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + pk_constraint = inspector.get_pk_constraint(table.table, table.schema) or {} def _convert(value: Any) -> Any: try: @@ -902,11 +926,12 @@ def _convert(value: Any) -> Any: return {key: _convert(value) for key, value in pk_constraint.items()} - def get_foreign_keys( - self, table_name: str, schema: str | None = None - ) -> list[dict[str, Any]]: - with self.get_inspector_with_context() as inspector: - return inspector.get_foreign_keys(table_name, schema) + def get_foreign_keys(self, table: Table) -> list[dict[str, Any]]: + with self.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: + return inspector.get_foreign_keys(table.table, table.schema) def get_schema_access_for_file_upload( # pylint: disable=invalid-name self, @@ -955,36 +980,23 @@ def get_perm(self) -> str: return self.perm # type: ignore def has_table(self, table: Table) -> bool: - with self.get_sqla_engine() as engine: - return engine.has_table(table.table_name, table.schema or None) + with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine: + # do not pass "" as an empty schema; force null + return engine.has_table(table.table, table.schema or None) - def has_table_by_name(self, table_name: str, schema: str | None = None) -> bool: - with self.get_sqla_engine() as engine: - return engine.has_table(table_name, schema) - - @classmethod - def _has_view( - cls, - conn: Connection, - dialect: Dialect, - view_name: str, - schema: str | None = None, - ) -> bool: - view_names: list[str] = [] - try: - view_names = dialect.get_view_names(connection=conn, schema=schema) - except Exception: # pylint: disable=broad-except - logger.warning("Has view failed", exc_info=True) - return view_name in view_names - - def has_view(self, view_name: str, schema: str | None = None) -> bool: - with self.get_sqla_engine(schema) as engine: - return engine.run_callable( - self._has_view, engine.dialect, view_name, schema - ) + def has_view(self, table: Table) -> bool: + with self.get_sqla_engine(catalog=table.catalog, schema=table.schema) as engine: + connection = engine.connect() + try: + views = engine.dialect.get_view_names( + connection=connection, + schema=table.schema, + ) + except Exception: # pylint: disable=broad-except + logger.warning("Has view failed", exc_info=True) + views = [] - def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool: - return self.has_view(view_name=view_name, schema=schema) + return table.table in views def get_dialect(self) -> Dialect: sqla_url = make_url_safe(self.sqlalchemy_uri_decrypted) diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index caf9006700773..991d2d41a46f9 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -32,7 +32,6 @@ Column, ForeignKey, Integer, - MetaData, String, Table, Text, @@ -214,13 +213,6 @@ def datasources(self) -> set[BaseDatasource]: def charts(self) -> list[str]: return [slc.chart for slc in self.slices] - @property - def sqla_metadata(self) -> None: - # pylint: disable=no-member - with self.get_sqla_engine() as engine: - meta = MetaData(bind=engine) - meta.reflect() - @property def status(self) -> utils.DashboardStatus: if self.published: diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 18d0f2e1d0fda..40a5132c556c6 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -109,6 +109,7 @@ class Query( tab_name = Column(String(256)) sql_editor_id = Column(String(256)) schema = Column(String(256)) + catalog = Column(String(256), nullable=True, default=None) sql = Column(MediumText()) # Query to retrieve the results, # used only in case of select_as_cta_used is true. @@ -386,6 +387,7 @@ class SavedQuery( user_id = Column(Integer, ForeignKey("ab_user.id"), nullable=True) db_id = Column(Integer, ForeignKey("dbs.id"), nullable=True) schema = Column(String(128)) + catalog = Column(String(256), nullable=True, default=None) label = Column(String(256)) description = Column(Text) sql = Column(MediumText()) @@ -474,6 +476,7 @@ class TabState(AuditMixinNullable, ExtraJSONMixin, Model): database_id = Column(Integer, ForeignKey("dbs.id", ondelete="CASCADE")) database = relationship("Database", foreign_keys=[database_id]) schema = Column(String(256)) + catalog = Column(String(256), nullable=True, default=None) # tables that are open in the schema browser and their data previews table_schemas = relationship( @@ -535,6 +538,7 @@ class TableSchema(AuditMixinNullable, ExtraJSONMixin, Model): ) database = relationship("Database", foreign_keys=[database_id]) schema = Column(String(256)) + catalog = Column(String(256), nullable=True, default=None) table = Column(String(256)) # JSON describing the schema, partitions, latest partition, etc. diff --git a/superset/security/manager.py b/superset/security/manager.py index 4da85b7d1f747..a84c0cec0d2a0 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1922,6 +1922,7 @@ def raise_for_access( table: Optional["Table"] = None, viz: Optional["BaseViz"] = None, sql: Optional[str] = None, + catalog: Optional[str] = None, # pylint: disable=unused-argument schema: Optional[str] = None, ) -> None: """ @@ -1934,6 +1935,7 @@ def raise_for_access( :param table: The Superset table (requires database) :param viz: The visualization :param sql: The SQL string (requires database) + :param catalog: Optional catalog name :param schema: Optional schema name :raises SupersetSecurityException: If the user cannot access the resource """ diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 9076136c64f81..3f8c1cc73709c 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -55,6 +55,7 @@ insert_rls_as_subquery, insert_rls_in_predicate, ParsedQuery, + Table, ) from superset.sqllab.limiting_factor import LimitingFactor from superset.sqllab.utils import write_ipc_buffer @@ -470,7 +471,11 @@ def execute_sql_statements( ) ) - with database.get_raw_connection(query.schema, source=QuerySource.SQL_LAB) as conn: + with database.get_raw_connection( + catalog=query.catalog, + schema=query.schema, + source=QuerySource.SQL_LAB, + ) as conn: # Sharing a single connection and cursor across the # execution of all statements (if many) cursor = conn.cursor() @@ -539,8 +544,7 @@ def execute_sql_statements( query.set_extra_json_key("columns", result_set.columns) if query.select_as_cta: query.select_sql = database.select_star( - query.tmp_table_name, - schema=query.tmp_schema_name, + Table(query.tmp_table_name, query.tmp_schema_name), limit=query.limit, show_cols=False, latest_partition=False, @@ -645,7 +649,9 @@ def cancel_query(query: Query) -> bool: return False with query.database.get_sqla_engine( - query.schema, source=QuerySource.SQL_LAB + catalog=query.catalog, + schema=query.schema, + source=QuerySource.SQL_LAB, ) as engine: with closing(engine.raw_connection()) as conn: with closing(conn.cursor()) as cursor: diff --git a/superset/sql_validators/base.py b/superset/sql_validators/base.py index 8344fc9264d64..25f73af0894a1 100644 --- a/superset/sql_validators/base.py +++ b/superset/sql_validators/base.py @@ -14,7 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Optional + +from __future__ import annotations + +from typing import Any from superset.models.core import Database @@ -25,9 +28,9 @@ class SQLValidationAnnotation: # pylint: disable=too-few-public-methods def __init__( self, message: str, - line_number: Optional[int], - start_column: Optional[int], - end_column: Optional[int], + line_number: int | None, + start_column: int | None, + end_column: int | None, ): self.message = message self.line_number = line_number @@ -52,7 +55,11 @@ class BaseSQLValidator: # pylint: disable=too-few-public-methods @classmethod def validate( - cls, sql: str, schema: Optional[str], database: Database + cls, + sql: str, + catalog: str | None, + schema: str | None, + database: Database, ) -> list[SQLValidationAnnotation]: """Check that the given SQL querystring is valid for the given engine""" raise NotImplementedError diff --git a/superset/sql_validators/postgres.py b/superset/sql_validators/postgres.py index 60c15ca034c27..279520292ea4c 100644 --- a/superset/sql_validators/postgres.py +++ b/superset/sql_validators/postgres.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import re -from typing import Optional from pgsanity.pgsanity import check_string @@ -31,7 +32,11 @@ class PostgreSQLValidator(BaseSQLValidator): # pylint: disable=too-few-public-m @classmethod def validate( - cls, sql: str, schema: Optional[str], database: Database + cls, + sql: str, + catalog: str | None, + schema: str | None, + database: Database, ) -> list[SQLValidationAnnotation]: annotations: list[SQLValidationAnnotation] = [] valid, error = check_string(sql, add_semicolon=True) diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 4d4d898034ca2..06bee217cf22a 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging import time from contextlib import closing -from typing import Any, Optional +from typing import Any from superset import app from superset.models.core import Database @@ -47,7 +49,7 @@ def validate_statement( statement: str, database: Database, cursor: Any, - ) -> Optional[SQLValidationAnnotation]: + ) -> SQLValidationAnnotation | None: # pylint: disable=too-many-locals db_engine_spec = database.db_engine_spec parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine) @@ -140,7 +142,11 @@ def validate_statement( @classmethod def validate( - cls, sql: str, schema: Optional[str], database: Database + cls, + sql: str, + catalog: str | None, + schema: str | None, + database: Database, ) -> list[SQLValidationAnnotation]: """ Presto supports query-validation queries by running them with a @@ -155,7 +161,11 @@ def validate( logger.info("Validating %i statement(s)", len(statements)) # todo(hughhh): update this to use new database.get_raw_connection() # this function keeps stalling CI - with database.get_sqla_engine(schema, source=QuerySource.SQL_LAB) as engine: + with database.get_sqla_engine( + catalog=catalog, + schema=schema, + source=QuerySource.SQL_LAB, + ) as engine: # Sharing a single connection and cursor across the # execution of all statements (if many) annotations: list[SQLValidationAnnotation] = [] diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py index fc082ecb45663..5013d0954cf97 100644 --- a/superset/utils/mock_data.py +++ b/superset/utils/mock_data.py @@ -29,12 +29,13 @@ import sqlalchemy.sql.sqltypes import sqlalchemy_utils from flask_appbuilder import Model -from sqlalchemy import Column, inspect, MetaData, Table +from sqlalchemy import Column, inspect, MetaData, Table as DBTable from sqlalchemy.dialects import postgresql from sqlalchemy.sql import func from sqlalchemy.sql.visitors import VisitableType from superset import db +from superset.sql_parse import Table logger = logging.getLogger(__name__) @@ -182,7 +183,7 @@ def add_data( from superset.utils.database import get_example_database database = get_example_database() - table_exists = database.has_table_by_name(table_name) + table_exists = database.has_table(Table(table_name)) with database.get_sqla_engine() as engine: if columns is None: @@ -198,7 +199,7 @@ def add_data( # create table if needed column_objects = get_column_objects(columns) metadata = MetaData() - table = Table(table_name, metadata, *column_objects) + table = DBTable(table_name, metadata, *column_objects) metadata.create_all(engine) if not append: diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py index eba3acf36edd4..7f81081777538 100644 --- a/superset/views/datasource/views.py +++ b/superset/views/datasource/views.py @@ -37,6 +37,7 @@ from superset.daos.datasource import DatasourceDAO from superset.exceptions import SupersetException, SupersetSecurityException from superset.models.core import Database +from superset.sql_parse import Table from superset.superset_typing import FlaskResponse from superset.utils.core import DatasourceType from superset.views.base import ( @@ -180,8 +181,7 @@ def external_metadata_by_name(self, **kwargs: Any) -> FlaskResponse: ) external_metadata = get_physical_table_metadata( database=database, - table_name=params["table_name"], - schema_name=params["schema_name"], + table=Table(params["table_name"], params["schema_name"]), normalize_columns=params.get("normalize_columns") or False, ) except (NoResultFound, NoSuchTableError) as ex: diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index 48497b977949d..3bd82211e5da4 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -121,7 +121,7 @@ def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None: def quote_f(value: Optional[str]): if not value: return value - with get_example_database().get_inspector_with_context() as inspector: + with get_example_database().get_inspector() as inspector: return inspector.engine.dialect.identifier_preparer.quote_identifier(value) diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 8122eac9d4520..6d70a1cd75a3c 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -132,9 +132,7 @@ def get_expected_row_count(self, client_id: str) -> int: def quote_name(self, name: str): if get_main_database().backend in {"presto", "hive"}: - with ( - get_example_database().get_inspector_with_context() as inspector - ): # E: Ne + with get_example_database().get_inspector() as inspector: # E: Ne return inspector.engine.dialect.identifier_preparer.quote_identifier( name ) diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index bcb9aa32919a7..90873d49b9f9e 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -50,6 +50,7 @@ from superset.models.slice import Slice from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet +from superset.sql_parse import Table from superset.utils import core as utils from superset.utils.core import backend from superset.utils.database import get_example_database @@ -1197,14 +1198,11 @@ def test_explore_redirect(self, mock_command: mock.Mock): ) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_has_table_by_name(self): + def test_has_table(self): if backend() in ("sqlite", "mysql"): return example_db = superset.utils.database.get_example_database() - assert ( - example_db.has_table_by_name(table_name="birth_names", schema="public") - is True - ) + assert example_db.has_table(Table("birth_names", "public")) is True @mock.patch("superset.views.core.request") @mock.patch( diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 016c897988476..ad7d71c768568 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -2031,7 +2031,7 @@ def test_database_tables(self): if database.backend == "postgresql": response = json.loads(rv.data.decode("utf-8")) schemas = [ - s[0] for s in database.get_all_table_names_in_schema(schema_name) + s[0] for s in database.get_all_table_names_in_schema(None, schema_name) ] self.assertEqual(response["count"], len(schemas)) for option in response["result"]: diff --git a/tests/integration_tests/databases/commands/upload_test.py b/tests/integration_tests/databases/commands/upload_test.py index 26379aa9769fb..1af85c3ab1fe1 100644 --- a/tests/integration_tests/databases/commands/upload_test.py +++ b/tests/integration_tests/databases/commands/upload_test.py @@ -73,7 +73,7 @@ def _setup_csv_upload(allowed_schemas: list[str] | None = None): yield upload_db = get_upload_db() - with upload_db.get_sqla_engine_with_context() as engine: + with upload_db.get_sqla_engine() as engine: engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}") engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}") db.session.delete(upload_db) @@ -107,7 +107,7 @@ def test_csv_upload_with_nulls(): None, CSVReader({"null_values": ["N/A", "None"]}), ).run() - with upload_database.get_sqla_engine_with_context() as engine: + with upload_database.get_sqla_engine() as engine: data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall() assert data == [ ("name1", None, "city1", "1-1-1980"), diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 8d24c2993b7db..c10d589d97fd9 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -773,14 +773,14 @@ def test_create_dataset_validate_tables_exists(self): assert rv.status_code == 422 @patch("superset.models.core.Database.get_columns") - @patch("superset.models.core.Database.has_table_by_name") - @patch("superset.models.core.Database.has_view_by_name") + @patch("superset.models.core.Database.has_table") + @patch("superset.models.core.Database.has_view") @patch("superset.models.core.Database.get_table") def test_create_dataset_validate_view_exists( self, mock_get_table, - mock_has_table_by_name, - mock_has_view_by_name, + mock_has_table, + mock_has_view, mock_get_columns, ): """ @@ -796,13 +796,12 @@ def test_create_dataset_validate_view_exists( } ] - mock_has_table_by_name.return_value = False - mock_has_view_by_name.return_value = True + mock_has_table.return_value = False + mock_has_view.return_value = True mock_get_table.return_value = None example_db = get_example_database() with example_db.get_sqla_engine() as engine: - engine = engine dialect = engine.dialect with patch.object( diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index d7498dc4fee84..c8db1f912ad21 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -30,7 +30,7 @@ from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.db_engine_specs.sqlite import SqliteEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.sql_parse import ParsedQuery +from superset.sql_parse import ParsedQuery, Table from superset.utils.database import get_example_database from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec from tests.integration_tests.test_app import app @@ -238,7 +238,7 @@ def test_get_table_names(self): @pytest.mark.usefixtures("load_energy_table_with_slice") def test_column_datatype_to_string(self): example_db = get_example_database() - sqla_table = example_db.get_table("energy_usage") + sqla_table = example_db.get_table(Table("energy_usage")) dialect = example_db.get_dialect() # TODO: fix column type conversion for presto. @@ -540,8 +540,7 @@ def test_get_indexes(): BaseEngineSpec.get_indexes( database=mock.Mock(), inspector=inspector, - table_name="bar", - schema="foo", + table=Table("bar", "foo"), ) == indexes ) diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index ce184685db540..53f9137076bb8 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -165,8 +165,7 @@ def test_get_indexes(self): BigQueryEngineSpec.get_indexes( database, inspector, - table_name, - schema, + Table(table_name, schema), ) == [] ) @@ -184,8 +183,7 @@ def test_get_indexes(self): assert BigQueryEngineSpec.get_indexes( database, inspector, - table_name, - schema, + Table(table_name, schema), ) == [ { "name": "partition", @@ -207,8 +205,7 @@ def test_get_indexes(self): assert BigQueryEngineSpec.get_indexes( database, inspector, - table_name, - schema, + Table(table_name, schema), ) == [ { "name": "partition", diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index 39d2c30fd1162..4d1a84508167b 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -23,7 +23,7 @@ from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3 from superset.exceptions import SupersetException -from superset.sql_parse import Table, ParsedQuery +from superset.sql_parse import ParsedQuery, Table from tests.integration_tests.test_app import app @@ -328,7 +328,10 @@ def test_where_latest_partition(mock_method): columns = [{"name": "ds"}, {"name": "hour"}] with app.app_context(): result = HiveEngineSpec.where_latest_partition( - "test_table", "test_schema", database, select(), columns + database, + Table("test_table", "test_schema"), + select(), + columns, ) query_result = str(result.compile(compile_kwargs={"literal_binds": True})) assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result @@ -341,7 +344,10 @@ def test_where_latest_partition_super_method_exception(mock_method): columns = [{"name": "ds"}, {"name": "hour"}] with app.app_context(): result = HiveEngineSpec.where_latest_partition( - "test_table", "test_schema", database, select(), columns + database, + Table("test_table", "test_schema"), + select(), + columns, ) assert result is None mock_method.assert_called() @@ -353,7 +359,9 @@ def test_where_latest_partition_no_columns_no_values(mock_method): db = mock.Mock() with app.app_context(): result = HiveEngineSpec.where_latest_partition( - "test_table", "test_schema", db, select() + db, + Table("test_table", "test_schema"), + select(), ) assert result is None diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index 0f4841fb3563d..708b94987677d 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -530,7 +530,7 @@ def test_get_catalog_names(app_context: AppContext) -> None: if database.backend != "postgresql": return - with database.get_inspector_with_context() as inspector: + with database.get_inspector() as inspector: assert PostgresEngineSpec.get_catalog_names(database, inspector) == [ "postgres", "superset", diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 02669a162fd23..607afa6953fcd 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -82,7 +82,7 @@ def verify_presto_column(self, column, expected_results): row = mock.Mock() row.Column, row.Type, row.Null = column inspector.bind.execute.return_value.fetchall = mock.Mock(return_value=[row]) - results = PrestoEngineSpec.get_columns(inspector, "", "") + results = PrestoEngineSpec.get_columns(inspector, Table("", "")) self.assertEqual(len(expected_results), len(results)) for expected_result, result in zip(expected_results, results): self.assertEqual(expected_result[0], result["column_name"]) @@ -573,7 +573,10 @@ def test_presto_where_latest_partition(self): db.get_df = mock.Mock(return_value=df) columns = [{"name": "ds"}, {"name": "hour"}] result = PrestoEngineSpec.where_latest_partition( - "test_table", "test_schema", db, select(), columns + db, + Table("test_table", "test_schema"), + select(), + columns, ) query_result = str(result.compile(compile_kwargs={"literal_binds": True})) self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result) @@ -802,7 +805,7 @@ def test_show_columns(self): return_value=["a", "b"] ) table_name = "table_name" - result = PrestoEngineSpec._show_columns(inspector, table_name, None) + result = PrestoEngineSpec._show_columns(inspector, Table(table_name)) assert result == ["a", "b"] inspector.bind.execute.assert_called_once_with( f'SHOW COLUMNS FROM "{table_name}"' @@ -818,7 +821,7 @@ def test_show_columns_with_schema(self): ) table_name = "table_name" schema = "schema" - result = PrestoEngineSpec._show_columns(inspector, table_name, schema) + result = PrestoEngineSpec._show_columns(inspector, Table(table_name, schema)) assert result == ["a", "b"] inspector.bind.execute.assert_called_once_with( f'SHOW COLUMNS FROM "{schema}"."{table_name}"' @@ -846,9 +849,16 @@ def test_select_star_no_presto_expand_data(self, mock_select_star): {"col1": "val1"}, {"col2": "val2"}, ] - PrestoEngineSpec.select_star(database, table_name, engine, cols=cols) + PrestoEngineSpec.select_star(database, Table(table_name), engine, cols=cols) mock_select_star.assert_called_once_with( - database, table_name, engine, None, 100, False, True, True, cols + database, + Table(table_name), + engine, + 100, + False, + True, + True, + cols, ) @mock.patch("superset.db_engine_specs.presto.is_feature_enabled") @@ -869,13 +879,16 @@ def test_select_star_presto_expand_data( {"column_name": ".val2."}, ] PrestoEngineSpec.select_star( - database, table_name, engine, show_cols=True, cols=cols + database, + Table(table_name), + engine, + show_cols=True, + cols=cols, ) mock_select_star.assert_called_once_with( database, - table_name, + Table(table_name), engine, - None, 100, True, True, @@ -1172,7 +1185,7 @@ def test_get_catalog_names(app_context: AppContext) -> None: if database.backend != "presto": return - with database.get_inspector_with_context() as inspector: + with database.get_inspector() as inspector: assert PrestoEngineSpec.get_catalog_names(database, inspector) == [ "jmx", "memory", diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 7c3bc15c39a56..b68cb7c05f559 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -39,6 +39,7 @@ from superset.common.db_query_status import QueryStatus from superset.models.core import Database from superset.models.slice import Slice +from superset.sql_parse import Table from superset.utils.database import get_example_database from .base_tests import SupersetTestCase @@ -294,14 +295,14 @@ def test_impersonate_user_hive(self, mocked_create_engine): def test_select_star(self): db = get_example_database() table_name = "energy_usage" - sql = db.select_star(table_name, show_cols=False, latest_partition=False) + sql = db.select_star(Table(table_name), show_cols=False, latest_partition=False) with db.get_sqla_engine() as engine: quote = engine.dialect.identifier_preparer.quote_identifier source = quote(table_name) if db.backend in {"presto", "hive"} else table_name expected = f"SELECT\n *\nFROM {source}\nLIMIT 100" assert expected in sql - sql = db.select_star(table_name, show_cols=True, latest_partition=False) + sql = db.select_star(Table(table_name), show_cols=True, latest_partition=False) # TODO(bkyryliuk): unify sql generation if db.backend == "presto": assert ( @@ -324,7 +325,9 @@ def test_select_star_fully_qualified_names(self): schema = "schema.name" table_name = "table/name" sql = db.select_star( - table_name, schema=schema, show_cols=False, latest_partition=False + Table(table_name, schema), + show_cols=False, + latest_partition=False, ) fully_qualified_names = { "sqlite": '"schema.name"."table/name"', diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py index 12cb530582e3b..c286bf3a438bd 100644 --- a/tests/integration_tests/sql_validator_tests.py +++ b/tests/integration_tests/sql_validator_tests.py @@ -56,7 +56,7 @@ def test_validator_success(self, flask_g): sql = "SELECT 1 FROM default.notarealtable" schema = "default" - errors = self.validator.validate(sql, schema, self.database) + errors = self.validator.validate(sql, None, schema, self.database) self.assertEqual([], errors) @@ -70,7 +70,7 @@ def test_validator_db_error(self, flask_g): fetch_fn.side_effect = DatabaseError("dummy db error") with self.assertRaises(PrestoSQLValidationError): - self.validator.validate(sql, schema, self.database) + self.validator.validate(sql, None, schema, self.database) @patch("superset.utils.core.g") def test_validator_unexpected_error(self, flask_g): @@ -82,7 +82,7 @@ def test_validator_unexpected_error(self, flask_g): fetch_fn.side_effect = Exception("a mysterious failure") with self.assertRaises(Exception): - self.validator.validate(sql, schema, self.database) + self.validator.validate(sql, None, schema, self.database) @patch("superset.utils.core.g") def test_validator_query_error(self, flask_g): @@ -93,7 +93,7 @@ def test_validator_query_error(self, flask_g): fetch_fn = self.database.db_engine_spec.fetch_data fetch_fn.side_effect = DatabaseError(self.PRESTO_ERROR_TEMPLATE) - errors = self.validator.validate(sql, schema, self.database) + errors = self.validator.validate(sql, None, schema, self.database) self.assertEqual(1, len(errors)) @@ -105,7 +105,10 @@ def test_valid_syntax(self): mock_database = MagicMock() annotations = PostgreSQLValidator.validate( - sql='SELECT 1, "col" FROM "table"', schema="", database=mock_database + sql='SELECT 1, "col" FROM "table"', + catalog=None, + schema="", + database=mock_database, ) assert annotations == [] @@ -115,7 +118,10 @@ def test_invalid_syntax(self): mock_database = MagicMock() annotations = PostgreSQLValidator.validate( - sql='SELECT 1, "col"\nFROOM "table"', schema="", database=mock_database + sql='SELECT 1, "col"\nFROOM "table"', + catalog=None, + schema="", + database=mock_database, ) assert len(annotations) == 1 diff --git a/tests/unit_tests/dao/dataset_test.py b/tests/unit_tests/dao/dataset_test.py index 1e3d1ec975022..a2e2b2b39fba6 100644 --- a/tests/unit_tests/dao/dataset_test.py +++ b/tests/unit_tests/dao/dataset_test.py @@ -18,6 +18,7 @@ from sqlalchemy.orm.session import Session from superset.daos.dataset import DatasetDAO +from superset.sql_parse import Table def test_validate_update_uniqueness(session: Session) -> None: @@ -54,9 +55,8 @@ def test_validate_update_uniqueness(session: Session) -> None: assert ( DatasetDAO.validate_update_uniqueness( database_id=database.id, - schema=dataset1.schema, + table=Table(dataset1.table_name, dataset1.schema), dataset_id=dataset1.id, - name=dataset1.table_name, ) is True ) @@ -65,9 +65,8 @@ def test_validate_update_uniqueness(session: Session) -> None: assert ( DatasetDAO.validate_update_uniqueness( database_id=database.id, - schema=dataset2.schema, + table=Table(dataset1.table_name, dataset2.schema), dataset_id=dataset1.id, - name=dataset1.table_name, ) is False ) @@ -76,9 +75,8 @@ def test_validate_update_uniqueness(session: Session) -> None: assert ( DatasetDAO.validate_update_uniqueness( database_id=database.id, - schema=None, + table=Table(dataset1.table_name), dataset_id=dataset1.id, - name=dataset1.table_name, ) is True ) diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 2f3c11f9a3521..154fd6501c5d1 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -1415,6 +1415,170 @@ def test_excel_upload_file_extension_invalid( assert response.json == {"message": {"file": ["File extension is not allowed."]}} +def test_table_metadata_happy_path( + mocker: MockFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `table_metadata` endpoint. + """ + database = mocker.MagicMock() + database.db_engine_spec.get_table_metadata.return_value = {"hello": "world"} + mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database) + mocker.patch("superset.databases.api.security_manager.raise_for_access") + + response = client.get("/api/v1/database/1/table_metadata/?name=t") + assert response.json == {"hello": "world"} + database.db_engine_spec.get_table_metadata.assert_called_with( + database, + Table("t"), + ) + + response = client.get("/api/v1/database/1/table_metadata/?name=t&schema=s") + database.db_engine_spec.get_table_metadata.assert_called_with( + database, + Table("t", "s"), + ) + + response = client.get("/api/v1/database/1/table_metadata/?name=t&catalog=c") + database.db_engine_spec.get_table_metadata.assert_called_with( + database, + Table("t", None, "c"), + ) + + response = client.get( + "/api/v1/database/1/table_metadata/?name=t&schema=s&catalog=c" + ) + database.db_engine_spec.get_table_metadata.assert_called_with( + database, + Table("t", "s", "c"), + ) + + +def test_table_metadata_no_table( + mocker: MockFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `table_metadata` endpoint when no table name is passed. + """ + database = mocker.MagicMock() + mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database) + + response = client.get("/api/v1/database/1/table_metadata/?schema=s&catalog=c") + assert response.status_code == 422 + assert response.json == { + "errors": [ + { + "message": "An error happened when validating the request", + "error_type": "INVALID_PAYLOAD_SCHEMA_ERROR", + "level": "error", + "extra": { + "messages": {"name": ["Missing data for required field."]}, + "issue_codes": [ + { + "code": 1020, + "message": "Issue 1020 - The submitted payload has the incorrect schema.", + } + ], + }, + } + ] + } + + +def test_table_metadata_slashes( + mocker: MockFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `table_metadata` endpoint with names that have slashes. + """ + database = mocker.MagicMock() + database.db_engine_spec.get_table_metadata.return_value = {"hello": "world"} + mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database) + mocker.patch("superset.databases.api.security_manager.raise_for_access") + + client.get("/api/v1/database/1/table_metadata/?name=foo/bar") + database.db_engine_spec.get_table_metadata.assert_called_with( + database, + Table("foo/bar"), + ) + + +def test_table_metadata_invalid_database( + mocker: MockFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `table_metadata` endpoint when the database is invalid. + """ + mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=None) + + response = client.get("/api/v1/database/1/table_metadata/?name=t") + assert response.status_code == 404 + assert response.json == { + "errors": [ + { + "message": "No such database", + "error_type": "DATABASE_NOT_FOUND_ERROR", + "level": "error", + "extra": { + "issue_codes": [ + { + "code": 1011, + "message": "Issue 1011 - Superset encountered an unexpected error.", + }, + { + "code": 1036, + "message": "Issue 1036 - The database was deleted.", + }, + ] + }, + } + ] + } + + +def test_table_metadata_unauthorized( + mocker: MockFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `table_metadata` endpoint when the user is unauthorized. + """ + database = mocker.MagicMock() + mocker.patch("superset.databases.api.DatabaseDAO.find_by_id", return_value=database) + mocker.patch( + "superset.databases.api.security_manager.raise_for_access", + side_effect=SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.TABLE_SECURITY_ACCESS_ERROR, + message="You don't have access to the table", + level=ErrorLevel.ERROR, + ) + ), + ) + + response = client.get("/api/v1/database/1/table_metadata/?name=t") + assert response.status_code == 404 + assert response.json == { + "errors": [ + { + "message": "No such table", + "error_type": "TABLE_NOT_FOUND_ERROR", + "level": "error", + "extra": None, + } + ] + } + + def test_table_extra_metadata_happy_path( mocker: MockFixture, client: Any, diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index e17e0d2833db8..3bc05ee20eec0 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -232,9 +232,8 @@ class NoLimitDBEngineSpec(BaseEngineSpec): sql = BaseEngineSpec.select_star( database=database, - table_name="my_table", + table=Table("my_table"), engine=engine, - schema=None, limit=100, show_cols=True, indent=True, @@ -252,9 +251,8 @@ class NoLimitDBEngineSpec(BaseEngineSpec): sql = NoLimitDBEngineSpec.select_star( database=database, - table_name="my_table", + table=Table("my_table"), engine=engine, - schema=None, limit=100, show_cols=True, indent=True, diff --git a/tests/unit_tests/db_engine_specs/test_bigquery.py b/tests/unit_tests/db_engine_specs/test_bigquery.py index 663fd7cac8460..616ae668418ce 100644 --- a/tests/unit_tests/db_engine_specs/test_bigquery.py +++ b/tests/unit_tests/db_engine_specs/test_bigquery.py @@ -27,6 +27,7 @@ from sqlalchemy.sql import sqltypes from sqlalchemy_bigquery import BigQueryDialect +from superset.sql_parse import Table from superset.superset_typing import ResultSetColumnType from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm # noqa: F401 @@ -156,9 +157,8 @@ def test_select_star(mocker: MockFixture) -> None: sql = BigQueryEngineSpec.select_star( database=database, - table_name="my_table", + table=Table("my_table"), engine=engine, - schema=None, limit=100, show_cols=True, indent=True, diff --git a/tests/unit_tests/db_engine_specs/test_db2.py b/tests/unit_tests/db_engine_specs/test_db2.py index 6d0d604a25f0c..017fcd7b80e7e 100644 --- a/tests/unit_tests/db_engine_specs/test_db2.py +++ b/tests/unit_tests/db_engine_specs/test_db2.py @@ -18,6 +18,8 @@ import pytest # noqa: F401 from pytest_mock import MockerFixture +from superset.sql_parse import Table + def test_epoch_to_dttm() -> None: """ @@ -43,7 +45,7 @@ def test_get_table_comment(mocker: MockerFixture): } assert ( - Db2EngineSpec.get_table_comment(mock_inspector, "my_table", "my_schema") + Db2EngineSpec.get_table_comment(mock_inspector, Table("my_table", "my_schema")) == "This is a table comment" ) @@ -59,7 +61,8 @@ def test_get_table_comment_empty(mocker: MockerFixture): mock_inspector.get_table_comment.return_value = {} assert ( - Db2EngineSpec.get_table_comment(mock_inspector, "my_table", "my_schema") is None # noqa: E711 + Db2EngineSpec.get_table_comment(mock_inspector, Table("my_table", "my_schema")) + is None ) diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index 8d57d4ed1a8c3..638b377c82709 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -24,6 +24,7 @@ from sqlalchemy import sql, text, types from sqlalchemy.engine.url import make_url +from superset.sql_parse import Table from superset.superset_typing import ResultSetColumnType from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( @@ -143,7 +144,10 @@ def test_where_latest_partition( expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}""" result = spec.where_latest_partition( - "table", mock.MagicMock(), mock.MagicMock(), query, columns + mock.MagicMock(), + Table("table"), + query, + columns, ) assert result is not None actual = result.compile( diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 5353578850a39..5bd83828ed2c6 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -311,15 +311,15 @@ def test_convert_dttm( assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm) -def test_get_extra_table_metadata() -> None: +def test_get_extra_table_metadata(mocker: MockerFixture) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec - db_mock = Mock() + db_mock = mocker.MagicMock() db_mock.get_indexes = Mock( return_value=[{"column_names": ["ds", "hour"], "name": "partition"}] ) db_mock.get_extra = Mock(return_value={}) - db_mock.has_view_by_name = Mock(return_value=None) + db_mock.has_view = Mock(return_value=None) db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})) result = TrinoEngineSpec.get_extra_table_metadata( db_mock, @@ -442,7 +442,7 @@ def test_get_columns(mocker: MockerFixture): mock_inspector = mocker.MagicMock() mock_inspector.get_columns.return_value = sqla_columns - actual = TrinoEngineSpec.get_columns(mock_inspector, "table", "schema") + actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema")) expected = [ ResultSetColumnType( name="field1", column_name="field1", type=field1_type, is_dttm=False @@ -475,7 +475,9 @@ def test_get_columns_expand_rows(mocker: MockerFixture): mock_inspector.get_columns.return_value = sqla_columns actual = TrinoEngineSpec.get_columns( - mock_inspector, "table", "schema", {"expand_rows": True} + mock_inspector, + Table("table", "schema"), + {"expand_rows": True}, ) expected = [ ResultSetColumnType( @@ -538,7 +540,9 @@ def test_get_indexes_no_table(): side_effect=NoSuchTableError("The specified table does not exist.") ) result = TrinoEngineSpec.get_indexes( - db_mock, inspector_mock, "test_table", "test_schema" + db_mock, + inspector_mock, + Table("test_table", "test_schema"), ) assert result == [] diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index beefd3ea3cc5a..ce3ad1822271f 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -18,7 +18,6 @@ # pylint: disable=import-outside-toplevel import json from datetime import datetime -from typing import Optional import pytest from pytest_mock import MockFixture @@ -26,6 +25,7 @@ from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.models.core import Database +from superset.sql_parse import Table def test_get_metrics(mocker: MockFixture) -> None: @@ -37,7 +37,7 @@ def test_get_metrics(mocker: MockFixture) -> None: from superset.models.core import Database database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") - assert database.get_metrics("table") == [ + assert database.get_metrics(Table("table")) == [ { "expression": "COUNT(*)", "metric_name": "count", @@ -52,8 +52,7 @@ def get_metrics( cls, database: Database, inspector: Inspector, - table_name: str, - schema: Optional[str], + table: Table, ) -> list[MetricType]: return [ { @@ -65,7 +64,7 @@ def get_metrics( ] database.get_db_engine_spec = mocker.MagicMock(return_value=CustomSqliteEngineSpec) - assert database.get_metrics("table") == [ + assert database.get_metrics(Table("table")) == [ { "expression": "COUNT(DISTINCT user_id)", "metric_name": "count_distinct_user_id",