Skip to content

Commit

Permalink
Update Gaussian mixture tutorial (#293)
Browse files Browse the repository at this point in the history
* Update Gaussian mixture tutorial

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Some additional fixes

* More fixes

* Update tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd

Co-authored-by: Rik Huijzer <rikhuijzer@pm.me>

* Some minor changes

* Another typo

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Rik Huijzer <rikhuijzer@pm.me>
  • Loading branch information
3 people authored Mar 24, 2022
1 parent 71c6349 commit ce51cad
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 202 deletions.
181 changes: 98 additions & 83 deletions tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
Original file line number Diff line number Diff line change
Expand Up @@ -4,162 +4,177 @@ permalink: /:collection/:name/
redirect_from: tutorials/1-gaussianmixturemodel/
---

The following tutorial illustrates the use *Turing* for clustering data using a Bayesian mixture model. The aim of this task is to infer a latent grouping (hidden structure) from unlabelled data.
The following tutorial illustrates the use of Turing for clustering data using a Bayesian mixture model.
The aim of this task is to infer a latent grouping (hidden structure) from unlabelled data.

More specifically, we are interested in discovering the grouping illustrated in figure below. This example consists of 2-D data points, i.e. $\boldsymbol{x} = \\{x_i\\}_{i=1}^N, x_i \in \mathbb{R}^2$, which are distributed according to Gaussian distributions. For simplicity, we use isotropic Gaussian distributions but this assumption can easily be relaxed by introducing additional parameters.
## Synthetic Data

We generate a synthetic dataset of $N = 60$ two-dimensional points $x_i \in \mathbb{R}^2$ drawn from a Gaussian mixture model.
For simplicity, we use $K = 2$ clusters with

- equal weights, i.e., we use mixture weights $w = [0.5, 0.5]$, and
- isotropic Gaussian distributions of the points in each cluster.

More concretely, we use the Gaussian distributions $\mathcal{N}([\mu_k, \mu_k]^\mathsf{T}, I)$ with parameters $\mu_1 = -3.5$ and $\mu_2 = 0.5$.

```julia
using Distributions, StatsPlots, Random
using Distributions
using FillArrays
using StatsPlots

using LinearAlgebra
using Random

# Set a random seed.
Random.seed!(3)

# Construct 30 data points for each cluster.
N = 30
# Define Gaussian mixture model.
w = [0.5, 0.5]
μ = [-3.5, 0.5]
mixturemodel = MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w)

# Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example.
μs = [-3.5, 0.0]
# We draw the data points.
N = 60
x = rand(mixturemodel, N);
```

# Construct the data points.
x = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.0), N), hcat, 1:2)
The following plot shows the dataset.

# Visualization.
```julia
scatter(x[1, :], x[2, :]; legend=false, title="Synthetic Dataset")
```

## Gaussian Mixture Model in Turing

To cluster the data points shown above, we use a model that consists of two mixture components (clusters) and assigns each datum to one of the components. The assignment thereof determines the distribution that the data point is generated from.
We are interested in recovering the grouping from the dataset.
More precisely, we want to infer the mixture weights, the parameters $\mu_1$ and $\mu_2$, and the assignment of each datum to a cluster for the generative Gaussian mixture model.

