Vote¶
Background¶
Voting-based wrappers compute a measure of model (i.e., epistemic) uncertainty by training a diverse set of model parameters that each independently vote on the prediction. The final prediction is obtained as the average of all the voter’s predictions and the uncertainty is measured as the disagreement between the voters. If the predictions of the voters have a large variance (i.e., disagreement), then this would be a strong indicator that the predictive process overall is uncertain and should not be trusted.
A naive approach to a voting-based system would be to copy the model multiple times (each with its own random initialization) and train each copy on a different subset of data. While this simple approach is the gold standard of uncertainty estimation, it comes with extremely significant compute, speed, and memory costs as it requires training, storing, and deploying multiple copies of the model all in parallel. It also requires manually setting up the relevant changes to the training and deployment infrastructure. This wrapper provides an automated method for voting-based uncertainty estimation with efficiency achieved through several different mechanisms (e.g., soft-sharing of voter parameters in a shared representation space, voter approximation across the training evolution, etc).
Usage¶
Wrapping your model with capsa_torch.vote
from torch import nn
from capsa_torch import vote
# Define your model
model = nn.Sequential(...)
# Build a voting wrapper and wrap your model!
wrapper = vote.Wrapper(n_voters=N_VOTERS, alpha=1)
wrapped_model = wrapper(model)
Calling your wrapped model
# By default, your wrapped model returns a prediction (the average of all voters)
prediction = wrapped_model(input_batch)
# But if you use `return_risk` you will also automatically get uncertainty too!
prediction, uncertainty = wrapped_model(input_batch, return_risk=True)
Training your wrapped model
Vote wrappers must be trained in order to return accurate measures of uncertainty. Ideally each voter sees different batches of data so they can learn diverse representations.
For training
Use the
tile_and_reduce=False
argument when trainingMultiple input batches of shape:
[B * N, …]
sharded across each voter whereB
is batch size andN
is number of voters.Prediction is not reduced and can be used to train each voter independently.
input_batches = get_batch(batch_size * n_voters) # shape: [B * N, …]
predictions = wrapped_model(input_batches, tile_and_reduce=False) # shape: [B * N, …]
For inference
Use
tile_and_reduce=True
(the default value) during inferenceEach batch is replicated and passed to all voters where prediction is the average of all of the voter outputs.
input_batch = get_batch(batch_size) # shape: [B, …]
prediction = model(input_batch) # shape: [B, …]
Note
tile_and_reduce=False
and return_risk=True
can’t be used together
Examples
API Reference
Wrapping your model with capsa_tf.vote
# Import the module
from capsa_tf import vote
# Define the wrapper arguments
wrapper = vote.Wrapper(alpha=2)
# Wrap the model
model.call_default = wrapper(model.call_default)
Calling your wrapped model
# A single batch is replicated and passed to all voters
# Prediction is the average of all of the voter outputs.
input_batch = get_batch(batch_size) # shape: [B, …]
prediction, risk = model(input_batch, return_risk=True) # shape: [B, …]
Training your wrapped model
Vote wrappers must be trained in order to return accurate measures of uncertainty. Ideally each voter sees different batches of data so they can learn diverse representations.
API Reference