diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 922c78f21f2c1..72265fbffbb76 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -80,11 +80,13 @@ from superset.common.utils.time_range_utils import get_since_until_from_time_range from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.connectors.sqla.utils import ( + find_cached_objects_in_session, get_columns_description, get_physical_table_metadata, get_virtual_table_metadata, validate_adhoc_subquery, ) +from superset.datasets.models import Dataset as NewDataset from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression from superset.exceptions import ( AdvancedDataTypeResponseError, @@ -2088,6 +2090,21 @@ def update_column( # pylint: disable=unused-argument # table is updated. This busts the cache key for all charts that use the table. session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id)) + # TODO: This shadow writing is deprecated + # if table itself has changed, shadow-writing will happen in `after_update` anyway + if target.table not in session.dirty: + dataset: NewDataset = ( + session.query(NewDataset) + .filter_by(uuid=target.table.uuid) + .one_or_none() + ) + # Update shadow dataset and columns + # did we find the dataset? + if not dataset: + # if dataset is not found create a new copy + target.table.write_shadow_dataset() + return + @staticmethod def after_insert( mapper: Mapper, @@ -2099,6 +2116,9 @@ def after_insert( """ security_manager.dataset_after_insert(mapper, connection, sqla_table) + # TODO: deprecated + sqla_table.write_shadow_dataset() + @staticmethod def after_delete( mapper: Mapper, @@ -2117,11 +2137,53 @@ def after_update( sqla_table: "SqlaTable", ) -> None: """ - Update dataset permissions after update + Update dataset permissions """ # set permissions security_manager.dataset_after_update(mapper, connection, sqla_table) + # TODO: the shadow writing is deprecated + inspector = inspect(sqla_table) + session = inspector.session + + # double-check that ``UPDATE``s are actually pending (this method is called even + # for instances that have no net changes to their column-based attributes) + if not session.is_modified(sqla_table, include_collections=True): + return + + # find the dataset from the known instance list first + # (it could be either from a previous query or newly created) + dataset = next( + find_cached_objects_in_session( + session, NewDataset, uuids=[sqla_table.uuid] + ), + None, + ) + # if not found, pull from database + if not dataset: + dataset = ( + session.query(NewDataset).filter_by(uuid=sqla_table.uuid).one_or_none() + ) + if not dataset: + sqla_table.write_shadow_dataset() + return + + def write_shadow_dataset( + self: "SqlaTable", + ) -> None: + """ + This method is deprecated + """ + session = inspect(self).session + # most of the write_shadow_dataset functionality has been removed + # but leaving this portion in + # to remove later because it is adding a Database relationship to the session + # and there is some functionality that depends on this + if self.database_id and ( + not self.database or self.database.id != self.database_id + ): + self.database = session.query(Database).filter_by(id=self.database_id).one() + sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update) sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update) diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index aed50574cb989..d260df3610002 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -15,11 +15,9 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional -from flask_appbuilder.models.sqla.interface import SQLAInterface from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import joinedload from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.dao.base import BaseDAO @@ -37,26 +35,6 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods model_cls = SqlaTable base_filter = DatasourceFilter - @classmethod - def find_by_ids(cls, model_ids: Union[List[str], List[int]]) -> List[SqlaTable]: - """ - Find a List of models by a list of ids, if defined applies `base_filter` - """ - id_col = getattr(SqlaTable, cls.id_column_name, None) - if id_col is None: - return [] - - # the joinedload option ensures that the database is - # available in the session later and not lazy loaded - query = ( - db.session.query(SqlaTable) - .options(joinedload(SqlaTable.database)) - .filter(id_col.in_(model_ids)) - ) - data_model = SQLAInterface(SqlaTable, db.session) - query = DatasourceFilter(cls.id_column_name, data_model).apply(query, None) - return query.all() - @staticmethod def get_database_by_id(database_id: int) -> Optional[Database]: try: