Vote Wrapper: Generating MNIST Digits with VAE

In this tutorial we show how to wrap a VAE model and train it on MNIST dataset. Here we are using a Vote wrapper.

Step 1: Initial Setup

This step is universal and dependant to your specific model and dataset you are using. In this tutorial we are covering the steps needed to create a VAE model and load MNIST dataset.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import tqdm

Note: Edit the line below to match your GPU information

device = "cuda:0" if torch.cuda.is_available() else "cpu"

Creating the Decoder module.

class Decoder(nn.Module):
    def __init__(self, latent_dims):
        super(Decoder, self).__init__()
        self.lin = nn.Linear(latent_dims, 32 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, z):
        z = self.lin(z)
        z = torch.reshape(z, (z.shape[0], 32, 7, 7))
        return self.decoder(z)

Creating the Encoder module.

class VariationalEncoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )
        self.fc_mu = nn.Linear(32 * 7 * 7, latent_dims)
        self.fc_sigma = nn.Linear(32 * 7 * 7, latent_dims)

        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda(device)  # hack to get sampling on the GPU
        self.N.scale = self.N.scale.cuda(device)
        self.kl = 0

    def forward(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        sigma = torch.exp(self.fc_sigma(x))
        z = mu + sigma * self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1 / 2).sum()
        return z

Creating a VAE module

class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = VariationalEncoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

Loading the MNIST dataset

data = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        ".", transform=torchvision.transforms.ToTensor(), download=True
    ),
    batch_size=128,
    shuffle=True,
    drop_last=True,
)

The latent space dimension is set to 2 for easier visualization.

vae = VariationalAutoencoder(2).to(device)  # GPU

Step 2: Wrapping the model

In order to wrap your model with any wrapper you need to pass in the model along with a sample input to your model.

# Wrap the model
from capsa_torch import vote

x, y = next(iter(data))
x = x.cuda(device)
vae.decoder = vote.Wrapper(n_voters=2)(vae.decoder)

Note

Most wrappers require re-training or finetuning. Vote is one of the wrappers that requires re-training or finetuning and will not provide useful outputs otherwise. Read the wrapper documentations in order to pick the one that is suitable for your usecase.

Step 3: Training

Note

Add the tile_and_reduce=False flag for the Vote wrapper

opt = torch.optim.Adam(vae.parameters())
for epoch in range(20):
    for x, y in data:
        x = x.to(device)
        opt.zero_grad()
        z = vae.encoder(x)
        # NOTE: add the tile_and_reduce = False flag for the Vote Wrapper
        x_hat = vae.decoder(z, tile_and_reduce=False, return_risk=False)
        loss = ((x - x_hat) ** 2).sum() + vae.encoder.kl
        loss.backward()
        opt.step()

Step 4: Testing & Evaluation

Once you have trained your wrapped model, you will get a risk value for every output to your model if return_risk = True. Before showing the risk predicted by the wrapper, you can visualize the outputs of the VAE in its first two dimensions, different colors indicate different labels. This gives a sense of what the distribution of reconstructed images looks like.

def plot_latent(autoencoder, data, num_batches=100):
    for i, (x, y) in enumerate(data):
        z = autoencoder.encoder(x.to(device))
        z = z.to('cpu').detach().numpy()
        plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
        if i > num_batches:
            plt.colorbar()
            break
plot_latent(vae, data)
../../_images/690919947c549085b591cb284c014c98d6c48cdcbbf89896a4ab5f40b23e39ee.png

Then, you can plot the reconstructed images together with the risk estimated from Vote wrapper. The x and y axies in the plots are the two latent variables.

def normalize(arr):
    return (arr - np.min(arr))/(np.max(arr) - np.min(arr))


def plot_reconstructed(vae, r0=(-1, 1), r1=(-1, 1), n=12):
    w = 28
    img = np.zeros((n*w, n*w))
    risks = np.zeros((n*w, n*w))
    for i, y in enumerate(np.linspace(*r1, n)):
        for j, x in enumerate(np.linspace(*r0, n)):
            z = torch.Tensor([[x, y]]).to(device)
            x_hat, r = vae.decoder(z.repeat(64, 1), return_risk = True)
            x_hat = x_hat[0].reshape(28, 28).to('cpu').detach().numpy()
            r = r[0].reshape(28, 28).to('cpu').detach().numpy()
            img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = normalize(x_hat)
            risks[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = normalize(r)
            
    fig = plt.figure()
    ax1 = fig.add_subplot(1,2,1)
    ax1.imshow(img, extent=[*r0, *r1])
    ax1.set_title("Reconstructed images")
    ax2 = fig.add_subplot(1,2,2)
    ax2.imshow(risks, extent=[*r0, *r1])
    ax2.set_title("Pixelwise uncertainty for\nthe reconstructed images")
    
    
plot_reconstructed(vae)
../../_images/679cccfd59fac64c7584b04b0ad11de8ae62aca79299e05da3136f06f08a80c5.png

In the uncertainty results we can see, the high uncertainty mostly appears at the edges of digital number.