In particular, in a Bayesian Gaussian mixture model with $1 \leq k \leq K$ components for 1-D data each data point $x_i$ with $1 \leq i \leq N$ is generated according to the following generative process.
First we draw the parameters for each cluster, i.e. in our example we draw location of the distributions from a Normal:
In a Bayesian Gaussian mixture model with $K$ components each data point $x_i$ ($i = 1,\ldots,N$) is generated according to the following generative process.
First we draw the model parameters, i.e., in our example we draw parameters $\mu_k$ for the mean of the isotropic normal distributions and the mixture weights $w$ of the $K$ clusters.
We use standard normal distributions as priors for $\mu_k$ and a Dirichlet distribution with parameters $\alpha_1 = \cdots = \alpha_K = 1$ as prior for $w$:
$$
\mu_k \sim \mathrm{Normal}() \, , \; \forall k
\begin{aligned}
\mu_k &\sim \mathcal{N}(0, 1) \qquad (k = 1,\ldots,K)\\
w &\sim \operatorname{Dirichlet}(\alpha_1, \ldots, \alpha_K)
\end{aligned}
$$
and then draw mixing weight for the $K$ clusters from a Dirichlet distribution, i.e.
After having constructed all the necessary model parameters, we can generate an observation by first selecting one of the clusters
$$
w \sim \mathrm{Dirichlet}(K, \alpha) \, .
z_i \sim \operatorname{Categorical}(w) \qquad (i = 1,\ldots,N),
$$
After having constructed all the necessary model parameters, we can generate an observation by first selecting one of the clusters and then drawing the datum accordingly, i.e.
and then drawing the datum accordingly, i.e., in our example drawing
$$
z_i \sim \mathrm{Categorical}(w) \, , \; \forall i \\
x_i \sim \mathrm{Normal}(\mu_{z_i}, 1.) \, , \; \forall i
x_i \sim \mathcal{N}([\mu_{z_i}, \mu_{z_i}]^\mathsf{T}, I) \qquad (i=1,\ldots,N).
$$

For more details on Gaussian mixture models, we refer to Christopher M. Bishop, *Pattern Recognition and Machine Learning*, Section 9.

```julia
using Turing, MCMCChains

# Turn off the progress monitor.
Turing.setprogress!(false);
```
We specify the model with Turing.

```julia
@model function GaussianMixtureModel(x)
D, N = size(x)

# Draw the parameters for cluster 1.
μ1 ~ Normal()
using Turing

# Draw the parameters for cluster 2.
μ2 ~ Normal()
@model function gaussian_mixture_model(x)
# Draw the parameters for each of the K=2 clusters from a standard normal distribution.
K = 2
μ ~ MvNormal(Zeros(K), I)

μ = [μ1, μ2]
# Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.
w ~ Dirichlet(K, 1.0)
# Alternatively, one could use a fixed set of weights.
# w = fill(1/K, K)

# Uncomment the following lines to draw the weights for the K clusters
# from a Dirichlet distribution.
# Construct categorical distribution of assignments.
distribution_assignments = Categorical(w)

# α = 1.0
# w ~ Dirichlet(2, α)

# Comment out this line if you instead want to draw the weights.
w = [0.5, 0.5]
# Construct multivariate normal distributions of each cluster.
D, N = size(x)
distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]

# Draw assignments for each datum and generate it from a multivariate normal.
# Draw assignments for each datum and generate it from the multivariate normal distribution.
k = Vector{Int}(undef, N)
for i in 1:N
k[i] ~ Categorical(w)
x[:, i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.0)
k[i] ~ distribution_assignments
x[:, i] ~ distribution_clusters[k[i]]
end
return k
end;
```

After having specified the model in Turing, we can construct the model function and run a MCMC simulation to obtain assignments of the data points.
return k
end

```julia
gmm_model = GaussianMixtureModel(x);
model = gaussian_mixture_model(x);
```

