Basic Usage¶
Initializing a wrapper¶
Currently, we have 3 wrappers (sample
, sculpt
, and vote
) implemented in capsa_torch
.
from capsa_torch import sample # or sculpt, vote
# Initialize a wrapper object with wrapper specific hyperparamters
wrapper = sample.Wrapper(n_samples=5, distribution=sample.Normal())
Note
All available hyperparameters for each wrapper are listed in the API Reference under the section for each wrapper.
Currently, we have 3 wrappers (sample
, sculpt
, and vote
) implemented in capsa_tf
.
from capsa_tf import sample # or sculpt, vote
# Initialize a wrapper object with wrapper specific hyperparamters
wrapper = sample.Wrapper(n_samples=5, distribution=sample.Normal())
Note
All available hyperparameters for each wrapper are listed in the API Reference under the section for each wrapper.
Wrapping a model¶
To wrap a PyTorch model it must be in the form of a torch.nn.Module
. This can either be a custom Module
, or a Module
from a package.
Note
Our wrappers work best on models that take (float/complex) Tensors as input and return (float/complex) Tensors as outputs. To read more about limitations, click here.
Reusing the wrapper we initialized in the previous step:
from torch import nn
# Define your model
class Model(nn.Module):
def __init__(self):
...
def forward(self, x):
...
model = Model()
wrapped_model = wrapper(model) # It's this easy
To wrap a Tensorflow model it must be defined as a callable function.
Note
Our wrappers work best on models that take (float/complex) Tensors as input and return (float/complex) Tensors as outputs. To read more about limitations, click here.
Reusing the wrapper we initialized in the previous step:
# Define your model as a function
def model(...):
...
wrapped_model = wrapper(model) # It's this easy
Wrapping with decorator syntax¶
Another way to wrap models is to use our decorator syntax. This can sometimes be more convenient. We simply decorate a torch.nn.Module
subclass with the wrapper and then when the subclass is instantiated it will be immediately wrapped.
Again reusing the same wrapper initialized above with no changes:
from torch import nn
@wrapper
class Model(nn.Module):
def __init__(self):
...
def forward(self, x):
...
wrapped_model = Model() # Initializes module and immediately wraps
Note
This approach even works if Model.__init__
takes additional arguments.
Another way to wrap models is to use our decorator syntax. This is often more convenient when defining models as tensorflow functions. We simply decorate the function with the wrapper and calls to the function will automatically use the wrapped implementation.
Again reusing the same wrapper initialized above with no changes:
@wrapper
def model(...):
...
model(...) # Calls the wrapped model
Getting Risk from a Wrapped Module¶
Once you’ve wrapped a model you may not notice anything immediately different about it. This is by design. We make the returned model behaviour as similar to the original unwrapped model as possible. That means, by default, we don’t return any risk values, unless you call your model with the aptly named return_risk=True
keyword argument.
# When called with the regular inputs, the outputs are the same as before wrapping
out = wrapped_model(...)
# Only when `return_risk=True`, risk is returned
out, risk = wrapped_model(..., return_risk=True)
# Note: we return a `RiskOutput` namedtuple which can be
# unpacked (as above) or used directly
risk_output = wrapped_model(..., return_risk=True)
risk_output.pred # Model prediction
risk_output.risk # Model risk
Your model doesn’t have to return a single output. Capsa is designed to work with any arbitrarily nested “PyTree” structure.The wrapped model will (with return_risk=True
) always return the RiskOutput
named tuple of two elements. The first is a PyTree representing the model’s prediction and the second is a PyTree representing the model’s risk.
For example, below is a model returns a tuple containing a dictionary containing a list.
out = model(...)
# out = ({"key": [y_1, y_2]}, y_3)
wrapped_model = wrapper(model)
out = wrapped_model(...)
# out = ({"key": [y_1, y_2]}, y_3)
out, risk = wrapped_model(..., return_risk=True)
# out = ({"key": [y_1, y_2]}, y_3)
# risk = ({"key": [y_1_risk, y_2_risk]}, y_3_risk)