From e10b514ee916673c24877de0bae973b0fd3d4117 Mon Sep 17 00:00:00 2001 From: John Bodley Date: Thu, 18 May 2023 12:08:17 -0700 Subject: [PATCH 1/2] fix: Address regression introduced in #22853 --- superset/connectors/sqla/models.py | 98 ++++++++++++-------------- superset/models/core.py | 20 +++++- superset/models/sql_lab.py | 21 +++--- tests/integration_tests/model_tests.py | 5 ++ tests/unit_tests/models/core_test.py | 5 ++ 5 files changed, 82 insertions(+), 67 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index e339f7b1f4c9e..514df85bbac6d 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -69,6 +69,7 @@ backref, Mapped, Query, + reconstructor, relationship, RelationshipProperty, Session, @@ -229,6 +230,30 @@ class TableColumn(Model, BaseColumn, CertificationMixin): update_from_object_fields = [s for s in export_fields if s not in ("table_id",)] export_parent = "table" + def __init__(self, **kwargs: Any) -> None: + """ + Construct a TableColumn object. + + Historically a TableColumn object (from an ORM perspective) was tighly bound to + a SqlaTable object, however with the introduction of the Query datasource this + is no longer true, i.e., the SqlaTable relationship is optional. + + Now the TableColumn is either directly associated with the Database object ( + which is unknown to the ORM) or indirectly via the SqlaTable object (courtesy of + the ORM) depending on the context. + """ + + self._database: Optional[Database] = kwargs.pop("database", None) + super().__init__(**kwargs) + + @reconstructor + def init_on_load(self) -> None: + """ + Construct a TableColumn object when invoked via the SQLAlchemy ORM. + """ + + self._database = None + @property def is_boolean(self) -> bool: """ @@ -262,51 +287,33 @@ def is_temporal(self) -> bool: return self.is_dttm return self.type_generic == GenericDataType.TEMPORAL + @property + def database(self) -> Database: + return self.table.database if self.table else self._database + @property def db_engine_spec(self) -> Type[BaseEngineSpec]: - return self.table.db_engine_spec + return self.database.db_engine_spec @property def db_extra(self) -> Dict[str, Any]: - return self.table.database.get_extra() + return self.database.get_extra() @property def type_generic(self) -> Optional[utils.GenericDataType]: if self.is_dttm: return GenericDataType.TEMPORAL - bool_types = ("BOOL",) - num_types = ( - "DOUBLE", - "FLOAT", - "INT", - "BIGINT", - "NUMBER", - "LONG", - "REAL", - "NUMERIC", - "DECIMAL", - "MONEY", - ) - date_types = ("DATE", "TIME") - str_types = ("VARCHAR", "STRING", "CHAR") - - if self.table is None: - # Query.TableColumns don't have a reference to a table.db_engine_spec - # reference so this logic will manage rendering types - if self.type and any(map(lambda t: t in self.type.upper(), str_types)): - return GenericDataType.STRING - if self.type and any(map(lambda t: t in self.type.upper(), bool_types)): - return GenericDataType.BOOLEAN - if self.type and any(map(lambda t: t in self.type.upper(), num_types)): - return GenericDataType.NUMERIC - if self.type and any(map(lambda t: t in self.type.upper(), date_types)): - return GenericDataType.TEMPORAL - - column_spec = self.db_engine_spec.get_column_spec( - self.type, db_extra=self.db_extra + return ( + column_spec.generic_type # pylint: disable=used-before-assignment + if ( + column_spec := self.db_engine_spec.get_column_spec( + self.type, + db_extra=self.db_extra, + ) + ) + else None ) - return column_spec.generic_type if column_spec else None def get_sqla_col( self, @@ -323,7 +330,7 @@ def get_sqla_col( col = literal_column(expression, type_=type_) else: col = column(self.column_name, type_=type_) - col = self.table.make_sqla_column_compatible(col, label) + col = self.database.make_sqla_column_compatible(col, label) return col @property @@ -354,7 +361,7 @@ def get_timestamp_expression( type_ = column_spec.sqla_type if column_spec else DateTime if not self.expression and not time_grain and not is_epoch: sqla_col = column(self.column_name, type_=type_) - return self.table.make_sqla_column_compatible(sqla_col, label) + return self.database.make_sqla_column_compatible(sqla_col, label) if expression := self.expression: if template_processor: expression = template_processor.process_template(expression) @@ -362,7 +369,7 @@ def get_timestamp_expression( else: col = column(self.column_name, type_=type_) time_expr = self.db_engine_spec.get_timestamp_expr(col, pdf, time_grain) - return self.table.make_sqla_column_compatible(time_expr, label) + return self.database.make_sqla_column_compatible(time_expr, label) @property def data(self) -> Dict[str, Any]: @@ -434,7 +441,7 @@ def get_sqla_col( expression = template_processor.process_template(expression) sqla_col: ColumnClause = literal_column(expression) - return self.table.make_sqla_column_compatible(sqla_col, label) + return self.table.database.make_sqla_column_compatible(sqla_col, label) @property def perm(self) -> Optional[str]: @@ -1023,23 +1030,6 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals ) return self.make_sqla_column_compatible(sqla_column, label) - def make_sqla_column_compatible( - self, sqla_col: ColumnElement, label: Optional[str] = None - ) -> ColumnElement: - """Takes a sqlalchemy column object and adds label info if supported by engine. - :param sqla_col: sqlalchemy column instance - :param label: alias/label that column is expected to have - :return: either a sql alchemy column or label instance if supported by engine - """ - label_expected = label or sqla_col.name - db_engine_spec = self.db_engine_spec - # add quotes to tables - if db_engine_spec.allows_alias_in_select: - label = db_engine_spec.make_label_compatible(label_expected) - sqla_col = sqla_col.label(label) - sqla_col.key = label_expected - return sqla_col - def make_orderby_compatible( self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement] ) -> None: diff --git a/superset/models/core.py b/superset/models/core.py index 43d12900e613d..0b1859f01d5df 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=line-too-long +# pylint: disable=line-too-long,too-many-lines """A collection of ORM sqlalchemy models for Superset""" import enum import json @@ -52,7 +52,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint -from sqlalchemy.sql import expression, Select +from sqlalchemy.sql import ColumnElement, expression, Select from superset import app, db_engine_specs from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK @@ -955,6 +955,22 @@ def get_dialect(self) -> Dialect: sqla_url = make_url_safe(self.sqlalchemy_uri_decrypted) return sqla_url.get_dialect()() + def make_sqla_column_compatible( + self, sqla_col: ColumnElement, label: Optional[str] = None + ) -> ColumnElement: + """Takes a sqlalchemy column object and adds label info if supported by engine. + :param sqla_col: sqlalchemy column instance + :param label: alias/label that column is expected to have + :return: either a sql alchemy column or label instance if supported by engine + """ + label_expected = label or sqla_col.name + # add quotes to tables + if self.db_engine_spec.allows_alias_in_select: + label = self.db_engine_spec.make_label_compatible(label_expected) + sqla_col = sqla_col.label(label) + sqla_col.key = label_expected + return sqla_col + sqla.event.listen(Database, "after_insert", security_manager.database_after_insert) sqla.event.listen(Database, "after_update", security_manager.database_after_update) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index b2f0c8c1ed2f1..b93a2128f3822 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -190,18 +190,17 @@ def columns(self) -> List["TableColumn"]: TableColumn, ) - columns = [] - for col in self.extra.get("columns", []): - columns.append( - TableColumn( - column_name=col["name"], - type=col["type"], - is_dttm=col["is_dttm"], - groupby=True, - filterable=True, - ) + return [ + TableColumn( + column_name=col["name"], + database=self.database, + is_dttm=col["is_dttm"], + filterable=True, + groupby=True, + type=col["type"], ) - return columns + for col in self.extra.get("columns", []) + ] @property def db_extra(self) -> Optional[Dict[str, Any]]: diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index d5684b1b62109..8d0cd8777bd93 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -671,3 +671,8 @@ def test_data_for_slices_with_adhoc_column(self): # clean up and auto commit metadata_db.session.delete(slc) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_table_column_database(self) -> None: + tbl = self.get_table(name="birth_names") + assert tbl.get_column("ds").database is tbl.database # type: ignore diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index bf8f589913086..7f0b443c85770 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -207,3 +207,8 @@ def test_dttm_sql_literal( result: str, ) -> None: assert SqlaTable(database=database).dttm_sql_literal(dttm, col) == result + + +def test_table_column_database() -> None: + database = Database(database_name="db") + assert TableColumn(database=database).database is database From 60cb66a9ef16743eb894b7cafc50acdd797ccbc6 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Mon, 12 Jun 2023 13:04:35 -0700 Subject: [PATCH 2/2] Fix lint --- superset/connectors/sqla/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 253a994ea3040..edde2232056c6 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -232,7 +232,7 @@ def __init__(self, **kwargs: Any) -> None: the ORM) depending on the context. """ - self._database: Optional[Database] = kwargs.pop("database", None) + self._database: Database | None = kwargs.pop("database", None) super().__init__(**kwargs) @reconstructor