diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f6950357..d110f755b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## 0.10.0dev +* [Feature] Add `--binwidth/-W` to ggplot histogram for specifying binwidth (#784) * [Feature] Add `%sqlcmd profile` support for DBAPI connections (#743) * [Fix] Perform `ROLLBACK` when SQLAlchemy raises `PendingRollbackError` * [Fix] Perform `ROLLBACK` when `psycopg2` raises `current transaction is aborted, commands ignored until end of transaction block` diff --git a/doc/api/magic-plot.md b/doc/api/magic-plot.md index 7305d7331..63d9e0f95 100644 --- a/doc/api/magic-plot.md +++ b/doc/api/magic-plot.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.7 + jupytext_version: 1.15.0 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -128,8 +128,14 @@ Shortcut: `%sqlplot hist` `-B`/`--breaks` Custom bin intervals +`-W`/`--binwidth` Width of each bin + `-w`/`--with` Use a previously saved query as input data +```{note} +When using -b/--bins, -B/--breaks, or -W/--binwidth, you can only specify one of them. If none of them is specified, the default value for -b/--bins will be used. +``` + +++ Histogram supports NULL values by skipping them. Now we can @@ -147,7 +153,9 @@ When plotting a histogram, it divides a range with the number of bins - 1 to cal +++ -### Number of bins +### Specifying bins + +Bins allow you to set the number of bins in a histogram, and it's useful when you are interested in the overall distribution. ```{code-cell} ipython3 %sqlplot histogram --table penguins.csv --column body_mass_g --bins 100 @@ -155,12 +163,20 @@ When plotting a histogram, it divides a range with the number of bins - 1 to cal ### Specifying breaks -Breaks allow you to set custom intervals for a histogram. You can specify breaks by passing desired each end and break points separated by whitespace after `-B/--breaks`. Since those break points define a range of data points to plot, bar width, and number of bars in a histogram, make sure to pass more than 1 point that is strictly increasing and includes at least one data point. Note that using both `-b/--bins` and `-B/--breaks` isn't allowed. +Breaks allow you to set custom intervals for a histogram. It is useful when you want to view distribution within a specific range. You can specify breaks by passing desired each end and break points separated by whitespace after `-B/--breaks`. Since those break points define a range of data points to plot, bar width, and number of bars in a histogram, make sure to pass more than 1 point that is strictly increasing and includes at least one data point. ```{code-cell} ipython3 %sqlplot histogram --table penguins.csv --column body_mass_g --breaks 3200 3400 3600 3800 4000 4200 4400 4600 4800 ``` +### Specifying binwidth + +Binwidth allows you to set the width of bins in a histogram. It is useful when you directly aim to adjust the granularity of the histogram. To specify the binwidth, pass a desired width after `-W/--binwidth`. Since the binwidth determines details of distribution, make sure to pass a suitable positive numeric value based on your data. + +```{code-cell} ipython3 +%sqlplot histogram --table penguins.csv --column body_mass_g --binwidth 150 +``` + ### Multiple columns ```{code-cell} ipython3 diff --git a/src/sql/ggplot/geom/geom_histogram.py b/src/sql/ggplot/geom/geom_histogram.py index f52a0470b..34d0e8293 100644 --- a/src/sql/ggplot/geom/geom_histogram.py +++ b/src/sql/ggplot/geom/geom_histogram.py @@ -21,13 +21,19 @@ class geom_histogram(geom): breaks : list Divide bins with custom intervals + + binwidth : int or float + Width of each bin """ - def __init__(self, bins=None, fill=None, cmap=None, breaks=None, **kwargs): + def __init__( + self, bins=None, fill=None, cmap=None, breaks=None, binwidth=None, **kwargs + ): self.bins = bins self.fill = fill self.cmap = cmap self.breaks = breaks + self.binwidth = binwidth super().__init__(**kwargs) @telemetry.log_call("ggplot-histogram") @@ -45,5 +51,6 @@ def draw(self, gg, ax=None, facet=None): facet=facet, ax=ax or gg.axs[0], breaks=self.breaks, + binwidth=self.binwidth, ) return gg diff --git a/src/sql/magic_plot.py b/src/sql/magic_plot.py index 74d051f53..5c9ba07b8 100644 --- a/src/sql/magic_plot.py +++ b/src/sql/magic_plot.py @@ -66,6 +66,12 @@ class SqlPlotMagic(Magics, Configurable): nargs="+", help="Histogram breaks", ) + @argument( + "-W", + "--binwidth", + type=float, + help="Histogram binwidth", + ) @modify_exceptions def execute(self, line="", cell="", local_ns=None): """ @@ -110,13 +116,13 @@ def execute(self, line="", cell="", local_ns=None): conn=None, ) elif cmd.args.line[0] in {"hist", "histogram"}: - # to avoid passing bins default value when breaks are given by a user + # to avoid passing bins default value when breaks or binwidth is specified bin_specified = " --bins " in line or " -b " in line breaks_specified = " --breaks " in line or " -B " in line - if breaks_specified and not bin_specified: + binwidth_specified = " --binwidth " in line or " -W " in line + bins = cmd.args.bins + if not bin_specified and any([breaks_specified, binwidth_specified]): bins = None - else: - bins = cmd.args.bins return plot.histogram( table=table, @@ -125,6 +131,7 @@ def execute(self, line="", cell="", local_ns=None): with_=with_, conn=None, breaks=cmd.args.breaks, + binwidth=cmd.args.binwidth, ) elif cmd.args.line[0] in {"bar"}: return plot.bar( diff --git a/src/sql/plot.py b/src/sql/plot.py index 5b87daf17..1cd94bf0a 100644 --- a/src/sql/plot.py +++ b/src/sql/plot.py @@ -8,6 +8,8 @@ from sql import exceptions, display from sql.stats import _summary_stats +from sql.util import _are_numeric_values, validate_mutually_exclusive_args +from sql.display import message try: import matplotlib.pyplot as plt @@ -264,11 +266,7 @@ def _min_max(con, table, column, with_=None, use_backticks=False): return min_, max_ -def _are_numeric_values(*values): - return all([isinstance(value, (int, float)) for value in values]) - - -def _get_bar_width(ax, bins, bin_size): +def _get_bar_width(ax, bins, bin_size, binwidth): """ Return a single bar width based on number of bins or a list of bar widths if `breaks` is given. @@ -286,6 +284,9 @@ def _get_bar_width(ax, bins, bin_size): Calculated bin_size from the _histogram function or from consecutive differences in `breaks` + binwidth : int or float or None + Specified binwidth from a user + Returns ------- width : float @@ -293,6 +294,8 @@ def _get_bar_width(ax, bins, bin_size): """ if _are_numeric_values(bin_size) or isinstance(bin_size, list): width = bin_size + elif _are_numeric_values(binwidth): + width = binwidth else: fig = plt.gcf() bbox = ax.get_window_extent() @@ -318,6 +321,7 @@ def histogram( ax=None, facet=None, breaks=None, + binwidth=None, ): """Plot histogram @@ -371,10 +375,23 @@ def histogram( f"Breaks given : {breaks}. When using breaks, please ensure that " "breaks are strictly increasing." ) - if bins: + + if _are_numeric_values(binwidth): + if binwidth <= 0: raise exceptions.ValueError( - "Both bins and breaks are specified. Must specify only one of them." + f"Binwidth given : {binwidth}. When using binwidth, please ensure to " + "pass a positive value." ) + binwidth = float(binwidth) + elif binwidth is not None: + raise exceptions.ValueError( + f"Binwidth given : {binwidth}. When using binwidth, please ensure to " + "pass a numeric value." + ) + + validate_mutually_exclusive_args( + ["bins", "breaks", "binwidth"], [bins, breaks, binwidth] + ) ax = ax or plt.gca() payload["connection_info"] = conn._get_database_information() @@ -393,9 +410,15 @@ def histogram( raise ValueError("Column name has not been specified") bin_, height, bin_size = _histogram( - table, column, bins, with_=with_, conn=conn, breaks=breaks + table, + column, + bins, + with_=with_, + conn=conn, + breaks=breaks, + binwidth=binwidth, ) - width = _get_bar_width(ax, bin_, bin_size) + width = _get_bar_width(ax, bin_, bin_size, binwidth) data = _histogram_stacked( table, column, @@ -406,6 +429,7 @@ def histogram( conn=conn, facet=facet, breaks=breaks, + binwidth=binwidth, ) cmap = plt.get_cmap(cmap or "viridis") norm = Normalize(vmin=0, vmax=len(data)) @@ -449,9 +473,16 @@ def histogram( ax.legend(handles[::-1], labels[::-1]) elif isinstance(column, str): bin_, height, bin_size = _histogram( - table, column, bins, with_=with_, conn=conn, facet=facet, breaks=breaks + table, + column, + bins, + with_=with_, + conn=conn, + facet=facet, + breaks=breaks, + binwidth=binwidth, ) - width = _get_bar_width(ax, bin_, bin_size) + width = _get_bar_width(ax, bin_, bin_size, binwidth) ax.bar( bin_, @@ -472,9 +503,16 @@ def histogram( ) for i, col in enumerate(column): bin_, height, bin_size = _histogram( - table, col, bins, with_=with_, conn=conn, facet=facet, breaks=breaks + table, + col, + bins, + with_=with_, + conn=conn, + facet=facet, + breaks=breaks, + binwidth=binwidth, ) - width = _get_bar_width(ax, bin_, bin_size) + width = _get_bar_width(ax, bin_, bin_size, binwidth) if isinstance(color, list): color_ = color[i] @@ -505,7 +543,9 @@ def histogram( @modify_exceptions -def _histogram(table, column, bins, with_=None, conn=None, facet=None, breaks=None): +def _histogram( + table, column, bins, with_=None, conn=None, facet=None, breaks=None, binwidth=None +): """Compute bins and heights""" if not conn: conn = sql.connection.ConnectionManager.current @@ -576,7 +616,7 @@ def _histogram(table, column, bins, with_=None, conn=None, facet=None, breaks=No query = template.render( table=table, column=column, filter_query=filter_query ) - elif not isinstance(bins, int): + elif not binwidth and not isinstance(bins, int): raise ValueError( f"bins are '{bins}'. Please specify a valid number of bins." ) @@ -584,7 +624,15 @@ def _histogram(table, column, bins, with_=None, conn=None, facet=None, breaks=No # Use bins - 1 instead of bins and round half down instead of floor # to mimic right-closed histogram intervals in R ggplot range_ = max_ - min_ - bin_size = range_ / (bins - 1) + if binwidth: + bin_size = binwidth + if binwidth > range_: + message( + f"Specified binwidth {binwidth} is larger than " + f"the range {range_}. Please choose a smaller binwidth." + ) + else: + bin_size = range_ / (bins - 1) template_ = """ select ceiling("{{column}}"/{{bin_size}} - 0.5)*{{bin_size}} as bin, @@ -638,6 +686,7 @@ def _histogram_stacked( conn=None, facet=None, breaks=None, + binwidth=None, ): """Compute the corresponding heights of each bin based on the category""" if not conn: @@ -654,6 +703,8 @@ def _histogram_stacked( cases.append(case) cases[0] = cases[0].replace(">", ">=", 1) else: + if binwidth: + bin_size = binwidth tolerance = bin_size / 1000 # Use to avoid floating point error for bin in bins: # Use round half down instead of floor to mimic diff --git a/src/sql/util.py b/src/sql/util.py index 424852ba8..a8d91efcc 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -494,3 +494,27 @@ def get_default_configs(sql): del default_configs["parent"] del default_configs["config"] return default_configs + + +def _are_numeric_values(*values): + return all([isinstance(value, (int, float)) for value in values]) + + +def validate_mutually_exclusive_args(arg_names, args): + """ + Raises ValueError if a list of values from arg_names filtered by + args' boolean representations is longer than one. + + Parameters + ---------- + arg_names : list + args' names in string + args : list + args values + """ + specified_args = [arg_name for arg_name, arg in zip(arg_names, args) if arg] + if len(specified_args) > 1: + raise exceptions.ValueError( + f"{pretty_print(specified_args)} are specified. " + "You can only specify one of them." + ) diff --git a/src/tests/baseline_images/test_ggplot/histogram_binwidth_facet_wrap.png b/src/tests/baseline_images/test_ggplot/histogram_binwidth_facet_wrap.png new file mode 100644 index 000000000..588831221 Binary files /dev/null and b/src/tests/baseline_images/test_ggplot/histogram_binwidth_facet_wrap.png differ diff --git a/src/tests/baseline_images/test_ggplot/histogram_binwidth_with_multiple_cols.png b/src/tests/baseline_images/test_ggplot/histogram_binwidth_with_multiple_cols.png new file mode 100644 index 000000000..be068e91d Binary files /dev/null and b/src/tests/baseline_images/test_ggplot/histogram_binwidth_with_multiple_cols.png differ diff --git a/src/tests/baseline_images/test_ggplot/histogram_stacked_with_binwidth.png b/src/tests/baseline_images/test_ggplot/histogram_stacked_with_binwidth.png new file mode 100644 index 000000000..5c704a37a Binary files /dev/null and b/src/tests/baseline_images/test_ggplot/histogram_stacked_with_binwidth.png differ diff --git a/src/tests/baseline_images/test_ggplot/histogram_with_binwidth.png b/src/tests/baseline_images/test_ggplot/histogram_with_binwidth.png new file mode 100644 index 000000000..f3565b727 Binary files /dev/null and b/src/tests/baseline_images/test_ggplot/histogram_with_binwidth.png differ diff --git a/src/tests/baseline_images/test_ggplot/histogram_with_narrow_binwidth.png b/src/tests/baseline_images/test_ggplot/histogram_with_narrow_binwidth.png new file mode 100644 index 000000000..8b4d69e16 Binary files /dev/null and b/src/tests/baseline_images/test_ggplot/histogram_with_narrow_binwidth.png differ diff --git a/src/tests/baseline_images/test_magic_plot/hist_binwidth.png b/src/tests/baseline_images/test_magic_plot/hist_binwidth.png new file mode 100644 index 000000000..f3565b727 Binary files /dev/null and b/src/tests/baseline_images/test_magic_plot/hist_binwidth.png differ diff --git a/src/tests/integration/baseline_images/test_questDB/histogram_binwidth_with_multiple_cols.png b/src/tests/integration/baseline_images/test_questDB/histogram_binwidth_with_multiple_cols.png new file mode 100644 index 000000000..0770980ac Binary files /dev/null and b/src/tests/integration/baseline_images/test_questDB/histogram_binwidth_with_multiple_cols.png differ diff --git a/src/tests/integration/baseline_images/test_questDB/histogram_stacked_with_binwidth.png b/src/tests/integration/baseline_images/test_questDB/histogram_stacked_with_binwidth.png new file mode 100644 index 000000000..f6a27e244 Binary files /dev/null and b/src/tests/integration/baseline_images/test_questDB/histogram_stacked_with_binwidth.png differ diff --git a/src/tests/integration/baseline_images/test_questDB/histogram_with_binwidth.png b/src/tests/integration/baseline_images/test_questDB/histogram_with_binwidth.png new file mode 100644 index 000000000..d0fc0e916 Binary files /dev/null and b/src/tests/integration/baseline_images/test_questDB/histogram_with_binwidth.png differ diff --git a/src/tests/integration/baseline_images/test_questDB/histogram_with_narrow_binwidth.png b/src/tests/integration/baseline_images/test_questDB/histogram_with_narrow_binwidth.png new file mode 100644 index 000000000..715a07ef2 Binary files /dev/null and b/src/tests/integration/baseline_images/test_questDB/histogram_with_narrow_binwidth.png differ diff --git a/src/tests/integration/test_generic_db_operations.py b/src/tests/integration/test_generic_db_operations.py index 81a9133d5..6c26e7ec9 100644 --- a/src/tests/integration/test_generic_db_operations.py +++ b/src/tests/integration/test_generic_db_operations.py @@ -292,12 +292,17 @@ def test_telemetry_execute_command_has_connection_info( "%sqlplot histogram --with plot_something_subset --table\ plot_something_subset --column x --breaks 0 2 3 4 5" ), + ( + "%sqlplot histogram --with plot_something_subset --table\ + plot_something_subset --column x --binwidth 1" + ), ], ids=[ "histogram", "hist", "histogram-bins", "histogram-breaks", + "histogram-binwidth", ], ) @pytest.mark.parametrize( diff --git a/src/tests/integration/test_questDB.py b/src/tests/integration/test_questDB.py index fc1dc1596..306306afe 100644 --- a/src/tests/integration/test_questDB.py +++ b/src/tests/integration/test_questDB.py @@ -483,6 +483,62 @@ def test_histogram_breaks_over_max(ip_questdb, diamonds_data): ) +@_cleanup_cm() +@image_comparison( + baseline_images=["histogram_with_binwidth"], + extensions=["png"], + remove_text=True, +) +def test_histogram_with_binwidth(ip_questdb, penguins_no_nulls_questdb): + ( + ggplot(table="no_nulls", with_="no_nulls", mapping=aes(x="body_mass_g")) + + geom_histogram(binwidth=150) + ) + + +@_cleanup_cm() +@image_comparison( + baseline_images=["histogram_stacked_with_binwidth"], + extensions=["png"], + remove_text=True, +) +def test_histogram_stacked_with_binwidth(ip_questdb, penguins_no_nulls_questdb): + ( + ggplot(table="no_nulls", with_="no_nulls", mapping=aes(x="body_mass_g")) + + geom_histogram(binwidth=150, fill="species") + ) + + +@_cleanup_cm() +@image_comparison( + baseline_images=["histogram_binwidth_with_multiple_cols"], + extensions=["png"], + remove_text=True, +) +def test_histogram_binwidth_with_multiple_cols(ip_questdb, penguins_no_nulls_questdb): + ( + ggplot( + table="no_nulls", + with_="no_nulls", + mapping=aes(x=["bill_length_mm", "bill_depth_mm"]), + ) + + geom_histogram(binwidth=1.5) + ) + + +@_cleanup_cm() +@image_comparison( + baseline_images=["histogram_with_narrow_binwidth"], + extensions=["png"], + remove_text=True, +) +def test_histogram_with_narrow_binwidth(ip_questdb, penguins_no_nulls_questdb): + ( + ggplot(table="no_nulls", with_="no_nulls", mapping=aes(x="body_mass_g")) + + geom_histogram(binwidth=10) + ) + + @_cleanup_cm() @pytest.mark.parametrize( "x, expected_error, expected_error_message", diff --git a/src/tests/test_ggplot.py b/src/tests/test_ggplot.py index ebeb5d8fd..2b8ce004f 100644 --- a/src/tests/test_ggplot.py +++ b/src/tests/test_ggplot.py @@ -496,6 +496,72 @@ def test_histogram_with_extreme_breaks(penguins_data): ) +@_cleanup_cm() +@image_comparison( + baseline_images=["histogram_with_binwidth"], + extensions=["png"], + remove_text=True, +) +def test_histogram_with_binwidth(penguins_data): + ( + ggplot(table=penguins_data, mapping=aes(x="body_mass_g")) + + geom_histogram(binwidth=150) + ) + + +@_cleanup_cm() +@image_comparison( + baseline_images=["histogram_stacked_with_binwidth"], + extensions=["png"], + remove_text=True, +) +def test_histogram_stacked_with_binwidth(penguins_data): + ( + ggplot(table=penguins_data, mapping=aes(x="body_mass_g")) + + geom_histogram(binwidth=150, fill="species") + ) + + +@_cleanup_cm() +@image_comparison( + baseline_images=["histogram_binwidth_with_multiple_cols"], + extensions=["png"], + remove_text=True, +) +def test_histogram_binwidth_with_multiple_cols(penguins_data): + ( + ggplot(table=penguins_data, mapping=aes(x=["bill_length_mm", "bill_depth_mm"])) + + geom_histogram(binwidth=1.5) + ) + + +@_cleanup_cm() +@image_comparison( + baseline_images=["histogram_binwidth_facet_wrap"], + extensions=["png"], + remove_text=True, +) +def test_histogram_binwidth_facet_wrap(penguins_data): + ( + ggplot(table=penguins_data, mapping=aes(x=["body_mass_g"])) + + geom_histogram(binwidth=150) + + facet_wrap("species") + ) + + +@_cleanup_cm() +@image_comparison( + baseline_images=["histogram_with_narrow_binwidth"], + extensions=["png"], + remove_text=True, +) +def test_histogram_with_narrow_binwidth(penguins_data): + ( + ggplot(table=penguins_data, mapping=aes(x="body_mass_g")) + + geom_histogram(binwidth=10) + ) + + @pytest.mark.parametrize( "x, expected_error, expected_error_message", [ @@ -555,7 +621,7 @@ def test_histogram_no_bins_error(diamonds_data): ( 40, [3000.0, 4000.0, 5000.0], - "Both bins and breaks are specified. Must specify only one of them.", + "'bins', and 'breaks' are specified. You can only specify one of them.", ), ], ) @@ -568,3 +634,55 @@ def test_hist_breaks_error(penguins_data, bins, breaks, error_message): assert error.value.error_type == "ValueError" assert error_message in str(error.value) + + +@pytest.mark.parametrize( + "bins, breaks, binwidth, error_message", + [ + ( + None, + [1000, 2000, 3000], + 150, + ( + "'binwidth', and 'breaks' are specified. " + "You can only specify one of them." + ), + ), + ( + 50, + [1000, 2000, 3000], + 150, + ( + "'bins', 'binwidth', and 'breaks' are specified. " + "You can only specify one of them." + ), + ), + ( + None, + None, + "invalid", + ( + "Binwidth given : invalid. When using binwidth, " + "please ensure to pass a numeric value." + ), + ), + ( + None, + None, + 0, + ( + "Binwidth given : 0. When using binwidth, " + "please ensure to pass a positive value." + ), + ), + ], +) +def test_hist_binwidth_error(penguins_data, bins, breaks, binwidth, error_message): + with pytest.raises(UsageError) as error: + ( + ggplot(penguins_data, aes(x="body_mass_g")) + + geom_histogram(bins=bins, breaks=breaks, binwidth=binwidth) + ) + + assert error.value.error_type == "ValueError" + assert error_message in str(error.value) diff --git a/src/tests/test_magic_plot.py b/src/tests/test_magic_plot.py index 387898d7b..354b5a4c0 100644 --- a/src/tests/test_magic_plot.py +++ b/src/tests/test_magic_plot.py @@ -104,14 +104,14 @@ def test_validate_arguments(tmp_empty, ip, cell, error_message): "%sqlplot histogram --table penguins.csv --column body_mass_g " "--breaks 3000 4000 5000 --bins 50" ), - "Both bins and breaks are specified. Must specify only one of them.", + "'bins', and 'breaks' are specified. You can only specify one of them.", ], [ ( "%sqlplot histogram --table penguins.csv --bins 50 --column body_mass_g" " --breaks 3000 4000 5000" ), - "Both bins and breaks are specified. Must specify only one of them.", + "'bins', and 'breaks' are specified. You can only specify one of them.", ], [ ( @@ -129,6 +129,74 @@ def test_validate_breaks_arguments(load_penguin, ip, cell, error_message): assert error_message in str(excinfo.value) +@pytest.mark.parametrize( + "cell, error_message", + [ + [ + ( + "%sqlplot histogram --table penguins.csv --column body_mass_g " + "--bins 50 --binwidth 1000" + ), + "'bins', and 'binwidth' are specified. You can only specify one of them.", + ], + [ + ( + "%sqlplot histogram --table penguins.csv --column body_mass_g " + "-W 50 --breaks 3000 4000 5000" + ), + "'binwidth', and 'breaks' are specified. You can only specify one of them.", + ], + [ + ( + "%sqlplot histogram --table penguins.csv --column body_mass_g " + "--binwidth 0" + ), + ( + "Binwidth given : 0.0. When using binwidth, " + "please ensure to pass a positive value." + ), + ], + [ + ( + "%sqlplot histogram --table penguins.csv --column body_mass_g " + "--binwidth -10" + ), + ( + "Binwidth given : -10.0. When using binwidth, " + "please ensure to pass a positive value." + ), + ], + ], +) +def test_validate_binwidth_arguments(load_penguin, ip, cell, error_message): + with pytest.raises(UsageError) as excinfo: + ip.run_cell(cell) + + assert error_message in str(excinfo.value) + assert excinfo.value.error_type == "ValueError" + + +def test_validate_binwidth_text_argument(tmp_empty, ip): + with pytest.raises(UsageError) as excinfo: + ip.run_cell( + "%sqlplot histogram --table penguins.csv " + "--column body_mass_g --binwidth test" + ) + + assert "argument -W/--binwidth: invalid float value: 'test'" == str(excinfo.value) + + +def test_binwidth_larger_than_range(load_penguin, ip, capsys): + ip.run_cell( + "%sqlplot histogram --table penguins.csv --column body_mass_g --binwidth 3601" + ) + out, _ = capsys.readouterr() + assert ( + "Specified binwidth 3601.0 is larger than the range 3600. " + "Please choose a smaller binwidth." + ) in out + + @_cleanup_cm() @pytest.mark.parametrize( "cell", @@ -136,6 +204,7 @@ def test_validate_breaks_arguments(load_penguin, ip, cell, error_message): "%sqlplot histogram --table data.csv --column x", "%sqlplot hist --table data.csv --column x", "%sqlplot histogram --table data.csv --column x --bins 10", + "%sqlplot histogram --table data.csv --column x --binwidth 1", pytest.param( "%sqlplot histogram --table nas.csv --column x", marks=pytest.mark.xfail(reason="Not implemented yet"), @@ -195,6 +264,7 @@ def test_validate_breaks_arguments(load_penguin, ip, cell, error_message): "histogram", "hist", "histogram-bins", + "histogram-binwidth", "histogram-nas", "boxplot", "boxplot-with", @@ -491,6 +561,23 @@ def test_hist_breaks(load_penguin, ip): ) +@pytest.mark.parametrize( + "binwidth", + [ + "--binwidth", + "-W", + ], +) +@_cleanup_cm() +@image_comparison( + baseline_images=["hist_binwidth"], extensions=["png"], remove_text=True +) +def test_hist_binwidth(load_penguin, ip, binwidth): + ip.run_cell( + f"%sqlplot histogram --table penguins.csv --column body_mass_g {binwidth} 150" + ) + + @pytest.mark.parametrize( "arg", [