Variational Bayesian inference

Today I learned a cool trick for practical implementation of Bayesian inference.

Bayesians are interested in calculating posterior probability distributions of unobserved parameters X, given data (which consists of the values of observed parameters Y).

To do so, they need only know the form of the likelihood function (the probability of Y given X) and their own prior distribution over X. Then they can apply Bayes’ rule…

P(X | Y) = P(Y | X) P(X) / P(Y)

… and voila, Bayesian inference complete.

The trickiest part of this process is calculating the term in the denominator, the marginal likelihood P(Y). Trying to calculate this term analytically is typically  very computationally expensive – it involves a sum over all possible values of the parameters of the likelihood multiplied by the prior. If Y is drawn from a continuous infinity of possible parameter values, then calculating the marginal likelihood amounts to solving a (typically completely intractable) integral.

P(Y) = ∫ P(Y | X) P(X) dX

Variational Bayesian inference is a procedure that solves this problem through a clever trick.

We start by searching for a posterior in a space of functions F that are easily integrable.

Our goal is not to find the exact form of the posterior, although if we do, that’s great. Instead, we want to find the function Q(X) within F that is as close to the posterior P(X | Y) as possible.

Distance between probability distributions is typically calculated by the information divergence D(Q, P), which is defined by…

D(Q, P) = ∫ Q(X) log(Q(X) / P(X|Y)) dX

To explicitly calculate and minimize this, we would need to know the form of the posterior P(X | Y) from the start. But let’s plug in the definition of conditional probability…

P(X | Y) = P(X, Y) / P(Y)

D(Q, P) = ∫ Q(X) log(Q(X) P(Y) / P(X, Y)) dX
= ∫ Q(X) log(Q(X) / P(X, Y)) dX  +  ∫ Q(X) log P(Y) dX

The second term is easily calculated. Since log(P(Y)) is not a function of X, the integral just becomes…

∫ Q(X) log P(Y) dX = log P(Y)

Rearranging, we get…

log P(Y) = D(Q, P)  –  ∫ Q(X) log(Q(X) / P(X, Y)) dX

The second term depends on Q(X) and the joint probability P(X, Y), which we can calculate easily as the product of the likelihood P(Y | X) and the prior P(X). We name it the variational free energy, L(Q).

log P(Y) = D(Q, P) + L(Q)

Now, on the left-hand side we have the log of the marginal likelihood, and on the right we have the information distance plus the variational free energy.

Notice that the left side is not a function of Q. This is really important! It tells us that if we’re trying to vary Q to minimize D(Q, P), then the right side will be a constant quantity.

In other words, any increase in L(Q) is necessarily a decrease in D(Q, P). What this means is that the Q that minimizes D(Q, P) is the same thing as the Q that maximizes L(Q)!

We can use this to minimize D(Q, P) without ever explicitly knowing P.

Recalling the definition of the variational free energy, we have…

L(Q) = – ∫ Q(X) log(Q(X) / P(X, Y)) dX
= ∫ Q(X) log P(X, Y) dX – ∫ Q(X) log Q(X) dX

Both of these integrals are computable insofar as we made a good choice for the function space F. Thus we can exactly find Q*, the best approximation to P in F. Then, knowing Q*, we can calculate L(Q*), which serves as a lower bound on the log of the marginal likelihood P(Y).

log P(Y) = D(Q, P) + L(Q)
so log P(Y) ≥ L(Q*)

Summing up…

  1. Variational Bayesian inference approximates the posterior probability P(X | Y) with a function Q(X) in the function space F.
  2. We find the function Q* that is as similar as possible to P(X | Y) by maximizing L(Q).
  3. L(Q*) gives us a lower bound on the log of the marginal likelihood, log P(Y).