Skip to content

Commit

Permalink
Rename fine distributions to BMMs (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz authored Jun 28, 2024
1 parent 7fc5280 commit f4fa4f4
Show file tree
Hide file tree
Showing 23 changed files with 210 additions and 190 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ["3.11", "3.12"]
# Use only 3.11 as for 3.12 pytype is not supported yet
python-version: ["3.11"]
poetry-version: ["1.3.2"]

steps:
Expand Down
4 changes: 2 additions & 2 deletions docs/api/fine-distributions.md → docs/api/bmm.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Fine distributions
# Bend and Mix Models

## Core utilities

Expand All @@ -12,7 +12,7 @@

::: bmi.samplers._tfp.ProductDistribution

::: bmi.samplers._tfp.FineSampler
::: bmi.samplers._tfp.BMMSampler

## Basic distributions

Expand Down
4 changes: 2 additions & 2 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

[Samplers](samplers.md) represent joint probability distributions with known mutual information from which one can sample. They are lower level than `Tasks` and can be used to define new tasks by transformations which preserve mutual information.

### Fine distributions
[Subpackage](fine-distributions.md) implementing distributions in which the ground-truth mutual information may not be known analytically, but can be efficiently approximated using Monte Carlo methods.
### Bend and Mix Models
[Subpackage](bmm.md) implementing distributions known as *Bend and Mix Models*, for which the ground-truth mutual information may not be known analytically, but can be efficiently approximated using Monte Carlo methods.

## Interfaces
[Interfaces](interfaces.md) defines the main interfaces used in the package.
4 changes: 2 additions & 2 deletions docs/api/samplers.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ Samplers represent probability distributions with known mutual information.

::: bmi.samplers.ZeroInflatedPoissonizationSampler

## Fine distributions
## Bend and Mix Models

See the [fine distributions subpackage API](fine-distributions.md) for more information.
See the [Bend and Mix Models subpackage API](bmm.md) for more information.

### Auxiliary

Expand Down
106 changes: 54 additions & 52 deletions docs/fine-distributions.md → docs/bmm.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ nav:
- Home: index.md
- Estimators: estimators.md
- Benchmark: benchmarking-new-estimator.md
- Fine distributions: fine-distributions.md
- Bend and Mix Models: bmm.md
- Contributing: contributing.md
- API: api/index.md

Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ scipy = "^1.10.1"
tqdm = "^4.64.1"
tensorflow-probability = {extras = ["jax"], version = "^0.20.1"}

[tool.poetry.group.bayes]
optional = true

[tool.poetry.group.bayes.dependencies]
numpyro = "^0.14.0"

[tool.poetry.group.dev]
optional = true
Expand Down
7 changes: 4 additions & 3 deletions references.bib
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
@misc{mixtures-neural-critics-2023,
title={The Mixtures and the Neural Critics: On the Pointwise Mutual Information Profiles of Fine Distributions},
@misc{pmi-profiles-bmms-2023,
title={On the Properties and Estimation of Pointwise Mutual Information Profiles},
author={Paweł Czyż and Frederic Grabowski and Julia E. Vogt and Niko Beerenwinkel and Alexander Marx},
year={2023},
eprint={2310.10240},
archivePrefix={arXiv},
primaryClass={stat.ML}
primaryClass={stat.ML},
url={https://arxiv.org/abs/2310.10240}
}

@inproceedings{beyond-normal-2023,
Expand Down
50 changes: 25 additions & 25 deletions src/bmi/benchmark/tasks/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import bmi.samplers as samplers
import bmi.transforms as transforms
from bmi.benchmark.task import Task
from bmi.samplers import fine
from bmi.samplers import bmm

_MC_MI_ESTIMATE_SAMPLE = 100_000

Expand All @@ -15,10 +15,10 @@ def task_x(
) -> Task:
"""The X distribution."""

dist = fine.mixture(
dist = bmm.mixture(
proportions=jnp.array([0.5, 0.5]),
components=[
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
covariance=samplers.canonical_correlation([x * gaussian_correlation]),
mean=jnp.zeros(2),
dim_x=1,
Expand All @@ -27,7 +27,7 @@ def task_x(
for x in [-1, 1]
],
)
sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)
sampler = bmm.BMMSampler(dist, mi_estimate_sample=mi_estimate_sample)

return Task(
sampler=sampler,
Expand All @@ -47,44 +47,44 @@ def task_ai(
corr = 0.95
var_x = 0.04

dist = fine.mixture(
dist = bmm.mixture(
proportions=jnp.full(6, fill_value=1 / 6),
components=[
# I components
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([1.0, 0.0]),
covariance=np.diag([0.01, 0.2]),
),
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([1.0, 1]),
covariance=np.diag([0.05, 0.001]),
),
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([1.0, -1]),
covariance=np.diag([0.05, 0.001]),
),
# A components
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([-0.8, -0.2]),
covariance=np.diag([0.03, 0.001]),
),
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([-1.2, 0.0]),
covariance=jnp.array(
[[var_x, jnp.sqrt(var_x * 0.2) * corr], [jnp.sqrt(var_x * 0.2) * corr, 0.2]]
),
),
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([-0.4, 0.0]),
Expand All @@ -94,7 +94,7 @@ def task_ai(
),
],
)
sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)
sampler = bmm.BMMSampler(dist, mi_estimate_sample=mi_estimate_sample)

return Task(
sampler=sampler,
Expand All @@ -110,10 +110,10 @@ def task_galaxy(
) -> Task:
"""The Galaxy distribution."""

balls_mixt = fine.mixture(
balls_mixt = bmm.mixture(
proportions=jnp.array([0.5, 0.5]),
components=[
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
covariance=samplers.canonical_correlation([0.0], additional_y=1),
mean=jnp.array([x, x, x]) * distance / 2,
dim_x=2,
Expand All @@ -123,7 +123,7 @@ def task_galaxy(
],
)

base_sampler = fine.FineSampler(balls_mixt, mi_estimate_sample=mi_estimate_sample)
base_sampler = bmm.BMMSampler(balls_mixt, mi_estimate_sample=mi_estimate_sample)
a = jnp.array([[0, -1], [1, 0]])
spiral = transforms.Spiral(a, speed=speed)

Expand All @@ -150,10 +150,10 @@ def task_waves(

assert n_components > 0

base_dist = fine.mixture(
base_dist = bmm.mixture(
proportions=jnp.ones(n_components) / n_components,
components=[
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
covariance=jnp.diag(jnp.array([0.1, 1.0, 0.1])),
mean=jnp.array([x, 0, x % 4]) * 1.5,
dim_x=2,
Expand All @@ -162,7 +162,7 @@ def task_waves(
for x in range(n_components)
],
)
base_sampler = fine.FineSampler(base_dist, mi_estimate_sample=mi_estimate_sample)
base_sampler = bmm.BMMSampler(base_dist, mi_estimate_sample=mi_estimate_sample)
aux_sampler = samplers.TransformedSampler(
base_sampler,
transform_x=lambda x: x
Expand Down Expand Up @@ -193,10 +193,10 @@ def task_concentric_multinormal(

assert n_components > 0

dist = fine.mixture(
dist = bmm.mixture(
proportions=jnp.ones(n_components) / n_components,
components=[
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
covariance=jnp.diag(jnp.array(dim_x * [i**2] + [0.0001])),
mean=jnp.array(dim_x * [0.0] + [1.0 * i]),
dim_x=dim_x,
Expand All @@ -205,7 +205,7 @@ def task_concentric_multinormal(
for i in range(1, 1 + n_components)
],
)
sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)
sampler = bmm.BMMSampler(dist, mi_estimate_sample=mi_estimate_sample)

return Task(
sampler=sampler,
Expand Down Expand Up @@ -238,23 +238,23 @@ def task_multinormal_sparse_w_inliers(
eta_x=strength,
)

signal_dist = fine.MultivariateNormalDistribution(
signal_dist = bmm.MultivariateNormalDistribution(
dim_x=dim_x,
dim_y=dim_y,
covariance=params.correlation,
)

noise_dist = fine.ProductDistribution(
noise_dist = bmm.ProductDistribution(
dist_x=signal_dist.dist_x,
dist_y=signal_dist.dist_y,
)

dist = fine.mixture(
dist = bmm.mixture(
proportions=jnp.array([1 - inlier_fraction, inlier_fraction]),
components=[signal_dist, noise_dist],
)

sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)
sampler = bmm.BMMSampler(dist, mi_estimate_sample=mi_estimate_sample)

task_id = f"mult-sparse-w-inliers-{dim_x}-{dim_y}-{n_interacting}-{strength}-{inlier_fraction}"
return Task(
Expand Down
30 changes: 20 additions & 10 deletions src/bmi/estimators/external/gmm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""A Gaussian mixture model estimator, allowing for model-based
Bayesian estimator of mutual information.
The full description can be found [here](https://arxiv.org/abs/2310.10240).
Note that to use this estimator you need to install external dependencies:
```bash
$ pip install benchmark-mi[bayes]
```
"""

try:
import numpyro # type: ignore
import numpyro.distributions as dist # type: ignore
Expand All @@ -12,7 +22,7 @@
from numpy.typing import ArrayLike

from bmi.interface import BaseModel, IMutualInformationPointEstimator
from bmi.samplers import fine
from bmi.samplers import bmm
from bmi.utils import ProductSpace


Expand Down Expand Up @@ -74,14 +84,14 @@ def model(
)


def sample_into_fine_distribution(
def sample_into_bmm_distribution(
means: jnp.ndarray,
covariances: jnp.ndarray,
proportions: jnp.ndarray,
dim_x: int,
dim_y: int,
) -> fine.JointDistribution:
"""Builds a fine distribution from a Gaussian mixture model parameters."""
) -> bmm.JointDistribution:
"""Builds a bmm distribution from a Gaussian mixture model parameters."""
# Check if the dimensions are right
n_components = proportions.shape[0]
n_dims = dim_x + dim_y
Expand All @@ -90,7 +100,7 @@ def sample_into_fine_distribution(

# Build components
components = [
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=dim_x,
dim_y=dim_y,
mean=mean,
Expand All @@ -100,7 +110,7 @@ def sample_into_fine_distribution(
]

# Build a mixture model
return fine.mixture(proportions=proportions, components=components)
return bmm.mixture(proportions=proportions, components=components)


class GMMEstimatorParams(BaseModel):
Expand Down Expand Up @@ -185,12 +195,12 @@ def run_mcmc(self, x: ArrayLike, y: ArrayLike):
self._dim_x = space.dim_x
self._dim_y = space.dim_y

def get_fine_distribution(self, idx: int) -> fine.JointDistribution:
def get_bmm_distribution(self, idx: int) -> bmm.JointDistribution:
if self._mcmc is None:
raise ValueError("You need to run MCMC first. See the `run_mcmc` method.")

samples = self._mcmc.get_samples()
return sample_into_fine_distribution(
return sample_into_bmm_distribution(
means=samples["mu"][idx],
covariances=samples["cov"][idx],
proportions=samples["pi"][idx],
Expand All @@ -204,8 +214,8 @@ def get_sample_mi(self, idx: int, mc_samples: Optional[int] = None, key=None) ->
if key is None:
self.key, key = jax.random.split(self.key)

distribution = self.get_fine_distribution(idx)
mi, _ = fine.monte_carlo_mi_estimate(key=key, dist=distribution, n=mc_samples)
distribution = self.get_bmm_distribution(idx)
mi, _ = bmm.monte_carlo_mi_estimate(key=key, dist=distribution, n=mc_samples)
return mi

def get_posterior_mi(
Expand Down
4 changes: 2 additions & 2 deletions src/bmi/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

# isort: on
import bmi.samplers._tfp as fine
import bmi.samplers._tfp as bmm
from bmi.samplers._independent_coordinates import IndependentConcatenationSampler
from bmi.samplers._split_student_t import SplitStudentT
from bmi.samplers._splitmultinormal import BivariateNormalSampler, SplitMultinormal
Expand All @@ -33,7 +33,7 @@
"AdditiveUniformSampler",
"BaseSampler",
"canonical_correlation",
"fine",
"bmm",
"parametrised_correlation_matrix",
"BivariateNormalSampler",
"SplitMultinormal",
Expand Down
4 changes: 2 additions & 2 deletions src/bmi/samplers/_tfp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# isort: on
from bmi.samplers._tfp._product import ProductDistribution
from bmi.samplers._tfp._wrapper import FineSampler
from bmi.samplers._tfp._wrapper import BMMSampler

__all__ = [
"JointDistribution",
Expand All @@ -30,7 +30,7 @@
"MultivariateNormalDistribution",
"MultivariateStudentDistribution",
"ProductDistribution",
"FineSampler",
"BMMSampler",
"construct_multivariate_normal_distribution",
"construct_multivariate_student_distribution",
]
Loading

0 comments on commit f4fa4f4

Please sign in to comment.