diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index d19a8f0282116..90b4dc063e46b 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -22,7 +22,7 @@ from six import string_types import sqlalchemy as sa from sqlalchemy import ( - Boolean, Column, DateTime, ForeignKey, Integer, or_, String, Text, + Boolean, Column, DateTime, ForeignKey, Integer, or_, String, Text, UniqueConstraint, ) from sqlalchemy.orm import backref, relationship @@ -169,7 +169,7 @@ def refresh(self, datasource_names, merge_flag, refreshAll): if cols: col_objs_list = ( session.query(DruidColumn) - .filter(DruidColumn.datasource_name == datasource.datasource_name) + .filter(DruidColumn.datasource_id == datasource.id) .filter(or_(DruidColumn.column_name == col for col in cols)) ) col_objs = {col.column_name: col for col in col_objs_list} @@ -179,7 +179,7 @@ def refresh(self, datasource_names, merge_flag, refreshAll): col_obj = col_objs.get(col, None) if not col_obj: col_obj = DruidColumn( - datasource_name=datasource.datasource_name, + datasource_id=datasource.id, column_name=col) with session.no_autoflush: session.add(col_obj) @@ -220,9 +220,9 @@ class DruidColumn(Model, BaseColumn): __tablename__ = 'columns' - datasource_name = Column( - String(255), - ForeignKey('datasources.datasource_name')) + datasource_id = Column( + Integer, + ForeignKey('datasources.id')) # Setting enable_typechecks=False disables polymorphic inheritance. datasource = relationship( 'DruidDatasource', @@ -231,7 +231,7 @@ class DruidColumn(Model, BaseColumn): dimension_spec_json = Column(Text) export_fields = ( - 'datasource_name', 'column_name', 'is_active', 'type', 'groupby', + 'datasource_id', 'column_name', 'is_active', 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable', 'description', 'dimension_spec_json', ) @@ -334,15 +334,14 @@ def generate_metrics(self): metrics = self.get_metrics() dbmetrics = ( db.session.query(DruidMetric) - .filter(DruidCluster.cluster_name == self.datasource.cluster_name) - .filter(DruidMetric.datasource_name == self.datasource_name) + .filter(DruidMetric.datasource_id == self.datasource_id) .filter(or_( DruidMetric.metric_name == m for m in metrics )) ) dbmetrics = {metric.metric_name: metric for metric in dbmetrics} for metric in metrics.values(): - metric.datasource_name = self.datasource_name + metric.datasource_id = self.datasource_id if not dbmetrics.get(metric.metric_name, None): db.session.add(metric) @@ -350,7 +349,7 @@ def generate_metrics(self): def import_obj(cls, i_column): def lookup_obj(lookup_column): return db.session.query(DruidColumn).filter( - DruidColumn.datasource_name == lookup_column.datasource_name, + DruidColumn.datasource_id == lookup_column.datasource_id, DruidColumn.column_name == lookup_column.column_name).first() return import_util.import_simple_obj(db.session, i_column, lookup_obj) @@ -361,9 +360,9 @@ class DruidMetric(Model, BaseMetric): """ORM object referencing Druid metrics for a datasource""" __tablename__ = 'metrics' - datasource_name = Column( - String(255), - ForeignKey('datasources.datasource_name')) + datasource_id = Column( + Integer, + ForeignKey('datasources.id')) # Setting enable_typechecks=False disables polymorphic inheritance. datasource = relationship( 'DruidDatasource', @@ -372,7 +371,7 @@ class DruidMetric(Model, BaseMetric): json = Column(Text) export_fields = ( - 'metric_name', 'verbose_name', 'metric_type', 'datasource_name', + 'metric_name', 'verbose_name', 'metric_type', 'datasource_id', 'json', 'description', 'is_restricted', 'd3format', ) @@ -400,7 +399,7 @@ def perm(self): def import_obj(cls, i_metric): def lookup_obj(lookup_metric): return db.session.query(DruidMetric).filter( - DruidMetric.datasource_name == lookup_metric.datasource_name, + DruidMetric.datasource_id == lookup_metric.datasource_id, DruidMetric.metric_name == lookup_metric.metric_name).first() return import_util.import_simple_obj(db.session, i_metric, lookup_obj) @@ -420,7 +419,7 @@ class DruidDatasource(Model, BaseDatasource): baselink = 'druiddatasourcemodelview' # Columns - datasource_name = Column(String(255), unique=True) + datasource_name = Column(String(255)) is_hidden = Column(Boolean, default=False) fetch_values_from = Column(String(100)) cluster_name = Column( @@ -432,6 +431,7 @@ class DruidDatasource(Model, BaseDatasource): sm.user_model, backref=backref('datasources', cascade='all, delete-orphan'), foreign_keys=[user_id]) + UniqueConstraint('cluster_name', 'datasource_name') export_fields = ( 'datasource_name', 'is_hidden', 'description', 'default_endpoint', @@ -519,7 +519,7 @@ def import_obj(cls, i_datasource, import_time=None): superset instances. Audit metadata isn't copies over. """ def lookup_datasource(d): - return db.session.query(DruidDatasource).join(DruidCluster).filter( + return db.session.query(DruidDatasource).filter( DruidDatasource.datasource_name == d.datasource_name, DruidCluster.cluster_name == d.cluster_name, ).first() @@ -620,13 +620,12 @@ def generate_metrics_for(self, columns): metrics.update(col.get_metrics()) dbmetrics = ( db.session.query(DruidMetric) - .filter(DruidCluster.cluster_name == self.cluster_name) - .filter(DruidMetric.datasource_name == self.datasource_name) + .filter(DruidMetric.datasource_id == self.id) .filter(or_(DruidMetric.metric_name == m for m in metrics)) ) dbmetrics = {metric.metric_name: metric for metric in dbmetrics} for metric in metrics.values(): - metric.datasource_name = self.datasource_name + metric.datasource_id = self.id if not dbmetrics.get(metric.metric_name, None): with db.session.no_autoflush: db.session.add(metric) @@ -661,7 +660,7 @@ def sync_to_db_from_config( dimensions = druid_config['dimensions'] col_objs = ( session.query(DruidColumn) - .filter(DruidColumn.datasource_name == druid_config['name']) + .filter(DruidColumn.datasource_id == datasource.id) .filter(or_(DruidColumn.column_name == dim for dim in dimensions)) ) col_objs = {col.column_name: col for col in col_objs} @@ -669,7 +668,7 @@ def sync_to_db_from_config( col_obj = col_objs.get(dim, None) if not col_obj: col_obj = DruidColumn( - datasource_name=druid_config['name'], + datasource_id=datasource.id, column_name=dim, groupby=True, filterable=True, @@ -681,7 +680,7 @@ def sync_to_db_from_config( # Import Druid metrics metric_objs = ( session.query(DruidMetric) - .filter(DruidMetric.datasource_name == druid_config['name']) + .filter(DruidMetric.datasource_id == datasource.id) .filter(or_(DruidMetric.metric_name == spec['name'] for spec in druid_config['metrics_spec'])) ) diff --git a/superset/migrations/versions/4736ec66ce19_.py b/superset/migrations/versions/4736ec66ce19_.py new file mode 100644 index 0000000000000..2d560d57dfd21 --- /dev/null +++ b/superset/migrations/versions/4736ec66ce19_.py @@ -0,0 +1,201 @@ +"""empty message + +Revision ID: 4736ec66ce19 +Revises: f959a6652acd +Create Date: 2017-10-03 14:37:01.376578 + +""" + +# revision identifiers, used by Alembic. +revision = '4736ec66ce19' +down_revision = 'f959a6652acd' + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.exc import OperationalError + +from superset.utils import ( + generic_find_fk_constraint_name, + generic_find_fk_constraint_names, + generic_find_uq_constraint_name, +) + + +conv = { + 'fk': 'fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s', + 'uq': 'uq_%(table_name)s_%(column_0_name)s', +} + +# Helper table for database migrations using minimal schema. +datasources = sa.Table( + 'datasources', + sa.MetaData(), + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('datasource_name', sa.String(255)), +) + +bind = op.get_bind() +insp = sa.engine.reflection.Inspector.from_engine(bind) + + +def upgrade(): + + # Add the new less restrictive uniqueness constraint. + with op.batch_alter_table('datasources', naming_convention=conv) as batch_op: + batch_op.create_unique_constraint( + 'uq_datasources_cluster_name', + ['cluster_name', 'datasource_name'], + ) + + # Augment the tables which have a foreign key constraint related to the + # datasources.datasource_name column. + for foreign in ['columns', 'metrics']: + with op.batch_alter_table(foreign, naming_convention=conv) as batch_op: + + # Add the datasource_id column with the relevant constraints. + batch_op.add_column(sa.Column('datasource_id', sa.Integer)) + + batch_op.create_foreign_key( + 'fk_{}_datasource_id_datasources'.format(foreign), + 'datasources', + ['datasource_id'], + ['id'], + ) + + # Helper table for database migration using minimal schema. + table = sa.Table( + foreign, + sa.MetaData(), + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('datasource_name', sa.String(255)), + sa.Column('datasource_id', sa.Integer), + ) + + # Migrate the existing data. + for datasource in bind.execute(datasources.select()): + bind.execute( + table.update().where( + table.c.datasource_name == datasource.datasource_name, + ).values( + datasource_id=datasource.id, + ), + ) + + with op.batch_alter_table(foreign, naming_convention=conv) as batch_op: + + # Drop the datasource_name column and associated constraints. Note + # due to prior revisions (1226819ee0e3, 3b626e2a6783) there may + # incorectly be multiple duplicate constraints. + names = generic_find_fk_constraint_names( + foreign, + {'datasource_name'}, + 'datasources', + insp, + ) + + for name in names: + batch_op.drop_constraint( + name or 'fk_{}_datasource_name_datasources'.format(foreign), + type_='foreignkey', + ) + + batch_op.drop_column('datasource_name') + + # Drop the old more restrictive uniqueness constraint. + with op.batch_alter_table('datasources', naming_convention=conv) as batch_op: + batch_op.drop_constraint( + generic_find_uq_constraint_name( + 'datasources', + {'datasource_name'}, + insp, + ) or 'uq_datasources_datasource_name', + type_='unique', + ) + + +def downgrade(): + + # Add the new more restrictive uniqueness constraint which is required by + # the foreign key constraints. Note this operation will fail if the + # datasources.datasource_name column is no longer unique. + with op.batch_alter_table('datasources', naming_convention=conv) as batch_op: + batch_op.create_unique_constraint( + 'uq_datasources_datasource_name', + ['datasource_name'], + ) + + # Augment the tables which have a foreign key constraint related to the + # datasources.datasource_id column. + for foreign in ['columns', 'metrics']: + with op.batch_alter_table(foreign, naming_convention=conv) as batch_op: + + # Add the datasource_name column with the relevant constraints. + batch_op.add_column(sa.Column('datasource_name', sa.String(255))) + + batch_op.create_foreign_key( + 'fk_{}_datasource_name_datasources'.format(foreign), + 'datasources', + ['datasource_name'], + ['datasource_name'], + ) + + # Helper table for database migration using minimal schema. + table = sa.Table( + foreign, + sa.MetaData(), + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('datasource_name', sa.String(255)), + sa.Column('datasource_id', sa.Integer), + ) + + # Migrate the existing data. + for datasource in bind.execute(datasources.select()): + bind.execute( + table.update().where( + table.c.datasource_id == datasource.id, + ).values( + datasource_name=datasource.datasource_name, + ), + ) + + with op.batch_alter_table(foreign, naming_convention=conv) as batch_op: + + # Drop the datasource_id column and associated constraint. + batch_op.drop_constraint( + 'fk_{}_datasource_id_datasources'.format(foreign), + type_='foreignkey', + ) + + batch_op.drop_column('datasource_id') + + with op.batch_alter_table('datasources', naming_convention=conv) as batch_op: + + # Prior to dropping the uniqueness constraint, the foreign key + # associated with the cluster_name column needs to be dropped. + batch_op.drop_constraint( + generic_find_fk_constraint_name( + 'datasources', + {'cluster_name'}, + 'clusters', + insp, + ) or 'fk_datasources_cluster_name_clusters', + type_='foreignkey', + ) + + # Drop the old less restrictive uniqueness constraint. + batch_op.drop_constraint( + generic_find_uq_constraint_name( + 'datasources', + {'cluster_name', 'datasource_name'}, + insp, + ) or 'uq_datasources_cluster_name', + type_='unique', + ) + + # Re-create the foreign key associated with the cluster_name column. + batch_op.create_foreign_key( + 'fk_{}_datasource_id_datasources'.format(foreign), + 'clusters', + ['cluster_name'], + ['cluster_name'], + ) diff --git a/superset/utils.py b/superset/utils.py index 469bbc26cfc7a..bae330b4af279 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -377,11 +377,36 @@ def generic_find_constraint_name(table, columns, referenced, db): t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine) for fk in t.foreign_key_constraints: - if (fk.referred_table.name == referenced and - set(fk.column_keys) == columns): + if fk.referred_table.name == referenced and set(fk.column_keys) == columns: return fk.name +def generic_find_fk_constraint_name(table, columns, referenced, insp): + """Utility to find a foreign-key constraint name in alembic migrations""" + for fk in insp.get_foreign_keys(table): + if fk['referred_table'] == referenced and set(fk['referred_columns']) == columns: + return fk['name'] + + +def generic_find_fk_constraint_names(table, columns, referenced, insp): + """Utility to find foreign-key constraint names in alembic migrations""" + names = set() + + for fk in insp.get_foreign_keys(table): + if fk['referred_table'] == referenced and set(fk['referred_columns']) == columns: + names.add(fk['name']) + + return names + + +def generic_find_uq_constraint_name(table, columns, insp): + """Utility to find a unique constraint name in alembic migrations""" + + for uq in insp.get_unique_constraints(table): + if columns == set(uq['column_names']): + return uq['name'] + + def get_datasource_full_name(database_name, datasource_name, schema=None): if not schema: return '[{}].[{}]'.format(database_name, datasource_name) diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index e945630f36986..0710cacecec5d 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -485,13 +485,12 @@ def test_import_druid_2_col_2_met(self): def test_import_druid_override(self): datasource = self.create_druid_datasource( - 'druid_override', id=10003, cols_names=['col1'], + 'druid_override', id=10004, cols_names=['col1'], metric_names=['m1']) imported_id = DruidDatasource.import_obj( datasource, import_time=1991) - table_over = self.create_druid_datasource( - 'druid_override', id=10003, + 'druid_override', id=10004, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) imported_over_id = DruidDatasource.import_obj( @@ -500,19 +499,19 @@ def test_import_druid_override(self): imported_over = self.get_datasource(imported_over_id) self.assertEquals(imported_id, imported_over.id) expected_datasource = self.create_druid_datasource( - 'druid_override', id=10003, metric_names=['new_metric1', 'm1'], + 'druid_override', id=10004, metric_names=['new_metric1', 'm1'], cols_names=['col1', 'new_col1', 'col2', 'col3']) self.assert_datasource_equals(expected_datasource, imported_over) def test_import_druid_override_idential(self): datasource = self.create_druid_datasource( - 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], + 'copy_cat', id=10005, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) imported_id = DruidDatasource.import_obj( datasource, import_time=1993) copy_datasource = self.create_druid_datasource( - 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], + 'copy_cat', id=10005, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) imported_id_copy = DruidDatasource.import_obj( copy_datasource, import_time=1994)