Vote Wrapper: MNIST Classification¶
In this tutorial we show how to wrap a MNIST classifier with the Vote
wrapper. Unlike the Sample
wrapper, Vote
wrapper requires further training after wrapped, we show here how to initialize and train the Vote
wrapped model.
Step 1: Initial Setup¶
import torch
import torchvision
import torch.nn.functional as F
from torch.optim import Adam
import matplotlib.pyplot as plt
import tqdm
torch.manual_seed(42)
n_epochs = 5
batch_size = 128
learning_rate = 1e-3
Load the MNIST training and test set
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
".",
train=True,
download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
".",
train=False,
download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
Step 2: Define a model¶
In this tutorial we will use a simple convolutional neural network
class ConvMod(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = torch.nn.Dropout2d()
self.fc1 = torch.nn.Linear(320, 50)
self.fc2 = torch.nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.flatten(start_dim=1)
x = F.relu(self.fc1(x))
x = F.dropout(x)
x = self.fc2(x)
return x
model = ConvMod()
Step 3: Instantiate the wrapper and wrap the model¶
Here we create a new instance of the Vote
wrapper with n_voters=8
.
Then we wrap our model.
Since the Vote
wrapper produces new parameters during wrapping, we need to trigger the wrapping process before instantiating our optimizer. To do so, simply call the wrapped model with sample inputs.
from capsa_torch import vote
wrapper = vote.Wrapper(n_voters=8)
model = wrapper(model)
sample_x, sample_y= next(iter(train_loader))
_ = model(sample_x)
Step 4: Training¶
Vote
wrapper must be trained to in order to return accurate measures of uncertainty. Note that while training with the Vote
wrapper in particular, you must pass tile_and_reduce=False
to the wrapped model. You will also need to ensure that your batch_size
is a multiple of n_voters
. If you have an original trained model, the batch size after wrapping can be set to the number of voters times the original batch size.
model.train()
model.to('cuda')
optim = Adam(model.parameters(), lr=learning_rate)
for epoch in range(n_epochs):
prog_bar = tqdm.tqdm(enumerate(train_loader), total=len(train_loader), postfix=dict(epoch=0, loss=0))
for batch_idx, (x, y) in prog_bar:
optim.zero_grad()
pred= model(x.to('cuda'), tile_and_reduce=False)
loss = F.cross_entropy(pred, y.to('cuda'), reduce="mean")
loss.backward()
optim.step()
prog_bar.set_postfix(epoch=epoch, loss=loss.item())
100%|██████████| 468/468 [00:08<00:00, 54.27it/s, epoch=0, loss=0.22]
100%|██████████| 468/468 [00:07<00:00, 60.22it/s, epoch=1, loss=0.181]
100%|██████████| 468/468 [00:07<00:00, 60.33it/s, epoch=2, loss=0.259]
100%|██████████| 468/468 [00:07<00:00, 59.75it/s, epoch=3, loss=0.0963]
100%|██████████| 468/468 [00:07<00:00, 59.91it/s, epoch=4, loss=0.203]
Step 5: Visualize and analysis the uncertainty¶
We visualize the images with smallest and largest uncertainties in the testing dataset, and plot the risks as error bars of the predictions for each label.
# Move the model to CPU
model.to('cpu')
# Get a batch of test images and their labels
test_imgs, labels = next(iter(test_loader))
# Get predictions and risks for the test images
predictions, risk = model(test_imgs, return_risk=True)
# Calculate the average risk for each image
avg_risks = risk.mean(dim=1)
# Find the indices of the images with the smallest and largest average risk
min_risk_index = torch.argmin(avg_risks).item()
max_risk_index = torch.argmax(avg_risks).item()
# Plot the image with the smallest risk
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes[0, 0].imshow(test_imgs[min_risk_index][0], cmap="gray")
axes[0, 0].set_title("Image with Smallest Risk")
axes[0, 1].errorbar(list(range(10)), predictions[min_risk_index].detach().numpy(),
yerr=risk[min_risk_index].detach().numpy(), fmt='o')
axes[0, 1].set_ylabel('Logit')
axes[0, 1].set_ylim(-20, 20)
axes[0, 1].set_xticks(list(range(10)))
axes[0, 1].set_xlabel('Label')
axes[0, 1].set_title(f"Prediction: {torch.argmax(predictions[min_risk_index]).item()}, Avg Risk: {avg_risks[min_risk_index].item():.2f}")
# Plot the image with the largest risk
axes[1, 0].imshow(test_imgs[max_risk_index][0], cmap="gray")
axes[1, 0].set_title("Image with Largest Risk")
axes[1, 1].errorbar(list(range(10)), predictions[max_risk_index].detach().numpy(),
yerr=risk[max_risk_index].detach().numpy(), fmt='o')
axes[1, 1].set_ylabel('Logit')
axes[1, 1].set_ylim(-20, 20)
axes[1, 1].set_xticks(list(range(10)))
axes[1, 1].set_xlabel('Label')
axes[1, 1].set_title(f"Prediction: {torch.argmax(predictions[max_risk_index]).item()}, Avg Risk: {avg_risks[max_risk_index].item():.2f}")
plt.tight_layout()
plt.show()
Here we note that, the risk of classification is not the same as probability distribution of the predicted classes. The image with largest risk is not necessary the one whose predicted label has the smallest probability, instead, its the image whose predicted probability distribution has the largest variance. The epistemic uncertainty in a classification task represents the uncertainty of the predicted probability distribution, in other word, it is the “distribution over distribution”.