Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Histogram breaks #767

Merged
merged 7 commits into from
Aug 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# CHANGELOG

## 0.9.1dev
* [Feature] Added `--breaks/-B` to ggplot histogram for specifying breaks (#719)
* [Fix] Fix boxplot for duckdb native ([#728](https://github.com/ploomber/jupysql/issues/728))
* [Doc] Add Redshift tutorial
* [Feature] Adds Redshift support for `%sqlplot boxplot`
Expand Down
10 changes: 10 additions & 0 deletions doc/api/magic-plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ Shortcut: `%sqlplot hist`

`-b`/`--bins` (default: `50`) Number of bins

`-B`/`--breaks` Custom bin intervals

`-w`/`--with` Use a previously saved query as input data

+++
Expand All @@ -151,6 +153,14 @@ When plotting a histogram, it divides a range with the number of bins - 1 to cal
%sqlplot histogram --table penguins.csv --column body_mass_g --bins 100
```

### Specifying breaks

edublancas marked this conversation as resolved.
Show resolved Hide resolved
Breaks allow you to set custom intervals for a histogram. You can specify breaks by passing desired each end and break points separated by whitespace after `-B/--breaks`. Since those break points define a range of data points to plot, bar width, and number of bars in a histogram, make sure to pass more than 1 point that is strictly increasing and includes at least one data point. Note that using both `-b/--bins` and `-B/--breaks` isn't allowed.

```{code-cell} ipython3
%sqlplot histogram --table penguins.csv --column body_mass_g --breaks 3200 3400 3600 3800 4000 4200 4400 4600 4800
```

### Multiple columns

```{code-cell} ipython3
Expand Down
7 changes: 6 additions & 1 deletion src/sql/ggplot/geom/geom_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@ class geom_histogram(geom):

cmap : str, default 'viridis
Apply a color map to the stacked graph

breaks : list
Divide bins with custom intervals
"""

def __init__(self, bins=None, fill=None, cmap=None, **kwargs):
def __init__(self, bins=None, fill=None, cmap=None, breaks=None, **kwargs):
edublancas marked this conversation as resolved.
Show resolved Hide resolved
self.bins = bins
self.fill = fill
self.cmap = cmap
self.breaks = breaks
super().__init__(**kwargs)

@telemetry.log_call("ggplot-histogram")
Expand All @@ -40,5 +44,6 @@ def draw(self, gg, ax=None, facet=None):
edgecolor=gg.mapping.color,
facet=facet,
ax=ax or gg.axs[0],
breaks=self.breaks,
)
return gg
18 changes: 17 additions & 1 deletion src/sql/magic_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ class SqlPlotMagic(Magics, Configurable):
action="store_true",
help="Show number of observations",
)
@argument(
"-B",
"--breaks",
type=float,
nargs="+",
help="Histogram breaks",
)
@modify_exceptions
def execute(self, line="", cell="", local_ns=None):
"""
Expand Down Expand Up @@ -103,12 +110,21 @@ def execute(self, line="", cell="", local_ns=None):
conn=None,
)
elif cmd.args.line[0] in {"hist", "histogram"}:
# to avoid passing bins default value when breaks are given by a user
bin_specified = " --bins " in line or " -b " in line
edublancas marked this conversation as resolved.
Show resolved Hide resolved
breaks_specified = " --breaks " in line or " -B " in line
if breaks_specified and not bin_specified:
bins = None
else:
bins = cmd.args.bins

return plot.histogram(
table=table,
column=column,
bins=cmd.args.bins,
bins=bins,
with_=with_,
conn=None,
breaks=cmd.args.breaks,
)
elif cmd.args.line[0] in {"bar"}:
return plot.bar(
Expand Down
175 changes: 137 additions & 38 deletions src/sql/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def _are_numeric_values(*values):
def _get_bar_width(ax, bins, bin_size):
"""
Return a single bar width based on number of bins
or a list of bar widths if `breaks` is given.
If bins values are str, calculate value based on figure size.

Parameters
Expand All @@ -281,16 +282,16 @@ def _get_bar_width(ax, bins, bin_size):
bins : tuple
Contains bins' midpoints as float

bin_size : int or None
bin_size : int or list or None
Calculated bin_size from the _histogram function
or from consecutive differences in `breaks`

Returns
-------
width : float
A single bar width
"""

if _are_numeric_values(bin_size):
if _are_numeric_values(bin_size) or isinstance(bin_size, list):
width = bin_size
else:
fig = plt.gcf()
Expand All @@ -316,6 +317,7 @@ def histogram(
edgecolor=None,
ax=None,
facet=None,
breaks=None,
):
"""Plot histogram

Expand Down Expand Up @@ -358,6 +360,21 @@ def histogram(
"""
if not conn:
conn = sql.connection.ConnectionManager.current
if isinstance(breaks, list):
if len(breaks) < 2:
raise exceptions.ValueError(
f"Breaks given : {breaks}. When using breaks, please ensure "
neelasha23 marked this conversation as resolved.
Show resolved Hide resolved
"to specify at least two points."
)
if not all([b2 > b1 for b1, b2 in zip(breaks[:-1], breaks[1:])]):
raise exceptions.ValueError(
f"Breaks given : {breaks}. When using breaks, please ensure that "
"breaks are strictly increasing."
neelasha23 marked this conversation as resolved.
Show resolved Hide resolved
)
if bins:
raise exceptions.ValueError(
"Both bins and breaks are specified. Must specify only one of them."
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error repeats itself. Can we delete the error handler from here?

Also, please change ValueError to exceptions.ValueError

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted the error handler in magic_plot and changed ValueError to exceptions.ValueError


ax = ax or plt.gca()
payload["connection_info"] = conn._get_database_information()
Expand All @@ -375,10 +392,20 @@ def histogram(
if column is None or len(column) == 0:
raise ValueError("Column name has not been specified")

bin_, height, bin_size = _histogram(table, column, bins, with_=with_, conn=conn)
bin_, height, bin_size = _histogram(
table, column, bins, with_=with_, conn=conn, breaks=breaks
)
width = _get_bar_width(ax, bin_, bin_size)
data = _histogram_stacked(
table, column, category, bin_, bin_size, with_=with_, conn=conn, facet=facet
table,
column,
category,
bin_,
bin_size,
with_=with_,
conn=conn,
facet=facet,
breaks=breaks,
)
cmap = plt.get_cmap(cmap or "viridis")
norm = Normalize(vmin=0, vmax=len(data))
Expand Down Expand Up @@ -422,7 +449,7 @@ def histogram(
ax.legend(handles[::-1], labels[::-1])
elif isinstance(column, str):
bin_, height, bin_size = _histogram(
table, column, bins, with_=with_, conn=conn, facet=facet
table, column, bins, with_=with_, conn=conn, facet=facet, breaks=breaks
)
width = _get_bar_width(ax, bin_, bin_size)

Expand All @@ -439,9 +466,13 @@ def histogram(
ax.set_xlabel(column)

else:
if breaks and len(column) > 1:
raise exceptions.UsageError(
"Multiple columns don't support breaks. Please use bins instead."
)
for i, col in enumerate(column):
bin_, height, bin_size = _histogram(
table, col, bins, with_=with_, conn=conn, facet=facet
table, col, bins, with_=with_, conn=conn, facet=facet, breaks=breaks
)
width = _get_bar_width(ax, bin_, bin_size)

Expand Down Expand Up @@ -474,7 +505,7 @@ def histogram(


@modify_exceptions
def _histogram(table, column, bins, with_=None, conn=None, facet=None):
def _histogram(table, column, bins, with_=None, conn=None, facet=None, breaks=None):
"""Compute bins and heights"""
if not conn:
conn = sql.connection.ConnectionManager.current
Expand All @@ -493,33 +524,85 @@ def _histogram(table, column, bins, with_=None, conn=None, facet=None):
bin_size = None

if _are_numeric_values(min_, max_):
if not isinstance(bins, int):
if breaks:
if min_ > breaks[-1]:
raise exceptions.UsageError(
f"All break points are lower than the min data point of {min_}."
)
elif max_ < breaks[0]:
raise exceptions.UsageError(
f"All break points are higher than the max data point of {max_}."
)

cases, bin_size = [], []
for b_start, b_end in zip(breaks[:-1], breaks[1:]):
case = f"WHEN {{{{column}}}} > {b_start} AND {{{{column}}}} <= {b_end} \
THEN {(b_start+b_end)/2}"
cases.append(case)
bin_size.append(b_end - b_start)
cases[0] = cases[0].replace(">", ">=", 1)
bin_midpoints = [
(b_start + b_end) / 2 for b_start, b_end in zip(breaks[:-1], breaks[1:])
]
all_bins = " union ".join([f"select {mid} as bin" for mid in bin_midpoints])

# Group data based on the intervals in breaks
# Left join is used to ensure count=0
template_ = (
"select all_bins.bin, coalesce(count_table.count, 0) as count "
f"from ({all_bins}) as all_bins "
"left join ("
f"select case {' '.join(cases)} end as bin, "
"count(*) as count "
'from "{{table}}" '
"{{filter_query}} "
"group by bin) "
"as count_table on all_bins.bin = count_table.bin "
"order by all_bins.bin;"
)

breaks_filter_query = (
f'"{column}" >= {breaks[0]} and "{column}" <= {breaks[-1]}'
)
filter_query = _filter_aggregate(
filter_query_1, filter_query_2, breaks_filter_query
)

if use_backticks:
template_ = template_.replace('"', "`")

template = Template(template_)

query = template.render(
table=table, column=column, filter_query=filter_query
)
elif not isinstance(bins, int):
raise ValueError(
f"bins are '{bins}'. Please specify a valid number of bins."
)
else:
# Use bins - 1 instead of bins and round half down instead of floor
# to mimic right-closed histogram intervals in R ggplot
range_ = max_ - min_
bin_size = range_ / (bins - 1)
template_ = """
select
ceiling("{{column}}"/{{bin_size}} - 0.5)*{{bin_size}} as bin,
count(*) as count
from "{{table}}"
{{filter_query}}
group by bin
order by bin;
"""

# Use bins - 1 instead of bins and round half down instead of floor
# to mimic right-closed histogram intervals in R ggplot
range_ = max_ - min_
bin_size = range_ / (bins - 1)
template_ = """
select
ceiling("{{column}}"/{{bin_size}} - 0.5)*{{bin_size}} as bin,
count(*) as count
from "{{table}}"
{{filter_query}}
group by bin
order by bin;
"""

if use_backticks:
template_ = template_.replace('"', "`")
if use_backticks:
template_ = template_.replace('"', "`")

template = Template(template_)
template = Template(template_)

query = template.render(
table=table, column=column, bin_size=bin_size, filter_query=filter_query
)
query = template.render(
table=table, column=column, bin_size=bin_size, filter_query=filter_query
)
else:
template_ = """
select
Expand Down Expand Up @@ -554,29 +637,45 @@ def _histogram_stacked(
with_=None,
conn=None,
facet=None,
breaks=None,
):
"""Compute the corresponding heights of each bin based on the category"""
if not conn:
conn = sql.connection.ConnectionManager.current

cases = []
tolerance = bin_size / 1000 # Use to avoid floating point error
for bin in bins:
# Use round half down instead of floor to mimic
# right-closed histogram intervals in R ggplot
case = (
f"SUM(CASE WHEN ABS(CEILING({column}/{bin_size} - 0.5)*{bin_size} "
f"- {bin}) <= {tolerance} THEN 1 ELSE 0 END) AS '{bin}',"
if breaks:
breaks_filter_query = (
f'"{column}" >= {breaks[0]} and "{column}" <= {breaks[-1]}'
)
cases.append(case)
for b_start, b_end in zip(breaks[:-1], breaks[1:]):
case = f'SUM(CASE WHEN {column} > {b_start} AND {column} <= {b_end} \
THEN 1 ELSE 0 END) AS "{(b_start+b_end)/2}",'
cases.append(case)
cases[0] = cases[0].replace(">", ">=", 1)
else:
tolerance = bin_size / 1000 # Use to avoid floating point error
for bin in bins:
# Use round half down instead of floor to mimic
# right-closed histogram intervals in R ggplot
case = (
f"SUM(CASE WHEN ABS(CEILING({column}/{bin_size} - 0.5)*{bin_size} "
f"- {bin}) <= {tolerance} THEN 1 ELSE 0 END) AS '{bin}',"
)
cases.append(case)

cases = " ".join(cases)

filter_query_1 = f'"{column}" IS NOT NULL'

filter_query_2 = f"{facet['key']} == '{facet['value']}'" if facet else None

filter_query = _filter_aggregate(filter_query_1, filter_query_2)
if breaks:
filter_query = _filter_aggregate(
filter_query_1, filter_query_2, breaks_filter_query
)
else:
filter_query = _filter_aggregate(filter_query_1, filter_query_2)

template = Template(
"""
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions src/tests/integration/test_generic_db_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,16 @@ def test_telemetry_execute_command_has_connection_info(
"%sqlplot histogram --with plot_something_subset --table\
plot_something_subset --column x --bins 10"
),
(
"%sqlplot histogram --with plot_something_subset --table\
plot_something_subset --column x --breaks 0 2 3 4 5"
),
],
ids=[
"histogram",
"hist",
"histogram-bins",
"histogram-breaks",
],
)
@pytest.mark.parametrize(
Expand Down
Loading
Loading