Skip to content

Commit

Permalink
mcse quantiles computating shifted to plot_mcse function (out of visu…
Browse files Browse the repository at this point in the history
…al element)
  • Loading branch information
imperorrp committed Aug 6, 2024
1 parent 5efbef8 commit 46bdc17
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 87 deletions.
84 changes: 4 additions & 80 deletions src/arviz_plots/plots/mcseplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,94 +233,18 @@ def plot_mcse(
# else:
# use new errorbar visual element function to plot errorbars
else:
# ----------------------------------------------
# || WIP ||
# ----------------------------------------------
# data to pass to .map() and get facetted:
# - mcse_dataset

# data that has to be computed before: quantile values
# thi can be computed with distribution and probs and the quantile func from arviz-stats

# ^for every subsel of variable/dimension-by-coords in distribution, the values are
# flattened. then call quantile(values, probs) for each of these subselections, and
# concatenate it to the mcse_dataset for later facetting by map() and destructuring by
# the errorbar visual element func

a = """
plotters = xarray_sel_iter(
distribution, skip_dims=mcse_dims
) # create subselections of distribution, looping over all dims except mcse_dims
z_mcse_dataset = (
xr.Dataset()
) # to be concatenated with mcse_dataset later along dim=plot_axis
# z_mcse_dict = {}
# compute quantile values for each subselection of distribution
for var_name, sel, isel in plotters:
print(f"\nvar={var_name} | sel={sel} | sel.items()={sel.items()} | isel={isel}")
da = distribution[var_name].sel(sel)
da = da.values.flatten()
# print(f"\n da = {da}")
quantile_values = array_stats.quantile(da, probs)
# print(f"quantile_values = {quantile_values}")
# Convert quantile_values to a DataArray
quantile_da = xr.DataArray(quantile_values, dims=["mcse_dim"], name=var_name)
print(f"\n quantile_da = {quantile_da}")
# Expand dims of dataarray and assign new coords 'z' and sel dim coords
quantile_da_expanded = quantile_da.expand_dims({"plot_axis": ["z"]})
# .assign_coords(
# {"plot_axis": ["z"]}
# )
# .to_dataset( # .assign_coords(plot_axis=["z"])
# name=var_name
# )
if sel: # adding sel dims to quantile_da
sel_expanded = {key: [value] for key, value in sel.items()}
print(f"sel_expanded = {sel_expanded}")
quantile_da_expanded = quantile_da_expanded.expand_dims(sel_expanded)
# dim=list(sel.keys())
# ).assign_coords(**sel)
# quantile_da_expanded = quantile_da_expanded.assign_coords(**sel)
print(f"\n quantile_da_expanded = {quantile_da_expanded}")
# print(f"\n mcse_dataset sel = {mcse_dataset[var_name].sel(**sel)!r}")
# print(f"\n mcse_dataset isel = {mcse_dataset[var_name].isel(isel)!r}")
if var_name in z_mcse_dataset:
combined_dataarray = z_mcse_dataset[var_name].combine_first(
quantile_da_expanded
)
print(f"\n combined_dataarray = {combined_dataarray}")
z_mcse_dataset.update({var_name: (var_name, combined_dataarray)})
# .concat(), .merge() are not in-place functions. .update() is in-place but
# not working rn
# z_mcse_dataset[var_name] =
# xr.concat(
# [z_mcse_dataset[var_name], quantile_da_expanded],
# dim="school",
# )
else:
z_mcse_dataset[var_name] = quantile_da_expanded
# print(f"\n adding var {var_name} to z_mcse_dataset: {z_mcse_dataset}")
print(f"\n z_mcse_dataset = {z_mcse_dataset}")"""
print(len(a))

# print(f"\n final z_mcse_dataset = {z_mcse_dataset}")
quantiles_dataset = distribution.quantile(probs, dim=mcse_dims)
print(f"\n quantiles_dataset = {quantiles_dataset}")

# for now the quantile_values can be computed in the visual element function itself
plot_collection.map(
error_bar,
"mcse",
data=mcse_dataset,
ignore_aes=mcse_ignore,
distribution=distribution, # map() subsets this before passing to error_bar
quantiles_dataset=quantiles_dataset,
# distribution=distribution, # map() subsets this before passing to error_bar
**mcse_kwargs,
)

Expand Down
9 changes: 2 additions & 7 deletions src/arviz_plots/visuals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,14 @@
from arviz_stats.numba import array_stats


def error_bar(da, target, backend, distribution, x=None, y=None, **kwargs):
def error_bar(da, target, backend, quantiles_dataset, x=None, y=None, **kwargs):
"""Plot error bars.
Note: This uses subset info from .map() so toggle it to true when calling this func
"""
plot_backend = import_module(f"arviz_plots.backend.{backend}")
probs, yerr = _process_da_x_y(da, x, y)
# the dataarray of distribution subset (by map()) is flattened and
# used to calculate the quantile values
da = distribution
# print(f"\n subsetted distribution = {da}")
da = da.values.flatten()
quantile_values = array_stats.quantile(da, probs)
quantile_values = quantiles_dataset
return plot_backend.errorbar(probs, quantile_values, yerr, target, **kwargs)


Expand Down

0 comments on commit 46bdc17

Please sign in to comment.