From d2c30cb3da2ca32d3093bd9dab607da40ede0c0c Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Wed, 3 Jul 2024 14:26:14 +0530 Subject: [PATCH 01/24] First commit for essplot and scatter_xy visual element --- src/arviz_plots/plots/__init__.py | 2 + src/arviz_plots/plots/essplot.py | 251 ++++++++++++++++++++++++++++ src/arviz_plots/visuals/__init__.py | 9 + 3 files changed, 262 insertions(+) create mode 100644 src/arviz_plots/plots/essplot.py diff --git a/src/arviz_plots/plots/__init__.py b/src/arviz_plots/plots/__init__.py index 72304d5..25dcd4f 100644 --- a/src/arviz_plots/plots/__init__.py +++ b/src/arviz_plots/plots/__init__.py @@ -2,6 +2,7 @@ from .compareplot import plot_compare from .distplot import plot_dist +from .essplot import plot_ess from .forestplot import plot_forest from .ridgeplot import plot_ridge from .tracedistplot import plot_trace_dist @@ -13,5 +14,6 @@ "plot_forest", "plot_trace", "plot_trace_dist", + "plot_ess", "plot_ridge", ] diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py new file mode 100644 index 0000000..262f194 --- /dev/null +++ b/src/arviz_plots/plots/essplot.py @@ -0,0 +1,251 @@ +"""ess plot code.""" + +# imports +# import warnings +from copy import copy + +import arviz_stats # pylint: disable=unused-import +import numpy as np +import xarray as xr +from arviz_base import rcParams +from arviz_base.labels import BaseLabeller + +from arviz_plots.plot_collection import PlotCollection +from arviz_plots.plots.utils import filter_aes, process_group_variables_coords +from arviz_plots.visuals import labelled_title, scatter_xy + + +# function signature +def plot_ess( + # initial base arguments + dt, + var_names=None, + filter_vars=None, + group="posterior", + coords=None, + sample_dims=None, + # plot specific arguments + kind="local", + relative=False, + # rug=False, + # rug_kind="diverging", + n_points=20, + # extra_methods=False, + # min_ess=400, + # more base arguments + plot_collection=None, + backend=None, + labeller=None, + aes_map=None, + plot_kwargs=None, + stats_kwargs=None, + pc_kwargs=None, +): + """Plot effective sample size plots. + + Parameters + ---------- + dt : DataTree or dict of {str : DataTree} + Input data. In case of dictionary input, the keys are taken to be model names. + In such cases, a dimension "model" is generated and can be used to map to aesthetics. + var_names : str or sequence of str, optional + One or more variables to be plotted. + Prefix the variables by ~ when you want to exclude them from the plot. + filter_vars : {None, “like”, “regex”}, default None + If None, interpret `var_names` as the real variables names. + If “like”, interpret `var_names` as substrings of the real variables names. + If “regex”, interpret `var_names` as regular expressions on the real variables names. + group : str, default "posterior" + Group to be plotted. + coords : dict, optional + sample_dims : str or sequence of hashable, optional + Dimensions to reduce unless mapped to an aesthetic. + Defaults to ``rcParams["data.sample_dims"]`` + kind : {"local", "quantile", "evolution"}, default "local" + Specify the kind of plot: + + * The ``kind="local"`` argument generates the ESS' local efficiency for + estimating quantiles of a desired posterior. + * The ``kind="quantile"`` argument generates the ESS' local efficiency + for estimating small-interval probability of a desired posterior. + * The ``kind="evolution"`` argument generates the estimated ESS' + with incrised number of iterations of a desired posterior. + WIP: add the other kinds for each kind of ess computation in arviz stats + + relative : bool, default False + Show relative ess in plot ``ress = ess / N``. + rug : bool, default False + Add a `rug plot `_ for a specific subset of values. + rug_kind : str, default "diverging" + Variable in sample stats to use as rug mask. Must be a boolean variable. + n_points : int, default 20 + Number of points for which to plot their quantile/local ess or number of subsets + in the evolution plot. + extra_methods : bool, default False + Plot mean and sd ESS as horizontal lines. Not taken into account if ``kind = 'evolution'``. + min_ess : int, default 400 + Minimum number of ESS desired. If ``relative=True`` the line is plotted at + ``min_ess / n_samples`` for local and quantile kinds and as a curve following + the ``min_ess / n`` dependency in evolution kind. + plot_collection : PlotCollection, optional + backend : {"matplotlib", "bokeh"}, optional + labeller : labeller, optional + aes_map : mapping of {str : sequence of str or False}, optional + Mapping of artists to aesthetics that should use their mapping in `plot_collection` + when plotted. Valid keys are the same as for `plot_kwargs`. + + plot_kwargs : mapping of {str : mapping or False}, optional + Valid keys are: + + * One of "local", "quantile", "evolution", matching the `kind` argument. + * "local" -> passed to :func:`~arviz_plots.visuals.scatter_xy` + * "quantile" -> passed to :func:`~arviz_plots.visuals.line_xy` + * "evolution" -> passed to :func:`~arviz_plots.visuals.line_xy` + + * divergence -> passed to :func:`~.visuals.trace_rug` + * title -> passed to :func:`~arviz_plots.visuals.labelled_title` + + stats_kwargs : mapping, optional + Valid keys are: + + * ess -> passed to ess + + pc_kwargs : mapping + Passed to :class:`arviz_plots.PlotCollection.wrap` + + Returns + ------- + PlotCollection + """ + # initial defaults + if sample_dims is None: + sample_dims = rcParams["data.sample_dims"] + if isinstance(sample_dims, str): + sample_dims = [sample_dims] + + # mutable inputs + if plot_kwargs is None: + plot_kwargs = {} + if pc_kwargs is None: + pc_kwargs = {} + else: + pc_kwargs = pc_kwargs.copy() + + if stats_kwargs is None: + stats_kwargs = {} + + # processing dt/group/coords/filtering + distribution = process_group_variables_coords( + dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords + ) + + # set plot collection initialization defaults if it doesnt exist + if plot_collection is None: + if backend is None: + backend = rcParams["plot.backend"] + pc_kwargs.setdefault("col_wrap", 5) + pc_kwargs.setdefault( + "cols", + ["__variable__"] + + [dim for dim in distribution.dims if dim not in {"model"}.union(sample_dims)], + ) + if "model" in distribution: + pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() + pc_kwargs["aes"].setdefault("color", ["model"]) + plot_collection = PlotCollection.wrap( + distribution, + backend=backend, + **pc_kwargs, + ) + + # set plot collection dependent defaults (like aesthetics mappings for each artist) + if aes_map is None: + aes_map = {} + else: + aes_map = aes_map.copy() + aes_map.setdefault(kind, plot_collection.aes_set) + if labeller is None: + labeller = BaseLabeller() + + # compute and add ess subplots + # step 1 + ess_kwargs = copy(plot_kwargs.get(kind, {})) + + if ess_kwargs is not False: + # step 2 + ess_dims, _, ess_ignore = filter_aes(plot_collection, aes_map, kind, sample_dims) + if kind == "local": + probs = np.linspace(0, 1, n_points, endpoint=False) + xdata = probs + + # step 3 + ess_y_dataset = xr.concat( + [ + distribution.azstats.ess( + dims=ess_dims, + method="local", + relative=relative, + prob=[p, (p + 1 / n_points)], + **stats_kwargs.get("ess", {}), + ) + for p in probs + ], + dim="ess_dim", + ) + print(f"\n ess_y_dataset = {ess_y_dataset}") + + # broadcasting xdata to match ess_y_dataset's shape + xdata_da = xr.DataArray(xdata, dims="ess_dim") + print(f"\n xdata_da ={xdata_da}") + + # broadcasting xdata_da to match shape of each variable in ess_y_dataset + xdata_broadcasted_dict = {} + for var in ess_y_dataset.data_vars: + _, xdata_broadcasted = xr.broadcast(ess_y_dataset[var], xdata_da) + xdata_broadcasted_dict[var] = xdata_broadcasted + + # creating a new dataset from dict of broadcasted xdata + xdata_dataset = xr.Dataset(xdata_broadcasted_dict) + print(f"\n xdata_dataset = {xdata_dataset}") + + # concatenating xdata_dataset and ess_y_dataset along plot_axis + ess_dataset = xr.concat([xdata_dataset, ess_y_dataset], dim="plot_axis") + + # assigning 'x' and 'y' coordinates to the 'plot_axis' dimension + ess_dataset["plot_axis"] = ["x", "y"] + print(f"\n ess_dataset = {ess_dataset}") + + # step 4 + # if "color" not in ess_aes: + # ess_kwargs.setdefault("color", "gray") + + # step 5 + plot_collection.map( + scatter_xy, "local", data=ess_dataset, ignore_aes=ess_ignore, **ess_kwargs + ) + + # WIP: repeat previous pattern for all ess methods as kind='method' + + # all the ess methods supported in arviz stats: + # valid_methods = { + # "bulk", "tail", "mean", "sd", "quantile", "local", "median", "mad", + # "z_scale", "folded", "identity" + # } + + # plot titles for each facetted subplot + title_kwargs = copy(plot_kwargs.get("title", {})) + + if title_kwargs is not False: + _, title_aes, title_ignore = filter_aes(plot_collection, aes_map, "title", sample_dims) + if "color" not in title_aes: + title_kwargs.setdefault("color", "black") + plot_collection.map( + labelled_title, + "title", + ignore_aes=title_ignore, + subset_info=True, + labeller=labeller, + **title_kwargs, + ) + + return plot_collection diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index 406bf42..9c79d46 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -89,6 +89,15 @@ def scatter_x(da, target, backend, y=None, **kwargs): return plot_backend.scatter(da, y, target, **kwargs) +def scatter_xy(da, target, backend, **kwargs): + """Plot a scatter plot x vs y. + + The input argument `da` is split into x and y using the dimension ``plot_axis``. + """ + plot_backend = import_module(f"arviz_plots.backend.{backend}") + return plot_backend.scatter(da.sel(plot_axis="x"), da.sel(plot_axis="y"), target, **kwargs) + + def ecdf_line(values, target, backend, **kwargs): """Plot an ecdf line.""" plot_backend = import_module(f"arviz_plots.backend.{backend}") From 0e662cd5ccbfae7644ef05513d7e12028d9c1a71 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Mon, 8 Jul 2024 15:32:38 +0530 Subject: [PATCH 02/24] update for ess plot and addition of 'x' aesthetic for 'model' dim --- src/arviz_plots/plots/essplot.py | 33 +++++++++++++---------------- src/arviz_plots/visuals/__init__.py | 8 +++++-- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index 262f194..deabfa1 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -64,10 +64,10 @@ def plot_ess( kind : {"local", "quantile", "evolution"}, default "local" Specify the kind of plot: - * The ``kind="local"`` argument generates the ESS' local efficiency for - estimating quantiles of a desired posterior. - * The ``kind="quantile"`` argument generates the ESS' local efficiency + * The ``kind="local"`` argument generates the ESS' local efficiency for estimating small-interval probability of a desired posterior. + * The ``kind="quantile"`` argument generates the ESS' local efficiency + for estimating quantiles of a desired posterior. * The ``kind="evolution"`` argument generates the estimated ESS' with incrised number of iterations of a desired posterior. WIP: add the other kinds for each kind of ess computation in arviz stats @@ -152,6 +152,7 @@ def plot_ess( if "model" in distribution: pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs["aes"].setdefault("color", ["model"]) + pc_kwargs["aes"].setdefault("x", ["model"]) plot_collection = PlotCollection.wrap( distribution, backend=backend, @@ -192,28 +193,24 @@ def plot_ess( ], dim="ess_dim", ) - print(f"\n ess_y_dataset = {ess_y_dataset}") + # print(f"\n ess_y_dataset = {ess_y_dataset}") # broadcasting xdata to match ess_y_dataset's shape xdata_da = xr.DataArray(xdata, dims="ess_dim") - print(f"\n xdata_da ={xdata_da}") - - # broadcasting xdata_da to match shape of each variable in ess_y_dataset - xdata_broadcasted_dict = {} - for var in ess_y_dataset.data_vars: - _, xdata_broadcasted = xr.broadcast(ess_y_dataset[var], xdata_da) - xdata_broadcasted_dict[var] = xdata_broadcasted + # print(f"\n xdata_da ={xdata_da}") + # broadcasting xdata_da to match shape of each variable in ess_y_dataset and # creating a new dataset from dict of broadcasted xdata - xdata_dataset = xr.Dataset(xdata_broadcasted_dict) - print(f"\n xdata_dataset = {xdata_dataset}") + xdata_dataset = xr.Dataset( + {var_name: xdata_da.broadcast_like(da) for var_name, da in ess_y_dataset.items()} + ) + # print(f"\n xdata_dataset = {xdata_dataset}") # concatenating xdata_dataset and ess_y_dataset along plot_axis - ess_dataset = xr.concat([xdata_dataset, ess_y_dataset], dim="plot_axis") - - # assigning 'x' and 'y' coordinates to the 'plot_axis' dimension - ess_dataset["plot_axis"] = ["x", "y"] - print(f"\n ess_dataset = {ess_dataset}") + ess_dataset = xr.concat([xdata_dataset, ess_y_dataset], dim="plot_axis").assign_coords( + plot_axis=["x", "y"] + ) + print(f"\n ess_dataset = {ess_dataset!r}") # step 4 # if "color" not in ess_aes: diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index 9c79d46..1639021 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -89,13 +89,17 @@ def scatter_x(da, target, backend, y=None, **kwargs): return plot_backend.scatter(da, y, target, **kwargs) -def scatter_xy(da, target, backend, **kwargs): +def scatter_xy(da, target, backend, x=None, **kwargs): """Plot a scatter plot x vs y. The input argument `da` is split into x and y using the dimension ``plot_axis``. """ plot_backend = import_module(f"arviz_plots.backend.{backend}") - return plot_backend.scatter(da.sel(plot_axis="x"), da.sel(plot_axis="y"), target, **kwargs) + # note: temporary patch only before plot_ridge merge to use _process_da_x_y + # print(f"\n x arg = {x}") + x = da.sel(plot_axis="x") if x is None else da.sel(plot_axis="x") + x + # print(f"\n to plot x = {x}") + return plot_backend.scatter(x, da.sel(plot_axis="y"), target, **kwargs) def ecdf_line(values, target, backend, **kwargs): From 32167b34f6e4c6926a83d641d1b8be3a82c234f4 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Sun, 14 Jul 2024 17:34:08 +0530 Subject: [PATCH 03/24] addition of quantile plot and updated x aesthetic mapping --- src/arviz_plots/plots/essplot.py | 153 +++++++++++++++++++++++++++++-- 1 file changed, 146 insertions(+), 7 deletions(-) diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index deabfa1..90c60ab 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -12,7 +12,12 @@ from arviz_plots.plot_collection import PlotCollection from arviz_plots.plots.utils import filter_aes, process_group_variables_coords -from arviz_plots.visuals import labelled_title, scatter_xy +from arviz_plots.visuals import ( # trace_rug,; line_x,; line_fixed_x, + labelled_title, + labelled_x, + labelled_y, + scatter_xy, +) # function signature @@ -97,18 +102,23 @@ def plot_ess( plot_kwargs : mapping of {str : mapping or False}, optional Valid keys are: - * One of "local", "quantile", "evolution", matching the `kind` argument. + * One of "local" or "quantile", matching the `kind` argument. * "local" -> passed to :func:`~arviz_plots.visuals.scatter_xy` - * "quantile" -> passed to :func:`~arviz_plots.visuals.line_xy` - * "evolution" -> passed to :func:`~arviz_plots.visuals.line_xy` + * "quantile" -> passed to :func:`~arviz_plots.visuals.scatter_xy` * divergence -> passed to :func:`~.visuals.trace_rug` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` + * label -> passed to :func:`~arviz_plots.visuals.labelled_x` and + :func:`~arviz_plots.visuals.labelled_y` + * mean -> + * sd -> stats_kwargs : mapping, optional Valid keys are: - * ess -> passed to ess + * ess -> passed to ess, method = 'local' or 'quantile' based on `kind` + * mean -> passed to ess, method='mean' + * sd -> passed to ess, method='sd' pc_kwargs : mapping Passed to :class:`arviz_plots.PlotCollection.wrap` @@ -152,6 +162,11 @@ def plot_ess( if "model" in distribution: pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs["aes"].setdefault("color", ["model"]) + # setting x aesthetic to np.linspace(-x_diff/3, x_diff/3, length of 'model' dim) + # x_diff = span of x axis (1) divided by number of points to be plotted (n_points) + x_diff = 1 / n_points + if "x" not in pc_kwargs: + pc_kwargs["x"] = np.linspace(-x_diff / 3, x_diff / 3, distribution.sizes["model"]) pc_kwargs["aes"].setdefault("x", ["model"]) plot_collection = PlotCollection.wrap( distribution, @@ -165,6 +180,10 @@ def plot_ess( else: aes_map = aes_map.copy() aes_map.setdefault(kind, plot_collection.aes_set) + aes_map.setdefault("mean", plot_collection.aes_set) + aes_map.setdefault("mean", ["linestyle", "-"]) + aes_map.setdefault("sd", plot_collection.aes_set) + aes_map.setdefault("sd", ["linestyle", "-"]) if labeller is None: labeller = BaseLabeller() @@ -178,6 +197,7 @@ def plot_ess( if kind == "local": probs = np.linspace(0, 1, n_points, endpoint=False) xdata = probs + ylabel = "{} for small intervals" # step 3 ess_y_dataset = xr.concat( @@ -195,7 +215,7 @@ def plot_ess( ) # print(f"\n ess_y_dataset = {ess_y_dataset}") - # broadcasting xdata to match ess_y_dataset's shape + # converting xdata into a xr dataarray xdata_da = xr.DataArray(xdata, dims="ess_dim") # print(f"\n xdata_da ={xdata_da}") @@ -221,7 +241,50 @@ def plot_ess( scatter_xy, "local", data=ess_dataset, ignore_aes=ess_ignore, **ess_kwargs ) - # WIP: repeat previous pattern for all ess methods as kind='method' + # WIP: repeat previous pattern for ess method kind = 'quantile' + if kind == "quantile": + probs = np.linspace(1 / n_points, 1 - 1 / n_points, n_points) + xdata = probs + ylabel = "{} for quantiles" + + # step 3 + ess_y_dataset = xr.concat( + [ + distribution.azstats.ess( + dims=ess_dims, + method="quantile", + relative=relative, + prob=p, + **stats_kwargs.get("ess", {}), + ) + for p in probs + ], + dim="ess_dim", + ) + + # converting xdata into an xr datarray + xdata_da = xr.DataArray(xdata, dims="ess_dim") + + # broadcasting xdata_da to match shape of each variable in ess_y_dataset and + # creating a new dataset from dict of broadcasted xdata + xdata_dataset = xr.Dataset( + {var_name: xdata_da.broadcast_like(da) for var_name, da in ess_y_dataset.items()} + ) + + # concatenating xdata_dataset and ess_y_dataset along plot_axis + ess_dataset = xr.concat([xdata_dataset, ess_y_dataset], dim="plot_axis").assign_coords( + plot_axis=["x", "y"] + ) + print(f"\n ess_dataset = {ess_dataset!r}") + + # step 4 + # if "color" not in ess_aes: + # ess_kwargs.setdefault("color", "gray") + + # step 5 + plot_collection.map( + scatter_xy, "quantile", data=ess_dataset, ignore_aes=ess_ignore, **ess_kwargs + ) # all the ess methods supported in arviz stats: # valid_methods = { @@ -229,6 +292,45 @@ def plot_ess( # "z_scale", "folded", "identity" # } + # plot rug + + # WIP: plot mean and sd (to be done after plot_ridge PR merge for line_xy update) + a = """xyz. + if extra_methods is not False: + x_range = [0, 1] + x_range = xr.DataArray(x_range) + + mean_kwargs = copy(plot_kwargs.get("mean", {})) + if mean_kwargs is not False: + mean_dims, _, mean_ignore = filter_aes(plot_collection, aes_map, "mean", sample_dims) + mean_ess = distribution.azstats.ess( + dims=mean_dims, method="mean", relative=relative, **stats_kwargs.get("mean", {}) + ) + print(f"\nmean_ess = {mean_ess!r}") + + plot_collection.map( + line_x, + "mean", + data=mean_ess, + x=x_range, + ignore_aes=mean_ignore, + **mean_kwargs, + ) + + sd_kwargs = copy(plot_kwargs.get("sd", {})) + if sd_kwargs is not False: + sd_dims, _, sd_ignore = filter_aes(plot_collection, aes_map, "sd", sample_dims) + sd_ess = distribution.azstats.ess( + dims=sd_dims, method="sd", relative=relative, **stats_kwargs.get("sd", {}) + ) + print(f"\nsd_ess = {sd_ess!r}") + + plot_collection.map( + line_fixed_x, "sd", data=sd_ess, x=x_range, ignore_aes=sd_ignore, **sd_kwargs + ) + """ + print(a) + # plot titles for each facetted subplot title_kwargs = copy(plot_kwargs.get("title", {})) @@ -245,4 +347,41 @@ def plot_ess( **title_kwargs, ) + # plot x and y axis labels + # Add varnames as x and y labels + _, labels_aes, labels_ignore = filter_aes(plot_collection, aes_map, "label", sample_dims) + label_kwargs = plot_kwargs.get("label", {}).copy() + + if "color" not in labels_aes: + label_kwargs.setdefault("color", "black") + + # label_kwargs.setdefault("size", textsize) + + # formatting ylabel and setting xlabel + if relative is not False: + ylabel = ylabel.format("Relative ESS") + else: + ylabel = ylabel.format("ESS") + xlabel = "Quantile" + + plot_collection.map( + labelled_x, + "xlabel", + ignore_aes=labels_ignore, + subset_info=True, + text=xlabel, + store_artist=False, + **label_kwargs, + ) + + plot_collection.map( + labelled_y, + "ylabel", + ignore_aes=labels_ignore, + subset_info=True, + text=ylabel, + store_artist=False, + **label_kwargs, + ) + return plot_collection From c6a8c2d841780163f3b2c2310a19e34bec653795 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Tue, 16 Jul 2024 16:53:01 +0530 Subject: [PATCH 04/24] Added rugplot to essplot --- src/arviz_plots/plots/essplot.py | 87 +++++++++++++++++++++++------ src/arviz_plots/visuals/__init__.py | 4 +- 2 files changed, 72 insertions(+), 19 deletions(-) diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index 90c60ab..c5f90e6 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -11,13 +11,8 @@ from arviz_base.labels import BaseLabeller from arviz_plots.plot_collection import PlotCollection -from arviz_plots.plots.utils import filter_aes, process_group_variables_coords -from arviz_plots.visuals import ( # trace_rug,; line_x,; line_fixed_x, - labelled_title, - labelled_x, - labelled_y, - scatter_xy, -) +from arviz_plots.plots.utils import filter_aes, get_group, process_group_variables_coords +from arviz_plots.visuals import labelled_title, labelled_x, labelled_y, scatter_xy, trace_rug # function signature @@ -32,8 +27,8 @@ def plot_ess( # plot specific arguments kind="local", relative=False, - # rug=False, - # rug_kind="diverging", + rug=False, + rug_kind="diverging", n_points=20, # extra_methods=False, # min_ess=400, @@ -159,8 +154,12 @@ def plot_ess( ["__variable__"] + [dim for dim in distribution.dims if dim not in {"model"}.union(sample_dims)], ) + pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() + if "chain" in distribution: + pc_kwargs["aes"].setdefault("overlay", ["chain"]) # so rug for each chain is overlaid + # doing this^ sets overlay: chain for each artist. But we only want overlay for the + # rug so .difference() has to be be applied to the aes_map defaults if not wanted if "model" in distribution: - pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs["aes"].setdefault("color", ["model"]) # setting x aesthetic to np.linspace(-x_diff/3, x_diff/3, length of 'model' dim) # x_diff = span of x axis (1) divided by number of points to be plotted (n_points) @@ -168,22 +167,30 @@ def plot_ess( if "x" not in pc_kwargs: pc_kwargs["x"] = np.linspace(-x_diff / 3, x_diff / 3, distribution.sizes["model"]) pc_kwargs["aes"].setdefault("x", ["model"]) + aux_dim_list = [dim for dim in pc_kwargs["cols"] if dim != "__variable__"] # for divergence plot_collection = PlotCollection.wrap( distribution, backend=backend, **pc_kwargs, ) + else: + aux_dim_list = list( + set( + dim for child in plot_collection.viz.children.values() for dim in child["plot"].dims + ) + ) # set plot collection dependent defaults (like aesthetics mappings for each artist) if aes_map is None: aes_map = {} else: aes_map = aes_map.copy() - aes_map.setdefault(kind, plot_collection.aes_set) + aes_map.setdefault(kind, plot_collection.aes_set.difference({"overlay"})) aes_map.setdefault("mean", plot_collection.aes_set) aes_map.setdefault("mean", ["linestyle", "-"]) aes_map.setdefault("sd", plot_collection.aes_set) aes_map.setdefault("sd", ["linestyle", "-"]) + aes_map.setdefault("divergence", {"overlay"}) if labeller is None: labeller = BaseLabeller() @@ -230,6 +237,7 @@ def plot_ess( ess_dataset = xr.concat([xdata_dataset, ess_y_dataset], dim="plot_axis").assign_coords( plot_axis=["x", "y"] ) + print(f"\n distribution = {distribution!r}") print(f"\n ess_dataset = {ess_dataset!r}") # step 4 @@ -241,7 +249,6 @@ def plot_ess( scatter_xy, "local", data=ess_dataset, ignore_aes=ess_ignore, **ess_kwargs ) - # WIP: repeat previous pattern for ess method kind = 'quantile' if kind == "quantile": probs = np.linspace(1 / n_points, 1 - 1 / n_points, n_points) xdata = probs @@ -286,13 +293,57 @@ def plot_ess( scatter_xy, "quantile", data=ess_dataset, ignore_aes=ess_ignore, **ess_kwargs ) - # all the ess methods supported in arviz stats: - # valid_methods = { - # "bulk", "tail", "mean", "sd", "quantile", "local", "median", "mad", - # "z_scale", "folded", "identity" - # } - # plot rug + # overlaying divergences for each chain + if rug: + sample_stats = get_group(dt, "sample_stats", allow_missing=True) + divergence_kwargs = copy(plot_kwargs.get("divergence", {})) + if ( + sample_stats is not None + and "diverging" in sample_stats.data_vars + and np.any(sample_stats[rug_kind]) # 'diverging' by default + and divergence_kwargs is not False + ): + divergence_mask = dt.sample_stats[rug_kind] # 'diverging' by default + print(f"\n divergence_mask = {divergence_mask!r}") + _, div_aes, div_ignore = filter_aes(plot_collection, aes_map, "divergence", sample_dims) + if "color" not in div_aes: + divergence_kwargs.setdefault("color", "black") + if "marker" not in div_aes: + divergence_kwargs.setdefault("marker", "|") + # if "width" not in div_aes: # should this be hardcoded? + # divergence_kwargs.setdefault("width", linewidth) + if "size" not in div_aes: + divergence_kwargs.setdefault("size", 30) + div_reduce_dims = [dim for dim in distribution.dims if dim not in aux_dim_list] + + # xname is used to pick subset of dataset in map() to be masked + xname = None # xname logic from traceplot + default_xname = sample_dims[0] if len(sample_dims) == 1 else "draw" + if (default_xname not in distribution.dims) or ( + not np.issubdtype(distribution[default_xname].dtype, np.number) + ): + default_xname = None + xname = divergence_kwargs.get("xname", default_xname) + divergence_kwargs["xname"] = xname + print(f"\n div_reduce_dims = {div_reduce_dims!r}") + print(f"\n xname = {xname}") + + draw_length = distribution.sizes["draw"] # used to scale xvalues to between 0-1 + + # print(f"\n distribution = {distribution}") + + plot_collection.map( + trace_rug, + "divergence", + data=distribution, + ignore_aes=div_ignore, + # xname=xname, + y=distribution.min(div_reduce_dims), + mask=divergence_mask, + scale=draw_length, + **divergence_kwargs, + ) # WIP: plot mean and sd (to be done after plot_ridge PR merge for line_xy update) a = """xyz. diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index 1639021..00fa36b 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -60,7 +60,7 @@ def line(da, target, backend, xname=None, **kwargs): return plot_backend.line(xvalues, yvalues, target, **kwargs) -def trace_rug(da, target, backend, mask, xname=None, y=None, **kwargs): +def trace_rug(da, target, backend, mask, xname=None, y=None, scale=1, **kwargs): """Create a rug plot with the subset of `da` indicated by `mask`.""" xname = xname.item() if hasattr(xname, "item") else xname if xname is False: @@ -76,6 +76,8 @@ def trace_rug(da, target, backend, mask, xname=None, y=None, **kwargs): y = da.min().item() if len(xvalues.shape) != 1: raise ValueError(f"Expected unidimensional data but got {xvalues.sizes}") + xvalues = xvalues / scale + # print(f"\n trace_rug call. xvalues = {xvalues}\nmask = {mask}") return scatter_x(xvalues[mask], target=target, backend=backend, y=y, **kwargs) From eacdb4334ceb2f993dc3e2b4bd0fd1f47cd17f08 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Tue, 23 Jul 2024 15:36:49 +0530 Subject: [PATCH 05/24] updates to essplot --- src/arviz_plots/plots/essplot.py | 372 +++++++++++++++---------------- 1 file changed, 183 insertions(+), 189 deletions(-) diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index c5f90e6..8ee8b95 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -12,27 +12,30 @@ from arviz_plots.plot_collection import PlotCollection from arviz_plots.plots.utils import filter_aes, get_group, process_group_variables_coords -from arviz_plots.visuals import labelled_title, labelled_x, labelled_y, scatter_xy, trace_rug +from arviz_plots.visuals import ( + labelled_title, + labelled_x, + labelled_y, + line_xy, + scatter_xy, + trace_rug, +) -# function signature def plot_ess( - # initial base arguments dt, var_names=None, filter_vars=None, group="posterior", coords=None, sample_dims=None, - # plot specific arguments kind="local", relative=False, rug=False, rug_kind="diverging", n_points=20, - # extra_methods=False, - # min_ess=400, - # more base arguments + extra_methods=True, # default False + min_ess=400, plot_collection=None, backend=None, labeller=None, @@ -68,9 +71,6 @@ def plot_ess( for estimating small-interval probability of a desired posterior. * The ``kind="quantile"`` argument generates the ESS' local efficiency for estimating quantiles of a desired posterior. - * The ``kind="evolution"`` argument generates the estimated ESS' - with incrised number of iterations of a desired posterior. - WIP: add the other kinds for each kind of ess computation in arviz stats relative : bool, default False Show relative ess in plot ``ress = ess / N``. @@ -82,11 +82,10 @@ def plot_ess( Number of points for which to plot their quantile/local ess or number of subsets in the evolution plot. extra_methods : bool, default False - Plot mean and sd ESS as horizontal lines. Not taken into account if ``kind = 'evolution'``. + Plot mean and sd ESS as horizontal lines. min_ess : int, default 400 Minimum number of ESS desired. If ``relative=True`` the line is plotted at - ``min_ess / n_samples`` for local and quantile kinds and as a curve following - the ``min_ess / n`` dependency in evolution kind. + ``min_ess / n_samples`` for local and quantile kinds plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh"}, optional labeller : labeller, optional @@ -97,16 +96,18 @@ def plot_ess( plot_kwargs : mapping of {str : mapping or False}, optional Valid keys are: - * One of "local" or "quantile", matching the `kind` argument. - * "local" -> passed to :func:`~arviz_plots.visuals.scatter_xy` - * "quantile" -> passed to :func:`~arviz_plots.visuals.scatter_xy` + * ess -> passed to :func:`~arviz_plots.visuals.scatter_xy` + if `kind`='local', + else passed to :func:`~arviz_plots.visuals.scatter_xy` + if `kind` = 'quantile' - * divergence -> passed to :func:`~.visuals.trace_rug` + * rug -> passed to :func:`~.visuals.trace_rug` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` - * label -> passed to :func:`~arviz_plots.visuals.labelled_x` and - :func:`~arviz_plots.visuals.labelled_y` - * mean -> - * sd -> + * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` + * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` + * mean -> passed to :func:`~arviz.plots.visuals.line_xy` + * sd -> passed to :func:`~arviz.plots.visuals.line_xy` + * min_ess -> passed to :func:`~arviz.plots.visuals.line_xy` stats_kwargs : mapping, optional Valid keys are: @@ -156,18 +157,14 @@ def plot_ess( ) pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() if "chain" in distribution: - pc_kwargs["aes"].setdefault("overlay", ["chain"]) # so rug for each chain is overlaid - # doing this^ sets overlay: chain for each artist. But we only want overlay for the - # rug so .difference() has to be be applied to the aes_map defaults if not wanted + pc_kwargs["aes"].setdefault("overlay", ["chain"]) if "model" in distribution: pc_kwargs["aes"].setdefault("color", ["model"]) - # setting x aesthetic to np.linspace(-x_diff/3, x_diff/3, length of 'model' dim) - # x_diff = span of x axis (1) divided by number of points to be plotted (n_points) - x_diff = 1 / n_points - if "x" not in pc_kwargs: - pc_kwargs["x"] = np.linspace(-x_diff / 3, x_diff / 3, distribution.sizes["model"]) + n_models = distribution.sizes["model"] + x_diff = min(1 / n_points / 3, 1 / n_points * n_models / 10) + pc_kwargs.setdefault("x", np.linspace(-x_diff, x_diff, n_models)) pc_kwargs["aes"].setdefault("x", ["model"]) - aux_dim_list = [dim for dim in pc_kwargs["cols"] if dim != "__variable__"] # for divergence + aux_dim_list = [dim for dim in pc_kwargs["cols"] if dim != "__variable__"] plot_collection = PlotCollection.wrap( distribution, backend=backend, @@ -186,181 +183,133 @@ def plot_ess( else: aes_map = aes_map.copy() aes_map.setdefault(kind, plot_collection.aes_set.difference({"overlay"})) - aes_map.setdefault("mean", plot_collection.aes_set) - aes_map.setdefault("mean", ["linestyle", "-"]) - aes_map.setdefault("sd", plot_collection.aes_set) - aes_map.setdefault("sd", ["linestyle", "-"]) - aes_map.setdefault("divergence", {"overlay"}) + aes_map.setdefault("rug", {"overlay"}) if labeller is None: labeller = BaseLabeller() # compute and add ess subplots - # step 1 - ess_kwargs = copy(plot_kwargs.get(kind, {})) + ess_kwargs = copy(plot_kwargs.get("ess", {})) if ess_kwargs is not False: - # step 2 ess_dims, _, ess_ignore = filter_aes(plot_collection, aes_map, kind, sample_dims) if kind == "local": probs = np.linspace(0, 1, n_points, endpoint=False) - xdata = probs ylabel = "{} for small intervals" - - # step 3 - ess_y_dataset = xr.concat( - [ - distribution.azstats.ess( - dims=ess_dims, - method="local", - relative=relative, - prob=[p, (p + 1 / n_points)], - **stats_kwargs.get("ess", {}), - ) - for p in probs - ], - dim="ess_dim", - ) - # print(f"\n ess_y_dataset = {ess_y_dataset}") - - # converting xdata into a xr dataarray - xdata_da = xr.DataArray(xdata, dims="ess_dim") - # print(f"\n xdata_da ={xdata_da}") - - # broadcasting xdata_da to match shape of each variable in ess_y_dataset and - # creating a new dataset from dict of broadcasted xdata - xdata_dataset = xr.Dataset( - {var_name: xdata_da.broadcast_like(da) for var_name, da in ess_y_dataset.items()} - ) - # print(f"\n xdata_dataset = {xdata_dataset}") - - # concatenating xdata_dataset and ess_y_dataset along plot_axis - ess_dataset = xr.concat([xdata_dataset, ess_y_dataset], dim="plot_axis").assign_coords( - plot_axis=["x", "y"] - ) - print(f"\n distribution = {distribution!r}") - print(f"\n ess_dataset = {ess_dataset!r}") - - # step 4 - # if "color" not in ess_aes: - # ess_kwargs.setdefault("color", "gray") - - # step 5 - plot_collection.map( - scatter_xy, "local", data=ess_dataset, ignore_aes=ess_ignore, **ess_kwargs - ) - - if kind == "quantile": + elif kind == "quantile": probs = np.linspace(1 / n_points, 1 - 1 / n_points, n_points) - xdata = probs ylabel = "{} for quantiles" + xdata = probs + + ess_y_dataset = xr.concat( + [ + distribution.azstats.ess( + dims=ess_dims, + method=kind, + relative=relative, + prob=[p, (p + 1 / n_points)] if kind == "local" else p, + **stats_kwargs.get("ess", {}), + ) + for p in probs + ], + dim="ess_dim", + ) - # step 3 - ess_y_dataset = xr.concat( - [ - distribution.azstats.ess( - dims=ess_dims, - method="quantile", - relative=relative, - prob=p, - **stats_kwargs.get("ess", {}), - ) - for p in probs - ], - dim="ess_dim", - ) - - # converting xdata into an xr datarray - xdata_da = xr.DataArray(xdata, dims="ess_dim") - - # broadcasting xdata_da to match shape of each variable in ess_y_dataset and - # creating a new dataset from dict of broadcasted xdata - xdata_dataset = xr.Dataset( - {var_name: xdata_da.broadcast_like(da) for var_name, da in ess_y_dataset.items()} - ) - - # concatenating xdata_dataset and ess_y_dataset along plot_axis - ess_dataset = xr.concat([xdata_dataset, ess_y_dataset], dim="plot_axis").assign_coords( - plot_axis=["x", "y"] - ) - print(f"\n ess_dataset = {ess_dataset!r}") - - # step 4 - # if "color" not in ess_aes: - # ess_kwargs.setdefault("color", "gray") + xdata_da = xr.DataArray(xdata, dims="ess_dim") + # broadcasting xdata_da to match shape of each variable in ess_y_dataset and + # creating a new dataset from dict of broadcasted xdata + xdata_dataset = xr.Dataset( + {var_name: xdata_da.broadcast_like(da) for var_name, da in ess_y_dataset.items()} + ) + # concatenating xdata_dataset and ess_y_dataset along plot_axis + ess_dataset = xr.concat([xdata_dataset, ess_y_dataset], dim="plot_axis").assign_coords( + plot_axis=["x", "y"] + ) - # step 5 - plot_collection.map( - scatter_xy, "quantile", data=ess_dataset, ignore_aes=ess_ignore, **ess_kwargs - ) + plot_collection.map( + scatter_xy, "ess", data=ess_dataset, ignore_aes=ess_ignore, **ess_kwargs + ) # plot rug - # overlaying divergences for each chain + # overlaying divergences(or other 'rug_kind') for each chain if rug: sample_stats = get_group(dt, "sample_stats", allow_missing=True) - divergence_kwargs = copy(plot_kwargs.get("divergence", {})) + rug_kwargs = copy(plot_kwargs.get("rug", {})) + if rug_kwargs is False: + raise ValueError("plot_kwargs['rug'] can't be False, use rug=False to remove the rug") if ( sample_stats is not None - and "diverging" in sample_stats.data_vars + and rug_kind in sample_stats.data_vars and np.any(sample_stats[rug_kind]) # 'diverging' by default - and divergence_kwargs is not False + and rug_kwargs is not False ): - divergence_mask = dt.sample_stats[rug_kind] # 'diverging' by default - print(f"\n divergence_mask = {divergence_mask!r}") - _, div_aes, div_ignore = filter_aes(plot_collection, aes_map, "divergence", sample_dims) + rug_mask = dt.sample_stats[rug_kind] # 'diverging' by default + _, div_aes, div_ignore = filter_aes(plot_collection, aes_map, "rug", sample_dims) if "color" not in div_aes: - divergence_kwargs.setdefault("color", "black") + rug_kwargs.setdefault("color", "black") if "marker" not in div_aes: - divergence_kwargs.setdefault("marker", "|") - # if "width" not in div_aes: # should this be hardcoded? - # divergence_kwargs.setdefault("width", linewidth) + rug_kwargs.setdefault("marker", "|") + # WIP: if using a default linewidth once defined in backend/agnostic defaults + # if "width" not in div_aes: + # # get default linewidth for backends + # plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") + # default_linewidth = plot_bknd.get_default_aes("linewidth", 1, {}) + # rug_kwargs.setdefault("width", default_linewidth) if "size" not in div_aes: - divergence_kwargs.setdefault("size", 30) + rug_kwargs.setdefault("size", 30) div_reduce_dims = [dim for dim in distribution.dims if dim not in aux_dim_list] # xname is used to pick subset of dataset in map() to be masked - xname = None # xname logic from traceplot + xname = None default_xname = sample_dims[0] if len(sample_dims) == 1 else "draw" if (default_xname not in distribution.dims) or ( not np.issubdtype(distribution[default_xname].dtype, np.number) ): default_xname = None - xname = divergence_kwargs.get("xname", default_xname) - divergence_kwargs["xname"] = xname - print(f"\n div_reduce_dims = {div_reduce_dims!r}") - print(f"\n xname = {xname}") - - draw_length = distribution.sizes["draw"] # used to scale xvalues to between 0-1 + xname = rug_kwargs.get("xname", default_xname) + rug_kwargs["xname"] = xname - # print(f"\n distribution = {distribution}") + draw_length = ( + distribution.sizes[sample_dims[0]] + if len(sample_dims) == 1 + else distribution.sizes["draw"] + ) # used to scale xvalues to between 0-1 plot_collection.map( trace_rug, - "divergence", + "rug", data=distribution, ignore_aes=div_ignore, # xname=xname, y=distribution.min(div_reduce_dims), - mask=divergence_mask, + mask=rug_mask, scale=draw_length, - **divergence_kwargs, - ) + **rug_kwargs, + ) # note: after plot_ppc merge, the `trace_rug` function might change - # WIP: plot mean and sd (to be done after plot_ridge PR merge for line_xy update) - a = """xyz. - if extra_methods is not False: - x_range = [0, 1] - x_range = xr.DataArray(x_range) + # defining x_range (used for mean, sd, minimum ess plotting) + x_range = [0, 1] + x_range = xr.DataArray(x_range) + # plot mean and sd + if extra_methods is not False: mean_kwargs = copy(plot_kwargs.get("mean", {})) if mean_kwargs is not False: - mean_dims, _, mean_ignore = filter_aes(plot_collection, aes_map, "mean", sample_dims) + mean_dims, mean_aes, mean_ignore = filter_aes( + plot_collection, aes_map, "mean", sample_dims + ) mean_ess = distribution.azstats.ess( dims=mean_dims, method="mean", relative=relative, **stats_kwargs.get("mean", {}) ) - print(f"\nmean_ess = {mean_ess!r}") + print(f"\n mean_ess = {mean_ess}") + + if "linestyle" not in mean_aes: + if backend == "matplotlib": + mean_kwargs.setdefault("linestyle", "--") + elif backend == "bokeh": + mean_kwargs.setdefault("linestyle", "dashed") plot_collection.map( - line_x, + line_xy, "mean", data=mean_ess, x=x_range, @@ -370,17 +319,58 @@ def plot_ess( sd_kwargs = copy(plot_kwargs.get("sd", {})) if sd_kwargs is not False: - sd_dims, _, sd_ignore = filter_aes(plot_collection, aes_map, "sd", sample_dims) + sd_dims, sd_aes, sd_ignore = filter_aes(plot_collection, aes_map, "sd", sample_dims) sd_ess = distribution.azstats.ess( dims=sd_dims, method="sd", relative=relative, **stats_kwargs.get("sd", {}) ) - print(f"\nsd_ess = {sd_ess!r}") + print(f"\n sd_ess = {sd_ess}") + + if "linestyle" not in sd_aes: + if backend == "matplotlib": + sd_kwargs.setdefault("linestyle", "--") + elif backend == "bokeh": + sd_kwargs.setdefault("linestyle", "dashed") plot_collection.map( - line_fixed_x, "sd", data=sd_ess, x=x_range, ignore_aes=sd_ignore, **sd_kwargs + line_xy, "sd", data=sd_ess, ignore_aes=sd_ignore, x=x_range, **sd_kwargs + ) + + # plot minimum ess + min_ess_kwargs = copy(plot_kwargs.get("min_ess", {})) + + if min_ess_kwargs is not False: + min_ess_dims, min_ess_aes, min_ess_ignore = filter_aes( + plot_collection, aes_map, "min_ess", sample_dims + ) + + if relative: + min_ess = min_ess / n_points + + # for each variable of distribution, put min_ess as the value, reducing all min_ess_dims + min_ess_data = {} + for var in distribution.data_vars: + reduced_data = distribution[var].mean( + dim=[dim for dim in distribution[var].dims if dim in min_ess_dims] ) - """ - print(a) + min_ess_data[var] = xr.full_like(reduced_data, min_ess) + + min_ess_dataset = xr.Dataset(min_ess_data) + print(f"\n min_ess = {min_ess_dataset}") + + if "linestyle" not in min_ess_aes: + if backend == "matplotlib": + min_ess_kwargs.setdefault("linestyle", "--") + elif backend == "bokeh": + min_ess_kwargs.setdefault("linestyle", "dashed") + + plot_collection.map( + line_xy, + "min_ess", + data=min_ess_dataset, + ignore_aes=min_ess_ignore, + x=x_range, + **min_ess_kwargs, + ) # plot titles for each facetted subplot title_kwargs = copy(plot_kwargs.get("title", {})) @@ -400,39 +390,43 @@ def plot_ess( # plot x and y axis labels # Add varnames as x and y labels - _, labels_aes, labels_ignore = filter_aes(plot_collection, aes_map, "label", sample_dims) - label_kwargs = plot_kwargs.get("label", {}).copy() + _, labels_aes, labels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims) + xlabel_kwargs = plot_kwargs.get("xlabel", {}).copy() + if xlabel_kwargs is not False: + if "color" not in labels_aes: + xlabel_kwargs.setdefault("color", "black") - if "color" not in labels_aes: - label_kwargs.setdefault("color", "black") + # formatting ylabel and setting xlabel + xlabel_kwargs.setdefault("text", "Quantile") - # label_kwargs.setdefault("size", textsize) + plot_collection.map( + labelled_x, + "xlabel", + ignore_aes=labels_ignore, + subset_info=True, + store_artist=False, + **xlabel_kwargs, + ) - # formatting ylabel and setting xlabel - if relative is not False: - ylabel = ylabel.format("Relative ESS") - else: - ylabel = ylabel.format("ESS") - xlabel = "Quantile" - - plot_collection.map( - labelled_x, - "xlabel", - ignore_aes=labels_ignore, - subset_info=True, - text=xlabel, - store_artist=False, - **label_kwargs, - ) + _, labels_aes, labels_ignore = filter_aes(plot_collection, aes_map, "ylabel", sample_dims) + ylabel_kwargs = plot_kwargs.get("ylabel", {}).copy() + if ylabel_kwargs is not False: + if "color" not in labels_aes: + ylabel_kwargs.setdefault("color", "black") - plot_collection.map( - labelled_y, - "ylabel", - ignore_aes=labels_ignore, - subset_info=True, - text=ylabel, - store_artist=False, - **label_kwargs, - ) + if relative is not False: + ylabel_text = ylabel.format("Relative ESS") + else: + ylabel_text = ylabel.format("ESS") + ylabel_kwargs.setdefault("text", ylabel_text) + + plot_collection.map( + labelled_y, + "ylabel", + ignore_aes=labels_ignore, + subset_info=True, + store_artist=False, + **ylabel_kwargs, + ) return plot_collection From 9adcb3246cedfc3a7b0fcb374199ca611ccfa3e5 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Tue, 23 Jul 2024 15:47:00 +0530 Subject: [PATCH 06/24] fixed default value for arg 'extra_methods' --- src/arviz_plots/plots/essplot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index 8ee8b95..4c3f5e5 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -34,7 +34,7 @@ def plot_ess( rug=False, rug_kind="diverging", n_points=20, - extra_methods=True, # default False + extra_methods=False, min_ess=400, plot_collection=None, backend=None, From 1674ab6425f5c7290114a1a9113d2b261980d808 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Tue, 23 Jul 2024 17:15:59 +0530 Subject: [PATCH 07/24] modified scatter_xy visual element to take into account _process_da_x_y update --- src/arviz_plots/visuals/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index 00fa36b..3db78a2 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -91,17 +91,16 @@ def scatter_x(da, target, backend, y=None, **kwargs): return plot_backend.scatter(da, y, target, **kwargs) -def scatter_xy(da, target, backend, x=None, **kwargs): +def scatter_xy(da, target, backend, x=None, y=None, **kwargs): """Plot a scatter plot x vs y. The input argument `da` is split into x and y using the dimension ``plot_axis``. + If additional x and y arguments are provided, x and y are added to the values + in the `da` dataset sliced along plot_axis='x' and plot_axis='y'. """ plot_backend = import_module(f"arviz_plots.backend.{backend}") - # note: temporary patch only before plot_ridge merge to use _process_da_x_y - # print(f"\n x arg = {x}") - x = da.sel(plot_axis="x") if x is None else da.sel(plot_axis="x") + x - # print(f"\n to plot x = {x}") - return plot_backend.scatter(x, da.sel(plot_axis="y"), target, **kwargs) + x, y = _process_da_x_y(da, x, y) + return plot_backend.scatter(x, y, target, **kwargs) def ecdf_line(values, target, backend, **kwargs): From 358cd5e52336a8bbd654f84620c3729f30d90282 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Sat, 3 Aug 2024 03:12:13 +0530 Subject: [PATCH 08/24] Added color/linestyles aesthetics and simplified min_ess plotting --- src/arviz_plots/plots/essplot.py | 69 ++++++++++++++------------------ 1 file changed, 31 insertions(+), 38 deletions(-) diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index 4c3f5e5..19d7f17 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -3,6 +3,7 @@ # imports # import warnings from copy import copy +from importlib import import_module import arviz_stats # pylint: disable=unused-import import numpy as np @@ -145,6 +146,11 @@ def plot_ess( dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords ) + # ensuring plot_kwargs['rug'] is not False + rug_kwargs = copy(plot_kwargs.get("rug", {})) + if rug_kwargs is False: + raise ValueError("plot_kwargs['rug'] can't be False, use rug=False to remove the rug") + # set plot collection initialization defaults if it doesnt exist if plot_collection is None: if backend is None: @@ -184,6 +190,10 @@ def plot_ess( aes_map = aes_map.copy() aes_map.setdefault(kind, plot_collection.aes_set.difference({"overlay"})) aes_map.setdefault("rug", {"overlay"}) + if "model" in distribution: + aes_map.setdefault("mean", {"color"}) + aes_map.setdefault("sd", {"color"}) + aes_map.setdefault("min_ess", {"color"}) if labeller is None: labeller = BaseLabeller() @@ -233,9 +243,6 @@ def plot_ess( # overlaying divergences(or other 'rug_kind') for each chain if rug: sample_stats = get_group(dt, "sample_stats", allow_missing=True) - rug_kwargs = copy(plot_kwargs.get("rug", {})) - if rug_kwargs is False: - raise ValueError("plot_kwargs['rug'] can't be False, use rug=False to remove the rug") if ( sample_stats is not None and rug_kind in sample_stats.data_vars @@ -248,12 +255,6 @@ def plot_ess( rug_kwargs.setdefault("color", "black") if "marker" not in div_aes: rug_kwargs.setdefault("marker", "|") - # WIP: if using a default linewidth once defined in backend/agnostic defaults - # if "width" not in div_aes: - # # get default linewidth for backends - # plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") - # default_linewidth = plot_bknd.get_default_aes("linewidth", 1, {}) - # rug_kwargs.setdefault("width", default_linewidth) if "size" not in div_aes: rug_kwargs.setdefault("size", 30) div_reduce_dims = [dim for dim in distribution.dims if dim not in aux_dim_list] @@ -290,6 +291,12 @@ def plot_ess( x_range = [0, 1] x_range = xr.DataArray(x_range) + # getting backend specific linestyles + plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") + linestyles = plot_bknd.get_default_aes("linestyle", 4, {}) + # and default color + default_color = plot_bknd.get_default_aes("color", 1, {})[0] + # plot mean and sd if extra_methods is not False: mean_kwargs = copy(plot_kwargs.get("mean", {})) @@ -300,13 +307,12 @@ def plot_ess( mean_ess = distribution.azstats.ess( dims=mean_dims, method="mean", relative=relative, **stats_kwargs.get("mean", {}) ) - print(f"\n mean_ess = {mean_ess}") - if "linestyle" not in mean_aes: - if backend == "matplotlib": - mean_kwargs.setdefault("linestyle", "--") - elif backend == "bokeh": - mean_kwargs.setdefault("linestyle", "dashed") + # getting 2nd default linestyle for chosen backend and assigning it by default + mean_kwargs.setdefault("linestyle", linestyles[1]) + + if "color" not in mean_aes: + mean_kwargs.setdefault("color", default_color) plot_collection.map( line_xy, @@ -323,13 +329,11 @@ def plot_ess( sd_ess = distribution.azstats.ess( dims=sd_dims, method="sd", relative=relative, **stats_kwargs.get("sd", {}) ) - print(f"\n sd_ess = {sd_ess}") - if "linestyle" not in sd_aes: - if backend == "matplotlib": - sd_kwargs.setdefault("linestyle", "--") - elif backend == "bokeh": - sd_kwargs.setdefault("linestyle", "dashed") + sd_kwargs.setdefault("linestyle", linestyles[2]) + + if "color" not in sd_aes: + sd_kwargs.setdefault("color", default_color) plot_collection.map( line_xy, "sd", data=sd_ess, ignore_aes=sd_ignore, x=x_range, **sd_kwargs @@ -339,36 +343,25 @@ def plot_ess( min_ess_kwargs = copy(plot_kwargs.get("min_ess", {})) if min_ess_kwargs is not False: - min_ess_dims, min_ess_aes, min_ess_ignore = filter_aes( + _, min_ess_aes, min_ess_ignore = filter_aes( plot_collection, aes_map, "min_ess", sample_dims ) if relative: min_ess = min_ess / n_points - # for each variable of distribution, put min_ess as the value, reducing all min_ess_dims - min_ess_data = {} - for var in distribution.data_vars: - reduced_data = distribution[var].mean( - dim=[dim for dim in distribution[var].dims if dim in min_ess_dims] - ) - min_ess_data[var] = xr.full_like(reduced_data, min_ess) - - min_ess_dataset = xr.Dataset(min_ess_data) - print(f"\n min_ess = {min_ess_dataset}") + min_ess_kwargs.setdefault("linestyle", linestyles[3]) - if "linestyle" not in min_ess_aes: - if backend == "matplotlib": - min_ess_kwargs.setdefault("linestyle", "--") - elif backend == "bokeh": - min_ess_kwargs.setdefault("linestyle", "dashed") + if "color" not in min_ess_aes: + min_ess_kwargs.setdefault("color", "gray") plot_collection.map( line_xy, "min_ess", - data=min_ess_dataset, + data=distribution, ignore_aes=min_ess_ignore, x=x_range, + y=min_ess, **min_ess_kwargs, ) From 7513d77c25991061836ba79dd8360cf35a55de7b Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Sun, 11 Aug 2024 19:41:15 +0530 Subject: [PATCH 09/24] Added annotate_xy visual element, applied to essplot for extra_methods --- src/arviz_plots/plots/essplot.py | 73 ++++++++++++++++++++++++++++- src/arviz_plots/visuals/__init__.py | 14 ++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index 19d7f17..0c451a3 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -14,6 +14,7 @@ from arviz_plots.plot_collection import PlotCollection from arviz_plots.plots.utils import filter_aes, get_group, process_group_variables_coords from arviz_plots.visuals import ( + annotate_xy, labelled_title, labelled_x, labelled_y, @@ -107,6 +108,8 @@ def plot_ess( * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` * mean -> passed to :func:`~arviz.plots.visuals.line_xy` + * mean_text -> passed to :func:`~arviz.plots.visuals.annotate_xy` + * sd_text -> passed to :func:`~arviz.plots.visuals.annotate_xy` * sd -> passed to :func:`~arviz.plots.visuals.line_xy` * min_ess -> passed to :func:`~arviz.plots.visuals.line_xy` @@ -194,6 +197,10 @@ def plot_ess( aes_map.setdefault("mean", {"color"}) aes_map.setdefault("sd", {"color"}) aes_map.setdefault("min_ess", {"color"}) + if "mean" in aes_map and "mean_text" not in aes_map: + aes_map["mean_text"] = aes_map["mean"] + if "sd" in aes_map and "sd_text" not in aes_map: + aes_map["sd_text"] = aes_map["sd"] if labeller is None: labeller = BaseLabeller() @@ -297,7 +304,7 @@ def plot_ess( # and default color default_color = plot_bknd.get_default_aes("color", 1, {})[0] - # plot mean and sd + # plot mean and sd and annotate them if extra_methods is not False: mean_kwargs = copy(plot_kwargs.get("mean", {})) if mean_kwargs is not False: @@ -339,6 +346,70 @@ def plot_ess( line_xy, "sd", data=sd_ess, ignore_aes=sd_ignore, x=x_range, **sd_kwargs ) + mean_text_kwargs = copy(plot_kwargs.get("mean_text", {})) + if ( + mean_text_kwargs is not False and mean_ess + ): # mean_ess has to exist for an annotation to be applied + _, mean_text_aes, mean_text_ignore = filter_aes( + plot_collection, aes_map, "mean_text", sample_dims + ) + + if "color" not in mean_text_aes: + mean_text_kwargs.setdefault("color", "black") + + mean_text_kwargs.setdefault("x", 1) + mean_text_kwargs.setdefault("horizontal_align", "right") + mean_text_kwargs.setdefault( + "vertical_align", "bottom" + ) # by default set to bottom for mean + + # pass the sd_ess data to be facetted/subsetted too for vertical alignment setting + if sd_ess: + extra_da = sd_ess + else: + extra_da = None + + plot_collection.map( + annotate_xy, + "mean_text", + text="mean", + data=mean_ess, + extra_da=extra_da, + ignore_aes=mean_text_ignore, + **mean_text_kwargs, + ) + + sd_text_kwargs = copy(plot_kwargs.get("sd_text", {})) + if ( + sd_text_kwargs is not False and sd_ess + ): # sd_ess has to exist for an annotation to be applied + _, sd_text_aes, sd_text_ignore = filter_aes( + plot_collection, aes_map, "sd_text", sample_dims + ) + + if "color" not in sd_text_aes: + sd_text_kwargs.setdefault("color", "black") + + sd_text_kwargs.setdefault("x", 1) + sd_text_kwargs.setdefault("horizontal_align", "right") + sd_text_kwargs.setdefault("vertical_align", "top") # by default set to top for sd + + # pass the mean_ess data to be facetted/subsetted too for vertical alignment setting + if mean_ess: + extra_da = mean_ess + else: + extra_da = None + + plot_collection.map( + annotate_xy, + "sd_text", + text="sd", + data=sd_ess, + extra_da=extra_da, + ignore_aes=sd_text_ignore, + **sd_text_kwargs, + ) + # plot minimum ess min_ess_kwargs = copy(plot_kwargs.get("min_ess", {})) diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index 3db78a2..240cfe5 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -154,6 +154,20 @@ def _ensure_scalar(*args): return tuple(arg.item() if hasattr(arg, "item") else arg for arg in args) +def annotate_xy(da, target, backend, *, text, x=None, y=None, extra_da=None, **kwargs): + """Annotate a point (x, y) in a plot.""" + # kwargs["vertical_align"] will depend on extra_da if passed to this function + if extra_da is not None: + if da.values > extra_da.values: + kwargs["vertical_align"] = "bottom" + if da.values < extra_da.values: + kwargs["vertical_align"] = "top" + # if equal, default/pre-set vertical_aligns are used + x, y = _process_da_x_y(da, x, y) + plot_backend = import_module(f"arviz_plots.backend.{backend}") + return plot_backend.text(x, y, text, target, **kwargs) + + def point_estimate_text( da, target, backend, *, point_estimate, x=None, y=None, point_label="x", **kwargs ): From 4fd998dda07f56106cd6aede5c654cd7735337b8 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Mon, 19 Aug 2024 20:28:28 +0530 Subject: [PATCH 10/24] visual element vertical alignment logic modification and arviz-stats compute_ranks addition attempt --- pyproject.toml | 2 +- src/arviz_plots/plots/essplot.py | 77 ++++++++++++++++------------- src/arviz_plots/visuals/__init__.py | 17 +++---- 3 files changed, 51 insertions(+), 45 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9a3ec7a..e08b0ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ dynamic = ["version", "description"] dependencies = [ "arviz-base==0.2", - "arviz-stats[xarray]==0.2", + "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats@rankdata", ] [tool.flit.module] diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index 0c451a3..2dcde04 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -133,6 +133,17 @@ def plot_ess( if isinstance(sample_dims, str): sample_dims = [sample_dims] + # from importlib.metadata import version, PackageNotFoundError + + # get the version of the arviz_stats module + # try: + # arviz_stats_version = version("arviz_stats") + # print(f"arviz_stats version: {arviz_stats_version}") + # except PackageNotFoundError: + # print("arviz_stats package is not installed") + + # print(f"arviz_stats version: {arviz_stats.__version__}") + # mutable inputs if plot_kwargs is None: plot_kwargs = {} @@ -266,31 +277,16 @@ def plot_ess( rug_kwargs.setdefault("size", 30) div_reduce_dims = [dim for dim in distribution.dims if dim not in aux_dim_list] - # xname is used to pick subset of dataset in map() to be masked - xname = None - default_xname = sample_dims[0] if len(sample_dims) == 1 else "draw" - if (default_xname not in distribution.dims) or ( - not np.issubdtype(distribution[default_xname].dtype, np.number) - ): - default_xname = None - xname = rug_kwargs.get("xname", default_xname) - rug_kwargs["xname"] = xname - - draw_length = ( - distribution.sizes[sample_dims[0]] - if len(sample_dims) == 1 - else distribution.sizes["draw"] - ) # used to scale xvalues to between 0-1 + values = distribution.azstats.compute_ranks(relative=False) + print(f"\n compute_ranks values = {values}") plot_collection.map( trace_rug, "rug", - data=distribution, + data=values, ignore_aes=div_ignore, - # xname=xname, y=distribution.min(div_reduce_dims), mask=rug_mask, - scale=draw_length, **rug_kwargs, ) # note: after plot_ppc merge, the `trace_rug` function might change @@ -306,6 +302,9 @@ def plot_ess( # plot mean and sd and annotate them if extra_methods is not False: + mean_ess = None + sd_ess = None + mean_kwargs = copy(plot_kwargs.get("mean", {})) if mean_kwargs is not False: mean_dims, mean_aes, mean_ignore = filter_aes( @@ -346,9 +345,15 @@ def plot_ess( line_xy, "sd", data=sd_ess, ignore_aes=sd_ignore, x=x_range, **sd_kwargs ) + sd_va_align = None + mean_va_align = None + if mean_ess is not None and sd_ess is not None: + sd_va_align = xr.where(mean_ess < sd_ess, "bottom", "top") + mean_va_align = xr.where(mean_ess < sd_ess, "top", "bottom") + mean_text_kwargs = copy(plot_kwargs.get("mean_text", {})) if ( - mean_text_kwargs is not False and mean_ess + mean_text_kwargs is not False and mean_ess is not None ): # mean_ess has to exist for an annotation to be applied _, mean_text_aes, mean_text_ignore = filter_aes( plot_collection, aes_map, "mean_text", sample_dims @@ -359,29 +364,29 @@ def plot_ess( mean_text_kwargs.setdefault("x", 1) mean_text_kwargs.setdefault("horizontal_align", "right") - mean_text_kwargs.setdefault( - "vertical_align", "bottom" - ) # by default set to bottom for mean + # mean_text_kwargs.setdefault( + # "vertical_align", "bottom" + # ) # by default set to bottom for mean - # pass the sd_ess data to be facetted/subsetted too for vertical alignment setting - if sd_ess: - extra_da = sd_ess + # pass the mean vertical_align data for vertical alignment setting + if mean_va_align is not None: + vertical_align = mean_va_align else: - extra_da = None + vertical_align = "bottom" plot_collection.map( annotate_xy, "mean_text", text="mean", data=mean_ess, - extra_da=extra_da, + vertical_align=vertical_align, ignore_aes=mean_text_ignore, **mean_text_kwargs, ) sd_text_kwargs = copy(plot_kwargs.get("sd_text", {})) if ( - sd_text_kwargs is not False and sd_ess + sd_text_kwargs is not False and sd_ess is not None ): # sd_ess has to exist for an annotation to be applied _, sd_text_aes, sd_text_ignore = filter_aes( plot_collection, aes_map, "sd_text", sample_dims @@ -392,20 +397,20 @@ def plot_ess( sd_text_kwargs.setdefault("x", 1) sd_text_kwargs.setdefault("horizontal_align", "right") - sd_text_kwargs.setdefault("vertical_align", "top") # by default set to top for sd + # sd_text_kwargs.setdefault("vertical_align", "top") # by default set to top for sd - # pass the mean_ess data to be facetted/subsetted too for vertical alignment setting - if mean_ess: - extra_da = mean_ess + # pass the sd vertical_align data for vertical alignment setting + if sd_va_align is not None: + vertical_align = sd_va_align else: - extra_da = None + vertical_align = "top" plot_collection.map( annotate_xy, "sd_text", text="sd", data=sd_ess, - extra_da=extra_da, + vertical_align=vertical_align, ignore_aes=sd_text_ignore, **sd_text_kwargs, ) @@ -493,4 +498,8 @@ def plot_ess( **ylabel_kwargs, ) + # print(f"\n plot_collection.viz = {plot_collection.viz}") + + # print(f"\n plot_collection.aes = {plot_collection.aes}") + return plot_collection diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index 240cfe5..c7c86d8 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -60,8 +60,9 @@ def line(da, target, backend, xname=None, **kwargs): return plot_backend.line(xvalues, yvalues, target, **kwargs) -def trace_rug(da, target, backend, mask, xname=None, y=None, scale=1, **kwargs): +def trace_rug(da, target, backend, mask, xname=None, y=None, **kwargs): """Create a rug plot with the subset of `da` indicated by `mask`.""" + # print(f'\n da = {da}') xname = xname.item() if hasattr(xname, "item") else xname if xname is False: xvalues = da @@ -76,8 +77,8 @@ def trace_rug(da, target, backend, mask, xname=None, y=None, scale=1, **kwargs): y = da.min().item() if len(xvalues.shape) != 1: raise ValueError(f"Expected unidimensional data but got {xvalues.sizes}") - xvalues = xvalues / scale # print(f"\n trace_rug call. xvalues = {xvalues}\nmask = {mask}") + print(f"\n trace_rug call. xvalues[mask] = {xvalues[mask]}") return scatter_x(xvalues[mask], target=target, backend=backend, y=y, **kwargs) @@ -154,15 +155,11 @@ def _ensure_scalar(*args): return tuple(arg.item() if hasattr(arg, "item") else arg for arg in args) -def annotate_xy(da, target, backend, *, text, x=None, y=None, extra_da=None, **kwargs): +def annotate_xy(da, target, backend, *, text, x=None, y=None, vertical_align=None, **kwargs): """Annotate a point (x, y) in a plot.""" - # kwargs["vertical_align"] will depend on extra_da if passed to this function - if extra_da is not None: - if da.values > extra_da.values: - kwargs["vertical_align"] = "bottom" - if da.values < extra_da.values: - kwargs["vertical_align"] = "top" - # if equal, default/pre-set vertical_aligns are used + if vertical_align is not None: + # print(f"\n vertical_align.item() = {vertical_align.item()}") + kwargs["vertical_align"] = vertical_align.item() x, y = _process_da_x_y(da, x, y) plot_backend = import_module(f"arviz_plots.backend.{backend}") return plot_backend.text(x, y, text, target, **kwargs) From 6a8d8eec7491e13dc3112e2a8f172fed657ab3d8 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Wed, 21 Aug 2024 03:05:02 +0530 Subject: [PATCH 11/24] added docs for essplot --- docs/source/api/plots.rst | 1 + docs/source/api/visuals.rst | 1 + .../inference_diagnostics/plot_ess_local.py | 25 ++++++ .../inference_diagnostics/plot_ess_models.py | 25 ++++++ .../plot_ess_quantile.py | 25 ++++++ src/arviz_plots/plots/essplot.py | 79 +++++++++++++++++-- 6 files changed, 151 insertions(+), 5 deletions(-) create mode 100644 docs/source/gallery/inference_diagnostics/plot_ess_local.py create mode 100644 docs/source/gallery/inference_diagnostics/plot_ess_models.py create mode 100644 docs/source/gallery/inference_diagnostics/plot_ess_quantile.py diff --git a/docs/source/api/plots.rst b/docs/source/api/plots.rst index c6f3cd2..da290c4 100644 --- a/docs/source/api/plots.rst +++ b/docs/source/api/plots.rst @@ -23,3 +23,4 @@ A complementary introduction and guide to ``plot_...`` functions is available at plot_ridge plot_trace plot_trace_dist + plot_ess diff --git a/docs/source/api/visuals.rst b/docs/source/api/visuals.rst index 3c76f3f..e567692 100644 --- a/docs/source/api/visuals.rst +++ b/docs/source/api/visuals.rst @@ -24,6 +24,7 @@ Data and axis annotating elements :toctree: generated/ annotate_label + annotate_xy point_estimate_text labelled_title labelled_y diff --git a/docs/source/gallery/inference_diagnostics/plot_ess_local.py b/docs/source/gallery/inference_diagnostics/plot_ess_local.py new file mode 100644 index 0000000..4d5e6af --- /dev/null +++ b/docs/source/gallery/inference_diagnostics/plot_ess_local.py @@ -0,0 +1,25 @@ +""" +# ESS Local plot + +Facetted local ESS plot + +--- + +:::{seealso} +API Documentation: {func}`~arviz_plots.plot_ess` +::: +""" + +from arviz_base import load_arviz_data + +import arviz_plots as azp + +azp.style.use("arviz-clean") + +data = load_arviz_data("centered_eight") +pc = azp.plot_ess( + data, + kind="local", + backend="none", # change to preferred backend +) +pc.show() diff --git a/docs/source/gallery/inference_diagnostics/plot_ess_models.py b/docs/source/gallery/inference_diagnostics/plot_ess_models.py new file mode 100644 index 0000000..eeb17b2 --- /dev/null +++ b/docs/source/gallery/inference_diagnostics/plot_ess_models.py @@ -0,0 +1,25 @@ +""" +# ESS comparison plot + +Full ESS (Either local or quantile) comparison between different models + +--- + +:::{seealso} +API Documentation: {func}`~arviz_plots.plot_ess` +::: +""" + +from arviz_base import load_arviz_data + +import arviz_plots as azp + +azp.style.use("arviz-clean") + +c = load_arviz_data("centered_eight") +n = load_arviz_data("non_centered_eight") +pc = azp.plot_ess( + {"Centered": c, "Non Centered": n}, + backend="none", # change to preferred backend +) +pc.show() diff --git a/docs/source/gallery/inference_diagnostics/plot_ess_quantile.py b/docs/source/gallery/inference_diagnostics/plot_ess_quantile.py new file mode 100644 index 0000000..e2e833c --- /dev/null +++ b/docs/source/gallery/inference_diagnostics/plot_ess_quantile.py @@ -0,0 +1,25 @@ +""" +# ESS Quantile plot + +Facetted quantile ESS plot + +--- + +:::{seealso} +API Documentation: {func}`~arviz_plots.plot_ess` +::: +""" + +from arviz_base import load_arviz_data + +import arviz_plots as azp + +azp.style.use("arviz-clean") + +data = load_arviz_data("centered_eight") +pc = azp.plot_ess( + data, + kind="quantile", + backend="none", # change to preferred backend +) +pc.show() diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index 2dcde04..83d4dce 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -66,7 +66,7 @@ def plot_ess( sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` - kind : {"local", "quantile", "evolution"}, default "local" + kind : {"local", "quantile"}, default "local" Specify the kind of plot: * The ``kind="local"`` argument generates the ESS' local efficiency @@ -99,10 +99,6 @@ def plot_ess( Valid keys are: * ess -> passed to :func:`~arviz_plots.visuals.scatter_xy` - if `kind`='local', - else passed to :func:`~arviz_plots.visuals.scatter_xy` - if `kind` = 'quantile' - * rug -> passed to :func:`~.visuals.trace_rug` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` @@ -126,6 +122,77 @@ def plot_ess( Returns ------- PlotCollection + + Examples + -------- + The following examples focus on behaviour specific to ``plot_ess``. + For a general introduction to batteries-included functions like this one and common + usage examples see :ref:`plots_intro` + + Default plot_ess for a single model: + + .. plot:: + :context: close-figs + + >>> from arviz_plots import plot_ess, style + >>> style.use("arviz-clean") + >>> from arviz_base import load_arviz_data + >>> centered = load_arviz_data('centered_eight') + >>> non_centered = load_arviz_data('non_centered_eight') + >>> pc = plot_ess(centered) + + Default plot_ess for multiple models: (Depending on the number of models, a slight + x-axis separation aesthetic is applied for each ess point for distinguishability in + case of overlap) + + .. plot:: + :context: close-figs + + >>> pc = plot_ess( + >>> {"centered": centered, "non centered": non_centered}, + >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, + >>> ) + >>> pc.add_legend("model") + + We can also manually map the color to the variable, and have the mapping apply + to the title too instead of only the ess markers: + + .. plot:: + :context: close-figs + + >>> pc = plot_ess( + >>> non_centered, + >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, + >>> pc_kwargs={"aes": {"color": ["__variable__"]}}, + >>> aes_map={"title": ["color"]}, + >>> ) + + If we add a mapping (like color) manually to the variable, but not specify which artist + to apply the mapping to- then it is applied to the 'ess' marker artist by default: + + .. plot:: + :context: close-figs + + >>> pc = plot_ess( + >>> centered, + >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, + >>> pc_kwargs={"aes": {"color": ["__variable__"]}}, + >>> ) + + The artists' visual features can also be customized through plot_kwargs, based on the + kwargs that the visual element function for the artist accepts- like all the other + batteries included plots. For example, for the 'ess' artist, the scatter_xy function is + used. So if we want to change the marker: + + .. plot:: + :context: close-figs + + >>> pc = plot_ess( + >>> centered, + >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, + >>> plot_kwargs={"ess": {"marker": "_"}}, + >>> ) + """ # initial defaults if sample_dims is None: @@ -133,6 +200,8 @@ def plot_ess( if isinstance(sample_dims, str): sample_dims = [sample_dims] + ylabel = "{}" + # from importlib.metadata import version, PackageNotFoundError # get the version of the arviz_stats module From 20ff025baf68c1874f330e87c96b0bf0fd677d52 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Wed, 21 Aug 2024 15:35:40 +0530 Subject: [PATCH 12/24] tests for essplot --- tests/test_hypothesis_plots.py | 62 +++++++++++++++++++++++++++++++++- tests/test_plots.py | 49 +++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 1 deletion(-) diff --git a/tests/test_hypothesis_plots.py b/tests/test_hypothesis_plots.py index 77bcb44..7d8c534 100644 --- a/tests/test_hypothesis_plots.py +++ b/tests/test_hypothesis_plots.py @@ -8,7 +8,7 @@ from datatree import DataTree from hypothesis import given -from arviz_plots import plot_dist, plot_forest, plot_ridge +from arviz_plots import plot_dist, plot_ess, plot_forest, plot_ridge pytestmark = pytest.mark.usefixtures("no_artist_kwargs") @@ -192,3 +192,63 @@ def test_plot_ridge(datatree, combined, plot_kwargs, labels_shade_label): assert all(key in child for child in pc.viz.children.values()) elif key not in ("remove_axis", "ticklabels"): assert all(key in child for child in pc.viz.children.values()) + + +ess_kind_value = st.sampled_from(("local", "quantile")) +ess_relative = st.booleans() +ess_rug = st.booleans() +ess_extra_methods = st.booleans() + + +@st.composite +def ess_n_points(draw): + return draw(st.integers(min_value=1, max_value=50)) # should this range be changed? + + +@st.composite +def ess_min_ess(draw): + return draw(st.integers(min_value=10, max_value=150)) # max samples = 3 x 50 = 150 + + +@given( + plot_kwargs=st.fixed_dictionaries( + {}, + optional={ + "ess": plot_kwargs_value, + "rug": st.sampled_from(({}, {"color": "red"})), + "xlabel": st.sampled_from(({}, {"color": "red"})), + "ylabel": st.sampled_from(({}, {"color": "red"})), + "mean": plot_kwargs_value, + "mean_text": plot_kwargs_value, + "sd": plot_kwargs_value, + "sd_text": plot_kwargs_value, + "min_ess": plot_kwargs_value, + "title": plot_kwargs_value, + "remove_axis": st.just(False), + }, + ), + kind=ess_kind_value, + relative=ess_relative, + rug=ess_rug, + n_points=ess_n_points(), + extra_methods=ess_extra_methods, + min_ess=ess_min_ess(), +) +def test_plot_ess(datatree, kind, relative, rug, n_points, extra_methods, min_ess, plot_kwargs): + pc = plot_ess( + datatree, + backend="none", + kind=kind, + relative=relative, + rug=rug, + n_points=n_points, + extra_methods=extra_methods, + min_ess=min_ess, + plot_kwargs=plot_kwargs, + ) + assert all("plot" in child for child in pc.viz.children.values()) + for key, value in plot_kwargs.items(): + if value is False: + assert all(key not in child for child in pc.viz.children.values()) + elif key != "remove_axis": + assert all(key in child for child in pc.viz.children.values()) diff --git a/tests/test_plots.py b/tests/test_plots.py index 64532af..cd95fba 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -8,6 +8,7 @@ from arviz_plots import ( plot_compare, plot_dist, + plot_ess, plot_forest, plot_ridge, plot_trace, @@ -291,3 +292,51 @@ def test_plot_compare_kwargs(self, cmp, backend): pc_kwargs={"plot_grid_kws": {"figsize": (1000, 200)}}, backend=backend, ) + + def test_plot_ess(self, datatree, backend): + pc = plot_ess(datatree, backend=backend, rug=True) + assert "chart" in pc.viz.data_vars + assert "plot" not in pc.viz.data_vars + assert "ess" in pc.viz["mu"] + assert "min_ess" in pc.viz["mu"] + assert "title" in pc.viz["mu"] + assert "rug" in pc.viz["mu"] + assert "hierarchy" not in pc.viz["mu"].dims + assert "hierarchy" in pc.viz["theta"].dims + assert pc.viz["mu"].rug.shape == (4,) # checking rug artist shape (4 chains overlaid) + # checking aesthetics + assert "overlay" in pc.aes["mu"].data_vars # overlay of chains + + a = """ + def test_plot_ess_sample(self, datatree_sample, backend): + pc = plot_ess(datatree_sample, backend=backend, rug=True, sample_dims="sample") + assert "chart" in pc.viz.data_vars + assert "plot" not in pc.viz.data_vars + assert "ess" in pc.viz["mu"] + assert "min_ess" in pc.viz["mu"] + assert "title" in pc.viz["mu"] + assert "rug" in pc.viz["mu"] + assert "hierarchy" not in pc.viz["mu"].dims + assert "hierarchy" in pc.viz["theta"].dims + assert pc.viz["mu"].trace.shape == () # 0 chains here, so no overlay + """ # error when running this: ValueError: dims must be of length 2 (from arviz-stats ) + if a: + print("a") # just to temporarily suspend linting error + + def test_plot_ess_models(self, datatree, datatree2, backend): + pc = plot_ess({"c": datatree, "n": datatree2}, backend=backend, rug=False) + assert "chart" in pc.viz.data_vars + assert "plot" not in pc.viz.data_vars + assert "ess" in pc.viz["mu"] + assert "min_ess" in pc.viz["mu"] + assert "title" in pc.viz["mu"] + # assert "rug" in pc.viz["mu"] + assert "hierarchy" not in pc.viz["mu"].dims + assert "hierarchy" in pc.viz["theta"].dims + assert "model" in pc.viz["mu"].dims + # assert pc.viz["mu"].rug.shape == (2, 4) # since there are 2 Ms, 4 Cs in datatree + # checking aesthetics + assert "model" in pc.aes["mu"].dims + assert "x" in pc.aes["mu"].data_vars + assert "color" in pc.aes["mu"].data_vars + assert "overlay" in pc.aes["mu"].data_vars # overlay of chains From 6c3c9069cf505b0d6f4d86436782b210dd127c98 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Wed, 21 Aug 2024 16:15:45 +0530 Subject: [PATCH 13/24] added scatter_xy to visuals.rst --- docs/source/api/visuals.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/api/visuals.rst b/docs/source/api/visuals.rst index e567692..0881105 100644 --- a/docs/source/api/visuals.rst +++ b/docs/source/api/visuals.rst @@ -13,6 +13,7 @@ Data plotting elements line line_xy line_x + scatter_xy scatter_x ecdf_line hist From ac8b62a2ca91b21754b957dc14758c2b7d2977f6 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Wed, 21 Aug 2024 19:53:58 +0530 Subject: [PATCH 14/24] added rug=True to example gallery plot_ess_local --- docs/source/gallery/inference_diagnostics/plot_ess_local.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/gallery/inference_diagnostics/plot_ess_local.py b/docs/source/gallery/inference_diagnostics/plot_ess_local.py index 4d5e6af..8583e2f 100644 --- a/docs/source/gallery/inference_diagnostics/plot_ess_local.py +++ b/docs/source/gallery/inference_diagnostics/plot_ess_local.py @@ -21,5 +21,6 @@ data, kind="local", backend="none", # change to preferred backend + rug=True, ) pc.show() From da85925bc48a0fb316427e0bb675ae6b0c03f25c Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Wed, 21 Aug 2024 20:45:35 +0530 Subject: [PATCH 15/24] fixes for rugplot issue and hypothesis test failures --- src/arviz_plots/plots/essplot.py | 3 ++- src/arviz_plots/visuals/__init__.py | 5 ++++- tests/test_hypothesis_plots.py | 5 +++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index 83d4dce..01ad37c 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -346,7 +346,7 @@ def plot_ess( rug_kwargs.setdefault("size", 30) div_reduce_dims = [dim for dim in distribution.dims if dim not in aux_dim_list] - values = distribution.azstats.compute_ranks(relative=False) + values = distribution.azstats.compute_ranks(relative=True) print(f"\n compute_ranks values = {values}") plot_collection.map( @@ -356,6 +356,7 @@ def plot_ess( ignore_aes=div_ignore, y=distribution.min(div_reduce_dims), mask=rug_mask, + xname=False, **rug_kwargs, ) # note: after plot_ppc merge, the `trace_rug` function might change diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index c7c86d8..e241827 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -159,7 +159,10 @@ def annotate_xy(da, target, backend, *, text, x=None, y=None, vertical_align=Non """Annotate a point (x, y) in a plot.""" if vertical_align is not None: # print(f"\n vertical_align.item() = {vertical_align.item()}") - kwargs["vertical_align"] = vertical_align.item() + if hasattr(vertical_align, "item"): + kwargs["vertical_align"] = vertical_align.item() + else: + kwargs["vertical_align"] = vertical_align # if a string and not a dataarray x, y = _process_da_x_y(da, x, y) plot_backend = import_module(f"arviz_plots.backend.{backend}") return plot_backend.text(x, y, text, target, **kwargs) diff --git a/tests/test_hypothesis_plots.py b/tests/test_hypothesis_plots.py index 7d8c534..e66a690 100644 --- a/tests/test_hypothesis_plots.py +++ b/tests/test_hypothesis_plots.py @@ -250,5 +250,10 @@ def test_plot_ess(datatree, kind, relative, rug, n_points, extra_methods, min_es for key, value in plot_kwargs.items(): if value is False: assert all(key not in child for child in pc.viz.children.values()) + elif key == "rug": + if rug is False: + assert all(key not in child for child in pc.viz.children.values()) + else: + assert all(key in child for child in pc.viz.children.values()) elif key != "remove_axis": assert all(key in child for child in pc.viz.children.values()) From 24fe1948351d5c6e33f82d2dfa37f04dcf534a2d Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Mon, 26 Aug 2024 17:27:04 +0530 Subject: [PATCH 16/24] shifted mean_ess, sd_ess computing to before plot_kwargs check+artist plotting logic and modified hypothesis tests --- src/arviz_plots/plots/essplot.py | 47 +++++++---------------------- src/arviz_plots/visuals/__init__.py | 2 -- tests/test_hypothesis_plots.py | 5 +++ 3 files changed, 16 insertions(+), 38 deletions(-) diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index 01ad37c..9e61dd4 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -202,17 +202,6 @@ def plot_ess( ylabel = "{}" - # from importlib.metadata import version, PackageNotFoundError - - # get the version of the arviz_stats module - # try: - # arviz_stats_version = version("arviz_stats") - # print(f"arviz_stats version: {arviz_stats_version}") - # except PackageNotFoundError: - # print("arviz_stats package is not installed") - - # print(f"arviz_stats version: {arviz_stats.__version__}") - # mutable inputs if plot_kwargs is None: plot_kwargs = {} @@ -347,7 +336,6 @@ def plot_ess( div_reduce_dims = [dim for dim in distribution.dims if dim not in aux_dim_list] values = distribution.azstats.compute_ranks(relative=True) - print(f"\n compute_ranks values = {values}") plot_collection.map( trace_rug, @@ -372,18 +360,20 @@ def plot_ess( # plot mean and sd and annotate them if extra_methods is not False: - mean_ess = None - sd_ess = None + # computing mean_ess + mean_dims, mean_aes, mean_ignore = filter_aes(plot_collection, aes_map, "mean", sample_dims) + mean_ess = distribution.azstats.ess( + dims=mean_dims, method="mean", relative=relative, **stats_kwargs.get("mean", {}) + ) + + # computing sd_ess + sd_dims, sd_aes, sd_ignore = filter_aes(plot_collection, aes_map, "sd", sample_dims) + sd_ess = distribution.azstats.ess( + dims=sd_dims, method="sd", relative=relative, **stats_kwargs.get("sd", {}) + ) mean_kwargs = copy(plot_kwargs.get("mean", {})) if mean_kwargs is not False: - mean_dims, mean_aes, mean_ignore = filter_aes( - plot_collection, aes_map, "mean", sample_dims - ) - mean_ess = distribution.azstats.ess( - dims=mean_dims, method="mean", relative=relative, **stats_kwargs.get("mean", {}) - ) - # getting 2nd default linestyle for chosen backend and assigning it by default mean_kwargs.setdefault("linestyle", linestyles[1]) @@ -401,11 +391,6 @@ def plot_ess( sd_kwargs = copy(plot_kwargs.get("sd", {})) if sd_kwargs is not False: - sd_dims, sd_aes, sd_ignore = filter_aes(plot_collection, aes_map, "sd", sample_dims) - sd_ess = distribution.azstats.ess( - dims=sd_dims, method="sd", relative=relative, **stats_kwargs.get("sd", {}) - ) - sd_kwargs.setdefault("linestyle", linestyles[2]) if "color" not in sd_aes: @@ -434,9 +419,6 @@ def plot_ess( mean_text_kwargs.setdefault("x", 1) mean_text_kwargs.setdefault("horizontal_align", "right") - # mean_text_kwargs.setdefault( - # "vertical_align", "bottom" - # ) # by default set to bottom for mean # pass the mean vertical_align data for vertical alignment setting if mean_va_align is not None: @@ -467,7 +449,6 @@ def plot_ess( sd_text_kwargs.setdefault("x", 1) sd_text_kwargs.setdefault("horizontal_align", "right") - # sd_text_kwargs.setdefault("vertical_align", "top") # by default set to top for sd # pass the sd vertical_align data for vertical alignment setting if sd_va_align is not None: @@ -543,7 +524,6 @@ def plot_ess( "xlabel", ignore_aes=labels_ignore, subset_info=True, - store_artist=False, **xlabel_kwargs, ) @@ -564,12 +544,7 @@ def plot_ess( "ylabel", ignore_aes=labels_ignore, subset_info=True, - store_artist=False, **ylabel_kwargs, ) - # print(f"\n plot_collection.viz = {plot_collection.viz}") - - # print(f"\n plot_collection.aes = {plot_collection.aes}") - return plot_collection diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index e241827..dd6459d 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -77,8 +77,6 @@ def trace_rug(da, target, backend, mask, xname=None, y=None, **kwargs): y = da.min().item() if len(xvalues.shape) != 1: raise ValueError(f"Expected unidimensional data but got {xvalues.sizes}") - # print(f"\n trace_rug call. xvalues = {xvalues}\nmask = {mask}") - print(f"\n trace_rug call. xvalues[mask] = {xvalues[mask]}") return scatter_x(xvalues[mask], target=target, backend=backend, y=y, **kwargs) diff --git a/tests/test_hypothesis_plots.py b/tests/test_hypothesis_plots.py index e66a690..44c839f 100644 --- a/tests/test_hypothesis_plots.py +++ b/tests/test_hypothesis_plots.py @@ -250,6 +250,11 @@ def test_plot_ess(datatree, kind, relative, rug, n_points, extra_methods, min_es for key, value in plot_kwargs.items(): if value is False: assert all(key not in child for child in pc.viz.children.values()) + elif key in ["mean", "sd", "mean_text", "sd_text"]: + if extra_methods is False: + assert all(key not in child for child in pc.viz.children.values()) + else: + assert all(key in child for child in pc.viz.children.values()) elif key == "rug": if rug is False: assert all(key not in child for child in pc.viz.children.values()) From e66beb289df1c4bc1d6c0af244567910b29f10ff Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Fri, 30 Aug 2024 16:51:23 +0530 Subject: [PATCH 17/24] Updated plot_ess and tests --- src/arviz_plots/plots/essplot.py | 47 +++++++++++++++++++++++++---- src/arviz_plots/visuals/__init__.py | 2 -- tests/test_hypothesis_plots.py | 27 ++++++----------- tests/test_plots.py | 5 ++- 4 files changed, 52 insertions(+), 29 deletions(-) diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index 9e61dd4..0ffa28d 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -11,7 +11,7 @@ from arviz_base import rcParams from arviz_base.labels import BaseLabeller -from arviz_plots.plot_collection import PlotCollection +from arviz_plots.plot_collection import PlotCollection, leaf_dataset, process_facet_dims from arviz_plots.plots.utils import filter_aes, get_group, process_group_variables_coords from arviz_plots.visuals import ( annotate_xy, @@ -223,16 +223,52 @@ def plot_ess( if rug_kwargs is False: raise ValueError("plot_kwargs['rug'] can't be False, use rug=False to remove the rug") + if backend is None: + backend = rcParams["plot.backend"] + plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") + # set plot collection initialization defaults if it doesnt exist + + # figsizing related plot collection initialization if plot_collection is None: - if backend is None: - backend = rcParams["plot.backend"] + figsize = pc_kwargs.get("plot_grid_kws", {}).get("figsize", None) + figsize_units = pc_kwargs.get("plot_grid_kws", {}).get("figsize_units", "inches") pc_kwargs.setdefault("col_wrap", 5) pc_kwargs.setdefault( "cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in {"model"}.union(sample_dims)], ) + n_plots, _ = process_facet_dims(distribution, pc_kwargs["cols"]) + col_wrap = pc_kwargs["col_wrap"] + if n_plots <= col_wrap: + n_rows, n_cols = 1, n_plots + else: + div_mod = divmod(n_plots, col_wrap) + n_rows = div_mod[0] + (div_mod[1] != 0) + n_cols = col_wrap + else: + figsize, figsize_units = plot_bknd.get_figsize(plot_collection) + n_rows = leaf_dataset(plot_collection.viz, "row").max().to_array().max().item() + n_cols = leaf_dataset(plot_collection.viz, "col").max().to_array().max().item() + + plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") + + figsize = plot_bknd.scale_fig_size( + figsize, + rows=n_rows, + cols=n_cols, + figsize_units=figsize_units, + ) + + # other plot collection related initialization + if plot_collection is None: + # making copy of pc_kwargs["plot_grid_kws"] to pass to .wrap() + pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() + if "figsize" not in pc_kwargs["plot_grid_kws"]: + pc_kwargs["plot_grid_kws"]["figsize"] = figsize + pc_kwargs["plot_grid_kws"]["figsize_units"] = "dots" + pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() if "chain" in distribution: pc_kwargs["aes"].setdefault("overlay", ["chain"]) @@ -353,7 +389,6 @@ def plot_ess( x_range = xr.DataArray(x_range) # getting backend specific linestyles - plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") linestyles = plot_bknd.get_default_aes("linestyle", 4, {}) # and default color default_color = plot_bknd.get_default_aes("color", 1, {})[0] @@ -425,13 +460,13 @@ def plot_ess( vertical_align = mean_va_align else: vertical_align = "bottom" + mean_text_kwargs.setdefault("vertical_align", vertical_align) plot_collection.map( annotate_xy, "mean_text", text="mean", data=mean_ess, - vertical_align=vertical_align, ignore_aes=mean_text_ignore, **mean_text_kwargs, ) @@ -455,13 +490,13 @@ def plot_ess( vertical_align = sd_va_align else: vertical_align = "top" + sd_text_kwargs.setdefault("vertical_align", vertical_align) plot_collection.map( annotate_xy, "sd_text", text="sd", data=sd_ess, - vertical_align=vertical_align, ignore_aes=sd_text_ignore, **sd_text_kwargs, ) diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index dd6459d..97f7ce8 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -62,7 +62,6 @@ def line(da, target, backend, xname=None, **kwargs): def trace_rug(da, target, backend, mask, xname=None, y=None, **kwargs): """Create a rug plot with the subset of `da` indicated by `mask`.""" - # print(f'\n da = {da}') xname = xname.item() if hasattr(xname, "item") else xname if xname is False: xvalues = da @@ -156,7 +155,6 @@ def _ensure_scalar(*args): def annotate_xy(da, target, backend, *, text, x=None, y=None, vertical_align=None, **kwargs): """Annotate a point (x, y) in a plot.""" if vertical_align is not None: - # print(f"\n vertical_align.item() = {vertical_align.item()}") if hasattr(vertical_align, "item"): kwargs["vertical_align"] = vertical_align.item() else: diff --git a/tests/test_hypothesis_plots.py b/tests/test_hypothesis_plots.py index 44c839f..d1a518a 100644 --- a/tests/test_hypothesis_plots.py +++ b/tests/test_hypothesis_plots.py @@ -37,9 +37,11 @@ def datatree(seed=31): kind_value = st.sampled_from(("kde", "ecdf")) +ess_kind_value = st.sampled_from(("local", "quantile")) ci_kind_value = st.sampled_from(("eti", "hdi")) point_estimate_value = st.sampled_from(("mean", "median")) plot_kwargs_value = st.sampled_from(({}, False, {"color": "red"})) +plot_kwargs_value_no_false = st.sampled_from(({}, {"color": "red"})) @st.composite @@ -194,17 +196,6 @@ def test_plot_ridge(datatree, combined, plot_kwargs, labels_shade_label): assert all(key in child for child in pc.viz.children.values()) -ess_kind_value = st.sampled_from(("local", "quantile")) -ess_relative = st.booleans() -ess_rug = st.booleans() -ess_extra_methods = st.booleans() - - -@st.composite -def ess_n_points(draw): - return draw(st.integers(min_value=1, max_value=50)) # should this range be changed? - - @st.composite def ess_min_ess(draw): return draw(st.integers(min_value=10, max_value=150)) # max samples = 3 x 50 = 150 @@ -215,9 +206,9 @@ def ess_min_ess(draw): {}, optional={ "ess": plot_kwargs_value, - "rug": st.sampled_from(({}, {"color": "red"})), - "xlabel": st.sampled_from(({}, {"color": "red"})), - "ylabel": st.sampled_from(({}, {"color": "red"})), + "rug": plot_kwargs_value_no_false, + "xlabel": plot_kwargs_value_no_false, + "ylabel": plot_kwargs_value_no_false, "mean": plot_kwargs_value, "mean_text": plot_kwargs_value, "sd": plot_kwargs_value, @@ -228,10 +219,10 @@ def ess_min_ess(draw): }, ), kind=ess_kind_value, - relative=ess_relative, - rug=ess_rug, - n_points=ess_n_points(), - extra_methods=ess_extra_methods, + relative=st.booleans(), + rug=st.booleans(), + n_points=st.integers(min_value=1, max_value=5), + extra_methods=st.booleans(), min_ess=ess_min_ess(), ) def test_plot_ess(datatree, kind, relative, rug, n_points, extra_methods, min_ess, plot_kwargs): diff --git a/tests/test_plots.py b/tests/test_plots.py index cd95fba..656ecc9 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -303,7 +303,7 @@ def test_plot_ess(self, datatree, backend): assert "rug" in pc.viz["mu"] assert "hierarchy" not in pc.viz["mu"].dims assert "hierarchy" in pc.viz["theta"].dims - assert pc.viz["mu"].rug.shape == (4,) # checking rug artist shape (4 chains overlaid) + assert "chain" in pc.viz["mu"].rug.dims # checking rug artist overlay # checking aesthetics assert "overlay" in pc.aes["mu"].data_vars # overlay of chains @@ -330,11 +330,10 @@ def test_plot_ess_models(self, datatree, datatree2, backend): assert "ess" in pc.viz["mu"] assert "min_ess" in pc.viz["mu"] assert "title" in pc.viz["mu"] - # assert "rug" in pc.viz["mu"] + assert "rug" not in pc.viz["mu"] assert "hierarchy" not in pc.viz["mu"].dims assert "hierarchy" in pc.viz["theta"].dims assert "model" in pc.viz["mu"].dims - # assert pc.viz["mu"].rug.shape == (2, 4) # since there are 2 Ms, 4 Cs in datatree # checking aesthetics assert "model" in pc.aes["mu"].dims assert "x" in pc.aes["mu"].data_vars From 85d011787ddf65146e30f867c590274e6401b312 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Mon, 2 Sep 2024 13:15:16 +0530 Subject: [PATCH 18/24] updated .toml file for arviz-stats dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e08b0ec..635fb26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ dynamic = ["version", "description"] dependencies = [ "arviz-base==0.2", - "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats@rankdata", + "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats", ] [tool.flit.module] From f7698a2fc8b003b6a86b4653b4733338379c7845 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Tue, 3 Sep 2024 01:00:09 +0530 Subject: [PATCH 19/24] Modified figsize to more of a plot_forest approach, fixed order of plots in plots.rst and expanded max limit for test methods in testplots.py in .pylintrc --- .pylintrc | 2 +- docs/source/api/plots.rst | 4 +- src/arviz_plots/plots/essplot.py | 79 +++++++++++++++----------------- 3 files changed, 41 insertions(+), 44 deletions(-) diff --git a/.pylintrc b/.pylintrc index 0e48a26..23d42d7 100644 --- a/.pylintrc +++ b/.pylintrc @@ -474,7 +474,7 @@ max-locals=15 max-parents=7 # Maximum number of public methods for a class (see R0904). -max-public-methods=20 +max-public-methods=25 # Maximum number of return / yield for function / method body max-returns=6 diff --git a/docs/source/api/plots.rst b/docs/source/api/plots.rst index da290c4..828dbd2 100644 --- a/docs/source/api/plots.rst +++ b/docs/source/api/plots.rst @@ -19,8 +19,8 @@ A complementary introduction and guide to ``plot_...`` functions is available at plot_compare plot_dist + plot_ess plot_forest plot_ridge plot_trace - plot_trace_dist - plot_ess + plot_trace_dist \ No newline at end of file diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index 0ffa28d..b1b19d0 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -1,7 +1,5 @@ """ess plot code.""" -# imports -# import warnings from copy import copy from importlib import import_module @@ -11,7 +9,7 @@ from arviz_base import rcParams from arviz_base.labels import BaseLabeller -from arviz_plots.plot_collection import PlotCollection, leaf_dataset, process_facet_dims +from arviz_plots.plot_collection import PlotCollection, process_facet_dims from arviz_plots.plots.utils import filter_aes, get_group, process_group_variables_coords from arviz_plots.visuals import ( annotate_xy, @@ -224,52 +222,23 @@ def plot_ess( raise ValueError("plot_kwargs['rug'] can't be False, use rug=False to remove the rug") if backend is None: - backend = rcParams["plot.backend"] - plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") + if plot_collection is None: + backend = rcParams["plot.backend"] + else: + backend = plot_collection.backend - # set plot collection initialization defaults if it doesnt exist + plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") - # figsizing related plot collection initialization + # set plot collection initialization defaults if plot_collection is None: - figsize = pc_kwargs.get("plot_grid_kws", {}).get("figsize", None) - figsize_units = pc_kwargs.get("plot_grid_kws", {}).get("figsize_units", "inches") + pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() + pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs.setdefault("col_wrap", 5) pc_kwargs.setdefault( "cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in {"model"}.union(sample_dims)], ) - n_plots, _ = process_facet_dims(distribution, pc_kwargs["cols"]) - col_wrap = pc_kwargs["col_wrap"] - if n_plots <= col_wrap: - n_rows, n_cols = 1, n_plots - else: - div_mod = divmod(n_plots, col_wrap) - n_rows = div_mod[0] + (div_mod[1] != 0) - n_cols = col_wrap - else: - figsize, figsize_units = plot_bknd.get_figsize(plot_collection) - n_rows = leaf_dataset(plot_collection.viz, "row").max().to_array().max().item() - n_cols = leaf_dataset(plot_collection.viz, "col").max().to_array().max().item() - - plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") - - figsize = plot_bknd.scale_fig_size( - figsize, - rows=n_rows, - cols=n_cols, - figsize_units=figsize_units, - ) - - # other plot collection related initialization - if plot_collection is None: - # making copy of pc_kwargs["plot_grid_kws"] to pass to .wrap() - pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() - if "figsize" not in pc_kwargs["plot_grid_kws"]: - pc_kwargs["plot_grid_kws"]["figsize"] = figsize - pc_kwargs["plot_grid_kws"]["figsize_units"] = "dots" - - pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() if "chain" in distribution: pc_kwargs["aes"].setdefault("overlay", ["chain"]) if "model" in distribution: @@ -279,7 +248,35 @@ def plot_ess( pc_kwargs.setdefault("x", np.linspace(-x_diff, x_diff, n_models)) pc_kwargs["aes"].setdefault("x", ["model"]) aux_dim_list = [dim for dim in pc_kwargs["cols"] if dim != "__variable__"] - plot_collection = PlotCollection.wrap( + figsize = pc_kwargs.get("plot_grid_kws", {}).get("figsize", None) + figsize_units = pc_kwargs.get("plot_grid_kws", {}).get("figsize_units", "inches") + if figsize is None: + coeff = 0.2 + if "chain" in distribution.dims: + coeff += 0.1 + if "model" in distribution.dims: + coeff += 0.1 * distribution.sizes["model"] + n_plots, _ = process_facet_dims(distribution, pc_kwargs["cols"]) + col_wrap = pc_kwargs["col_wrap"] + print(f"\n n_plots = {n_plots},\n col_wrap = {col_wrap}") + if n_plots <= col_wrap: + n_rows, n_cols = 1, n_plots + else: + div_mod = divmod(n_plots, col_wrap) + n_rows = div_mod[0] + (div_mod[1] != 0) + n_cols = col_wrap + print(f"\n n_rows = {n_rows},\n n_cols = {n_cols}") + figsize = plot_bknd.scale_fig_size( + figsize, + rows=n_rows, + cols=n_cols, + figsize_units=figsize_units, + ) + print(f"\n figsize = {figsize!r}") + figsize_units = "dots" + pc_kwargs["plot_grid_kws"]["figsize"] = figsize + pc_kwargs["plot_grid_kws"]["figsize_units"] = figsize_units + plot_collection = PlotCollection.grid( distribution, backend=backend, **pc_kwargs, From 74e2dd83f9a152ecc0ec3b9710b275ccecff975d Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Tue, 3 Sep 2024 01:31:28 +0530 Subject: [PATCH 20/24] Updated plot_ess docstring --- src/arviz_plots/plots/essplot.py | 61 ++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index b1b19d0..eff9072 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -121,13 +121,20 @@ def plot_ess( ------- PlotCollection - Examples + Notes + ----- + Depending on the number of models, a slight x-axis separation aesthetic is applied for each + ess point for distinguishability in case of overlap + + See Also -------- - The following examples focus on behaviour specific to ``plot_ess``. - For a general introduction to batteries-included functions like this one and common - usage examples see :ref:`plots_intro` + :ref:`plots_intro` : + General introduction to batteries-included plotting functions, common use and logic overview - Default plot_ess for a single model: + Examples + -------- + We can manually map the color to the variable, and have the mapping apply + to the title too instead of only the ess markers: .. plot:: :context: close-figs @@ -135,25 +142,28 @@ def plot_ess( >>> from arviz_plots import plot_ess, style >>> style.use("arviz-clean") >>> from arviz_base import load_arviz_data - >>> centered = load_arviz_data('centered_eight') >>> non_centered = load_arviz_data('non_centered_eight') - >>> pc = plot_ess(centered) + >>> pc = plot_ess( + >>> non_centered, + >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, + >>> pc_kwargs={"aes": {"color": ["__variable__"]}}, + >>> aes_map={"title": ["color"]}, + >>> ) - Default plot_ess for multiple models: (Depending on the number of models, a slight - x-axis separation aesthetic is applied for each ess point for distinguishability in - case of overlap) + We can add extra methods to plot the mean and standard deviation as lines, and adjust + the minimum ess baseline as well: .. plot:: :context: close-figs >>> pc = plot_ess( - >>> {"centered": centered, "non centered": non_centered}, + >>> non_centered, >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, + >>> extra_methods=True, + >>> min_ess=200, >>> ) - >>> pc.add_legend("model") - We can also manually map the color to the variable, and have the mapping apply - to the title too instead of only the ess markers: + Rugs can also be added: .. plot:: :context: close-figs @@ -161,36 +171,33 @@ def plot_ess( >>> pc = plot_ess( >>> non_centered, >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, - >>> pc_kwargs={"aes": {"color": ["__variable__"]}}, - >>> aes_map={"title": ["color"]}, + >>> rug=True, >>> ) - If we add a mapping (like color) manually to the variable, but not specify which artist - to apply the mapping to- then it is applied to the 'ess' marker artist by default: + Relative ESS can be plotted instead of absolute: .. plot:: :context: close-figs >>> pc = plot_ess( - >>> centered, + >>> non_centered, >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, - >>> pc_kwargs={"aes": {"color": ["__variable__"]}}, + >>> relative=True, >>> ) - The artists' visual features can also be customized through plot_kwargs, based on the - kwargs that the visual element function for the artist accepts- like all the other - batteries included plots. For example, for the 'ess' artist, the scatter_xy function is - used. So if we want to change the marker: + We can also adjust the number of points: .. plot:: :context: close-figs >>> pc = plot_ess( - >>> centered, - >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, - >>> plot_kwargs={"ess": {"marker": "_"}}, + >>> non_centered, + >>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]}, + >>> n_points=10, >>> ) + .. minigallery:: plot_ess + """ # initial defaults if sample_dims is None: From 1ca070c059a4e35b94571896c5a7468b59d798bb Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Tue, 3 Sep 2024 04:22:57 +0530 Subject: [PATCH 21/24] Switched from .grid to .wrap, removed unused figsize coeffs, disabled pylint warning on testplots.py --- .pylintrc | 2 +- src/arviz_plots/plots/essplot.py | 10 +--------- tests/test_plots.py | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/.pylintrc b/.pylintrc index 23d42d7..0e48a26 100644 --- a/.pylintrc +++ b/.pylintrc @@ -474,7 +474,7 @@ max-locals=15 max-parents=7 # Maximum number of public methods for a class (see R0904). -max-public-methods=25 +max-public-methods=20 # Maximum number of return / yield for function / method body max-returns=6 diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index eff9072..565ca56 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -258,32 +258,24 @@ def plot_ess( figsize = pc_kwargs.get("plot_grid_kws", {}).get("figsize", None) figsize_units = pc_kwargs.get("plot_grid_kws", {}).get("figsize_units", "inches") if figsize is None: - coeff = 0.2 - if "chain" in distribution.dims: - coeff += 0.1 - if "model" in distribution.dims: - coeff += 0.1 * distribution.sizes["model"] n_plots, _ = process_facet_dims(distribution, pc_kwargs["cols"]) col_wrap = pc_kwargs["col_wrap"] - print(f"\n n_plots = {n_plots},\n col_wrap = {col_wrap}") if n_plots <= col_wrap: n_rows, n_cols = 1, n_plots else: div_mod = divmod(n_plots, col_wrap) n_rows = div_mod[0] + (div_mod[1] != 0) n_cols = col_wrap - print(f"\n n_rows = {n_rows},\n n_cols = {n_cols}") figsize = plot_bknd.scale_fig_size( figsize, rows=n_rows, cols=n_cols, figsize_units=figsize_units, ) - print(f"\n figsize = {figsize!r}") figsize_units = "dots" pc_kwargs["plot_grid_kws"]["figsize"] = figsize pc_kwargs["plot_grid_kws"]["figsize_units"] = figsize_units - plot_collection = PlotCollection.grid( + plot_collection = PlotCollection.wrap( distribution, backend=backend, **pc_kwargs, diff --git a/tests/test_plots.py b/tests/test_plots.py index 656ecc9..7da202c 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -111,7 +111,7 @@ def cmp(): @pytest.mark.parametrize("backend", ["matplotlib", "bokeh", "plotly", "none"]) -class TestPlots: +class TestPlots: # pylint: disable=too-many-public-methods @pytest.mark.parametrize("kind", ["kde", "hist", "ecdf"]) def test_plot_dist(self, datatree, backend, kind): pc = plot_dist(datatree, backend=backend, kind=kind) From 7726c571f6f099d57fd1bfc043acf809b57802e7 Mon Sep 17 00:00:00 2001 From: "Oriol (ProDesk)" Date: Tue, 10 Sep 2024 02:28:28 +0200 Subject: [PATCH 22/24] final fixes now only waiting for us to figure out behaviour and scope in arviz-stats and xarray-einstats --- pyproject.toml | 2 +- src/arviz_plots/plots/essplot.py | 130 ++++++++++++++++--------------- src/arviz_plots/plots/utils.py | 2 +- tests/test_hypothesis_plots.py | 10 +-- tests/test_plots.py | 6 +- 5 files changed, 73 insertions(+), 77 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 635fb26..3786161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ dynamic = ["version", "description"] dependencies = [ "arviz-base==0.2", - "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats", + "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats@1d_ess", ] [tool.flit.module] diff --git a/src/arviz_plots/plots/essplot.py b/src/arviz_plots/plots/essplot.py index 565ca56..1a01a89 100644 --- a/src/arviz_plots/plots/essplot.py +++ b/src/arviz_plots/plots/essplot.py @@ -93,6 +93,14 @@ def plot_ess( Mapping of artists to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `plot_kwargs`. + By default, no aesthetic mappings are defined. Only when multiple models + are present a color and x shift is generated to distinguish the data + coming from the different models. + + When ``mean`` or ``sd`` keys are present in `aes_map` but ``mean_text`` + or ``sd_text`` are not, the respective ``_text`` key will be added + with the same values as ``mean`` or ``sd`` ones. + plot_kwargs : mapping of {str : mapping or False}, optional Valid keys are: @@ -121,11 +129,6 @@ def plot_ess( ------- PlotCollection - Notes - ----- - Depending on the number of models, a slight x-axis separation aesthetic is applied for each - ess point for distinguishability in case of overlap - See Also -------- :ref:`plots_intro` : @@ -235,6 +238,10 @@ def plot_ess( backend = plot_collection.backend plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") + # getting backend specific linestyles + linestyles = plot_bknd.get_default_aes("linestyle", 4, {}) + # and default color + default_color = plot_bknd.get_default_aes("color", 1, {})[0] # set plot collection initialization defaults if plot_collection is None: @@ -309,7 +316,7 @@ def plot_ess( ess_kwargs = copy(plot_kwargs.get("ess", {})) if ess_kwargs is not False: - ess_dims, _, ess_ignore = filter_aes(plot_collection, aes_map, kind, sample_dims) + ess_dims, ess_aes, ess_ignore = filter_aes(plot_collection, aes_map, kind, sample_dims) if kind == "local": probs = np.linspace(0, 1, n_points, endpoint=False) ylabel = "{} for small intervals" @@ -343,70 +350,74 @@ def plot_ess( plot_axis=["x", "y"] ) + if "color" not in ess_aes: + ess_kwargs.setdefault("color", default_color) + plot_collection.map( scatter_xy, "ess", data=ess_dataset, ignore_aes=ess_ignore, **ess_kwargs ) # plot rug - # overlaying divergences(or other 'rug_kind') for each chain - if rug: - sample_stats = get_group(dt, "sample_stats", allow_missing=True) - if ( - sample_stats is not None - and rug_kind in sample_stats.data_vars - and np.any(sample_stats[rug_kind]) # 'diverging' by default - and rug_kwargs is not False - ): - rug_mask = dt.sample_stats[rug_kind] # 'diverging' by default - _, div_aes, div_ignore = filter_aes(plot_collection, aes_map, "rug", sample_dims) - if "color" not in div_aes: - rug_kwargs.setdefault("color", "black") - if "marker" not in div_aes: - rug_kwargs.setdefault("marker", "|") - if "size" not in div_aes: - rug_kwargs.setdefault("size", 30) - div_reduce_dims = [dim for dim in distribution.dims if dim not in aux_dim_list] - - values = distribution.azstats.compute_ranks(relative=True) + sample_stats = get_group(dt, "sample_stats", allow_missing=True) + if ( + rug + and sample_stats is not None + and rug_kind in sample_stats.data_vars + and np.any(sample_stats[rug_kind]) + ): + rug_mask = dt.sample_stats[rug_kind] + _, div_aes, div_ignore = filter_aes(plot_collection, aes_map, "rug", sample_dims) + if "color" not in div_aes: + rug_kwargs.setdefault("color", "black") + if "marker" not in div_aes: + rug_kwargs.setdefault("marker", "|") + if "size" not in div_aes: + rug_kwargs.setdefault("size", 30) + div_reduce_dims = [dim for dim in distribution.dims if dim not in aux_dim_list] + + values = distribution.azstats.compute_ranks(dims=sample_dims, relative=True) - plot_collection.map( - trace_rug, - "rug", - data=values, - ignore_aes=div_ignore, - y=distribution.min(div_reduce_dims), - mask=rug_mask, - xname=False, - **rug_kwargs, - ) # note: after plot_ppc merge, the `trace_rug` function might change + plot_collection.map( + trace_rug, + "rug", + data=values, + ignore_aes=div_ignore, + y=distribution.min(div_reduce_dims), + mask=rug_mask, + xname=False, + **rug_kwargs, + ) # defining x_range (used for mean, sd, minimum ess plotting) - x_range = [0, 1] - x_range = xr.DataArray(x_range) - - # getting backend specific linestyles - linestyles = plot_bknd.get_default_aes("linestyle", 4, {}) - # and default color - default_color = plot_bknd.get_default_aes("color", 1, {})[0] + x_range = xr.DataArray([0, 1]) # plot mean and sd and annotate them if extra_methods is not False: + mean_kwargs = copy(plot_kwargs.get("mean", {})) + mean_text_kwargs = copy(plot_kwargs.get("mean_text", {})) + sd_kwargs = copy(plot_kwargs.get("sd", {})) + sd_text_kwargs = copy(plot_kwargs.get("sd_text", {})) + # computing mean_ess mean_dims, mean_aes, mean_ignore = filter_aes(plot_collection, aes_map, "mean", sample_dims) - mean_ess = distribution.azstats.ess( - dims=mean_dims, method="mean", relative=relative, **stats_kwargs.get("mean", {}) - ) + mean_ess = None + if (mean_kwargs is not False) or (mean_text_kwargs is not False): + mean_ess = distribution.azstats.ess( + dims=mean_dims, method="mean", relative=relative, **stats_kwargs.get("mean", {}) + ) # computing sd_ess sd_dims, sd_aes, sd_ignore = filter_aes(plot_collection, aes_map, "sd", sample_dims) - sd_ess = distribution.azstats.ess( - dims=sd_dims, method="sd", relative=relative, **stats_kwargs.get("sd", {}) - ) + sd_ess = None + if (sd_kwargs is not False) or (sd_text_kwargs is not False): + sd_ess = distribution.azstats.ess( + dims=sd_dims, method="sd", relative=relative, **stats_kwargs.get("sd", {}) + ) - mean_kwargs = copy(plot_kwargs.get("mean", {})) if mean_kwargs is not False: # getting 2nd default linestyle for chosen backend and assigning it by default - mean_kwargs.setdefault("linestyle", linestyles[1]) + if "linestyle" not in mean_aes: + mean_kwargs.setdefault("linestyle", linestyles[1]) if "color" not in mean_aes: mean_kwargs.setdefault("color", default_color) @@ -420,9 +431,9 @@ def plot_ess( **mean_kwargs, ) - sd_kwargs = copy(plot_kwargs.get("sd", {})) if sd_kwargs is not False: - sd_kwargs.setdefault("linestyle", linestyles[2]) + if "linestyle" not in sd_aes: + sd_kwargs.setdefault("linestyle", linestyles[2]) if "color" not in sd_aes: sd_kwargs.setdefault("color", default_color) @@ -437,10 +448,7 @@ def plot_ess( sd_va_align = xr.where(mean_ess < sd_ess, "bottom", "top") mean_va_align = xr.where(mean_ess < sd_ess, "top", "bottom") - mean_text_kwargs = copy(plot_kwargs.get("mean_text", {})) - if ( - mean_text_kwargs is not False and mean_ess is not None - ): # mean_ess has to exist for an annotation to be applied + if mean_text_kwargs is not False: _, mean_text_aes, mean_text_ignore = filter_aes( plot_collection, aes_map, "mean_text", sample_dims ) @@ -467,10 +475,7 @@ def plot_ess( **mean_text_kwargs, ) - sd_text_kwargs = copy(plot_kwargs.get("sd_text", {})) - if ( - sd_text_kwargs is not False and sd_ess is not None - ): # sd_ess has to exist for an annotation to be applied + if sd_text_kwargs is not False: _, sd_text_aes, sd_text_ignore = filter_aes( plot_collection, aes_map, "sd_text", sample_dims ) @@ -508,7 +513,8 @@ def plot_ess( if relative: min_ess = min_ess / n_points - min_ess_kwargs.setdefault("linestyle", linestyles[3]) + if "linestyle" not in min_ess_aes: + min_ess_kwargs.setdefault("linestyle", linestyles[3]) if "color" not in min_ess_aes: min_ess_kwargs.setdefault("color", "gray") diff --git a/src/arviz_plots/plots/utils.py b/src/arviz_plots/plots/utils.py index 3b83592..a870033 100644 --- a/src/arviz_plots/plots/utils.py +++ b/src/arviz_plots/plots/utils.py @@ -37,7 +37,7 @@ def get_group(data, group, allow_missing=False): try: data = data[group] except KeyError: - if allow_missing: + if not allow_missing: raise return None if isinstance(data, Dataset): diff --git a/tests/test_hypothesis_plots.py b/tests/test_hypothesis_plots.py index d1a518a..93163fc 100644 --- a/tests/test_hypothesis_plots.py +++ b/tests/test_hypothesis_plots.py @@ -196,11 +196,6 @@ def test_plot_ridge(datatree, combined, plot_kwargs, labels_shade_label): assert all(key in child for child in pc.viz.children.values()) -@st.composite -def ess_min_ess(draw): - return draw(st.integers(min_value=10, max_value=150)) # max samples = 3 x 50 = 150 - - @given( plot_kwargs=st.fixed_dictionaries( {}, @@ -215,7 +210,6 @@ def ess_min_ess(draw): "sd_text": plot_kwargs_value, "min_ess": plot_kwargs_value, "title": plot_kwargs_value, - "remove_axis": st.just(False), }, ), kind=ess_kind_value, @@ -223,7 +217,7 @@ def ess_min_ess(draw): rug=st.booleans(), n_points=st.integers(min_value=1, max_value=5), extra_methods=st.booleans(), - min_ess=ess_min_ess(), + min_ess=st.integers(min_value=10, max_value=150), ) def test_plot_ess(datatree, kind, relative, rug, n_points, extra_methods, min_ess, plot_kwargs): pc = plot_ess( @@ -251,5 +245,5 @@ def test_plot_ess(datatree, kind, relative, rug, n_points, extra_methods, min_es assert all(key not in child for child in pc.viz.children.values()) else: assert all(key in child for child in pc.viz.children.values()) - elif key != "remove_axis": + else: assert all(key in child for child in pc.viz.children.values()) diff --git a/tests/test_plots.py b/tests/test_plots.py index 7da202c..54fee88 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -307,7 +307,6 @@ def test_plot_ess(self, datatree, backend): # checking aesthetics assert "overlay" in pc.aes["mu"].data_vars # overlay of chains - a = """ def test_plot_ess_sample(self, datatree_sample, backend): pc = plot_ess(datatree_sample, backend=backend, rug=True, sample_dims="sample") assert "chart" in pc.viz.data_vars @@ -318,10 +317,7 @@ def test_plot_ess_sample(self, datatree_sample, backend): assert "rug" in pc.viz["mu"] assert "hierarchy" not in pc.viz["mu"].dims assert "hierarchy" in pc.viz["theta"].dims - assert pc.viz["mu"].trace.shape == () # 0 chains here, so no overlay - """ # error when running this: ValueError: dims must be of length 2 (from arviz-stats ) - if a: - print("a") # just to temporarily suspend linting error + assert pc.viz["mu"].rug.shape == () # 0 chains here, so no overlay def test_plot_ess_models(self, datatree, datatree2, backend): pc = plot_ess({"c": datatree, "n": datatree2}, backend=backend, rug=False) From 8b053d6f5d78c4a60f7dbc0e8f1d5cde76185a8c Mon Sep 17 00:00:00 2001 From: "Oriol (ProDesk)" Date: Thu, 17 Oct 2024 11:37:30 +0200 Subject: [PATCH 23/24] update pyproject requirements --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3786161..635fb26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ dynamic = ["version", "description"] dependencies = [ "arviz-base==0.2", - "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats@1d_ess", + "arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats", ] [tool.flit.module] From 56cb9c6ee12ceacaf73f831f6c2b7456f48999bd Mon Sep 17 00:00:00 2001 From: "Oriol (ProDesk)" Date: Thu, 17 Oct 2024 11:42:49 +0200 Subject: [PATCH 24/24] pylint --- .pylintrc | 1 + 1 file changed, 1 insertion(+) diff --git a/.pylintrc b/.pylintrc index 0e48a26..c4ef136 100644 --- a/.pylintrc +++ b/.pylintrc @@ -67,6 +67,7 @@ disable=missing-docstring, no-member, import-error, possibly-used-before-assignment, + too-many-positional-arguments, fixme