Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dataset->stacked dataarray/dataframe converters #25

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/source/api/index.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
# API reference

## Data reoganization

```{eval-rst}
.. autosummary::
:toctree: generated/

arviz_base.extract
arviz_base.dataset_to_dataarray
arviz_base.dataset_to_dataframe
```

## User facing converters

```{eval-rst}
.. autosummary::
:toctree: generated/

arviz_base.convert_to_datatree
arviz_base.extract
arviz_base.from_dict
```

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dynamic = ["version", "description"]
dependencies = [
"numpy>=1.23",
"xarray>=2022.6.0,<2024.9.1",
"xarray-datatree",
"xarray-datatree<0.0.15",
"typing-extensions>=3.10",
]

Expand Down
3 changes: 2 additions & 1 deletion src/arviz_base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from arviz_base._version import __version__
from arviz_base.base import dict_to_dataset, generate_dims_coords, make_attrs, ndarray_to_dataarray
from arviz_base.converters import *
from arviz_base.converters import convert_to_dataset, convert_to_datatree
from arviz_base.datasets import clear_data_home, get_data_home, list_datasets, load_arviz_data
from arviz_base.io_cmdstanpy import from_cmdstanpy
from arviz_base.io_dict import from_dict
from arviz_base.io_emcee import from_emcee
from arviz_base.rcparams import rc_context, rcParams
from arviz_base.reorg import extract, dataset_to_dataarray, dataset_to_dataframe
from arviz_base.sel_utils import *
184 changes: 9 additions & 175 deletions src/arviz_base/converters.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
"""Generalistic converters."""
"""Generalistic converters.

Here "generalistic" means catch anything that can be converter into datatree and
convert it via its specific function.
"""

import numpy as np
import xarray as xr
from datatree import DataTree, open_datatree

from arviz_base.base import dict_to_dataset
from arviz_base.rcparams import rcParams
from arviz_base.utils import _var_names

__all__ = ["convert_to_datatree", "convert_to_dataset", "extract"]
__all__ = [
"convert_to_datatree",
"convert_to_dataset",
]


# pylint: disable=too-many-return-statements
Expand Down Expand Up @@ -185,174 +190,3 @@ def convert_to_dataset(obj, *, group="posterior", **kwargs):
f"Can not extract {group} from {obj}! See docs for other " "conversion utilities."
)
return dataset.to_dataset()


# TODO: remove this ignore about too many statements once the code uses validator functions
def extract( # noqa: PLR0915
data,
group="posterior",
sample_dims=None,
*,
combined=True,
var_names=None,
filter_vars=None,
num_samples=None,
weights=None,
resampling_method=None,
keep_dataset=False,
random_seed=None,
):
"""Extract a group or group subset from a DataTree.

Parameters
----------
idata : DataTree_like
DataTree from which to extract the data.
group : str, optional
Which group to extract data from.
sample_dims : sequence of hashable, optional
List of dimensions that should be considered sampling dimensions.
Random subsets and potential stacking if ``combine=True`` happen
over these dimensions only. Defaults to ``rcParams["data.sample_dims"]``.
combined : bool, optional
Combine `sample_dims` dimensions into ``sample``. Won't work if
a dimension named ``sample`` already exists.
It is irrelevant and ignored when `sample_dims` is a single dimension.
var_names : str or list of str, optional
Variables to be extracted. Prefix the variables by `~` when you want to exclude them.
filter_vars: {None, "like", "regex"}, optional
If `None` (default), interpret var_names as the real variables names. If "like",
interpret var_names as substrings of the real variables names. If "regex",
interpret var_names as regular expressions on the real variables names. A la
`pandas.filter`.
Like with plotting, sometimes it's easier to subset saying what to exclude
instead of what to include
num_samples : int, optional
Extract only a subset of the samples. Only valid if ``combined=True`` or
`sample_dims` represents a single dimension.
weights : array-like, optional
Extract a weighted subset of the samples. Only valid if `num_samples` is not ``None``.
resampling_method : str, optional
Method to use for resampling. Default is "multinomial". Options are "multinomial"
and "stratified". For stratified resampling, weights must be provided.
Default is "stratified" if weights are provided, "multinomial" otherwise.
keep_dataset : bool, optional
If true, always return a DataSet. If false (default) return a DataArray
when there is a single variable.
random_seed : int, numpy.Generator, optional
Random number generator or seed. Only used if ``weights`` is not ``None``
or if ``num_samples`` is not ``None``.

Returns
-------
xarray.DataArray or xarray.Dataset

Examples
--------
The default behaviour is to return the posterior group after stacking the chain and
draw dimensions.

.. jupyter-execute::

import arviz_base as az
idata = az.load_arviz_data("centered_eight")
az.extract(idata)

You can also indicate a subset to be returned, but in variables and in samples:

.. jupyter-execute::

az.extract(idata, var_names="theta", num_samples=100)

To keep the chain and draw dimensions, use ``combined=False``.

.. jupyter-execute::

az.extract(idata, group="prior", combined=False)

"""
# TODO: use validator function
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
if len(sample_dims) == 1:
combined = True
if num_samples is not None and not combined:
raise ValueError(
"num_samples is only compatible with combined=True or length 1 sample_dims"
)
if weights is not None and num_samples is None:
raise ValueError("weights are only compatible with num_samples")

data = convert_to_dataset(data, group=group)
var_names = _var_names(var_names, data, filter_vars)
if var_names is not None:
if len(var_names) == 1 and not keep_dataset:
var_names = var_names[0]
data = data[var_names]
elif len(data.data_vars) == 1:
data = data[list(data.data_vars)[0]]

if weights is not None:
resampling_method = "stratified" if resampling_method is None else resampling_method
weights = np.array(weights).ravel()
if len(weights) != np.prod([data.sizes[dim] for dim in sample_dims]):
raise ValueError("Weights must have the same size as `sample_dims`")
else:
resampling_method = "multinomial" if resampling_method is None else resampling_method

if resampling_method not in ("multinomial", "stratified"):
raise ValueError(f"Invalid resampling_method: {resampling_method}")

if combined and len(sample_dims) != 1:
data = data.stack(sample=sample_dims)
combined_dim = "sample"
elif len(sample_dims) == 1:
combined_dim = sample_dims[0]

if weights is not None or num_samples is not None:
if random_seed is None:
rng = np.random.default_rng()
elif isinstance(random_seed, int | np.integer):
rng = np.random.default_rng(random_seed)
elif isinstance(random_seed, np.random.Generator):
rng = random_seed
else:
raise ValueError(f"Invalid random_seed value: {random_seed}")

replace = weights is not None

if resampling_method == "multinomial":
resample_indices = rng.choice(
np.arange(data.sizes[combined_dim]),
size=num_samples,
p=weights,
replace=replace,
)
elif resampling_method == "stratified":
if weights is None:
raise ValueError("Weights must be provided for stratified resampling")
resample_indices = _stratified_resample(weights, rng)

data = data.isel({combined_dim: resample_indices})

return data


def _stratified_resample(weights, rng):
"""Stratified resampling."""
N = len(weights)
single_uniform = (rng.random(N) + np.arange(N)) / N
indexes = np.zeros(N, dtype=int)
cum_sum = np.cumsum(weights)

i, j = 0, 0
while i < N:
if single_uniform[i] < cum_sum[j]:
indexes[i] = j
i += 1
else:
j += 1

return indexes
Loading
Loading