Vote Wrapper: Regression

In this tutorial we show how to wrap a simple regression model with Capsa Vote Wrapper. The Vote wrapper predicts epistemic uncertainty, which increases in regions where data is sparse or absent. This differs from aleatoric uncertainty, which increases where the relationship between inputs and outputs is noisy or ambiguous (compare with the sculpt wrapper regression tutorial).

Step 1: Initial Setup

This step is universal and dependant to your specific model and dataset you are using.

In this example our goal is to estimate a sinusoidal function. We generate synthetic data adding randomness to the function, where the randomness in one area is much larger than the rest, and predict it with a simple neural network.

import torch
import math
import tqdm
import matplotlib.pyplot as plt

torch.manual_seed(1)

def fn(x, noise_level = 0.1):
    y = torch.sin(x)
    noise = torch.randn(x.shape)
    y += noise * noise_level
    return y

x_train_left = torch.linspace(-2*math.pi,-math.pi,1000).unsqueeze(-1)
x_train_right = torch.linspace(math.pi,2*math.pi,1000).unsqueeze(-1)
x_train_full = torch.cat([x_train_left, x_train_right], dim=0)
y_train_full = fn(x_train_full)
plt.scatter(x_train_full, y_train_full, label='Ground truth data', s=5, c='blue',alpha=0.5)
plt.show()
../../_images/bb3b3633a145b606b6eae80ad305a6312a6d82a1358e8e1755631f6c7b67fee5.png

In this tutorial we are using a simple sequential model.

from torch import nn

model = nn.Sequential(nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1))

Step 2: Wrapping the model

Using Capsa requires constructing a wrapper and applying it to your model. This wrapper adds new outputs and parameters to the model, which will be trained in the following step.

from capsa_torch import vote

#Wrap model with vote wrapper
wrapper = vote.Wrapper(n_voters=4, weight_noise=0.5)
wrapped_model = wrapper(model)

Note

The vote-wrapped model will not produce useful outputs unless it is first trained or finetuned, as seen in the following step. Read the wrapper documentation for more details.

Step 3: Training

from sklearn.model_selection import train_test_split

opt = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
loss_func = torch.nn.MSELoss(reduction='mean')

x_train, x_val, y_train, y_val = train_test_split(x_train_full.numpy(), y_train_full.numpy(), test_size=0.2, random_state=42)
x_train, x_val, y_train, y_val = torch.from_numpy(x_train), torch.from_numpy(x_val), torch.from_numpy(y_train), torch.from_numpy(y_val)

prog_bar = tqdm.trange(500)
for t in prog_bar:
    #The model now returns 2 outputs (when return_risk=True)
    y_pred = wrapped_model(x_train, training=True)
    # Use the loss term for your distribution
    loss = loss_func(y_train, y_pred)

    opt.zero_grad()
    loss.backward()
    opt.step()
    
    prog_bar.set_postfix(loss = loss.item())
100%|██████████| 500/500 [00:01<00:00, 296.97it/s, loss=0.0109]

Step 4: Testing & Evaluation

Once you have trained your wrapped model, you will get a risk value for every output to your model if return_risk = True. Here, you can visualize the predicted risk as the orange filling area in the plot.

#Generate testing data
x_test = torch.linspace(-2*math.pi,2*math.pi,500).unsqueeze(-1)
y_test_noisy = fn(x_test)
y_test_true = fn(x_test, noise_level = 0.0)

with torch.no_grad():
    y_pred_test, risk_test = wrapped_model(x_test, return_risk = True)
ln2 = plt.axvspan(-torch.pi, torch.pi, color='lightgray', label='Training Data')
plt.fill_between(x_test.ravel(), y_pred_test.detach().ravel()+3*risk_test.ravel(),y_pred_test.detach().ravel()-3*risk_test.ravel(),color='orange',alpha=1, label='95% Confidence Interval')
plt.scatter(x_test, y_test_noisy, label='Ground truth data', s=5, c='blue',alpha=0.5)
plt.plot(x_test, y_pred_test.detach(), label='Prediction', c='red')
plt.plot(x_test, y_test_true, label='True function', c='black')
plt.legend(loc='best')
<matplotlib.legend.Legend at 0x7aaae0e786d0>
../../_images/54a7c89f7189acd93995eaca467551072df7e808160f9d99a92e7e009a159d3a.png

We can also compare the ground truth errors with the epistemic risk score

gt_error = torch.abs(y_test_true - y_pred_test.detach())

plt.figure()
ax = plt.gca()
plt.ylabel('Epistemic Uncertainty')
ln2 = ax.axvspan(-torch.pi, torch.pi, color='lightgray', label='Training Data')
ln1 = ax.plot(x_test, risk_test, label='Epistemic Uncertainty', c='green')
ax2 = plt.twinx()
ln3 = ax2.plot(x_test, gt_error, label='Ground Truth Error', c='red')
plt.ylabel('Ground Truth Error')
lns = ln1+[ln2]+ln3
labs = [l.get_label() for l in lns]
ax.legend(lns, labs, loc='upper right')
plt.show()
../../_images/3e6f5e735816044f02294257550a545f016e97127cca204433c366a224f8fdde.png