Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add features to support ArviZ integration, Rebased #607

Merged
merged 3 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
Allows to perform, e.g. MCMC in unconstrained space.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of the observed data.
"""

# Ensure device string.
Expand Down Expand Up @@ -132,6 +133,8 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior":

def _x_else_default_x(self, x: Optional[Array]) -> Tensor:
if x is not None:
# New x, reset posterior sampler.
self._posterior_sampler = None
return process_x(
x, x_shape=self._x_shape, allow_iid_x=self.potential_fn.allow_iid_x
)
Expand Down
131 changes: 110 additions & 21 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.
from functools import partial
from typing import Any, Callable, Dict, Optional, Union
from math import ceil
from typing import Any, Callable, Dict, Optional, Tuple, Union
from warnings import warn

import arviz as az
import torch
import torch.distributions.transforms as torch_tf
from arviz.data import InferenceData
from joblib import Parallel, delayed
from numpy import ndarray
from pyro.infer.mcmc import HMC, NUTS
from pyro.infer.mcmc.api import MCMC
from torch import Tensor
Expand All @@ -17,14 +21,15 @@
from sbi.samplers.mcmc import (
IterateParameters,
Slice,
SliceSamplerSerial,
SliceSamplerVectorized,
proposal_init,
resample_given_potential_fn,
sir_init,
slice_np_parallized,
)
from sbi.simulators.simutils import tqdm_joblib
from sbi.types import Shape, TorchTransform
from sbi.utils import pyro_potential_wrapper, transformed_potential
from sbi.utils import pyro_potential_wrapper, tensor2numpy, transformed_potential
from sbi.utils.torchutils import ensure_theta_batched


Expand Down Expand Up @@ -102,6 +107,9 @@ def __init__(
self.init_strategy = init_strategy
self.init_strategy_parameters = init_strategy_parameters
self.num_workers = num_workers
self._posterior_sampler = None
# Hardcode parameter name to reduce clutter kwargs.
self.param_name = "theta"

if init_strategy_num_candidates is not None:
warn(
Expand Down Expand Up @@ -130,6 +138,11 @@ def mcmc_method(self, method: str) -> None:
"""See `set_mcmc_method`."""
self.set_mcmc_method(method)

@property
def posterior_sampler(self):
"""Returns sampler created by `sample`."""
return self._posterior_sampler

def set_mcmc_method(self, method: str) -> "NeuralPosterior":
"""Sets sampling method to for MCMC and returns `NeuralPosterior`.

Expand Down Expand Up @@ -185,7 +198,7 @@ def sample(
sample_with: Optional[str] = None,
num_workers: Optional[int] = None,
show_progress_bars: bool = True,
) -> Tensor:
) -> Union[Tensor, Tuple[Tensor, InferenceData]]:
r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC.

Check the `__init__()` method for a description of all arguments as well as
Expand Down Expand Up @@ -291,11 +304,12 @@ def sample(
warmup_steps=warmup_steps, # type: ignore
num_chains=num_chains,
show_progress_bars=show_progress_bars,
).detach()
janfb marked this conversation as resolved.
Show resolved Hide resolved
)
else:
raise NameError

samples = self.theta_transform.inv(transformed_samples)

return samples.reshape((*sample_shape, -1)) # type: ignore

def _build_mcmc_init_fn(
Expand Down Expand Up @@ -419,6 +433,7 @@ def _slice_np_mcmc(
warmup_steps: int,
vectorized: bool = False,
num_workers: int = 1,
init_width: Union[float, ndarray] = 0.01,
janfb marked this conversation as resolved.
Show resolved Hide resolved
show_progress_bars: bool = True,
) -> Tensor:
"""Custom implementation of slice sampling using Numpy.
Expand All @@ -429,32 +444,47 @@ def _slice_np_mcmc(
initial_params: Initial parameters for MCMC chain.
thin: Thinning (subsampling) factor.
warmup_steps: Initial number of samples to discard.
vectorized: Whether to use a vectorized implementation of
the Slice sampler (still experimental).
num_workers: number of CPU cores to use
seed: seed that will be used to generate sub-seeds for each worker
vectorized: Whether to use a vectorized implementation of the Slice sampler.
num_workers: Number of CPU cores to use.
init_width: Inital width of brackets.
show_progress_bars: Whether to show a progressbar during sampling;
can only be turned off for vectorized sampler.

Returns: Tensor of shape (num_samples, shape_of_single_theta).
Returns:
Tensor of shape (num_samples, shape_of_single_theta).
Arviz InferenceData object.
"""

num_chains, dim_samples = initial_params.shape

samples = slice_np_parallized(
potential_function,
initial_params,
num_samples,
if not vectorized:
SliceSamplerMultiChain = SliceSamplerSerial
else:
SliceSamplerMultiChain = SliceSamplerVectorized

posterior_sampler = SliceSamplerMultiChain(
janfb marked this conversation as resolved.
Show resolved Hide resolved
init_params=tensor2numpy(initial_params),
log_prob_fn=potential_function,
num_chains=num_chains,
thin=thin,
warmup_steps=warmup_steps,
vectorized=vectorized,
verbose=show_progress_bars,
num_workers=num_workers,
show_progress_bars=show_progress_bars,
init_width=init_width,
)
warmup_ = warmup_steps * thin
num_samples_ = ceil((num_samples * thin) / num_chains)
# Run mcmc including warmup
samples = posterior_sampler.run(warmup_ + num_samples_)
samples = samples[:, warmup_steps:, :] # discard warmup steps
samples = torch.from_numpy(samples) # chains x samples x dim

# Save posterior sampler.
self._posterior_sampler = posterior_sampler

# Save sample as potential next init (if init_strategy == 'latest_sample').
self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples)

# Collect samples from all chains.
samples = samples.reshape(-1, dim_samples)[:num_samples, :]
assert samples.shape[0] == num_samples

Expand All @@ -470,7 +500,7 @@ def _pyro_mcmc(
warmup_steps: int = 200,
num_chains: Optional[int] = 1,
show_progress_bars: bool = True,
):
) -> Tensor:
r"""Return samples obtained using Pyro HMC, NUTS for slice kernels.

Args:
Expand All @@ -484,7 +514,9 @@ def _pyro_mcmc(
num_chains: Whether to sample in parallel. If None, use all but one CPU.
show_progress_bars: Whether to show a progressbar during sampling.

Returns: Tensor of shape (num_samples, shape_of_single_theta).
Returns:
Tensor of shape (num_samples, shape_of_single_theta).
Arviz InferenceData object.
"""
num_chains = mp.cpu_count() - 1 if num_chains is None else num_chains

