From 7cc71e31b6e51bfb226909b2d3eed671f4324781 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Mon, 6 Mar 2017 15:49:51 -0800 Subject: [PATCH 1/8] Formalizing the Connector interface --- superset/connectors/__init__.py | 0 superset/connectors/base.py | 77 +++++++++++++++++++++++++++++++++ superset/models.py | 76 -------------------------------- 3 files changed, 77 insertions(+), 76 deletions(-) create mode 100644 superset/connectors/__init__.py create mode 100644 superset/connectors/base.py diff --git a/superset/connectors/__init__.py b/superset/connectors/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/superset/connectors/base.py b/superset/connectors/base.py new file mode 100644 index 0000000000000..ae8ebf48696ea --- /dev/null +++ b/superset/connectors/base.py @@ -0,0 +1,77 @@ + +class Datasource(object): + + """A common interface to objects that are queryable (tables and datasources)""" + + # Used to do code highlighting when displaying the query in the UI + query_language = None + + @property + def column_names(self): + return sorted([c.column_name for c in self.columns]) + + @property + def main_dttm_col(self): + return "timestamp" + + @property + def groupby_column_names(self): + return sorted([c.column_name for c in self.columns if c.groupby]) + + @property + def filterable_column_names(self): + return sorted([c.column_name for c in self.columns if c.filterable]) + + @property + def dttm_cols(self): + return [] + + @property + def url(self): + return '/{}/edit/{}'.format(self.baselink, self.id) + + @property + def explore_url(self): + if self.default_endpoint: + return self.default_endpoint + else: + return "/superset/explore/{obj.type}/{obj.id}/".format(obj=self) + + @property + def column_formats(self): + return { + m.metric_name: m.d3format + for m in self.metrics + if m.d3format + } + + @property + def data(self): + """data representation of the datasource sent to the frontend""" + order_by_choices = [] + for s in sorted(self.column_names): + order_by_choices.append((json.dumps([s, True]), s + ' [asc]')) + order_by_choices.append((json.dumps([s, False]), s + ' [desc]')) + + d = { + 'all_cols': utils.choicify(self.column_names), + 'column_formats': self.column_formats, + 'edit_url' : self.url, + 'filter_select': self.filter_select_enabled, + 'filterable_cols': utils.choicify(self.filterable_column_names), + 'gb_cols': utils.choicify(self.groupby_column_names), + 'id': self.id, + 'metrics_combo': self.metrics_combo, + 'name': self.name, + 'order_by_choices': order_by_choices, + 'type': self.type, + } + if self.type == 'table': + grains = self.database.grains() or [] + if grains: + grains = [(g.name, g.name) for g in grains] + d['granularity_sqla'] = utils.choicify(self.dttm_cols) + d['time_grain_sqla'] = grains + return d + + diff --git a/superset/models.py b/superset/models.py index cd732fb4832ef..ffbcb95400a9d 100644 --- a/superset/models.py +++ b/superset/models.py @@ -648,82 +648,6 @@ def export_dashboards(cls, dashboard_ids): }) -class Datasource(object): - - """A common interface to objects that are queryable (tables and datasources)""" - - # Used to do code highlighting when displaying the query in the UI - query_language = None - - @property - def column_names(self): - return sorted([c.column_name for c in self.columns]) - - @property - def main_dttm_col(self): - return "timestamp" - - @property - def groupby_column_names(self): - return sorted([c.column_name for c in self.columns if c.groupby]) - - @property - def filterable_column_names(self): - return sorted([c.column_name for c in self.columns if c.filterable]) - - @property - def dttm_cols(self): - return [] - - @property - def url(self): - return '/{}/edit/{}'.format(self.baselink, self.id) - - @property - def explore_url(self): - if self.default_endpoint: - return self.default_endpoint - else: - return "/superset/explore/{obj.type}/{obj.id}/".format(obj=self) - - @property - def column_formats(self): - return { - m.metric_name: m.d3format - for m in self.metrics - if m.d3format - } - - @property - def data(self): - """data representation of the datasource sent to the frontend""" - order_by_choices = [] - for s in sorted(self.column_names): - order_by_choices.append((json.dumps([s, True]), s + ' [asc]')) - order_by_choices.append((json.dumps([s, False]), s + ' [desc]')) - - d = { - 'all_cols': utils.choicify(self.column_names), - 'column_formats': self.column_formats, - 'edit_url' : self.url, - 'filter_select': self.filter_select_enabled, - 'filterable_cols': utils.choicify(self.filterable_column_names), - 'gb_cols': utils.choicify(self.groupby_column_names), - 'id': self.id, - 'metrics_combo': self.metrics_combo, - 'name': self.name, - 'order_by_choices': order_by_choices, - 'type': self.type, - } - if self.type == 'table': - grains = self.database.grains() or [] - if grains: - grains = [(g.name, g.name) for g in grains] - d['granularity_sqla'] = utils.choicify(self.dttm_cols) - d['time_grain_sqla'] = grains - return d - - class Database(Model, AuditMixinNullable): """An ORM object that stores Database related information""" From 661e0fca5b461d6f898dcda95fa8499ae260d622 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Mon, 6 Mar 2017 23:45:42 -0800 Subject: [PATCH 2/8] Checkpoint --- superset/__init__.py | 8 +- superset/cli.py | 3 +- superset/config.py | 5 +- superset/connectors/base.py | 4 + .../connector_registry.py} | 12 +- superset/connectors/druid/__init__.py | 2 + superset/connectors/druid/models.py | 1048 ++++++ superset/connectors/druid/views.py | 203 ++ superset/connectors/sqla/__init__.py | 2 + superset/connectors/sqla/models.py | 701 +++++ superset/connectors/sqla/views.py | 213 ++ superset/data/__init__.py | 8 +- superset/models.py | 2800 ----------------- superset/models/__init__.py | 1 + superset/models/core.py | 952 ++++++ superset/models/helpers.py | 127 + superset/security.py | 6 +- superset/sql_lab.py | 2 +- superset/utils.py | 2 +- superset/views/__init__.py | 2 + superset/views/base.py | 201 ++ superset/{views.py => views/core.py} | 619 +--- 22 files changed, 3510 insertions(+), 3411 deletions(-) rename superset/{source_registry.py => connectors/connector_registry.py} (85%) create mode 100644 superset/connectors/druid/__init__.py create mode 100644 superset/connectors/druid/models.py create mode 100644 superset/connectors/druid/views.py create mode 100644 superset/connectors/sqla/__init__.py create mode 100644 superset/connectors/sqla/models.py create mode 100644 superset/connectors/sqla/views.py delete mode 100644 superset/models.py create mode 100644 superset/models/__init__.py create mode 100644 superset/models/core.py create mode 100644 superset/models/helpers.py create mode 100644 superset/views/__init__.py create mode 100644 superset/views/base.py rename superset/{views.py => views/core.py} (78%) diff --git a/superset/__init__.py b/superset/__init__.py index bad241e3fccb0..cce4830f4c504 100644 --- a/superset/__init__.py +++ b/superset/__init__.py @@ -13,9 +13,9 @@ from flask_appbuilder import SQLA, AppBuilder, IndexView from flask_appbuilder.baseviews import expose from flask_migrate import Migrate -from superset.source_registry import SourceRegistry +from superset.connectors.connector_registry import ConnectorRegistry from werkzeug.contrib.fixers import ProxyFix -from superset import utils +from superset import utils, config # noqa APP_DIR = os.path.dirname(__file__) @@ -104,6 +104,6 @@ def index(self): # Registering sources module_datasource_map = app.config.get("DEFAULT_MODULE_DS_MAP") module_datasource_map.update(app.config.get("ADDITIONAL_MODULE_DS_MAP")) -SourceRegistry.register_sources(module_datasource_map) +ConnectorRegistry.register_sources(module_datasource_map) -from superset import views, config # noqa +from superset import views # noqa diff --git a/superset/cli.py b/superset/cli.py index f56faef5548d4..dfeb3ae8aa5f4 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -13,7 +13,7 @@ from flask_migrate import MigrateCommand from flask_script import Manager -from superset import app, db, data, security +from superset import app, db, security config = app.config @@ -89,6 +89,7 @@ def version(verbose): help="Load additional test data") def load_examples(load_test_data): """Loads a set of Slices and Dashboards and a supporting dataset """ + from superset import data print("Loading examples into {}".format(db)) data.load_css_templates() diff --git a/superset/config.py b/superset/config.py index b2bc460c4f430..ce3a3aad4c024 100644 --- a/superset/config.py +++ b/superset/config.py @@ -178,7 +178,10 @@ # -------------------------------------------------- # Modules, datasources and middleware to be registered # -------------------------------------------------- -DEFAULT_MODULE_DS_MAP = {'superset.models': ['DruidDatasource', 'SqlaTable']} +DEFAULT_MODULE_DS_MAP = { + 'superset.connectors.druid.models': ['DruidDatasource'], + 'superset.connectors.sqla.models': ['SqlaTable'], +} ADDITIONAL_MODULE_DS_MAP = {} ADDITIONAL_MIDDLEWARE = [] diff --git a/superset/connectors/base.py b/superset/connectors/base.py index ae8ebf48696ea..45dd3360f8807 100644 --- a/superset/connectors/base.py +++ b/superset/connectors/base.py @@ -1,3 +1,7 @@ +import json + +from superset import utils + class Datasource(object): diff --git a/superset/source_registry.py b/superset/connectors/connector_registry.py similarity index 85% rename from superset/source_registry.py rename to superset/connectors/connector_registry.py index df91762564bbd..c3662ef14e009 100644 --- a/superset/source_registry.py +++ b/superset/connectors/connector_registry.py @@ -1,7 +1,7 @@ from sqlalchemy.orm import subqueryload -class SourceRegistry(object): +class ConnectorRegistry(object): """ Central Registry for all available datasource engines""" sources = {} @@ -26,15 +26,15 @@ def get_datasource(cls, datasource_type, datasource_id, session): @classmethod def get_all_datasources(cls, session): datasources = [] - for source_type in SourceRegistry.sources: + for source_type in ConnectorRegistry.sources: datasources.extend( - session.query(SourceRegistry.sources[source_type]).all()) + session.query(ConnectorRegistry.sources[source_type]).all()) return datasources @classmethod def get_datasource_by_name(cls, session, datasource_type, datasource_name, schema, database_name): - datasource_class = SourceRegistry.sources[datasource_type] + datasource_class = ConnectorRegistry.sources[datasource_type] datasources = session.query(datasource_class).all() # Filter datasoures that don't have database. @@ -45,7 +45,7 @@ def get_datasource_by_name(cls, session, datasource_type, datasource_name, @classmethod def query_datasources_by_permissions(cls, session, database, permissions): - datasource_class = SourceRegistry.sources[database.type] + datasource_class = ConnectorRegistry.sources[database.type] return ( session.query(datasource_class) .filter_by(database_id=database.id) @@ -56,7 +56,7 @@ def query_datasources_by_permissions(cls, session, database, permissions): @classmethod def get_eager_datasource(cls, session, datasource_type, datasource_id): """Returns datasource with columns and metrics.""" - datasource_class = SourceRegistry.sources[datasource_type] + datasource_class = ConnectorRegistry.sources[datasource_type] return ( session.query(datasource_class) .options( diff --git a/superset/connectors/druid/__init__.py b/superset/connectors/druid/__init__.py new file mode 100644 index 0000000000000..b2df79851f224 --- /dev/null +++ b/superset/connectors/druid/__init__.py @@ -0,0 +1,2 @@ +from . import models # noqa +from . import views # noqa diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py new file mode 100644 index 0000000000000..f1d70ed42d73b --- /dev/null +++ b/superset/connectors/druid/models.py @@ -0,0 +1,1048 @@ +from collections import OrderedDict +import json +import logging +from copy import deepcopy +from datetime import datetime, timedelta +from six import string_types + +import requests +from sqlalchemy import ( + Column, Integer, String, ForeignKey, Text, Boolean, + DateTime, +) +from sqlalchemy.orm import backref, relationship +from dateutil.parser import parse as dparse + +from pydruid.client import PyDruid +from pydruid.utils.aggregators import count +from pydruid.utils.filters import Dimension, Filter +from pydruid.utils.postaggregator import ( + Postaggregator, Quantile, Quantiles, Field, Const, HyperUniqueCardinality, +) +from pydruid.utils.having import Aggregation + +from flask import Markup, escape +from flask_appbuilder.models.decorators import renders +from flask_appbuilder import Model + +from flask_babel import lazy_gettext as _ + +from superset import config, db, import_util, utils, sm, get_session +from superset.utils import ( + flasher, MetricPermException, DimSelector, DTTM_ALIAS +) +from superset.connectors.base import Datasource +from superset.models.helpers import ( + AuditMixinNullable, ImportMixin, QueryResult) + + +class JavascriptPostAggregator(Postaggregator): + def __init__(self, name, field_names, function): + self.post_aggregator = { + 'type': 'javascript', + 'fieldNames': field_names, + 'name': name, + 'function': function, + } + self.name = name + + +class DruidCluster(Model, AuditMixinNullable): + + """ORM object referencing the Druid clusters""" + + __tablename__ = 'clusters' + type = "druid" + + id = Column(Integer, primary_key=True) + cluster_name = Column(String(250), unique=True) + coordinator_host = Column(String(255)) + coordinator_port = Column(Integer) + coordinator_endpoint = Column( + String(255), default='druid/coordinator/v1/metadata') + broker_host = Column(String(255)) + broker_port = Column(Integer) + broker_endpoint = Column(String(255), default='druid/v2') + metadata_last_refreshed = Column(DateTime) + cache_timeout = Column(Integer) + + def __repr__(self): + return self.cluster_name + + def get_pydruid_client(self): + cli = PyDruid( + "http://{0}:{1}/".format(self.broker_host, self.broker_port), + self.broker_endpoint) + return cli + + def get_datasources(self): + endpoint = ( + "http://{obj.coordinator_host}:{obj.coordinator_port}/" + "{obj.coordinator_endpoint}/datasources" + ).format(obj=self) + + return json.loads(requests.get(endpoint).text) + + def get_druid_version(self): + endpoint = ( + "http://{obj.coordinator_host}:{obj.coordinator_port}/status" + ).format(obj=self) + return json.loads(requests.get(endpoint).text)['version'] + + def refresh_datasources(self, datasource_name=None, merge_flag=False): + """Refresh metadata of all datasources in the cluster + If ``datasource_name`` is specified, only that datasource is updated + """ + self.druid_version = self.get_druid_version() + for datasource in self.get_datasources(): + if datasource not in config.get('DRUID_DATA_SOURCE_BLACKLIST'): + if not datasource_name or datasource_name == datasource: + DruidDatasource.sync_to_db(datasource, self, merge_flag) + + @property + def perm(self): + return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self) + + @property + def name(self): + return self.cluster_name + + +class DruidColumn(Model, AuditMixinNullable, ImportMixin): + """ORM model for storing Druid datasource column metadata""" + + __tablename__ = 'columns' + id = Column(Integer, primary_key=True) + datasource_name = Column( + String(255), + ForeignKey('datasources.datasource_name')) + # Setting enable_typechecks=False disables polymorphic inheritance. + datasource = relationship( + 'DruidDatasource', + backref=backref('columns', cascade='all, delete-orphan'), + enable_typechecks=False) + column_name = Column(String(255)) + is_active = Column(Boolean, default=True) + type = Column(String(32)) + groupby = Column(Boolean, default=False) + count_distinct = Column(Boolean, default=False) + sum = Column(Boolean, default=False) + avg = Column(Boolean, default=False) + max = Column(Boolean, default=False) + min = Column(Boolean, default=False) + filterable = Column(Boolean, default=False) + description = Column(Text) + dimension_spec_json = Column(Text) + + export_fields = ( + 'datasource_name', 'column_name', 'is_active', 'type', 'groupby', + 'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable', + 'description', 'dimension_spec_json' + ) + + def __repr__(self): + return self.column_name + + @property + def is_num(self): + return self.type in ('LONG', 'DOUBLE', 'FLOAT', 'INT') + + @property + def dimension_spec(self): + if self.dimension_spec_json: + return json.loads(self.dimension_spec_json) + + def generate_metrics(self): + """Generate metrics based on the column metadata""" + M = DruidMetric # noqa + metrics = [] + metrics.append(DruidMetric( + metric_name='count', + verbose_name='COUNT(*)', + metric_type='count', + json=json.dumps({'type': 'count', 'name': 'count'}) + )) + # Somehow we need to reassign this for UDAFs + if self.type in ('DOUBLE', 'FLOAT'): + corrected_type = 'DOUBLE' + else: + corrected_type = self.type + + if self.sum and self.is_num: + mt = corrected_type.lower() + 'Sum' + name = 'sum__' + self.column_name + metrics.append(DruidMetric( + metric_name=name, + metric_type='sum', + verbose_name='SUM({})'.format(self.column_name), + json=json.dumps({ + 'type': mt, 'name': name, 'fieldName': self.column_name}) + )) + + if self.avg and self.is_num: + mt = corrected_type.lower() + 'Avg' + name = 'avg__' + self.column_name + metrics.append(DruidMetric( + metric_name=name, + metric_type='avg', + verbose_name='AVG({})'.format(self.column_name), + json=json.dumps({ + 'type': mt, 'name': name, 'fieldName': self.column_name}) + )) + + if self.min and self.is_num: + mt = corrected_type.lower() + 'Min' + name = 'min__' + self.column_name + metrics.append(DruidMetric( + metric_name=name, + metric_type='min', + verbose_name='MIN({})'.format(self.column_name), + json=json.dumps({ + 'type': mt, 'name': name, 'fieldName': self.column_name}) + )) + if self.max and self.is_num: + mt = corrected_type.lower() + 'Max' + name = 'max__' + self.column_name + metrics.append(DruidMetric( + metric_name=name, + metric_type='max', + verbose_name='MAX({})'.format(self.column_name), + json=json.dumps({ + 'type': mt, 'name': name, 'fieldName': self.column_name}) + )) + if self.count_distinct: + name = 'count_distinct__' + self.column_name + if self.type == 'hyperUnique' or self.type == 'thetaSketch': + metrics.append(DruidMetric( + metric_name=name, + verbose_name='COUNT(DISTINCT {})'.format(self.column_name), + metric_type=self.type, + json=json.dumps({ + 'type': self.type, + 'name': name, + 'fieldName': self.column_name + }) + )) + else: + mt = 'count_distinct' + metrics.append(DruidMetric( + metric_name=name, + verbose_name='COUNT(DISTINCT {})'.format(self.column_name), + metric_type='count_distinct', + json=json.dumps({ + 'type': 'cardinality', + 'name': name, + 'fieldNames': [self.column_name]}) + )) + session = get_session() + new_metrics = [] + for metric in metrics: + m = ( + session.query(M) + .filter(M.metric_name == metric.metric_name) + .filter(M.datasource_name == self.datasource_name) + .filter(DruidCluster.cluster_name == self.datasource.cluster_name) + .first() + ) + metric.datasource_name = self.datasource_name + if not m: + new_metrics.append(metric) + session.add(metric) + session.flush() + + @classmethod + def import_obj(cls, i_column): + def lookup_obj(lookup_column): + return db.session.query(DruidColumn).filter( + DruidColumn.datasource_name == lookup_column.datasource_name, + DruidColumn.column_name == lookup_column.column_name).first() + + return import_util.import_simple_obj(db.session, i_column, lookup_obj) + + +class DruidMetric(Model, AuditMixinNullable, ImportMixin): + + """ORM object referencing Druid metrics for a datasource""" + + __tablename__ = 'metrics' + id = Column(Integer, primary_key=True) + metric_name = Column(String(512)) + verbose_name = Column(String(1024)) + metric_type = Column(String(32)) + datasource_name = Column( + String(255), + ForeignKey('datasources.datasource_name')) + # Setting enable_typechecks=False disables polymorphic inheritance. + datasource = relationship( + 'DruidDatasource', + backref=backref('metrics', cascade='all, delete-orphan'), + enable_typechecks=False) + json = Column(Text) + description = Column(Text) + is_restricted = Column(Boolean, default=False, nullable=True) + d3format = Column(String(128)) + + def refresh_datasources(self, datasource_name=None, merge_flag=False): + """Refresh metadata of all datasources in the cluster + + If ``datasource_name`` is specified, only that datasource is updated + """ + self.druid_version = self.get_druid_version() + for datasource in self.get_datasources(): + if datasource not in config.get('DRUID_DATA_SOURCE_BLACKLIST'): + if not datasource_name or datasource_name == datasource: + DruidDatasource.sync_to_db(datasource, self, merge_flag) + export_fields = ( + 'metric_name', 'verbose_name', 'metric_type', 'datasource_name', + 'json', 'description', 'is_restricted', 'd3format' + ) + + @property + def json_obj(self): + try: + obj = json.loads(self.json) + except Exception: + obj = {} + return obj + + @property + def perm(self): + return ( + "{parent_name}.[{obj.metric_name}](id:{obj.id})" + ).format(obj=self, + parent_name=self.datasource.full_name + ) if self.datasource else None + + @classmethod + def import_obj(cls, i_metric): + def lookup_obj(lookup_metric): + return db.session.query(DruidMetric).filter( + DruidMetric.datasource_name == lookup_metric.datasource_name, + DruidMetric.metric_name == lookup_metric.metric_name).first() + return import_util.import_simple_obj(db.session, i_metric, lookup_obj) + + +class DruidDatasource(Model, AuditMixinNullable, Datasource, ImportMixin): + + """ORM object referencing Druid datasources (tables)""" + + type = "druid" + query_langtage = "json" + + baselink = "druiddatasourcemodelview" + + __tablename__ = 'datasources' + id = Column(Integer, primary_key=True) + datasource_name = Column(String(255), unique=True) + is_featured = Column(Boolean, default=False) + is_hidden = Column(Boolean, default=False) + filter_select_enabled = Column(Boolean, default=False) + description = Column(Text) + default_endpoint = Column(Text) + user_id = Column(Integer, ForeignKey('ab_user.id')) + owner = relationship( + 'User', + backref=backref('datasources', cascade='all, delete-orphan'), + foreign_keys=[user_id]) + cluster_name = Column( + String(250), ForeignKey('clusters.cluster_name')) + cluster = relationship( + 'DruidCluster', backref='datasources', foreign_keys=[cluster_name]) + offset = Column(Integer, default=0) + cache_timeout = Column(Integer) + params = Column(String(1000)) + perm = Column(String(1000)) + + metric_cls = DruidMetric + column_cls = DruidColumn + + export_fields = ( + 'datasource_name', 'is_hidden', 'description', 'default_endpoint', + 'cluster_name', 'is_featured', 'offset', 'cache_timeout', 'params' + ) + + @property + def metrics_combo(self): + return sorted( + [(m.metric_name, m.verbose_name) for m in self.metrics], + key=lambda x: x[1]) + + @property + def database(self): + return self.cluster + + @property + def num_cols(self): + return [c.column_name for c in self.columns if c.is_num] + + @property + def name(self): + return self.datasource_name + + @property + def schema(self): + name_pieces = self.datasource_name.split('.') + if len(name_pieces) > 1: + return name_pieces[0] + else: + return None + + @property + def schema_perm(self): + """Returns schema permission if present, cluster one otherwise.""" + return utils.get_schema_perm(self.cluster, self.schema) + + def get_perm(self): + return ( + "[{obj.cluster_name}].[{obj.datasource_name}]" + "(id:{obj.id})").format(obj=self) + + @property + def link(self): + name = escape(self.datasource_name) + return Markup('{name}').format(**locals()) + + @property + def full_name(self): + return utils.get_datasource_full_name( + self.cluster_name, self.datasource_name) + + @property + def time_column_grains(self): + return { + "time_columns": [ + 'all', '5 seconds', '30 seconds', '1 minute', + '5 minutes', '1 hour', '6 hour', '1 day', '7 days', + 'week', 'week_starting_sunday', 'week_ending_saturday', + 'month', + ], + "time_grains": ['now'] + } + + def __repr__(self): + return self.datasource_name + + @renders('datasource_name') + def datasource_link(self): + url = "/superset/explore/{obj.type}/{obj.id}/".format(obj=self) + name = escape(self.datasource_name) + return Markup('{name}'.format(**locals())) + + def get_metric_obj(self, metric_name): + return [ + m.json_obj for m in self.metrics + if m.metric_name == metric_name + ][0] + + @classmethod + def import_obj(cls, i_datasource, import_time=None): + """Imports the datasource from the object to the database. + + Metrics and columns and datasource will be overridden if exists. + This function can be used to import/export dashboards between multiple + superset instances. Audit metadata isn't copies over. + """ + def lookup_datasource(d): + return db.session.query(DruidDatasource).join(DruidCluster).filter( + DruidDatasource.datasource_name == d.datasource_name, + DruidCluster.cluster_name == d.cluster_name, + ).first() + + def lookup_cluster(d): + return db.session.query(DruidCluster).filter_by( + cluster_name=d.cluster_name).one() + return import_util.import_datasource( + db.session, i_datasource, lookup_cluster, lookup_datasource, + import_time) + + @staticmethod + def version_higher(v1, v2): + """is v1 higher than v2 + + >>> DruidDatasource.version_higher('0.8.2', '0.9.1') + False + >>> DruidDatasource.version_higher('0.8.2', '0.6.1') + True + >>> DruidDatasource.version_higher('0.8.2', '0.8.2') + False + >>> DruidDatasource.version_higher('0.8.2', '0.9.BETA') + False + >>> DruidDatasource.version_higher('0.8.2', '0.9') + False + """ + def int_or_0(v): + try: + v = int(v) + except (TypeError, ValueError): + v = 0 + return v + v1nums = [int_or_0(n) for n in v1.split('.')] + v2nums = [int_or_0(n) for n in v2.split('.')] + v1nums = (v1nums + [0, 0, 0])[:3] + v2nums = (v2nums + [0, 0, 0])[:3] + return v1nums[0] > v2nums[0] or \ + (v1nums[0] == v2nums[0] and v1nums[1] > v2nums[1]) or \ + (v1nums[0] == v2nums[0] and v1nums[1] == v2nums[1] and v1nums[2] > v2nums[2]) + + def latest_metadata(self): + """Returns segment metadata from the latest segment""" + client = self.cluster.get_pydruid_client() + results = client.time_boundary(datasource=self.datasource_name) + if not results: + return + max_time = results[0]['result']['maxTime'] + max_time = dparse(max_time) + # Query segmentMetadata for 7 days back. However, due to a bug, + # we need to set this interval to more than 1 day ago to exclude + # realtime segments, which triggered a bug (fixed in druid 0.8.2). + # https://groups.google.com/forum/#!topic/druid-user/gVCqqspHqOQ + lbound = (max_time - timedelta(days=7)).isoformat() + rbound = max_time.isoformat() + if not self.version_higher(self.cluster.druid_version, '0.8.2'): + rbound = (max_time - timedelta(1)).isoformat() + segment_metadata = None + try: + segment_metadata = client.segment_metadata( + datasource=self.datasource_name, + intervals=lbound + '/' + rbound, + merge=self.merge_flag, + analysisTypes=config.get('DRUID_ANALYSIS_TYPES')) + except Exception as e: + logging.warning("Failed first attempt to get latest segment") + logging.exception(e) + if not segment_metadata: + # if no segments in the past 7 days, look at all segments + lbound = datetime(1901, 1, 1).isoformat()[:10] + rbound = datetime(2050, 1, 1).isoformat()[:10] + if not self.version_higher(self.cluster.druid_version, '0.8.2'): + rbound = datetime.now().isoformat()[:10] + try: + segment_metadata = client.segment_metadata( + datasource=self.datasource_name, + intervals=lbound + '/' + rbound, + merge=self.merge_flag, + analysisTypes=config.get('DRUID_ANALYSIS_TYPES')) + except Exception as e: + logging.warning("Failed 2nd attempt to get latest segment") + logging.exception(e) + if segment_metadata: + return segment_metadata[-1]['columns'] + + def generate_metrics(self): + for col in self.columns: + col.generate_metrics() + + @classmethod + def sync_to_db_from_config(cls, druid_config, user, cluster): + """Merges the ds config from druid_config into one stored in the db.""" + session = db.session() + datasource = ( + session.query(DruidDatasource) + .filter_by( + datasource_name=druid_config['name']) + ).first() + # Create a new datasource. + if not datasource: + datasource = DruidDatasource( + datasource_name=druid_config['name'], + cluster=cluster, + owner=user, + changed_by_fk=user.id, + created_by_fk=user.id, + ) + session.add(datasource) + + dimensions = druid_config['dimensions'] + for dim in dimensions: + col_obj = ( + session.query(DruidColumn) + .filter_by( + datasource_name=druid_config['name'], + column_name=dim) + ).first() + if not col_obj: + col_obj = DruidColumn( + datasource_name=druid_config['name'], + column_name=dim, + groupby=True, + filterable=True, + # TODO: fetch type from Hive. + type="STRING", + datasource=datasource + ) + session.add(col_obj) + # Import Druid metrics + for metric_spec in druid_config["metrics_spec"]: + metric_name = metric_spec["name"] + metric_type = metric_spec["type"] + metric_json = json.dumps(metric_spec) + + if metric_type == "count": + metric_type = "longSum" + metric_json = json.dumps({ + "type": "longSum", + "name": metric_name, + "fieldName": metric_name, + }) + + metric_obj = ( + session.query(DruidMetric) + .filter_by( + datasource_name=druid_config['name'], + metric_name=metric_name) + ).first() + if not metric_obj: + metric_obj = DruidMetric( + metric_name=metric_name, + metric_type=metric_type, + verbose_name="%s(%s)" % (metric_type, metric_name), + datasource=datasource, + json=metric_json, + description=( + "Imported from the airolap config dir for %s" % + druid_config['name']), + ) + session.add(metric_obj) + session.commit() + + @classmethod + def sync_to_db(cls, name, cluster, merge): + """Fetches metadata for that datasource and merges the Superset db""" + logging.info("Syncing Druid datasource [{}]".format(name)) + session = get_session() + datasource = session.query(cls).filter_by(datasource_name=name).first() + if not datasource: + datasource = cls(datasource_name=name) + session.add(datasource) + flasher("Adding new datasource [{}]".format(name), "success") + else: + flasher("Refreshing datasource [{}]".format(name), "info") + session.flush() + datasource.cluster = cluster + datasource.merge_flag = merge + session.flush() + + cols = datasource.latest_metadata() + if not cols: + logging.error("Failed at fetching the latest segment") + return + for col in cols: + col_obj = ( + session + .query(DruidColumn) + .filter_by(datasource_name=name, column_name=col) + .first() + ) + datatype = cols[col]['type'] + if not col_obj: + col_obj = DruidColumn(datasource_name=name, column_name=col) + session.add(col_obj) + if datatype == "STRING": + col_obj.groupby = True + col_obj.filterable = True + if datatype == "hyperUnique" or datatype == "thetaSketch": + col_obj.count_distinct = True + if col_obj: + col_obj.type = cols[col]['type'] + session.flush() + col_obj.datasource = datasource + col_obj.generate_metrics() + session.flush() + + @staticmethod + def time_offset(granularity): + if granularity == 'week_ending_saturday': + return 6 * 24 * 3600 * 1000 # 6 days + return 0 + + # uses https://en.wikipedia.org/wiki/ISO_8601 + # http://druid.io/docs/0.8.0/querying/granularities.html + # TODO: pass origin from the UI + @staticmethod + def granularity(period_name, timezone=None, origin=None): + if not period_name or period_name == 'all': + return 'all' + iso_8601_dict = { + '5 seconds': 'PT5S', + '30 seconds': 'PT30S', + '1 minute': 'PT1M', + '5 minutes': 'PT5M', + '1 hour': 'PT1H', + '6 hour': 'PT6H', + 'one day': 'P1D', + '1 day': 'P1D', + '7 days': 'P7D', + 'week': 'P1W', + 'week_starting_sunday': 'P1W', + 'week_ending_saturday': 'P1W', + 'month': 'P1M', + } + + granularity = {'type': 'period'} + if timezone: + granularity['timeZone'] = timezone + + if origin: + dttm = utils.parse_human_datetime(origin) + granularity['origin'] = dttm.isoformat() + + if period_name in iso_8601_dict: + granularity['period'] = iso_8601_dict[period_name] + if period_name in ('week_ending_saturday', 'week_starting_sunday'): + # use Sunday as start of the week + granularity['origin'] = '2016-01-03T00:00:00' + elif not isinstance(period_name, string_types): + granularity['type'] = 'duration' + granularity['duration'] = period_name + elif period_name.startswith('P'): + # identify if the string is the iso_8601 period + granularity['period'] = period_name + else: + granularity['type'] = 'duration' + granularity['duration'] = utils.parse_human_timedelta( + period_name).total_seconds() * 1000 + return granularity + + def values_for_column(self, + column_name, + from_dttm, + to_dttm, + limit=500): + """Retrieve some values for the given column""" + # TODO: Use Lexicographic TopNMetricSpec once supported by PyDruid + from_dttm = from_dttm.replace(tzinfo=config.get("DRUID_TZ")) + to_dttm = to_dttm.replace(tzinfo=config.get("DRUID_TZ")) + + qry = dict( + datasource=self.datasource_name, + granularity="all", + intervals=from_dttm.isoformat() + '/' + to_dttm.isoformat(), + aggregations=dict(count=count("count")), + dimension=column_name, + metric="count", + threshold=limit, + ) + + client = self.cluster.get_pydruid_client() + client.topn(**qry) + df = client.export_pandas() + + if df is None or df.size == 0: + raise Exception(_("No data was returned.")) + + return df + + def get_query_str( # noqa / druid + self, client, qry_start_dttm, + groupby, metrics, + granularity, + from_dttm, to_dttm, + filter=None, # noqa + is_timeseries=True, + timeseries_limit=None, + timeseries_limit_metric=None, + row_limit=None, + inner_from_dttm=None, inner_to_dttm=None, + orderby=None, + extras=None, # noqa + select=None, # noqa + columns=None, phase=2): + """Runs a query against Druid and returns a dataframe. + + This query interface is common to SqlAlchemy and Druid + """ + # TODO refactor into using a TBD Query object + if not is_timeseries: + granularity = 'all' + inner_from_dttm = inner_from_dttm or from_dttm + inner_to_dttm = inner_to_dttm or to_dttm + + # add tzinfo to native datetime with config + from_dttm = from_dttm.replace(tzinfo=config.get("DRUID_TZ")) + to_dttm = to_dttm.replace(tzinfo=config.get("DRUID_TZ")) + timezone = from_dttm.tzname() + + query_str = "" + metrics_dict = {m.metric_name: m for m in self.metrics} + all_metrics = [] + post_aggs = {} + + columns_dict = {c.column_name: c for c in self.columns} + + def recursive_get_fields(_conf): + _fields = _conf.get('fields', []) + field_names = [] + for _f in _fields: + _type = _f.get('type') + if _type in ['fieldAccess', 'hyperUniqueCardinality']: + field_names.append(_f.get('fieldName')) + elif _type == 'arithmetic': + field_names += recursive_get_fields(_f) + return list(set(field_names)) + + for metric_name in metrics: + metric = metrics_dict[metric_name] + if metric.metric_type != 'postagg': + all_metrics.append(metric_name) + else: + conf = metric.json_obj + all_metrics += recursive_get_fields(conf) + all_metrics += conf.get('fieldNames', []) + if conf.get('type') == 'javascript': + post_aggs[metric_name] = JavascriptPostAggregator( + name=conf.get('name', ''), + field_names=conf.get('fieldNames', []), + function=conf.get('function', '')) + elif conf.get('type') == 'quantile': + post_aggs[metric_name] = Quantile( + conf.get('name', ''), + conf.get('probability', ''), + ) + elif conf.get('type') == 'quantiles': + post_aggs[metric_name] = Quantiles( + conf.get('name', ''), + conf.get('probabilities', ''), + ) + elif conf.get('type') == 'fieldAccess': + post_aggs[metric_name] = Field(conf.get('name'), '') + elif conf.get('type') == 'constant': + post_aggs[metric_name] = Const( + conf.get('value'), + output_name=conf.get('name', '') + ) + elif conf.get('type') == 'hyperUniqueCardinality': + post_aggs[metric_name] = HyperUniqueCardinality( + conf.get('name'), '' + ) + else: + post_aggs[metric_name] = Postaggregator( + conf.get('fn', "/"), + conf.get('fields', []), + conf.get('name', '')) + + aggregations = OrderedDict() + for m in self.metrics: + if m.metric_name in all_metrics: + aggregations[m.metric_name] = m.json_obj + + rejected_metrics = [ + m.metric_name for m in self.metrics + if m.is_restricted and + m.metric_name in aggregations.keys() and + not sm.has_access('metric_access', m.perm) + ] + + if rejected_metrics: + raise MetricPermException( + "Access to the metrics denied: " + ', '.join(rejected_metrics) + ) + + # the dimensions list with dimensionSpecs expanded + dimensions = [] + groupby = [gb for gb in groupby if gb in columns_dict] + for column_name in groupby: + col = columns_dict.get(column_name) + dim_spec = col.dimension_spec + if dim_spec: + dimensions.append(dim_spec) + else: + dimensions.append(column_name) + qry = dict( + datasource=self.datasource_name, + dimensions=dimensions, + aggregations=aggregations, + granularity=DruidDatasource.granularity( + granularity, + timezone=timezone, + origin=extras.get('druid_time_origin'), + ), + post_aggregations=post_aggs, + intervals=from_dttm.isoformat() + '/' + to_dttm.isoformat(), + ) + + filters = self.get_filters(filter) + if filters: + qry['filter'] = filters + + having_filters = self.get_having_filters(extras.get('having_druid')) + if having_filters: + qry['having'] = having_filters + + orig_filters = filters + if len(groupby) == 0: + del qry['dimensions'] + client.timeseries(**qry) + if not having_filters and len(groupby) == 1: + qry['threshold'] = timeseries_limit or 1000 + if row_limit and granularity == 'all': + qry['threshold'] = row_limit + qry['dimension'] = list(qry.get('dimensions'))[0] + del qry['dimensions'] + qry['metric'] = list(qry['aggregations'].keys())[0] + client.topn(**qry) + elif len(groupby) > 1 or having_filters: + # If grouping on multiple fields or using a having filter + # we have to force a groupby query + if timeseries_limit and is_timeseries: + order_by = metrics[0] if metrics else self.metrics[0] + if timeseries_limit_metric: + order_by = timeseries_limit_metric + # Limit on the number of timeseries, doing a two-phases query + pre_qry = deepcopy(qry) + pre_qry['granularity'] = "all" + pre_qry['limit_spec'] = { + "type": "default", + "limit": timeseries_limit, + 'intervals': ( + inner_from_dttm.isoformat() + '/' + + inner_to_dttm.isoformat()), + "columns": [{ + "dimension": order_by, + "direction": "descending", + }], + } + client.groupby(**pre_qry) + query_str += "// Two phase query\n// Phase 1\n" + query_str += json.dumps( + client.query_builder.last_query.query_dict, indent=2) + query_str += "\n" + if phase == 1: + return query_str + query_str += ( + "//\nPhase 2 (built based on phase one's results)\n") + df = client.export_pandas() + if df is not None and not df.empty: + dims = qry['dimensions'] + filters = [] + for unused, row in df.iterrows(): + fields = [] + for dim in dims: + f = Dimension(dim) == row[dim] + fields.append(f) + if len(fields) > 1: + filt = Filter(type="and", fields=fields) + filters.append(filt) + elif fields: + filters.append(fields[0]) + + if filters: + ff = Filter(type="or", fields=filters) + if not orig_filters: + qry['filter'] = ff + else: + qry['filter'] = Filter(type="and", fields=[ + ff, + orig_filters]) + qry['limit_spec'] = None + if row_limit: + qry['limit_spec'] = { + "type": "default", + "limit": row_limit, + "columns": [{ + "dimension": ( + metrics[0] if metrics else self.metrics[0]), + "direction": "descending", + }], + } + client.groupby(**qry) + query_str += json.dumps( + client.query_builder.last_query.query_dict, indent=2) + return query_str + + def query(self, query_obj): + qry_start_dttm = datetime.now() + client = self.cluster.get_pydruid_client() + query_str = self.get_query_str(client, qry_start_dttm, **query_obj) + df = client.export_pandas() + + if df is None or df.size == 0: + raise Exception(_("No data was returned.")) + df.columns = [ + DTTM_ALIAS if c == 'timestamp' else c for c in df.columns] + + is_timeseries = query_obj['is_timeseries'] \ + if 'is_timeseries' in query_obj else True + if ( + not is_timeseries and + query_obj['granularity'] == "all" and + DTTM_ALIAS in df.columns): + del df[DTTM_ALIAS] + + # Reordering columns + cols = [] + if DTTM_ALIAS in df.columns: + cols += [DTTM_ALIAS] + cols += [col for col in query_obj['groupby'] if col in df.columns] + cols += [col for col in query_obj['metrics'] if col in df.columns] + df = df[cols] + + time_offset = DruidDatasource.time_offset(query_obj['granularity']) + + def increment_timestamp(ts): + dt = utils.parse_human_datetime(ts).replace( + tzinfo=config.get("DRUID_TZ")) + return dt + timedelta(milliseconds=time_offset) + if DTTM_ALIAS in df.columns and time_offset: + df[DTTM_ALIAS] = df[DTTM_ALIAS].apply(increment_timestamp) + + return QueryResult( + df=df, + query=query_str, + duration=datetime.now() - qry_start_dttm) + + def get_filters(self, raw_filters): # noqa + filters = None + for flt in raw_filters: + if not all(f in flt for f in ['col', 'op', 'val']): + continue + col = flt['col'] + op = flt['op'] + eq = flt['val'] + cond = None + if op in ('in', 'not in'): + eq = [types.replace("'", '').strip() for types in eq] + elif not isinstance(flt['val'], basestring): + eq = eq[0] if len(eq) > 0 else '' + if col in self.num_cols: + if op in ('in', 'not in'): + eq = [utils.js_string_to_num(v) for v in eq] + else: + eq = utils.js_string_to_num(eq) + if op == '==': + cond = Dimension(col) == eq + elif op == '!=': + cond = ~(Dimension(col) == eq) + elif op in ('in', 'not in'): + fields = [] + if len(eq) > 1: + for s in eq: + fields.append(Dimension(col) == s) + cond = Filter(type="or", fields=fields) + elif len(eq) == 1: + cond = Dimension(col) == eq[0] + if op == 'not in': + cond = ~cond + elif op == 'regex': + cond = Filter(type="regex", pattern=eq, dimension=col) + if filters: + filters = Filter(type="and", fields=[ + cond, + filters + ]) + else: + filters = cond + return filters + + def _get_having_obj(self, col, op, eq): + cond = None + if op == '==': + if col in self.column_names: + cond = DimSelector(dimension=col, value=eq) + else: + cond = Aggregation(col) == eq + elif op == '>': + cond = Aggregation(col) > eq + elif op == '<': + cond = Aggregation(col) < eq + + return cond diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py new file mode 100644 index 0000000000000..d86e380295269 --- /dev/null +++ b/superset/connectors/druid/views.py @@ -0,0 +1,203 @@ +import sqlalchemy as sqla + +from flask import Markup +from flask_appbuilder import CompactCRUDMixin +from flask_appbuilder.models.sqla.interface import SQLAInterface + +from flask_babel import lazy_gettext as _ +from flask_babel import gettext as __ + +import superset +from superset import db, utils, appbuilder, sm, security +from superset.views.base import ( + SupersetModelView, validate_json, DeleteMixin, ListWidgetWithCheckboxes, + DatasourceFilter, get_datasource_exist_error_mgs) + +from . import models + + +class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa + datamodel = SQLAInterface(models.DruidColumn) + edit_columns = [ + 'column_name', 'description', 'dimension_spec_json', 'datasource', + 'groupby', 'count_distinct', 'sum', 'min', 'max'] + add_columns = edit_columns + list_columns = [ + 'column_name', 'type', 'groupby', 'filterable', 'count_distinct', + 'sum', 'min', 'max'] + can_delete = False + page_size = 500 + label_columns = { + 'column_name': _("Column"), + 'type': _("Type"), + 'datasource': _("Datasource"), + 'groupby': _("Groupable"), + 'filterable': _("Filterable"), + 'count_distinct': _("Count Distinct"), + 'sum': _("Sum"), + 'min': _("Min"), + 'max': _("Max"), + } + description_columns = { + 'dimension_spec_json': utils.markdown( + "this field can be used to specify " + "a `dimensionSpec` as documented [here]" + "(http://druid.io/docs/latest/querying/dimensionspecs.html). " + "Make sure to input valid JSON and that the " + "`outputName` matches the `column_name` defined " + "above.", + True), + } + + def post_update(self, col): + col.generate_metrics() + utils.validate_json(col.dimension_spec_json) + + def post_add(self, col): + self.post_update(col) + +appbuilder.add_view_no_menu(DruidColumnInlineView) + + +class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa + datamodel = SQLAInterface(models.DruidMetric) + list_columns = ['metric_name', 'verbose_name', 'metric_type'] + edit_columns = [ + 'metric_name', 'description', 'verbose_name', 'metric_type', 'json', + 'datasource', 'd3format', 'is_restricted'] + add_columns = edit_columns + page_size = 500 + validators_columns = { + 'json': [validate_json], + } + description_columns = { + 'metric_type': utils.markdown( + "use `postagg` as the metric type if you are defining a " + "[Druid Post Aggregation]" + "(http://druid.io/docs/latest/querying/post-aggregations.html)", + True), + 'is_restricted': _("Whether the access to this metric is restricted " + "to certain roles. Only roles with the permission " + "'metric access on XXX (the name of this metric)' " + "are allowed to access this metric"), + } + label_columns = { + 'metric_name': _("Metric"), + 'description': _("Description"), + 'verbose_name': _("Verbose Name"), + 'metric_type': _("Type"), + 'json': _("JSON"), + 'datasource': _("Druid Datasource"), + } + + def post_add(self, metric): + utils.init_metrics_perm(superset, [metric]) + + def post_update(self, metric): + utils.init_metrics_perm(superset, [metric]) + + +appbuilder.add_view_no_menu(DruidMetricInlineView) + + +class DruidClusterModelView(SupersetModelView, DeleteMixin): # noqa + datamodel = SQLAInterface(models.DruidCluster) + add_columns = [ + 'cluster_name', + 'coordinator_host', 'coordinator_port', 'coordinator_endpoint', + 'broker_host', 'broker_port', 'broker_endpoint', 'cache_timeout', + ] + edit_columns = add_columns + list_columns = ['cluster_name', 'metadata_last_refreshed'] + label_columns = { + 'cluster_name': _("Cluster"), + 'coordinator_host': _("Coordinator Host"), + 'coordinator_port': _("Coordinator Port"), + 'coordinator_endpoint': _("Coordinator Endpoint"), + 'broker_host': _("Broker Host"), + 'broker_port': _("Broker Port"), + 'broker_endpoint': _("Broker Endpoint"), + } + + def pre_add(self, cluster): + security.merge_perm(sm, 'database_access', cluster.perm) + + def pre_update(self, cluster): + self.pre_add(cluster) + + +appbuilder.add_view( + DruidClusterModelView, + name="Druid Clusters", + label=__("Druid Clusters"), + icon="fa-cubes", + category="Sources", + category_label=__("Sources"), + category_icon='fa-database',) + + +class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa + datamodel = SQLAInterface(models.DruidDatasource) + list_widget = ListWidgetWithCheckboxes + list_columns = [ + 'datasource_link', 'cluster', 'changed_by_', 'changed_on_', 'offset'] + order_columns = [ + 'datasource_link', 'changed_on_', 'offset'] + related_views = [DruidColumnInlineView, DruidMetricInlineView] + edit_columns = [ + 'datasource_name', 'cluster', 'description', 'owner', + 'is_featured', 'is_hidden', 'filter_select_enabled', + 'default_endpoint', 'offset', 'cache_timeout'] + add_columns = edit_columns + show_columns = add_columns + ['perm'] + page_size = 500 + base_order = ('datasource_name', 'asc') + description_columns = { + 'offset': _("Timezone offset (in hours) for this datasource"), + 'description': Markup( + "Supports markdown"), + } + base_filters = [['id', DatasourceFilter, lambda: []]] + label_columns = { + 'datasource_link': _("Data Source"), + 'cluster': _("Cluster"), + 'description': _("Description"), + 'owner': _("Owner"), + 'is_featured': _("Is Featured"), + 'is_hidden': _("Is Hidden"), + 'filter_select_enabled': _("Enable Filter Select"), + 'default_endpoint': _("Default Endpoint"), + 'offset': _("Time Offset"), + 'cache_timeout': _("Cache Timeout"), + } + + def pre_add(self, datasource): + number_of_existing_datasources = db.session.query( + sqla.func.count('*')).filter( + models.DruidDatasource.datasource_name == + datasource.datasource_name, + models.DruidDatasource.cluster_name == datasource.cluster.id + ).scalar() + + # table object is already added to the session + if number_of_existing_datasources > 1: + raise Exception(get_datasource_exist_error_mgs( + datasource.full_name)) + + def post_add(self, datasource): + datasource.generate_metrics() + security.merge_perm(sm, 'datasource_access', datasource.get_perm()) + if datasource.schema: + security.merge_perm(sm, 'schema_access', datasource.schema_perm) + + def post_update(self, datasource): + self.post_add(datasource) + +appbuilder.add_view( + DruidDatasourceModelView, + "Druid Datasources", + label=__("Druid Datasources"), + category="Sources", + category_label=__("Sources"), + icon="fa-cube") diff --git a/superset/connectors/sqla/__init__.py b/superset/connectors/sqla/__init__.py new file mode 100644 index 0000000000000..b2df79851f224 --- /dev/null +++ b/superset/connectors/sqla/__init__.py @@ -0,0 +1,2 @@ +from . import models # noqa +from . import views # noqa diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py new file mode 100644 index 0000000000000..6ec0e3fff443b --- /dev/null +++ b/superset/connectors/sqla/models.py @@ -0,0 +1,701 @@ +from datetime import datetime +import logging +import sqlparse + +import pandas as pd + +from sqlalchemy import ( + Column, Integer, String, ForeignKey, Text, Boolean, + DateTime, +) +import sqlalchemy as sa +from sqlalchemy import asc, and_, desc, select +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.expression import ColumnClause, TextAsFrom +from sqlalchemy.orm import backref, relationship +from sqlalchemy.sql import table, literal_column, text, column + +from flask import escape, Markup +from flask_appbuilder import Model +from flask_babel import lazy_gettext as _ + +from superset import db, utils, import_util +from superset.connectors.base import Datasource +from superset.utils import ( + wrap_clause_in_parens, + DTTM_ALIAS, QueryStatus +) +from superset.models.helpers import QueryResult +from superset.models.core import Database +from superset.jinja_context import get_template_processor +from superset.models.helpers import AuditMixinNullable, ImportMixin, set_perm + + +class TableColumn(Model, AuditMixinNullable, ImportMixin): + + """ORM object for table columns, each table can have multiple columns""" + + __tablename__ = 'table_columns' + id = Column(Integer, primary_key=True) + table_id = Column(Integer, ForeignKey('tables.id')) + table = relationship( + 'SqlaTable', + backref=backref('columns', cascade='all, delete-orphan'), + foreign_keys=[table_id]) + column_name = Column(String(255)) + verbose_name = Column(String(1024)) + is_dttm = Column(Boolean, default=False) + is_active = Column(Boolean, default=True) + type = Column(String(32), default='') + groupby = Column(Boolean, default=False) + count_distinct = Column(Boolean, default=False) + sum = Column(Boolean, default=False) + avg = Column(Boolean, default=False) + max = Column(Boolean, default=False) + min = Column(Boolean, default=False) + filterable = Column(Boolean, default=False) + expression = Column(Text, default='') + description = Column(Text, default='') + python_date_format = Column(String(255)) + database_expression = Column(String(255)) + + num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG', 'REAL', 'NUMERIC') + date_types = ('DATE', 'TIME') + str_types = ('VARCHAR', 'STRING', 'CHAR') + export_fields = ( + 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', + 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min', + 'filterable', 'expression', 'description', 'python_date_format', + 'database_expression' + ) + + def __repr__(self): + return self.column_name + + @property + def is_num(self): + return any([t in self.type.upper() for t in self.num_types]) + + @property + def is_time(self): + return any([t in self.type.upper() for t in self.date_types]) + + @property + def is_string(self): + return any([t in self.type.upper() for t in self.str_types]) + + @property + def sqla_col(self): + name = self.column_name + if not self.expression: + col = column(self.column_name).label(name) + else: + col = literal_column(self.expression).label(name) + return col + + def get_time_filter(self, start_dttm, end_dttm): + col = self.sqla_col.label('__time') + return and_( + col >= text(self.dttm_sql_literal(start_dttm)), + col <= text(self.dttm_sql_literal(end_dttm)), + ) + + def get_timestamp_expression(self, time_grain): + """Getting the time component of the query""" + expr = self.expression or self.column_name + if not self.expression and not time_grain: + return column(expr, type_=DateTime).label(DTTM_ALIAS) + if time_grain: + pdf = self.python_date_format + if pdf in ('epoch_s', 'epoch_ms'): + # if epoch, translate to DATE using db specific conf + db_spec = self.table.database.db_engine_spec + if pdf == 'epoch_s': + expr = db_spec.epoch_to_dttm().format(col=expr) + elif pdf == 'epoch_ms': + expr = db_spec.epoch_ms_to_dttm().format(col=expr) + grain = self.table.database.grains_dict().get(time_grain, '{col}') + expr = grain.function.format(col=expr) + return literal_column(expr, type_=DateTime).label(DTTM_ALIAS) + + @classmethod + def import_obj(cls, i_column): + def lookup_obj(lookup_column): + return db.session.query(TableColumn).filter( + TableColumn.table_id == lookup_column.table_id, + TableColumn.column_name == lookup_column.column_name).first() + return import_util.import_simple_obj(db.session, i_column, lookup_obj) + + def dttm_sql_literal(self, dttm): + """Convert datetime object to a SQL expression string + + If database_expression is empty, the internal dttm + will be parsed as the string with the pattern that + the user inputted (python_date_format) + If database_expression is not empty, the internal dttm + will be parsed as the sql sentence for the database to convert + """ + + tf = self.python_date_format or '%Y-%m-%d %H:%M:%S.%f' + if self.database_expression: + return self.database_expression.format(dttm.strftime('%Y-%m-%d %H:%M:%S')) + elif tf == 'epoch_s': + return str((dttm - datetime(1970, 1, 1)).total_seconds()) + elif tf == 'epoch_ms': + return str((dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0) + else: + s = self.table.database.db_engine_spec.convert_dttm( + self.type, dttm) + return s or "'{}'".format(dttm.strftime(tf)) + + +class SqlMetric(Model, AuditMixinNullable, ImportMixin): + + """ORM object for metrics, each table can have multiple metrics""" + + __tablename__ = 'sql_metrics' + id = Column(Integer, primary_key=True) + metric_name = Column(String(512)) + verbose_name = Column(String(1024)) + metric_type = Column(String(32)) + table_id = Column(Integer, ForeignKey('tables.id')) + table = relationship( + 'SqlaTable', + backref=backref('metrics', cascade='all, delete-orphan'), + foreign_keys=[table_id]) + expression = Column(Text) + description = Column(Text) + is_restricted = Column(Boolean, default=False, nullable=True) + d3format = Column(String(128)) + + export_fields = ( + 'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression', + 'description', 'is_restricted', 'd3format') + + @property + def sqla_col(self): + name = self.metric_name + return literal_column(self.expression).label(name) + + @property + def perm(self): + return ( + "{parent_name}.[{obj.metric_name}](id:{obj.id})" + ).format(obj=self, + parent_name=self.table.full_name) if self.table else None + + @classmethod + def import_obj(cls, i_metric): + def lookup_obj(lookup_metric): + return db.session.query(SqlMetric).filter( + SqlMetric.table_id == lookup_metric.table_id, + SqlMetric.metric_name == lookup_metric.metric_name).first() + return import_util.import_simple_obj(db.session, i_metric, lookup_obj) + + +class SqlaTable(Model, Datasource, AuditMixinNullable, ImportMixin): + + """An ORM object for SqlAlchemy table references""" + + type = "table" + query_language = 'sql' + + __tablename__ = 'tables' + id = Column(Integer, primary_key=True) + table_name = Column(String(250)) + main_dttm_col = Column(String(250)) + description = Column(Text) + default_endpoint = Column(Text) + database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False) + is_featured = Column(Boolean, default=False) + filter_select_enabled = Column(Boolean, default=False) + user_id = Column(Integer, ForeignKey('ab_user.id')) + owner = relationship('User', backref='tables', foreign_keys=[user_id]) + database = relationship( + 'Database', + backref=backref('tables', cascade='all, delete-orphan'), + foreign_keys=[database_id]) + offset = Column(Integer, default=0) + cache_timeout = Column(Integer) + schema = Column(String(255)) + sql = Column(Text) + params = Column(Text) + perm = Column(String(1000)) + + baselink = "tablemodelview" + column_cls = TableColumn + metric_cls = SqlMetric + export_fields = ( + 'table_name', 'main_dttm_col', 'description', 'default_endpoint', + 'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema', + 'sql', 'params') + + __table_args__ = ( + sa.UniqueConstraint( + 'database_id', 'schema', 'table_name', + name='_customer_location_uc'),) + + def __repr__(self): + return self.name + + @property + def description_markeddown(self): + return utils.markdown(self.description) + + @property + def link(self): + name = escape(self.name) + return Markup( + '{name}'.format(**locals())) + + @property + def schema_perm(self): + """Returns schema permission if present, database one otherwise.""" + return utils.get_schema_perm(self.database, self.schema) + + def get_perm(self): + return ( + "[{obj.database}].[{obj.table_name}]" + "(id:{obj.id})").format(obj=self) + + @property + def name(self): + if not self.schema: + return self.table_name + return "{}.{}".format(self.schema, self.table_name) + + @property + def full_name(self): + return utils.get_datasource_full_name( + self.database, self.table_name, schema=self.schema) + + @property + def dttm_cols(self): + l = [c.column_name for c in self.columns if c.is_dttm] + if self.main_dttm_col and self.main_dttm_col not in l: + l.append(self.main_dttm_col) + return l + + @property + def num_cols(self): + return [c.column_name for c in self.columns if c.is_num] + + @property + def any_dttm_col(self): + cols = self.dttm_cols + if cols: + return cols[0] + + @property + def html(self): + t = ((c.column_name, c.type) for c in self.columns) + df = pd.DataFrame(t) + df.columns = ['field', 'type'] + return df.to_html( + index=False, + classes=( + "dataframe table table-striped table-bordered " + "table-condensed")) + + @property + def metrics_combo(self): + return sorted( + [ + (m.metric_name, m.verbose_name or m.metric_name) + for m in self.metrics], + key=lambda x: x[1]) + + @property + def sql_url(self): + return self.database.sql_url + "?table_name=" + str(self.table_name) + + @property + def time_column_grains(self): + return { + "time_columns": self.dttm_cols, + "time_grains": [grain.name for grain in self.database.grains()] + } + + def get_col(self, col_name): + columns = self.columns + for col in columns: + if col_name == col.column_name: + return col + + def values_for_column(self, + column_name, + from_dttm, + to_dttm, + limit=500): + """Runs query against sqla to retrieve some + sample values for the given column. + """ + granularity = self.main_dttm_col + + cols = {col.column_name: col for col in self.columns} + target_col = cols[column_name] + + tbl = table(self.table_name) + qry = sa.select([target_col.sqla_col]) + qry = qry.select_from(tbl) + qry = qry.distinct(column_name) + qry = qry.limit(limit) + + if granularity: + dttm_col = cols[granularity] + timestamp = dttm_col.sqla_col.label('timestamp') + time_filter = [ + timestamp >= text(dttm_col.dttm_sql_literal(from_dttm)), + timestamp <= text(dttm_col.dttm_sql_literal(to_dttm)), + ] + qry = qry.where(and_(*time_filter)) + + engine = self.database.get_sqla_engine() + sql = "{}".format( + qry.compile( + engine, compile_kwargs={"literal_binds": True}, ), + ) + + return pd.read_sql_query( + sql=sql, + con=engine + ) + + def get_query_str( # sqla + self, engine, qry_start_dttm, + groupby, metrics, + granularity, + from_dttm, to_dttm, + filter=None, # noqa + is_timeseries=True, + timeseries_limit=15, + timeseries_limit_metric=None, + row_limit=None, + inner_from_dttm=None, + inner_to_dttm=None, + orderby=None, + extras=None, + columns=None): + """Querying any sqla table from this common interface""" + template_processor = get_template_processor( + table=self, database=self.database) + + # For backward compatibility + if granularity not in self.dttm_cols: + granularity = self.main_dttm_col + + cols = {col.column_name: col for col in self.columns} + metrics_dict = {m.metric_name: m for m in self.metrics} + + if not granularity and is_timeseries: + raise Exception(_( + "Datetime column not provided as part table configuration " + "and is required by this type of chart")) + for m in metrics: + if m not in metrics_dict: + raise Exception(_("Metric '{}' is not valid".format(m))) + metrics_exprs = [metrics_dict.get(m).sqla_col for m in metrics] + timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric) + timeseries_limit_metric_expr = None + if timeseries_limit_metric: + timeseries_limit_metric_expr = \ + timeseries_limit_metric.sqla_col + if metrics: + main_metric_expr = metrics_exprs[0] + else: + main_metric_expr = literal_column("COUNT(*)").label("ccount") + + select_exprs = [] + groupby_exprs = [] + + if groupby: + select_exprs = [] + inner_select_exprs = [] + inner_groupby_exprs = [] + for s in groupby: + col = cols[s] + outer = col.sqla_col + inner = col.sqla_col.label(col.column_name + '__') + + groupby_exprs.append(outer) + select_exprs.append(outer) + inner_groupby_exprs.append(inner) + inner_select_exprs.append(inner) + elif columns: + for s in columns: + select_exprs.append(cols[s].sqla_col) + metrics_exprs = [] + + if granularity: + @compiles(ColumnClause) + def visit_column(element, compiler, **kw): + """Patch for sqlalchemy bug + + TODO: sqlalchemy 1.2 release should be doing this on its own. + Patch only if the column clause is specific for DateTime + set and granularity is selected. + """ + text = compiler.visit_column(element, **kw) + try: + if ( + element.is_literal and + hasattr(element.type, 'python_type') and + type(element.type) is DateTime + ): + text = text.replace('%%', '%') + except NotImplementedError: + # Some elements raise NotImplementedError for python_type + pass + return text + + dttm_col = cols[granularity] + time_grain = extras.get('time_grain_sqla') + + if is_timeseries: + timestamp = dttm_col.get_timestamp_expression(time_grain) + select_exprs += [timestamp] + groupby_exprs += [timestamp] + + time_filter = dttm_col.get_time_filter(from_dttm, to_dttm) + + select_exprs += metrics_exprs + qry = sa.select(select_exprs) + + tbl = table(self.table_name) + if self.schema: + tbl.schema = self.schema + + # Supporting arbitrary SQL statements in place of tables + if self.sql: + tbl = TextAsFrom(sa.text(self.sql), []).alias('expr_qry') + + if not columns: + qry = qry.group_by(*groupby_exprs) + + where_clause_and = [] + having_clause_and = [] + for flt in filter: + if not all([flt.get(s) for s in ['col', 'op', 'val']]): + continue + col = flt['col'] + op = flt['op'] + eq = flt['val'] + col_obj = cols.get(col) + if col_obj and op in ('in', 'not in'): + values = [types.strip("'").strip('"') for types in eq] + if col_obj.is_num: + values = [utils.js_string_to_num(s) for s in values] + cond = col_obj.sqla_col.in_(values) + if op == 'not in': + cond = ~cond + where_clause_and.append(cond) + if extras: + where = extras.get('where') + if where: + where_clause_and += [wrap_clause_in_parens( + template_processor.process_template(where))] + having = extras.get('having') + if having: + having_clause_and += [wrap_clause_in_parens( + template_processor.process_template(having))] + if granularity: + qry = qry.where(and_(*([time_filter] + where_clause_and))) + else: + qry = qry.where(and_(*where_clause_and)) + qry = qry.having(and_(*having_clause_and)) + if groupby: + qry = qry.order_by(desc(main_metric_expr)) + elif orderby: + for col, ascending in orderby: + direction = asc if ascending else desc + qry = qry.order_by(direction(col)) + + qry = qry.limit(row_limit) + + if is_timeseries and timeseries_limit and groupby: + # some sql dialects require for order by expressions + # to also be in the select clause -- others, e.g. vertica, + # require a unique inner alias + inner_main_metric_expr = main_metric_expr.label('mme_inner__') + inner_select_exprs += [inner_main_metric_expr] + subq = select(inner_select_exprs) + subq = subq.select_from(tbl) + inner_time_filter = dttm_col.get_time_filter( + inner_from_dttm or from_dttm, + inner_to_dttm or to_dttm, + ) + subq = subq.where(and_(*(where_clause_and + [inner_time_filter]))) + subq = subq.group_by(*inner_groupby_exprs) + ob = inner_main_metric_expr + if timeseries_limit_metric_expr is not None: + ob = timeseries_limit_metric_expr + subq = subq.order_by(desc(ob)) + subq = subq.limit(timeseries_limit) + on_clause = [] + for i, gb in enumerate(groupby): + on_clause.append( + groupby_exprs[i] == column(gb + '__')) + + tbl = tbl.join(subq.alias(), and_(*on_clause)) + + qry = qry.select_from(tbl) + + sql = "{}".format( + qry.compile( + engine, compile_kwargs={"literal_binds": True},), + ) + logging.info(sql) + sql = sqlparse.format(sql, reindent=True) + return sql + + def query(self, query_obj): + qry_start_dttm = datetime.now() + engine = self.database.get_sqla_engine() + sql = self.get_query_str(engine, qry_start_dttm, **query_obj) + status = QueryStatus.SUCCESS + error_message = None + df = None + try: + df = pd.read_sql_query(sql, con=engine) + except Exception as e: + status = QueryStatus.FAILED + error_message = str(e) + + return QueryResult( + status=status, + df=df, + duration=datetime.now() - qry_start_dttm, + query=sql, + error_message=error_message) + + def get_sqla_table_object(self): + return self.database.get_table(self.table_name, schema=self.schema) + + def fetch_metadata(self): + """Fetches the metadata for the table and merges it in""" + try: + table = self.get_sqla_table_object() + except Exception: + raise Exception( + "Table doesn't seem to exist in the specified database, " + "couldn't fetch column information") + + TC = TableColumn # noqa shortcut to class + M = SqlMetric # noqa + metrics = [] + any_date_col = None + for col in table.columns: + try: + datatype = "{}".format(col.type).upper() + except Exception as e: + datatype = "UNKNOWN" + logging.error( + "Unrecognized data type in {}.{}".format(table, col.name)) + logging.exception(e) + dbcol = ( + db.session + .query(TC) + .filter(TC.table == self) + .filter(TC.column_name == col.name) + .first() + ) + db.session.flush() + if not dbcol: + dbcol = TableColumn(column_name=col.name, type=datatype) + dbcol.groupby = dbcol.is_string + dbcol.filterable = dbcol.is_string + dbcol.sum = dbcol.is_num + dbcol.avg = dbcol.is_num + dbcol.is_dttm = dbcol.is_time + + db.session.merge(self) + self.columns.append(dbcol) + + if not any_date_col and dbcol.is_time: + any_date_col = col.name + + quoted = "{}".format( + column(dbcol.column_name).compile(dialect=db.engine.dialect)) + if dbcol.sum: + metrics.append(M( + metric_name='sum__' + dbcol.column_name, + verbose_name='sum__' + dbcol.column_name, + metric_type='sum', + expression="SUM({})".format(quoted) + )) + if dbcol.avg: + metrics.append(M( + metric_name='avg__' + dbcol.column_name, + verbose_name='avg__' + dbcol.column_name, + metric_type='avg', + expression="AVG({})".format(quoted) + )) + if dbcol.max: + metrics.append(M( + metric_name='max__' + dbcol.column_name, + verbose_name='max__' + dbcol.column_name, + metric_type='max', + expression="MAX({})".format(quoted) + )) + if dbcol.min: + metrics.append(M( + metric_name='min__' + dbcol.column_name, + verbose_name='min__' + dbcol.column_name, + metric_type='min', + expression="MIN({})".format(quoted) + )) + if dbcol.count_distinct: + metrics.append(M( + metric_name='count_distinct__' + dbcol.column_name, + verbose_name='count_distinct__' + dbcol.column_name, + metric_type='count_distinct', + expression="COUNT(DISTINCT {})".format(quoted) + )) + dbcol.type = datatype + db.session.merge(self) + db.session.commit() + + metrics.append(M( + metric_name='count', + verbose_name='COUNT(*)', + metric_type='count', + expression="COUNT(*)" + )) + for metric in metrics: + m = ( + db.session.query(M) + .filter(M.metric_name == metric.metric_name) + .filter(M.table_id == self.id) + .first() + ) + metric.table_id = self.id + if not m: + db.session.add(metric) + db.session.commit() + if not self.main_dttm_col: + self.main_dttm_col = any_date_col + + @classmethod + def import_obj(cls, i_datasource, import_time=None): + """Imports the datasource from the object to the database. + + Metrics and columns and datasource will be overrided if exists. + This function can be used to import/export dashboards between multiple + superset instances. Audit metadata isn't copies over. + """ + def lookup_sqlatable(table): + return db.session.query(SqlaTable).join(Database).filter( + SqlaTable.table_name == table.table_name, + SqlaTable.schema == table.schema, + Database.id == table.database_id, + ).first() + + def lookup_database(table): + return db.session.query(Database).filter_by( + database_name=table.params_dict['database_name']).one() + return import_util.import_datasource( + db.session, i_datasource, lookup_database, lookup_sqlatable, + import_time) + +sa.event.listen(SqlaTable, 'after_insert', set_perm) +sa.event.listen(SqlaTable, 'after_update', set_perm) diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py new file mode 100644 index 0000000000000..f5dccc76ce200 --- /dev/null +++ b/superset/connectors/sqla/views.py @@ -0,0 +1,213 @@ +import logging + +from flask import Markup, flash +from flask_appbuilder import CompactCRUDMixin +from flask_appbuilder.models.sqla.interface import SQLAInterface +import sqlalchemy as sa + +from flask_babel import lazy_gettext as _ +from flask_babel import gettext as __ + +from superset import appbuilder, db, utils, security, sm +from superset.views.base import ( + SupersetModelView, ListWidgetWithCheckboxes, DeleteMixin, DatasourceFilter, + get_datasource_exist_error_mgs, +) + +from . import models + + +class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa + datamodel = SQLAInterface(models.TableColumn) + can_delete = False + list_widget = ListWidgetWithCheckboxes + edit_columns = [ + 'column_name', 'verbose_name', 'description', 'groupby', 'filterable', + 'table', 'count_distinct', 'sum', 'min', 'max', 'expression', + 'is_dttm', 'python_date_format', 'database_expression'] + add_columns = edit_columns + list_columns = [ + 'column_name', 'type', 'groupby', 'filterable', 'count_distinct', + 'sum', 'min', 'max', 'is_dttm'] + page_size = 500 + description_columns = { + 'is_dttm': (_( + "Whether to make this column available as a " + "[Time Granularity] option, column has to be DATETIME or " + "DATETIME-like")), + 'expression': utils.markdown( + "a valid SQL expression as supported by the underlying backend. " + "Example: `substr(name, 1, 1)`", True), + 'python_date_format': utils.markdown(Markup( + "The pattern of timestamp format, use " + "" + "python datetime string pattern " + "expression. If time is stored in epoch " + "format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` " + "below empty if timestamp is stored in " + "String or Integer(epoch) type"), True), + 'database_expression': utils.markdown( + "The database expression to cast internal datetime " + "constants to database date/timestamp type according to the DBAPI. " + "The expression should follow the pattern of " + "%Y-%m-%d %H:%M:%S, based on different DBAPI. " + "The string should be a python string formatter \n" + "`Ex: TO_DATE('{}', 'YYYY-MM-DD HH24:MI:SS')` for Oracle" + "Superset uses default expression based on DB URI if this " + "field is blank.", True), + } + label_columns = { + 'column_name': _("Column"), + 'verbose_name': _("Verbose Name"), + 'description': _("Description"), + 'groupby': _("Groupable"), + 'filterable': _("Filterable"), + 'table': _("Table"), + 'count_distinct': _("Count Distinct"), + 'sum': _("Sum"), + 'min': _("Min"), + 'max': _("Max"), + 'expression': _("Expression"), + 'is_dttm': _("Is temporal"), + 'python_date_format': _("Datetime Format"), + 'database_expression': _("Database Expression") + } +appbuilder.add_view_no_menu(TableColumnInlineView) + + +class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa + datamodel = SQLAInterface(models.SqlMetric) + list_columns = ['metric_name', 'verbose_name', 'metric_type'] + edit_columns = [ + 'metric_name', 'description', 'verbose_name', 'metric_type', + 'expression', 'table', 'd3format', 'is_restricted'] + description_columns = { + 'expression': utils.markdown( + "a valid SQL expression as supported by the underlying backend. " + "Example: `count(DISTINCT userid)`", True), + 'is_restricted': _("Whether the access to this metric is restricted " + "to certain roles. Only roles with the permission " + "'metric access on XXX (the name of this metric)' " + "are allowed to access this metric"), + 'd3format': utils.markdown( + "d3 formatting string as defined [here]" + "(https://github.com/d3/d3-format/blob/master/README.md#format). " + "For instance, this default formatting applies in the Table " + "visualization and allow for different metric to use different " + "formats", True + ), + } + add_columns = edit_columns + page_size = 500 + label_columns = { + 'metric_name': _("Metric"), + 'description': _("Description"), + 'verbose_name': _("Verbose Name"), + 'metric_type': _("Type"), + 'expression': _("SQL Expression"), + 'table': _("Table"), + } + + def post_add(self, metric): + if metric.is_restricted: + security.merge_perm(sm, 'metric_access', metric.get_perm()) + + def post_update(self, metric): + if metric.is_restricted: + security.merge_perm(sm, 'metric_access', metric.get_perm()) + +appbuilder.add_view_no_menu(SqlMetricInlineView) + + +class TableModelView(SupersetModelView, DeleteMixin): # noqa + datamodel = SQLAInterface(models.SqlaTable) + list_columns = [ + 'link', 'database', 'is_featured', + 'changed_by_', 'changed_on_'] + order_columns = [ + 'link', 'database', 'is_featured', 'changed_on_'] + add_columns = ['database', 'schema', 'table_name'] + edit_columns = [ + 'table_name', 'sql', 'is_featured', 'filter_select_enabled', + 'database', 'schema', + 'description', 'owner', + 'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout'] + show_columns = edit_columns + ['perm'] + related_views = [TableColumnInlineView, SqlMetricInlineView] + base_order = ('changed_on', 'desc') + description_columns = { + 'offset': _("Timezone offset (in hours) for this datasource"), + 'table_name': _( + "Name of the table that exists in the source database"), + 'schema': _( + "Schema, as used only in some databases like Postgres, Redshift " + "and DB2"), + 'description': Markup( + "Supports " + "markdown"), + 'sql': _( + "This fields acts a Superset view, meaning that Superset will " + "run a query against this string as a subquery." + ), + } + base_filters = [['id', DatasourceFilter, lambda: []]] + label_columns = { + 'link': _("Table"), + 'changed_by_': _("Changed By"), + 'database': _("Database"), + 'changed_on_': _("Last Changed"), + 'is_featured': _("Is Featured"), + 'filter_select_enabled': _("Enable Filter Select"), + 'schema': _("Schema"), + 'default_endpoint': _("Default Endpoint"), + 'offset': _("Offset"), + 'cache_timeout': _("Cache Timeout"), + } + + def pre_add(self, table): + number_of_existing_tables = db.session.query( + sa.func.count('*')).filter( + models.SqlaTable.table_name == table.table_name, + models.SqlaTable.schema == table.schema, + models.SqlaTable.database_id == table.database.id + ).scalar() + # table object is already added to the session + if number_of_existing_tables > 1: + raise Exception(get_datasource_exist_error_mgs(table.full_name)) + + # Fail before adding if the table can't be found + try: + table.get_sqla_table_object() + except Exception as e: + logging.exception(e) + raise Exception( + "Table [{}] could not be found, " + "please double check your " + "database connection, schema, and " + "table name".format(table.name)) + + def post_add(self, table): + table.fetch_metadata() + security.merge_perm(sm, 'datasource_access', table.get_perm()) + if table.schema: + security.merge_perm(sm, 'schema_access', table.schema_perm) + + flash(_( + "The table was created. As part of this two phase configuration " + "process, you should now click the edit button by " + "the new table to configure it."), + "info") + + def post_update(self, table): + self.post_add(table) + +appbuilder.add_view( + TableModelView, + "Tables", + label=__("Tables"), + category="Sources", + category_label=__("Sources"), + icon='fa-table',) + +appbuilder.add_separator("Sources") diff --git a/superset/data/__init__.py b/superset/data/__init__.py index f75b56b0158b7..f061b629bd9ae 100644 --- a/superset/data/__init__.py +++ b/superset/data/__init__.py @@ -14,15 +14,19 @@ import pandas as pd from sqlalchemy import String, DateTime, Date, Float, BigInteger -from superset import app, db, models, utils +from superset import app, db, utils +from superset.models import core as models from superset.security import get_or_create_main_db +from superset.connectors.connector_registry import ConnectorRegistry + # Shortcuts DB = models.Database Slice = models.Slice -TBL = models.SqlaTable Dash = models.Dashboard +TBL = ConnectorRegistry.sources['sqla'] + config = app.config DATA_FOLDER = os.path.join(config.get("BASE_DIR"), 'data') diff --git a/superset/models.py b/superset/models.py deleted file mode 100644 index ffbcb95400a9d..0000000000000 --- a/superset/models.py +++ /dev/null @@ -1,2800 +0,0 @@ -"""A collection of ORM sqlalchemy models for Superset""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -from collections import OrderedDict -import functools -import json -import logging -import numpy -import pickle -import re -import textwrap -from future.standard_library import install_aliases -install_aliases() -from urllib import parse -from copy import deepcopy, copy -from datetime import timedelta, datetime, date - -import humanize -import pandas as pd -import requests -import sqlalchemy as sqla -from sqlalchemy.engine.url import make_url -from sqlalchemy.orm import subqueryload - -import sqlparse -from dateutil.parser import parse as dparse - -from flask import escape, g, Markup, request -from flask_appbuilder import Model -from flask_appbuilder.models.mixins import AuditMixin -from flask_appbuilder.models.decorators import renders -from flask_babel import lazy_gettext as _ - -from pydruid.client import PyDruid -from pydruid.utils.aggregators import count -from pydruid.utils.filters import Dimension, Filter -from pydruid.utils.postaggregator import ( - Postaggregator, Quantile, Quantiles, Field, Const, HyperUniqueCardinality, -) -from pydruid.utils.having import Aggregation -from six import string_types - -from sqlalchemy import ( - Column, Integer, String, ForeignKey, Text, Boolean, - DateTime, Date, Table, Numeric, - create_engine, MetaData, desc, asc, select, and_ -) -from sqlalchemy.ext.compiler import compiles -from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.orm import backref, relationship -from sqlalchemy.orm.session import make_transient -from sqlalchemy.sql import table, literal_column, text, column -from sqlalchemy.sql.expression import ColumnClause, TextAsFrom -from sqlalchemy_utils import EncryptedType - -from superset import ( - app, db, db_engine_specs, get_session, utils, sm, import_util, -) -from superset.legacy import cast_form_data -from superset.source_registry import SourceRegistry -from superset.viz import viz_types -from superset.jinja_context import get_template_processor -from superset.utils import ( - flasher, MetricPermException, DimSelector, wrap_clause_in_parens, - DTTM_ALIAS, QueryStatus, -) - -config = app.config - - -class QueryResult(object): - - """Object returned by the query interface""" - - def __init__( # noqa - self, - df, - query, - duration, - status=QueryStatus.SUCCESS, - error_message=None): - self.df = df - self.query = query - self.duration = duration - self.status = status - self.error_message = error_message - - -def set_perm(mapper, connection, target): # noqa - if target.perm != target.get_perm(): - link_table = target.__table__ - connection.execute( - link_table.update() - .where(link_table.c.id == target.id) - .values(perm=target.get_perm()) - ) - - -def set_related_perm(mapper, connection, target): # noqa - src_class = target.cls_model - id_ = target.datasource_id - ds = db.session.query(src_class).filter_by(id=int(id_)).first() - target.perm = ds.perm - - -class JavascriptPostAggregator(Postaggregator): - def __init__(self, name, field_names, function): - self.post_aggregator = { - 'type': 'javascript', - 'fieldNames': field_names, - 'name': name, - 'function': function, - } - self.name = name - - -class ImportMixin(object): - def override(self, obj): - """Overrides the plain fields of the dashboard.""" - for field in obj.__class__.export_fields: - setattr(self, field, getattr(obj, field)) - - def copy(self): - """Creates a copy of the dashboard without relationships.""" - new_obj = self.__class__() - new_obj.override(self) - return new_obj - - def alter_params(self, **kwargs): - d = self.params_dict - d.update(kwargs) - self.params = json.dumps(d) - - @property - def params_dict(self): - if self.params: - params = re.sub(",[ \t\r\n]+}", "}", self.params) - params = re.sub(",[ \t\r\n]+\]", "]", params) - return json.loads(params) - else: - return {} - - -class AuditMixinNullable(AuditMixin): - - """Altering the AuditMixin to use nullable fields - - Allows creating objects programmatically outside of CRUD - """ - - created_on = Column(DateTime, default=datetime.now, nullable=True) - changed_on = Column( - DateTime, default=datetime.now, - onupdate=datetime.now, nullable=True) - - @declared_attr - def created_by_fk(cls): # noqa - return Column(Integer, ForeignKey('ab_user.id'), - default=cls.get_user_id, nullable=True) - - @declared_attr - def changed_by_fk(cls): # noqa - return Column( - Integer, ForeignKey('ab_user.id'), - default=cls.get_user_id, onupdate=cls.get_user_id, nullable=True) - - def _user_link(self, user): - if not user: - return '' - url = '/superset/profile/{}/'.format(user.username) - return Markup('{}'.format(url, escape(user) or '')) - - @renders('created_by') - def creator(self): # noqa - return self._user_link(self.created_by) - - @property - def changed_by_(self): - return self._user_link(self.changed_by) - - @renders('changed_on') - def changed_on_(self): - return Markup( - '{}'.format(self.changed_on)) - - @renders('changed_on') - def modified(self): - s = humanize.naturaltime(datetime.now() - self.changed_on) - return Markup('{}'.format(s)) - - @property - def icons(self): - return """ - - - - """.format(**locals()) - - -class Url(Model, AuditMixinNullable): - - """Used for the short url feature""" - - __tablename__ = 'url' - id = Column(Integer, primary_key=True) - url = Column(Text) - - -class KeyValue(Model): - - """Used for any type of key-value store""" - - __tablename__ = 'keyvalue' - id = Column(Integer, primary_key=True) - value = Column(Text, nullable=False) - - -class CssTemplate(Model, AuditMixinNullable): - - """CSS templates for dashboards""" - - __tablename__ = 'css_templates' - id = Column(Integer, primary_key=True) - template_name = Column(String(250)) - css = Column(Text, default='') - - -slice_user = Table('slice_user', Model.metadata, - Column('id', Integer, primary_key=True), - Column('user_id', Integer, ForeignKey('ab_user.id')), - Column('slice_id', Integer, ForeignKey('slices.id')) - ) - - -class Slice(Model, AuditMixinNullable, ImportMixin): - - """A slice is essentially a report or a view on data""" - - __tablename__ = 'slices' - id = Column(Integer, primary_key=True) - slice_name = Column(String(250)) - datasource_id = Column(Integer) - datasource_type = Column(String(200)) - datasource_name = Column(String(2000)) - viz_type = Column(String(250)) - params = Column(Text) - description = Column(Text) - cache_timeout = Column(Integer) - perm = Column(String(1000)) - owners = relationship("User", secondary=slice_user) - - export_fields = ('slice_name', 'datasource_type', 'datasource_name', - 'viz_type', 'params', 'cache_timeout') - - def __repr__(self): - return self.slice_name - - @property - def cls_model(self): - return SourceRegistry.sources[self.datasource_type] - - @property - def datasource(self): - return self.get_datasource - - @datasource.getter - @utils.memoized - def get_datasource(self): - ds = db.session.query( - self.cls_model).filter_by( - id=self.datasource_id).first() - return ds - - @renders('datasource_name') - def datasource_link(self): - datasource = self.datasource - if datasource: - return self.datasource.link - - @property - def datasource_edit_url(self): - self.datasource.url - - @property - @utils.memoized - def viz(self): - d = json.loads(self.params) - viz_class = viz_types[self.viz_type] - return viz_class(self.datasource, form_data=d) - - @property - def description_markeddown(self): - return utils.markdown(self.description) - - @property - def data(self): - """Data used to render slice in templates""" - d = {} - self.token = '' - try: - d = self.viz.data - self.token = d.get('token') - except Exception as e: - logging.exception(e) - d['error'] = str(e) - return { - 'datasource': self.datasource_name, - 'description': self.description, - 'description_markeddown': self.description_markeddown, - 'edit_url': self.edit_url, - 'form_data': self.form_data, - 'slice_id': self.id, - 'slice_name': self.slice_name, - 'slice_url': self.slice_url, - } - - @property - def json_data(self): - return json.dumps(self.data) - - @property - def form_data(self): - form_data = json.loads(self.params) - form_data['slice_id'] = self.id - form_data['viz_type'] = self.viz_type - form_data['datasource'] = ( - str(self.datasource_id) + '__' + self.datasource_type) - return form_data - - @property - def slice_url(self): - """Defines the url to access the slice""" - return ( - "/superset/explore/{obj.datasource_type}/" - "{obj.datasource_id}/?form_data={params}".format( - obj=self, params=parse.quote(json.dumps(self.form_data)))) - - @property - def slice_id_url(self): - return ( - "/superset/{slc.datasource_type}/{slc.datasource_id}/{slc.id}/" - ).format(slc=self) - - @property - def edit_url(self): - return "/slicemodelview/edit/{}".format(self.id) - - @property - def slice_link(self): - url = self.slice_url - name = escape(self.slice_name) - return Markup('{name}'.format(**locals())) - - def get_viz(self, url_params_multidict=None): - """Creates :py:class:viz.BaseViz object from the url_params_multidict. - - :param werkzeug.datastructures.MultiDict url_params_multidict: - Contains the visualization params, they override the self.params - stored in the database - :return: object of the 'viz_type' type that is taken from the - url_params_multidict or self.params. - :rtype: :py:class:viz.BaseViz - """ - slice_params = json.loads(self.params) - slice_params['slice_id'] = self.id - slice_params['json'] = "false" - slice_params['slice_name'] = self.slice_name - slice_params['viz_type'] = self.viz_type if self.viz_type else "table" - - return viz_types[slice_params.get('viz_type')]( - self.datasource, - form_data=slice_params, - slice_=self - ) - - @classmethod - def import_obj(cls, slc_to_import, import_time=None): - """Inserts or overrides slc in the database. - - remote_id and import_time fields in params_dict are set to track the - slice origin and ensure correct overrides for multiple imports. - Slice.perm is used to find the datasources and connect them. - """ - session = db.session - make_transient(slc_to_import) - slc_to_import.dashboards = [] - slc_to_import.alter_params( - remote_id=slc_to_import.id, import_time=import_time) - - # find if the slice was already imported - slc_to_override = None - for slc in session.query(Slice).all(): - if ('remote_id' in slc.params_dict and - slc.params_dict['remote_id'] == slc_to_import.id): - slc_to_override = slc - - slc_to_import = slc_to_import.copy() - params = slc_to_import.params_dict - slc_to_import.datasource_id = SourceRegistry.get_datasource_by_name( - session, slc_to_import.datasource_type, params['datasource_name'], - params['schema'], params['database_name']).id - if slc_to_override: - slc_to_override.override(slc_to_import) - session.flush() - return slc_to_override.id - session.add(slc_to_import) - logging.info('Final slice: {}'.format(slc_to_import.to_json())) - session.flush() - return slc_to_import.id - - -sqla.event.listen(Slice, 'before_insert', set_related_perm) -sqla.event.listen(Slice, 'before_update', set_related_perm) - - -dashboard_slices = Table( - 'dashboard_slices', Model.metadata, - Column('id', Integer, primary_key=True), - Column('dashboard_id', Integer, ForeignKey('dashboards.id')), - Column('slice_id', Integer, ForeignKey('slices.id')), -) - -dashboard_user = Table( - 'dashboard_user', Model.metadata, - Column('id', Integer, primary_key=True), - Column('user_id', Integer, ForeignKey('ab_user.id')), - Column('dashboard_id', Integer, ForeignKey('dashboards.id')) -) - - -class Dashboard(Model, AuditMixinNullable, ImportMixin): - - """The dashboard object!""" - - __tablename__ = 'dashboards' - id = Column(Integer, primary_key=True) - dashboard_title = Column(String(500)) - position_json = Column(Text) - description = Column(Text) - css = Column(Text) - json_metadata = Column(Text) - slug = Column(String(255), unique=True) - slices = relationship( - 'Slice', secondary=dashboard_slices, backref='dashboards') - owners = relationship("User", secondary=dashboard_user) - - export_fields = ('dashboard_title', 'position_json', 'json_metadata', - 'description', 'css', 'slug') - - def __repr__(self): - return self.dashboard_title - - @property - def table_names(self): - return ", ".join( - {"{}".format(s.datasource.name) for s in self.slices}) - - @property - def url(self): - return "/superset/dashboard/{}/".format(self.slug or self.id) - - @property - def datasources(self): - return {slc.datasource for slc in self.slices} - - @property - def sqla_metadata(self): - metadata = MetaData(bind=self.get_sqla_engine()) - return metadata.reflect() - - def dashboard_link(self): - title = escape(self.dashboard_title) - return Markup( - '{title}'.format(**locals())) - - @property - def json_data(self): - positions = self.position_json - if positions: - positions = json.loads(positions) - d = { - 'id': self.id, - 'metadata': self.params_dict, - 'css': self.css, - 'dashboard_title': self.dashboard_title, - 'slug': self.slug, - 'slices': [slc.data for slc in self.slices], - 'position_json': positions, - } - return json.dumps(d) - - @property - def params(self): - return self.json_metadata - - @params.setter - def params(self, value): - self.json_metadata = value - - @property - def position_array(self): - if self.position_json: - return json.loads(self.position_json) - return [] - - @classmethod - def import_obj(cls, dashboard_to_import, import_time=None): - """Imports the dashboard from the object to the database. - - Once dashboard is imported, json_metadata field is extended and stores - remote_id and import_time. It helps to decide if the dashboard has to - be overridden or just copies over. Slices that belong to this - dashboard will be wired to existing tables. This function can be used - to import/export dashboards between multiple superset instances. - Audit metadata isn't copies over. - """ - def alter_positions(dashboard, old_to_new_slc_id_dict): - """ Updates slice_ids in the position json. - - Sample position json: - [{ - "col": 5, - "row": 10, - "size_x": 4, - "size_y": 2, - "slice_id": "3610" - }] - """ - position_array = dashboard.position_array - for position in position_array: - if 'slice_id' not in position: - continue - old_slice_id = int(position['slice_id']) - if old_slice_id in old_to_new_slc_id_dict: - position['slice_id'] = '{}'.format( - old_to_new_slc_id_dict[old_slice_id]) - dashboard.position_json = json.dumps(position_array) - - logging.info('Started import of the dashboard: {}' - .format(dashboard_to_import.to_json())) - session = db.session - logging.info('Dashboard has {} slices' - .format(len(dashboard_to_import.slices))) - # copy slices object as Slice.import_slice will mutate the slice - # and will remove the existing dashboard - slice association - slices = copy(dashboard_to_import.slices) - old_to_new_slc_id_dict = {} - new_filter_immune_slices = [] - new_expanded_slices = {} - i_params_dict = dashboard_to_import.params_dict - for slc in slices: - logging.info('Importing slice {} from the dashboard: {}'.format( - slc.to_json(), dashboard_to_import.dashboard_title)) - new_slc_id = Slice.import_obj(slc, import_time=import_time) - old_to_new_slc_id_dict[slc.id] = new_slc_id - # update json metadata that deals with slice ids - new_slc_id_str = '{}'.format(new_slc_id) - old_slc_id_str = '{}'.format(slc.id) - if ('filter_immune_slices' in i_params_dict and - old_slc_id_str in i_params_dict['filter_immune_slices']): - new_filter_immune_slices.append(new_slc_id_str) - if ('expanded_slices' in i_params_dict and - old_slc_id_str in i_params_dict['expanded_slices']): - new_expanded_slices[new_slc_id_str] = ( - i_params_dict['expanded_slices'][old_slc_id_str]) - - # override the dashboard - existing_dashboard = None - for dash in session.query(Dashboard).all(): - if ('remote_id' in dash.params_dict and - dash.params_dict['remote_id'] == - dashboard_to_import.id): - existing_dashboard = dash - - dashboard_to_import.id = None - alter_positions(dashboard_to_import, old_to_new_slc_id_dict) - dashboard_to_import.alter_params(import_time=import_time) - if new_expanded_slices: - dashboard_to_import.alter_params( - expanded_slices=new_expanded_slices) - if new_filter_immune_slices: - dashboard_to_import.alter_params( - filter_immune_slices=new_filter_immune_slices) - - new_slices = session.query(Slice).filter( - Slice.id.in_(old_to_new_slc_id_dict.values())).all() - - if existing_dashboard: - existing_dashboard.override(dashboard_to_import) - existing_dashboard.slices = new_slices - session.flush() - return existing_dashboard.id - else: - # session.add(dashboard_to_import) causes sqlachemy failures - # related to the attached users / slices. Creating new object - # allows to avoid conflicts in the sql alchemy state. - copied_dash = dashboard_to_import.copy() - copied_dash.slices = new_slices - session.add(copied_dash) - session.flush() - return copied_dash.id - - @classmethod - def export_dashboards(cls, dashboard_ids): - copied_dashboards = [] - datasource_ids = set() - for dashboard_id in dashboard_ids: - # make sure that dashboard_id is an integer - dashboard_id = int(dashboard_id) - copied_dashboard = ( - db.session.query(Dashboard) - .options(subqueryload(Dashboard.slices)) - .filter_by(id=dashboard_id).first() - ) - make_transient(copied_dashboard) - for slc in copied_dashboard.slices: - datasource_ids.add((slc.datasource_id, slc.datasource_type)) - # add extra params for the import - slc.alter_params( - remote_id=slc.id, - datasource_name=slc.datasource.name, - schema=slc.datasource.name, - database_name=slc.datasource.database.name, - ) - copied_dashboard.alter_params(remote_id=dashboard_id) - copied_dashboards.append(copied_dashboard) - - eager_datasources = [] - for dashboard_id, dashboard_type in datasource_ids: - eager_datasource = SourceRegistry.get_eager_datasource( - db.session, dashboard_type, dashboard_id) - eager_datasource.alter_params( - remote_id=eager_datasource.id, - database_name=eager_datasource.database.name, - ) - make_transient(eager_datasource) - eager_datasources.append(eager_datasource) - - return pickle.dumps({ - 'dashboards': copied_dashboards, - 'datasources': eager_datasources, - }) - - -class Database(Model, AuditMixinNullable): - - """An ORM object that stores Database related information""" - - __tablename__ = 'dbs' - type = "table" - - id = Column(Integer, primary_key=True) - database_name = Column(String(250), unique=True) - sqlalchemy_uri = Column(String(1024)) - password = Column(EncryptedType(String(1024), config.get('SECRET_KEY'))) - cache_timeout = Column(Integer) - select_as_create_table_as = Column(Boolean, default=False) - expose_in_sqllab = Column(Boolean, default=False) - allow_run_sync = Column(Boolean, default=True) - allow_run_async = Column(Boolean, default=False) - allow_ctas = Column(Boolean, default=False) - allow_dml = Column(Boolean, default=False) - force_ctas_schema = Column(String(250)) - extra = Column(Text, default=textwrap.dedent("""\ - { - "metadata_params": {}, - "engine_params": {} - } - """)) - perm = Column(String(1000)) - - def __repr__(self): - return self.database_name - - @property - def name(self): - return self.database_name - - @property - def backend(self): - url = make_url(self.sqlalchemy_uri_decrypted) - return url.get_backend_name() - - def set_sqlalchemy_uri(self, uri): - password_mask = "X" * 10 - conn = sqla.engine.url.make_url(uri) - if conn.password != password_mask: - # do not over-write the password with the password mask - self.password = conn.password - conn.password = password_mask if conn.password else None - self.sqlalchemy_uri = str(conn) # hides the password - - def get_sqla_engine(self, schema=None): - extra = self.get_extra() - url = make_url(self.sqlalchemy_uri_decrypted) - params = extra.get('engine_params', {}) - url.database = self.get_database_for_various_backend(url, schema) - return create_engine(url, **params) - - def get_database_for_various_backend(self, uri, default_database=None): - database = uri.database - if self.backend == 'presto' and default_database: - if '/' in database: - database = database.split('/')[0] + '/' + default_database - else: - database += '/' + default_database - # Postgres and Redshift use the concept of schema as a logical entity - # on top of the database, so the database should not be changed - # even if passed default_database - elif self.backend == 'redshift' or self.backend == 'postgresql': - pass - elif default_database: - database = default_database - return database - - def get_reserved_words(self): - return self.get_sqla_engine().dialect.preparer.reserved_words - - def get_quoter(self): - return self.get_sqla_engine().dialect.identifier_preparer.quote - - def get_df(self, sql, schema): - sql = sql.strip().strip(';') - eng = self.get_sqla_engine(schema=schema) - cur = eng.execute(sql, schema=schema) - cols = [col[0] for col in cur.cursor.description] - df = pd.DataFrame(cur.fetchall(), columns=cols) - - def needs_conversion(df_series): - if df_series.empty: - return False - for df_type in [list, dict]: - if isinstance(df_series[0], df_type): - return True - return False - - for k, v in df.dtypes.iteritems(): - if v.type == numpy.object_ and needs_conversion(df[k]): - df[k] = df[k].apply(utils.json_dumps_w_dates) - return df - - def compile_sqla_query(self, qry, schema=None): - eng = self.get_sqla_engine(schema=schema) - compiled = qry.compile(eng, compile_kwargs={"literal_binds": True}) - return '{}'.format(compiled) - - def select_star( - self, table_name, schema=None, limit=100, show_cols=False, - indent=True): - """Generates a ``select *`` statement in the proper dialect""" - return self.db_engine_spec.select_star( - self, table_name, schema=schema, limit=limit, show_cols=show_cols, - indent=indent) - - def wrap_sql_limit(self, sql, limit=1000): - qry = ( - select('*') - .select_from(TextAsFrom(text(sql), ['*']) - .alias('inner_qry')).limit(limit) - ) - return self.compile_sqla_query(qry) - - def safe_sqlalchemy_uri(self): - return self.sqlalchemy_uri - - @property - def inspector(self): - engine = self.get_sqla_engine() - return sqla.inspect(engine) - - def all_table_names(self, schema=None, force=False): - if not schema: - tables_dict = self.db_engine_spec.fetch_result_sets( - self, 'table', force=force) - return tables_dict.get("", []) - return sorted(self.inspector.get_table_names(schema)) - - def all_view_names(self, schema=None, force=False): - if not schema: - views_dict = self.db_engine_spec.fetch_result_sets( - self, 'view', force=force) - return views_dict.get("", []) - views = [] - try: - views = self.inspector.get_view_names(schema) - except Exception: - pass - return views - - def all_schema_names(self): - return sorted(self.inspector.get_schema_names()) - - @property - def db_engine_spec(self): - engine_name = self.get_sqla_engine().name or 'base' - return db_engine_specs.engines.get( - engine_name, db_engine_specs.BaseEngineSpec) - - def grains(self): - """Defines time granularity database-specific expressions. - - The idea here is to make it easy for users to change the time grain - form a datetime (maybe the source grain is arbitrary timestamps, daily - or 5 minutes increments) to another, "truncated" datetime. Since - each database has slightly different but similar datetime functions, - this allows a mapping between database engines and actual functions. - """ - return self.db_engine_spec.time_grains - - def grains_dict(self): - return {grain.name: grain for grain in self.grains()} - - def get_extra(self): - extra = {} - if self.extra: - try: - extra = json.loads(self.extra) - except Exception as e: - logging.error(e) - return extra - - def get_table(self, table_name, schema=None): - extra = self.get_extra() - meta = MetaData(**extra.get('metadata_params', {})) - return Table( - table_name, meta, - schema=schema or None, - autoload=True, - autoload_with=self.get_sqla_engine()) - - def get_columns(self, table_name, schema=None): - return self.inspector.get_columns(table_name, schema) - - def get_indexes(self, table_name, schema=None): - return self.inspector.get_indexes(table_name, schema) - - def get_pk_constraint(self, table_name, schema=None): - return self.inspector.get_pk_constraint(table_name, schema) - - def get_foreign_keys(self, table_name, schema=None): - return self.inspector.get_foreign_keys(table_name, schema) - - @property - def sqlalchemy_uri_decrypted(self): - conn = sqla.engine.url.make_url(self.sqlalchemy_uri) - conn.password = self.password - return str(conn) - - @property - def sql_url(self): - return '/superset/sql/{}/'.format(self.id) - - def get_perm(self): - return ( - "[{obj.database_name}].(id:{obj.id})").format(obj=self) - -sqla.event.listen(Database, 'after_insert', set_perm) -sqla.event.listen(Database, 'after_update', set_perm) - - -class TableColumn(Model, AuditMixinNullable, ImportMixin): - - """ORM object for table columns, each table can have multiple columns""" - - __tablename__ = 'table_columns' - id = Column(Integer, primary_key=True) - table_id = Column(Integer, ForeignKey('tables.id')) - table = relationship( - 'SqlaTable', - backref=backref('columns', cascade='all, delete-orphan'), - foreign_keys=[table_id]) - column_name = Column(String(255)) - verbose_name = Column(String(1024)) - is_dttm = Column(Boolean, default=False) - is_active = Column(Boolean, default=True) - type = Column(String(32), default='') - groupby = Column(Boolean, default=False) - count_distinct = Column(Boolean, default=False) - sum = Column(Boolean, default=False) - avg = Column(Boolean, default=False) - max = Column(Boolean, default=False) - min = Column(Boolean, default=False) - filterable = Column(Boolean, default=False) - expression = Column(Text, default='') - description = Column(Text, default='') - python_date_format = Column(String(255)) - database_expression = Column(String(255)) - - num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG', 'REAL', 'NUMERIC') - date_types = ('DATE', 'TIME') - str_types = ('VARCHAR', 'STRING', 'CHAR') - export_fields = ( - 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', - 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min', - 'filterable', 'expression', 'description', 'python_date_format', - 'database_expression' - ) - - def __repr__(self): - return self.column_name - - @property - def is_num(self): - return any([t in self.type.upper() for t in self.num_types]) - - @property - def is_time(self): - return any([t in self.type.upper() for t in self.date_types]) - - @property - def is_string(self): - return any([t in self.type.upper() for t in self.str_types]) - - @property - def sqla_col(self): - name = self.column_name - if not self.expression: - col = column(self.column_name).label(name) - else: - col = literal_column(self.expression).label(name) - return col - - def get_time_filter(self, start_dttm, end_dttm): - col = self.sqla_col.label('__time') - return and_( - col >= text(self.dttm_sql_literal(start_dttm)), - col <= text(self.dttm_sql_literal(end_dttm)), - ) - - def get_timestamp_expression(self, time_grain): - """Getting the time component of the query""" - expr = self.expression or self.column_name - if not self.expression and not time_grain: - return column(expr, type_=DateTime).label(DTTM_ALIAS) - if time_grain: - pdf = self.python_date_format - if pdf in ('epoch_s', 'epoch_ms'): - # if epoch, translate to DATE using db specific conf - db_spec = self.table.database.db_engine_spec - if pdf == 'epoch_s': - expr = db_spec.epoch_to_dttm().format(col=expr) - elif pdf == 'epoch_ms': - expr = db_spec.epoch_ms_to_dttm().format(col=expr) - grain = self.table.database.grains_dict().get(time_grain, '{col}') - expr = grain.function.format(col=expr) - return literal_column(expr, type_=DateTime).label(DTTM_ALIAS) - - @classmethod - def import_obj(cls, i_column): - def lookup_obj(lookup_column): - return db.session.query(TableColumn).filter( - TableColumn.table_id == lookup_column.table_id, - TableColumn.column_name == lookup_column.column_name).first() - return import_util.import_simple_obj(db.session, i_column, lookup_obj) - - def dttm_sql_literal(self, dttm): - """Convert datetime object to a SQL expression string - - If database_expression is empty, the internal dttm - will be parsed as the string with the pattern that - the user inputted (python_date_format) - If database_expression is not empty, the internal dttm - will be parsed as the sql sentence for the database to convert - """ - - tf = self.python_date_format or '%Y-%m-%d %H:%M:%S.%f' - if self.database_expression: - return self.database_expression.format(dttm.strftime('%Y-%m-%d %H:%M:%S')) - elif tf == 'epoch_s': - return str((dttm - datetime(1970, 1, 1)).total_seconds()) - elif tf == 'epoch_ms': - return str((dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0) - else: - s = self.table.database.db_engine_spec.convert_dttm( - self.type, dttm) - return s or "'{}'".format(dttm.strftime(tf)) - - -class SqlMetric(Model, AuditMixinNullable, ImportMixin): - - """ORM object for metrics, each table can have multiple metrics""" - - __tablename__ = 'sql_metrics' - id = Column(Integer, primary_key=True) - metric_name = Column(String(512)) - verbose_name = Column(String(1024)) - metric_type = Column(String(32)) - table_id = Column(Integer, ForeignKey('tables.id')) - table = relationship( - 'SqlaTable', - backref=backref('metrics', cascade='all, delete-orphan'), - foreign_keys=[table_id]) - expression = Column(Text) - description = Column(Text) - is_restricted = Column(Boolean, default=False, nullable=True) - d3format = Column(String(128)) - - export_fields = ( - 'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression', - 'description', 'is_restricted', 'd3format') - - @property - def sqla_col(self): - name = self.metric_name - return literal_column(self.expression).label(name) - - @property - def perm(self): - return ( - "{parent_name}.[{obj.metric_name}](id:{obj.id})" - ).format(obj=self, - parent_name=self.table.full_name) if self.table else None - - @classmethod - def import_obj(cls, i_metric): - def lookup_obj(lookup_metric): - return db.session.query(SqlMetric).filter( - SqlMetric.table_id == lookup_metric.table_id, - SqlMetric.metric_name == lookup_metric.metric_name).first() - return import_util.import_simple_obj(db.session, i_metric, lookup_obj) - - -class SqlaTable(Model, Datasource, AuditMixinNullable, ImportMixin): - - """An ORM object for SqlAlchemy table references""" - - type = "table" - query_language = 'sql' - - __tablename__ = 'tables' - id = Column(Integer, primary_key=True) - table_name = Column(String(250)) - main_dttm_col = Column(String(250)) - description = Column(Text) - default_endpoint = Column(Text) - database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False) - is_featured = Column(Boolean, default=False) - filter_select_enabled = Column(Boolean, default=False) - user_id = Column(Integer, ForeignKey('ab_user.id')) - owner = relationship('User', backref='tables', foreign_keys=[user_id]) - database = relationship( - 'Database', - backref=backref('tables', cascade='all, delete-orphan'), - foreign_keys=[database_id]) - offset = Column(Integer, default=0) - cache_timeout = Column(Integer) - schema = Column(String(255)) - sql = Column(Text) - params = Column(Text) - perm = Column(String(1000)) - - baselink = "tablemodelview" - column_cls = TableColumn - metric_cls = SqlMetric - export_fields = ( - 'table_name', 'main_dttm_col', 'description', 'default_endpoint', - 'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema', - 'sql', 'params') - - __table_args__ = ( - sqla.UniqueConstraint( - 'database_id', 'schema', 'table_name', - name='_customer_location_uc'),) - - def __repr__(self): - return self.name - - @property - def description_markeddown(self): - return utils.markdown(self.description) - - @property - def link(self): - name = escape(self.name) - return Markup( - '{name}'.format(**locals())) - - @property - def schema_perm(self): - """Returns schema permission if present, database one otherwise.""" - return utils.get_schema_perm(self.database, self.schema) - - def get_perm(self): - return ( - "[{obj.database}].[{obj.table_name}]" - "(id:{obj.id})").format(obj=self) - - @property - def name(self): - if not self.schema: - return self.table_name - return "{}.{}".format(self.schema, self.table_name) - - @property - def full_name(self): - return utils.get_datasource_full_name( - self.database, self.table_name, schema=self.schema) - - @property - def dttm_cols(self): - l = [c.column_name for c in self.columns if c.is_dttm] - if self.main_dttm_col and self.main_dttm_col not in l: - l.append(self.main_dttm_col) - return l - - @property - def num_cols(self): - return [c.column_name for c in self.columns if c.is_num] - - @property - def any_dttm_col(self): - cols = self.dttm_cols - if cols: - return cols[0] - - @property - def html(self): - t = ((c.column_name, c.type) for c in self.columns) - df = pd.DataFrame(t) - df.columns = ['field', 'type'] - return df.to_html( - index=False, - classes=( - "dataframe table table-striped table-bordered " - "table-condensed")) - - @property - def metrics_combo(self): - return sorted( - [ - (m.metric_name, m.verbose_name or m.metric_name) - for m in self.metrics], - key=lambda x: x[1]) - - @property - def sql_url(self): - return self.database.sql_url + "?table_name=" + str(self.table_name) - - @property - def time_column_grains(self): - return { - "time_columns": self.dttm_cols, - "time_grains": [grain.name for grain in self.database.grains()] - } - - def get_col(self, col_name): - columns = self.columns - for col in columns: - if col_name == col.column_name: - return col - - def values_for_column(self, - column_name, - from_dttm, - to_dttm, - limit=500): - """Runs query against sqla to retrieve some - sample values for the given column. - """ - granularity = self.main_dttm_col - - cols = {col.column_name: col for col in self.columns} - target_col = cols[column_name] - - tbl = table(self.table_name) - qry = select([target_col.sqla_col]) - qry = qry.select_from(tbl) - qry = qry.distinct(column_name) - qry = qry.limit(limit) - - if granularity: - dttm_col = cols[granularity] - timestamp = dttm_col.sqla_col.label('timestamp') - time_filter = [ - timestamp >= text(dttm_col.dttm_sql_literal(from_dttm)), - timestamp <= text(dttm_col.dttm_sql_literal(to_dttm)), - ] - qry = qry.where(and_(*time_filter)) - - engine = self.database.get_sqla_engine() - sql = "{}".format( - qry.compile( - engine, compile_kwargs={"literal_binds": True}, ), - ) - - return pd.read_sql_query( - sql=sql, - con=engine - ) - - def get_query_str( # sqla - self, engine, qry_start_dttm, - groupby, metrics, - granularity, - from_dttm, to_dttm, - filter=None, # noqa - is_timeseries=True, - timeseries_limit=15, - timeseries_limit_metric=None, - row_limit=None, - inner_from_dttm=None, - inner_to_dttm=None, - orderby=None, - extras=None, - columns=None): - """Querying any sqla table from this common interface""" - template_processor = get_template_processor( - table=self, database=self.database) - - # For backward compatibility - if granularity not in self.dttm_cols: - granularity = self.main_dttm_col - - cols = {col.column_name: col for col in self.columns} - metrics_dict = {m.metric_name: m for m in self.metrics} - - if not granularity and is_timeseries: - raise Exception(_( - "Datetime column not provided as part table configuration " - "and is required by this type of chart")) - for m in metrics: - if m not in metrics_dict: - raise Exception(_("Metric '{}' is not valid".format(m))) - metrics_exprs = [metrics_dict.get(m).sqla_col for m in metrics] - timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric) - timeseries_limit_metric_expr = None - if timeseries_limit_metric: - timeseries_limit_metric_expr = \ - timeseries_limit_metric.sqla_col - if metrics: - main_metric_expr = metrics_exprs[0] - else: - main_metric_expr = literal_column("COUNT(*)").label("ccount") - - select_exprs = [] - groupby_exprs = [] - - if groupby: - select_exprs = [] - inner_select_exprs = [] - inner_groupby_exprs = [] - for s in groupby: - col = cols[s] - outer = col.sqla_col - inner = col.sqla_col.label(col.column_name + '__') - - groupby_exprs.append(outer) - select_exprs.append(outer) - inner_groupby_exprs.append(inner) - inner_select_exprs.append(inner) - elif columns: - for s in columns: - select_exprs.append(cols[s].sqla_col) - metrics_exprs = [] - - if granularity: - @compiles(ColumnClause) - def visit_column(element, compiler, **kw): - """Patch for sqlalchemy bug - - TODO: sqlalchemy 1.2 release should be doing this on its own. - Patch only if the column clause is specific for DateTime - set and granularity is selected. - """ - text = compiler.visit_column(element, **kw) - try: - if ( - element.is_literal and - hasattr(element.type, 'python_type') and - type(element.type) is DateTime - ): - text = text.replace('%%', '%') - except NotImplementedError: - # Some elements raise NotImplementedError for python_type - pass - return text - - dttm_col = cols[granularity] - time_grain = extras.get('time_grain_sqla') - - if is_timeseries: - timestamp = dttm_col.get_timestamp_expression(time_grain) - select_exprs += [timestamp] - groupby_exprs += [timestamp] - - time_filter = dttm_col.get_time_filter(from_dttm, to_dttm) - - select_exprs += metrics_exprs - qry = select(select_exprs) - - tbl = table(self.table_name) - if self.schema: - tbl.schema = self.schema - - # Supporting arbitrary SQL statements in place of tables - if self.sql: - tbl = TextAsFrom(sqla.text(self.sql), []).alias('expr_qry') - - if not columns: - qry = qry.group_by(*groupby_exprs) - - where_clause_and = [] - having_clause_and = [] - for flt in filter: - if not all([flt.get(s) for s in ['col', 'op', 'val']]): - continue - col = flt['col'] - op = flt['op'] - eq = flt['val'] - col_obj = cols.get(col) - if col_obj and op in ('in', 'not in'): - values = [types.strip("'").strip('"') for types in eq] - if col_obj.is_num: - values = [utils.js_string_to_num(s) for s in values] - cond = col_obj.sqla_col.in_(values) - if op == 'not in': - cond = ~cond - where_clause_and.append(cond) - if extras: - where = extras.get('where') - if where: - where_clause_and += [wrap_clause_in_parens( - template_processor.process_template(where))] - having = extras.get('having') - if having: - having_clause_and += [wrap_clause_in_parens( - template_processor.process_template(having))] - if granularity: - qry = qry.where(and_(*([time_filter] + where_clause_and))) - else: - qry = qry.where(and_(*where_clause_and)) - qry = qry.having(and_(*having_clause_and)) - if groupby: - qry = qry.order_by(desc(main_metric_expr)) - elif orderby: - for col, ascending in orderby: - direction = asc if ascending else desc - qry = qry.order_by(direction(col)) - - qry = qry.limit(row_limit) - - if is_timeseries and timeseries_limit and groupby: - # some sql dialects require for order by expressions - # to also be in the select clause -- others, e.g. vertica, - # require a unique inner alias - inner_main_metric_expr = main_metric_expr.label('mme_inner__') - inner_select_exprs += [inner_main_metric_expr] - subq = select(inner_select_exprs) - subq = subq.select_from(tbl) - inner_time_filter = dttm_col.get_time_filter( - inner_from_dttm or from_dttm, - inner_to_dttm or to_dttm, - ) - subq = subq.where(and_(*(where_clause_and + [inner_time_filter]))) - subq = subq.group_by(*inner_groupby_exprs) - ob = inner_main_metric_expr - if timeseries_limit_metric_expr is not None: - ob = timeseries_limit_metric_expr - subq = subq.order_by(desc(ob)) - subq = subq.limit(timeseries_limit) - on_clause = [] - for i, gb in enumerate(groupby): - on_clause.append( - groupby_exprs[i] == column(gb + '__')) - - tbl = tbl.join(subq.alias(), and_(*on_clause)) - - qry = qry.select_from(tbl) - - sql = "{}".format( - qry.compile( - engine, compile_kwargs={"literal_binds": True},), - ) - logging.info(sql) - sql = sqlparse.format(sql, reindent=True) - return sql - - def query(self, query_obj): - qry_start_dttm = datetime.now() - engine = self.database.get_sqla_engine() - sql = self.get_query_str(engine, qry_start_dttm, **query_obj) - status = QueryStatus.SUCCESS - error_message = None - df = None - try: - df = pd.read_sql_query(sql, con=engine) - except Exception as e: - status = QueryStatus.FAILED - error_message = str(e) - - return QueryResult( - status=status, - df=df, - duration=datetime.now() - qry_start_dttm, - query=sql, - error_message=error_message) - - def get_sqla_table_object(self): - return self.database.get_table(self.table_name, schema=self.schema) - - def fetch_metadata(self): - """Fetches the metadata for the table and merges it in""" - try: - table = self.get_sqla_table_object() - except Exception: - raise Exception( - "Table doesn't seem to exist in the specified database, " - "couldn't fetch column information") - - TC = TableColumn # noqa shortcut to class - M = SqlMetric # noqa - metrics = [] - any_date_col = None - for col in table.columns: - try: - datatype = "{}".format(col.type).upper() - except Exception as e: - datatype = "UNKNOWN" - logging.error( - "Unrecognized data type in {}.{}".format(table, col.name)) - logging.exception(e) - dbcol = ( - db.session - .query(TC) - .filter(TC.table == self) - .filter(TC.column_name == col.name) - .first() - ) - db.session.flush() - if not dbcol: - dbcol = TableColumn(column_name=col.name, type=datatype) - dbcol.groupby = dbcol.is_string - dbcol.filterable = dbcol.is_string - dbcol.sum = dbcol.is_num - dbcol.avg = dbcol.is_num - dbcol.is_dttm = dbcol.is_time - - db.session.merge(self) - self.columns.append(dbcol) - - if not any_date_col and dbcol.is_time: - any_date_col = col.name - - quoted = "{}".format( - column(dbcol.column_name).compile(dialect=db.engine.dialect)) - if dbcol.sum: - metrics.append(M( - metric_name='sum__' + dbcol.column_name, - verbose_name='sum__' + dbcol.column_name, - metric_type='sum', - expression="SUM({})".format(quoted) - )) - if dbcol.avg: - metrics.append(M( - metric_name='avg__' + dbcol.column_name, - verbose_name='avg__' + dbcol.column_name, - metric_type='avg', - expression="AVG({})".format(quoted) - )) - if dbcol.max: - metrics.append(M( - metric_name='max__' + dbcol.column_name, - verbose_name='max__' + dbcol.column_name, - metric_type='max', - expression="MAX({})".format(quoted) - )) - if dbcol.min: - metrics.append(M( - metric_name='min__' + dbcol.column_name, - verbose_name='min__' + dbcol.column_name, - metric_type='min', - expression="MIN({})".format(quoted) - )) - if dbcol.count_distinct: - metrics.append(M( - metric_name='count_distinct__' + dbcol.column_name, - verbose_name='count_distinct__' + dbcol.column_name, - metric_type='count_distinct', - expression="COUNT(DISTINCT {})".format(quoted) - )) - dbcol.type = datatype - db.session.merge(self) - db.session.commit() - - metrics.append(M( - metric_name='count', - verbose_name='COUNT(*)', - metric_type='count', - expression="COUNT(*)" - )) - for metric in metrics: - m = ( - db.session.query(M) - .filter(M.metric_name == metric.metric_name) - .filter(M.table_id == self.id) - .first() - ) - metric.table_id = self.id - if not m: - db.session.add(metric) - db.session.commit() - if not self.main_dttm_col: - self.main_dttm_col = any_date_col - - @classmethod - def import_obj(cls, i_datasource, import_time=None): - """Imports the datasource from the object to the database. - - Metrics and columns and datasource will be overrided if exists. - This function can be used to import/export dashboards between multiple - superset instances. Audit metadata isn't copies over. - """ - def lookup_sqlatable(table): - return db.session.query(SqlaTable).join(Database).filter( - SqlaTable.table_name == table.table_name, - SqlaTable.schema == table.schema, - Database.id == table.database_id, - ).first() - - def lookup_database(table): - return db.session.query(Database).filter_by( - database_name=table.params_dict['database_name']).one() - return import_util.import_datasource( - db.session, i_datasource, lookup_database, lookup_sqlatable, - import_time) - -sqla.event.listen(SqlaTable, 'after_insert', set_perm) -sqla.event.listen(SqlaTable, 'after_update', set_perm) - - -class DruidCluster(Model, AuditMixinNullable): - - """ORM object referencing the Druid clusters""" - - __tablename__ = 'clusters' - type = "druid" - - id = Column(Integer, primary_key=True) - cluster_name = Column(String(250), unique=True) - coordinator_host = Column(String(255)) - coordinator_port = Column(Integer) - coordinator_endpoint = Column( - String(255), default='druid/coordinator/v1/metadata') - broker_host = Column(String(255)) - broker_port = Column(Integer) - broker_endpoint = Column(String(255), default='druid/v2') - metadata_last_refreshed = Column(DateTime) - cache_timeout = Column(Integer) - - def __repr__(self): - return self.cluster_name - - def get_pydruid_client(self): - cli = PyDruid( - "http://{0}:{1}/".format(self.broker_host, self.broker_port), - self.broker_endpoint) - return cli - - def get_datasources(self): - endpoint = ( - "http://{obj.coordinator_host}:{obj.coordinator_port}/" - "{obj.coordinator_endpoint}/datasources" - ).format(obj=self) - - return json.loads(requests.get(endpoint).text) - - def get_druid_version(self): - endpoint = ( - "http://{obj.coordinator_host}:{obj.coordinator_port}/status" - ).format(obj=self) - return json.loads(requests.get(endpoint).text)['version'] - - def refresh_datasources(self, datasource_name=None, merge_flag=False): - """Refresh metadata of all datasources in the cluster - If ``datasource_name`` is specified, only that datasource is updated - """ - self.druid_version = self.get_druid_version() - for datasource in self.get_datasources(): - if datasource not in config.get('DRUID_DATA_SOURCE_BLACKLIST'): - if not datasource_name or datasource_name == datasource: - DruidDatasource.sync_to_db(datasource, self, merge_flag) - - @property - def perm(self): - return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self) - - @property - def name(self): - return self.cluster_name - - -class DruidColumn(Model, AuditMixinNullable, ImportMixin): - """ORM model for storing Druid datasource column metadata""" - - __tablename__ = 'columns' - id = Column(Integer, primary_key=True) - datasource_name = Column( - String(255), - ForeignKey('datasources.datasource_name')) - # Setting enable_typechecks=False disables polymorphic inheritance. - datasource = relationship( - 'DruidDatasource', - backref=backref('columns', cascade='all, delete-orphan'), - enable_typechecks=False) - column_name = Column(String(255)) - is_active = Column(Boolean, default=True) - type = Column(String(32)) - groupby = Column(Boolean, default=False) - count_distinct = Column(Boolean, default=False) - sum = Column(Boolean, default=False) - avg = Column(Boolean, default=False) - max = Column(Boolean, default=False) - min = Column(Boolean, default=False) - filterable = Column(Boolean, default=False) - description = Column(Text) - dimension_spec_json = Column(Text) - - export_fields = ( - 'datasource_name', 'column_name', 'is_active', 'type', 'groupby', - 'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable', - 'description', 'dimension_spec_json' - ) - - def __repr__(self): - return self.column_name - - @property - def is_num(self): - return self.type in ('LONG', 'DOUBLE', 'FLOAT', 'INT') - - @property - def dimension_spec(self): - if self.dimension_spec_json: - return json.loads(self.dimension_spec_json) - - def generate_metrics(self): - """Generate metrics based on the column metadata""" - M = DruidMetric # noqa - metrics = [] - metrics.append(DruidMetric( - metric_name='count', - verbose_name='COUNT(*)', - metric_type='count', - json=json.dumps({'type': 'count', 'name': 'count'}) - )) - # Somehow we need to reassign this for UDAFs - if self.type in ('DOUBLE', 'FLOAT'): - corrected_type = 'DOUBLE' - else: - corrected_type = self.type - - if self.sum and self.is_num: - mt = corrected_type.lower() + 'Sum' - name = 'sum__' + self.column_name - metrics.append(DruidMetric( - metric_name=name, - metric_type='sum', - verbose_name='SUM({})'.format(self.column_name), - json=json.dumps({ - 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) - - if self.avg and self.is_num: - mt = corrected_type.lower() + 'Avg' - name = 'avg__' + self.column_name - metrics.append(DruidMetric( - metric_name=name, - metric_type='avg', - verbose_name='AVG({})'.format(self.column_name), - json=json.dumps({ - 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) - - if self.min and self.is_num: - mt = corrected_type.lower() + 'Min' - name = 'min__' + self.column_name - metrics.append(DruidMetric( - metric_name=name, - metric_type='min', - verbose_name='MIN({})'.format(self.column_name), - json=json.dumps({ - 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) - if self.max and self.is_num: - mt = corrected_type.lower() + 'Max' - name = 'max__' + self.column_name - metrics.append(DruidMetric( - metric_name=name, - metric_type='max', - verbose_name='MAX({})'.format(self.column_name), - json=json.dumps({ - 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) - if self.count_distinct: - name = 'count_distinct__' + self.column_name - if self.type == 'hyperUnique' or self.type == 'thetaSketch': - metrics.append(DruidMetric( - metric_name=name, - verbose_name='COUNT(DISTINCT {})'.format(self.column_name), - metric_type=self.type, - json=json.dumps({ - 'type': self.type, - 'name': name, - 'fieldName': self.column_name - }) - )) - else: - mt = 'count_distinct' - metrics.append(DruidMetric( - metric_name=name, - verbose_name='COUNT(DISTINCT {})'.format(self.column_name), - metric_type='count_distinct', - json=json.dumps({ - 'type': 'cardinality', - 'name': name, - 'fieldNames': [self.column_name]}) - )) - session = get_session() - new_metrics = [] - for metric in metrics: - m = ( - session.query(M) - .filter(M.metric_name == metric.metric_name) - .filter(M.datasource_name == self.datasource_name) - .filter(DruidCluster.cluster_name == self.datasource.cluster_name) - .first() - ) - metric.datasource_name = self.datasource_name - if not m: - new_metrics.append(metric) - session.add(metric) - session.flush() - - @classmethod - def import_obj(cls, i_column): - def lookup_obj(lookup_column): - return db.session.query(DruidColumn).filter( - DruidColumn.datasource_name == lookup_column.datasource_name, - DruidColumn.column_name == lookup_column.column_name).first() - - return import_util.import_simple_obj(db.session, i_column, lookup_obj) - - -class DruidMetric(Model, AuditMixinNullable, ImportMixin): - - """ORM object referencing Druid metrics for a datasource""" - - __tablename__ = 'metrics' - id = Column(Integer, primary_key=True) - metric_name = Column(String(512)) - verbose_name = Column(String(1024)) - metric_type = Column(String(32)) - datasource_name = Column( - String(255), - ForeignKey('datasources.datasource_name')) - # Setting enable_typechecks=False disables polymorphic inheritance. - datasource = relationship( - 'DruidDatasource', - backref=backref('metrics', cascade='all, delete-orphan'), - enable_typechecks=False) - json = Column(Text) - description = Column(Text) - is_restricted = Column(Boolean, default=False, nullable=True) - d3format = Column(String(128)) - - def refresh_datasources(self, datasource_name=None, merge_flag=False): - """Refresh metadata of all datasources in the cluster - - If ``datasource_name`` is specified, only that datasource is updated - """ - self.druid_version = self.get_druid_version() - for datasource in self.get_datasources(): - if datasource not in config.get('DRUID_DATA_SOURCE_BLACKLIST'): - if not datasource_name or datasource_name == datasource: - DruidDatasource.sync_to_db(datasource, self, merge_flag) - export_fields = ( - 'metric_name', 'verbose_name', 'metric_type', 'datasource_name', - 'json', 'description', 'is_restricted', 'd3format' - ) - - @property - def json_obj(self): - try: - obj = json.loads(self.json) - except Exception: - obj = {} - return obj - - @property - def perm(self): - return ( - "{parent_name}.[{obj.metric_name}](id:{obj.id})" - ).format(obj=self, - parent_name=self.datasource.full_name - ) if self.datasource else None - - @classmethod - def import_obj(cls, i_metric): - def lookup_obj(lookup_metric): - return db.session.query(DruidMetric).filter( - DruidMetric.datasource_name == lookup_metric.datasource_name, - DruidMetric.metric_name == lookup_metric.metric_name).first() - return import_util.import_simple_obj(db.session, i_metric, lookup_obj) - - -class DruidDatasource(Model, AuditMixinNullable, Datasource, ImportMixin): - - """ORM object referencing Druid datasources (tables)""" - - type = "druid" - query_langtage = "json" - - baselink = "druiddatasourcemodelview" - - __tablename__ = 'datasources' - id = Column(Integer, primary_key=True) - datasource_name = Column(String(255), unique=True) - is_featured = Column(Boolean, default=False) - is_hidden = Column(Boolean, default=False) - filter_select_enabled = Column(Boolean, default=False) - description = Column(Text) - default_endpoint = Column(Text) - user_id = Column(Integer, ForeignKey('ab_user.id')) - owner = relationship( - 'User', - backref=backref('datasources', cascade='all, delete-orphan'), - foreign_keys=[user_id]) - cluster_name = Column( - String(250), ForeignKey('clusters.cluster_name')) - cluster = relationship( - 'DruidCluster', backref='datasources', foreign_keys=[cluster_name]) - offset = Column(Integer, default=0) - cache_timeout = Column(Integer) - params = Column(String(1000)) - perm = Column(String(1000)) - - metric_cls = DruidMetric - column_cls = DruidColumn - - export_fields = ( - 'datasource_name', 'is_hidden', 'description', 'default_endpoint', - 'cluster_name', 'is_featured', 'offset', 'cache_timeout', 'params' - ) - - @property - def metrics_combo(self): - return sorted( - [(m.metric_name, m.verbose_name) for m in self.metrics], - key=lambda x: x[1]) - - @property - def database(self): - return self.cluster - - @property - def num_cols(self): - return [c.column_name for c in self.columns if c.is_num] - - @property - def name(self): - return self.datasource_name - - @property - def schema(self): - name_pieces = self.datasource_name.split('.') - if len(name_pieces) > 1: - return name_pieces[0] - else: - return None - - @property - def schema_perm(self): - """Returns schema permission if present, cluster one otherwise.""" - return utils.get_schema_perm(self.cluster, self.schema) - - def get_perm(self): - return ( - "[{obj.cluster_name}].[{obj.datasource_name}]" - "(id:{obj.id})").format(obj=self) - - @property - def link(self): - name = escape(self.datasource_name) - return Markup('{name}').format(**locals()) - - @property - def full_name(self): - return utils.get_datasource_full_name( - self.cluster_name, self.datasource_name) - - @property - def time_column_grains(self): - return { - "time_columns": [ - 'all', '5 seconds', '30 seconds', '1 minute', - '5 minutes', '1 hour', '6 hour', '1 day', '7 days', - 'week', 'week_starting_sunday', 'week_ending_saturday', - 'month', - ], - "time_grains": ['now'] - } - - def __repr__(self): - return self.datasource_name - - @renders('datasource_name') - def datasource_link(self): - url = "/superset/explore/{obj.type}/{obj.id}/".format(obj=self) - name = escape(self.datasource_name) - return Markup('{name}'.format(**locals())) - - def get_metric_obj(self, metric_name): - return [ - m.json_obj for m in self.metrics - if m.metric_name == metric_name - ][0] - - @classmethod - def import_obj(cls, i_datasource, import_time=None): - """Imports the datasource from the object to the database. - - Metrics and columns and datasource will be overridden if exists. - This function can be used to import/export dashboards between multiple - superset instances. Audit metadata isn't copies over. - """ - def lookup_datasource(d): - return db.session.query(DruidDatasource).join(DruidCluster).filter( - DruidDatasource.datasource_name == d.datasource_name, - DruidCluster.cluster_name == d.cluster_name, - ).first() - - def lookup_cluster(d): - return db.session.query(DruidCluster).filter_by( - cluster_name=d.cluster_name).one() - return import_util.import_datasource( - db.session, i_datasource, lookup_cluster, lookup_datasource, - import_time) - - @staticmethod - def version_higher(v1, v2): - """is v1 higher than v2 - - >>> DruidDatasource.version_higher('0.8.2', '0.9.1') - False - >>> DruidDatasource.version_higher('0.8.2', '0.6.1') - True - >>> DruidDatasource.version_higher('0.8.2', '0.8.2') - False - >>> DruidDatasource.version_higher('0.8.2', '0.9.BETA') - False - >>> DruidDatasource.version_higher('0.8.2', '0.9') - False - """ - def int_or_0(v): - try: - v = int(v) - except (TypeError, ValueError): - v = 0 - return v - v1nums = [int_or_0(n) for n in v1.split('.')] - v2nums = [int_or_0(n) for n in v2.split('.')] - v1nums = (v1nums + [0, 0, 0])[:3] - v2nums = (v2nums + [0, 0, 0])[:3] - return v1nums[0] > v2nums[0] or \ - (v1nums[0] == v2nums[0] and v1nums[1] > v2nums[1]) or \ - (v1nums[0] == v2nums[0] and v1nums[1] == v2nums[1] and v1nums[2] > v2nums[2]) - - def latest_metadata(self): - """Returns segment metadata from the latest segment""" - client = self.cluster.get_pydruid_client() - results = client.time_boundary(datasource=self.datasource_name) - if not results: - return - max_time = results[0]['result']['maxTime'] - max_time = dparse(max_time) - # Query segmentMetadata for 7 days back. However, due to a bug, - # we need to set this interval to more than 1 day ago to exclude - # realtime segments, which triggered a bug (fixed in druid 0.8.2). - # https://groups.google.com/forum/#!topic/druid-user/gVCqqspHqOQ - lbound = (max_time - timedelta(days=7)).isoformat() - rbound = max_time.isoformat() - if not self.version_higher(self.cluster.druid_version, '0.8.2'): - rbound = (max_time - timedelta(1)).isoformat() - segment_metadata = None - try: - segment_metadata = client.segment_metadata( - datasource=self.datasource_name, - intervals=lbound + '/' + rbound, - merge=self.merge_flag, - analysisTypes=config.get('DRUID_ANALYSIS_TYPES')) - except Exception as e: - logging.warning("Failed first attempt to get latest segment") - logging.exception(e) - if not segment_metadata: - # if no segments in the past 7 days, look at all segments - lbound = datetime(1901, 1, 1).isoformat()[:10] - rbound = datetime(2050, 1, 1).isoformat()[:10] - if not self.version_higher(self.cluster.druid_version, '0.8.2'): - rbound = datetime.now().isoformat()[:10] - try: - segment_metadata = client.segment_metadata( - datasource=self.datasource_name, - intervals=lbound + '/' + rbound, - merge=self.merge_flag, - analysisTypes=config.get('DRUID_ANALYSIS_TYPES')) - except Exception as e: - logging.warning("Failed 2nd attempt to get latest segment") - logging.exception(e) - if segment_metadata: - return segment_metadata[-1]['columns'] - - def generate_metrics(self): - for col in self.columns: - col.generate_metrics() - - @classmethod - def sync_to_db_from_config(cls, druid_config, user, cluster): - """Merges the ds config from druid_config into one stored in the db.""" - session = db.session() - datasource = ( - session.query(DruidDatasource) - .filter_by( - datasource_name=druid_config['name']) - ).first() - # Create a new datasource. - if not datasource: - datasource = DruidDatasource( - datasource_name=druid_config['name'], - cluster=cluster, - owner=user, - changed_by_fk=user.id, - created_by_fk=user.id, - ) - session.add(datasource) - - dimensions = druid_config['dimensions'] - for dim in dimensions: - col_obj = ( - session.query(DruidColumn) - .filter_by( - datasource_name=druid_config['name'], - column_name=dim) - ).first() - if not col_obj: - col_obj = DruidColumn( - datasource_name=druid_config['name'], - column_name=dim, - groupby=True, - filterable=True, - # TODO: fetch type from Hive. - type="STRING", - datasource=datasource - ) - session.add(col_obj) - # Import Druid metrics - for metric_spec in druid_config["metrics_spec"]: - metric_name = metric_spec["name"] - metric_type = metric_spec["type"] - metric_json = json.dumps(metric_spec) - - if metric_type == "count": - metric_type = "longSum" - metric_json = json.dumps({ - "type": "longSum", - "name": metric_name, - "fieldName": metric_name, - }) - - metric_obj = ( - session.query(DruidMetric) - .filter_by( - datasource_name=druid_config['name'], - metric_name=metric_name) - ).first() - if not metric_obj: - metric_obj = DruidMetric( - metric_name=metric_name, - metric_type=metric_type, - verbose_name="%s(%s)" % (metric_type, metric_name), - datasource=datasource, - json=metric_json, - description=( - "Imported from the airolap config dir for %s" % - druid_config['name']), - ) - session.add(metric_obj) - session.commit() - - @classmethod - def sync_to_db(cls, name, cluster, merge): - """Fetches metadata for that datasource and merges the Superset db""" - logging.info("Syncing Druid datasource [{}]".format(name)) - session = get_session() - datasource = session.query(cls).filter_by(datasource_name=name).first() - if not datasource: - datasource = cls(datasource_name=name) - session.add(datasource) - flasher("Adding new datasource [{}]".format(name), "success") - else: - flasher("Refreshing datasource [{}]".format(name), "info") - session.flush() - datasource.cluster = cluster - datasource.merge_flag = merge - session.flush() - - cols = datasource.latest_metadata() - if not cols: - logging.error("Failed at fetching the latest segment") - return - for col in cols: - col_obj = ( - session - .query(DruidColumn) - .filter_by(datasource_name=name, column_name=col) - .first() - ) - datatype = cols[col]['type'] - if not col_obj: - col_obj = DruidColumn(datasource_name=name, column_name=col) - session.add(col_obj) - if datatype == "STRING": - col_obj.groupby = True - col_obj.filterable = True - if datatype == "hyperUnique" or datatype == "thetaSketch": - col_obj.count_distinct = True - if col_obj: - col_obj.type = cols[col]['type'] - session.flush() - col_obj.datasource = datasource - col_obj.generate_metrics() - session.flush() - - @staticmethod - def time_offset(granularity): - if granularity == 'week_ending_saturday': - return 6 * 24 * 3600 * 1000 # 6 days - return 0 - - # uses https://en.wikipedia.org/wiki/ISO_8601 - # http://druid.io/docs/0.8.0/querying/granularities.html - # TODO: pass origin from the UI - @staticmethod - def granularity(period_name, timezone=None, origin=None): - if not period_name or period_name == 'all': - return 'all' - iso_8601_dict = { - '5 seconds': 'PT5S', - '30 seconds': 'PT30S', - '1 minute': 'PT1M', - '5 minutes': 'PT5M', - '1 hour': 'PT1H', - '6 hour': 'PT6H', - 'one day': 'P1D', - '1 day': 'P1D', - '7 days': 'P7D', - 'week': 'P1W', - 'week_starting_sunday': 'P1W', - 'week_ending_saturday': 'P1W', - 'month': 'P1M', - } - - granularity = {'type': 'period'} - if timezone: - granularity['timeZone'] = timezone - - if origin: - dttm = utils.parse_human_datetime(origin) - granularity['origin'] = dttm.isoformat() - - if period_name in iso_8601_dict: - granularity['period'] = iso_8601_dict[period_name] - if period_name in ('week_ending_saturday', 'week_starting_sunday'): - # use Sunday as start of the week - granularity['origin'] = '2016-01-03T00:00:00' - elif not isinstance(period_name, string_types): - granularity['type'] = 'duration' - granularity['duration'] = period_name - elif period_name.startswith('P'): - # identify if the string is the iso_8601 period - granularity['period'] = period_name - else: - granularity['type'] = 'duration' - granularity['duration'] = utils.parse_human_timedelta( - period_name).total_seconds() * 1000 - return granularity - - def values_for_column(self, - column_name, - from_dttm, - to_dttm, - limit=500): - """Retrieve some values for the given column""" - # TODO: Use Lexicographic TopNMetricSpec once supported by PyDruid - from_dttm = from_dttm.replace(tzinfo=config.get("DRUID_TZ")) - to_dttm = to_dttm.replace(tzinfo=config.get("DRUID_TZ")) - - qry = dict( - datasource=self.datasource_name, - granularity="all", - intervals=from_dttm.isoformat() + '/' + to_dttm.isoformat(), - aggregations=dict(count=count("count")), - dimension=column_name, - metric="count", - threshold=limit, - ) - - client = self.cluster.get_pydruid_client() - client.topn(**qry) - df = client.export_pandas() - - if df is None or df.size == 0: - raise Exception(_("No data was returned.")) - - return df - - def get_query_str( # druid - self, client, qry_start_dttm, - groupby, metrics, - granularity, - from_dttm, to_dttm, - filter=None, # noqa - is_timeseries=True, - timeseries_limit=None, - timeseries_limit_metric=None, - row_limit=None, - inner_from_dttm=None, inner_to_dttm=None, - orderby=None, - extras=None, # noqa - select=None, # noqa - columns=None, phase=2): - """Runs a query against Druid and returns a dataframe. - - This query interface is common to SqlAlchemy and Druid - """ - # TODO refactor into using a TBD Query object - if not is_timeseries: - granularity = 'all' - inner_from_dttm = inner_from_dttm or from_dttm - inner_to_dttm = inner_to_dttm or to_dttm - - # add tzinfo to native datetime with config - from_dttm = from_dttm.replace(tzinfo=config.get("DRUID_TZ")) - to_dttm = to_dttm.replace(tzinfo=config.get("DRUID_TZ")) - timezone = from_dttm.tzname() - - query_str = "" - metrics_dict = {m.metric_name: m for m in self.metrics} - all_metrics = [] - post_aggs = {} - - columns_dict = {c.column_name: c for c in self.columns} - - def recursive_get_fields(_conf): - _fields = _conf.get('fields', []) - field_names = [] - for _f in _fields: - _type = _f.get('type') - if _type in ['fieldAccess', 'hyperUniqueCardinality']: - field_names.append(_f.get('fieldName')) - elif _type == 'arithmetic': - field_names += recursive_get_fields(_f) - return list(set(field_names)) - - for metric_name in metrics: - metric = metrics_dict[metric_name] - if metric.metric_type != 'postagg': - all_metrics.append(metric_name) - else: - conf = metric.json_obj - all_metrics += recursive_get_fields(conf) - all_metrics += conf.get('fieldNames', []) - if conf.get('type') == 'javascript': - post_aggs[metric_name] = JavascriptPostAggregator( - name=conf.get('name', ''), - field_names=conf.get('fieldNames', []), - function=conf.get('function', '')) - elif conf.get('type') == 'quantile': - post_aggs[metric_name] = Quantile( - conf.get('name', ''), - conf.get('probability', ''), - ) - elif conf.get('type') == 'quantiles': - post_aggs[metric_name] = Quantiles( - conf.get('name', ''), - conf.get('probabilities', ''), - ) - elif conf.get('type') == 'fieldAccess': - post_aggs[metric_name] = Field(conf.get('name'), '') - elif conf.get('type') == 'constant': - post_aggs[metric_name] = Const( - conf.get('value'), - output_name=conf.get('name', '') - ) - elif conf.get('type') == 'hyperUniqueCardinality': - post_aggs[metric_name] = HyperUniqueCardinality( - conf.get('name'), '' - ) - else: - post_aggs[metric_name] = Postaggregator( - conf.get('fn', "/"), - conf.get('fields', []), - conf.get('name', '')) - - aggregations = OrderedDict() - for m in self.metrics: - if m.metric_name in all_metrics: - aggregations[m.metric_name] = m.json_obj - - rejected_metrics = [ - m.metric_name for m in self.metrics - if m.is_restricted and - m.metric_name in aggregations.keys() and - not sm.has_access('metric_access', m.perm) - ] - - if rejected_metrics: - raise MetricPermException( - "Access to the metrics denied: " + ', '.join(rejected_metrics) - ) - - # the dimensions list with dimensionSpecs expanded - dimensions = [] - groupby = [gb for gb in groupby if gb in columns_dict] - for column_name in groupby: - col = columns_dict.get(column_name) - dim_spec = col.dimension_spec - if dim_spec: - dimensions.append(dim_spec) - else: - dimensions.append(column_name) - qry = dict( - datasource=self.datasource_name, - dimensions=dimensions, - aggregations=aggregations, - granularity=DruidDatasource.granularity( - granularity, - timezone=timezone, - origin=extras.get('druid_time_origin'), - ), - post_aggregations=post_aggs, - intervals=from_dttm.isoformat() + '/' + to_dttm.isoformat(), - ) - - filters = self.get_filters(filter) - if filters: - qry['filter'] = filters - - having_filters = self.get_having_filters(extras.get('having_druid')) - if having_filters: - qry['having'] = having_filters - - orig_filters = filters - if len(groupby) == 0: - del qry['dimensions'] - client.timeseries(**qry) - if not having_filters and len(groupby) == 1: - qry['threshold'] = timeseries_limit or 1000 - if row_limit and granularity == 'all': - qry['threshold'] = row_limit - qry['dimension'] = list(qry.get('dimensions'))[0] - del qry['dimensions'] - qry['metric'] = list(qry['aggregations'].keys())[0] - client.topn(**qry) - elif len(groupby) > 1 or having_filters: - # If grouping on multiple fields or using a having filter - # we have to force a groupby query - if timeseries_limit and is_timeseries: - order_by = metrics[0] if metrics else self.metrics[0] - if timeseries_limit_metric: - order_by = timeseries_limit_metric - # Limit on the number of timeseries, doing a two-phases query - pre_qry = deepcopy(qry) - pre_qry['granularity'] = "all" - pre_qry['limit_spec'] = { - "type": "default", - "limit": timeseries_limit, - 'intervals': ( - inner_from_dttm.isoformat() + '/' + - inner_to_dttm.isoformat()), - "columns": [{ - "dimension": order_by, - "direction": "descending", - }], - } - client.groupby(**pre_qry) - query_str += "// Two phase query\n// Phase 1\n" - query_str += json.dumps( - client.query_builder.last_query.query_dict, indent=2) - query_str += "\n" - if phase == 1: - return query_str - query_str += ( - "//\nPhase 2 (built based on phase one's results)\n") - df = client.export_pandas() - if df is not None and not df.empty: - dims = qry['dimensions'] - filters = [] - for unused, row in df.iterrows(): - fields = [] - for dim in dims: - f = Dimension(dim) == row[dim] - fields.append(f) - if len(fields) > 1: - filt = Filter(type="and", fields=fields) - filters.append(filt) - elif fields: - filters.append(fields[0]) - - if filters: - ff = Filter(type="or", fields=filters) - if not orig_filters: - qry['filter'] = ff - else: - qry['filter'] = Filter(type="and", fields=[ - ff, - orig_filters]) - qry['limit_spec'] = None - if row_limit: - qry['limit_spec'] = { - "type": "default", - "limit": row_limit, - "columns": [{ - "dimension": ( - metrics[0] if metrics else self.metrics[0]), - "direction": "descending", - }], - } - client.groupby(**qry) - query_str += json.dumps( - client.query_builder.last_query.query_dict, indent=2) - return query_str - - def query(self, query_obj): - qry_start_dttm = datetime.now() - client = self.cluster.get_pydruid_client() - query_str = self.get_query_str(client, qry_start_dttm, **query_obj) - df = client.export_pandas() - - if df is None or df.size == 0: - raise Exception(_("No data was returned.")) - df.columns = [ - DTTM_ALIAS if c == 'timestamp' else c for c in df.columns] - - is_timeseries = query_obj['is_timeseries'] \ - if 'is_timeseries' in query_obj else True - if ( - not is_timeseries and - query_obj['granularity'] == "all" and - DTTM_ALIAS in df.columns): - del df[DTTM_ALIAS] - - # Reordering columns - cols = [] - if DTTM_ALIAS in df.columns: - cols += [DTTM_ALIAS] - cols += [col for col in query_obj['groupby'] if col in df.columns] - cols += [col for col in query_obj['metrics'] if col in df.columns] - df = df[cols] - - time_offset = DruidDatasource.time_offset(query_obj['granularity']) - - def increment_timestamp(ts): - dt = utils.parse_human_datetime(ts).replace( - tzinfo=config.get("DRUID_TZ")) - return dt + timedelta(milliseconds=time_offset) - if DTTM_ALIAS in df.columns and time_offset: - df[DTTM_ALIAS] = df[DTTM_ALIAS].apply(increment_timestamp) - - return QueryResult( - df=df, - query=query_str, - duration=datetime.now() - qry_start_dttm) - - def get_filters(self, raw_filters): - filters = None - for flt in raw_filters: - if not all(f in flt for f in ['col', 'op', 'val']): - continue - col = flt['col'] - op = flt['op'] - eq = flt['val'] - cond = None - if op in ('in', 'not in'): - eq = [types.replace("'", '').strip() for types in eq] - elif not isinstance(flt['val'], basestring): - eq = eq[0] if len(eq) > 0 else '' - if col in self.num_cols: - if op in ('in', 'not in'): - eq = [utils.js_string_to_num(v) for v in eq] - else: - eq = utils.js_string_to_num(eq) - if op == '==': - cond = Dimension(col) == eq - elif op == '!=': - cond = ~(Dimension(col) == eq) - elif op in ('in', 'not in'): - fields = [] - if len(eq) > 1: - for s in eq: - fields.append(Dimension(col) == s) - cond = Filter(type="or", fields=fields) - elif len(eq) == 1: - cond = Dimension(col) == eq[0] - if op == 'not in': - cond = ~cond - elif op == 'regex': - cond = Filter(type="regex", pattern=eq, dimension=col) - if filters: - filters = Filter(type="and", fields=[ - cond, - filters - ]) - else: - filters = cond - return filters - - def _get_having_obj(self, col, op, eq): - cond = None - if op == '==': - if col in self.column_names: - cond = DimSelector(dimension=col, value=eq) - else: - cond = Aggregation(col) == eq - elif op == '>': - cond = Aggregation(col) > eq - elif op == '<': - cond = Aggregation(col) < eq - - return cond - - def get_having_filters(self, raw_filters): - filters = None - reversed_op_map = { - '!=': '==', - '>=': '<', - '<=': '>' - } - - for flt in raw_filters: - if not all(f in flt for f in ['col', 'op', 'val']): - continue - col = flt['col'] - op = flt['op'] - eq = flt['val'] - cond = None - if op in ['==', '>', '<']: - cond = self._get_having_obj(col, op, eq) - elif op in reversed_op_map: - cond = ~self._get_having_obj(col, reversed_op_map[op], eq) - - if filters: - filters = filters & cond - else: - filters = cond - return filters - -sqla.event.listen(DruidDatasource, 'after_insert', set_perm) -sqla.event.listen(DruidDatasource, 'after_update', set_perm) - - -class Log(Model): - - """ORM object used to log Superset actions to the database""" - - __tablename__ = 'logs' - - id = Column(Integer, primary_key=True) - action = Column(String(512)) - user_id = Column(Integer, ForeignKey('ab_user.id')) - dashboard_id = Column(Integer) - slice_id = Column(Integer) - json = Column(Text) - user = relationship('User', backref='logs', foreign_keys=[user_id]) - dttm = Column(DateTime, default=datetime.utcnow) - dt = Column(Date, default=date.today()) - duration_ms = Column(Integer) - referrer = Column(String(1024)) - - @classmethod - def log_this(cls, f): - """Decorator to log user actions""" - @functools.wraps(f) - def wrapper(*args, **kwargs): - start_dttm = datetime.now() - user_id = None - if g.user: - user_id = g.user.get_id() - d = request.args.to_dict() - post_data = request.form or {} - d.update(post_data) - d.update(kwargs) - slice_id = d.get('slice_id', 0) - try: - slice_id = int(slice_id) if slice_id else 0 - except ValueError: - slice_id = 0 - params = "" - try: - params = json.dumps(d) - except: - pass - value = f(*args, **kwargs) - - sesh = db.session() - log = cls( - action=f.__name__, - json=params, - dashboard_id=d.get('dashboard_id') or None, - slice_id=slice_id, - duration_ms=( - datetime.now() - start_dttm).total_seconds() * 1000, - referrer=request.referrer[:1000] if request.referrer else None, - user_id=user_id) - sesh.add(log) - sesh.commit() - return value - return wrapper - - -class FavStar(Model): - __tablename__ = 'favstar' - - id = Column(Integer, primary_key=True) - user_id = Column(Integer, ForeignKey('ab_user.id')) - class_name = Column(String(50)) - obj_id = Column(Integer) - dttm = Column(DateTime, default=datetime.utcnow) - - -class Query(Model): - - """ORM model for SQL query""" - - __tablename__ = 'query' - id = Column(Integer, primary_key=True) - client_id = Column(String(11), unique=True, nullable=False) - - database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False) - - # Store the tmp table into the DB only if the user asks for it. - tmp_table_name = Column(String(256)) - user_id = Column( - Integer, ForeignKey('ab_user.id'), nullable=True) - status = Column(String(16), default=QueryStatus.PENDING) - tab_name = Column(String(256)) - sql_editor_id = Column(String(256)) - schema = Column(String(256)) - sql = Column(Text) - # Query to retrieve the results, - # used only in case of select_as_cta_used is true. - select_sql = Column(Text) - executed_sql = Column(Text) - # Could be configured in the superset config. - limit = Column(Integer) - limit_used = Column(Boolean, default=False) - limit_reached = Column(Boolean, default=False) - select_as_cta = Column(Boolean) - select_as_cta_used = Column(Boolean, default=False) - - progress = Column(Integer, default=0) # 1..100 - # # of rows in the result set or rows modified. - rows = Column(Integer) - error_message = Column(Text) - # key used to store the results in the results backend - results_key = Column(String(64), index=True) - - # Using Numeric in place of DateTime for sub-second precision - # stored as seconds since epoch, allowing for milliseconds - start_time = Column(Numeric(precision=3)) - end_time = Column(Numeric(precision=3)) - changed_on = Column( - DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=True) - - database = relationship( - 'Database', - foreign_keys=[database_id], - backref=backref('queries', cascade='all, delete-orphan') - ) - user = relationship( - 'User', - backref=backref('queries', cascade='all, delete-orphan'), - foreign_keys=[user_id]) - - __table_args__ = ( - sqla.Index('ti_user_id_changed_on', user_id, changed_on), - ) - - @property - def limit_reached(self): - return self.rows == self.limit if self.limit_used else False - - def to_dict(self): - return { - 'changedOn': self.changed_on, - 'changed_on': self.changed_on.isoformat(), - 'dbId': self.database_id, - 'db': self.database.database_name, - 'endDttm': self.end_time, - 'errorMessage': self.error_message, - 'executedSql': self.executed_sql, - 'id': self.client_id, - 'limit': self.limit, - 'progress': self.progress, - 'rows': self.rows, - 'schema': self.schema, - 'ctas': self.select_as_cta, - 'serverId': self.id, - 'sql': self.sql, - 'sqlEditorId': self.sql_editor_id, - 'startDttm': self.start_time, - 'state': self.status.lower(), - 'tab': self.tab_name, - 'tempTable': self.tmp_table_name, - 'userId': self.user_id, - 'user': self.user.username, - 'limit_reached': self.limit_reached, - 'resultsKey': self.results_key, - } - - @property - def name(self): - ts = datetime.now().isoformat() - ts = ts.replace('-', '').replace(':', '').split('.')[0] - tab = self.tab_name.replace(' ', '_').lower() if self.tab_name else 'notab' - tab = re.sub(r'\W+', '', tab) - return "sqllab_{tab}_{ts}".format(**locals()) - - -class DatasourceAccessRequest(Model, AuditMixinNullable): - """ORM model for the access requests for datasources and dbs.""" - __tablename__ = 'access_request' - id = Column(Integer, primary_key=True) - - datasource_id = Column(Integer) - datasource_type = Column(String(200)) - - ROLES_BLACKLIST = set(config.get('ROBOT_PERMISSION_ROLES', [])) - - @property - def cls_model(self): - return SourceRegistry.sources[self.datasource_type] - - @property - def username(self): - return self.creator() - - @property - def datasource(self): - return self.get_datasource - - @datasource.getter - @utils.memoized - def get_datasource(self): - ds = db.session.query(self.cls_model).filter_by( - id=self.datasource_id).first() - return ds - - @property - def datasource_link(self): - return self.datasource.link - - @property - def roles_with_datasource(self): - action_list = '' - pv = sm.find_permission_view_menu( - 'datasource_access', self.datasource.perm) - for r in pv.role: - if r.name in self.ROLES_BLACKLIST: - continue - url = ( - '/superset/approve?datasource_type={self.datasource_type}&' - 'datasource_id={self.datasource_id}&' - 'created_by={self.created_by.username}&role_to_grant={r.name}' - .format(**locals()) - ) - href = 'Grant {} Role'.format(url, r.name) - action_list = action_list + '
  • ' + href + '
  • ' - return '' - - @property - def user_roles(self): - action_list = '' - for r in self.created_by.roles: - url = ( - '/superset/approve?datasource_type={self.datasource_type}&' - 'datasource_id={self.datasource_id}&' - 'created_by={self.created_by.username}&role_to_extend={r.name}' - .format(**locals()) - ) - href = 'Extend {} Role'.format(url, r.name) - if r.name in self.ROLES_BLACKLIST: - href = "{} Role".format(r.name) - action_list = action_list + '
  • ' + href + '
  • ' - return '' diff --git a/superset/models/__init__.py b/superset/models/__init__.py new file mode 100644 index 0000000000000..2a1dbbc06ff9e --- /dev/null +++ b/superset/models/__init__.py @@ -0,0 +1 @@ +from . import core # noqa diff --git a/superset/models/core.py b/superset/models/core.py new file mode 100644 index 0000000000000..e8f3f2bc8218b --- /dev/null +++ b/superset/models/core.py @@ -0,0 +1,952 @@ +"""A collection of ORM sqlalchemy models for Superset""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import functools +import json +import logging +import numpy +import pickle +import re +import sqlparse +import textwrap +from future.standard_library import install_aliases +from copy import copy +from datetime import datetime, date + +import pandas as pd +import sqlalchemy as sqla +from sqlalchemy.engine.url import make_url +from sqlalchemy.orm import subqueryload + +from flask import escape, g, Markup, request +from flask_appbuilder import Model +from flask_appbuilder.models.decorators import renders + +from sqlalchemy import ( + Column, Integer, String, ForeignKey, Text, Boolean, + DateTime, Date, Table, Numeric, + create_engine, MetaData, select +) +from sqlalchemy.orm import backref, relationship +from sqlalchemy.orm.session import make_transient +from sqlalchemy.sql import text +from sqlalchemy.sql.expression import TextAsFrom +from sqlalchemy_utils import EncryptedType + +from superset import app, db, db_engine_specs, utils, sm +from superset.connectors.connector_registry import ConnectorRegistry +from superset.viz import viz_types +from superset.utils import QueryStatus +from superset.models.helpers import AuditMixinNullable, ImportMixin, set_perm +install_aliases() +from urllib import parse # noqa + +config = app.config + + +def set_related_perm(mapper, connection, target): # noqa + src_class = target.cls_model + id_ = target.datasource_id + ds = db.session.query(src_class).filter_by(id=int(id_)).first() + target.perm = ds.perm + + +class Url(Model, AuditMixinNullable): + """Used for the short url feature""" + + __tablename__ = 'url' + id = Column(Integer, primary_key=True) + url = Column(Text) + + +class KeyValue(Model): + + """Used for any type of key-value store""" + + __tablename__ = 'keyvalue' + id = Column(Integer, primary_key=True) + value = Column(Text, nullable=False) + + +class CssTemplate(Model, AuditMixinNullable): + + """CSS templates for dashboards""" + + __tablename__ = 'css_templates' + id = Column(Integer, primary_key=True) + template_name = Column(String(250)) + css = Column(Text, default='') + + +slice_user = Table('slice_user', Model.metadata, + Column('id', Integer, primary_key=True), + Column('user_id', Integer, ForeignKey('ab_user.id')), + Column('slice_id', Integer, ForeignKey('slices.id')) + ) + + +class Slice(Model, AuditMixinNullable, ImportMixin): + + """A slice is essentially a report or a view on data""" + + __tablename__ = 'slices' + id = Column(Integer, primary_key=True) + slice_name = Column(String(250)) + datasource_id = Column(Integer) + datasource_type = Column(String(200)) + datasource_name = Column(String(2000)) + viz_type = Column(String(250)) + params = Column(Text) + description = Column(Text) + cache_timeout = Column(Integer) + perm = Column(String(1000)) + owners = relationship("User", secondary=slice_user) + + export_fields = ('slice_name', 'datasource_type', 'datasource_name', + 'viz_type', 'params', 'cache_timeout') + + def __repr__(self): + return self.slice_name + + @property + def cls_model(self): + return ConnectorRegistry.sources[self.datasource_type] + + @property + def datasource(self): + return self.get_datasource + + @datasource.getter + @utils.memoized + def get_datasource(self): + ds = db.session.query( + self.cls_model).filter_by( + id=self.datasource_id).first() + return ds + + @renders('datasource_name') + def datasource_link(self): + datasource = self.datasource + if datasource: + return self.datasource.link + + @property + def datasource_edit_url(self): + self.datasource.url + + @property + @utils.memoized + def viz(self): + d = json.loads(self.params) + viz_class = viz_types[self.viz_type] + return viz_class(self.datasource, form_data=d) + + @property + def description_markeddown(self): + return utils.markdown(self.description) + + @property + def data(self): + """Data used to render slice in templates""" + d = {} + self.token = '' + try: + d = self.viz.data + self.token = d.get('token') + except Exception as e: + logging.exception(e) + d['error'] = str(e) + return { + 'datasource': self.datasource_name, + 'description': self.description, + 'description_markeddown': self.description_markeddown, + 'edit_url': self.edit_url, + 'form_data': self.form_data, + 'slice_id': self.id, + 'slice_name': self.slice_name, + 'slice_url': self.slice_url, + } + + @property + def json_data(self): + return json.dumps(self.data) + + @property + def form_data(self): + form_data = json.loads(self.params) + form_data['slice_id'] = self.id + form_data['viz_type'] = self.viz_type + form_data['datasource'] = ( + str(self.datasource_id) + '__' + self.datasource_type) + return form_data + + @property + def slice_url(self): + """Defines the url to access the slice""" + return ( + "/superset/explore/{obj.datasource_type}/" + "{obj.datasource_id}/?form_data={params}".format( + obj=self, params=parse.quote(json.dumps(self.form_data)))) + + @property + def slice_id_url(self): + return ( + "/superset/{slc.datasource_type}/{slc.datasource_id}/{slc.id}/" + ).format(slc=self) + + @property + def edit_url(self): + return "/slicemodelview/edit/{}".format(self.id) + + @property + def slice_link(self): + url = self.slice_url + name = escape(self.slice_name) + return Markup('{name}'.format(**locals())) + + def get_viz(self, url_params_multidict=None): + """Creates :py:class:viz.BaseViz object from the url_params_multidict. + + :param werkzeug.datastructures.MultiDict url_params_multidict: + Contains the visualization params, they override the self.params + stored in the database + :return: object of the 'viz_type' type that is taken from the + url_params_multidict or self.params. + :rtype: :py:class:viz.BaseViz + """ + slice_params = json.loads(self.params) + slice_params['slice_id'] = self.id + slice_params['json'] = "false" + slice_params['slice_name'] = self.slice_name + slice_params['viz_type'] = self.viz_type if self.viz_type else "table" + + return viz_types[slice_params.get('viz_type')]( + self.datasource, + form_data=slice_params, + slice_=self + ) + + @classmethod + def import_obj(cls, slc_to_import, import_time=None): + """Inserts or overrides slc in the database. + + remote_id and import_time fields in params_dict are set to track the + slice origin and ensure correct overrides for multiple imports. + Slice.perm is used to find the datasources and connect them. + """ + session = db.session + make_transient(slc_to_import) + slc_to_import.dashboards = [] + slc_to_import.alter_params( + remote_id=slc_to_import.id, import_time=import_time) + + # find if the slice was already imported + slc_to_override = None + for slc in session.query(Slice).all(): + if ('remote_id' in slc.params_dict and + slc.params_dict['remote_id'] == slc_to_import.id): + slc_to_override = slc + + slc_to_import = slc_to_import.copy() + params = slc_to_import.params_dict + slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name( + session, slc_to_import.datasource_type, params['datasource_name'], + params['schema'], params['database_name']).id + if slc_to_override: + slc_to_override.override(slc_to_import) + session.flush() + return slc_to_override.id + session.add(slc_to_import) + logging.info('Final slice: {}'.format(slc_to_import.to_json())) + session.flush() + return slc_to_import.id + + +sqla.event.listen(Slice, 'before_insert', set_related_perm) +sqla.event.listen(Slice, 'before_update', set_related_perm) + + +dashboard_slices = Table( + 'dashboard_slices', Model.metadata, + Column('id', Integer, primary_key=True), + Column('dashboard_id', Integer, ForeignKey('dashboards.id')), + Column('slice_id', Integer, ForeignKey('slices.id')), +) + +dashboard_user = Table( + 'dashboard_user', Model.metadata, + Column('id', Integer, primary_key=True), + Column('user_id', Integer, ForeignKey('ab_user.id')), + Column('dashboard_id', Integer, ForeignKey('dashboards.id')) +) + + +class Dashboard(Model, AuditMixinNullable, ImportMixin): + + """The dashboard object!""" + + __tablename__ = 'dashboards' + id = Column(Integer, primary_key=True) + dashboard_title = Column(String(500)) + position_json = Column(Text) + description = Column(Text) + css = Column(Text) + json_metadata = Column(Text) + slug = Column(String(255), unique=True) + slices = relationship( + 'Slice', secondary=dashboard_slices, backref='dashboards') + owners = relationship("User", secondary=dashboard_user) + + export_fields = ('dashboard_title', 'position_json', 'json_metadata', + 'description', 'css', 'slug') + + def __repr__(self): + return self.dashboard_title + + @property + def table_names(self): + return ", ".join( + {"{}".format(s.datasource.name) for s in self.slices}) + + @property + def url(self): + return "/superset/dashboard/{}/".format(self.slug or self.id) + + @property + def datasources(self): + return {slc.datasource for slc in self.slices} + + @property + def sqla_metadata(self): + metadata = MetaData(bind=self.get_sqla_engine()) + return metadata.reflect() + + def dashboard_link(self): + title = escape(self.dashboard_title) + return Markup( + '{title}'.format(**locals())) + + @property + def json_data(self): + positions = self.position_json + if positions: + positions = json.loads(positions) + d = { + 'id': self.id, + 'metadata': self.params_dict, + 'css': self.css, + 'dashboard_title': self.dashboard_title, + 'slug': self.slug, + 'slices': [slc.data for slc in self.slices], + 'position_json': positions, + } + return json.dumps(d) + + @property + def params(self): + return self.json_metadata + + @params.setter + def params(self, value): + self.json_metadata = value + + @property + def position_array(self): + if self.position_json: + return json.loads(self.position_json) + return [] + + @classmethod + def import_obj(cls, dashboard_to_import, import_time=None): + """Imports the dashboard from the object to the database. + + Once dashboard is imported, json_metadata field is extended and stores + remote_id and import_time. It helps to decide if the dashboard has to + be overridden or just copies over. Slices that belong to this + dashboard will be wired to existing tables. This function can be used + to import/export dashboards between multiple superset instances. + Audit metadata isn't copies over. + """ + def alter_positions(dashboard, old_to_new_slc_id_dict): + """ Updates slice_ids in the position json. + + Sample position json: + [{ + "col": 5, + "row": 10, + "size_x": 4, + "size_y": 2, + "slice_id": "3610" + }] + """ + position_array = dashboard.position_array + for position in position_array: + if 'slice_id' not in position: + continue + old_slice_id = int(position['slice_id']) + if old_slice_id in old_to_new_slc_id_dict: + position['slice_id'] = '{}'.format( + old_to_new_slc_id_dict[old_slice_id]) + dashboard.position_json = json.dumps(position_array) + + logging.info('Started import of the dashboard: {}' + .format(dashboard_to_import.to_json())) + session = db.session + logging.info('Dashboard has {} slices' + .format(len(dashboard_to_import.slices))) + # copy slices object as Slice.import_slice will mutate the slice + # and will remove the existing dashboard - slice association + slices = copy(dashboard_to_import.slices) + old_to_new_slc_id_dict = {} + new_filter_immune_slices = [] + new_expanded_slices = {} + i_params_dict = dashboard_to_import.params_dict + for slc in slices: + logging.info('Importing slice {} from the dashboard: {}'.format( + slc.to_json(), dashboard_to_import.dashboard_title)) + new_slc_id = Slice.import_obj(slc, import_time=import_time) + old_to_new_slc_id_dict[slc.id] = new_slc_id + # update json metadata that deals with slice ids + new_slc_id_str = '{}'.format(new_slc_id) + old_slc_id_str = '{}'.format(slc.id) + if ('filter_immune_slices' in i_params_dict and + old_slc_id_str in i_params_dict['filter_immune_slices']): + new_filter_immune_slices.append(new_slc_id_str) + if ('expanded_slices' in i_params_dict and + old_slc_id_str in i_params_dict['expanded_slices']): + new_expanded_slices[new_slc_id_str] = ( + i_params_dict['expanded_slices'][old_slc_id_str]) + + # override the dashboard + existing_dashboard = None + for dash in session.query(Dashboard).all(): + if ('remote_id' in dash.params_dict and + dash.params_dict['remote_id'] == + dashboard_to_import.id): + existing_dashboard = dash + + dashboard_to_import.id = None + alter_positions(dashboard_to_import, old_to_new_slc_id_dict) + dashboard_to_import.alter_params(import_time=import_time) + if new_expanded_slices: + dashboard_to_import.alter_params( + expanded_slices=new_expanded_slices) + if new_filter_immune_slices: + dashboard_to_import.alter_params( + filter_immune_slices=new_filter_immune_slices) + + new_slices = session.query(Slice).filter( + Slice.id.in_(old_to_new_slc_id_dict.values())).all() + + if existing_dashboard: + existing_dashboard.override(dashboard_to_import) + existing_dashboard.slices = new_slices + session.flush() + return existing_dashboard.id + else: + # session.add(dashboard_to_import) causes sqlachemy failures + # related to the attached users / slices. Creating new object + # allows to avoid conflicts in the sql alchemy state. + copied_dash = dashboard_to_import.copy() + copied_dash.slices = new_slices + session.add(copied_dash) + session.flush() + return copied_dash.id + + @classmethod + def export_dashboards(cls, dashboard_ids): + copied_dashboards = [] + datasource_ids = set() + for dashboard_id in dashboard_ids: + # make sure that dashboard_id is an integer + dashboard_id = int(dashboard_id) + copied_dashboard = ( + db.session.query(Dashboard) + .options(subqueryload(Dashboard.slices)) + .filter_by(id=dashboard_id).first() + ) + make_transient(copied_dashboard) + for slc in copied_dashboard.slices: + datasource_ids.add((slc.datasource_id, slc.datasource_type)) + # add extra params for the import + slc.alter_params( + remote_id=slc.id, + datasource_name=slc.datasource.name, + schema=slc.datasource.name, + database_name=slc.datasource.database.name, + ) + copied_dashboard.alter_params(remote_id=dashboard_id) + copied_dashboards.append(copied_dashboard) + + eager_datasources = [] + for dashboard_id, dashboard_type in datasource_ids: + eager_datasource = ConnectorRegistry.get_eager_datasource( + db.session, dashboard_type, dashboard_id) + eager_datasource.alter_params( + remote_id=eager_datasource.id, + database_name=eager_datasource.database.name, + ) + make_transient(eager_datasource) + eager_datasources.append(eager_datasource) + + return pickle.dumps({ + 'dashboards': copied_dashboards, + 'datasources': eager_datasources, + }) + + +class Database(Model, AuditMixinNullable): + + """An ORM object that stores Database related information""" + + __tablename__ = 'dbs' + type = "table" + + id = Column(Integer, primary_key=True) + database_name = Column(String(250), unique=True) + sqlalchemy_uri = Column(String(1024)) + password = Column(EncryptedType(String(1024), config.get('SECRET_KEY'))) + cache_timeout = Column(Integer) + select_as_create_table_as = Column(Boolean, default=False) + expose_in_sqllab = Column(Boolean, default=False) + allow_run_sync = Column(Boolean, default=True) + allow_run_async = Column(Boolean, default=False) + allow_ctas = Column(Boolean, default=False) + allow_dml = Column(Boolean, default=False) + force_ctas_schema = Column(String(250)) + extra = Column(Text, default=textwrap.dedent("""\ + { + "metadata_params": {}, + "engine_params": {} + } + """)) + perm = Column(String(1000)) + + def __repr__(self): + return self.database_name + + @property + def name(self): + return self.database_name + + @property + def backend(self): + url = make_url(self.sqlalchemy_uri_decrypted) + return url.get_backend_name() + + def set_sqlalchemy_uri(self, uri): + password_mask = "X" * 10 + conn = sqla.engine.url.make_url(uri) + if conn.password != password_mask: + # do not over-write the password with the password mask + self.password = conn.password + conn.password = password_mask if conn.password else None + self.sqlalchemy_uri = str(conn) # hides the password + + def get_sqla_engine(self, schema=None): + extra = self.get_extra() + url = make_url(self.sqlalchemy_uri_decrypted) + params = extra.get('engine_params', {}) + url.database = self.get_database_for_various_backend(url, schema) + return create_engine(url, **params) + + def get_database_for_various_backend(self, uri, default_database=None): + database = uri.database + if self.backend == 'presto' and default_database: + if '/' in database: + database = database.split('/')[0] + '/' + default_database + else: + database += '/' + default_database + # Postgres and Redshift use the concept of schema as a logical entity + # on top of the database, so the database should not be changed + # even if passed default_database + elif self.backend == 'redshift' or self.backend == 'postgresql': + pass + elif default_database: + database = default_database + return database + + def get_reserved_words(self): + return self.get_sqla_engine().dialect.preparer.reserved_words + + def get_quoter(self): + return self.get_sqla_engine().dialect.identifier_preparer.quote + + def get_df(self, sql, schema): + sql = sql.strip().strip(';') + eng = self.get_sqla_engine(schema=schema) + cur = eng.execute(sql, schema=schema) + cols = [col[0] for col in cur.cursor.description] + df = pd.DataFrame(cur.fetchall(), columns=cols) + + def needs_conversion(df_series): + if df_series.empty: + return False + for df_type in [list, dict]: + if isinstance(df_series[0], df_type): + return True + return False + + for k, v in df.dtypes.iteritems(): + if v.type == numpy.object_ and needs_conversion(df[k]): + df[k] = df[k].apply(utils.json_dumps_w_dates) + return df + + def compile_sqla_query(self, qry, schema=None): + eng = self.get_sqla_engine(schema=schema) + compiled = qry.compile(eng, compile_kwargs={"literal_binds": True}) + return '{}'.format(compiled) + + def select_star( + self, table_name, schema=None, limit=100, show_cols=False, + indent=True): + """Generates a ``select *`` statement in the proper dialect""" + return self.db_engine_spec.select_star( + self, table_name, schema=schema, limit=limit, show_cols=show_cols, + indent=indent) + + def wrap_sql_limit(self, sql, limit=1000): + qry = ( + select('*') + .select_from( + TextAsFrom(text(sql), ['*']) + .alias('inner_qry') + ).limit(limit) + ) + return self.compile_sqla_query(qry) + + def safe_sqlalchemy_uri(self): + return self.sqlalchemy_uri + + @property + def inspector(self): + engine = self.get_sqla_engine() + return sqla.inspect(engine) + + def all_table_names(self, schema=None, force=False): + if not schema: + tables_dict = self.db_engine_spec.fetch_result_sets( + self, 'table', force=force) + return tables_dict.get("", []) + return sorted(self.inspector.get_table_names(schema)) + + def all_view_names(self, schema=None, force=False): + if not schema: + views_dict = self.db_engine_spec.fetch_result_sets( + self, 'view', force=force) + return views_dict.get("", []) + views = [] + try: + views = self.inspector.get_view_names(schema) + except Exception: + pass + return views + + def all_schema_names(self): + return sorted(self.inspector.get_schema_names()) + + @property + def db_engine_spec(self): + engine_name = self.get_sqla_engine().name or 'base' + return db_engine_specs.engines.get( + engine_name, db_engine_specs.BaseEngineSpec) + + def grains(self): + """Defines time granularity database-specific expressions. + + The idea here is to make it easy for users to change the time grain + form a datetime (maybe the source grain is arbitrary timestamps, daily + or 5 minutes increments) to another, "truncated" datetime. Since + each database has slightly different but similar datetime functions, + this allows a mapping between database engines and actual functions. + """ + return self.db_engine_spec.time_grains + + def grains_dict(self): + return {grain.name: grain for grain in self.grains()} + + def get_extra(self): + extra = {} + if self.extra: + try: + extra = json.loads(self.extra) + except Exception as e: + logging.error(e) + return extra + + def get_table(self, table_name, schema=None): + extra = self.get_extra() + meta = MetaData(**extra.get('metadata_params', {})) + return Table( + table_name, meta, + schema=schema or None, + autoload=True, + autoload_with=self.get_sqla_engine()) + + def get_columns(self, table_name, schema=None): + return self.inspector.get_columns(table_name, schema) + + def get_indexes(self, table_name, schema=None): + return self.inspector.get_indexes(table_name, schema) + + def get_pk_constraint(self, table_name, schema=None): + return self.inspector.get_pk_constraint(table_name, schema) + + def get_foreign_keys(self, table_name, schema=None): + return self.inspector.get_foreign_keys(table_name, schema) + + @property + def sqlalchemy_uri_decrypted(self): + conn = sqla.engine.url.make_url(self.sqlalchemy_uri) + conn.password = self.password + return str(conn) + + @property + def sql_url(self): + return '/superset/sql/{}/'.format(self.id) + + def get_perm(self): + return ( + "[{obj.database_name}].(id:{obj.id})").format(obj=self) + +sqla.event.listen(Database, 'after_insert', set_perm) +sqla.event.listen(Database, 'after_update', set_perm) + + +class Log(Model): + + """ORM object used to log Superset actions to the database""" + + __tablename__ = 'logs' + + id = Column(Integer, primary_key=True) + action = Column(String(512)) + user_id = Column(Integer, ForeignKey('ab_user.id')) + dashboard_id = Column(Integer) + slice_id = Column(Integer) + json = Column(Text) + user = relationship('User', backref='logs', foreign_keys=[user_id]) + dttm = Column(DateTime, default=datetime.utcnow) + dt = Column(Date, default=date.today()) + duration_ms = Column(Integer) + referrer = Column(String(1024)) + + @classmethod + def log_this(cls, f): + """Decorator to log user actions""" + @functools.wraps(f) + def wrapper(*args, **kwargs): + start_dttm = datetime.now() + user_id = None + if g.user: + user_id = g.user.get_id() + d = request.args.to_dict() + post_data = request.form or {} + d.update(post_data) + d.update(kwargs) + slice_id = d.get('slice_id', 0) + try: + slice_id = int(slice_id) if slice_id else 0 + except ValueError: + slice_id = 0 + params = "" + try: + params = json.dumps(d) + except: + pass + value = f(*args, **kwargs) + + sesh = db.session() + log = cls( + action=f.__name__, + json=params, + dashboard_id=d.get('dashboard_id') or None, + slice_id=slice_id, + duration_ms=( + datetime.now() - start_dttm).total_seconds() * 1000, + referrer=request.referrer[:1000] if request.referrer else None, + user_id=user_id) + sesh.add(log) + sesh.commit() + return value + return wrapper + + +class FavStar(Model): + __tablename__ = 'favstar' + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey('ab_user.id')) + class_name = Column(String(50)) + obj_id = Column(Integer) + dttm = Column(DateTime, default=datetime.utcnow) + + +class Query(Model): + + """ORM model for SQL query""" + + __tablename__ = 'query' + id = Column(Integer, primary_key=True) + client_id = Column(String(11), unique=True, nullable=False) + + database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False) + + # Store the tmp table into the DB only if the user asks for it. + tmp_table_name = Column(String(256)) + user_id = Column( + Integer, ForeignKey('ab_user.id'), nullable=True) + status = Column(String(16), default=QueryStatus.PENDING) + tab_name = Column(String(256)) + sql_editor_id = Column(String(256)) + schema = Column(String(256)) + sql = Column(Text) + # Query to retrieve the results, + # used only in case of select_as_cta_used is true. + select_sql = Column(Text) + executed_sql = Column(Text) + # Could be configured in the superset config. + limit = Column(Integer) + limit_used = Column(Boolean, default=False) + limit_reached = Column(Boolean, default=False) + select_as_cta = Column(Boolean) + select_as_cta_used = Column(Boolean, default=False) + + progress = Column(Integer, default=0) # 1..100 + # # of rows in the result set or rows modified. + rows = Column(Integer) + error_message = Column(Text) + # key used to store the results in the results backend + results_key = Column(String(64), index=True) + + # Using Numeric in place of DateTime for sub-second precision + # stored as seconds since epoch, allowing for milliseconds + start_time = Column(Numeric(precision=3)) + end_time = Column(Numeric(precision=3)) + changed_on = Column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=True) + + database = relationship( + 'Database', + foreign_keys=[database_id], + backref=backref('queries', cascade='all, delete-orphan') + ) + user = relationship( + 'User', + backref=backref('queries', cascade='all, delete-orphan'), + foreign_keys=[user_id]) + + __table_args__ = ( + sqla.Index('ti_user_id_changed_on', user_id, changed_on), + ) + + @property + def limit_reached(self): + return self.rows == self.limit if self.limit_used else False + + def to_dict(self): + return { + 'changedOn': self.changed_on, + 'changed_on': self.changed_on.isoformat(), + 'dbId': self.database_id, + 'db': self.database.database_name, + 'endDttm': self.end_time, + 'errorMessage': self.error_message, + 'executedSql': self.executed_sql, + 'id': self.client_id, + 'limit': self.limit, + 'progress': self.progress, + 'rows': self.rows, + 'schema': self.schema, + 'ctas': self.select_as_cta, + 'serverId': self.id, + 'sql': self.sql, + 'sqlEditorId': self.sql_editor_id, + 'startDttm': self.start_time, + 'state': self.status.lower(), + 'tab': self.tab_name, + 'tempTable': self.tmp_table_name, + 'userId': self.user_id, + 'user': self.user.username, + 'limit_reached': self.limit_reached, + 'resultsKey': self.results_key, + } + + @property + def name(self): + ts = datetime.now().isoformat() + ts = ts.replace('-', '').replace(':', '').split('.')[0] + tab = self.tab_name.replace(' ', '_').lower() if self.tab_name else 'notab' + tab = re.sub(r'\W+', '', tab) + return "sqllab_{tab}_{ts}".format(**locals()) + + +class DatasourceAccessRequest(Model, AuditMixinNullable): + """ORM model for the access requests for datasources and dbs.""" + __tablename__ = 'access_request' + id = Column(Integer, primary_key=True) + + datasource_id = Column(Integer) + datasource_type = Column(String(200)) + + ROLES_BLACKLIST = set(config.get('ROBOT_PERMISSION_ROLES', [])) + + @property + def cls_model(self): + return ConnectorRegistry.sources[self.datasource_type] + + @property + def username(self): + return self.creator() + + @property + def datasource(self): + return self.get_datasource + + @datasource.getter + @utils.memoized + def get_datasource(self): + ds = db.session.query(self.cls_model).filter_by( + id=self.datasource_id).first() + return ds + + @property + def datasource_link(self): + return self.datasource.link + + @property + def roles_with_datasource(self): + action_list = '' + pv = sm.find_permission_view_menu( + 'datasource_access', self.datasource.perm) + for r in pv.role: + if r.name in self.ROLES_BLACKLIST: + continue + url = ( + '/superset/approve?datasource_type={self.datasource_type}&' + 'datasource_id={self.datasource_id}&' + 'created_by={self.created_by.username}&role_to_grant={r.name}' + .format(**locals()) + ) + href = 'Grant {} Role'.format(url, r.name) + action_list = action_list + '
  • ' + href + '
  • ' + return '' + + @property + def user_roles(self): + action_list = '' + for r in self.created_by.roles: + url = ( + '/superset/approve?datasource_type={self.datasource_type}&' + 'datasource_id={self.datasource_id}&' + 'created_by={self.created_by.username}&role_to_extend={r.name}' + .format(**locals()) + ) + href = 'Extend {} Role'.format(url, r.name) + if r.name in self.ROLES_BLACKLIST: + href = "{} Role".format(r.name) + action_list = action_list + '
  • ' + href + '
  • ' + return '' diff --git a/superset/models/helpers.py b/superset/models/helpers.py new file mode 100644 index 0000000000000..6082ed1923a46 --- /dev/null +++ b/superset/models/helpers.py @@ -0,0 +1,127 @@ +from datetime import datetime +import humanize +import json +import re +import sqlalchemy as sa + +from sqlalchemy.ext.declarative import declared_attr + +from flask import escape, Markup +from flask_appbuilder.models.mixins import AuditMixin +from flask_appbuilder.models.decorators import renders +from superset.utils import QueryStatus + + +class ImportMixin(object): + def override(self, obj): + """Overrides the plain fields of the dashboard.""" + for field in obj.__class__.export_fields: + setattr(self, field, getattr(obj, field)) + + def copy(self): + """Creates a copy of the dashboard without relationships.""" + new_obj = self.__class__() + new_obj.override(self) + return new_obj + + def alter_params(self, **kwargs): + d = self.params_dict + d.update(kwargs) + self.params = json.dumps(d) + + @property + def params_dict(self): + if self.params: + params = re.sub(",[ \t\r\n]+}", "}", self.params) + params = re.sub(",[ \t\r\n]+\]", "]", params) + return json.loads(params) + else: + return {} + + +class AuditMixinNullable(AuditMixin): + + """Altering the AuditMixin to use nullable fields + + Allows creating objects programmatically outside of CRUD + """ + + created_on = sa.Column(sa.DateTime, default=datetime.now, nullable=True) + changed_on = sa.Column( + sa.DateTime, default=datetime.now, + onupdate=datetime.now, nullable=True) + + @declared_attr + def created_by_fk(cls): # noqa + return sa.Column( + sa.Integer, sa.ForeignKey('ab_user.id'), + default=cls.get_user_id, nullable=True) + + @declared_attr + def changed_by_fk(cls): # noqa + return sa.Column( + sa.Integer, sa.ForeignKey('ab_user.id'), + default=cls.get_user_id, onupdate=cls.get_user_id, nullable=True) + + def _user_link(self, user): + if not user: + return '' + url = '/superset/profile/{}/'.format(user.username) + return Markup('{}'.format(url, escape(user) or '')) + + @renders('created_by') + def creator(self): # noqa + return self._user_link(self.created_by) + + @property + def changed_by_(self): + return self._user_link(self.changed_by) + + @renders('changed_on') + def changed_on_(self): + return Markup( + '{}'.format(self.changed_on)) + + @renders('changed_on') + def modified(self): + s = humanize.naturaltime(datetime.now() - self.changed_on) + return Markup('{}'.format(s)) + + @property + def icons(self): + return """ + + + + """.format(**locals()) + + +class QueryResult(object): + + """Object returned by the query interface""" + + def __init__( # noqa + self, + df, + query, + duration, + status=QueryStatus.SUCCESS, + error_message=None): + self.df = df + self.query = query + self.duration = duration + self.status = status + self.error_message = error_message + + +def set_perm(mapper, connection, target): # noqa + if target.perm != target.get_perm(): + link_table = target.__table__ + connection.execute( + link_table.update() + .where(link_table.c.id == target.id) + .values(perm=target.get_perm()) + ) diff --git a/superset/security.py b/superset/security.py index 1834213f040fd..795f5f45322ff 100644 --- a/superset/security.py +++ b/superset/security.py @@ -6,7 +6,9 @@ import logging from flask_appbuilder.security.sqla import models as ab_models -from superset import conf, db, models, sm, source_registry +from superset import conf, db, sm +from superset.models import core as models +from superset.connectors.connector_registry import ConnectorRegistry READ_ONLY_MODEL_VIEWS = { @@ -155,7 +157,7 @@ def create_custom_permissions(): def create_missing_datasource_perms(view_menu_set): logging.info("Creating missing datasource permissions.") - datasources = source_registry.SourceRegistry.get_all_datasources( + datasources = ConnectorRegistry.get_all_datasources( db.session) for datasource in datasources: if datasource and datasource.perm not in view_menu_set: diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 13a787145d082..4c52adc2474b4 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -16,7 +16,7 @@ from superset.sql_parse import SupersetQuery from superset.db_engine_specs import LimitMethod from superset.jinja_context import get_template_processor -QueryStatus = models.QueryStatus +from superset.utils import QueryStatus celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG')) diff --git a/superset/utils.py b/superset/utils.py index 9dfd3296ec75d..9d7f50d646974 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -438,7 +438,7 @@ def ping_connection(dbapi_connection, connection_record, connection_proxy): cursor.close() -class QueryStatus: +class QueryStatus(object): """Enum-type class for query statuses""" diff --git a/superset/views/__init__.py b/superset/views/__init__.py new file mode 100644 index 0000000000000..6a410e5e906f7 --- /dev/null +++ b/superset/views/__init__.py @@ -0,0 +1,2 @@ +from . import base # noqa +from . import core # noqa diff --git a/superset/views/base.py b/superset/views/base.py new file mode 100644 index 0000000000000..46dbf99f1ef05 --- /dev/null +++ b/superset/views/base.py @@ -0,0 +1,201 @@ +import logging +import json + +from flask import g, redirect +from flask_babel import gettext as __ + +from flask_appbuilder import BaseView +from flask_appbuilder import ModelView +from flask_appbuilder.widgets import ListWidget +from flask_appbuilder.actions import action +from flask_appbuilder.models.sqla.filters import BaseFilter +from flask_appbuilder.security.sqla import models as ab_models + +from superset import appbuilder, config, db, utils, sm, sql_parse +from superset.connectors.connector_registry import ConnectorRegistry + + +def get_datasource_exist_error_mgs(full_name): + return __("Datasource %(name)s already exists", name=full_name) + + +def get_user_roles(): + if g.user.is_anonymous(): + public_role = config.get('AUTH_ROLE_PUBLIC') + return [appbuilder.sm.find_role(public_role)] if public_role else [] + return g.user.roles + + +class BaseSupersetView(BaseView): + def can_access(self, permission_name, view_name, user=None): + if not user: + user = g.user + return utils.can_access( + appbuilder.sm, permission_name, view_name, user) + + def all_datasource_access(self, user=None): + return self.can_access( + "all_datasource_access", "all_datasource_access", user=user) + + def database_access(self, database, user=None): + return ( + self.can_access( + "all_database_access", "all_database_access", user=user) or + self.can_access("database_access", database.perm, user=user) + ) + + def schema_access(self, datasource, user=None): + return ( + self.database_access(datasource.database, user=user) or + self.all_datasource_access(user=user) or + self.can_access("schema_access", datasource.schema_perm, user=user) + ) + + def datasource_access(self, datasource, user=None): + return ( + self.schema_access(datasource, user=user) or + self.can_access("datasource_access", datasource.perm, user=user) + ) + + def datasource_access_by_name( + self, database, datasource_name, schema=None): + if self.database_access(database) or self.all_datasource_access(): + return True + + schema_perm = utils.get_schema_perm(database, schema) + if schema and utils.can_access( + sm, 'schema_access', schema_perm, g.user): + return True + + datasources = ConnectorRegistry.query_datasources_by_name( + db.session, database, datasource_name, schema=schema) + for datasource in datasources: + if self.can_access("datasource_access", datasource.perm): + return True + return False + + def datasource_access_by_fullname( + self, database, full_table_name, schema): + table_name_pieces = full_table_name.split(".") + if len(table_name_pieces) == 2: + table_schema = table_name_pieces[0] + table_name = table_name_pieces[1] + else: + table_schema = schema + table_name = table_name_pieces[0] + return self.datasource_access_by_name( + database, table_name, schema=table_schema) + + def rejected_datasources(self, sql, database, schema): + superset_query = sql_parse.SupersetQuery(sql) + return [ + t for t in superset_query.tables if not + self.datasource_access_by_fullname(database, t, schema)] + + def accessible_by_user(self, database, datasource_names, schema=None): + if self.database_access(database) or self.all_datasource_access(): + return datasource_names + + schema_perm = utils.get_schema_perm(database, schema) + if schema and utils.can_access( + sm, 'schema_access', schema_perm, g.user): + return datasource_names + + role_ids = set([role.id for role in g.user.roles]) + # TODO: cache user_perms or user_datasources + user_pvms = ( + db.session.query(ab_models.PermissionView) + .join(ab_models.Permission) + .filter(ab_models.Permission.name == 'datasource_access') + .filter(ab_models.PermissionView.role.any( + ab_models.Role.id.in_(role_ids))) + .all() + ) + user_perms = set([pvm.view_menu.name for pvm in user_pvms]) + user_datasources = ConnectorRegistry.query_datasources_by_permissions( + db.session, database, user_perms) + full_names = set([d.full_name for d in user_datasources]) + return [d for d in datasource_names if d in full_names] + + +class SupersetModelView(ModelView): + page_size = 500 + + +class ListWidgetWithCheckboxes(ListWidget): + """An alternative to list view that renders Boolean fields as checkboxes + + Works in conjunction with the `checkbox` view.""" + template = 'superset/fab_overrides/list_with_checkboxes.html' + + +def validate_json(form, field): # noqa + try: + json.loads(field.data) + except Exception as e: + logging.exception(e) + raise Exception("json isn't valid") + + +class DeleteMixin(object): + @action( + "muldelete", "Delete", "Delete all Really?", "fa-trash", single=False) + def muldelete(self, items): + self.datamodel.delete_all(items) + self.update_redirect() + return redirect(self.get_redirect()) + + +class SupersetFilter(BaseFilter): + + """Add utility function to make BaseFilter easy and fast + + These utility function exist in the SecurityManager, but would do + a database round trip at every check. Here we cache the role objects + to be able to make multiple checks but query the db only once + """ + + def get_user_roles(self): + return get_user_roles() + + def get_all_permissions(self): + """Returns a set of tuples with the perm name and view menu name""" + perms = set() + for role in self.get_user_roles(): + for perm_view in role.permissions: + t = (perm_view.permission.name, perm_view.view_menu.name) + perms.add(t) + return perms + + def has_role(self, role_name_or_list): + """Whether the user has this role name""" + if not isinstance(role_name_or_list, list): + role_name_or_list = [role_name_or_list] + return any( + [r.name in role_name_or_list for r in self.get_user_roles()]) + + def has_perm(self, permission_name, view_menu_name): + """Whether the user has this perm""" + return (permission_name, view_menu_name) in self.get_all_permissions() + + def get_view_menus(self, permission_name): + """Returns the details of view_menus for a perm name""" + vm = set() + for perm_name, vm_name in self.get_all_permissions(): + if perm_name == permission_name: + vm.add(vm_name) + return vm + + def has_all_datasource_access(self): + return ( + self.has_role(['Admin', 'Alpha']) or + self.has_perm('all_datasource_access', 'all_datasource_access')) + + +class DatasourceFilter(SupersetFilter): + def apply(self, query, func): # noqa + if self.has_all_datasource_access(): + return query + perms = self.get_view_menus('datasource_access') + # TODO(bogdan): add `schema_access` support here + return query.filter(self.model.perm.in_(perms)) diff --git a/superset/views.py b/superset/views/core.py similarity index 78% rename from superset/views.py rename to superset/views/core.py index 6e3eab0a98197..fa254d979b003 100755 --- a/superset/views.py +++ b/superset/views/core.py @@ -19,12 +19,10 @@ from flask import ( g, request, redirect, flash, Response, render_template, Markup) -from flask_appbuilder import ModelView, CompactCRUDMixin, BaseView, expose +from flask_appbuilder import expose from flask_appbuilder.actions import action from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.security.decorators import has_access_api -from flask_appbuilder.widgets import ListWidget -from flask_appbuilder.models.sqla.filters import BaseFilter from flask_appbuilder.security.sqla import models as ab_models from flask_babel import gettext as __ @@ -33,120 +31,26 @@ from sqlalchemy import create_engine from werkzeug.routing import BaseConverter -import superset from superset import ( - appbuilder, cache, db, models, viz, utils, app, - sm, sql_lab, sql_parse, results_backend, security, + appbuilder, cache, db, viz, utils, app, + sm, sql_lab, results_backend, security, ) from superset.legacy import cast_form_data from superset.utils import has_access -from superset.source_registry import SourceRegistry -from superset.models import DatasourceAccessRequest as DAR +from superset.connectors.connector_registry import ConnectorRegistry +import superset.models.core as models from superset.sql_parse import SupersetQuery +from .base import ( + SupersetModelView, BaseSupersetView, DeleteMixin, + SupersetFilter, get_user_roles +) + config = app.config log_this = models.Log.log_this can_access = utils.can_access QueryStatus = models.QueryStatus - - -class BaseSupersetView(BaseView): - def can_access(self, permission_name, view_name, user=None): - if not user: - user = g.user - return utils.can_access( - appbuilder.sm, permission_name, view_name, user) - - def all_datasource_access(self, user=None): - return self.can_access( - "all_datasource_access", "all_datasource_access", user=user) - - def database_access(self, database, user=None): - return ( - self.can_access( - "all_database_access", "all_database_access", user=user) or - self.can_access("database_access", database.perm, user=user) - ) - - def schema_access(self, datasource, user=None): - return ( - self.database_access(datasource.database, user=user) or - self.all_datasource_access(user=user) or - self.can_access("schema_access", datasource.schema_perm, user=user) - ) - - def datasource_access(self, datasource, user=None): - return ( - self.schema_access(datasource, user=user) or - self.can_access("datasource_access", datasource.perm, user=user) - ) - - def datasource_access_by_name( - self, database, datasource_name, schema=None): - if self.database_access(database) or self.all_datasource_access(): - return True - - schema_perm = utils.get_schema_perm(database, schema) - if schema and utils.can_access( - sm, 'schema_access', schema_perm, g.user): - return True - - datasources = SourceRegistry.query_datasources_by_name( - db.session, database, datasource_name, schema=schema) - for datasource in datasources: - if self.can_access("datasource_access", datasource.perm): - return True - return False - - def datasource_access_by_fullname( - self, database, full_table_name, schema): - table_name_pieces = full_table_name.split(".") - if len(table_name_pieces) == 2: - table_schema = table_name_pieces[0] - table_name = table_name_pieces[1] - else: - table_schema = schema - table_name = table_name_pieces[0] - return self.datasource_access_by_name( - database, table_name, schema=table_schema) - - def rejected_datasources(self, sql, database, schema): - superset_query = sql_parse.SupersetQuery(sql) - return [ - t for t in superset_query.tables if not - self.datasource_access_by_fullname(database, t, schema)] - - def accessible_by_user(self, database, datasource_names, schema=None): - if self.database_access(database) or self.all_datasource_access(): - return datasource_names - - schema_perm = utils.get_schema_perm(database, schema) - if schema and utils.can_access( - sm, 'schema_access', schema_perm, g.user): - return datasource_names - - role_ids = set([role.id for role in g.user.roles]) - # TODO: cache user_perms or user_datasources - user_pvms = ( - db.session.query(ab_models.PermissionView) - .join(ab_models.Permission) - .filter(ab_models.Permission.name == 'datasource_access') - .filter(ab_models.PermissionView.role.any( - ab_models.Role.id.in_(role_ids))) - .all() - ) - user_perms = set([pvm.view_menu.name for pvm in user_pvms]) - user_datasources = SourceRegistry.query_datasources_by_permissions( - db.session, database, user_perms) - full_names = set([d.full_name for d in user_datasources]) - return [d for d in datasource_names if d in full_names] - - -class ListWidgetWithCheckboxes(ListWidget): - """An alternative to list view that renders Boolean fields as checkboxes - - Works in conjunction with the `checkbox` view.""" - template = 'superset/fab_overrides/list_with_checkboxes.html' +DAR = models.DatasourceAccessRequest ALL_DATASOURCE_ACCESS_ERR = __( @@ -168,10 +72,6 @@ def get_datasource_access_error_msg(datasource_name): "`all_datasource_access` permission", name=datasource_name) -def get_datasource_exist_error_mgs(full_name): - return __("Datasource %(name)s already exists", name=full_name) - - def get_error_msg(): if config.get("SHOW_STACKTRACE"): error_msg = traceback.format_exc() @@ -211,10 +111,12 @@ def wraps(self, *args, **kwargs): return functools.update_wrapper(wraps, f) + def is_owner(obj, user): """ Check if user is owner of the slice """ return obj and obj.owners and user in obj.owners + def check_ownership(obj, raise_if_false=True): """Meant to be used in `pre_update` hooks on models to enforce ownership @@ -257,68 +159,6 @@ def check_ownership(obj, raise_if_false=True): return False -def get_user_roles(): - if g.user.is_anonymous(): - public_role = config.get('AUTH_ROLE_PUBLIC') - return [appbuilder.sm.find_role(public_role)] if public_role else [] - return g.user.roles - - -class SupersetFilter(BaseFilter): - - """Add utility function to make BaseFilter easy and fast - - These utility function exist in the SecurityManager, but would do - a database round trip at every check. Here we cache the role objects - to be able to make multiple checks but query the db only once - """ - - def get_user_roles(self): - return get_user_roles() - - def get_all_permissions(self): - """Returns a set of tuples with the perm name and view menu name""" - perms = set() - for role in get_user_roles(): - for perm_view in role.permissions: - t = (perm_view.permission.name, perm_view.view_menu.name) - perms.add(t) - return perms - - def has_role(self, role_name_or_list): - """Whether the user has this role name""" - if not isinstance(role_name_or_list, list): - role_name_or_list = [role_name_or_list] - return any( - [r.name in role_name_or_list for r in self.get_user_roles()]) - - def has_perm(self, permission_name, view_menu_name): - """Whether the user has this perm""" - return (permission_name, view_menu_name) in self.get_all_permissions() - - def get_view_menus(self, permission_name): - """Returns the details of view_menus for a perm name""" - vm = set() - for perm_name, vm_name in self.get_all_permissions(): - if perm_name == permission_name: - vm.add(vm_name) - return vm - - def has_all_datasource_access(self): - return ( - self.has_role(['Admin', 'Alpha']) or - self.has_perm('all_datasource_access', 'all_datasource_access')) - - -class DatasourceFilter(SupersetFilter): - def apply(self, query, func): # noqa - if self.has_all_datasource_access(): - return query - perms = self.get_view_menus('datasource_access') - # TODO(bogdan): add `schema_access` support here - return query.filter(self.model.perm.in_(perms)) - - class SliceFilter(SupersetFilter): def apply(self, query, func): # noqa if self.has_all_datasource_access(): @@ -355,14 +195,6 @@ def apply(self, query, func): # noqa return query -def validate_json(form, field): # noqa - try: - json.loads(field.data) - except Exception as e: - logging.exception(e) - raise Exception("json isn't valid") - - def generate_download_headers(extension): filename = datetime.now().strftime("%Y%m%d_%H%M%S") content_disp = "attachment; filename={}.{}".format(filename, extension) @@ -372,206 +204,6 @@ def generate_download_headers(extension): return headers -class DeleteMixin(object): - @action( - "muldelete", "Delete", "Delete all Really?", "fa-trash", single=False) - def muldelete(self, items): - self.datamodel.delete_all(items) - self.update_redirect() - return redirect(self.get_redirect()) - - -class SupersetModelView(ModelView): - page_size = 500 - - -class TableColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa - datamodel = SQLAInterface(models.TableColumn) - can_delete = False - list_widget = ListWidgetWithCheckboxes - edit_columns = [ - 'column_name', 'verbose_name', 'description', 'groupby', 'filterable', - 'table', 'count_distinct', 'sum', 'min', 'max', 'expression', - 'is_dttm', 'python_date_format', 'database_expression'] - add_columns = edit_columns - list_columns = [ - 'column_name', 'type', 'groupby', 'filterable', 'count_distinct', - 'sum', 'min', 'max', 'is_dttm'] - page_size = 500 - description_columns = { - 'is_dttm': (_( - "Whether to make this column available as a " - "[Time Granularity] option, column has to be DATETIME or " - "DATETIME-like")), - 'expression': utils.markdown( - "a valid SQL expression as supported by the underlying backend. " - "Example: `substr(name, 1, 1)`", True), - 'python_date_format': utils.markdown(Markup( - "The pattern of timestamp format, use " - "" - "python datetime string pattern " - "expression. If time is stored in epoch " - "format, put `epoch_s` or `epoch_ms`. Leave `Database Expression` " - "below empty if timestamp is stored in " - "String or Integer(epoch) type"), True), - 'database_expression': utils.markdown( - "The database expression to cast internal datetime " - "constants to database date/timestamp type according to the DBAPI. " - "The expression should follow the pattern of " - "%Y-%m-%d %H:%M:%S, based on different DBAPI. " - "The string should be a python string formatter \n" - "`Ex: TO_DATE('{}', 'YYYY-MM-DD HH24:MI:SS')` for Oracle" - "Superset uses default expression based on DB URI if this " - "field is blank.", True), - } - label_columns = { - 'column_name': _("Column"), - 'verbose_name': _("Verbose Name"), - 'description': _("Description"), - 'groupby': _("Groupable"), - 'filterable': _("Filterable"), - 'table': _("Table"), - 'count_distinct': _("Count Distinct"), - 'sum': _("Sum"), - 'min': _("Min"), - 'max': _("Max"), - 'expression': _("Expression"), - 'is_dttm': _("Is temporal"), - 'python_date_format': _("Datetime Format"), - 'database_expression': _("Database Expression") - } -appbuilder.add_view_no_menu(TableColumnInlineView) - - -class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): # noqa - datamodel = SQLAInterface(models.DruidColumn) - edit_columns = [ - 'column_name', 'description', 'dimension_spec_json', 'datasource', - 'groupby', 'count_distinct', 'sum', 'min', 'max'] - add_columns = edit_columns - list_columns = [ - 'column_name', 'type', 'groupby', 'filterable', 'count_distinct', - 'sum', 'min', 'max'] - can_delete = False - page_size = 500 - label_columns = { - 'column_name': _("Column"), - 'type': _("Type"), - 'datasource': _("Datasource"), - 'groupby': _("Groupable"), - 'filterable': _("Filterable"), - 'count_distinct': _("Count Distinct"), - 'sum': _("Sum"), - 'min': _("Min"), - 'max': _("Max"), - } - description_columns = { - 'dimension_spec_json': utils.markdown( - "this field can be used to specify " - "a `dimensionSpec` as documented [here]" - "(http://druid.io/docs/latest/querying/dimensionspecs.html). " - "Make sure to input valid JSON and that the " - "`outputName` matches the `column_name` defined " - "above.", - True), - } - - def post_update(self, col): - col.generate_metrics() - utils.validate_json(col.dimension_spec_json) - - def post_add(self, col): - self.post_update(col) - -appbuilder.add_view_no_menu(DruidColumnInlineView) - - -class SqlMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa - datamodel = SQLAInterface(models.SqlMetric) - list_columns = ['metric_name', 'verbose_name', 'metric_type'] - edit_columns = [ - 'metric_name', 'description', 'verbose_name', 'metric_type', - 'expression', 'table', 'd3format', 'is_restricted'] - description_columns = { - 'expression': utils.markdown( - "a valid SQL expression as supported by the underlying backend. " - "Example: `count(DISTINCT userid)`", True), - 'is_restricted': _("Whether the access to this metric is restricted " - "to certain roles. Only roles with the permission " - "'metric access on XXX (the name of this metric)' " - "are allowed to access this metric"), - 'd3format': utils.markdown( - "d3 formatting string as defined [here]" - "(https://github.com/d3/d3-format/blob/master/README.md#format). " - "For instance, this default formatting applies in the Table " - "visualization and allow for different metric to use different " - "formats", True - ), - } - add_columns = edit_columns - page_size = 500 - label_columns = { - 'metric_name': _("Metric"), - 'description': _("Description"), - 'verbose_name': _("Verbose Name"), - 'metric_type': _("Type"), - 'expression': _("SQL Expression"), - 'table': _("Table"), - } - - def post_add(self, metric): - if metric.is_restricted: - security.merge_perm(sm, 'metric_access', metric.get_perm()) - - def post_update(self, metric): - if metric.is_restricted: - security.merge_perm(sm, 'metric_access', metric.get_perm()) - -appbuilder.add_view_no_menu(SqlMetricInlineView) - - -class DruidMetricInlineView(CompactCRUDMixin, SupersetModelView): # noqa - datamodel = SQLAInterface(models.DruidMetric) - list_columns = ['metric_name', 'verbose_name', 'metric_type'] - edit_columns = [ - 'metric_name', 'description', 'verbose_name', 'metric_type', 'json', - 'datasource', 'd3format', 'is_restricted'] - add_columns = edit_columns - page_size = 500 - validators_columns = { - 'json': [validate_json], - } - description_columns = { - 'metric_type': utils.markdown( - "use `postagg` as the metric type if you are defining a " - "[Druid Post Aggregation]" - "(http://druid.io/docs/latest/querying/post-aggregations.html)", - True), - 'is_restricted': _("Whether the access to this metric is restricted " - "to certain roles. Only roles with the permission " - "'metric access on XXX (the name of this metric)' " - "are allowed to access this metric"), - } - label_columns = { - 'metric_name': _("Metric"), - 'description': _("Description"), - 'verbose_name': _("Verbose Name"), - 'metric_type': _("Type"), - 'json': _("JSON"), - 'datasource': _("Druid Datasource"), - } - - def post_add(self, metric): - utils.init_metrics_perm(superset, [metric]) - - def post_update(self, metric): - utils.init_metrics_perm(superset, [metric]) - - -appbuilder.add_view_no_menu(DruidMetricInlineView) - - class DatabaseView(SupersetModelView, DeleteMixin): # noqa datamodel = SQLAInterface(models.Database) list_columns = [ @@ -692,99 +324,6 @@ class DatabaseTablesAsync(DatabaseView): appbuilder.add_view_no_menu(DatabaseTablesAsync) -class TableModelView(SupersetModelView, DeleteMixin): # noqa - datamodel = SQLAInterface(models.SqlaTable) - list_columns = [ - 'link', 'database', 'is_featured', - 'changed_by_', 'changed_on_'] - order_columns = [ - 'link', 'database', 'is_featured', 'changed_on_'] - add_columns = ['database', 'schema', 'table_name'] - edit_columns = [ - 'table_name', 'sql', 'is_featured', 'filter_select_enabled', - 'database', 'schema', - 'description', 'owner', - 'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout'] - show_columns = edit_columns + ['perm'] - related_views = [TableColumnInlineView, SqlMetricInlineView] - base_order = ('changed_on', 'desc') - description_columns = { - 'offset': _("Timezone offset (in hours) for this datasource"), - 'table_name': _( - "Name of the table that exists in the source database"), - 'schema': _( - "Schema, as used only in some databases like Postgres, Redshift " - "and DB2"), - 'description': Markup( - "Supports " - "markdown"), - 'sql': _( - "This fields acts a Superset view, meaning that Superset will " - "run a query against this string as a subquery." - ), - } - base_filters = [['id', DatasourceFilter, lambda: []]] - label_columns = { - 'link': _("Table"), - 'changed_by_': _("Changed By"), - 'database': _("Database"), - 'changed_on_': _("Last Changed"), - 'is_featured': _("Is Featured"), - 'filter_select_enabled': _("Enable Filter Select"), - 'schema': _("Schema"), - 'default_endpoint': _("Default Endpoint"), - 'offset': _("Offset"), - 'cache_timeout': _("Cache Timeout"), - } - - def pre_add(self, table): - number_of_existing_tables = db.session.query( - sqla.func.count('*')).filter( - models.SqlaTable.table_name == table.table_name, - models.SqlaTable.schema == table.schema, - models.SqlaTable.database_id == table.database.id - ).scalar() - # table object is already added to the session - if number_of_existing_tables > 1: - raise Exception(get_datasource_exist_error_mgs(table.full_name)) - - # Fail before adding if the table can't be found - try: - table.get_sqla_table_object() - except Exception as e: - logging.exception(e) - raise Exception( - "Table [{}] could not be found, " - "please double check your " - "database connection, schema, and " - "table name".format(table.name)) - - def post_add(self, table): - table.fetch_metadata() - security.merge_perm(sm, 'datasource_access', table.get_perm()) - if table.schema: - security.merge_perm(sm, 'schema_access', table.schema_perm) - - flash(_( - "The table was created. As part of this two phase configuration " - "process, you should now click the edit button by " - "the new table to configure it."), - "info") - - def post_update(self, table): - self.post_add(table) - -appbuilder.add_view( - TableModelView, - "Tables", - label=__("Tables"), - category="Sources", - category_label=__("Sources"), - icon='fa-table',) - -appbuilder.add_separator("Sources") - - class AccessRequestsModelView(SupersetModelView, DeleteMixin): datamodel = SQLAInterface(DAR) list_columns = [ @@ -810,43 +349,6 @@ class AccessRequestsModelView(SupersetModelView, DeleteMixin): icon='fa-table',) -class DruidClusterModelView(SupersetModelView, DeleteMixin): # noqa - datamodel = SQLAInterface(models.DruidCluster) - add_columns = [ - 'cluster_name', - 'coordinator_host', 'coordinator_port', 'coordinator_endpoint', - 'broker_host', 'broker_port', 'broker_endpoint', 'cache_timeout', - ] - edit_columns = add_columns - list_columns = ['cluster_name', 'metadata_last_refreshed'] - label_columns = { - 'cluster_name': _("Cluster"), - 'coordinator_host': _("Coordinator Host"), - 'coordinator_port': _("Coordinator Port"), - 'coordinator_endpoint': _("Coordinator Endpoint"), - 'broker_host': _("Broker Host"), - 'broker_port': _("Broker Port"), - 'broker_endpoint': _("Broker Endpoint"), - } - - def pre_add(self, cluster): - security.merge_perm(sm, 'database_access', cluster.perm) - - def pre_update(self, cluster): - self.pre_add(cluster) - - -if config['DRUID_IS_ACTIVE']: - appbuilder.add_view( - DruidClusterModelView, - name="Druid Clusters", - label=__("Druid Clusters"), - icon="fa-cubes", - category="Sources", - category_label=__("Sources"), - category_icon='fa-database',) - - class SliceModelView(SupersetModelView, DeleteMixin): # noqa datamodel = SQLAInterface(models.Slice) can_add = False @@ -903,9 +405,9 @@ def add(self): if not widget: return redirect(self.get_redirect()) - sources = SourceRegistry.sources + sources = ConnectorRegistry.sources for source in sources: - ds = db.session.query(SourceRegistry.sources[source]).first() + ds = db.session.query(ConnectorRegistry.sources[source]).first() if ds is not None: url = "/{}/list/".format(ds.baselink) msg = _("Click on a {} link to create a Slice".format(source)) @@ -1078,74 +580,6 @@ class QueryView(SupersetModelView): icon="fa-search") -class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa - datamodel = SQLAInterface(models.DruidDatasource) - list_widget = ListWidgetWithCheckboxes - list_columns = [ - 'datasource_link', 'cluster', 'changed_by_', 'changed_on_', 'offset'] - order_columns = [ - 'datasource_link', 'changed_on_', 'offset'] - related_views = [DruidColumnInlineView, DruidMetricInlineView] - edit_columns = [ - 'datasource_name', 'cluster', 'description', 'owner', - 'is_featured', 'is_hidden', 'filter_select_enabled', - 'default_endpoint', 'offset', 'cache_timeout'] - add_columns = edit_columns - show_columns = add_columns + ['perm'] - page_size = 500 - base_order = ('datasource_name', 'asc') - description_columns = { - 'offset': _("Timezone offset (in hours) for this datasource"), - 'description': Markup( - "Supports markdown"), - } - base_filters = [['id', DatasourceFilter, lambda: []]] - label_columns = { - 'datasource_link': _("Data Source"), - 'cluster': _("Cluster"), - 'description': _("Description"), - 'owner': _("Owner"), - 'is_featured': _("Is Featured"), - 'is_hidden': _("Is Hidden"), - 'filter_select_enabled': _("Enable Filter Select"), - 'default_endpoint': _("Default Endpoint"), - 'offset': _("Time Offset"), - 'cache_timeout': _("Cache Timeout"), - } - - def pre_add(self, datasource): - number_of_existing_datasources = db.session.query( - sqla.func.count('*')).filter( - models.DruidDatasource.datasource_name == - datasource.datasource_name, - models.DruidDatasource.cluster_name == datasource.cluster.id - ).scalar() - - # table object is already added to the session - if number_of_existing_datasources > 1: - raise Exception(get_datasource_exist_error_mgs( - datasource.full_name)) - - def post_add(self, datasource): - datasource.generate_metrics() - security.merge_perm(sm, 'datasource_access', datasource.get_perm()) - if datasource.schema: - security.merge_perm(sm, 'schema_access', datasource.schema_perm) - - def post_update(self, datasource): - self.post_add(datasource) - -if config['DRUID_IS_ACTIVE']: - appbuilder.add_view( - DruidDatasourceModelView, - "Druid Datasources", - label=__("Druid Datasources"), - category="Sources", - category_label=__("Sources"), - icon="fa-cube") - - @app.route('/health') def health(): return "OK" @@ -1283,7 +717,7 @@ def json_response(self, obj, status=200): @has_access_api @expose("/datasources/") def datasources(self): - datasources = SourceRegistry.get_all_datasources(db.session) + datasources = ConnectorRegistry.get_all_datasources(db.session) datasources = [(str(o.id) + '__' + o.type, repr(o)) for o in datasources] return self.json_response(datasources) @@ -1318,7 +752,7 @@ def override_role_permissions(self): dbs['name'], ds_name, schema=schema['name']) db_ds_names.add(fullname) - existing_datasources = SourceRegistry.get_all_datasources(db.session) + existing_datasources = ConnectorRegistry.get_all_datasources(db.session) datasources = [ d for d in existing_datasources if d.full_name in db_ds_names] role = sm.find_role(role_name) @@ -1356,7 +790,7 @@ def request_access(self): datasource_id = request.args.get('datasource_id') datasource_type = request.args.get('datasource_type') if datasource_id: - ds_class = SourceRegistry.sources.get(datasource_type) + ds_class = ConnectorRegistry.sources.get(datasource_type) datasource = ( db.session.query(ds_class) .filter_by(id=int(datasource_id)) @@ -1385,7 +819,7 @@ def request_access(self): def approve(self): def clean_fulfilled_requests(session): for r in session.query(DAR).all(): - datasource = SourceRegistry.get_datasource( + datasource = ConnectorRegistry.get_datasource( r.datasource_type, r.datasource_id, session) user = sm.get_user_by_id(r.created_by_fk) if not datasource or \ @@ -1400,7 +834,7 @@ def clean_fulfilled_requests(session): role_to_extend = request.args.get('role_to_extend') session = db.session - datasource = SourceRegistry.get_datasource( + datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, session) if not datasource: @@ -1501,9 +935,9 @@ def get_viz( ) return slc.get_viz() else: - form_data=self.get_form_data() + form_data = self.get_form_data() viz_type = form_data.get('viz_type', 'table') - datasource = SourceRegistry.get_datasource( + datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session) viz_obj = viz.viz_types[viz_type]( datasource, @@ -1542,7 +976,6 @@ def explore_json(self, datasource_type, datasource_id): utils.error_msg_from_exception(e), stacktrace=traceback.format_exc()) - if not self.datasource_access(viz_obj.datasource): return json_error_response(DATASOURCE_ACCESS_ERR, status=404) @@ -1628,7 +1061,7 @@ def explore(self, datasource_type, datasource_id): error_redirect = '/slicemodelview/list/' datasource = ( - db.session.query(SourceRegistry.sources[datasource_type]) + db.session.query(ConnectorRegistry.sources[datasource_type]) .filter_by(id=datasource_id) .one() ) @@ -2216,8 +1649,8 @@ def warm_up_cache(self): except Exception as e: return json_error_response(utils.error_msg_from_exception(e)) return json_success(json.dumps( - [{"slice_id": session.id, "slice_name": session.slice_name} - for session in slices])) + [{"slice_id": slc.id, "slice_name": slc.slice_name} + for slc in slices])) @expose("/favstar////") def favstar(self, class_name, obj_id, action): @@ -2642,7 +2075,7 @@ def csv(self, client_id): def fetch_datasource_metadata(self): datasource_id, datasource_type = ( request.args.get('datasourceKey').split('__')) - datasource_class = SourceRegistry.sources[datasource_type] + datasource_class = ConnectorRegistry.sources[datasource_type] datasource = ( db.session.query(datasource_class) .filter_by(id=int(datasource_id)) From 10358210bc22ea32163ac5853de8b7894965dfa1 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Tue, 7 Mar 2017 08:20:46 -0800 Subject: [PATCH 3/8] Fixing views --- superset/views/core.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/superset/views/core.py b/superset/views/core.py index fa254d979b003..9752a96cd2248 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1031,11 +1031,8 @@ def import_dashboards(self): data = pickle.load(f) # TODO: import DRUID datasources for table in data['datasources']: - if table.type == 'table': - models.SqlaTable.import_obj(table, import_time=current_tt) - else: - models.DruidDatasource.import_obj( - table, import_time=current_tt) + ds_class = ConnectorRegistry.sources.get(table.type) + ds_class.import_obj(table, import_time=current_tt) db.session.commit() for dashboard in data['dashboards']: models.Dashboard.import_obj( @@ -1138,8 +1135,7 @@ def filter(self, datasource_type, datasource_id, column): """ # TODO: Cache endpoint by user, datasource and column error_redirect = '/slicemodelview/list/' - datasource_class = models.SqlaTable \ - if datasource_type == "table" else models.DruidDatasource + datasource_class = ConnectorRegistry.sources[datasource_type] datasource = db.session.query( datasource_class).filter_by(id=datasource_id).first() @@ -1627,12 +1623,13 @@ def warm_up_cache(self): return json_error_response(__( "Slice %(id)s not found", id=slice_id), status=404) elif table_name and db_name: + SqlaTable = ConnectorRegistry.sources['table'] table = ( - session.query(models.SqlaTable) + session.query(SqlaTable) .join(models.Database) .filter( models.Database.database_name == db_name or - models.SqlaTable.table_name == table_name) + SqlaTable.table_name == table_name) ).first() if not table: return json_error_response(__( @@ -1642,9 +1639,9 @@ def warm_up_cache(self): datasource_id=table.id, datasource_type=table.type).all() - for slice in slices: + for slc in slices: try: - obj = slice.get_viz() + obj = slc.get_viz() obj.get_json(force=True) except Exception as e: return json_error_response(utils.error_msg_from_exception(e)) @@ -1755,12 +1752,13 @@ def sync_druid_source(self): cluster_name = payload['cluster'] user = sm.find_user(username=user_name) + DruidCluster = ConnectorRegistry.sources['druid'] if not user: err_msg = __("Can't find User '%(name)s', please ask your admin " "to create one.", name=user_name) logging.error(err_msg) return json_error_response(err_msg) - cluster = db.session.query(models.DruidCluster).filter_by( + cluster = db.session.query(DruidCluster).filter_by( cluster_name=cluster_name).first() if not cluster: err_msg = __("Can't find DruidCluster with cluster_name = " @@ -1782,13 +1780,14 @@ def sqllab_viz(self): data = json.loads(request.form.get('data')) table_name = data.get('datasourceName') viz_type = data.get('chartType') + SqlaTable = ConnectorRegistry.sources['table'] table = ( - db.session.query(models.SqlaTable) + db.session.query(SqlaTable) .filter_by(table_name=table_name) .first() ) if not table: - table = models.SqlaTable(table_name=table_name) + table = SqlaTable(table_name=table_name) table.database_id = data.get('dbId') q = SupersetQuery(data.get('sql')) table.sql = q.stripped() @@ -2173,7 +2172,8 @@ def search_queries(self): def refresh_datasources(self): """endpoint that refreshes druid datasources metadata""" session = db.session() - for cluster in session.query(models.DruidCluster).all(): + DruidCluster = ConnectorRegistry.sources['druid'] + for cluster in session.query(DruidCluster).all(): cluster_name = cluster.cluster_name try: cluster.refresh_datasources() From 7f8800c108b44657c89af14491357031228ed0c8 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Tue, 7 Mar 2017 17:37:14 -0800 Subject: [PATCH 4/8] Fixing tests --- superset/config.py | 13 ++- superset/connectors/base.py | 81 ++++++++++++++- superset/connectors/druid/models.py | 150 +++++++++++++++------------- superset/connectors/sqla/models.py | 49 ++------- superset/data/__init__.py | 2 +- superset/models/core.py | 1 - superset/security.py | 8 +- superset/sql_lab.py | 3 +- superset/views/base.py | 4 +- superset/views/core.py | 5 +- tests/access_tests.py | 20 ++-- tests/base_tests.py | 25 +++-- tests/celery_tests.py | 6 +- tests/core_tests.py | 18 ++-- tests/druid_tests.py | 18 +++- tests/import_export_tests.py | 54 +++++----- tests/model_tests.py | 2 +- tests/sqllab_tests.py | 4 +- 18 files changed, 275 insertions(+), 188 deletions(-) diff --git a/superset/config.py b/superset/config.py index ce3a3aad4c024..8f55f2cbd332b 100644 --- a/superset/config.py +++ b/superset/config.py @@ -295,14 +295,17 @@ class CeleryConfig(object): BLUEPRINTS = [] try: + if CONFIG_PATH_ENV_VAR in os.environ: # Explicitly import config module that is not in pythonpath; useful # for case where app is being executed via pex. + print('Loaded your LOCAL configuration at [{}]'.format( + os.environ[CONFIG_PATH_ENV_VAR])) imp.load_source('superset_config', os.environ[CONFIG_PATH_ENV_VAR]) - - from superset_config import * # noqa - import superset_config - print('Loaded your LOCAL configuration at [{}]'.format( - superset_config.__file__)) + else: + from superset_config import * # noqa + import superset_config + print('Loaded your LOCAL configuration at [{}]'.format( + superset_config.__file__)) except ImportError: pass diff --git a/superset/connectors/base.py b/superset/connectors/base.py index 45dd3360f8807..4cfeac496e68f 100644 --- a/superset/connectors/base.py +++ b/superset/connectors/base.py @@ -1,12 +1,17 @@ import json +from sqlalchemy import Column, Integer, String, Text, Boolean + from superset import utils +from superset.models.helpers import AuditMixinNullable, ImportMixin -class Datasource(object): +class BaseDatasource(AuditMixinNullable, ImportMixin): """A common interface to objects that are queryable (tables and datasources)""" + __tablename__ = None # {connector_name}_datasource + # Used to do code highlighting when displaying the query in the UI query_language = None @@ -60,7 +65,7 @@ def data(self): d = { 'all_cols': utils.choicify(self.column_names), 'column_formats': self.column_formats, - 'edit_url' : self.url, + 'edit_url': self.url, 'filter_select': self.filter_select_enabled, 'filterable_cols': utils.choicify(self.filterable_column_names), 'gb_cols': utils.choicify(self.groupby_column_names), @@ -79,3 +84,75 @@ def data(self): return d +class BaseColumn(AuditMixinNullable, ImportMixin): + """Interface for column""" + + __tablename__ = None # {connector_name}_column + + id = Column(Integer, primary_key=True) + column_name = Column(String(255)) + verbose_name = Column(String(1024)) + is_active = Column(Boolean, default=True) + type = Column(String(32)) + groupby = Column(Boolean, default=False) + count_distinct = Column(Boolean, default=False) + sum = Column(Boolean, default=False) + avg = Column(Boolean, default=False) + max = Column(Boolean, default=False) + min = Column(Boolean, default=False) + filterable = Column(Boolean, default=False) + description = Column(Text) + + # [optional] Set this to support import/export functionality + export_fields = [] + + def __repr__(self): + return self.column_name + + num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG', 'REAL', 'NUMERIC') + date_types = ('DATE', 'TIME', 'DATETIME') + str_types = ('VARCHAR', 'STRING', 'CHAR') + + @property + def is_num(self): + return any([t in self.type.upper() for t in self.num_types]) + + @property + def is_time(self): + return any([t in self.type.upper() for t in self.date_types]) + + @property + def is_string(self): + return any([t in self.type.upper() for t in self.str_types]) + + +class BaseMetric(AuditMixinNullable, ImportMixin): + + """Interface for Metrics""" + + __tablename__ = None # {connector_name}_metric + + id = Column(Integer, primary_key=True) + metric_name = Column(String(512)) + verbose_name = Column(String(1024)) + metric_type = Column(String(32)) + description = Column(Text) + is_restricted = Column(Boolean, default=False, nullable=True) + d3format = Column(String(128)) + + """ + The interface should also declare a datasource relationship pointing + to a derivative of BaseDatasource, along with a FK + + datasource_name = Column( + String(255), + ForeignKey('datasources.datasource_name')) + datasource = relationship( + # needs to be altered to point to {Connector}Datasource + 'BaseDatasource', + backref=backref('metrics', cascade='all, delete-orphan'), + enable_typechecks=False) + """ + @property + def perm(self): + raise NotImplementedError() diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index f1d70ed42d73b..33683508f52bb 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -6,6 +6,7 @@ from six import string_types import requests +import sqlalchemy as sa from sqlalchemy import ( Column, Integer, String, ForeignKey, Text, Boolean, DateTime, @@ -27,13 +28,14 @@ from flask_babel import lazy_gettext as _ -from superset import config, db, import_util, utils, sm, get_session +from superset import conf, db, import_util, utils, sm, get_session from superset.utils import ( flasher, MetricPermException, DimSelector, DTTM_ALIAS ) -from superset.connectors.base import Datasource -from superset.models.helpers import ( - AuditMixinNullable, ImportMixin, QueryResult) +from superset.connectors.base import BaseDatasource, BaseColumn, BaseMetric +from superset.models.helpers import AuditMixinNullable, QueryResult, set_perm + +DRUID_TZ = conf.get("DRUID_TZ") class JavascriptPostAggregator(Postaggregator): @@ -95,7 +97,7 @@ def refresh_datasources(self, datasource_name=None, merge_flag=False): """ self.druid_version = self.get_druid_version() for datasource in self.get_datasources(): - if datasource not in config.get('DRUID_DATA_SOURCE_BLACKLIST'): + if datasource not in conf.get('DRUID_DATA_SOURCE_BLACKLIST', []): if not datasource_name or datasource_name == datasource: DruidDatasource.sync_to_db(datasource, self, merge_flag) @@ -108,11 +110,11 @@ def name(self): return self.cluster_name -class DruidColumn(Model, AuditMixinNullable, ImportMixin): +class DruidColumn(Model, BaseColumn): """ORM model for storing Druid datasource column metadata""" __tablename__ = 'columns' - id = Column(Integer, primary_key=True) + datasource_name = Column( String(255), ForeignKey('datasources.datasource_name')) @@ -121,17 +123,6 @@ class DruidColumn(Model, AuditMixinNullable, ImportMixin): 'DruidDatasource', backref=backref('columns', cascade='all, delete-orphan'), enable_typechecks=False) - column_name = Column(String(255)) - is_active = Column(Boolean, default=True) - type = Column(String(32)) - groupby = Column(Boolean, default=False) - count_distinct = Column(Boolean, default=False) - sum = Column(Boolean, default=False) - avg = Column(Boolean, default=False) - max = Column(Boolean, default=False) - min = Column(Boolean, default=False) - filterable = Column(Boolean, default=False) - description = Column(Text) dimension_spec_json = Column(Text) export_fields = ( @@ -143,10 +134,6 @@ class DruidColumn(Model, AuditMixinNullable, ImportMixin): def __repr__(self): return self.column_name - @property - def is_num(self): - return self.type in ('LONG', 'DOUBLE', 'FLOAT', 'INT') - @property def dimension_spec(self): if self.dimension_spec_json: @@ -260,15 +247,11 @@ def lookup_obj(lookup_column): return import_util.import_simple_obj(db.session, i_column, lookup_obj) -class DruidMetric(Model, AuditMixinNullable, ImportMixin): +class DruidMetric(Model, BaseMetric): """ORM object referencing Druid metrics for a datasource""" __tablename__ = 'metrics' - id = Column(Integer, primary_key=True) - metric_name = Column(String(512)) - verbose_name = Column(String(1024)) - metric_type = Column(String(32)) datasource_name = Column( String(255), ForeignKey('datasources.datasource_name')) @@ -278,9 +261,6 @@ class DruidMetric(Model, AuditMixinNullable, ImportMixin): backref=backref('metrics', cascade='all, delete-orphan'), enable_typechecks=False) json = Column(Text) - description = Column(Text) - is_restricted = Column(Boolean, default=False, nullable=True) - d3format = Column(String(128)) def refresh_datasources(self, datasource_name=None, merge_flag=False): """Refresh metadata of all datasources in the cluster @@ -289,7 +269,7 @@ def refresh_datasources(self, datasource_name=None, merge_flag=False): """ self.druid_version = self.get_druid_version() for datasource in self.get_datasources(): - if datasource not in config.get('DRUID_DATA_SOURCE_BLACKLIST'): + if datasource not in conf.get('DRUID_DATA_SOURCE_BLACKLIST'): if not datasource_name or datasource_name == datasource: DruidDatasource.sync_to_db(datasource, self, merge_flag) export_fields = ( @@ -322,12 +302,14 @@ def lookup_obj(lookup_metric): return import_util.import_simple_obj(db.session, i_metric, lookup_obj) -class DruidDatasource(Model, AuditMixinNullable, Datasource, ImportMixin): +class DruidDatasource(Model, BaseDatasource): """ORM object referencing Druid datasources (tables)""" type = "druid" query_langtage = "json" + metric_class = DruidMetric + cluster_class = DruidCluster baselink = "druiddatasourcemodelview" @@ -381,7 +363,8 @@ def name(self): @property def schema(self): - name_pieces = self.datasource_name.split('.') + ds_name = self.datasource_name or '' + name_pieces = ds_name.split('.') if len(name_pieces) > 1: return name_pieces[0] else: @@ -506,7 +489,7 @@ def latest_metadata(self): datasource=self.datasource_name, intervals=lbound + '/' + rbound, merge=self.merge_flag, - analysisTypes=config.get('DRUID_ANALYSIS_TYPES')) + analysisTypes=conf.get('DRUID_ANALYSIS_TYPES')) except Exception as e: logging.warning("Failed first attempt to get latest segment") logging.exception(e) @@ -521,7 +504,7 @@ def latest_metadata(self): datasource=self.datasource_name, intervals=lbound + '/' + rbound, merge=self.merge_flag, - analysisTypes=config.get('DRUID_ANALYSIS_TYPES')) + analysisTypes=conf.get('DRUID_ANALYSIS_TYPES')) except Exception as e: logging.warning("Failed 2nd attempt to get latest segment") logging.exception(e) @@ -537,13 +520,14 @@ def sync_to_db_from_config(cls, druid_config, user, cluster): """Merges the ds config from druid_config into one stored in the db.""" session = db.session() datasource = ( - session.query(DruidDatasource) + session.query(cls) .filter_by( datasource_name=druid_config['name']) - ).first() + .first() + ) # Create a new datasource. if not datasource: - datasource = DruidDatasource( + datasource = cls( datasource_name=druid_config['name'], cluster=cluster, owner=user, @@ -559,7 +543,8 @@ def sync_to_db_from_config(cls, druid_config, user, cluster): .filter_by( datasource_name=druid_config['name'], column_name=dim) - ).first() + .first() + ) if not col_obj: col_obj = DruidColumn( datasource_name=druid_config['name'], @@ -568,7 +553,7 @@ def sync_to_db_from_config(cls, druid_config, user, cluster): filterable=True, # TODO: fetch type from Hive. type="STRING", - datasource=datasource + datasource=datasource, ) session.add(col_obj) # Import Druid metrics @@ -710,8 +695,8 @@ def values_for_column(self, limit=500): """Retrieve some values for the given column""" # TODO: Use Lexicographic TopNMetricSpec once supported by PyDruid - from_dttm = from_dttm.replace(tzinfo=config.get("DRUID_TZ")) - to_dttm = to_dttm.replace(tzinfo=config.get("DRUID_TZ")) + from_dttm = from_dttm.replace(tzinfo=DRUID_TZ) + to_dttm = to_dttm.replace(tzinfo=DRUID_TZ) qry = dict( datasource=self.datasource_name, @@ -758,8 +743,8 @@ def get_query_str( # noqa / druid inner_to_dttm = inner_to_dttm or to_dttm # add tzinfo to native datetime with config - from_dttm = from_dttm.replace(tzinfo=config.get("DRUID_TZ")) - to_dttm = to_dttm.replace(tzinfo=config.get("DRUID_TZ")) + from_dttm = from_dttm.replace(tzinfo=DRUID_TZ) + to_dttm = to_dttm.replace(tzinfo=DRUID_TZ) timezone = from_dttm.tzname() query_str = "" @@ -785,40 +770,40 @@ def recursive_get_fields(_conf): if metric.metric_type != 'postagg': all_metrics.append(metric_name) else: - conf = metric.json_obj - all_metrics += recursive_get_fields(conf) - all_metrics += conf.get('fieldNames', []) - if conf.get('type') == 'javascript': + mconf = metric.json_obj + all_metrics += recursive_get_fields(mconf) + all_metrics += mconf.get('fieldNames', []) + if mconf.get('type') == 'javascript': post_aggs[metric_name] = JavascriptPostAggregator( - name=conf.get('name', ''), - field_names=conf.get('fieldNames', []), - function=conf.get('function', '')) - elif conf.get('type') == 'quantile': + name=mconf.get('name', ''), + field_names=mconf.get('fieldNames', []), + function=mconf.get('function', '')) + elif mconf.get('type') == 'quantile': post_aggs[metric_name] = Quantile( - conf.get('name', ''), - conf.get('probability', ''), + mconf.get('name', ''), + mconf.get('probability', ''), ) - elif conf.get('type') == 'quantiles': + elif mconf.get('type') == 'quantiles': post_aggs[metric_name] = Quantiles( - conf.get('name', ''), - conf.get('probabilities', ''), + mconf.get('name', ''), + mconf.get('probabilities', ''), ) - elif conf.get('type') == 'fieldAccess': - post_aggs[metric_name] = Field(conf.get('name'), '') - elif conf.get('type') == 'constant': + elif mconf.get('type') == 'fieldAccess': + post_aggs[metric_name] = Field(mconf.get('name'), '') + elif mconf.get('type') == 'constant': post_aggs[metric_name] = Const( - conf.get('value'), - output_name=conf.get('name', '') + mconf.get('value'), + output_name=mconf.get('name', '') ) - elif conf.get('type') == 'hyperUniqueCardinality': + elif mconf.get('type') == 'hyperUniqueCardinality': post_aggs[metric_name] = HyperUniqueCardinality( - conf.get('name'), '' + mconf.get('name'), '' ) else: post_aggs[metric_name] = Postaggregator( - conf.get('fn', "/"), - conf.get('fields', []), - conf.get('name', '')) + mconf.get('fn', "/"), + mconf.get('fields', []), + mconf.get('name', '')) aggregations = OrderedDict() for m in self.metrics: @@ -980,7 +965,7 @@ def query(self, query_obj): def increment_timestamp(ts): dt = utils.parse_human_datetime(ts).replace( - tzinfo=config.get("DRUID_TZ")) + tzinfo=DRUID_TZ) return dt + timedelta(milliseconds=time_offset) if DTTM_ALIAS in df.columns and time_offset: df[DTTM_ALIAS] = df[DTTM_ALIAS].apply(increment_timestamp) @@ -1046,3 +1031,32 @@ def _get_having_obj(self, col, op, eq): cond = Aggregation(col) < eq return cond + + def get_having_filters(self, raw_filters): + filters = None + reversed_op_map = { + '!=': '==', + '>=': '<', + '<=': '>' + } + + for flt in raw_filters: + if not all(f in flt for f in ['col', 'op', 'val']): + continue + col = flt['col'] + op = flt['op'] + eq = flt['val'] + cond = None + if op in ['==', '>', '<']: + cond = self._get_having_obj(col, op, eq) + elif op in reversed_op_map: + cond = ~self._get_having_obj(col, reversed_op_map[op], eq) + + if filters: + filters = filters & cond + else: + filters = cond + return filters + +sa.event.listen(DruidDatasource, 'after_insert', set_perm) +sa.event.listen(DruidDatasource, 'after_update', set_perm) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 6ec0e3fff443b..fec34b20eab3c 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -20,7 +20,7 @@ from flask_babel import lazy_gettext as _ from superset import db, utils, import_util -from superset.connectors.base import Datasource +from superset.connectors.base import BaseDatasource, BaseColumn, BaseMetric from superset.utils import ( wrap_clause_in_parens, DTTM_ALIAS, QueryStatus @@ -28,40 +28,24 @@ from superset.models.helpers import QueryResult from superset.models.core import Database from superset.jinja_context import get_template_processor -from superset.models.helpers import AuditMixinNullable, ImportMixin, set_perm +from superset.models.helpers import set_perm -class TableColumn(Model, AuditMixinNullable, ImportMixin): +class TableColumn(Model, BaseColumn): """ORM object for table columns, each table can have multiple columns""" __tablename__ = 'table_columns' - id = Column(Integer, primary_key=True) table_id = Column(Integer, ForeignKey('tables.id')) table = relationship( 'SqlaTable', backref=backref('columns', cascade='all, delete-orphan'), foreign_keys=[table_id]) - column_name = Column(String(255)) - verbose_name = Column(String(1024)) is_dttm = Column(Boolean, default=False) - is_active = Column(Boolean, default=True) - type = Column(String(32), default='') - groupby = Column(Boolean, default=False) - count_distinct = Column(Boolean, default=False) - sum = Column(Boolean, default=False) - avg = Column(Boolean, default=False) - max = Column(Boolean, default=False) - min = Column(Boolean, default=False) - filterable = Column(Boolean, default=False) expression = Column(Text, default='') - description = Column(Text, default='') python_date_format = Column(String(255)) database_expression = Column(String(255)) - num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG', 'REAL', 'NUMERIC') - date_types = ('DATE', 'TIME') - str_types = ('VARCHAR', 'STRING', 'CHAR') export_fields = ( 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min', @@ -69,21 +53,6 @@ class TableColumn(Model, AuditMixinNullable, ImportMixin): 'database_expression' ) - def __repr__(self): - return self.column_name - - @property - def is_num(self): - return any([t in self.type.upper() for t in self.num_types]) - - @property - def is_time(self): - return any([t in self.type.upper() for t in self.date_types]) - - @property - def is_string(self): - return any([t in self.type.upper() for t in self.str_types]) - @property def sqla_col(self): name = self.column_name @@ -149,24 +118,17 @@ def dttm_sql_literal(self, dttm): return s or "'{}'".format(dttm.strftime(tf)) -class SqlMetric(Model, AuditMixinNullable, ImportMixin): +class SqlMetric(Model, BaseMetric): """ORM object for metrics, each table can have multiple metrics""" __tablename__ = 'sql_metrics' - id = Column(Integer, primary_key=True) - metric_name = Column(String(512)) - verbose_name = Column(String(1024)) - metric_type = Column(String(32)) table_id = Column(Integer, ForeignKey('tables.id')) table = relationship( 'SqlaTable', backref=backref('metrics', cascade='all, delete-orphan'), foreign_keys=[table_id]) expression = Column(Text) - description = Column(Text) - is_restricted = Column(Boolean, default=False, nullable=True) - d3format = Column(String(128)) export_fields = ( 'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression', @@ -193,12 +155,13 @@ def lookup_obj(lookup_metric): return import_util.import_simple_obj(db.session, i_metric, lookup_obj) -class SqlaTable(Model, Datasource, AuditMixinNullable, ImportMixin): +class SqlaTable(Model, BaseDatasource): """An ORM object for SqlAlchemy table references""" type = "table" query_language = 'sql' + metric_class = SqlMetric __tablename__ = 'tables' id = Column(Integer, primary_key=True) diff --git a/superset/data/__init__.py b/superset/data/__init__.py index f061b629bd9ae..6717f1f17b0d7 100644 --- a/superset/data/__init__.py +++ b/superset/data/__init__.py @@ -25,7 +25,7 @@ Slice = models.Slice Dash = models.Dashboard -TBL = ConnectorRegistry.sources['sqla'] +TBL = ConnectorRegistry.sources['table'] config = app.config diff --git a/superset/models/core.py b/superset/models/core.py index e8f3f2bc8218b..5157e4ac7fd58 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -10,7 +10,6 @@ import numpy import pickle import re -import sqlparse import textwrap from future.standard_library import install_aliases from copy import copy diff --git a/superset/security.py b/superset/security.py index 795f5f45322ff..435f23d4634fe 100644 --- a/superset/security.py +++ b/superset/security.py @@ -183,8 +183,8 @@ def create_missing_metrics_perm(view_menu_set): """ logging.info("Creating missing metrics permissions") metrics = [] - for model in [models.SqlMetric, models.DruidMetric]: - metrics += list(db.session.query(model).all()) + for datasource_class in ConnectorRegistry.sources.values(): + metrics += list(db.session.query(datasource_class.metric_class).all()) for metric in metrics: if (metric.is_restricted and metric.perm and @@ -218,7 +218,9 @@ def sync_role_definitions(): if conf.get('PUBLIC_ROLE_LIKE_GAMMA', False): set_role('Public', pvms, is_gamma_pvm) - view_menu_set = db.session.query(models.SqlaTable).all() + view_menu_set = [] + for datasource_class in ConnectorRegistry.sources.values(): + view_menu_set += list(db.session.query(datasource_class).all()) create_missing_datasource_perms(view_menu_set) create_missing_database_perms(view_menu_set) create_missing_metrics_perm(view_menu_set) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 4c52adc2474b4..94d26a4991b9e 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -12,7 +12,8 @@ from sqlalchemy.orm import sessionmaker from superset import ( - app, db, models, utils, dataframe, results_backend) + app, db, utils, dataframe, results_backend) +from superset.models import core as models from superset.sql_parse import SupersetQuery from superset.db_engine_specs import LimitMethod from superset.jinja_context import get_template_processor diff --git a/superset/views/base.py b/superset/views/base.py index 46dbf99f1ef05..32ae3694f6d1c 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -11,7 +11,7 @@ from flask_appbuilder.models.sqla.filters import BaseFilter from flask_appbuilder.security.sqla import models as ab_models -from superset import appbuilder, config, db, utils, sm, sql_parse +from superset import appbuilder, conf, db, utils, sm, sql_parse from superset.connectors.connector_registry import ConnectorRegistry @@ -21,7 +21,7 @@ def get_datasource_exist_error_mgs(full_name): def get_user_roles(): if g.user.is_anonymous(): - public_role = config.get('AUTH_ROLE_PUBLIC') + public_role = conf.get('AUTH_ROLE_PUBLIC') return [appbuilder.sm.find_role(public_role)] if public_role else [] return g.user.roles diff --git a/superset/views/core.py b/superset/views/core.py index 9752a96cd2248..4813d9b3ed6d9 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1752,7 +1752,8 @@ def sync_druid_source(self): cluster_name = payload['cluster'] user = sm.find_user(username=user_name) - DruidCluster = ConnectorRegistry.sources['druid'] + DruidDatasource = ConnectorRegistry.sources['druid'] + DruidCluster = DruidDatasource.cluster_class if not user: err_msg = __("Can't find User '%(name)s', please ask your admin " "to create one.", name=user_name) @@ -1766,7 +1767,7 @@ def sync_druid_source(self): logging.error(err_msg) return json_error_response(err_msg) try: - models.DruidDatasource.sync_to_db_from_config( + DruidDatasource.sync_to_db_from_config( druid_config, user, cluster) except Exception as e: logging.exception(utils.error_msg_from_exception(e)) diff --git a/tests/access_tests.py b/tests/access_tests.py index 1d0b2ac1e74a0..54235b3a7f6cb 100644 --- a/tests/access_tests.py +++ b/tests/access_tests.py @@ -9,7 +9,11 @@ import unittest from superset import db, models, sm, security -from superset.source_registry import SourceRegistry +from superset.connector_registry import ConnectorRegistry + +from superset.models import core as models +from superset.connectors.sqla.models import SqlaTable +from superset.connectors.druid.models import DruidDatasource from .base_tests import SupersetTestCase @@ -58,7 +62,7 @@ def create_access_request(session, ds_type, ds_name, role_name, user_name): - ds_class = SourceRegistry.sources[ds_type] + ds_class = ConnectorRegistry.sources[ds_type] # TODO: generalize datasource names if ds_type == 'table': ds = session.query(ds_class).filter( @@ -293,7 +297,7 @@ def test_clean_requests_after_schema_grant(self): access_request2 = create_access_request( session, 'table', 'wb_health_population', TEST_ROLE_2, 'gamma2') ds_1_id = access_request1.datasource_id - ds = session.query(models.SqlaTable).filter_by( + ds = session.query(SqlaTable).filter_by( table_name='wb_health_population').first() @@ -314,7 +318,7 @@ def test_clean_requests_after_schema_grant(self): gamma_user = sm.find_user(username='gamma') gamma_user.roles.remove(sm.find_role(SCHEMA_ACCESS_ROLE)) - ds = session.query(models.SqlaTable).filter_by( + ds = session.query(SqlaTable).filter_by( table_name='wb_health_population').first() ds.schema = None @@ -441,7 +445,7 @@ def test_request_access(self): # Request table access, there are no roles have this table. - table1 = session.query(models.SqlaTable).filter_by( + table1 = session.query(SqlaTable).filter_by( table_name='random_time_series').first() table_1_id = table1.id @@ -454,7 +458,7 @@ def test_request_access(self): # Request access, roles exist that contains the table. # add table to the existing roles - table3 = session.query(models.SqlaTable).filter_by( + table3 = session.query(SqlaTable).filter_by( table_name='energy_usage').first() table_3_id = table3.id table3_perm = table3.perm @@ -479,7 +483,7 @@ def test_request_access(self): '
    • {}
    '.format(approve_link_3)) # Request druid access, there are no roles have this table. - druid_ds_4 = session.query(models.DruidDatasource).filter_by( + druid_ds_4 = session.query(DruidDatasource).filter_by( datasource_name='druid_ds_1').first() druid_ds_4_id = druid_ds_4.id @@ -493,7 +497,7 @@ def test_request_access(self): # Case 5. Roles exist that contains the druid datasource. # add druid ds to the existing roles - druid_ds_5 = session.query(models.DruidDatasource).filter_by( + druid_ds_5 = session.query(DruidDatasource).filter_by( datasource_name='druid_ds_2').first() druid_ds_5_id = druid_ds_5.id druid_ds_5_perm = druid_ds_5.perm diff --git a/tests/base_tests.py b/tests/base_tests.py index fb218b9c538db..6586802cc4155 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -11,8 +11,11 @@ from flask_appbuilder.security.sqla import models as ab_models -from superset import app, cli, db, models, appbuilder, security, sm +from superset import app, cli, db, appbuilder, security, sm +from superset.models import core as models from superset.security import sync_role_definitions +from superset.connectors.sqla.models import SqlaTable +from superset.connectors.druid.models import DruidCluster, DruidDatasource os.environ['SUPERSET_CONFIG'] = 'tests.superset_test_config' @@ -85,30 +88,34 @@ def __init__(self, *args, **kwargs): appbuilder.sm.find_role('Alpha'), password='general') sm.get_session.commit() - # create druid cluster and druid datasources session = db.session - cluster = session.query(models.DruidCluster).filter_by( - cluster_name="druid_test").first() + cluster = ( + session.query(DruidCluster) + .filter_by(cluster_name="druid_test") + .first() + ) if not cluster: - cluster = models.DruidCluster(cluster_name="druid_test") + cluster = DruidCluster(cluster_name="druid_test") session.add(cluster) session.commit() - druid_datasource1 = models.DruidDatasource( + druid_datasource1 = DruidDatasource( datasource_name='druid_ds_1', cluster_name='druid_test' ) session.add(druid_datasource1) - druid_datasource2 = models.DruidDatasource( + druid_datasource2 = DruidDatasource( datasource_name='druid_ds_2', cluster_name='druid_test' ) session.add(druid_datasource2) session.commit() + + def get_table(self, table_id): - return db.session.query(models.SqlaTable).filter_by( + return db.session.query(SqlaTable).filter_by( id=table_id).first() def get_or_create(self, cls, criteria, session): @@ -149,7 +156,7 @@ def get_slice(self, slice_name, session): return slc def get_table_by_name(self, name): - return db.session.query(models.SqlaTable).filter_by( + return db.session.query(SqlaTable).filter_by( table_name=name).first() def get_druid_ds_by_name(self, name): diff --git a/tests/celery_tests.py b/tests/celery_tests.py index a4e48f4b191ea..6a3802c7d258a 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -12,14 +12,14 @@ import pandas as pd -from superset import app, appbuilder, cli, db, models, dataframe +from superset import app, appbuilder, cli, db, dataframe +from superset.models import core as models +from superset.models.helpers import QueryStatus from superset.security import sync_role_definitions from superset.sql_parse import SupersetQuery from .base_tests import SupersetTestCase -QueryStatus = models.QueryStatus - BASE_DIR = app.config.get('BASE_DIR') diff --git a/tests/core_tests.py b/tests/core_tests.py index 2634b9eb0d29f..bc1543f66b93d 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -7,17 +7,19 @@ import csv import doctest import json +import logging import io import random import unittest from flask import escape -from superset import db, models, utils, appbuilder, sm, jinja_context, sql_lab -from superset.views import DatabaseView +from superset import db, utils, appbuilder, sm, jinja_context, sql_lab +from superset.models import core as models +from superset.views.core import DatabaseView +from superset.connectors.sqla.models import SqlaTable from .base_tests import SupersetTestCase -import logging class CoreTests(SupersetTestCase): @@ -31,7 +33,7 @@ def __init__(self, *args, **kwargs): def setUpClass(cls): cls.table_ids = {tbl.table_name: tbl.id for tbl in ( db.session - .query(models.SqlaTable) + .query(SqlaTable) .all() )} @@ -186,7 +188,7 @@ def test_filter_endpoint(self): slice_id = self.get_slice(slice_name, db.session).id db.session.commit() tbl_id = self.table_ids.get('energy_usage') - table = db.session.query(models.SqlaTable).filter(models.SqlaTable.id == tbl_id) + table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id) table.filter_select_enabled = True url = ( "/superset/filter/table/{}/target/?viz_type=sankey&groupby=source" @@ -220,7 +222,7 @@ def test_add_slice(self): url = '/tablemodelview/list/' resp = self.get_resp(url) - table = db.session.query(models.SqlaTable).first() + table = db.session.query(SqlaTable).first() assert table.name in resp assert '/superset/explore/table/{}'.format(table.id) in resp @@ -459,7 +461,7 @@ def test_csv_endpoint(self): def test_public_user_dashboard_access(self): table = ( db.session - .query(models.SqlaTable) + .query(SqlaTable) .filter_by(table_name='birth_names') .one() ) @@ -494,7 +496,7 @@ def test_dashboard_with_created_by_can_be_accessed_by_public_users(self): self.logout() table = ( db.session - .query(models.SqlaTable) + .query(SqlaTable) .filter_by(table_name='birth_names') .one() ) diff --git a/tests/druid_tests.py b/tests/druid_tests.py index 8d026234d21ef..c2affa9634731 100644 --- a/tests/druid_tests.py +++ b/tests/druid_tests.py @@ -11,7 +11,8 @@ from mock import Mock, patch from superset import db, sm, security -from superset.models import DruidCluster, DruidDatasource +from superset.connectors.druid.models import DruidCluster, DruidDatasource +from superset.connectors.druid.models import PyDruid from .base_tests import SupersetTestCase @@ -70,13 +71,14 @@ class DruidTests(SupersetTestCase): def __init__(self, *args, **kwargs): super(DruidTests, self).__init__(*args, **kwargs) - @patch('superset.models.PyDruid') + @patch('superset.connectors.druid.models.PyDruid') def test_client(self, PyDruid): self.login(username='admin') instance = PyDruid.return_value instance.time_boundary.return_value = [ {'result': {'maxTime': '2016-01-01'}}] instance.segment_metadata.return_value = SEGMENT_METADATA + print(PyDruid()) cluster = ( db.session @@ -135,6 +137,7 @@ def test_client(self, PyDruid): datasource_id, json.dumps(form_data)) ) resp = self.get_json_resp(url) + print(resp) self.assertEqual("Canada", resp['data']['records'][0]['dim1']) form_data = { @@ -197,8 +200,13 @@ def test_druid_sync_from_config(self): } def check(): resp = self.client.post('/superset/sync_druid/', data=json.dumps(cfg)) - druid_ds = db.session.query(DruidDatasource).filter_by( - datasource_name="test_click").first() + print(resp) + druid_ds = ( + db.session + .query(DruidDatasource) + .filter_by(datasource_name="test_click") + .one() + ) col_names = set([c.column_name for c in druid_ds.columns]) assert {"affiliate_id", "campaign", "first_seen"} == col_names metric_names = {m.metric_name for m in druid_ds.metrics} @@ -224,7 +232,7 @@ def check(): } resp = self.client.post('/superset/sync_druid/', data=json.dumps(cfg)) druid_ds = db.session.query(DruidDatasource).filter_by( - datasource_name="test_click").first() + datasource_name="test_click").one() # columns and metrics are not deleted if config is changed as # user could define his own dimensions / metrics and want to keep them assert set([c.column_name for c in druid_ds.columns]) == set( diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 2120fce9d6095..05536354cdf75 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -10,7 +10,11 @@ import pickle import unittest -from superset import db, models +from superset import db +from superset.models import core as models +from superset.connectors.druid.models import ( + DruidDatasource, DruidColumn, DruidMetric) +from superset.connectors.sqla.models import SqlaTable, TableColumn, SqlMetric from .base_tests import SupersetTestCase @@ -31,10 +35,10 @@ def delete_imports(cls): for dash in session.query(models.Dashboard): if 'remote_id' in dash.params_dict: session.delete(dash) - for table in session.query(models.SqlaTable): + for table in session.query(SqlaTable): if 'remote_id' in table.params_dict: session.delete(table) - for datasource in session.query(models.DruidDatasource): + for datasource in session.query(DruidDatasource): if 'remote_id' in datasource.params_dict: session.delete(datasource) session.commit() @@ -90,7 +94,7 @@ def create_dashboard(self, title, id=0, slcs=[]): def create_table( self, name, schema='', id=0, cols_names=[], metric_names=[]): params = {'remote_id': id, 'database_name': 'main'} - table = models.SqlaTable( + table = SqlaTable( id=id, schema=schema, table_name=name, @@ -98,15 +102,15 @@ def create_table( ) for col_name in cols_names: table.columns.append( - models.TableColumn(column_name=col_name)) + TableColumn(column_name=col_name)) for metric_name in metric_names: - table.metrics.append(models.SqlMetric(metric_name=metric_name)) + table.metrics.append(SqlMetric(metric_name=metric_name)) return table def create_druid_datasource( self, name, id=0, cols_names=[], metric_names=[]): params = {'remote_id': id, 'database_name': 'druid_test'} - datasource = models.DruidDatasource( + datasource = DruidDatasource( id=id, datasource_name=name, cluster_name='druid_test', @@ -114,9 +118,9 @@ def create_druid_datasource( ) for col_name in cols_names: datasource.columns.append( - models.DruidColumn(column_name=col_name)) + DruidColumn(column_name=col_name)) for metric_name in metric_names: - datasource.metrics.append(models.DruidMetric( + datasource.metrics.append(DruidMetric( metric_name=metric_name)) return datasource @@ -136,11 +140,11 @@ def get_dash_by_slug(self, dash_slug): slug=dash_slug).first() def get_datasource(self, datasource_id): - return db.session.query(models.DruidDatasource).filter_by( + return db.session.query(DruidDatasource).filter_by( id=datasource_id).first() def get_table_by_name(self, name): - return db.session.query(models.SqlaTable).filter_by( + return db.session.query(SqlaTable).filter_by( table_name=name).first() def assert_dash_equals(self, expected_dash, actual_dash, @@ -392,7 +396,7 @@ def test_import_override_dashboard_2_slices(self): def test_import_table_no_metadata(self): table = self.create_table('pure_table', id=10001) - imported_id = models.SqlaTable.import_obj(table, import_time=1989) + imported_id = SqlaTable.import_obj(table, import_time=1989) imported = self.get_table(imported_id) self.assert_table_equals(table, imported) @@ -400,7 +404,7 @@ def test_import_table_1_col_1_met(self): table = self.create_table( 'table_1_col_1_met', id=10002, cols_names=["col1"], metric_names=["metric1"]) - imported_id = models.SqlaTable.import_obj(table, import_time=1990) + imported_id = SqlaTable.import_obj(table, import_time=1990) imported = self.get_table(imported_id) self.assert_table_equals(table, imported) self.assertEquals( @@ -411,7 +415,7 @@ def test_import_table_2_col_2_met(self): table = self.create_table( 'table_2_col_2_met', id=10003, cols_names=['c1', 'c2'], metric_names=['m1', 'm2']) - imported_id = models.SqlaTable.import_obj(table, import_time=1991) + imported_id = SqlaTable.import_obj(table, import_time=1991) imported = self.get_table(imported_id) self.assert_table_equals(table, imported) @@ -420,12 +424,12 @@ def test_import_table_override(self): table = self.create_table( 'table_override', id=10003, cols_names=['col1'], metric_names=['m1']) - imported_id = models.SqlaTable.import_obj(table, import_time=1991) + imported_id = SqlaTable.import_obj(table, import_time=1991) table_over = self.create_table( 'table_override', id=10003, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_over_id = models.SqlaTable.import_obj( + imported_over_id = SqlaTable.import_obj( table_over, import_time=1992) imported_over = self.get_table(imported_over_id) @@ -439,12 +443,12 @@ def test_import_table_override_idential(self): table = self.create_table( 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_id = models.SqlaTable.import_obj(table, import_time=1993) + imported_id = SqlaTable.import_obj(table, import_time=1993) copy_table = self.create_table( 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_id_copy = models.SqlaTable.import_obj( + imported_id_copy = SqlaTable.import_obj( copy_table, import_time=1994) self.assertEquals(imported_id, imported_id_copy) @@ -452,7 +456,7 @@ def test_import_table_override_idential(self): def test_import_druid_no_metadata(self): datasource = self.create_druid_datasource('pure_druid', id=10001) - imported_id = models.DruidDatasource.import_obj( + imported_id = DruidDatasource.import_obj( datasource, import_time=1989) imported = self.get_datasource(imported_id) self.assert_datasource_equals(datasource, imported) @@ -461,7 +465,7 @@ def test_import_druid_1_col_1_met(self): datasource = self.create_druid_datasource( 'druid_1_col_1_met', id=10002, cols_names=["col1"], metric_names=["metric1"]) - imported_id = models.DruidDatasource.import_obj( + imported_id = DruidDatasource.import_obj( datasource, import_time=1990) imported = self.get_datasource(imported_id) self.assert_datasource_equals(datasource, imported) @@ -474,7 +478,7 @@ def test_import_druid_2_col_2_met(self): datasource = self.create_druid_datasource( 'druid_2_col_2_met', id=10003, cols_names=['c1', 'c2'], metric_names=['m1', 'm2']) - imported_id = models.DruidDatasource.import_obj( + imported_id = DruidDatasource.import_obj( datasource, import_time=1991) imported = self.get_datasource(imported_id) self.assert_datasource_equals(datasource, imported) @@ -483,14 +487,14 @@ def test_import_druid_override(self): datasource = self.create_druid_datasource( 'druid_override', id=10003, cols_names=['col1'], metric_names=['m1']) - imported_id = models.DruidDatasource.import_obj( + imported_id = DruidDatasource.import_obj( datasource, import_time=1991) table_over = self.create_druid_datasource( 'druid_override', id=10003, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_over_id = models.DruidDatasource.import_obj( + imported_over_id = DruidDatasource.import_obj( table_over, import_time=1992) imported_over = self.get_datasource(imported_over_id) @@ -504,13 +508,13 @@ def test_import_druid_override_idential(self): datasource = self.create_druid_datasource( 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_id = models.DruidDatasource.import_obj( + imported_id = DruidDatasource.import_obj( datasource, import_time=1993) copy_datasource = self.create_druid_datasource( 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_id_copy = models.DruidDatasource.import_obj( + imported_id_copy = DruidDatasource.import_obj( copy_datasource, import_time=1994) self.assertEquals(imported_id, imported_id_copy) diff --git a/tests/model_tests.py b/tests/model_tests.py index c9e60178abdb5..780a46db9fbb3 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -2,7 +2,7 @@ from sqlalchemy.engine.url import make_url -from superset.models import Database +from superset.models.core import Database class DatabaseModelTestCase(unittest.TestCase): diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index 1ddaf85c2bcb3..e1dac77103dbc 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -9,7 +9,9 @@ import unittest from flask_appbuilder.security.sqla import models as ab_models -from superset import db, models, utils, appbuilder, security, sm +from superset import db, utils, appbuilder, sm +from superset.models import core as models + from .base_tests import SupersetTestCase From a82c06941786b26758b8abdd9b4d5e9dc6291034 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Wed, 8 Mar 2017 11:54:58 -0800 Subject: [PATCH 5/8] Adding migrtion --- ...b6c_adding_verbose_name_to_druid_column.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 superset/migrations/versions/b318dfe5fb6c_adding_verbose_name_to_druid_column.py diff --git a/superset/migrations/versions/b318dfe5fb6c_adding_verbose_name_to_druid_column.py b/superset/migrations/versions/b318dfe5fb6c_adding_verbose_name_to_druid_column.py new file mode 100644 index 0000000000000..d492427b644f8 --- /dev/null +++ b/superset/migrations/versions/b318dfe5fb6c_adding_verbose_name_to_druid_column.py @@ -0,0 +1,22 @@ +"""adding verbose_name to druid column + +Revision ID: b318dfe5fb6c +Revises: d6db5a5cdb5d +Create Date: 2017-03-08 11:48:10.835741 + +""" + +# revision identifiers, used by Alembic. +revision = 'b318dfe5fb6c' +down_revision = 'd6db5a5cdb5d' + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + op.add_column('columns', sa.Column('verbose_name', sa.String(length=1024), nullable=True)) + + +def downgrade(): + op.drop_column('columns', 'verbose_name') From 4da1e8818c4662d0efa248f51a8736589bbc7439 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Wed, 8 Mar 2017 12:52:53 -0800 Subject: [PATCH 6/8] Tests --- tests/druid_tests.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/druid_tests.py b/tests/druid_tests.py index c2affa9634731..a1121425d22b2 100644 --- a/tests/druid_tests.py +++ b/tests/druid_tests.py @@ -78,7 +78,6 @@ def test_client(self, PyDruid): instance.time_boundary.return_value = [ {'result': {'maxTime': '2016-01-01'}}] instance.segment_metadata.return_value = SEGMENT_METADATA - print(PyDruid()) cluster = ( db.session @@ -137,7 +136,6 @@ def test_client(self, PyDruid): datasource_id, json.dumps(form_data)) ) resp = self.get_json_resp(url) - print(resp) self.assertEqual("Canada", resp['data']['records'][0]['dim1']) form_data = { @@ -200,7 +198,6 @@ def test_druid_sync_from_config(self): } def check(): resp = self.client.post('/superset/sync_druid/', data=json.dumps(cfg)) - print(resp) druid_ds = ( db.session .query(DruidDatasource) From f5aeccbdf6fad0c7e73575841a1c239ea04b44e2 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Wed, 8 Mar 2017 14:39:16 -0800 Subject: [PATCH 7/8] Final --- tests/access_tests.py | 2 +- tests/base_tests.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/access_tests.py b/tests/access_tests.py index 54235b3a7f6cb..ec1ce7483dcea 100644 --- a/tests/access_tests.py +++ b/tests/access_tests.py @@ -9,9 +9,9 @@ import unittest from superset import db, models, sm, security -from superset.connector_registry import ConnectorRegistry from superset.models import core as models +from superset.connectors.connector_registry import ConnectorRegistry from superset.connectors.sqla.models import SqlaTable from superset.connectors.druid.models import DruidDatasource diff --git a/tests/base_tests.py b/tests/base_tests.py index 6586802cc4155..1a86982adf265 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -160,7 +160,7 @@ def get_table_by_name(self, name): table_name=name).first() def get_druid_ds_by_name(self, name): - return db.session.query(models.DruidDatasource).filter_by( + return db.session.query(DruidDatasource).filter_by( datasource_name=name).first() def get_resp( From 8a2307343c24b6871c37894290ef6d11317d83a9 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Wed, 8 Mar 2017 20:45:13 -0800 Subject: [PATCH 8/8] Addressing comments --- superset/connectors/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/superset/connectors/base.py b/superset/connectors/base.py index 4cfeac496e68f..bdb7852f5129a 100644 --- a/superset/connectors/base.py +++ b/superset/connectors/base.py @@ -56,7 +56,7 @@ def column_formats(self): @property def data(self): - """data representation of the datasource sent to the frontend""" + """Data representation of the datasource sent to the frontend""" order_by_choices = [] for s in sorted(self.column_names): order_by_choices.append((json.dumps([s, True]), s + ' [asc]')) @@ -75,6 +75,8 @@ def data(self): 'order_by_choices': order_by_choices, 'type': self.type, } + + # TODO move this block to SqlaTable.data if self.type == 'table': grains = self.database.grains() or [] if grains: