capsa_torch.sample¶
- class capsa_torch.sample.Wrapper¶
- __init__(n_samples=5, distribution=Bernoulli(p=0.1), over=WrapOver.WEIGHTS, trainable=False, param_filter=None, *, torch_compile=False, verbose=0, symbolic_trace=True)¶
Initialize a Sample Wrapper with configs.
- Parameters:
n_samples (int) – Number of runs through the model. Higher sample numbers will give you a more accurate uncertainty value, but will also cost more memory and compute time. (default:
5
)distribution (Distribution) – The distribution to use when introducing noise into the model. (default:
Bernoulli(p=0.1)
)over (WrapOver) – The part of the model to apply stochastic noise to. One of
[WrapOver.WEIGHTS]
. (default:<WrapOver.WEIGHTS: 'weights'>
)trainable (bool) – trainable = True would create a more traditional Bayesian NN. False would approximate with a fixed noise distribution. (default:
False
)param_filter (str | Callable[[str, Tensor], bool] | None) – Either a string representing a regex pattern of parameters to match or a Callable that accepts a parameter name (str) and value (Tensor) as input and returns True if the parameter should be wrapped, False otherwise. Note that this hyperparameter is only used when over=WrapOver.WEIGHTS. (default:
None
)torch_compile (bool) – Apply torch’s torch inductor to compile the wrapped model. This should improve model perfomance, at the cost of initial overhead. (default:
False
)verbose (int) – Set the verbose level for wrapping.
0 <= verbose <= 2
(default:0
)symbolic_trace (bool) – Attempt to use symbolic shapes when tracing the module’s graph. Turning this off may help if the module is failing to wrap, however the resulting graph is more likely to use fixed input dimensions and trigger rewraps when fed different input shapes. (default:
True
)
Note
verbose
andsymbolic_trace
are keyword arguments only
- __call__(module_or_module_class)¶
Applies wrapper to either an instantiated
torch.nn.Module
or a class that subclassestorch.nn.Module
to create a new wrapped implementation.- Parameters:
module_or_module_class (
TypeVar
(T
,Module
, type[torch.nn.Module])) – The Module to wrap- Return type:
TypeVar
(T
,Module
, type[torch.nn.Module])- Returns:
The wrapped module, with weights shared with module
Example Usage
from capsa_torch.sample import Wrapper # or capsa_torch.sculpt, capsa_torch.vote wrapper = Wrapper(n_samples=3, verbose=1) # Initialize a wrapper object with your config options wrapped_module = wrapper(module) # wrap your module y = wrapped_module(x) # Use the wrapped module as usual y, risk = wrapped_module(x, return_risk=True) # Use the wrapped module to obtain risk values
from capsa_torch.sample import Wrapper # or capsa_torch.sculpt, capsa_torch.vote @Wrapper(n_samples=3, verbose=1) # Initialize a wrapper object with your config options class MyModule(torch.nn.Module): # Note: MyModule must subclass torch.nn.Module def __init__(self, ...): ... def forward(self, ...): ... wrapped_module = MyModule(...) # Call MyModule's __init__ fn as usual to create a wrapped module y = wrapped_module(x) # Use the wrapped module as usual y, risk = wrapped_module(x, return_risk=True) # Use the wrapped module to obtain risk values
- class capsa_torch.sample.WrapOver¶
Enum representing which part of the model is wrapped to produce noise into the model outputs
- WEIGHTS = 'weights'¶
Inject noise into the model predictions through the model’s weights
- ACTIVATIONS = 'activations'¶
Inject noise into model predictions after every activation function call
- class capsa_torch.sample.Bernoulli¶
Config Class to represent a Bernoulli Distribution for the Sample Wrapper
- __init__(p=0.1)¶
Initialize a Bernoulli Distribution object
- Parameters:
p (
float
) – The drop probability of the Bernoulli distribution (default:0.1
)
- class capsa_torch.sample.Normal¶
Config Class to represent a Normal Distribution for the Sample Wrapper
- __init__(init_sigma=0.2)¶
Initialize a Normal Distribution object
- Parameters:
init_sigma – The starting value for sigma (default:
0.2
)
- class capsa_torch.sample.Distribution¶
Base Config Class to represent a Distribution for the Sample Wrapper