capsa_torch.neo

class capsa_torch.neo.Wrapper
__init__(integration_sites=2, layer_alpha=(2.0, 1.0), *, torch_compile=False, verbose=0, symbolic_trace=True)

Initialize a Neo Wrapper with configs.

Parameters:
  • integration_sites (int) – The number of sites to use when integrating neo wrapper into your model. More integration sites may produce more robust vacuitic uncertainty estimates but will increase computation costs. (default: 2)

  • layer_alpha (tuple[float, float]) – Controls the structure of neo integrations. Pair of floats. Larger values produce more robust vacuitic uncertainty estimates but with more compute and memory overhead. (default: (2.0, 1.0))

  • torch_compile (bool) – Apply torch’s inductor to compile the wrapped model. This should improve model performance, at the cost of initial overhead. (default: False)

  • verbose (int) – Set the verbosity level for wrapping. 0 <= verbose <= 2 (default: 0)

  • symbolic_trace (bool) – Attempt to use symbolic shapes when tracing the module’s graph. Turning this off may help if the module is failing to wrap, however the resulting graph is more likely to use fixed input dimensions and trigger rewraps when fed different input shapes. (default: True)

Note

verbose and symbolic_trace are keyword arguments only

__call__(module_or_module_class)

Applies wrapper to either an instantiated torch.nn.Module or a class that subclasses torch.nn.Module to create a new wrapped implementation.

Parameters:

module_or_module_class (TypeVar(T, Module, type[torch.nn.Module])) – The Module to wrap

Returns:

The wrapped module, with weights shared with module

Example Usage

Wrapping a Module
from capsa_torch.sample import Wrapper # or capsa_torch.sculpt, capsa_torch.vote
wrapper = Wrapper(n_samples=3, verbose=1) # Initialize a wrapper object with your config options

wrapped_module = wrapper(module) # wrap your module

y = wrapped_module(x) # Use the wrapped module as usual
y, risk = wrapped_module(x, return_risk=True) # Use the wrapped module to obtain risk values
Decorator approach
from capsa_torch.sample import Wrapper # or capsa_torch.sculpt, capsa_torch.vote

@Wrapper(n_samples=3, verbose=1) # Initialize a wrapper object with your config options
class MyModule(torch.nn.Module): # Note: MyModule must subclass torch.nn.Module
    def __init__(self, ...):
        ...

    def forward(self, ...):
        ...

wrapped_module = MyModule(...) # Call MyModule's __init__ fn as usual to create a wrapped module

y = wrapped_module(x) # Use the wrapped module as usual
y, risk = wrapped_module(x, return_risk=True) # Use the wrapped module to obtain risk values