Skip to content

Commit

Permalink
fix: support non-string groupbys for pie chart (#10493)
Browse files Browse the repository at this point in the history
* chore: add unit tests to pie chart

* refine logic for floats and nans and add more tests
  • Loading branch information
villebro authored Jul 31, 2020
1 parent 7645fc8 commit 9d9c348
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 1 deletion.
34 changes: 33 additions & 1 deletion superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,11 +1547,43 @@ class DistributionPieViz(NVD3Viz):
is_timeseries = False

def get_data(self, df: pd.DataFrame) -> VizData:
def _label_aggfunc(labels: pd.Series) -> str:
"""
Convert a single or multi column label into a single label, replacing
null values with `NULL_STRING` and joining multiple columns together
with a comma. Examples:
>>> _label_aggfunc(pd.Series(["abc"]))
'abc'
>>> _label_aggfunc(pd.Series([1]))
'1'
>>> _label_aggfunc(pd.Series(["abc", "def"]))
'abc, def'
>>> # note: integer floats are stripped of decimal digits
>>> _label_aggfunc(pd.Series([0.1, 2.0, 0.3]))
'0.1, 2, 0.3'
>>> _label_aggfunc(pd.Series([1, None, "abc", 0.8], dtype="object"))
'1, <NULL>, abc, 0.8'
"""
label_list: List[str] = []
for label in labels:
if isinstance(label, str):
label_recast = label
elif label is None or isinstance(label, float) and math.isnan(label):
label_recast = NULL_STRING
elif isinstance(label, float) and label.is_integer():
label_recast = str(int(label))
else:
label_recast = str(label)
label_list.append(label_recast)

return ", ".join(label_list)

if df.empty:
return None
metric = self.metric_labels[0]
df = pd.DataFrame(
{"x": df[self.groupby].agg(func=", ".join, axis=1), "y": df[metric]}
{"x": df[self.groupby].agg(func=_label_aggfunc, axis=1), "y": df[metric]}
)
df.sort_values(by="y", ascending=False, inplace=True)
return df.to_dict(orient="records")
Expand Down
58 changes: 58 additions & 0 deletions tests/viz_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
from math import nan
from unittest.mock import Mock, patch
from typing import Any, Dict, List, Set

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1330,3 +1331,60 @@ def test_get_aggfunc_non_numeric(self):
viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "min"})
== "min"
)


class TestDistributionPieViz(SupersetTestCase):
base_df = pd.DataFrame(
data={
"intcol": [1, 2, 3, 4, None],
"floatcol": [1.0, 0.2, 0.3, 0.4, None],
"strcol_a": ["a", "a", "a", "a", None],
"strcol": ["a", "b", "c", None, "d"],
}
)

@staticmethod
def get_cols(data: List[Dict[str, Any]]) -> Set[str]:
return set([row["x"] for row in data])

def test_bool_groupby(self):
datasource = self.get_datasource_mock()
df = pd.DataFrame(data={"intcol": [1, 2, None], "boolcol": [True, None, False]})

pie_viz = viz.DistributionPieViz(
datasource, {"metrics": ["intcol"], "groupby": ["boolcol"]},
)
data = pie_viz.get_data(df)
assert self.get_cols(data) == {"True", "False", "<NULL>"}

def test_string_groupby(self):
datasource = self.get_datasource_mock()
pie_viz = viz.DistributionPieViz(
datasource, {"metrics": ["floatcol"], "groupby": ["strcol"]},
)
data = pie_viz.get_data(self.base_df)
assert self.get_cols(data) == {"<NULL>", "a", "b", "c", "d"}

def test_int_groupby(self):
datasource = self.get_datasource_mock()
pie_viz = viz.DistributionPieViz(
datasource, {"metrics": ["floatcol"], "groupby": ["intcol"]},
)
data = pie_viz.get_data(self.base_df)
assert self.get_cols(data) == {"<NULL>", "1", "2", "3", "4"}

def test_float_groupby(self):
datasource = self.get_datasource_mock()
pie_viz = viz.DistributionPieViz(
datasource, {"metrics": ["intcol"], "groupby": ["floatcol"]},
)
data = pie_viz.get_data(self.base_df)
assert self.get_cols(data) == {"<NULL>", "1", "0.2", "0.3", "0.4"}

def test_multi_groupby(self):
datasource = self.get_datasource_mock()
pie_viz = viz.DistributionPieViz(
datasource, {"metrics": ["floatcol"], "groupby": ["intcol", "strcol"]},
)
data = pie_viz.get_data(self.base_df)
assert self.get_cols(data) == {"1, a", "2, b", "3, c", "4, <NULL>", "<NULL>, d"}

0 comments on commit 9d9c348

Please sign in to comment.