Adversarial Variational Bayes in Pytorch
Adversarial Variational Bayes in Pytorch¶
In the previous post, we implemented a Variational Autoencoder, and pointed out a few problems. The overlap between classes was one of the key problems. The normality assumption is also perhaps somewhat constraining.
In this post, I implement the recent paper Adversarial Variational Bayes, in Pytorch. This addresses some of the issues with VAEs, and also provides some interesting links to GANs, the other popular approach to generative modelling.
Theory¶
Before diving in, if you are not familiar with VAEs, I suggest you take some time to recap my previous post on the theory of VAEs.
The main modification proposed by the AVB paper, is to change the encoder from being a parameterized Gaussian to being a fully implicit distribution. In a vanilla VAE, the encoder network takes a data point, x, as input and outputs the mean and variance of a normal distribution. We then used the reparameterization trick on standard normal samples to sample values of the latent variables.
The first modification is that instead of outputting a mean and variance, the encoder now returns a z value directly. The inputs are now data points, x, and standard Gaussian noise. So our encoder network learns how to incorporate the random noise to generate a sample from the approximate posterior directly.
This is most easily shown by figure 2 in the original paper.
Implicit Likelihood Ratio¶
Once we have made the above change, we have a problem. The ELBO:
$max_{\phi}max_{\theta} \ E_{p_{D}}[ \mathcal{L}] = max_{\phi, \theta} E_{p_{D}}[E_{q}[log \ P_{\theta}(X \mid Z) + log \ P_{\theta}(Z) - log \ q_{\phi}(Z \mid X)]] $
contains probabilities under the approximate posterior, q. Now we have an implicit model for q, we can't evaluate the probability of a sample. In order to deal with this problem, we use the idea from Learning in Implicit Generative Models, covered in a previous post. The idea is to use a discriminator to approximate the log ratio of q and the prior.
We introduce a second network, $T(x,z)$, which outputs a single value for each sample. If we label the samples from the posterior as class 1, and from the prior as class 0, we can then pass the output of T through a sigmoid, and then train it using binary cross entropy, in exactly the same way as with my previous post on Discriminators as likelihood ratios. In this case, the output of T (without the sigmoid) are the logits, which at optimality of the discriminator is the ratio of the two distributions. We can therefor substitute the output of T directly for the ratio term in the ELBO, $log \ p(Z) - log \ q_{\phi}(Z\mid X) $.
This gives us two loss functions that we optimise in an iterative process. The discriminator is trained by minimising the binary cross entropy, and the encoder and decoder is trained by maximising the ELBO, but with the discriminator's estimate in place of an analytical $log \ p(Z) - log \ q_{\phi}(Z\mid X) $.
$max \ \mathcal{L_{D}} = E_{P_{D}}( E_{q(z \mid x)}[log \ \sigma (T(X, Z))] + E_{p(z)}[log \ 1 - \sigma (T(X, Z))] )$
$max \ \mathcal{L_{G}} = E_{p_{D}}[E_{q}[log \ P_{\theta}(X \mid Z) - T(X, Z)]] $
In practice¶
I continue with the example I used in the VAE post, the toy example that is included in the post.
import torch
from torch import nn
from torch.autograd import Variable
import numpy as np
representation_size = 2
input_size = 4
n_samples = 2000
batch_size = 500
gen_hidden_size = 200
enc_hidden_size = 200
disc_hidden_size = 200
n_samples_per_batch = n_samples//input_size
y = np.array([i for i in range(input_size) for _ in range(n_samples_per_batch)])
d = np.identity(input_size)
x = np.array([d[i] for i in y], dtype=np.float32)
Some example data points:
print(x[[10, 58 ,610, 790, 1123, 1258, 1506, 1988]])
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.gen_l1 = torch.nn.Linear(representation_size, gen_hidden_size)
self.gen_l2 = torch.nn.Linear(gen_hidden_size, input_size)
self.enc_l1 = torch.nn.Linear(input_size+representation_size,
enc_hidden_size)
self.enc_l2 = torch.nn.Linear(enc_hidden_size, representation_size)
self.disc_l1 = torch.nn.Linear(input_size+representation_size,
disc_hidden_size)
self.disc_l2 = torch.nn.Linear(disc_hidden_size, 1)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def sample_prior(self, s):
if self.training:
m = torch.zeros((s.data.shape[0], representation_size))
std = torch.ones((s.data.shape[0], representation_size))
d = Variable(torch.normal(m,std))
else:
d = Variable(torch.zeros((s.data.shape[0], representation_size)))
return d
def discriminator(self, x,z):
i = torch.cat((x, z), dim=1)
h = self.relu(self.disc_l1(i))
return self.disc_l2(h)
def sample_posterior(self, x):
i = torch.cat((x, self.sample_prior(x)), dim=1)
h = self.relu(self.enc_l1(i))
return self.enc_l2(h)
def decoder(self, z):
i = self.relu(self.gen_l1(z))
h = self.sigmoid(self.gen_l2(i))
return h
def forward(self, x):
z_p = self.sample_prior(x)
z_q = self.sample_posterior(x)
log_d_prior = self.discriminator(x, z_p)
log_d_posterior = self.discriminator(x, z_q)
disc_loss = torch.mean(
torch.nn.functional.binary_cross_entropy_with_logits(
log_d_posterior, torch.ones_like(log_d_posterior)
)
+ torch.nn.functional.binary_cross_entropy_with_logits(
log_d_prior, torch.zeros_like(log_d_prior))
)
x_recon = self.decoder(z_q)
recon_liklihood = -torch.nn.functional.binary_cross_entropy(
x_recon, x)*x.data.shape[0]
gen_loss = torch.mean(log_d_posterior)-torch.mean(recon_liklihood)
return disc_loss, gen_loss
model = VAE()
disc_params = []
gen_params = []
for name, param in model.named_parameters():
if 'disc' in name:
disc_params.append(param)
else:
gen_params.append(param)
disc_optimizer = torch.optim.Adam(disc_params, lr=1e-3)
gen_optimizer = torch.optim.Adam(gen_params, lr=1e-3)
def train(epoch, batches_per_epoch = 501, log_interval=500):
model.train()
ind = np.arange(x.shape[0])
for i in range(batches_per_epoch):
data = torch.from_numpy(x[np.random.choice(ind, size=batch_size)])
data = Variable(data, requires_grad=False)
discrim_loss, gen_loss= model(data)
gen_optimizer.zero_grad()
gen_loss.backward(retain_graph=True)
gen_optimizer.step()
disc_optimizer.zero_grad()
discrim_loss.backward(retain_graph=True)
disc_optimizer.step()
if (i % log_interval == 0) and (epoch % 1 ==0):
#Print progress
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}\tLoss: {:.6f}'.format(
epoch, i * batch_size, batch_size*batches_per_epoch,
discrim_loss.data[0] / len(data), gen_loss.data[0] / len(data)))
print('====> Epoch: {} done!'.format(
epoch))
for epoch in range(1, 15):
train(epoch)
data = Variable(torch.from_numpy(x), requires_grad=False)
model.train()
zs = model.sample_posterior(data).data.numpy()
import matplotlib.pyplot as plt
%matplotlib inline
plt.scatter(zs[:,0], zs[:, 1], c=y)
The above plot generates stochastic z values. As we can see, this distribution is much clearer than the equivalent from VAEs.
data = Variable(torch.from_numpy(x), requires_grad=False)
model.eval()
zs = model.sample_posterior(data).data.numpy()
plt.scatter(zs[:,0], zs[:, 1], c=y)
As this shows, it has learnt that their should be 4 means for the posterior, and positioned them far away from each other.
Some example points¶
The output of the decoder is the Bernoulli probability that that element is 1.
test_point = np.array([0.5, 0.6], dtype=np.float32).reshape(1,-1)
test_point = Variable(torch.from_numpy(test_point), requires_grad=False)
s = model.decoder(test_point)
s.data
As the boundary is much sharper, it is normally clear for almost any point which class the latent variable corresponds to. If we pick a border point, the network has much lower probabilties.
test_point = np.array([0., 0.], dtype=np.float32).reshape(1,-1)
test_point = Variable(torch.from_numpy(test_point), requires_grad=False)
s = model.decoder(test_point)
s.data
Summary¶
In this post I have implemented the AVB algorithm in pytorch, and shown that it provides more intuitive latent codings than a vanilla VAE. This has been the first post to incorporate ideas from implicit generative modelling, and I hope to go over some more substantially theory in future posts.