This repository contains the JAX implementation that accompanies the paper Probabilistic programming with programmable variational inference, as well as the experiments used to generate figures and numbers in the empirical evaluation section.
Table of contents
- Overview
- Reproducing the paper's claims
- Interpreting the results
- Notes on artifact evaluation
- Building on our work
Table of contents generated with markdown-toc
The artifact accompanying our paper is a Python library for probabilistic programming with variational inference.
Using our library, users can write probabilistic programs encoding probabilistic models and variational families. They can then define variational objectives to optimize, such as the ELBO. Our system can automate the estimation of gradients of these objectives, enabling gradient-based training.
For example, here is the model from Figure 2 of our paper, encoding a probability distribution on triples (x, y, z) located near a cone:
import genjax
from genjax import gensp, vi
@genjax.gen
def model():
x = vi.normal_reparam(0.0, 10.0) @ "x"
y = vi.normal_reparam(0.0, 10.0) @ "y"
rs = x**2 + y**2
z = vi.normal_reparam(rs, 0.1 + (rs / 100.0)) @ "z"
Our goal will be to infer values of
data = genjax.choice_map({"z": 5.0})
To do so, we define a variational family -- a parametric family of distributions over the latent variables (x, y).
@genjax.gen
def variational_family(_, ϕ):
μ1, μ2, log_σ1, log_σ2 = ϕ
x = vi.normal_reparam(μ1, jnp.exp(log_σ1)) @ "x"
y = vi.normal_reparam(μ2, jnp.exp(log_σ2)) @ "y"
We want to find parameters that minimize the ELBO, a loss function that encourages the variational family to be close to the Bayesian posterior over the latent variables:
objective = vi.elbo(model, variational_family, data)
For convenience, the ELBO is defined as a library function, but users can also define their own objectives.
Our library can automatically estimate gradients of the ELBO and other probabilistic loss functions, and these gradients can be used to optimize the variational family's parameters.
import jax
import jax.tree_util as jtu
# Training.
key = jax.random.PRNGKey(314159)
ϕ = (0.0, 0.0, 1.0, 1.0)
L = jax.jit(jax.vmap(objective.value_and_grad_estimate, in_axes=(0, None)))
for i in range(5000):
key, sub_key = jax.random.split(key)
sub_keys = jax.random.split(sub_key, 64)
(loss, (_, (_, ϕ_grads))) = L(sub_keys, ((), (data, ϕ)))
ϕ = jtu.tree_map(lambda v, g: v + 1e-3 * jnp.mean(g), ϕ, ϕ_grads)
The full code to run this example and plot results can be found in the experiments/fig_2_noisy_cone
directory.
To ensure reproducibility, we have packaged a frozen version of our genjax
library in this repository. Documentation is publicly available for the relevant modules, but please note that some APIs have changed since we submitted our paper, and the documentation may be out-of-sync with the code in this repository:
For guidance on usage of the exact library version included in this artifact, please instead see our included tutorial notebook on using and extending genjax.vi
.
We have provided code to reproduce all of the experiments in the paper, namely the results in Figure 2, Figure 7, Table 1, Table 2, and Table 4. At a high level, these experiments validate our claims that (1) our implementation introduces minimal overhead compared to hand-coded gradient estimators, (2) our implementation of variational inference is faster than Pyro's, and on par with NumPyro's, for algorithms that all systems can express, and (3) our system supports gradient estimators that Pyro and NumPyro do not automate, some of which empirically converge faster than the algorithms they do support.
We've organized the experiments code under the experiments
directory. The experiments
directory contains the following subdirectories, which map onto the figures and tables in the submitted version of the paper:
fig_2_noisy_cone
fig_7_air_estimator_evaluation
table_1_minibatch_gradient_benchmark
table_2_benchmark_timings
table_4_objective_values
Each directory contains code used to run the experiment, and also to reproduce the graphs or table results that appear in the submission.
Computational cost of the experiments: Some of the experiments can be run on CPU in a reasonable amount of time, whereas others require GPU. (See further discussion below.) For reference:
-
(CPU likely okay) For
fig_2_noisy_cone
andtable_4_objective_values
, a CPU should be sufficient to run the experiments. These experiments illustrate usage of our system to automate gradient estimators for hierarchical variational inference (HVI), and nested importance weighted HVI. -
(GPU likely required) For
fig_7_air_estimator_evaluation
,table_1_minibatch_gradient_benchmark
, andtable_2_benchmark_timings
, we recommend running on a GPU. These experiments illustrate various performance comparisons between our system, handcoded gradient estimators, and Pyro for several variational objectives and estimators.
There are several possible ways to make an environment which can run the experiments.
If you're planning to use the GPU, or are on a Linux device (this process will not work natively on Mac OS or Windows, due to jaxlib
) -- possibly the easiest way is to use docker
. There are two ways to use docker
to build an environment:
-
(Docker Hub) We've pushed an image to Docker Hub which you can use directly.
-
(Dockerfile) We've provided a
Dockerfile
with the repository. To build an image, run:docker build .
This will proceed to build an image which you can use.
Once you have a built image (either by building it yourself, or using the pre-built one from Docker Hub), you can then run a virtual machine using the image:
-
With GPU support (requires underlying Nvidia driver with support for CUDA 12)
If you have a Nvidia device and an installed driver with supports CUDA 12 (checked via
nvidia-smi
), you can run the image with GPU support:docker run --runtime nvidia -it <YOUR_IMAGE_ID>
NOTE: you may need to install the Nvidia container toolkit (github) to support usage of the
--runtime nvidia
flag with Docker. -
Without GPU support If you don't have access to a GPU, you can still run the image using (without GPU support):
docker run -it <YOUR_IMAGE_ID>
In both cases, <YOUR_IMAGE_ID>
is the hash of the image you built or the image you pulled from Docker Hub.
With this method, you can ignore the setup for poetry
and just
below, and jump to GPU acceleration (if you have access to a GPU) and then Running the experiments.
If you just want to run some of the experiments on CPU, or are on a Mac -- there's a recommended alternative path to setup a working Python environment: we utilize poetry
to manage Python dependencies, and utilize just
as a command runner.
We recommend installing both of these tools, using the documentation at the links provided (at a bare minimum, you'll need to install poetry
, but we also recommend installing just
to utilize some of our convenience commands to run experiments).
With poetry
installed, you can use poetry install
to create and install our dependencies into a virtual environment. Run:
poetry install
to create and install dependencies into a virtual environment. Then, run:
poetry shell
to instantiate the virtual environment.
Using conda
with poetry
Note that, if you're having issues with poetry
environment creation via poetry shell
or poetry install
-- you can also managed a conda
environment directly (by e.g using miniconda
).
conda create --name programmable-vi-3.10.12 python=3.10.12
conda activate programmable-vi-3.10.12
poetry install
Several of our experiments are computationally intensive, and we recommend GPU acceleration.
For GPU acceleration, we assume access to a CUDA 12 enabled environment. There is a convenience command to install torch
and jaxlib
with support for CUDA 12:
just gpu
This will fetch versions of torch
and jaxlib
which are compatible with each other (because we're benchmarking both torch
and jax
-enabled code).
The versions we've selected for torch
and jaxlib
we've checked for compatibility on a device with Nvidia driver support up to CUDA 12.4 (as checked via nvidia-smi
): if you want GPU acceleration, we recommend attempting to setup your system so that you can run this command successfully. If you have a CUDA 12 enabled system, you should be okay, but please report any issues to us on this repository.
Important
Several of the experiments are computationally intensive, and may take a long time to run. We recommend running them on a machine with a GPU, and using jax
and torch
backend that supports GPU computation.
In particular, fig_7_air_estimator_evaluation
(just fig_7
), table_2_benchmark_timings
(just table_2
), and table_1_minibatch_gradient_benchmark
(just table_1
) will take quite a long time on a CPU.
Using experiments.ipynb
to run experiments
There's a Jupyter notebook which provides access to all the experiments via a single interface. Jupyter also allows you to view the PDFs generated by the experiments, either in the notebook directly, or via the file browser interface.
Using just
to run experiments
To run all of the experiments, it suffices to run (this will take a long time and not recommended on CPU):
just run_all
The experiments will be run in order, and the figure results will be saved to the ./figs
directory at the top level as PDF files.
You can also run each experiment individually, the set of recipes available via just
are:
❯ just -l
Available recipes:
fig_2 # These are components for the overview figure for the paper.
fig_7 # These are components for the AIR figure for the paper.
gpu # get GPU jax
table_1 # Not a plot, just timings printed out.
table_2 # Not a plot, just timings printed out using `pytest-benchmark`.
table_4 # Not a plot, just timings printed out.
meaning that one can run any of these experiments using just
, for example:
just table_1
Moving figures to your local device
Note that, if you used Docker to run the experiments and produce figures into ./figs
, you can copy them to your local machine using docker cp
:
docker cp <container id>:/home/wizard/figs /tmp
There are also several abbreviations which are not collected in a single place in the artifact:
- AIR -- Attend, Infer, Repeat, a generative model of multi-object images
- VAE -- variational autoencoder
- ELBO -- evidence lower bound
- MVD -- the measure valued derivative estimator
- REINFORCE -- the score function derivative estimator
- (IWELBO, or IWAE) -- importance weighted ELBO
- HVI -- hierarchical variational inference
- IWHVI -- importance weighted HVI
- DIWHVI -- doubly importance weighted HVI
Several of our experiments (the experiments which produce results for the tables) print out results to stdout. Below, we give a guide to interpreting the results:
Sign convention Systems which implement variational optimization can vary in terms of their sign convention for objective values. In our system, we consider maximization (maximizing the ELBO). Pyro and NumPyro consider minimization (minimizing the negative ELBO). These two processes are equivalent. Note that our print outs don't choose a standard, but our reported results do (we convert Pyro and NumPyro negative objective values to objective values by multiplying by -1).
For Table 1, you'll see a print out which looks like the following (specific numbers may depend on your runtime environment):
GenJAX VI timings:
Batch sizes: [64, 128, 256, 512, 1024]
(array([0.17483307, 0.23848654, 0.40820748, 0.74701416, 1.5442369 ], dtype=float32),
array([0.08583035, 0.06719527, 0.13345473, 0.24723288, 0.5431073 ], dtype=float32))
Handcoded timings:
Batch sizes: [64, 128, 256, 512, 1024]
(array([0.13013603, 0.21644622, 0.43758354, 0.75196713, 1.5987686 ], dtype=float32),
array([0.03493173, 0.07299326, 0.18554626, 0.26609895, 0.55678135], dtype=float32))
which we used to produce Table 1:
The "columns" of the print out match to the batch size column: for each timing experiment, the first array in the tuple returned by the print out is the row of mean (
The numbers from GenJAX VI timings
are mapped to the Handcoded timings
are mapped to the
To generate Table 2 in the paper, we use pytest-benchmark
to generate timing statistics. The print out will likely look something like this (the specific numbers may depend on your runtime environment -- but the general trend, that our timings are orders of magnitude faster than Pyro, doesn't depend on runtime):
We took the Mean
and StdDev
column numbers to generate the results (of form
The labels and the numbers for the columns in Table 2 are mapped from the named rows in the print out, according to the Mean
and StdDev
column numbers (as mentioned above):
-
test_benchmark_genjax_reinforce
andtest_benchmark_pyro[TraceGraph_ELBO]
$\rightarrow$ $\text{REINFORCE}$ -
test_benchmark_genjax_mvd
$\rightarrow$ $\text{MVD}$ -
test_benchmark_genjax_enum
andtest_benchmark_pyro[TraceEnum_ELBO]
$\rightarrow$ $\text{ENUM}$ -
test_benchmark_genjax_iwae_reinforce
andtest_benchmark_pyro_reinforce[RenyiELBO]
$\rightarrow$ $\text{IWAE} + \text{REINFORCE}$ -
test_benchmark_genjax_iwae_mvd
$\rightarrow$ $\text{IWAE} + \text{MVD}$
Each of the names on the right hand side of the arrows above correspond to particular gradient estimators strategies used in variational inference.
For Table 4, (as mentioned in Abbreviations), you'll see a print out which looks like this (specific numbers may depend on your runtime environment):
poetry run python experiments/table_4_objective_values/genjax_cone.py
ELBO:
(Array(-8.0759735, dtype=float32), Array(0.8189303, dtype=float32))
IWAE(K = 5):
(Array(-7.674431, dtype=float32), Array(2.659954, dtype=float32))
poetry run python experiments/table_4_objective_values/genjax_cone_marginal.py
HVI-ELBO(N = 1):
(Array(-9.751299, dtype=float32), Array(0.9588641, dtype=float32))
IWHVI(N = 5, K = 1):
(Array(-8.182691, dtype=float32), Array(0.91353637, dtype=float32))
IWHVI(N = 5, K = 5) (also called DIWHVI):
(Array(-7.298371, dtype=float32), Array(1.6482085, dtype=float32))
poetry run python experiments/table_4_objective_values/numpyro_cone.py
100%|█████████████████████| 6000/6000 [00:01<00:00, 5377.34it/s, init loss: 450.7343, avg. loss [5701-6000]: 8.0853]
NumPyro TraceGraph ELBO:
(Array(8.087664, dtype=float32), Array(0.11515266, dtype=float32))
100%|██████████████████████| 6000/6000 [00:01<00:00, 5432.70it/s, init loss: 71.4061, avg. loss [5701-6000]: 7.9096]
NumPyro RenyiELBO(k = 5):
(Array(7.869861, dtype=float32), Array(1.9360248, dtype=float32))
poetry run python experiments/table_4_objective_values/pyro_cone.py
Guessed max_plate_nesting = 1
Pyro ELBO:
(tensor(8.0826), tensor(0.1097))
Guessed max_plate_nesting = 1
Pyro IWAE(K = 5):
(tensor(7.8314), tensor(2.4545))
We've added spaces in the above print out between the independent experiments involved in producing Table 4. Every experiment produces a tuple e.g.
Pyro IWAE(K = 5):
(tensor(7.8314), tensor(2.4545))
the first element is the mean objective value over optimization, and the second is the standard deviation. We use these numbers to generate Table 4:
A few things to note:
- (Abbreviations and name changes) In the print out, the IWAE and RenyiELBO labels are equivalent to IWELBO in Table 4. The number of importance weighted samples (in the print out,
K = 5
) is converted to$n = 5$ in Table 4 (e.g. forPyro IWAE(K = 5)
and ourIWAE(K = 5)
). In the experiments,HVI-ELBO
is equivalent toHVI
in Table 4. All other names are the same between the print out and Table 4. - All system comparison experiments (Pyro and NumPyro) are labeled with their names in this table.
- (What the numbers mean) We did not report the standard deviation in this table: for each experiment, the first array is the mean over several trials, and the second is standard deviation.
-
(Objective value convention) Pyro and NumPyro use the negative ELBO as their objective (negative ELBO minimization), to compare with
genjax.vi
(and what we've done in Table 4) is apply a minus sign to the Pyro and NumPyro results.
For our submission results, our hardware was a Linux box with a Nvidia RTX 4090, and an AMD Ryzen 7 7800x3D CPU, with CUDA support (checked via nvidia-smi
) up to 12.4.
We also ran our experiments on a Linux box with an Nvidia Tesla V100 SMX2 16 GB, and an Intel Xeon (8) @ 2.2 GHz, with support for CUDA 12. In both experiment environments, we observed the same general trends, including the same order-of-magnitude speed-ups of our JAX-based gradient estimators compared to Pyro's.
Note that even on GPU, Pyro's implementation of the reweighted wake-sleep (RWS) algorithm may be prohibitively slow (due to the batch_size = 1
restriction).
There are several ways to build on our work. We've provided a tutorial notebook which provides an introduction to our system in the context of using it for variational inference on new model and guide programs.