Update: Revised for PyTorch 0.4 on Oct 28, 2018
Mixture models allow rich probability distributions to be represented as a combination of simpler “component” distributions. For example, consider the mixture of 1-dimensional gaussians in the image below:
While the representational capacity of a single gaussian is limited, a mixture is capable of approximating any distribution with an accuracy proportional to the number of components2.
In practice mixture models are used for a variety of statistical learning problems such as classification, image segmentation and clustering. My own interest stems from their role as an important precursor to more advanced generative models. For example, variational autoencoders provide a framework for learning mixture distributions with an infinite number of components and can model complex high dimensional data such as images.
In this blog I will offer a brief introduction to the gaussian mixture model and implement it in PyTorch. The full code will be available on my github.
The Gaussian Mixture Model
A gaussian mixture model with components takes the form1:
where is a categorical latent variable indicating the component identity. For brevity we will denote the prior . The likelihood term for the kth component is the parameterised gaussian:
Our goal is to learn the means , covariances and priors using an iterative procedure called expectation maximisation (EM).
The basic EM algorithm has three steps:
- Randomly initialise the parameters of the component distributions.
- Estimate the probability of each data point under the component parameters.
- Recalculate the parameters based on the estimated probabilities. Repeat Step 2.
Convergence is reached when the total likelihood of the data under the model stops decreasing.
In order to quickly test my implementation, I created a synthetic dataset of points sampled from three 2-dimensional gaussians, as follows:
Initialising the Parameters
For the sake of simplicity, I just randomly select
K points from my dataset to
act as initial means. I use a fixed initial variance and a uniform prior.
The Multivariate Gaussian
Step 2. of the EM algorithm requires us to compute the relative likelihood of each data point under each component. The p.d.f of the multivariate gaussian is
By only considering diagonal covariance matrices , we can greatly simplify the computation (at the loss of some flexibility):
Instead of computing the matrix inverse we can simply invert the variances.
And lastly, the exponent simplifies to,
where represents element-wise multiplication and is our vector of inverse variances.
It is worth taking a minute to reflect on the form of the exponent in the last equation. Because there is no linear dependence between the dimensions, the computation reduces to calculating a gaussian p.d.f for each dimension independently and then taking their product (or sum in the log domain).
In high dimensions the likelihood calculation can suffer from numerical underflow. It is therefore typical to work with the log p.d.f instead (i.e. the exponent we derived above, plus the constant normalisation term). Note that we could use the in-built PyTorch distributions package for this, however for transparency here is my own functional implementation:
To compute the likelihood of every point under every gaussian in parallel, we can exploit tensor broadcasting as follows:
In order to recompute the parameters we apply Bayes rule to likelihoods as follows:
The resulting values are sometimes referred to as the “membership weights”, as they can explain the observation . Since our likelihoods are in the log-domain, we exploit the logsumexp trick for stability.
Using the membership weights, the parameter update proceeds in three steps:
- Set new mean for each component to a weighted average of the data points.
- Set new covariance matrix as weighted combination of covariances for each data point.
- Set new prior, as the normalised sum of the membership weights.
Apart from some simple training logic, that is the bulk of the algorithm! Here is a visualisation of EM fitting three components to the synthetic data I generated earlier:
Thanks for Reading!
If you found this post interesting or informative, have questions or would like to offer feedback or corrections feel free to get in touch at my email or on twitter. Also stay tuned for my upcoming post on Variational Autoencoders!
For a more rigorous treatment of the EM algorithm see .
- Bishop, C. (2006). Pattern Recognition and Machine Learning. Ch9.
- Bengio, Y., Goodfellow, I. (2016). Deep Learning.