Variational Inference — or How to Estimate Unknown Probability Function 2

DZ
7 min readOct 2, 2023

--

In a previous post, we discussed the problem of estimating a probability function that cannot be calculated directly, and we saw a method called Markov Chain Monte Carlo to address this problem. In this post, we are going to repeat the problem (so you don’t have to read the previous post) and then see another method called variational inference to try and find an approximation to the problem.

In the field of probability and statistics, it is not uncommon to encounter situations where the probability distribution function cannot be calculated directly. One example of such a situation is the Bayesian inference.

A popular example of Bayesian inference is the estimation of model parameters. The parameters are the elements that define the model structure. For example, for a normal distribution, the parameters are the mean and the variance. For an exponential distribution, the parameter is λ.

In this case, we usually have some samples x that come from a distribution with unknown parameters θ, but we assume we have some prior knowledge about θ we can express using a probability function p(θ). Using Bayes theorem, we can find the probability for θ according to observed data x (called the posterior).

In most cases, the calculation of the numerator is not too difficult. However, the normalization factor can be a challenge. To find p(x), we usually need to calculate the integral

In some simple cases (especially in low dimensions), we can compute that integral, but in higher dimensions, it may become intractable. So, instead of computing p(θ|x) directly, we need to use some approximation.

There are two main approaches to tackling this problem. The first one is the Markov Chain Monte Carlo (MCMC), which is based on sampling from the unknown distribution, and you can read about it here. The second one is the Variational Inference, which is an approximation-based approach, and we are going to deal with it in this post.

Variational Inference

Suppose we have some probability distribution function that we know up to a scaling factor, as in the case of the Bayesian inference. We can write

In the variation inference, we want to find some (simple and easy to compute) distribution q(θ) from a set of parametric distribution probability functions (for example, a normal distribution is a family of distributions parametrized by the mean and the variance) such as q(θ) minimize some error function (or maximize similarity function) relative to the unnormalized distribution that we can compute (for example, p(x|z)p(z)). In other words, we search for parameters set θ that represent the best approximation q(θ) from some family of distribution to the known unnormalized function. Sometimes, q is called a surrogate function.

Mathematically, we can write it as

where Q is a distribution family, Err is some error function that we want to minimize, and q* is the probability distribution function that gives us the best approximation by minimizing the error function.

To define the error function, we want some metric that measures how much two distributions are different from each other and also a metric that the fact that we don’t know the normalization factor won’t influence. To achieve this, we can use the KL-divergence, which has high values for different distributions and zero for perfectly matched distributions.

We can write it as

where the KL divergence is

It’s important to note that, in general, θ can be a vector, so the integration is multi-dimensional. The problem with the expression above is that we still don’t know p(θ|x), but we do know p(x|θ)p(θ) = p(θ, x), where the right-hand side is the joint distribution of x and θ. Using Bayes theorem, we can write

So finally, we can write

L(q) is also called the Evidence Lower Bound (ELBO) since the log-likelihood of the evidence (the data x we know), log(p(x)), cannot be smaller than it. We can see that by rearranging the equation.

Since the KL divergence is a non-negative value, the value of L(q) is smaller or equal to the log-likelihood of the evidence, and hence, L(q) is a lower bound of it.

Recall now that we originally wanted to find a probability distribution q(θ) for which the KL-divergence gets minimum. But, from what we saw, we can define an equivalent optimization problem using the L(q) such as

That’s true because log(p(x)) is a constant in the optimization problem, and as such, it does not influence the optimization.

Let’s take another look at the optimization over L(q).

The first term is the expectation of the log-likelihood. This term wants to maximize the problem according to the observed data x by selecting a model that best explains the observed data. The second term is the negative KL divergence between the approximation and the prior. This term maximizes by choosing an approximation that is very similar to the prior. In summary, we can see in this optimization problem the prior-likelihood trade-off.

Simple Example

Let’s see an example to understand the variational inference better. This example follows the example in this video. As we explained earlier, we usually know the prior p(θ) and the likelihood p(x|θ), and our goal is to approximate p(θ|x) using a simpler distribution q(θ) that belongs to some distributions family.

We assume the following:

meaning p(θ) is an exponential distribution with parameter λ=1, and p(θ|x) is a normal distribution with mean equal to θ and variance of 1. We can write it directly as:

where I is the indicator function, which gives 1 when the condition in the argument is met and 0 otherwise. The value of p(x) is expressed by

but this integral has no closed form, and we don’t have it (although, in this simple case, we can try to evaluate it numerically). So, instead, we want to use the variational inference method to find an approximation to p(θ|x).

First we want to get p(θ, x) (the joint distribution)

Now, let’s define q(θ) as an exponential distribution parameterized by ϕ, so

Pay attention that q_ϕ(θ) represents the family of exponential distribution using the parameter ϕ. Our goal is to find the ϕ for which q_ϕ(θ) is the best approximation over all the possibles of ϕ, or mathematically

We have an expression for each of the elements so we can develop the expression we want to maximize.

where we used the fact that

and the expectation property for exponential distribution

We can now derivate L(q) according to ϕ to find the value that maximizes it.

Recall that x is known to us. Let’s assume x is 1.3, for example. Then, since we have a constraint on ϕ that it must be positive, the solution is the one with the plus sign.

In blue, the joint distribution p(θ, x) while we know x=1.3 (this is not a probability density function since integration over θ won’t sum to 1). In orange, the best approximation q assuming x=1.3. In green, the real probability density function of the posterior (computed numerically).

You can see the q*_ϕ(θ) captures the general properties of the real posterior as we expect from an approximation. Remember that the result is directly influenced by the distribution type we choose.

--

--