Expand All @@ -494,7 +526,7 @@ def _pyro_mcmc(
kernel=kernels[mcmc_method](potential_fn=potential_function),
num_samples=(thin * num_samples) // num_chains + num_chains,
warmup_steps=warmup_steps,
initial_params={"": initial_params},
initial_params={self.param_name: initial_params},
num_chains=num_chains,
mp_context="spawn",
disable_progbar=not show_progress_bars,
Expand All @@ -505,10 +537,13 @@ def _pyro_mcmc(
-1, initial_params.shape[1] # .shape[1] = dim of theta
)

# Save posterior sampler.
self._posterior_sampler = sampler

samples = samples[::thin][:num_samples]
assert samples.shape[0] == num_samples

return samples
return samples.detach()

def _prepare_potential(self, method: str) -> Callable:
"""Combines potential and transform and takes care of gradients and pyro.
Expand Down Expand Up @@ -612,6 +647,60 @@ def map(
force_update=force_update,
)

def get_arviz_inference_data(self) -> InferenceData:
"""Returns arviz InferenceData object constructed most recent samples.

Note: the InferenceData is constructed using the posterior samples generated in
most recent call to `.sample(...)`.

For Pyro HMC and NUTS kernels InferenceData will contain diagnostics, for Pyro
Slice or sbi slice sampling samples, only the samples are added.

Returns:
inference_data: Arviz InferenceData object.
"""
assert (
self._posterior_sampler is not None
), """No samples have been generated, call .sample() first."""

sampler: Union[
MCMC, SliceSamplerSerial, SliceSamplerVectorized
] = self._posterior_sampler

# If Pyro sampler and samples not transformed, use arviz' from_pyro.
# Exclude 'slice' kernel as it lacks the 'divergence' diagnostics key.
if isinstance(self._posterior_sampler, (HMC, NUTS)) and isinstance(
self.theta_transform, torch_tf.IndependentTransform
):
inference_data = az.from_pyro(sampler)

# otherwise get samples from sampler and transform to original space.
else:
transformed_samples = sampler.get_samples(group_by_chain=True)
# Pyro samplers returns dicts, get values.
if isinstance(transformed_samples, Dict):
# popitem gets last items, [1] get the values as tensor.
transformed_samples = transformed_samples.popitem()[1]
# Our slice samplers return numpy arrays.
elif isinstance(transformed_samples, ndarray):
transformed_samples = torch.from_numpy(transformed_samples).type(
torch.float32
)
# For MultipleIndependent priors transforms first dim must be batch dim.
# thus, reshape back and forth to have batch dim in front.
num_chains, samples_per_chain, dim_params = transformed_samples.shape
samples = self.theta_transform.inv( # type: ignore
transformed_samples.reshape(-1, dim_params)
).reshape( # type: ignore
num_chains, samples_per_chain, dim_params
)

inference_data = az.convert_to_inference_data(
{f"{self.param_name}": samples}
)

return inference_data


def _maybe_use_dict_entry(default: Any, key: str, dict_to_check: Dict) -> Any:
"""Returns `default` if `key` is not in the dict and otherwise the dict entry.
Expand Down
2 changes: 1 addition & 1 deletion sbi/samplers/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
from sbi.samplers.mcmc.slice import Slice
from sbi.samplers.mcmc.slice_numpy import (
SliceSampler,
SliceSamplerSerial,
SliceSamplerVectorized,
slice_np_parallized,
)
Loading