This semester, I took computational statistics (ISYE 6416) and machine learning (CS 4641). I’ve been attempting to study for the finals, both of which cover expectation maximization (EM), and initially struggled with connecting the concepts learned in the two classes. This post is more for me to conceptualize EM from a more theoretical/algorithmic standpoint and as an applied strategy.
EM builds off of Maximization Likelihood Estimation (MLE), which is a method for approximating distribution parameters given some observations. The maximum likelihood estimator is the parameter that maximizes the likelihood of having such data observations.
Let’s say we have some data , and we hypothesize that the data has been drawn independentally from some distribution, with parameter . Then, the likelihood of drawing those observations is:
Generally, you’ll have some distribution in mind (e.g. exponential distribution), and is the probability distribution function for that distribution. In practice, it is useful to use the natural logarithm of the likelihood function for a few reasons:
- The maximizer/argmax of the log likelihood function is the same as that for the likelihood function.
- Log has useful properties that can be applied to simplify the function. Notably, , , and
Thus, the log likelihood function usually looks something like this:
Generally, to find the maximum, we can take the derivative of with respect to , set this equal to 0, and solve for .
Expectation Maximization Algorithm
EM is an iterative strategy for estimating maximum likelihood parameters when there are missing data or latent variables. The Computational Statistics textbook by Givens and Hoeting uses the following notation:
- : observed variables
- : missing or latent variables
- : complete data,
Then, the algorithm is outlined below:
- E step: We compute , which is the expected log likelihood for the complete data, conditioned on the observed data.
Since we have observed X, Z is the only random portion of Y.
- M step: Maximize with respect to . Set to this maximizer, and go back to the E step unless some stopping or convergence criteria has been met.
The log likelihood is monotonically non-decreasing with subsequent iterations of EM.
The Givens and Hoeting textbook contains a toy example that illustrates the idea pretty well. In the example, there are two variables, and , which are independantly and identically distributed – . is 5, but is missing.
First, let’s find the log likelihood function of the complete data. Note that the probability distribution function of the exponential distribution is
How do we use this log likelihood to find ? Our current log likelihood is for the complete data, but we have observed . We can find the conditional expectation of based on the observed value for . Let’s note here that and are independent of each other and the mean of the exponential distribution is always . Thus,
We can use the observed data and the conditional expectation to write the E step as:
To find the argmax of the , we can derive with respect to and solve.
Clustering with EM
A common machine learning problem involves assigning data to clusters. EM is generally used in this context when we assume that our data corresponds to a Gaussian mixture model. Here, each Gaussian component represents a cluster, and each point gets assigned some probability of belonging to each cluster. This is also called soft clustering, as opposed to hard clustering, in which each point is explicitly assigned some cluster.
Gaussian Mixture Models
A mixture model consists of distributions, each with a weight. All weights sum to 1. In a Gaussian mixture model, each of the component distributions is a normal distribution. For example, you may see mixture models written like this:
This would imply that roughly 70 percent of observations are from the first distribution and 30 percent of observations are from the second one.
More generally, a Gaussian mixture model can be expressed as:
where are the weights and .
For clustering purposes, the parameters to estimate are the mean and standard deviations of each Gaussian component, and the hidden variables are which Gaussian component each point corresponds to.