From Variational Inference to Variational Autoencoders

In previous posts, we have introduced Variational Inference, mainly from the perspective of computing an intractable posterior. However, this is just one use of the very general framework that we have encountered. In this post, I will bridge the gap between the previous posts and Variational Autoencoders.

Recap

In previous posts, we have had a fixed model, and are looking to compute an approximate posterior, often because this is computationally hard to do exactly. To do this, we derived the ELBO, which revolves around the KL-divergence to measure the 'closeness' of the true and approximate posteriors. Let's denote the parameters of our approximate posterior as $\phi$.

$\mathcal{L} = E_{q}[log \ \frac{P(X,Z)}{q(Z)}] = E_{q}[log \ P(X,Z) - log \ q_{\phi}(Z)] $

In order to find our approximation, $q$, we solve:

$max_{\phi}(\mathcal{L})$

Learning the true joint

The natural question to ask is "what if we don't know the model for the joint distribution?". Let's first imagine that we do have a form for the model - e.g. it is normally distributed, but we don't know the parameters of this model. We can use maximum likelihood to learn the parameters of our true model, which we will denote as $\theta$.

So, if we want to learn both the model and the approximate posterior, we can rewrite the above objective as:

$max_{\phi}max_{\theta} \ E_{p_{D}}[ \mathcal{L}] = max_{\phi, \theta} E_{p_{D}}[E_{q}[log \ P_{\theta}(X,Z) - log \ q_{\phi}(Z)]] $

where $p_{D}(X)$ denotes the distribution of the data. In practice, we have some data, $X$, and we minimise the mean of the loss function over this data.

This represents a very minor change to the previous framework, and this is quite natural to handle in modern autodiff framworks.

Variational Autoencoders

One final step brings us to Variational Autoencoders. The problem with doing simple maximum likelihood as outlined before, is that it is very constraining. If I want to learn a model over Gaussians, it's fine, but if I want to learn a model of the distribution of images that contain zebras, specifying the model is a problem!

To solve this problem, VAEs replace the explicit model with a neural network. Firstly, we make the near trivial extension that the approximate posterior is given by $q(z \mid x)$. The original paper took the posterior as:

$q(z \mid X) \sim N(\mu_{q \mid X}, \Sigma_{q \mid X})$

The notation here means that every x value has its own mean and variance associated with it - so if we have n data points, we learn n means and n covariances.

We also specify the joint as a prior on Z and a likelihood term on x. The prior is a unit gaussian.

$P(Z) \sim N(0,I)$

The form of $P(X \mid Z)$ is either a Gaussian or Bernoulli, depending on the data.

Distributions from Neural Networks

In order to have both the approximate posterior and likelihood terms, we use neural networks to learn the form. In the case of the likelihood, we have a neural network that takes a sample from $P(Z)$ as input, and outputs a mean and a variance of the likelihood distribution if it is Gaussian, or the Bernoulli parameter if not.

The approximate posterior is parameterised by a neural network that takes a sample of our data as input, and outputs a mean and variance for the approximate posterior at this X value.

In short, the neural networks act a bit like lookup tables - we pass in an input value and get the parameters of the distribution out. From there, everything is much the same as before - we maximise the ELBO.

Let's assume that the likelihood is given be a Gaussian:

$P(X \mid Z) \sim N(f(Z), g(Z))$

Where f and g are our neural network. Assuming we don't use a trivial network (e.g. no hidden layers or no non-linearities), f and g are likely to be far from linear. This means the joint is no longer a Gaussian, but this is a good thing! If the network is sufficiently flexible, it can learn very complex joints that are not at all similar to a Gaussian.

Some complications to be aware of

One obvious problem comes from the posterior - whereas before we had to only learn a single posterior, now we learn one for each data point. This means if we have lots of data, this may be a problem!

$q(z \mid X) \sim N(h(X), k(X))$

Another, more subtle issue comes from the ELBO itself.

$\mathcal{L} = E_{q}[log \ \frac{P(X,Z)}{q(Z \mid X)}] $

We encounter problems if the support of the distributions P and q are different. If there is some point where the distribution under P is non-zero, but the distribution under q is very small, the bound becomes very large. If there is a point where q is actually or numerically zero, the bound is infinite.

As the bound becomes very large, the gradient signal we get becomes smaller and smaller, and so we may well encounter convergence issues. We will come back to this in more detail in a later post.

Summary

In this post, we have bridged the (theoretical) gap between VI and VAEs. We have discussed how to extend the ELBO to cope with learning the model as well as the approximation, and how to use neural networks to help us out for complex distributions.

In the next post, we will see some code implementing these ideas, and in the future we will expand on some of the problems raised at the end of the post, and see some recent solutions.