Vote Wrapper: MNIST Classification

In this tutorial we show how to wrap a MNIST classifier with the Vote wrapper. Unlike the Sample wrapper, Vote wrapper requires further training after wrapped, we show here how to initialize and train the Vote wrapped model.

Step 1: Initial Setup

import torch
import torchvision
import torch.nn.functional as F
from torch.optim import Adam
import matplotlib.pyplot as plt
import tqdm
torch.manual_seed(42)
n_epochs = 5
batch_size = 128
learning_rate = 1e-3

Load the MNIST training and test set

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        ".",
        train=True,
        download=True,
        transform=torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        ".",
        train=False,
        download=True,
        transform=torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
)

Step 2: Define a model

In this tutorial we will use a simple convolutional neural network

class ConvMod(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = torch.nn.Dropout2d()
        self.fc1 = torch.nn.Linear(320, 50)
        self.fc2 = torch.nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x)
        x = self.fc2(x)
        return x

model = ConvMod()

Step 3: Instantiate the wrapper and wrap the model

Here we create a new instance of the Vote wrapper with n_voters=8.

Then we wrap our model.

Since the Vote wrapper produces new parameters during wrapping, we need to trigger the wrapping process before instantiating our optimizer. To do so, simply call the wrapped model with sample inputs.

from capsa_torch import vote

wrapper = vote.Wrapper(n_voters=8)
model = wrapper(model)

sample_x, sample_y= next(iter(train_loader))
_ = model(sample_x)

Step 4: Training

Vote wrapper must be trained to in order to return accurate measures of uncertainty. Note that while training with the Vote wrapper in particular, you must pass tile_and_reduce=False to the wrapped model. You will also need to ensure that your batch_size is a multiple of n_voters. If you have an original trained model, the batch size after wrapping can be set to the number of voters times the original batch size.

model.train()
model.to('cuda')
optim = Adam(model.parameters(), lr=learning_rate)

for epoch in range(n_epochs):
    prog_bar = tqdm.tqdm(enumerate(train_loader), total=len(train_loader), postfix=dict(epoch=0, loss=0))

    for batch_idx, (x, y) in prog_bar:
        optim.zero_grad()
        pred= model(x.to('cuda'), tile_and_reduce=False)
        loss = F.cross_entropy(pred, y.to('cuda'), reduce="mean")
        loss.backward()
        optim.step()

        prog_bar.set_postfix(epoch=epoch, loss=loss.item())
100%|██████████| 468/468 [00:08<00:00, 54.27it/s, epoch=0, loss=0.22] 
100%|██████████| 468/468 [00:07<00:00, 60.22it/s, epoch=1, loss=0.181]
100%|██████████| 468/468 [00:07<00:00, 60.33it/s, epoch=2, loss=0.259] 
100%|██████████| 468/468 [00:07<00:00, 59.75it/s, epoch=3, loss=0.0963]
100%|██████████| 468/468 [00:07<00:00, 59.91it/s, epoch=4, loss=0.203] 

Step 5: Visualize and analysis the uncertainty

We visualize the images with smallest and largest uncertainties in the testing dataset, and plot the risks as error bars of the predictions for each label.

# Move the model to CPU
model.to('cpu')

# Get a batch of test images and their labels
test_imgs, labels = next(iter(test_loader))

# Get predictions and risks for the test images
predictions, risk = model(test_imgs, return_risk=True)

# Calculate the average risk for each image
avg_risks = risk.mean(dim=1)

# Find the indices of the images with the smallest and largest average risk
min_risk_index = torch.argmin(avg_risks).item()
max_risk_index = torch.argmax(avg_risks).item()

# Plot the image with the smallest risk
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

axes[0, 0].imshow(test_imgs[min_risk_index][0], cmap="gray")
axes[0, 0].set_title("Image with Smallest Risk")
axes[0, 1].errorbar(list(range(10)), predictions[min_risk_index].detach().numpy(), 
                    yerr=risk[min_risk_index].detach().numpy(), fmt='o')
axes[0, 1].set_ylabel('Logit')
axes[0, 1].set_ylim(-20, 20)
axes[0, 1].set_xticks(list(range(10)))
axes[0, 1].set_xlabel('Label')
axes[0, 1].set_title(f"Prediction: {torch.argmax(predictions[min_risk_index]).item()}, Avg Risk: {avg_risks[min_risk_index].item():.2f}")

# Plot the image with the largest risk
axes[1, 0].imshow(test_imgs[max_risk_index][0], cmap="gray")
axes[1, 0].set_title("Image with Largest Risk")
axes[1, 1].errorbar(list(range(10)), predictions[max_risk_index].detach().numpy(), 
                    yerr=risk[max_risk_index].detach().numpy(), fmt='o')
axes[1, 1].set_ylabel('Logit')
axes[1, 1].set_ylim(-20, 20)
axes[1, 1].set_xticks(list(range(10)))
axes[1, 1].set_xlabel('Label')
axes[1, 1].set_title(f"Prediction: {torch.argmax(predictions[max_risk_index]).item()}, Avg Risk: {avg_risks[max_risk_index].item():.2f}")

plt.tight_layout()
plt.show()
../../_images/8fbb60f4b3823779c670b9ac2b0e3b6870691d3f85d40a64283dc6ac105014b8.png

Here we note that, the risk of classification is not the same as probability distribution of the predicted classes. The image with largest risk is not necessary the one whose predicted label has the smallest probability, instead, its the image whose predicted probability distribution has the largest variance. The epistemic uncertainty in a classification task represents the uncertainty of the predicted probability distribution, in other word, it is the “distribution over distribution”.