From cf0b670932e11a7d1fbfaaab7d6d6748340d4e80 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Mon, 25 Sep 2017 18:00:46 -0700 Subject: [PATCH] Druid refresh metadata performance improvements (#3527) * parallelized refresh druid metadata * fixed code style errors * fixed code for python3 * added option to only scan for new druid datasources * Increased code coverage --- superset/connectors/druid/models.py | 294 +++++++++++++++++----------- superset/connectors/druid/views.py | 43 ++-- tests/druid_tests.py | 14 +- 3 files changed, 220 insertions(+), 131 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index cc40b83f1f1ee..89e1ed90ccc39 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -5,12 +5,13 @@ from copy import deepcopy from datetime import datetime, timedelta from six import string_types +from multiprocessing import Pool import requests import sqlalchemy as sa from sqlalchemy import ( Column, Integer, String, ForeignKey, Text, Boolean, - DateTime, + DateTime, or_, and_, ) from sqlalchemy.orm import backref, relationship from dateutil.parser import parse as dparse @@ -39,6 +40,12 @@ DRUID_TZ = conf.get("DRUID_TZ") +# Function wrapper because bound methods cannot +# be passed to processes +def _fetch_metadata_for(datasource): + return datasource.latest_metadata() + + class JavascriptPostAggregator(Postaggregator): def __init__(self, name, field_names, function): self.post_aggregator = { @@ -101,15 +108,99 @@ def get_druid_version(self): ).format(obj=self) return json.loads(requests.get(endpoint).text)['version'] - def refresh_datasources(self, datasource_name=None, merge_flag=False): + def refresh_datasources( + self, + datasource_name=None, + merge_flag=True, + refreshAll=True): """Refresh metadata of all datasources in the cluster If ``datasource_name`` is specified, only that datasource is updated """ self.druid_version = self.get_druid_version() - for datasource in self.get_datasources(): - if datasource not in conf.get('DRUID_DATA_SOURCE_BLACKLIST', []): - if not datasource_name or datasource_name == datasource: - DruidDatasource.sync_to_db(datasource, self, merge_flag) + ds_list = self.get_datasources() + blacklist = conf.get('DRUID_DATA_SOURCE_BLACKLIST', []) + ds_refresh = [] + if not datasource_name: + ds_refresh = list(filter(lambda ds: ds not in blacklist, ds_list)) + elif datasource_name not in blacklist and datasource_name in ds_list: + ds_refresh.append(datasource_name) + else: + return + self.refresh_async(ds_refresh, merge_flag, refreshAll) + + def refresh_async(self, datasource_names, merge_flag, refreshAll): + """ + Fetches metadata for the specified datasources andm + merges to the Superset database + """ + session = db.session + ds_list = ( + session.query(DruidDatasource) + .filter(or_(DruidDatasource.datasource_name == name + for name in datasource_names)) + ) + + ds_map = {ds.name: ds for ds in ds_list} + for ds_name in datasource_names: + datasource = ds_map.get(ds_name, None) + if not datasource: + datasource = DruidDatasource(datasource_name=ds_name) + with session.no_autoflush: + session.add(datasource) + flasher( + "Adding new datasource [{}]".format(ds_name), 'success') + ds_map[ds_name] = datasource + elif refreshAll: + flasher( + "Refreshing datasource [{}]".format(ds_name), 'info') + else: + del ds_map[ds_name] + continue + datasource.cluster = self + datasource.merge_flag = merge_flag + session.flush() + + # Prepare multithreaded executation + pool = Pool() + ds_refresh = list(ds_map.values()) + metadata = pool.map(_fetch_metadata_for, ds_refresh) + pool.close() + pool.join() + + for i in range(0, len(ds_refresh)): + datasource = ds_refresh[i] + cols = metadata[i] + col_objs_list = ( + session.query(DruidColumn) + .filter(DruidColumn.datasource_name == datasource.datasource_name) + .filter(or_(DruidColumn.column_name == col for col in cols)) + ) + col_objs = {col.column_name: col for col in col_objs_list} + for col in cols: + if col == '__time': # skip the time column + continue + col_obj = col_objs.get(col, None) + if not col_obj: + col_obj = DruidColumn( + datasource_name=datasource.datasource_name, + column_name=col) + with session.no_autoflush: + session.add(col_obj) + datatype = cols[col]['type'] + if datatype == 'STRING': + col_obj.groupby = True + col_obj.filterable = True + if datatype == 'hyperUnique' or datatype == 'thetaSketch': + col_obj.count_distinct = True + # Allow sum/min/max for long or double + if datatype == 'LONG' or datatype == 'DOUBLE': + col_obj.sum = True + col_obj.min = True + col_obj.max = True + col_obj.type = datatype + col_obj.datasource = datasource + datasource.generate_metrics_for(col_objs_list) + session.commit() @property def perm(self): @@ -160,16 +251,14 @@ def dimension_spec(self): if self.dimension_spec_json: return json.loads(self.dimension_spec_json) - def generate_metrics(self): - """Generate metrics based on the column metadata""" - M = DruidMetric # noqa - metrics = [] - metrics.append(DruidMetric( + def get_metrics(self): + metrics = {} + metrics['count'] = DruidMetric( metric_name='count', verbose_name='COUNT(*)', metric_type='count', json=json.dumps({'type': 'count', 'name': 'count'}) - )) + ) # Somehow we need to reassign this for UDAFs if self.type in ('DOUBLE', 'FLOAT'): corrected_type = 'DOUBLE' @@ -179,49 +268,49 @@ def generate_metrics(self): if self.sum and self.is_num: mt = corrected_type.lower() + 'Sum' name = 'sum__' + self.column_name - metrics.append(DruidMetric( + metrics[name] = DruidMetric( metric_name=name, metric_type='sum', verbose_name='SUM({})'.format(self.column_name), json=json.dumps({ 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) + ) if self.avg and self.is_num: mt = corrected_type.lower() + 'Avg' name = 'avg__' + self.column_name - metrics.append(DruidMetric( + metrics[name] = DruidMetric( metric_name=name, metric_type='avg', verbose_name='AVG({})'.format(self.column_name), json=json.dumps({ 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) + ) if self.min and self.is_num: mt = corrected_type.lower() + 'Min' name = 'min__' + self.column_name - metrics.append(DruidMetric( + metrics[name] = DruidMetric( metric_name=name, metric_type='min', verbose_name='MIN({})'.format(self.column_name), json=json.dumps({ 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) + ) if self.max and self.is_num: mt = corrected_type.lower() + 'Max' name = 'max__' + self.column_name - metrics.append(DruidMetric( + metrics[name] = DruidMetric( metric_name=name, metric_type='max', verbose_name='MAX({})'.format(self.column_name), json=json.dumps({ 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) + ) if self.count_distinct: name = 'count_distinct__' + self.column_name if self.type == 'hyperUnique' or self.type == 'thetaSketch': - metrics.append(DruidMetric( + metrics[name] = DruidMetric( metric_name=name, verbose_name='COUNT(DISTINCT {})'.format(self.column_name), metric_type=self.type, @@ -230,10 +319,9 @@ def generate_metrics(self): 'name': name, 'fieldName': self.column_name }) - )) + ) else: - mt = 'count_distinct' - metrics.append(DruidMetric( + metrics[name] = DruidMetric( metric_name=name, verbose_name='COUNT(DISTINCT {})'.format(self.column_name), metric_type='count_distinct', @@ -241,22 +329,25 @@ def generate_metrics(self): 'type': 'cardinality', 'name': name, 'fieldNames': [self.column_name]}) - )) - session = get_session() - new_metrics = [] - for metric in metrics: - m = ( - session.query(M) - .filter(M.metric_name == metric.metric_name) - .filter(M.datasource_name == self.datasource_name) - .filter(DruidCluster.cluster_name == self.datasource.cluster_name) - .first() - ) + ) + return metrics + + def generate_metrics(self): + """Generate metrics based on the column metadata""" + metrics = self.get_metrics() + dbmetrics = ( + db.session.query(DruidMetric) + .filter(DruidCluster.cluster_name == self.datasource.cluster_name) + .filter(DruidMetric.datasource_name == self.datasource_name) + .filter(or_( + DruidMetric.metric_name == m for m in metrics + )) + ) + dbmetrics = {metric.metric_name: metric for metric in dbmetrics} + for metric in metrics.values(): metric.datasource_name = self.datasource_name - if not m: - new_metrics.append(metric) - session.add(metric) - session.flush() + if not dbmetrics.get(metric.metric_name, None): + db.session.add(metric) @classmethod def import_obj(cls, i_column): @@ -474,6 +565,7 @@ def int_or_0(v): def latest_metadata(self): """Returns segment metadata from the latest segment""" + logging.info("Syncing datasource [{}]".format(self.datasource_name)) client = self.cluster.get_pydruid_client() results = client.time_boundary(datasource=self.datasource_name) if not results: @@ -485,31 +577,33 @@ def latest_metadata(self): # realtime segments, which triggered a bug (fixed in druid 0.8.2). # https://groups.google.com/forum/#!topic/druid-user/gVCqqspHqOQ lbound = (max_time - timedelta(days=7)).isoformat() - rbound = max_time.isoformat() if not self.version_higher(self.cluster.druid_version, '0.8.2'): rbound = (max_time - timedelta(1)).isoformat() + else: + rbound = max_time.isoformat() segment_metadata = None try: segment_metadata = client.segment_metadata( datasource=self.datasource_name, intervals=lbound + '/' + rbound, merge=self.merge_flag, - analysisTypes=conf.get('DRUID_ANALYSIS_TYPES')) + analysisTypes=[]) except Exception as e: logging.warning("Failed first attempt to get latest segment") logging.exception(e) if not segment_metadata: # if no segments in the past 7 days, look at all segments lbound = datetime(1901, 1, 1).isoformat()[:10] - rbound = datetime(2050, 1, 1).isoformat()[:10] if not self.version_higher(self.cluster.druid_version, '0.8.2'): rbound = datetime.now().isoformat() + else: + rbound = datetime(2050, 1, 1).isoformat()[:10] try: segment_metadata = client.segment_metadata( datasource=self.datasource_name, intervals=lbound + '/' + rbound, merge=self.merge_flag, - analysisTypes=conf.get('DRUID_ANALYSIS_TYPES')) + analysisTypes=[]) except Exception as e: logging.warning("Failed 2nd attempt to get latest segment") logging.exception(e) @@ -517,17 +611,37 @@ def latest_metadata(self): return segment_metadata[-1]['columns'] def generate_metrics(self): - for col in self.columns: - col.generate_metrics() + self.generate_metrics_for(self.columns) + + def generate_metrics_for(self, columns): + metrics = {} + for col in columns: + metrics.update(col.get_metrics()) + dbmetrics = ( + db.session.query(DruidMetric) + .filter(DruidCluster.cluster_name == self.cluster_name) + .filter(DruidMetric.datasource_name == self.datasource_name) + .filter(or_(DruidMetric.metric_name == m for m in metrics)) + ) + dbmetrics = {metric.metric_name: metric for metric in dbmetrics} + for metric in metrics.values(): + metric.datasource_name = self.datasource_name + if not dbmetrics.get(metric.metric_name, None): + with db.session.no_autoflush: + db.session.add(metric) @classmethod - def sync_to_db_from_config(cls, druid_config, user, cluster): + def sync_to_db_from_config( + cls, + druid_config, + user, + cluster, + refresh=True): """Merges the ds config from druid_config into one stored in the db.""" - session = db.session() + session = db.session datasource = ( session.query(cls) - .filter_by( - datasource_name=druid_config['name']) + .filter_by(datasource_name=druid_config['name']) .first() ) # Create a new datasource. @@ -540,16 +654,18 @@ def sync_to_db_from_config(cls, druid_config, user, cluster): created_by_fk=user.id, ) session.add(datasource) + elif not refresh: + return dimensions = druid_config['dimensions'] + col_objs = ( + session.query(DruidColumn) + .filter(DruidColumn.datasource_name == druid_config['name']) + .filter(or_(DruidColumn.column_name == dim for dim in dimensions)) + ) + col_objs = {col.column_name: col for col in col_objs} for dim in dimensions: - col_obj = ( - session.query(DruidColumn) - .filter_by( - datasource_name=druid_config['name'], - column_name=dim) - .first() - ) + col_obj = col_objs.get(dim, None) if not col_obj: col_obj = DruidColumn( datasource_name=druid_config['name'], @@ -562,6 +678,13 @@ def sync_to_db_from_config(cls, druid_config, user, cluster): ) session.add(col_obj) # Import Druid metrics + metric_objs = ( + session.query(DruidMetric) + .filter(DruidMetric.datasource_name == druid_config['name']) + .filter(or_(DruidMetric.metric_name == spec['name'] + for spec in druid_config["metrics_spec"])) + ) + metric_objs = {metric.metric_name: metric for metric in metric_objs} for metric_spec in druid_config["metrics_spec"]: metric_name = metric_spec["name"] metric_type = metric_spec["type"] @@ -575,12 +698,7 @@ def sync_to_db_from_config(cls, druid_config, user, cluster): "fieldName": metric_name, }) - metric_obj = ( - session.query(DruidMetric) - .filter_by( - datasource_name=druid_config['name'], - metric_name=metric_name) - ).first() + metric_obj = metric_objs.get(metric_name, None) if not metric_obj: metric_obj = DruidMetric( metric_name=metric_name, @@ -595,58 +713,6 @@ def sync_to_db_from_config(cls, druid_config, user, cluster): session.add(metric_obj) session.commit() - @classmethod - def sync_to_db(cls, name, cluster, merge): - """Fetches metadata for that datasource and merges the Superset db""" - logging.info("Syncing Druid datasource [{}]".format(name)) - session = get_session() - datasource = session.query(cls).filter_by(datasource_name=name).first() - if not datasource: - datasource = cls(datasource_name=name) - session.add(datasource) - flasher("Adding new datasource [{}]".format(name), "success") - else: - flasher("Refreshing datasource [{}]".format(name), "info") - session.flush() - datasource.cluster = cluster - datasource.merge_flag = merge - session.flush() - - cols = datasource.latest_metadata() - if not cols: - logging.error("Failed at fetching the latest segment") - return - for col in cols: - # Skip the time column - if col == "__time": - continue - col_obj = ( - session - .query(DruidColumn) - .filter_by(datasource_name=name, column_name=col) - .first() - ) - datatype = cols[col]['type'] - if not col_obj: - col_obj = DruidColumn(datasource_name=name, column_name=col) - session.add(col_obj) - if datatype == "STRING": - col_obj.groupby = True - col_obj.filterable = True - if datatype == "hyperUnique" or datatype == "thetaSketch": - col_obj.count_distinct = True - # If long or double, allow sum/min/max - if datatype == "LONG" or datatype == "DOUBLE": - col_obj.sum = True - col_obj.min = True - col_obj.max = True - if col_obj: - col_obj.type = cols[col]['type'] - session.flush() - col_obj.datasource = datasource - col_obj.generate_metrics() - session.flush() - @staticmethod def time_offset(granularity): if granularity == 'week_ending_saturday': diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py index f64b6c185009d..42fbdbbe987b0 100644 --- a/superset/connectors/druid/views.py +++ b/superset/connectors/druid/views.py @@ -235,17 +235,17 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin): # noqa } def pre_add(self, datasource): - number_of_existing_datasources = db.session.query( - sqla.func.count('*')).filter( - models.DruidDatasource.datasource_name == - datasource.datasource_name, - models.DruidDatasource.cluster_name == datasource.cluster.id - ).scalar() - - # table object is already added to the session - if number_of_existing_datasources > 1: - raise Exception(get_datasource_exist_error_mgs( - datasource.full_name)) + with db.session.no_autoflush: + query = ( + db.session.query(models.DruidDatasource) + .filter(models.DruidDatasource.datasource_name == + datasource.datasource_name, + models.DruidDatasource.cluster_name == + datasource.cluster.id) + ) + if db.session.query(query.exists()).scalar(): + raise Exception(get_datasource_exist_error_mgs( + datasource.full_name)) def post_add(self, datasource): datasource.generate_metrics() @@ -273,14 +273,14 @@ class Druid(BaseSupersetView): @has_access @expose("/refresh_datasources/") - def refresh_datasources(self): + def refresh_datasources(self, refreshAll=True): """endpoint that refreshes druid datasources metadata""" session = db.session() DruidCluster = ConnectorRegistry.sources['druid'].cluster_class for cluster in session.query(DruidCluster).all(): cluster_name = cluster.cluster_name try: - cluster.refresh_datasources() + cluster.refresh_datasources(refreshAll=refreshAll) except Exception as e: flash( "Error while processing cluster '{}'\n{}".format( @@ -296,8 +296,25 @@ def refresh_datasources(self): session.commit() return redirect("/druiddatasourcemodelview/list/") + @has_access + @expose("/scan_new_datasources/") + def scan_new_datasources(self): + """ + Calling this endpoint will cause a scan for new + datasources only and add them. + """ + return self.refresh_datasources(refreshAll=False) + appbuilder.add_view_no_menu(Druid) +appbuilder.add_link( + "Scan New Datasources", + label=__("Scan New Datasources"), + href='/druid/scan_new_datasources/', + category='Sources', + category_label=__("Sources"), + category_icon='fa-database', + icon="fa-refresh") appbuilder.add_link( "Refresh Druid Metadata", label=__("Refresh Druid Metadata"), diff --git a/tests/druid_tests.py b/tests/druid_tests.py index 637afe984ce02..c506ebf596edb 100644 --- a/tests/druid_tests.py +++ b/tests/druid_tests.py @@ -16,6 +16,9 @@ from .base_tests import SupersetTestCase +class PickableMock(Mock): + def __reduce__(self): + return (Mock, ()) SEGMENT_METADATA = [{ "id": "some_id", @@ -98,8 +101,8 @@ def test_client(self, PyDruid): metadata_last_refreshed=datetime.now()) db.session.add(cluster) - cluster.get_datasources = Mock(return_value=['test_datasource']) - cluster.get_druid_version = Mock(return_value='0.9.1') + cluster.get_datasources = PickableMock(return_value=['test_datasource']) + cluster.get_druid_version = PickableMock(return_value='0.9.1') cluster.refresh_datasources() cluster.refresh_datasources(merge_flag=True) datasource_id = cluster.datasources[0].id @@ -303,11 +306,14 @@ def test_sync_druid_perm(self, PyDruid): metadata_last_refreshed=datetime.now()) db.session.add(cluster) - cluster.get_datasources = Mock(return_value=['test_datasource']) - cluster.get_druid_version = Mock(return_value='0.9.1') + cluster.get_datasources = PickableMock(return_value=['test_datasource']) + cluster.get_druid_version = PickableMock(return_value='0.9.1') cluster.refresh_datasources() datasource_id = cluster.datasources[0].id + cluster.datasources[0].merge_flag = True + metadata = cluster.datasources[0].latest_metadata() + self.assertEqual(len(metadata), 4) db.session.commit() view_menu_name = cluster.datasources[0].get_perm()