diff --git a/superset/assets/javascripts/explore/stores/visTypes.js b/superset/assets/javascripts/explore/stores/visTypes.js index d1807aeb9db85..a8819d331ade4 100644 --- a/superset/assets/javascripts/explore/stores/visTypes.js +++ b/superset/assets/javascripts/explore/stores/visTypes.js @@ -1147,6 +1147,7 @@ export const visTypes = { controlSetRows: [ ['all_columns_x'], ['row_limit'], + ['groupby'], ], }, { @@ -1156,20 +1157,26 @@ export const visTypes = { ['color_scheme'], ['link_length'], ['x_axis_label', 'y_axis_label'], + ['global_opacity'], ['normalized'], ], }, ], controlOverrides: { all_columns_x: { - label: t('Numeric Column'), - description: t('Select the numeric column to draw the histogram'), + label: t('Numeric Columns'), + description: t('Select the numeric columns to draw the histogram'), + multi: true, }, link_length: { label: t('No of Bins'), description: t('Select number of bins for the histogram'), default: 5, }, + global_opacity: { + description: t('Opacity of the bars. Between 0 and 1'), + renderTrigger: true, + }, }, }, diff --git a/superset/assets/visualizations/histogram.js b/superset/assets/visualizations/histogram.js index 2f292311102c9..8ab9c1188f91d 100644 --- a/superset/assets/visualizations/histogram.js +++ b/superset/assets/visualizations/histogram.js @@ -1,4 +1,5 @@ import d3 from 'd3'; +import nv from 'nvd3'; import { getColorFromScheme } from '../javascripts/modules/colors'; require('./histogram.css'); @@ -10,6 +11,7 @@ function histogram(slice, payload) { const normalized = slice.formData.normalized; const xAxisLabel = slice.formData.x_axis_label; const yAxisLabel = slice.formData.y_axis_label; + const opacity = slice.formData.global_opacity; const draw = function () { // Set Margins @@ -39,18 +41,28 @@ function histogram(slice, payload) { .scale(y) .orient('left') .ticks(numTicks, 's'); - // Calculate bins for the data - let bins = d3.layout.histogram().bins(numBins)(data); - if (normalized) { - const total = data.length; - bins = bins.map(d => ({ ...d, y: d.y / total })); - } // Set the x-values - const max = d3.max(data); - const min = d3.min(data); + const max = d3.max(data, d => d3.max(d.values)); + const min = d3.min(data, d => d3.min(d.values)); x.domain([min, max]) .range([0, width], 0.1); + + // Calculate bins for the data + let bins = []; + data.forEach((d) => { + let b = d3.layout.histogram().bins(numBins)(d.values); + const color = getColorFromScheme(d.key, slice.formData.color_scheme); + const w = d3.max([(x(b[0].dx) - x(0)) - 1, 0]); + const key = d.key; + // normalize if necessary + if (normalized) { + const total = d.values.length; + b = b.map(v => ({ ...v, y: v.y / total })); + } + bins = bins.concat(b.map(v => ({ ...v, color, width: w, key, opacity }))); + }); + // Set the y-values y.domain([0, d3.max(bins, d => d.y)]) .range([height, 0]); @@ -80,17 +92,38 @@ function histogram(slice, payload) { svg.attr('width', slice.width()) .attr('height', slice.height()); - // Create the bars in the svg - const bar = svg.select('.bars').selectAll('.bar').data(bins); - bar.enter().append('rect'); - bar.exit().remove(); - // Set the Height and Width for each bar - bar.attr('width', (x(bins[0].dx) - x(0)) - 1) - .attr('x', d => x(d.x)) - .attr('y', d => y(d.y)) - .attr('height', d => y.range()[0] - y(d.y)) - .style('fill', getColorFromScheme(1, slice.formData.color_scheme)) - .order(); + // make legend + const legend = nv.models.legend() + .color(d => getColorFromScheme(d.key, slice.formData.color_scheme)) + .width(width); + const gLegend = gEnter.append('g').attr('class', 'nv-legendWrap') + .attr('transform', 'translate(0,' + (-margin.top) + ')') + .datum(data.map(d => ({ ...d, disabled: false }))); + + // function to draw bars and legends + function update(selectedBins) { + // Create the bars in the svg + const bar = svg.select('.bars') + .selectAll('rect') + .data(selectedBins, d => d.key + d.x); + // Set the Height and Width for each bar + bar.enter() + .append('rect') + .attr('width', d => d.width) + .attr('x', d => x(d.x)) + .style('fill', d => d.color) + .style('fill-opacity', d => d.opacity) + .attr('y', d => y(d.y)) + .attr('height', d => y.range()[0] - y(d.y)); + bar.exit() + .attr('y', y(0)) + .attr('height', 0) + .remove(); + // apply legend + gLegend.call(legend); + } + + update(bins); // Update the x-axis svg.append('g') @@ -109,6 +142,14 @@ function histogram(slice, payload) { .filter(function (d) { return d; }) .classed('minor', true); + // set callback on legend toggle + legend.dispatch.on('stateChange', function (newState) { + const activeKeys = data + .filter((d, i) => !newState.disabled[i]) + .map(d => d.key); + update(bins.filter(d => activeKeys.indexOf(d.key) >= 0)); + }); + // add axis labels if passed if (xAxisLabel) { svg.append('text') diff --git a/superset/viz.py b/superset/viz.py index 87dd1a9ebe01c..f27406b758a7a 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -17,6 +17,7 @@ from itertools import product import logging import math +import re import traceback import uuid @@ -1364,15 +1365,31 @@ def query_obj(self): d = super(HistogramViz, self).query_obj() d['row_limit'] = self.form_data.get( 'row_limit', int(config.get('VIZ_ROW_LIMIT'))) - numeric_column = self.form_data.get('all_columns_x') - if numeric_column is None: - raise Exception(_('Must have one numeric column specified')) - d['columns'] = [numeric_column] + numeric_columns = self.form_data.get('all_columns_x') + if numeric_columns is None: + raise Exception(_('Must have at least one numeric column specified')) + self.columns = numeric_columns + d['columns'] = numeric_columns + self.groupby + # override groupby entry to avoid aggregation + d['groupby'] = [] return d def get_data(self, df): """Returns the chart data""" - chart_data = df[df.columns[0]].values.tolist() + chart_data = [] + if len(self.groupby) > 0: + groups = df.groupby(self.groupby) + else: + groups = [((), df)] + for keys, data in groups: + if isinstance(keys, str): + keys = (keys,) + # removing undesirable characters + keys = [re.sub(r'\W+', r'_', k) for k in keys] + chart_data.extend([{ + 'key': '__'.join([c] + keys), + 'values': data[c].tolist()} + for c in self.columns]) return chart_data