Basic Usage

Initializing a wrapper

Currently, we have 3 wrappers (sample, sculpt, and vote) implemented in capsa_torch.

from capsa_torch import sample # or sculpt, vote

# Initialize a wrapper object with wrapper specific hyperparamters
wrapper = sample.Wrapper(n_samples=5, distribution=sample.Normal())

Note

All available hyperparameters for each wrapper are listed in the API Reference under the section for each wrapper.

Currently, we have 3 wrappers (sample, sculpt, and vote) implemented in capsa_tf.

from capsa_tf import sample # or sculpt, vote

# Initialize a wrapper object with wrapper specific hyperparamters
wrapper = sample.Wrapper(n_samples=5, distribution=sample.Normal())

Note

All available hyperparameters for each wrapper are listed in the API Reference under the section for each wrapper.

Wrapping a model

To wrap a PyTorch model it must be in the form of a torch.nn.Module. This can either be a custom Module, or a Module from a package.

Note

Our wrappers work best on models that take (float/complex) Tensors as input and return (float/complex) Tensors as outputs. To read more about limitations, click here.

Reusing the wrapper we initialized in the previous step:

from torch import nn

# Define your model
class Model(nn.Module):
    def __init__(self):
        ...

    def forward(self, x):
        ...

model = Model()
wrapped_model = wrapper(model) # It's this easy

To wrap a Tensorflow model it must be defined as a callable function.

Note

Our wrappers work best on models that take (float/complex) Tensors as input and return (float/complex) Tensors as outputs. To read more about limitations, click here.

Reusing the wrapper we initialized in the previous step:

# Define your model as a function
def model(...):
    ...

wrapped_model = wrapper(model) # It's this easy

Wrapping with decorator syntax

Another way to wrap models is to use our decorator syntax. This can sometimes be more convenient. We simply decorate a torch.nn.Module subclass with the wrapper and then when the subclass is instantiated it will be immediately wrapped.

Again reusing the same wrapper initialized above with no changes:

from torch import nn

@wrapper
class Model(nn.Module):
    def __init__(self):
        ...

    def forward(self, x):
        ...

wrapped_model = Model() # Initializes module and immediately wraps

Note

This approach even works if Model.__init__ takes additional arguments.

Another way to wrap models is to use our decorator syntax. This is often more convenient when defining models as tensorflow functions. We simply decorate the function with the wrapper and calls to the function will automatically use the wrapped implementation.

Again reusing the same wrapper initialized above with no changes:

@wrapper
def model(...):
    ...

model(...) # Calls the wrapped model

Getting Risk from a Wrapped Module

Once you’ve wrapped a model you may not notice anything immediately different about it. This is by design. We make the returned model behaviour as similar to the original unwrapped model as possible. That means, by default, we don’t return any risk values, unless you call your model with the aptly named return_risk=True keyword argument.

# When called with the regular inputs, the outputs are the same as before wrapping
out = wrapped_model(...)

# Only when `return_risk=True`, risk is returned 
out, risk = wrapped_model(..., return_risk=True)

# Note: we return a `RiskOutput` namedtuple which can be 
# unpacked (as above) or used directly
risk_output = wrapped_model(..., return_risk=True)

risk_output.pred # Model prediction
risk_output.risk # Model risk

Your model doesn’t have to return a single output. Capsa is designed to work with any arbitrarily nested “PyTree” structure.The wrapped model will (with return_risk=True) always return the RiskOutput named tuple of two elements. The first is a PyTree representing the model’s prediction and the second is a PyTree representing the model’s risk.

For example, below is a model returns a tuple containing a dictionary containing a list.

out = model(...)
# out = ({"key": [y_1, y_2]}, y_3) 

wrapped_model = wrapper(model)

out = wrapped_model(...)
# out = ({"key": [y_1, y_2]}, y_3) 

out, risk = wrapped_model(..., return_risk=True)
# out = ({"key": [y_1, y_2]}, y_3) 
# risk = ({"key": [y_1_risk, y_2_risk]}, y_3_risk)