diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py
index 302764a51..f45f67115 100644
--- a/sbi/inference/posteriors/base_posterior.py
+++ b/sbi/inference/posteriors/base_posterior.py
@@ -37,6 +37,7 @@ def __init__(
Allows to perform, e.g. MCMC in unconstrained space.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
+ x_shape: Shape of the observed data.
"""
# Ensure device string.
@@ -132,6 +133,8 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior":
def _x_else_default_x(self, x: Optional[Array]) -> Tensor:
if x is not None:
+ # New x, reset posterior sampler.
+ self._posterior_sampler = None
return process_x(
x, x_shape=self._x_shape, allow_iid_x=self.potential_fn.allow_iid_x
)
diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py
index 1c9dd5889..7b2dbb531 100644
--- a/sbi/inference/posteriors/mcmc_posterior.py
+++ b/sbi/inference/posteriors/mcmc_posterior.py
@@ -1,12 +1,16 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see .
from functools import partial
-from typing import Any, Callable, Dict, Optional, Union
+from math import ceil
+from typing import Any, Callable, Dict, Optional, Tuple, Union
from warnings import warn
+import arviz as az
import torch
import torch.distributions.transforms as torch_tf
+from arviz.data import InferenceData
from joblib import Parallel, delayed
+from numpy import ndarray
from pyro.infer.mcmc import HMC, NUTS
from pyro.infer.mcmc.api import MCMC
from torch import Tensor
@@ -17,14 +21,15 @@
from sbi.samplers.mcmc import (
IterateParameters,
Slice,
+ SliceSamplerSerial,
+ SliceSamplerVectorized,
proposal_init,
resample_given_potential_fn,
sir_init,
- slice_np_parallized,
)
from sbi.simulators.simutils import tqdm_joblib
from sbi.types import Shape, TorchTransform
-from sbi.utils import pyro_potential_wrapper, transformed_potential
+from sbi.utils import pyro_potential_wrapper, tensor2numpy, transformed_potential
from sbi.utils.torchutils import ensure_theta_batched
@@ -102,6 +107,9 @@ def __init__(
self.init_strategy = init_strategy
self.init_strategy_parameters = init_strategy_parameters
self.num_workers = num_workers
+ self._posterior_sampler = None
+ # Hardcode parameter name to reduce clutter kwargs.
+ self.param_name = "theta"
if init_strategy_num_candidates is not None:
warn(
@@ -130,6 +138,11 @@ def mcmc_method(self, method: str) -> None:
"""See `set_mcmc_method`."""
self.set_mcmc_method(method)
+ @property
+ def posterior_sampler(self):
+ """Returns sampler created by `sample`."""
+ return self._posterior_sampler
+
def set_mcmc_method(self, method: str) -> "NeuralPosterior":
"""Sets sampling method to for MCMC and returns `NeuralPosterior`.
@@ -185,7 +198,7 @@ def sample(
sample_with: Optional[str] = None,
num_workers: Optional[int] = None,
show_progress_bars: bool = True,
- ) -> Tensor:
+ ) -> Union[Tensor, Tuple[Tensor, InferenceData]]:
r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC.
Check the `__init__()` method for a description of all arguments as well as
@@ -291,11 +304,12 @@ def sample(
warmup_steps=warmup_steps, # type: ignore
num_chains=num_chains,
show_progress_bars=show_progress_bars,
- ).detach()
+ )
else:
raise NameError
samples = self.theta_transform.inv(transformed_samples)
+
return samples.reshape((*sample_shape, -1)) # type: ignore
def _build_mcmc_init_fn(
@@ -419,6 +433,7 @@ def _slice_np_mcmc(
warmup_steps: int,
vectorized: bool = False,
num_workers: int = 1,
+ init_width: Union[float, ndarray] = 0.01,
show_progress_bars: bool = True,
) -> Tensor:
"""Custom implementation of slice sampling using Numpy.
@@ -429,32 +444,47 @@ def _slice_np_mcmc(
initial_params: Initial parameters for MCMC chain.
thin: Thinning (subsampling) factor.
warmup_steps: Initial number of samples to discard.
- vectorized: Whether to use a vectorized implementation of
- the Slice sampler (still experimental).
- num_workers: number of CPU cores to use
- seed: seed that will be used to generate sub-seeds for each worker
+ vectorized: Whether to use a vectorized implementation of the Slice sampler.
+ num_workers: Number of CPU cores to use.
+ init_width: Inital width of brackets.
show_progress_bars: Whether to show a progressbar during sampling;
can only be turned off for vectorized sampler.
- Returns: Tensor of shape (num_samples, shape_of_single_theta).
+ Returns:
+ Tensor of shape (num_samples, shape_of_single_theta).
+ Arviz InferenceData object.
"""
num_chains, dim_samples = initial_params.shape
- samples = slice_np_parallized(
- potential_function,
- initial_params,
- num_samples,
+ if not vectorized:
+ SliceSamplerMultiChain = SliceSamplerSerial
+ else:
+ SliceSamplerMultiChain = SliceSamplerVectorized
+
+ posterior_sampler = SliceSamplerMultiChain(
+ init_params=tensor2numpy(initial_params),
+ log_prob_fn=potential_function,
+ num_chains=num_chains,
thin=thin,
- warmup_steps=warmup_steps,
- vectorized=vectorized,
+ verbose=show_progress_bars,
num_workers=num_workers,
- show_progress_bars=show_progress_bars,
+ init_width=init_width,
)
+ warmup_ = warmup_steps * thin
+ num_samples_ = ceil((num_samples * thin) / num_chains)
+ # Run mcmc including warmup
+ samples = posterior_sampler.run(warmup_ + num_samples_)
+ samples = samples[:, warmup_steps:, :] # discard warmup steps
+ samples = torch.from_numpy(samples) # chains x samples x dim
+
+ # Save posterior sampler.
+ self._posterior_sampler = posterior_sampler
# Save sample as potential next init (if init_strategy == 'latest_sample').
self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples)
+ # Collect samples from all chains.
samples = samples.reshape(-1, dim_samples)[:num_samples, :]
assert samples.shape[0] == num_samples
@@ -470,7 +500,7 @@ def _pyro_mcmc(
warmup_steps: int = 200,
num_chains: Optional[int] = 1,
show_progress_bars: bool = True,
- ):
+ ) -> Tensor:
r"""Return samples obtained using Pyro HMC, NUTS for slice kernels.
Args:
@@ -484,7 +514,9 @@ def _pyro_mcmc(
num_chains: Whether to sample in parallel. If None, use all but one CPU.
show_progress_bars: Whether to show a progressbar during sampling.
- Returns: Tensor of shape (num_samples, shape_of_single_theta).
+ Returns:
+ Tensor of shape (num_samples, shape_of_single_theta).
+ Arviz InferenceData object.
"""
num_chains = mp.cpu_count() - 1 if num_chains is None else num_chains
@@ -494,7 +526,7 @@ def _pyro_mcmc(
kernel=kernels[mcmc_method](potential_fn=potential_function),
num_samples=(thin * num_samples) // num_chains + num_chains,
warmup_steps=warmup_steps,
- initial_params={"": initial_params},
+ initial_params={self.param_name: initial_params},
num_chains=num_chains,
mp_context="spawn",
disable_progbar=not show_progress_bars,
@@ -505,10 +537,13 @@ def _pyro_mcmc(
-1, initial_params.shape[1] # .shape[1] = dim of theta
)
+ # Save posterior sampler.
+ self._posterior_sampler = sampler
+
samples = samples[::thin][:num_samples]
assert samples.shape[0] == num_samples
- return samples
+ return samples.detach()
def _prepare_potential(self, method: str) -> Callable:
"""Combines potential and transform and takes care of gradients and pyro.
@@ -612,6 +647,60 @@ def map(
force_update=force_update,
)
+ def get_arviz_inference_data(self) -> InferenceData:
+ """Returns arviz InferenceData object constructed most recent samples.
+
+ Note: the InferenceData is constructed using the posterior samples generated in
+ most recent call to `.sample(...)`.
+
+ For Pyro HMC and NUTS kernels InferenceData will contain diagnostics, for Pyro
+ Slice or sbi slice sampling samples, only the samples are added.
+
+ Returns:
+ inference_data: Arviz InferenceData object.
+ """
+ assert (
+ self._posterior_sampler is not None
+ ), """No samples have been generated, call .sample() first."""
+
+ sampler: Union[
+ MCMC, SliceSamplerSerial, SliceSamplerVectorized
+ ] = self._posterior_sampler
+
+ # If Pyro sampler and samples not transformed, use arviz' from_pyro.
+ # Exclude 'slice' kernel as it lacks the 'divergence' diagnostics key.
+ if isinstance(self._posterior_sampler, (HMC, NUTS)) and isinstance(
+ self.theta_transform, torch_tf.IndependentTransform
+ ):
+ inference_data = az.from_pyro(sampler)
+
+ # otherwise get samples from sampler and transform to original space.
+ else:
+ transformed_samples = sampler.get_samples(group_by_chain=True)
+ # Pyro samplers returns dicts, get values.
+ if isinstance(transformed_samples, Dict):
+ # popitem gets last items, [1] get the values as tensor.
+ transformed_samples = transformed_samples.popitem()[1]
+ # Our slice samplers return numpy arrays.
+ elif isinstance(transformed_samples, ndarray):
+ transformed_samples = torch.from_numpy(transformed_samples).type(
+ torch.float32
+ )
+ # For MultipleIndependent priors transforms first dim must be batch dim.
+ # thus, reshape back and forth to have batch dim in front.
+ num_chains, samples_per_chain, dim_params = transformed_samples.shape
+ samples = self.theta_transform.inv( # type: ignore
+ transformed_samples.reshape(-1, dim_params)
+ ).reshape( # type: ignore
+ num_chains, samples_per_chain, dim_params
+ )
+
+ inference_data = az.convert_to_inference_data(
+ {f"{self.param_name}": samples}
+ )
+
+ return inference_data
+
def _maybe_use_dict_entry(default: Any, key: str, dict_to_check: Dict) -> Any:
"""Returns `default` if `key` is not in the dict and otherwise the dict entry.
diff --git a/sbi/samplers/mcmc/__init__.py b/sbi/samplers/mcmc/__init__.py
index 125cc5f86..7d3abe146 100644
--- a/sbi/samplers/mcmc/__init__.py
+++ b/sbi/samplers/mcmc/__init__.py
@@ -7,6 +7,6 @@
from sbi.samplers.mcmc.slice import Slice
from sbi.samplers.mcmc.slice_numpy import (
SliceSampler,
+ SliceSamplerSerial,
SliceSamplerVectorized,
- slice_np_parallized,
)
diff --git a/sbi/samplers/mcmc/slice_numpy.py b/sbi/samplers/mcmc/slice_numpy.py
index a0e90cf94..7924fa671 100644
--- a/sbi/samplers/mcmc/slice_numpy.py
+++ b/sbi/samplers/mcmc/slice_numpy.py
@@ -5,6 +5,7 @@
import sys
from math import ceil
from typing import Callable, Optional, Union
+from warnings import warn
import numpy as np
import torch
@@ -55,23 +56,34 @@ def gen(self, n_samples):
class SliceSampler(MCMCSampler):
def __init__(
- self, x, lp_f, max_width=float("inf"), thin=None, verbose: bool = False
+ self,
+ x,
+ lp_f,
+ max_width=float("inf"),
+ init_width: Union[float, np.ndarray] = 0.01,
+ thin=None,
+ tuning: int = 50,
+ verbose: bool = False,
):
"""Slice sampling for multivariate continuous probability distributions.
It cycles sampling from each conditional using univariate slice sampling.
Args:
- x: initial state
+ x: Initial state.
lp_f: Function that returns the log prob.
- max_width: maximum bracket width
- thin: amount of thinning; if None, no thinning.
+ max_width: maximum bracket width.
+ init_width: Inital width of brackets.
+ thin: Amount of thinning; if None, no thinning.
+ tuning: Number of tuning steps for brackets.
verbose: Whether to show progress bars (False).
"""
MCMCSampler.__init__(self, x, lp_f, thin, verbose=verbose)
self.max_width = max_width
+ self.init_width = init_width
self.width = None
+ self.tuning = tuning
def gen(
self,
@@ -141,15 +153,14 @@ def _tune_bracket_width(self, rng):
rng: Random number generator to use.
"""
- n_samples = 50
order = list(range(self.n_dims))
x = self.x.copy()
- self.width = np.full(self.n_dims, 0.01)
+ self.width = np.full(self.n_dims, self.init_width)
- tbar = trange(n_samples, miniters=10, disable=not self.verbose)
+ tbar = trange(self.tuning, miniters=10, disable=not self.verbose)
tbar.set_description("Tuning bracket width...")
for n in tbar:
- # for n in range(int(n_samples)):
+ # for n in range(int(self.tuning)):
rng.shuffle(order)
for i in range(self.n_dims):
@@ -203,31 +214,171 @@ def _sample_from_conditional(self, i: int, cxi, rng):
return xi, ux - lx
+class SliceSamplerSerial:
+ def __init__(
+ self,
+ log_prob_fn: Callable,
+ init_params: np.ndarray,
+ num_chains: int = 1,
+ thin: Optional[int] = None,
+ tuning: int = 50,
+ verbose: bool = True,
+ init_width: Union[float, np.ndarray] = 0.01,
+ max_width: float = float("inf"),
+ num_workers: int = 1,
+ ):
+ """Slice sampler in pure Numpy, running for each chain in serial.
+
+ Parallelization across CPUs is possible by setting num_workers > 1.
+
+ Args:
+ log_prob_fn: Log prob function.
+ init_params: Initial parameters.
+ num_chains: Number of MCMC chains to run in parallel
+ thin: amount of thinning; if None, no thinning.
+ tuning: Number of tuning steps for brackets.
+ verbose: Show/hide additional info such as progress bars.
+ init_width: Inital width of brackets.
+ max_width: Maximum width of brackets.
+ num_workers: Number of parallel workers to use.
+ """
+ self._log_prob_fn = log_prob_fn
+
+ self.x = init_params
+ self.num_chains = num_chains
+ self.thin = thin
+ self.tuning = tuning
+ self.verbose = verbose
+
+ self.init_width = init_width
+ self.max_width = max_width
+
+ self.n_dims = self.x.size
+ self.num_workers = num_workers
+ self._samples = None
+
+ def run(self, num_samples: int) -> np.ndarray:
+ """Runs MCMC and returns thinned samples.
+
+ Sampling is performed parallelized across CPUs if self.num_workers > 1.
+ Parallelization is seeded across workers.
+
+ Note: Thinning is performed internally.
+
+ Args:
+ num_samples: Number of samples to generate
+ Returns:
+ MCMC samples in shape (num_chains, num_samples_per_chain, num_dim)
+ """
+
+ num_chains, dim_samples = self.x.shape
+
+ # Generate seeds for workers from current random state.
+ seeds = torch.randint(high=2**31, size=(num_chains,))
+
+ with tqdm_joblib(
+ tqdm(
+ range(num_chains), # type: ignore
+ disable=not self.verbose or self.num_workers == 1,
+ desc=f"""Running {self.num_chains} MCMC chains with
+ {self.num_workers} worker{"s" if self.num_workers>1 else ""}.""",
+ total=self.num_chains,
+ )
+ ):
+ all_samples = Parallel(n_jobs=self.num_workers)(
+ delayed(self.run_fun)(num_samples, initial_params_batch, seed)
+ for initial_params_batch, seed in zip(self.x, seeds)
+ )
+
+ samples = np.stack(all_samples).astype(np.float32)
+ samples = samples.reshape(num_chains, -1, dim_samples) # chains, samples, dim
+ samples = samples[:, :: self.thin, :] # thin chains
+
+ # save samples
+ self._samples = samples
+
+ return samples
+
+ def run_fun(self, num_samples, inits, seed) -> np.ndarray:
+ """Runs MCMC for a given number of samples starting at inits."""
+ np.random.seed(seed)
+ posterior_sampler = SliceSampler(
+ inits,
+ lp_f=self._log_prob_fn,
+ max_width=self.max_width,
+ init_width=self.init_width,
+ thin=self.thin,
+ tuning=self.tuning,
+ # turn off pbars in parallel mode.
+ verbose=self.num_workers == 1 and self.verbose,
+ )
+ return posterior_sampler.gen(num_samples)
+
+ def get_samples(
+ self, num_samples: Optional[int] = None, group_by_chain: bool = True
+ ) -> np.ndarray:
+ """Returns samples from last call to self.run.
+
+ Raises ValueError if no samples have been generated yet.
+
+ Args:
+ num_samples: Number of samples to return (for each chain if grouped by
+ chain), if too large, all samples are returned (no error).
+ group_by_chain: Whether to return samples grouped by chain (chain x samples
+ x dim_params) or flattened (all_samples, dim_params).
+
+ Returns:
+ samples
+ """
+ if self._samples is None:
+ raise ValueError("No samples found from MCMC run.")
+ # if not grouped by chain, flatten samples into (all_samples, dim_params)
+ if not group_by_chain:
+ samples = self._samples.reshape(-1, self._samples.shape[2])
+ else:
+ samples = self._samples
+
+ # if not specified return all samples
+ if num_samples is None:
+ return samples
+ # otherwise return last num_samples (for each chain when grouped).
+ elif group_by_chain:
+ return samples[:, -num_samples:, :]
+ else:
+ return samples[-num_samples:, :]
+
+
class SliceSamplerVectorized:
def __init__(
self,
log_prob_fn: Callable,
init_params: np.ndarray,
num_chains: int = 1,
+ thin: Optional[int] = None,
tuning: int = 50,
verbose: bool = True,
init_width: Union[float, np.ndarray] = 0.01,
max_width: float = float("inf"),
+ num_workers: int = 1,
):
"""Slice sampler in pure Numpy, vectorized evaluations across chains.
Args:
log_prob_fn: Log prob function.
init_params: Initial parameters.
- verbose: Show/hide additional info such as progress bars.
+ num_chains: Number of MCMC chains to run in parallel
+ thin: amount of thinning; if None, no thinning.
tuning: Number of tuning steps for brackets.
+ verbose: Show/hide additional info such as progress bars.
init_width: Inital width of brackets.
max_width: Maximum width of brackets.
+ num_workers: Number of parallel workers to use (not implemented.)
"""
self._log_prob_fn = log_prob_fn
self.x = init_params
self.num_chains = num_chains
+ self.thin = 1 if thin is None else thin
self.tuning = tuning
self.verbose = verbose
@@ -236,6 +387,14 @@ def __init__(
self.n_dims = self.x.size
+ self._samples = None
+
+ # TODO: implement parallelization across batches of chains.
+ if num_workers > 1:
+ warn(
+ """Parallelization of vectorized slice sampling not implement, running
+ serially."""
+ )
self._reset()
def _reset(self):
@@ -425,109 +584,41 @@ def run(self, num_samples: int) -> np.ndarray:
samples = np.stack([self.state[c]["samples"] for c in range(self.num_chains)])
+ samples = samples[:, :: self.thin, :] # thin chains
+
+ self._samples = samples
+
return samples
+ def get_samples(
+ self, num_samples: Optional[int] = None, group_by_chain: bool = True
+ ) -> np.ndarray:
+ """Returns samples from last call to self.run.
-def slice_np_parallized(
- potential_function: Callable,
- initial_params: torch.Tensor,
- num_samples: int,
- thin: int,
- warmup_steps: int,
- vectorized: bool,
- num_workers: int = 1,
- show_progress_bars: bool = False,
-):
- """Run slice np (vectorized) parallized over CPU cores.
-
- In case of the vectorized version of slice np parallization happens over batches of
- chains to still exploit vectorization.
-
- MCMC progress bars are omitted if num_workers > 1 to reduce clutter. Instead the
- progress over chains is shown.
-
- Args:
- potential_function: potential function
- initial_params: initital parameters, one for each chain
- num_samples: number of MCMC samples to produce
- thin: thinning factor
- warmup_steps: number of warmup / burnin steps
- vectorized: whether to use the vectorized version
- num_workers: number of CPU cores to use
- show_progress_bars: whether to show progress bars
-
- Returns:
- Tensor: final MCMC samples of each chain (num_chains, num_samples, dim_samples)
- """
- num_chains, dim_samples = initial_params.shape
-
- # Generate seeds for workers from current random state.
- seeds = torch.randint(high=2**31, size=(num_chains,))
-
- if not vectorized:
- # Define run function for given input.
- def run_slice_np(inits, seed):
- # Seed current job.
- np.random.seed(seed)
- posterior_sampler = SliceSampler(
- tensor2numpy(inits).reshape(-1),
- lp_f=potential_function,
- thin=thin,
- # Show pbars of workers only for single worker
- verbose=show_progress_bars and num_workers == 1,
- )
- if warmup_steps > 0:
- posterior_sampler.gen(int(warmup_steps))
- return posterior_sampler.gen(ceil(num_samples / num_chains))
-
- # For sequential chains each batch has only a single chain.
- batch_size = 1
- run_fun = run_slice_np
-
- else: # Sample all chains at the same time
-
- # Define local function to run a batch of chains vectorized.
- def run_slice_np_vectorized(inits, seed):
- # Seed current job.
- np.random.seed(seed)
- posterior_sampler = SliceSamplerVectorized(
- init_params=tensor2numpy(inits),
- log_prob_fn=potential_function,
- num_chains=inits.shape[0],
- # Show pbars of workers only for single worker
- verbose=show_progress_bars and num_workers == 1,
- )
- warmup_ = warmup_steps * thin
- num_samples_ = ceil((num_samples * thin) / num_chains)
- samples = posterior_sampler.run(warmup_ + num_samples_)
- samples = samples[:, warmup_:, :] # discard warmup steps
- samples = samples[:, ::thin, :] # thin chains
- samples = torch.from_numpy(samples) # chains x samples x dim
- return samples
+ Raises ValueError if no samples have been generated yet.
- # For vectorized case a batch contains multiple chains to exploit vectorization.
- batch_size = ceil(num_chains / num_workers)
- run_fun = run_slice_np_vectorized
-
- # Parallize over batch of chains.
- initial_params_in_batches = torch.split(initial_params, batch_size, dim=0)
- num_batches = len(initial_params_in_batches)
-
- # Show progress bars over batches.
- with tqdm_joblib(
- tqdm(
- range(num_batches), # type: ignore
- disable=not show_progress_bars or num_workers == 1,
- desc=f"""Running {num_chains} MCMC chains with {num_workers} worker{"s" if
- num_workers>1 else ""} (batch_size={batch_size}).""",
- total=num_chains,
- )
- ):
- all_samples = Parallel(n_jobs=num_workers)(
- delayed(run_fun)(initial_params_batch, seed)
- for initial_params_batch, seed in zip(initial_params_in_batches, seeds)
- )
- all_samples = np.stack(all_samples).astype(np.float32)
- samples = torch.from_numpy(all_samples)
+ Args:
+ num_samples: Number of samples to return (for each chain if grouped by
+ chain), if too large, all samples are returned (no error).
+ group_by_chain: Whether to return samples grouped by chain (chain x samples
+ x dim_params) or flattened (all_samples, dim_params).
- return samples.reshape(num_chains, -1, dim_samples) # chains x samples x dim
+ Returns:
+ samples
+ """
+ if self._samples is None:
+ raise ValueError("No samples found from MCMC run.")
+ # if not grouped by chain, flatten samples into (all_samples, dim_params)
+ if not group_by_chain:
+ samples = self._samples.reshape(-1, self._samples.shape[2])
+ else:
+ samples = self._samples
+
+ # if not specified return all samples
+ if num_samples is None:
+ return samples
+ # otherwise return last num_samples (for each chain when grouped).
+ elif group_by_chain:
+ return samples[:, -num_samples:, :]
+ else:
+ return samples[-num_samples:, :]
diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py
index bde9d7375..0dc2dd811 100644
--- a/sbi/utils/sbiutils.py
+++ b/sbi/utils/sbiutils.py
@@ -6,9 +6,12 @@
from math import pi
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
+import arviz as az
import pyknos.nflows.transforms as transforms
import torch
import torch.distributions.transforms as torch_tf
+from arviz.data import InferenceData
+from numpy import ndarray
from pyro.distributions import Empirical
from torch import Tensor
from torch import nn as nn
diff --git a/setup.py b/setup.py
index f997f840e..6064d0cba 100644
--- a/setup.py
+++ b/setup.py
@@ -24,6 +24,7 @@
REQUIRES_PYTHON = ">=3.6.0"
REQUIRED = [
+ "arviz",
"joblib>=1.0.0",
"matplotlib",
"numpy",
diff --git a/tests/mcmc_test.py b/tests/mcmc_test.py
index c9035f2d4..f68717b13 100644
--- a/tests/mcmc_test.py
+++ b/tests/mcmc_test.py
@@ -3,19 +3,32 @@
from __future__ import annotations
-from typing import Union
+from math import ceil
+import arviz as az
import numpy as np
import pytest
import torch
from torch import eye, ones, zeros
-
+from torch.distributions import Uniform
+
+from sbi.inference import (
+ SNLE,
+ MCMCPosterior,
+ likelihood_estimator_based_potential,
+ prepare_for_sbi,
+ simulate_for_sbi,
+)
from sbi.samplers.mcmc.slice_numpy import (
SliceSampler,
+ SliceSamplerSerial,
SliceSamplerVectorized,
- slice_np_parallized,
)
-from sbi.simulators.linear_gaussian import true_posterior_linear_gaussian_mvn_prior
+from sbi.simulators.linear_gaussian import (
+ diagonal_linear_gaussian,
+ true_posterior_linear_gaussian_mvn_prior,
+)
+from sbi.utils import likelihood_nn, tensor2numpy
from tests.test_utils import check_c2st
@@ -43,17 +56,25 @@ def test_c2st_slice_np_on_Gaussian(num_dim: int):
def lp_f(x):
return target_distribution.log_prob(torch.as_tensor(x, dtype=torch.float32))
- sampler = SliceSampler(lp_f=lp_f, x=np.zeros((num_dim,)).astype(np.float32))
- _ = sampler.gen(warmup)
+ sampler = SliceSampler(
+ lp_f=lp_f,
+ x=np.zeros((num_dim,)).astype(np.float32),
+ tuning=warmup,
+ )
+ warmup_samples = sampler.gen(warmup)
+ assert warmup_samples.shape == (warmup, num_dim)
+
samples = sampler.gen(num_samples)
+ assert samples.shape == (num_samples, num_dim)
samples = torch.as_tensor(samples, dtype=torch.float32)
- check_c2st(samples, target_samples, alg=f"slice_np")
+ check_c2st(samples, target_samples, alg="slice_np")
@pytest.mark.parametrize("num_dim", (1, 2))
-def test_c2st_slice_np_vectorized_on_Gaussian(num_dim: int):
+@pytest.mark.parametrize("slice_sampler", (SliceSamplerVectorized, SliceSamplerSerial))
+def test_c2st_slice_np_vectorized_on_Gaussian(num_dim: int, slice_sampler):
"""Test MCMC on Gaussian, comparing to ground truth target via c2st.
Args:
@@ -63,6 +84,7 @@ def test_c2st_slice_np_vectorized_on_Gaussian(num_dim: int):
num_samples = 500
warmup = 500
num_chains = 5
+ thin = 2
likelihood_shift = -1.0 * ones(num_dim)
likelihood_cov = 0.3 * eye(num_dim)
@@ -77,7 +99,7 @@ def test_c2st_slice_np_vectorized_on_Gaussian(num_dim: int):
def lp_f(x):
return target_distribution.log_prob(torch.as_tensor(x, dtype=torch.float32))
- sampler = SliceSamplerVectorized(
+ sampler = slice_sampler(
log_prob_fn=lp_f,
init_params=np.zeros(
(
@@ -85,23 +107,32 @@ def lp_f(x):
num_dim,
)
).astype(np.float32),
+ tuning=warmup,
+ thin=thin,
num_chains=num_chains,
)
- samples = sampler.run(warmup + int(num_samples / num_chains))
+ samples = sampler.run(thin * (warmup + int(num_samples / num_chains)))
+ assert samples.shape == (
+ num_chains,
+ warmup + int(num_samples / num_chains),
+ num_dim,
+ )
samples = samples[:, warmup:, :]
samples = samples.reshape(-1, num_dim)
samples = torch.as_tensor(samples, dtype=torch.float32)
- check_c2st(samples, target_samples, alg="slice_np_vectorized")
+ alg = {
+ SliceSamplerVectorized: "slice_np_vectorized",
+ SliceSamplerSerial: "slice_np",
+ }[slice_sampler]
+
+ check_c2st(samples, target_samples, alg=alg)
@pytest.mark.parametrize("vectorized", (False, True))
@pytest.mark.parametrize("num_workers", (1, 10))
-@pytest.mark.parametrize("seed", (None, 42))
-def test_c2st_slice_np_parallelized(
- vectorized: bool, num_workers: int, seed: Union[None, int]
-):
+def test_c2st_slice_np_parallelized(vectorized: bool, num_workers: int):
"""Test MCMC on Gaussian, comparing to ground truth target via c2st.
Args:
@@ -110,8 +141,9 @@ def test_c2st_slice_np_parallelized(
"""
num_dim = 2
num_samples = 500
- warmup = 500
- num_chains = 5
+ warmup = 100
+ num_chains = 10
+ thin = 2
likelihood_shift = -1.0 * ones(num_dim)
likelihood_cov = 0.3 * eye(num_dim)
@@ -128,37 +160,73 @@ def lp_f(x):
initial_params = torch.zeros((num_chains, num_dim))
- # Maybe test seeding.
- if seed is not None:
- torch.manual_seed(seed)
-
- samples = slice_np_parallized(
- lp_f,
- initial_params,
- num_samples,
- thin=1,
- warmup_steps=warmup,
- vectorized=vectorized,
+ if not vectorized:
+ SliceSamplerMultiChain = SliceSamplerSerial
+ else:
+ SliceSamplerMultiChain = SliceSamplerVectorized
+
+ posterior_sampler = SliceSamplerMultiChain(
+ init_params=tensor2numpy(initial_params),
+ log_prob_fn=lp_f,
+ num_chains=num_chains,
+ thin=thin,
+ verbose=False,
num_workers=num_workers,
- show_progress_bars=False,
)
- # Repeat to test seeding.
- if seed is not None:
- torch.manual_seed(seed)
- samples_2 = slice_np_parallized(
- lp_f,
- initial_params,
- num_samples,
- thin=1,
- warmup_steps=warmup,
- vectorized=vectorized,
- num_workers=num_workers,
- )
- # Test seeding.
- assert torch.allclose(samples, samples_2)
-
+ warmup_ = warmup * thin
+ num_samples_ = ceil((num_samples * thin) / num_chains)
+ samples = posterior_sampler.run(warmup_ + num_samples_) # chains x samples x dim
+ samples = samples[:, warmup:, :] # discard warmup steps
samples = torch.as_tensor(samples, dtype=torch.float32).reshape(-1, num_dim)
check_c2st(
samples, target_samples, alg=f"slice_np {'vectorized' if vectorized else ''}"
)
+
+
+@pytest.mark.parametrize(
+ "method",
+ (
+ "nuts",
+ "hmc",
+ "slice",
+ "slice_np",
+ "slice_np_vectorized",
+ ),
+)
+def test_getting_inference_diagnostics(method):
+
+ num_samples = 100
+ num_dim = 2
+ num_chains = 2
+
+ # Use composed prior to test MultipleIndependent case.
+ prior = [
+ Uniform(low=-ones(1), high=ones(1)),
+ Uniform(low=-ones(1), high=ones(1)),
+ ]
+
+ simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior)
+ density_estimator = likelihood_nn("maf", num_transforms=3)
+ inference = SNLE(density_estimator=density_estimator, show_progress_bars=False)
+
+ theta, x = simulate_for_sbi(simulator, prior, 1000, simulation_batch_size=50)
+ likelihood_estimator = inference.append_simulations(theta, x).train(
+ training_batch_size=100
+ )
+
+ x_o = zeros((1, num_dim))
+ potential_fn, theta_transform = likelihood_estimator_based_potential(
+ prior=prior, likelihood_estimator=likelihood_estimator, x_o=x_o
+ )
+ posterior = MCMCPosterior(
+ proposal=prior,
+ potential_fn=potential_fn,
+ theta_transform=theta_transform,
+ thin=3,
+ num_chains=num_chains,
+ )
+ posterior.sample(sample_shape=(num_samples,), method=method)
+ idata = posterior.get_arviz_inference_data()
+
+ az.plot_trace(idata)
diff --git a/tests/posterior_sampler_test.py b/tests/posterior_sampler_test.py
new file mode 100644
index 000000000..b397ee23f
--- /dev/null
+++ b/tests/posterior_sampler_test.py
@@ -0,0 +1,80 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Affero General Public License v3, see .
+
+from __future__ import annotations
+
+import pytest
+from pyro.infer.mcmc import MCMC
+from torch import eye, zeros
+from torch.distributions import MultivariateNormal
+
+from sbi import utils as utils
+from sbi.inference import (
+ SNL,
+ MCMCPosterior,
+ likelihood_estimator_based_potential,
+ prepare_for_sbi,
+ simulate_for_sbi,
+)
+from sbi.samplers.mcmc import SliceSamplerSerial, SliceSamplerVectorized
+from sbi.simulators.linear_gaussian import diagonal_linear_gaussian
+
+
+@pytest.mark.parametrize(
+ "sampling_method",
+ (
+ "slice_np",
+ "slice_np_vectorized",
+ "slice",
+ "nuts",
+ "hmc",
+ ),
+)
+def test_api_posterior_sampler_set(sampling_method: str, set_seed):
+ """Runs SNL and checks that posterior_sampler is correctly set.
+
+ Args:
+ mcmc_method: which mcmc method to use for sampling
+ set_seed: fixture for manual seeding
+ """
+
+ num_dim = 2
+ num_samples = 10
+ num_trials = 2
+ num_simulations = 10
+ x_o = zeros((num_trials, num_dim))
+ # Test for multiple chains is cheap when vectorized.
+ num_chains = 3 if sampling_method in "slice_np_vectorized" else 1
+
+ prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
+ simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior)
+ inference = SNL(prior, show_progress_bars=False)
+
+ theta, x = simulate_for_sbi(
+ simulator, prior, num_simulations, simulation_batch_size=10
+ )
+ estimator = inference.append_simulations(theta, x).train(max_num_epochs=5)
+ potential_fn, transform = likelihood_estimator_based_potential(
+ estimator, prior, x_o
+ )
+ posterior = MCMCPosterior(
+ potential_fn, theta_transform=transform, method=sampling_method, proposal=prior
+ )
+
+ assert posterior.posterior_sampler is None
+ posterior.sample(
+ sample_shape=(num_samples, num_chains),
+ x=x_o,
+ mcmc_parameters={
+ "thin": 3,
+ "num_chains": num_chains,
+ "init_strategy": "prior",
+ },
+ )
+
+ if sampling_method in ["slice", "hmc", "nuts"]:
+ assert type(posterior.posterior_sampler) is MCMC
+ elif sampling_method == "slice_np":
+ assert type(posterior.posterior_sampler) is SliceSamplerSerial
+ else: # sampling_method == "slice_np_vectorized"
+ assert type(posterior.posterior_sampler) is SliceSamplerVectorized
diff --git a/tutorials/14_multi-trial-data-and-mixed-data-types.ipynb b/tutorials/14_multi-trial-data-and-mixed-data-types.ipynb
index 2af57ffac..efba7bcc2 100644
--- a/tutorials/14_multi-trial-data-and-mixed-data-types.ipynb
+++ b/tutorials/14_multi-trial-data-and-mixed-data-types.ipynb
@@ -43,7 +43,6 @@
"outputs": [],
"source": [
"import torch\n",
- "import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from torch import zeros, ones, eye\n",
@@ -793,7 +792,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "Python 3.8.13 ('sbi')",
"language": "python",
"name": "python3"
},
@@ -807,7 +806,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.7.11"
+ "version": "3.8.13"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "9ef9b53a5ce850816b9705a866e49207a37a04a71269aa157d9f9ab944ea42bf"
+ }
}
},
"nbformat": 4,
diff --git a/tutorials/15_mcmc_diagnostics_with_arviz.ipynb b/tutorials/15_mcmc_diagnostics_with_arviz.ipynb
new file mode 100644
index 000000000..7d85c0b63
--- /dev/null
+++ b/tutorials/15_mcmc_diagnostics_with_arviz.ipynb
@@ -0,0 +1,1218 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# MCMC diagnostics with [Arviz](https://python.arviz.org/)\n",
+ "\n",
+ "This tutorial shows how to evaluate the quality of MCMC samples generated via `sbi` using the `arviz` package. \n",
+ "\n",
+ "We demonstrate this case using the trial-based simulator presented in Tutorial 14: A toy simulator mimicking the drift-diffusion model of decision-making. \n",
+ "\n",
+ "Outline:\n",
+ "\n",
+ "1) Train MNLE to approximate the likelihood underlying the simulator\n",
+ "2) Run MCMC using `pyro` MCMC samplers via `sbi` interface\n",
+ "3) Use `arviz` to visualize the posterior, predictive distributions and MCMC diagnostics. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import arviz as az\n",
+ "import torch\n",
+ "\n",
+ "from sbi.inference import MNLE, likelihood_estimator_based_potential\n",
+ "from pyro.distributions import InverseGamma\n",
+ "from torch.distributions import Beta, Binomial, Gamma\n",
+ "from sbi.utils import MultipleIndependent\n",
+ "\n",
+ "from sbi.inference import MCMCPosterior\n",
+ "\n",
+ "# Seeding\n",
+ "torch.manual_seed(1);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Toy simulator for mixed data\n",
+ "def mixed_simulator(theta):\n",
+ " beta, ps = theta[:, :1], theta[:, 1:]\n",
+ "\n",
+ " choices = Binomial(probs=ps).sample()\n",
+ " rts = InverseGamma(concentration=2 * torch.ones_like(beta), rate=beta).sample()\n",
+ "\n",
+ " return torch.cat((rts, choices), dim=1)\n",
+ "\n",
+ "\n",
+ "# Define independent priors for each dimension.\n",
+ "prior = MultipleIndependent(\n",
+ " [\n",
+ " Gamma(torch.tensor([1.0]), torch.tensor([0.5])),\n",
+ " Beta(torch.tensor([2.0]), torch.tensor([2.0])),\n",
+ " ],\n",
+ " validate_args=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train MNLE to approximate the likelihood\n",
+ "\n",
+ "For details see [tutorial 14](https://www.mackelab.org/sbi/tutorial/14_multi-trial-data-and-mixed-data-types/). \n",
+ "\n",
+ "Here, we pass `mcmc_method=\"nuts\"` in order to use the underlying [`pyro` No-U-turn sampler](https://docs.pyro.ai/en/1.8.1/mcmc.html#nuts), but it would work as well with other samplers (e.g. \"slice_np_vectorized\", \"hmc\"). \n",
+ "\n",
+ "Additionally, when calling `posterior.sample(...)` we pass `return_arviz=True` so that the [`Arviz InferenceData`](https://arviz-devs.github.io/arviz/api/generated/arviz.InferenceData.html#arviz.InferenceData) object is returned. This object gives us access to the wealth of MCMC diagnostics tool provided by `arviz`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/janbolts/qode/sbi/sbi/neural_nets/mnle.py:60: UserWarning: The mixed neural likelihood estimator assumes that x contains\n",
+ " continuous data in the first n-1 columns (e.g., reaction times) and\n",
+ " categorical data in the last column (e.g., corresponding choices). If\n",
+ " this is not the case for the passed `x` do not use this function.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " Neural network successfully converged after 65 epochs."
+ ]
+ }
+ ],
+ "source": [
+ "# Generate training data and train MNLE.\n",
+ "num_simulations = 10000\n",
+ "theta = prior.sample((num_simulations,))\n",
+ "x = mixed_simulator(theta)\n",
+ "\n",
+ "trainer = MNLE(prior)\n",
+ "likelihood_estimator = trainer.append_simulations(theta, x).train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Run Pyro NUTS MCMC and obtain `arviz InferenceData` object"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/janbolts/qode/sbi/sbi/utils/sbiutils.py:280: UserWarning: An x with a batch size of 100 was passed. It will be interpreted as a batch of independent and identically\n",
+ " distributed data X={x_1, ..., x_n}, i.e., data generated based on the\n",
+ " same underlying (unknown) parameter. The resulting posterior will be with\n",
+ " respect to entire batch, i.e,. p(theta | X).\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Simulate \"observed\" data x_o\n",
+ "torch.manual_seed(42)\n",
+ "num_trials = 100\n",
+ "theta_o = prior.sample((1,))\n",
+ "x_o = mixed_simulator(theta_o.repeat(num_trials, 1))\n",
+ "\n",
+ "# Set MCMC parameters and run Pyro NUTS.\n",
+ "mcmc_parameters = dict(\n",
+ " num_chains=4,\n",
+ " thin=5,\n",
+ " warmup_steps=50,\n",
+ " init_strategy=\"proposal\",\n",
+ " method=\"nuts\",\n",
+ ")\n",
+ "num_samples = 1000\n",
+ "\n",
+ "# get the potential function and parameter transform for constructing the posterior\n",
+ "potential_fn, parameter_transform = likelihood_estimator_based_potential(\n",
+ " likelihood_estimator, prior, x_o\n",
+ ")\n",
+ "mnle_posterior = MCMCPosterior(\n",
+ " potential_fn, proposal=prior, theta_transform=parameter_transform, **mcmc_parameters\n",
+ ")\n",
+ "\n",
+ "mnle_samples = mnle_posterior.sample(\n",
+ " (num_samples,), x=x_o, show_progress_bars=False\n",
+ ")\n",
+ "# get arviz InferenceData object from posterior\n",
+ "inference_data = mnle_posterior.get_arviz_inference_data()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Generate `arviz` plots\n",
+ "\n",
+ "The resulting `InferenceData` object can be passed to most `arviz` plotting functions, and there are plenty see [here](https://arviz-devs.github.io/arviz/examples/index.html#) for an overview.\n",
+ "\n",
+ "To get a better understanding of the `InferenceData` object see [here](https://arviz-devs.github.io/arviz/schema/schema.html). \n",
+ "\n",
+ "Below and overview of common MCMC diagnostics plot, see the corresponding `arviz` documentation for interpretation of the plots. \n",
+ "\n",
+ "We will a full use-case using the SBI-MCMC-arviz workflow soon."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ " \n",
+ " - \n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
\n",
+ " \n",
+ "
<xarray.Dataset>\n",
+ "Dimensions: (chain: 4, draw: 1254, theta_dim_0: 2)\n",
+ "Coordinates:\n",
+ " * chain (chain) int64 0 1 2 3\n",
+ " * draw (draw) int64 0 1 2 3 4 5 6 ... 1248 1249 1250 1251 1252 1253\n",
+ " * theta_dim_0 (theta_dim_0) int64 0 1\n",
+ "Data variables:\n",
+ " theta (chain, draw, theta_dim_0) float32 2.125 0.8092 ... 0.8088\n",
+ "Attributes:\n",
+ " created_at: 2022-08-10T14:02:41.300799\n",
+ " arviz_version: 0.11.2
- chain: 4
- draw: 1254
- theta_dim_0: 2
theta
(chain, draw, theta_dim_0)
float32
2.125 0.8092 2.267 ... 1.848 0.8088
array([[[2.1245914 , 0.8092162 ],\n",
+ " [2.266512 , 0.8164251 ],\n",
+ " [2.036817 , 0.79519475],\n",
+ " ...,\n",
+ " [1.8315581 , 0.7797423 ],\n",
+ " [2.050106 , 0.7541253 ],\n",
+ " [1.9130744 , 0.7940467 ]],\n",
+ "\n",
+ " [[1.8672262 , 0.79227704],\n",
+ " [1.9084876 , 0.87156725],\n",
+ " [1.9282253 , 0.89998335],\n",
+ " ...,\n",
+ " [1.966494 , 0.7684441 ],\n",
+ " [1.9171734 , 0.76520354],\n",
+ " [1.9165115 , 0.8100004 ]],\n",
+ "\n",
+ " [[2.1789386 , 0.92230934],\n",
+ " [2.2388353 , 0.8388026 ],\n",
+ " [2.2388353 , 0.8388026 ],\n",
+ " ...,\n",
+ " [2.2749808 , 0.8510151 ],\n",
+ " [2.207828 , 0.843363 ],\n",
+ " [1.9331468 , 0.7714892 ]],\n",
+ "\n",
+ " [[2.0355892 , 0.902892 ],\n",
+ " [2.1481078 , 0.76634836],\n",
+ " [2.0999866 , 0.8813891 ],\n",
+ " ...,\n",
+ " [1.8813787 , 0.85843307],\n",
+ " [1.9198259 , 0.727575 ],\n",
+ " [1.8484387 , 0.80877745]]], dtype=float32)
- created_at :
- 2022-08-10T14:02:41.300799
- arviz_version :
- 0.11.2
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ "Inference data with groups:\n",
+ "\t> posterior"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inference_data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Diagnostic plots"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([,\n",
+ " ],\n",
+ " dtype=object)"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "