Inference Generation Tradeoffs in VAEs
Introduction¶
This post covers the recent paper Distribution Matching in Variational Inference by Mihaela Rosca, Balaji Lakshminarayanan and Shakir Mohamed. It covers a variety of topics in deep generative modelling, and this post is my own take on some parts of it, so I also recommend the original paper! Many of the ideas in this post have been seen before, for example see Ferenc's post on a similar topic, which lists a lot of related resources.
Inference Generation Tradeoff in VAEs¶
In VAEs, we try to maximise the ELBO, which serves as a lower bound on the marginal likelihood. We simultaneously learn the parameters, $\theta$, of the 'true' distribution, $p$, and the parameters, $\phi$ of the approximation $q$.
Recall the ELBO;
$log \ p_{\theta}(x) \geq E_{q_{\phi}}[log \ p_{\theta}(x \mid z)] - KL[q_{\phi}(z \mid x) || p(z)]$
As the authors of the paper say, the first term of the ELBO is an approximation of the log-likelihood of the sample. To maximise this term, our network needs to learn a $p$ and $q$ such that $q_{\phi}(z \mid x)$ is sufficiently different to $q_{\phi}(z \mid x^{*})$, where $x$ and $x^{*}$ are distinct samples. If this isn't true, the $z$ variables will not encode much information about our sample. In addition, the network should learn $p$ such that the sample $x$ has high likelihood under $p$.
To summarise this point, the first term is large when $q$ encodes a large amount of information about the sample, and also when $p$ assigns high mass to the sample when conditioned on $z$.
The second term, however, needs to be minimised in order to maximise the ELBO. This penalizes the distance between the prior and the approximate posterior.
So the inference generation trade off comes from these two terms - the first strives to encode lots of information into $q$, whilst the KLD seeks to make $q$ unconditional.
Sampling Vs Reconstructions¶
With VAEs, we train the network by taking a sample, passing it through $q$ to get a latent variable, then reconstructing by passing the latent sample through $p$. Gradients are then taken based on the ELBO.
So, during training, the reconstructor $p$, only ever sees zs that are sampled from q. However, when we sample, we sample from our prior $p(z)$. If the prior and approximate posterior are very close this is probably fine. Based on this assumption, we can generate both good samples and good reconstructions.
However, this paper shows that this is never true - in almost every case, the likelihood term dominates, and the KLD is always non zero, often in a substantial way. This means we can have good reconstructions but poor samples - which is likely a reason VAEs and extensions have always struggled to compete with GANs for the quality of the generated samples.
This is the crux of the Inference-Generation tradeoff, by making the likelihood term large, the network is unable to make the approximate posterior be reasonably similar to the prior, and so when we sample using the prior, we probably wont get the same quality as we do when we reconstruct.
This can be seen as a generalization problem - the VAE can do just fine when we reconstruct values from $q$ but don't do anywhere near as well when we sample from $p$. What we want is:
$ \int p_{\theta}(x | z)q_{\phi}(z | x^{*})p(x^{*}) dx^{*} = p_{\theta}(x | z)p(z)$
In words, we can say we want the 'average' over our true samples to be equivalent under either the prior or approximate posterior. Basically, every point under the distribution of the prior should be a latent sample corresponding to some point on the data manifold.
$ p_{\theta}(x | z) \int q_{\phi}(z | x^{*})p(x^{*}) dx^{*} = p_{\theta}(x | z)p(z)$
$ p_{\theta}(x | z) q_{\phi}(z) = p_{\theta}(x | z)p(z)$
where $p(x^{*})$ is the true data distribution.
This obviously implies we want $q_{\phi}(z)$, not the conditional distribution, to be close to the prior.
An example of when this is a problem¶
Last year, Rui Shu published a blog post called Autoencoding a single bit. Rui proposes a simple problem where the VAE takes either 1 or 0 as input, and has a 1D latent dimension.
Here I replicate the experiment, using the exact same code from my previous post on VAEs (some code is omitted for clarity, you can checkout the previous post and make relevant changes easily).
import torch
from torch import nn
from torch.autograd import Variable
import numpy as np
representation_size = 1
input_size = 1
n_samples = 2
batch_size = 2
n_samples_per_batch = n_samples//input_size
x = np.concatenate([np.zeros(n_samples//2, dtype=np.float32), np.ones(n_samples//2, dtype=np.float32)]).reshape(-1,1)
print(x)
model = VAE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
for epoch in range(1, 50):
train(epoch)
As this shows, the KL term dominates in this case. This means the optimal ELBO is found by moving q to be close to the prior, and once we do this, we essentially force independence, so $p(x | z) = p(x)$. Rui's original post explains this in much more detail, so check it out.
The key point here is that this is a scenario where the VAE strongly favours generation over inference - because the KL dominates, whether we generate samples from the prior or approximate posterior, they will be roughly the same quality.
However, we sacrifice on the inference - we have learnt a really poor model for $p(x \mid z)$.
The problem¶
The real problem lies within the KL divergence. In most VAE formulations, the prior is used to shape the latent variable space - have a look at a few of my previous posts to see this in action. However, this property is actually about the marginal distribution of $q$.
We would really like the overall shape to be given by the prior, e.g. the unit circle, and that this be 'divided up' so that conditioning on different samples gives some part of this space.
Because the ELBO requires us to minimise $KL[q(z|x) || p(z)]$, we don't do this. What we really want is to minimise $KL[q(z)|| p(z)]$. However, we don't have easy access to the marginal on $q$, so the ELBO uses the conditional as a way to approximately achieve this. This is the cause of all our problems!
Unfortunately, knowing the problem doesn't help much - computing the marginal is really expensive. We would also need to compute a bound involving the marginal. The ELBO Surgery paper shows:
$E_{p(x^{*})}[KL[q_{\phi}(z|x)||p(z)]] \geq KL[q_{\phi}(z)||p(z)]$
So if we just change the KL term in the ELBO to be in terms of the marginal q, we may not have a lower bound anymore.
Summary¶
This post has been a quick summary of the recent paper on the Generation Inference problem within VAEs. This has also been noted before in many places, and in general it should be no surprise - if we pick a very simple prior, this will be constraining.