To draw observations from the posterior distribution, we use a [particle Gibbs](https://www.stats.ox.ac.uk/%7Edoucet/andrieu_doucet_holenstein_PMCMC.pdf) sampler to draw the discrete assignment parameters as well as a Hamiltonion Monte Carlo sampler for continous parameters.

Note that we use a `Gibbs` sampler to combine both samplers for Bayesian inference in our model.
We are also calling `MCMCThreads` to generate multiple chains, particularly so we test for convergence.
We run a MCMC simulation to obtain an approximation of the posterior distribution of the parameters $\mu$ and $w$ and assignments $k$.
We use a `Gibbs` sampler that combines a [particle Gibbs](https://www.stats.ox.ac.uk/%7Edoucet/andrieu_doucet_holenstein_PMCMC.pdf) sampler for the discrete parameters (assignments $k$) and a Hamiltonion Monte Carlo sampler for the continuous parameters ($\mu$ and $w$).
We generate multiple chains in parallel using multi-threading.

```julia
gmm_sampler = Gibbs(PG(100, :k), HMC(0.05, 10, :μ1, :μ2))
tchain = sample(gmm_model, gmm_sampler, MCMCThreads(), 100, 3);
sampler = Gibbs(PG(100, :k), HMC(0.05, 10, :μ, :w))
nsamples = 100
nchains = 3
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains);
```

```julia; echo=false; error=false
let
matrix = get(tchain, :μ1).μ1
first_chain = matrix[:, 1]
actual = mean(first_chain)
# Verify that the output of the chain is as expected.
# μ1 and μ2 appear to switch places, so that's why isapprox(...) || isapprox(...).
@assert isapprox(actual, -3.5; atol=1) || isapprox(actual, 0.2; atol=1)
for i in MCMCChains.chains(chains)
# μ[1] and μ[2] can switch places, so we sort the values first.
chain = Array(chains[:, ["μ[1]", "μ[2]"], i])
μ_mean = vec(mean(chain; dims=1))
@assert isapprox(sort(μ_mean), μ; rtol=0.1)
end
end
```

## Visualize the Density Region of the Mixture Model
## Inferred Mixture Model

After successfully doing posterior inference, we can first visualize the trace and density of the parameters of interest.
After sampling we can visualize the trace and density of the parameters of interest.

In particular, in this example we consider the sample values of the location parameter for the two clusters.
We consider the samples of the location parameters $\mu_1$ and $\mu_2$ for the two clusters.

```julia
ids = findall(map(name -> occursin("μ", string(name)), names(tchain)));
p = plot(tchain[:, ids, :]; legend=true, labels=["Mu 1" "Mu 2"], colordim=:parameter)
plot(chains[["μ[1]", "μ[2]"]]; colordim=:parameter, legend=true)
```

You'll note here that it appears the location means are switching between chains. We will address this in future tutorials. For those who are keenly interested, see [this](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html) article on potential solutions.
It can happen that the modes of $\mu_1$ and $\mu_2$ switch between chains.
For more information see the [Stan documentation](https://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html) for potential solutions.

For the moment, we will just use the first chain to ensure the validity of our inference.
We also inspect the samples of the mixture weights $w$.

```julia
tchain = tchain[:, :, 1];
plot(chains[["w[1]", "w[2]"]]; colordim=:parameter, legend=true)
```

As the samples for the location parameter for both clusters are unimodal, we can safely visualize the density region of our model using the average location.
In the following, we just use the first chain to ensure the validity of our inference.

```julia
# Helper function used for visualizing the density region.
function predict(x, y, w, μ)
# Use log-sum-exp trick for numeric stability.
return Turing.logaddexp(
log(w[1]) + logpdf(MvNormal([μ[1], μ[1]], 1.0), [x, y]),
log(w[2]) + logpdf(MvNormal([μ[2], μ[2]], 1.0), [x, y]),
)
end;
chain = chains[:, :, 1];
```

As the distributions of the samples for the parameters $\mu_1$, $\mu_2$, $w_1$, and $w_2$ are unimodal, we can safely visualize the density region of our model using the average values.

```julia
# Model with mean of samples as parameters.
μ_mean = [mean(chain, "μ[$i]") for i in 1:2]
w_mean = [mean(chain, "w[$i]") for i in 1:2]
mixturemodel_mean = MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ_mean], w_mean)

contour(
range(-5; stop=3),
range(-6; stop=2),
(x, y) -> predict(x, y, [0.5, 0.5], [mean(tchain[:μ1]), mean(tchain[:μ2])]),
range(-7.5, 3; length=1_000),
range(-6.5, 3; length=1_000),
(x, y) -> logpdf(mixturemodel_mean, [x, y]);
widen=false,
)
scatter!(x[1, :], x[2, :]; legend=false, title="Synthetic Dataset")
```

## Inferred Assignments

Finally, we can inspect the assignments of the data points inferred using Turing. As we can see, the dataset is partitioned into two distinct groups.
Finally, we can inspect the assignments of the data points inferred using Turing.
As we can see, the dataset is partitioned into two distinct groups.

```julia
assignments = mean(MCMCChains.group(tchain, :k)).nt.mean
assignments = [mean(chain, "k[$i]") for i in 1:N]
scatter(
x[1, :],
x[2, :];
Expand Down
Loading

0 comments on commit ce51cad

Please sign in to comment.