diff --git a/docs/source/api/backend/none.rst b/docs/source/api/backend/none.rst new file mode 100644 index 00000000..cd2b1cd9 --- /dev/null +++ b/docs/source/api/backend/none.rst @@ -0,0 +1,48 @@ +============ +None backend +============ + +.. automodule:: arviz_plots.backend.none + + +Object creation and I/O +....................... + +.. autosummary:: + :toctree: generated/ + + create_plotting_grid + show + +Geoms +..... + +.. autosummary:: + :toctree: generated/ + + line + scatter + text + +Plot appeareance +................ + +.. autosummary:: + :toctree: generated/ + + title + ylabel + xlabel + xticks + yticks + ticklabel_props + remove_ticks + remove_axis + +Legend +...... + +.. autosummary:: + :toctree: generated/ + + legend diff --git a/docs/source/api/backend/plotly.rst b/docs/source/api/backend/plotly.rst new file mode 100644 index 00000000..6bf1ec29 --- /dev/null +++ b/docs/source/api/backend/plotly.rst @@ -0,0 +1,48 @@ +============== +Plotly backend +============== + +.. automodule:: arviz_plots.backend.plotly + + +Object creation and I/O +....................... + +.. autosummary:: + :toctree: generated/ + + create_plotting_grid + show + +Geoms +..... + +.. autosummary:: + :toctree: generated/ + + line + scatter + text + +Plot appeareance +................ + +.. autosummary:: + :toctree: generated/ + + title + ylabel + xlabel + xticks + yticks + ticklabel_props + remove_ticks + remove_axis + +Legend +...... + +.. autosummary:: + :toctree: generated/ + + legend diff --git a/docs/source/gallery/backreferences.json b/docs/source/gallery/backreferences.json new file mode 100644 index 00000000..44b2f455 --- /dev/null +++ b/docs/source/gallery/backreferences.json @@ -0,0 +1 @@ +{"plot_forest": [{"basename": "plot_forest_ess", "refname": "gallery_forest_ess", "title": "Forest plot with ESS", "description": "Multiple panel visualization with a forest plot and ESS information"}, {"basename": "plot_forest", "refname": "gallery_forest", "title": "Forest plot", "description": "Default forest plot with marginal distribution summaries"}, {"basename": "plot_forest_shade", "refname": "gallery_forest_shade", "title": "Forest plot with shading", "description": "Forest plot marginal summaries with row shading to enhance reading"}, {"basename": "plot_forest_models", "refname": "gallery_forest_models", "title": "Forest plot comparison", "description": "Forest plot summaries for 1D marginal distributions"}, {"basename": "plot_forest_pp_obs", "refname": "gallery_forest_pp_obs", "title": "Posterior predictive and observations forest plot", "description": "Overlay of forest plot for the posterior predictive samples and the actual observations"}], "plot_dist": [{"basename": "plot_dist_ecdf", "refname": "gallery_dist_ecdf", "title": "ECDF plot", "description": "Facetted ECDF plots for 1D marginals of the distribution"}, {"basename": "plot_dist_hist", "refname": "gallery_dist_hist", "title": "Histogram plot", "description": "Facetted histogram plots for 1D marginals of the distribution"}, {"basename": "plot_dist_kde", "refname": "gallery_dist_kde", "title": "KDE plot", "description": "Facetted KDE plots for 1D marginals of the distribution"}, {"basename": "plot_dist_models", "refname": "gallery_dist_models", "title": "Marginal distribution comparison plot", "description": "Full marginal distribution comparison between different models"}], "plot_mcse": [{"basename": "plot_mcse", "refname": "gallery_mcse", "title": "MCSE Quantile plot", "description": "Facetted quantile MCSE plot"}, {"basename": "plot_mcse_errorbar", "refname": "gallery_mcse_errorbar", "title": "MCSE Quantile plot with errorbars", "description": "Facetted quantile MCSE plot with errorbars"}, {"basename": "plot_mcse_models", "refname": "gallery_mcse_models", "title": "MCSE comparison plot", "description": "Full MCSE comparison between different models"}], "plot_trace": [{"basename": "plot_trace", "refname": "gallery_trace", "title": "Trace plot", "description": "Facetted plot with MCMC traces for each variable"}]} \ No newline at end of file diff --git a/docs/source/gallery/inference_diagnostics/plot_mcse.py b/docs/source/gallery/inference_diagnostics/plot_mcse.py new file mode 100644 index 00000000..0e9006b3 --- /dev/null +++ b/docs/source/gallery/inference_diagnostics/plot_mcse.py @@ -0,0 +1,24 @@ +""" +# MCSE Quantile plot + +Facetted quantile MCSE plot + +--- + +:::{seealso} +API Documentation: {func}`~arviz_plots.plot_mcse` +::: +""" + +from arviz_base import load_arviz_data + +import arviz_plots as azp + +azp.style.use("arviz-clean") + +data = load_arviz_data("centered_eight") +pc = azp.plot_mcse( + data, + backend="none", # change to preferred backend +) +pc.show() diff --git a/docs/source/gallery/inference_diagnostics/plot_mcse_errorbar.py b/docs/source/gallery/inference_diagnostics/plot_mcse_errorbar.py new file mode 100644 index 00000000..c1ad9996 --- /dev/null +++ b/docs/source/gallery/inference_diagnostics/plot_mcse_errorbar.py @@ -0,0 +1,25 @@ +""" +# MCSE Quantile plot with errorbars + +Facetted quantile MCSE plot with errorbars + +--- + +:::{seealso} +API Documentation: {func}`~arviz_plots.plot_mcse` +::: +""" + +from arviz_base import load_arviz_data + +import arviz_plots as azp + +azp.style.use("arviz-clean") + +data = load_arviz_data("centered_eight") +pc = azp.plot_mcse( + data, + errorbar=True, + backend="none", # change to preferred backend +) +pc.show() diff --git a/docs/source/gallery/inference_diagnostics/plot_mcse_models.py b/docs/source/gallery/inference_diagnostics/plot_mcse_models.py new file mode 100644 index 00000000..0257ecd9 --- /dev/null +++ b/docs/source/gallery/inference_diagnostics/plot_mcse_models.py @@ -0,0 +1,25 @@ +""" +# MCSE comparison plot + +Full MCSE comparison between different models + +--- + +:::{seealso} +API Documentation: {func}`~arviz_plots.plot_mcse` +::: +""" + +from arviz_base import load_arviz_data + +import arviz_plots as azp + +azp.style.use("arviz-clean") + +c = load_arviz_data("centered_eight") +n = load_arviz_data("non_centered_eight") +pc = azp.plot_mcse( + {"Centered": c, "Non Centered": n}, + backend="none", # change to preferred backend +) +pc.show() diff --git a/docs/source/gallery/inference_diagnostics/plot_mcse_models_errorbar.py b/docs/source/gallery/inference_diagnostics/plot_mcse_models_errorbar.py new file mode 100644 index 00000000..0788f5de --- /dev/null +++ b/docs/source/gallery/inference_diagnostics/plot_mcse_models_errorbar.py @@ -0,0 +1,26 @@ +""" +# MCSE comparison plot with errorbars + +Full MCSE comparison between different models with errorbars + +--- + +:::{seealso} +API Documentation: {func}`~arviz_plots.plot_mcse` +::: +""" + +from arviz_base import load_arviz_data + +import arviz_plots as azp + +azp.style.use("arviz-clean") + +c = load_arviz_data("centered_eight") +n = load_arviz_data("non_centered_eight") +pc = azp.plot_mcse( + {"Centered": c, "Non Centered": n}, + errorbar=True, + backend="none", # change to preferred backend +) +pc.show() diff --git a/pyproject.toml b/pyproject.toml index 9a3ec7a1..635fb26b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ dynamic = ["version", "description"] dependencies = [ "arviz-base==0.2", - "arviz-stats[xarray]==0.2", + "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats", ] [tool.flit.module] diff --git a/src/arviz_plots/backend/bokeh/__init__.py b/src/arviz_plots/backend/bokeh/__init__.py index 0079f02e..8c6b235a 100644 --- a/src/arviz_plots/backend/bokeh/__init__.py +++ b/src/arviz_plots/backend/bokeh/__init__.py @@ -298,6 +298,49 @@ def scatter( return target.scatter(np.atleast_1d(x), np.atleast_1d(y), **kwargs) +def errorbar( + x, + y, + error, + target, + *, + size=unset, + marker=unset, + color=unset, + facecolor=unset, + edgecolor=unset, + width=unset, + **artist_kws, +): + """Interface to bokeh for an errorbar plot.""" + if color is not unset: + if facecolor is unset and edgecolor is unset: + facecolor = color + edgecolor = color + elif facecolor is unset: + facecolor = color + elif edgecolor is unset: + edgecolor = color + kwargs = { + "size": size, + "marker": marker, + "fill_color": facecolor, + "line_color": edgecolor, + "line_width": width, + } + kwargs = _filter_kwargs(kwargs, artist_kws) + if marker == "|": + kwargs["marker"] = "dash" + kwargs["angle"] = np.pi / 2 + + target.scatter(np.atleast_1d(x), np.atleast_1d(y), **kwargs) + + x_err = list(zip(x, x)) + y_err = [(y_i - err, y_i + err) for y_i, err in zip(y, error)] + target.multi_line(xs=x_err, ys=y_err, **kwargs) + return target + + def text( x, y, diff --git a/src/arviz_plots/backend/matplotlib/__init__.py b/src/arviz_plots/backend/matplotlib/__init__.py index 3a605839..49af59d7 100644 --- a/src/arviz_plots/backend/matplotlib/__init__.py +++ b/src/arviz_plots/backend/matplotlib/__init__.py @@ -274,6 +274,42 @@ def scatter( return target.scatter(x, y, **_filter_kwargs(kwargs, None, artist_kws)) +def errorbar( + x, + y, + error, + target, + *, + size=unset, + marker=unset, + color=unset, + facecolor=unset, + edgecolor=unset, + width=unset, + **artist_kws, +): + """Interface to matplotlib for an errorbar plot.""" + artist_kws.setdefault("zorder", 2) + fillable_marker = (marker is unset) or (marker in Line2D.filled_markers) + if color is not unset: + if facecolor is unset and edgecolor is unset: + facecolor = color + if fillable_marker: + edgecolor = color + elif facecolor is unset: + facecolor = color + elif edgecolor is unset and fillable_marker: + edgecolor = color + kwargs = { + "capsize": size, + "marker": marker, + "markerfacecolor": facecolor, + "markeredgecolor": edgecolor, + "elinewidth": width, + } + return target.errorbar(x, y, error, **_filter_kwargs(kwargs, None, artist_kws)) + + def text( x, y, diff --git a/src/arviz_plots/backend/none/__init__.py b/src/arviz_plots/backend/none/__init__.py index 84770816..d50cf6cd 100644 --- a/src/arviz_plots/backend/none/__init__.py +++ b/src/arviz_plots/backend/none/__init__.py @@ -271,6 +271,49 @@ def scatter( return artist_element +def errorbar( + x, + y, + error, + target, + *, + size=unset, + marker=unset, + color=unset, + facecolor=unset, + edgecolor=unset, + width=unset, + **artist_kws, +): + """Interface to an errorbar plot.""" + if color is not unset: + if facecolor is unset and edgecolor is unset: + facecolor = color + edgecolor = color + elif facecolor is unset: + facecolor = color + elif edgecolor is unset: + edgecolor = color + kwargs = { + "capsize": size, + "marker": marker, + "markerfacecolor": facecolor, + "markeredgecolor": edgecolor, + "elinewidth": width, + } + if not ALLOW_KWARGS and artist_kws: + raise ValueError("artist_kws not empty") + artist_element = { + "function": "errorbar", + "x": np.atleast_1d(x), + "y": np.atleast_1d(y), + "error": np.atleast_1d(error), + **_filter_kwargs(kwargs, artist_kws), + } + target.append(artist_element) + return artist_element + + def text( x, y, diff --git a/src/arviz_plots/plots/__init__.py b/src/arviz_plots/plots/__init__.py index 72304d5e..f1122fe6 100644 --- a/src/arviz_plots/plots/__init__.py +++ b/src/arviz_plots/plots/__init__.py @@ -3,6 +3,7 @@ from .compareplot import plot_compare from .distplot import plot_dist from .forestplot import plot_forest +from .mcseplot import plot_mcse from .ridgeplot import plot_ridge from .tracedistplot import plot_trace_dist from .traceplot import plot_trace @@ -14,4 +15,5 @@ "plot_trace", "plot_trace_dist", "plot_ridge", + "plot_mcse", ] diff --git a/src/arviz_plots/plots/mcseplot.py b/src/arviz_plots/plots/mcseplot.py new file mode 100644 index 00000000..e8f7dc87 --- /dev/null +++ b/src/arviz_plots/plots/mcseplot.py @@ -0,0 +1,538 @@ +"""mcse plot code.""" + +# imports +# import warnings +from copy import copy +from importlib import import_module + +import arviz_stats # pylint: disable=unused-import +import numpy as np +import xarray as xr +from arviz_base import rcParams +from arviz_base.labels import BaseLabeller + +from arviz_plots.plot_collection import PlotCollection, process_facet_dims +from arviz_plots.plots.utils import filter_aes, get_group, process_group_variables_coords +from arviz_plots.visuals import ( + annotate_xy, + error_bar, + labelled_title, + labelled_x, + labelled_y, + line_xy, + scatter_xy, + trace_rug, +) + + +def plot_mcse( + dt, + var_names=None, + filter_vars=None, + group="posterior", + coords=None, + sample_dims=None, + errorbar=False, + rug=False, + rug_kind="diverging", + n_points=20, + extra_methods=False, + plot_collection=None, + backend=None, + labeller=None, + aes_map=None, + plot_kwargs=None, + stats_kwargs=None, + pc_kwargs=None, +): + """Plot quantile Monte Carlo Standard Error. + + Parameters + ---------- + dt : DataTree or dict of {str : DataTree} + Input data. In case of dictionary input, the keys are taken to be model names. + In such cases, a dimension "model" is generated and can be used to map to aesthetics. + var_names : str or sequence of str, optional + One or more variables to be plotted. + Prefix the variables by ~ when you want to exclude them from the plot. + filter_vars : {None, “like”, “regex”}, default None + If None, interpret `var_names` as the real variables names. + If “like”, interpret `var_names` as substrings of the real variables names. + If “regex”, interpret `var_names` as regular expressions on the real variables names. + group : str, default "posterior" + Group to be plotted. + coords : dict, optional + sample_dims : str or sequence of hashable, optional + Dimensions to reduce unless mapped to an aesthetic. + Defaults to ``rcParams["data.sample_dims"]`` + errorbar : bool, default False + Plot quantile value +/- mcse instead of plotting mcse. + rug : bool, default False + Add a `rug plot `_ for a specific subset of values. + rug_kind : str, default "diverging" + Variable in sample stats to use as rug mask. Must be a boolean variable. + n_points : int, default 20 + Number of points for which to plot their quantile MCSE. + extra_methods : bool, default False + Plot mean and sd MCSE as horizontal lines. Only taken into account when + ``errorbar=False``. + plot_collection : PlotCollection, optional + backend : {"matplotlib", "bokeh"}, optional + labeller : labeller, optional + aes_map : mapping of {str : sequence of str or False}, optional + Mapping of artists to aesthetics that should use their mapping in `plot_collection` + when plotted. Valid keys are the same as for `plot_kwargs`. + + plot_kwargs : mapping of {str : mapping or False}, optional + Valid keys are: + + * mcse -> passed to :func:`~arviz_plots.visuals.scatter_xy` + if ``errorbar=False``. + Passed to :func:`~arviz_plots.visuals.errorbar` + if ``errorbar=True``. + + * rug -> passed to :func:`~.visuals.trace_rug` + * title -> passed to :func:`~arviz_plots.visuals.labelled_title` + * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` + * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` + * mean -> passed to :func:`~arviz.plots.visuals.line_xy` + * sd -> passed to :func:`~arviz.plots.visuals.line_xy` + * mean_text -> passed to :func:`~arviz.plots.visuals.annotate_xy` + * sd_text -> passed to :func:`~arviz.plots.visuals.annotate_xy` + + stats_kwargs : mapping, optional + Valid keys are: + + * mcse -> passed to mcse, method = 'quantile' + * mean -> passed to mcse, method='mean' + * sd -> passed to mcse, method='sd' + + pc_kwargs : mapping + Passed to :class:`arviz_plots.PlotCollection.wrap` + + Returns + ------- + PlotCollection + + Notes + ----- + Depending on the number of models, a slight x-axis separation aesthetic is applied for each + ess point for distinguishability in case of overlap + + See Also + -------- + :ref:`plots_intro` : + General introduction to batteries-included plotting functions, common use and logic overview + + Examples + -------- + We can manually map the color to the variable, and have the mapping apply + to the title too instead of only the mcse markers: + + .. plot:: + :context: close-figs + >>> from arviz_plots import plot_mcse, style + >>> style.use("arviz-clean") + >>> from arviz_base import load_arviz_data + >>> centered = load_arviz_data('centered_eight') + >>> non_centered = load_arviz_data('non_centered_eight') + >>> pc = plot_mcse( + >>> non_centered, + >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, + >>> pc_kwargs={"aes": {"color": ["__variable__"]}}, + >>> aes_map={"title": ["color"]}, + >>> ) + + We can add extra methods to plot the mean and standard deviation as lines + + .. plot:: + :context: close-figs + >>> pc = plot_mcse( + >>> non_centered, + >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, + >>> extra_methods=True, + >>> ) + + Rugs can also be added: + .. plot:: + :context: close-figs + >>> pc = plot_mcse( + >>> non_centered, + >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, + >>> rug=True, + >>> ) + + We can also adjust the number of points: + + .. plot:: + :context: close-figs + >>> pc = plot_mcse( + >>> non_centered, + >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, + >>> n_points=10, + >>> ) + + If we want to plot quantile values +/- mcse, we can turn errorbars on: + + .. plot:: + :context: close-figs + >>> pc = plot_mcse( + >>> non_centered, + >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, + >>> errorbar=True, + >>> ) + + .. minigallery:: plot_mcse + + """ + # initial defaults + if sample_dims is None: + sample_dims = rcParams["data.sample_dims"] + if isinstance(sample_dims, str): + sample_dims = [sample_dims] + + # mutable inputs + if plot_kwargs is None: + plot_kwargs = {} + if pc_kwargs is None: + pc_kwargs = {} + else: + pc_kwargs = pc_kwargs.copy() + + if stats_kwargs is None: + stats_kwargs = {} + + # processing dt/group/coords/filtering + distribution = process_group_variables_coords( + dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords + ) + + # ensuring plot_kwargs['rug'] is not False + rug_kwargs = copy(plot_kwargs.get("rug", {})) + if rug_kwargs is False: + raise ValueError("plot_kwargs['rug'] can't be False, use rug=False to remove the rug") + + if backend is None: + if plot_collection is None: + backend = rcParams["plot.backend"] + else: + backend = plot_collection.backend + + plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") + + # set plot collection initialization defaults if it doesnt exist + if plot_collection is None: + pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() + pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() + pc_kwargs.setdefault("col_wrap", 5) + pc_kwargs.setdefault( + "cols", + ["__variable__"] + + [dim for dim in distribution.dims if dim not in {"model"}.union(sample_dims)], + ) + if "chain" in distribution: + pc_kwargs["aes"].setdefault("overlay", ["chain"]) + if "model" in distribution: + pc_kwargs["aes"].setdefault("color", ["model"]) + n_models = distribution.sizes["model"] + x_diff = min(1 / n_points / 3, 1 / n_points * n_models / 10) + pc_kwargs.setdefault("x", np.linspace(-x_diff, x_diff, n_models)) + pc_kwargs["aes"].setdefault("x", ["model"]) + figsize = pc_kwargs.get("plot_grid_kws", {}).get("figsize", None) + figsize_units = pc_kwargs.get("plot_grid_kws", {}).get("figsize_units", "inches") + if figsize is None: + n_plots, _ = process_facet_dims(distribution, pc_kwargs["cols"]) + col_wrap = pc_kwargs["col_wrap"] + if n_plots <= col_wrap: + n_rows, n_cols = 1, n_plots + else: + div_mod = divmod(n_plots, col_wrap) + n_rows = div_mod[0] + (div_mod[1] != 0) + n_cols = col_wrap + figsize = plot_bknd.scale_fig_size( + figsize, + rows=n_rows, + cols=n_cols, + figsize_units=figsize_units, + ) + figsize_units = "dots" + pc_kwargs["plot_grid_kws"]["figsize"] = figsize + pc_kwargs["plot_grid_kws"]["figsize_units"] = figsize_units + plot_collection = PlotCollection.wrap( + distribution, + backend=backend, + **pc_kwargs, + ) + + # set plot collection dependent defaults (like aesthetics mappings for each artist) + if aes_map is None: + aes_map = {} + else: + aes_map = aes_map.copy() + aes_map.setdefault("mcse", plot_collection.aes_set.difference({"overlay"})) + aes_map.setdefault("rug", {"overlay"}) + if "model" in distribution: + aes_map.setdefault("mean", {"color"}) + aes_map.setdefault("sd", {"color"}) + if "mean" in aes_map and "mean_text" not in aes_map: + aes_map["mean_text"] = aes_map["mean"] + if "sd" in aes_map and "sd_text" not in aes_map: + aes_map["sd_text"] = aes_map["sd"] + if labeller is None: + labeller = BaseLabeller() + + # compute and add mcse subplots + mcse_kwargs = copy(plot_kwargs.get("mcse", {})) + + if mcse_kwargs is not False: + mcse_dims, _, mcse_ignore = filter_aes(plot_collection, aes_map, "mcse", sample_dims) + probs = np.linspace(1 / n_points, 1 - 1 / n_points, n_points) + xdata = probs + + mcse_y_dataset = xr.concat( + [ + distribution.azstats.mcse( + dims=mcse_dims, + method="quantile", + prob=p, + **stats_kwargs.get("mcse", {}), + ) + for p in probs + ], + dim="mcse_dim", + ) + + xdata_da = xr.DataArray(xdata, dims="mcse_dim") + # broadcasting xdata_da to match shape of each variable in mcse_y_dataset and + # creating a new dataset from dict of broadcasted xdata + xdata_dataset = xr.Dataset( + {var_name: xdata_da.broadcast_like(da) for var_name, da in mcse_y_dataset.items()} + ) + # concatenating xdata_dataset and ess_y_dataset along plot_axis + mcse_dataset = xr.concat([xdata_dataset, mcse_y_dataset], dim="plot_axis").assign_coords( + plot_axis=["x", "y"] + ) + + if errorbar is False: + plot_collection.map( + scatter_xy, "mcse", data=mcse_dataset, ignore_aes=mcse_ignore, **mcse_kwargs + ) + + # else: + # use new errorbar visual element function to plot errorbars + else: + quantiles_dataset = distribution.quantile(probs, dim=mcse_dims) + + plot_collection.map( + error_bar, + "mcse", + data=mcse_dataset, + ignore_aes=mcse_ignore, + quantiles_dataset=quantiles_dataset, + **mcse_kwargs, + ) + + # plot rug + # overlaying divergences(or other 'rug_kind') for each chain + if rug: + sample_stats = get_group(dt, "sample_stats", allow_missing=True) + if ( + sample_stats is not None + and rug_kind in sample_stats.data_vars + and np.any(sample_stats[rug_kind]) # 'diverging' by default + and rug_kwargs is not False + ): + rug_mask = dt.sample_stats[rug_kind] # 'diverging' by default + _, div_aes, div_ignore = filter_aes(plot_collection, aes_map, "rug", sample_dims) + if "color" not in div_aes: + rug_kwargs.setdefault("color", "black") + if "marker" not in div_aes: + rug_kwargs.setdefault("marker", "|") + if "size" not in div_aes: + rug_kwargs.setdefault("size", 30) + + values = distribution.azstats.compute_ranks(relative=True) + + plot_collection.map( + trace_rug, + "rug", + data=values, + ignore_aes=div_ignore, + y=0, + mask=rug_mask, + xname=False, + **rug_kwargs, + ) # note: after plot_ppc merge, the `trace_rug` function might change + + # defining x_range (used for mean, sd plotting) + x_range = [0, 1] + x_range = xr.DataArray(x_range) + + # getting backend specific linestyles + linestyles = plot_bknd.get_default_aes("linestyle", 4, {}) + # and default color + default_color = plot_bknd.get_default_aes("color", 1, {})[0] + + # plot mean and sd + if extra_methods is not False: + if errorbar is not False: + raise ValueError("Please ensure errorbar=False if you want to plot mean and sd") + + # computing mean_mcse + mean_dims, mean_aes, mean_ignore = filter_aes(plot_collection, aes_map, "mean", sample_dims) + mean_mcse = distribution.azstats.mcse( + dims=mean_dims, method="mean", **stats_kwargs.get("mean", {}) + ) + + # computing sd_mcse + sd_dims, sd_aes, sd_ignore = filter_aes(plot_collection, aes_map, "sd", sample_dims) + sd_mcse = distribution.azstats.mcse(dims=sd_dims, method="sd", **stats_kwargs.get("sd", {})) + + mean_kwargs = copy(plot_kwargs.get("mean", {})) + if mean_kwargs is not False: + # getting 2nd default linestyle for chosen backend and assigning it by default + mean_kwargs.setdefault("linestyle", linestyles[1]) + + if "color" not in mean_aes: + mean_kwargs.setdefault("color", default_color) + + plot_collection.map( + line_xy, + "mean", + data=mean_mcse, + x=x_range, + ignore_aes=mean_ignore, + **mean_kwargs, + ) + + sd_kwargs = copy(plot_kwargs.get("sd", {})) + if sd_kwargs is not False: + sd_kwargs.setdefault("linestyle", linestyles[2]) + + if "color" not in sd_aes: + sd_kwargs.setdefault("color", default_color) + + plot_collection.map( + line_xy, "sd", data=sd_mcse, ignore_aes=sd_ignore, x=x_range, **sd_kwargs + ) + + sd_va_align = None + mean_va_align = None + if mean_mcse is not None and sd_mcse is not None: + sd_va_align = xr.where(mean_mcse < sd_mcse, "bottom", "top") + mean_va_align = xr.where(mean_mcse < sd_mcse, "top", "bottom") + + mean_text_kwargs = copy(plot_kwargs.get("mean_text", {})) + if ( + mean_text_kwargs is not False and mean_mcse is not None + ): # mean_mcse has to exist for an annotation to be applied + _, mean_text_aes, mean_text_ignore = filter_aes( + plot_collection, aes_map, "mean_text", sample_dims + ) + + if "color" not in mean_text_aes: + mean_text_kwargs.setdefault("color", "black") + + mean_text_kwargs.setdefault("x", 1) + mean_text_kwargs.setdefault("horizontal_align", "right") + + # pass the mean vertical_align data for vertical alignment setting + if mean_va_align is not None: + vertical_align = mean_va_align + else: + vertical_align = "bottom" + mean_text_kwargs.setdefault("vertical_align", vertical_align) + + plot_collection.map( + annotate_xy, + "mean_text", + text="mean", + data=mean_mcse, + ignore_aes=mean_text_ignore, + **mean_text_kwargs, + ) + + sd_text_kwargs = copy(plot_kwargs.get("sd_text", {})) + if ( + sd_text_kwargs is not False and sd_mcse is not None + ): # sd_mcse has to exist for an annotation to be applied + _, sd_text_aes, sd_text_ignore = filter_aes( + plot_collection, aes_map, "sd_text", sample_dims + ) + + if "color" not in sd_text_aes: + sd_text_kwargs.setdefault("color", "black") + + sd_text_kwargs.setdefault("x", 1) + sd_text_kwargs.setdefault("horizontal_align", "right") + + # pass the sd vertical_align data for vertical alignment setting + if sd_va_align is not None: + vertical_align = sd_va_align + else: + vertical_align = "top" + sd_text_kwargs.setdefault("vertical_align", vertical_align) + + plot_collection.map( + annotate_xy, + "sd_text", + text="sd", + data=sd_mcse, + ignore_aes=sd_text_ignore, + **sd_text_kwargs, + ) + + # plot titles for each facetted subplot + title_kwargs = copy(plot_kwargs.get("title", {})) + + if title_kwargs is not False: + _, title_aes, title_ignore = filter_aes(plot_collection, aes_map, "title", sample_dims) + if "color" not in title_aes: + title_kwargs.setdefault("color", "black") + plot_collection.map( + labelled_title, + "title", + ignore_aes=title_ignore, + subset_info=True, + labeller=labeller, + **title_kwargs, + ) + + # plot x and y axis labels + # Add varnames as x and y labels + _, labels_aes, labels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims) + xlabel_kwargs = plot_kwargs.get("xlabel", {}).copy() + if xlabel_kwargs is not False: + if "color" not in labels_aes: + xlabel_kwargs.setdefault("color", "black") + + # formatting ylabel and setting xlabel + xlabel_kwargs.setdefault("text", "Quantile") + + plot_collection.map( + labelled_x, + "xlabel", + ignore_aes=labels_ignore, + subset_info=True, + **xlabel_kwargs, + ) + + _, labels_aes, labels_ignore = filter_aes(plot_collection, aes_map, "ylabel", sample_dims) + ylabel_kwargs = plot_kwargs.get("ylabel", {}).copy() + if ylabel_kwargs is not False: + if "color" not in labels_aes: + ylabel_kwargs.setdefault("color", "black") + + ylabel_text = r"Value $\pm$ MCSE for quantiles" if errorbar else "MCSE for quantiles" + + ylabel_kwargs.setdefault("text", ylabel_text) + + plot_collection.map( + labelled_y, + "ylabel", + ignore_aes=labels_ignore, + subset_info=True, + **ylabel_kwargs, + ) + + return plot_collection diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index 406bf427..8e356c9b 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -11,6 +11,15 @@ import numpy as np import xarray as xr from arviz_base.labels import BaseLabeller +from arviz_stats.numba import array_stats + + +def error_bar(da, target, backend, quantiles_dataset, x=None, y=None, **kwargs): + """Plot error bars.""" + plot_backend = import_module(f"arviz_plots.backend.{backend}") + probs, yerr = _process_da_x_y(da, x, y) + quantile_values = quantiles_dataset + return plot_backend.errorbar(probs, quantile_values, yerr, target, **kwargs) def hist(da, target, backend, **kwargs): @@ -89,6 +98,18 @@ def scatter_x(da, target, backend, y=None, **kwargs): return plot_backend.scatter(da, y, target, **kwargs) +def scatter_xy(da, target, backend, x=None, y=None, **kwargs): + """Plot a scatter plot x vs y. + + The input argument `da` is split into x and y using the dimension ``plot_axis``. + If additional x and y arguments are provided, x and y are added to the values + in the `da` dataset sliced along plot_axis='x' and plot_axis='y'. + """ + plot_backend = import_module(f"arviz_plots.backend.{backend}") + x, y = _process_da_x_y(da, x, y) + return plot_backend.scatter(x, y, target, **kwargs) + + def ecdf_line(values, target, backend, **kwargs): """Plot an ecdf line.""" plot_backend = import_module(f"arviz_plots.backend.{backend}") @@ -140,6 +161,18 @@ def _ensure_scalar(*args): return tuple(arg.item() if hasattr(arg, "item") else arg for arg in args) +def annotate_xy(da, target, backend, *, text, x=None, y=None, vertical_align=None, **kwargs): + """Annotate a point (x, y) in a plot.""" + if vertical_align is not None: + if hasattr(vertical_align, "item"): + kwargs["vertical_align"] = vertical_align.item() + else: + kwargs["vertical_align"] = vertical_align # if a string and not a dataarray + x, y = _process_da_x_y(da, x, y) + plot_backend = import_module(f"arviz_plots.backend.{backend}") + return plot_backend.text(x, y, text, target, **kwargs) + + def point_estimate_text( da, target, backend, *, point_estimate, x=None, y=None, point_label="x", **kwargs ): diff --git a/tests/test_hypothesis_plots.py b/tests/test_hypothesis_plots.py index 77bcb44a..ab64c5f8 100644 --- a/tests/test_hypothesis_plots.py +++ b/tests/test_hypothesis_plots.py @@ -8,7 +8,7 @@ from datatree import DataTree from hypothesis import given -from arviz_plots import plot_dist, plot_forest, plot_ridge +from arviz_plots import plot_dist, plot_forest, plot_mcse, plot_ridge pytestmark = pytest.mark.usefixtures("no_artist_kwargs") @@ -192,3 +192,59 @@ def test_plot_ridge(datatree, combined, plot_kwargs, labels_shade_label): assert all(key in child for child in pc.viz.children.values()) elif key not in ("remove_axis", "ticklabels"): assert all(key in child for child in pc.viz.children.values()) + + +mcse_rug = st.booleans() +mcse_extra_methods = st.booleans() + + +@st.composite +def mcse_n_points(draw): + return draw(st.integers(min_value=1, max_value=50)) # should this range be changed? + + +@given( + plot_kwargs=st.fixed_dictionaries( + {}, + optional={ + "mcse": plot_kwargs_value, + "rug": st.sampled_from(({}, {"color": "red"})), + "xlabel": st.sampled_from(({}, {"color": "red"})), + "ylabel": st.sampled_from(({}, {"color": "red"})), + "mean": plot_kwargs_value, + "mean_text": plot_kwargs_value, + "sd": plot_kwargs_value, + "sd_text": plot_kwargs_value, + "title": plot_kwargs_value, + "remove_axis": st.just(False), + }, + ), + rug=mcse_rug, + n_points=mcse_n_points(), + extra_methods=mcse_extra_methods, +) +def test_plot_mcse(datatree, rug, n_points, extra_methods, plot_kwargs): + pc = plot_mcse( + datatree, + backend="none", + rug=rug, + n_points=n_points, + extra_methods=extra_methods, + plot_kwargs=plot_kwargs, + ) + assert all("plot" in child for child in pc.viz.children.values()) + for key, value in plot_kwargs.items(): + if value is False: + assert all(key not in child for child in pc.viz.children.values()) + elif key in ["mean", "sd", "mean_text", "sd_text"]: + if extra_methods is False: + assert all(key not in child for child in pc.viz.children.values()) + else: + assert all(key in child for child in pc.viz.children.values()) + elif key == "rug": + if rug is False: + assert all(key not in child for child in pc.viz.children.values()) + else: + assert all(key in child for child in pc.viz.children.values()) + elif key != "remove_axis": + assert all(key in child for child in pc.viz.children.values()) diff --git a/tests/test_plots.py b/tests/test_plots.py index 64532af4..fc46a6f6 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -9,6 +9,7 @@ plot_compare, plot_dist, plot_forest, + plot_mcse, plot_ridge, plot_trace, plot_trace_dist, @@ -110,7 +111,7 @@ def cmp(): @pytest.mark.parametrize("backend", ["matplotlib", "bokeh", "plotly", "none"]) -class TestPlots: +class TestPlots: # pylint: disable=too-many-public-methods @pytest.mark.parametrize("kind", ["kde", "hist", "ecdf"]) def test_plot_dist(self, datatree, backend, kind): pc = plot_dist(datatree, backend=backend, kind=kind) @@ -291,3 +292,49 @@ def test_plot_compare_kwargs(self, cmp, backend): pc_kwargs={"plot_grid_kws": {"figsize": (1000, 200)}}, backend=backend, ) + + def test_plot_mcse(self, datatree, backend): + pc = plot_mcse(datatree, backend=backend, rug=True) + assert "chart" in pc.viz.data_vars + assert "plot" not in pc.viz.data_vars + assert "mcse" in pc.viz["mu"] + assert "title" in pc.viz["mu"] + assert "rug" in pc.viz["mu"] + assert "hierarchy" not in pc.viz["mu"].dims + assert "hierarchy" in pc.viz["theta"].dims + assert pc.viz["mu"].rug.shape == (4,) # checking rug artist shape (4 chains overlaid) + # checking aesthetics + assert "overlay" in pc.aes["mu"].data_vars # overlay of chains + + a = """ + def test_plot_mcse_sample(self, datatree_sample, backend): + pc = plot_mcse(datatree_sample, backend=backend, rug=True, sample_dims="sample") + assert "chart" in pc.viz.data_vars + assert "plot" not in pc.viz.data_vars + assert "mcse" in pc.viz["mu"] + assert "title" in pc.viz["mu"] + assert "rug" in pc.viz["mu"] + assert "hierarchy" not in pc.viz["mu"].dims + assert "hierarchy" in pc.viz["theta"].dims + assert pc.viz["mu"].trace.shape == () # 0 chains here, so no overlay + + """ # error when running this: ValueError: dims must be of length 2 (from arviz-stats ) + if a: + print("a") # just to temporarily suspend linting error + + def test_plot_mcse_models(self, datatree, datatree2, backend): + pc = plot_mcse({"c": datatree, "n": datatree2}, backend=backend, rug=False) + assert "chart" in pc.viz.data_vars + assert "plot" not in pc.viz.data_vars + assert "mcse" in pc.viz["mu"] + assert "title" in pc.viz["mu"] + # assert "rug" in pc.viz["mu"] + assert "hierarchy" not in pc.viz["mu"].dims + assert "hierarchy" in pc.viz["theta"].dims + assert "model" in pc.viz["mu"].dims + # assert pc.viz["mu"].rug.shape == (2, 4) # since there are 2 Ms, 4 Cs in datatree + # checking aesthetics + assert "model" in pc.aes["mu"].dims + assert "x" in pc.aes["mu"].data_vars + assert "color" in pc.aes["mu"].data_vars + assert "overlay" in pc.aes["mu"].data_vars # overlay of chains