Basic Usage

Initializing a wrapper

Currently, we have 4 wrappers (sample, sculpt, vote, and neo) implemented in capsa_torch.

from capsa_torch import vote # or sample, sculpt, neo

# Initialize a wrapper object with wrapper specific hyperparamters
wrapper = vote.Wrapper()

Note

All available hyperparameters for each wrapper are listed in the API Reference under the section for each wrapper.

Currently, we have 3 wrappers (sample, sculpt, and vote) implemented in capsa_tf.

from capsa_tf import sample # or sculpt, vote

# Initialize a wrapper object with wrapper specific hyperparamters
wrapper = sample.Wrapper(n_samples=5, distribution=sample.Normal())

Note

All available hyperparameters for each wrapper are listed in the API Reference under the section for each wrapper.

Wrapping a model

To wrap a PyTorch model it must be in the form of a torch.nn.Module. This can either be a custom Module, or a Module from a package.

Note

Our wrappers work best on models that take (float/complex) Tensors as input and return (float/complex) Tensors as outputs. To read more about limitations, click here.

Reusing the wrapper we initialized in the previous step:

from torch import nn

# Define your model
class Model(nn.Module):
    def __init__(self):
        ...

    def forward(self, x):
        ...

model = Model()
wrapped_model = wrapper(model) # It's this easy

To wrap a Tensorflow model it must be defined as a callable function.

Note

Our wrappers work best on models that take (float/complex) Tensors as input and return (float/complex) Tensors as outputs. To read more about limitations, click here.

Reusing the wrapper we initialized in the previous step:

# Define your model as a function
def model(...):
    ...

wrapped_model = wrapper(model) # It's this easy

Wrapping with decorator syntax

Another way to wrap models is to use our decorator syntax. This can sometimes be more convenient. We simply decorate a torch.nn.Module subclass with the wrapper and then when the subclass is instantiated it will be immediately wrapped.

Again reusing the same wrapper initialized above with no changes:

from torch import nn

@wrapper
class Model(nn.Module):
    def __init__(self):
        ...

    def forward(self, x):
        ...

wrapped_model = Model() # Initializes module and immediately wraps

Note

This approach even works if Model.__init__ takes additional arguments.

Another way to wrap models is to use our decorator syntax. This is often more convenient when defining models as tensorflow functions. We simply decorate the function with the wrapper and calls to the function will automatically use the wrapped implementation.

Again reusing the same wrapper initialized above with no changes:

@wrapper
def model(...):
    ...

model(...) # Calls the wrapped model

Training a wrapped model

Once you’ve wrapped a model, you may not notice anything immediately different about it. This is by design. We make the returned model behaviour as similar to the original unwrapped model as possible. For the most part, you can train/finetune and use wrapped models just as you would an unwrapped model.

There are some exceptions, however, so we encourage reading the documentation for each wrapper, as well as relevant tutorials and use case examples, in order to understand each wrapper’s proper usage. For example, you must add a training=True flag when training or finetuning a model wrapped with the vote wrapper, but should use training=False for validation, testing, and deployment.

Getting Risk from a Wrapped Module

By default, a Capsa-wrapped model’s inputs and outputs look just like the unwrapped base model’s, and the outputs don’t contain any risk values. To get them, simply call your model with the aptly named return_risk=True keyword argument.

# When called with the regular inputs, the outputs are the same as before wrapping
out = wrapped_model(...)

# Only when `return_risk=True`, risk is returned 
out, risk = wrapped_model(..., return_risk=True)

# Note: we return a `RiskOutput` namedtuple which can be 
# unpacked (as above) or used directly
risk_output = wrapped_model(..., return_risk=True)

risk_output.pred # Model prediction
risk_output.risk # Model risk

Your model doesn’t have to return a single output. Capsa is designed to work with any arbitrarily nested “PyTree” structure.The wrapped model will (with return_risk=True) always return the RiskOutput named tuple of two elements. The first is a PyTree representing the model’s prediction and the second is a PyTree representing the model’s risk.

For example, below is a model returns a tuple containing a dictionary containing a list.

out = model(...)
# out = ({"key": [y_1, y_2]}, y_3) 

wrapped_model = wrapper(model)

out = wrapped_model(...)
# out = ({"key": [y_1, y_2]}, y_3) 

out, risk = wrapped_model(..., return_risk=True)
# out = ({"key": [y_1, y_2]}, y_3) 
# risk = ({"key": [y_1_risk, y_2_risk]}, y_3_risk)