Sculpt Wrapper: Regression

In this tutorial we show how to wrap a simple regression model with Capsa Sculpt Wrapper. The data are random points distributed around a simple function, Sculpt wrapper can predict uncertainty that reflects the noise in data, as we domonstrated in the results.

Step 1: Initial Setup

This step is universal and dependant to your specific model and dataset you are using.

In this example our goal is to estimate a sinusoidal function. We generate synthetic data adding randomness to the function, where the randomness in one area is much larger than the rest, and predict it with a simple neural network.

import torch
import math
import tqdm
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(1)
np.random.seed(1)

def fn(x):
    y = np.sin(x)
    high_risk_x = (x > 0) & (x < np.pi / 2)
    noise = np.random.normal(0, 1, x.shape) * high_risk_x + np.random.normal(0, 0.2, x.shape) * ~high_risk_x
    y += noise
    return y

x = np.linspace(-math.pi,math.pi,500)[:, np.newaxis]
y = fn(x)

x = torch.Tensor(x)
y = torch.Tensor(y)
plt.scatter(x, y, label='Ground truth data', s=5, c='blue',alpha=0.5)
plt.show()
../../_images/9e3ba2376bbf58c5c8d07ab84c3f8e7aa530ae7fa1add9f176cde16ce4efeb06.png

In this tutorial we are using a simple sequential model.

from torch import nn

model = nn.Sequential(nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1))

Step 2: Wrapping the model

In order to wrap your model with any wrapper you need to pass in the model along with a sample input to your model. The Sculpt wrapper can output a risk after wrapped, however, the risk is not useful before further training (even if the unwrapped model is already trained).

from capsa_torch import sculpt
#Wrap model with sculpt wrapper
dist = sculpt.Normal
wrapped_model = sculpt.Wrapper(dist)(model)

with torch.no_grad():
    pred, risk = wrapped_model(x, return_risk=True)
learning_rate = 1e-4

loss_func = torch.nn.MSELoss(reduction='mean')
opt = torch.optim.RMSprop(model.parameters(), lr = learning_rate)

Note

Most wrappers require re-training or finetuning. Sculpt is one of the wrappers that requires re-training or finetuning and will not provide useful outputs otherwise. Read the wrapper documentations in order to pick the one that is suitable for your usecase.

Step 3: Training

prog_bar = tqdm.trange(1000)
for t in prog_bar:
    #The model now returns 2 outputs (when return_risk=True)
    y_pred, risk = wrapped_model(x, return_risk = True)
    # Use the loss term for your distribution
    loss = dist.loss_function(y, y_pred, risk)

    opt.zero_grad()
    loss.backward()
    opt.step()
    prog_bar.set_postfix(loss = loss.item())
100%|██████████| 1000/1000 [00:01<00:00, 730.94it/s, loss=0.234]

Step 4: Testing & Evaluation

Once you have trained your wrapped model, you will get a risk value for every output to your model if return_risk = True. Here, you can visualize the predicted risk as the orange filling area in the plot.

#Generate testing data
x_test = torch.linspace(-math.pi, math.pi, 2000).unsqueeze(-1)
y_test = torch.Tensor(fn(x_test.numpy()))

with torch.no_grad():
    y_pred_test, risk_test = wrapped_model(x_test, return_risk = True)
plt.fill_between(x_test.ravel(), y_pred_test.detach().ravel()+2*risk_test.ravel(),y_pred_test.detach().ravel()-2*risk_test.ravel(),color='orange',alpha=1, label='95% Confidence Interval')
plt.scatter(x_test, y_test, label='Ground truth data', s=5, c='blue',alpha=0.5)
plt.plot(x_test, y_pred_test.detach(), label='Prediction', c='red')
plt.legend(loc='best')
<matplotlib.legend.Legend at 0x32ca8ed10>
../../_images/a267775efc8de8ef2375a62edc2ee3d564979569dc259972f80b2d017829c32f.png
plt.plot(x_test, risk_test, label='Sculpt Wrapper Risk', c='green')
plt.plot([-np.pi, 0, 0, np.pi/2, np.pi/2, np.pi], [0.2,0.2,1,1,0.2,0.2], 'k--', label='Actual Risk')
plt.legend(loc='best')
<matplotlib.legend.Legend at 0x32ca827d0>
../../_images/2665af0dd1b62d2a8d4955fe0fbd572c9e148a2d556b28cc6863574a32593212.png