Variational Autoencoder (VAE) — PyTorch Tutorial
Variational Autoencoder (VAE) — PyTorch Tutorial
In autoencoders, the information from the input data is mapped into a fixed latent representation. This is particularly useful when we aim to train models to generate deterministic predictions. In contrast, a variational autoencoder (VAE) converts the input data to a variational representation vector (as the name suggests), where the elements of this vector represent different attributes about the input data distribution. This probabilistic property of the VAE makes it a generative model. The latent representation in VAE is composed of a probability distribution (μ, σ) that best defines our input data. To find out more about the intuition of VAEs, I recommend reading Understanding Variational Autoencoders (VAEs) and What is a variational autoencoder?
In this article, we only focus on a simple VAE in PyTorch and visualize its latent representation after training on the MNIST dataset. Let’s begin by importing some libraries:
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.utils import save_image, make_grid
and download the MNIST dataset and make dataloaders:
# create a transofrm to apply to each datapoint
transform = transforms.Compose([transforms.ToTensor()])
# download the MNIST datasets
path = '~/datasets'
train_dataset = MNIST(path, transform=transform, download=True)
test_dataset = MNIST(path, transform=transform, download=True)
# create train and test dataloaders
batch_size = 100
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Let’s visualize some sample training data:
# get 25 sample training images for visualization
dataiter = iter(train_loader)
image = dataiter.next()
num_samples = 25
sample_images = [image[0][i,0] for i in range(num_samples)]
fig = plt.figure(figsize=(5, 5))
grid = ImageGrid(fig, 111, nrows_ncols=(5, 5), axes_pad=0.1)
for ax, im in zip(grid, sample_images):
ax.imshow(im, cmap='gray')
ax.axis('off')
plt.show()
Now, we create a simple VAE which has fully-connected encoders and decoders . The input dimension is 784 which is the flattened dimension of MNIST images (28×28). In the encoder, the mean (μ) and variance (σ²) vectors are our variational representation vectors (size=200). Notice that we multiply the latent variance with the epsilon (ε) parameter for reparameterization before decoding. This allows us to perform backpropagation and tackle the node stochasticity. Read more on reparameterization here.
Also, our final encoder dimension has dimension 2 which are the μ and σ vectors. These continuous vectors define our latent space distribution that allows us to sample images in VAE.
class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=200, device=device):
super(VAE, self).__init__()
# encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, latent_dim),
nn.LeakyReLU(0.2)
)
# latent mean and variance
self.mean_layer = nn.Linear(latent_dim, 2)
self.logvar_layer = nn.Linear(latent_dim, 2)
# decoder
self.decoder = nn.Sequential(
nn.Linear(2, latent_dim),
nn.LeakyReLU(0.2),
nn.Linear(latent_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid()
)
def encode(self, x):
x = self.encoder(x)
mean, logvar = self.mean_layer(x), self.logvar_layer(x)
return mean, logvar
def reparameterization(self, mean, var):
epsilon = torch.randn_like(var).to(device)
z = mean + var*epsilon
return z
def decode(self, x):
return self.decoder(x)
def forward(self, x):
mean, logvar = self.encode(x)
z = self.reparameterization(mean, logvar)
x_hat = self.decode(z)
return x_hat, mean, log_var
Now, we can define our model and optimizer:
model = VAE().to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
The loss function in VAE consists of reproduction loss and the Kullback–Leibler (KL) divergence. The KL divergence is a metric used to measure the distance between two probability distributions. KL divergence is an important concept in generative modelling, but in this tutorial we won’t go into more detail. Read more: Intuitive Guide to Understanding KL Divergence.
def loss_function(x, x_hat, mean, log_var):
reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())
return reproduction_loss + KLD
Finally, we can train our model:
Recommended by LinkedIn
def train(model, optimizer, epochs, device):
model.train()
for epoch in range(epochs):
overall_loss = 0
for batch_idx, (x, _) in enumerate(train_loader):
x = x.view(batch_size, x_dim).to(device)
optimizer.zero_grad()
x_hat, mean, log_var = model(x)
loss = loss_function(x, x_hat, mean, log_var)
overall_loss += loss.item()
loss.backward()
optimizer.step()
print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss/(batch_idx*batch_size))
return overall_loss
train(model, optimizer, epochs=50, device=device)
We now know that all we need to generate an image from the latent space is two float values (mean and variance). Let’s generate some images from the latent space:
def generate_digit(mean, var):
z_sample = torch.tensor([[mean, var]], dtype=torch.float).to(device)
x_decoded = model.decode(z_sample)
digit = x_decoded.detach().cpu().reshape(28, 28) # reshape vector to 2d array
plt.imshow(digit, cmap='gray')
plt.axis('off')
plt.show()
generate_digit(0.0, 1.0), generate_digit(1.0, 0.0)
Interesting! A more impressive view of the latent space:
def plot_latent_space(model, scale=1.0, n=25, digit_size=28, figsize=15):
# display a n*n 2D manifold of digits
figure = np.zeros((digit_size * n, digit_size * n))
# construct a grid
grid_x = np.linspace(-scale, scale, n)
grid_y = np.linspace(-scale, scale, n)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = torch.tensor([[xi, yi]], dtype=torch.float).to(device)
x_decoded = model.decode(z_sample)
digit = x_decoded[0].detach().cpu().reshape(digit_size, digit_size)
figure[i * digit_size : (i + 1) * digit_size, j * digit_size : (j + 1) * digit_size,] = digit
plt.figure(figsize=(figsize, figsize))
plt.title('VAE Latent Space Visualization')
start_range = digit_size // 2
end_range = n * digit_size + start_range
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("mean, z [0]")
plt.ylabel("var, z [1]")
plt.imshow(figure, cmap="Greys_r")
plt.show()
plot_latent_space(model)
This is what the latent space looks like for mean and variance values between -1.0 and 1.0. What happens if we change this scale to -5.0 and 5.0?
Again, interesting! We can now see the range of mean and variance values that most digit representations lie within. Now, we know how to build a simple VAE from scratch, sample images and visualize the latent space. But VAEs do not end here, there are more advanced techniques that make representation learning even more fascinating.
Try the code in Google Colab:
. . . . . . . . . . . . . .