capsa_torch.neo¶
- class capsa_torch.neo.AEWrapper¶
- __init__(integration_sites=2, encoder_dims=(128, 64), decoder_dims=(64, 128), finetune=False, conv_params=None, node_name_filter=None, pixel_wise=False, freeze_encoder=False, sparsity_metric=None, topk_k=None, *, verbose=0, torch_compile=False, symbolic_trace=False)¶
Initialize a Neo Wrapper with configs. :type integration_sites: int :param integration_sites: The number of sites to use when integrating neo wrapper into your model. More integration sites may produce more robust vacuitic uncertainty estimates but will increase computation costs. (default:
2) :type encoder_dims: tuple[int, …] :param encoder_dims: Tuple of integers specifying the dimensions of the encoder layers. The final value must match the first value in decoder_dims. (default:(128, 64)) :type decoder_dims: tuple[int, …] :param decoder_dims: Tuple of integers specifying the dimensions of the decoder layers. The first value must match the final value in encoder_dims. (default:(64, 128)) :type finetune: bool :param finetune: Freeze the existing module weights and only train the added parameters. (default:False) :type conv_params: dict | None :param conv_params: Dictionary of convolution parameters for integration layers. Supported keys:kernel_size: size of the convolution kernels (int)
padding: padding applied to the layers (int)
stride: stride of the convolution (int) (default:
None)
- Parameters:
node_name_filter (list[str] | None) – List of node names to filter out from the graph. If provided, must be a list matching the number of integration sites, or a list of a single string. (default:
None)pixel_wise (bool) – If True, compute vacuity scores per pixel; otherwise, compute a global score. (default:
False)freeze_encoder (bool) – If True, freeze the weights of the encoder layers. This is independent from the finetune argument, which controls whether the original model weights are frozen. (default:
False)sparsity_metric (str | None) – If not None, apply a sparsity regularization to the output of the encoder. Supported values are ‘topk’ and ‘batch_topk’. (default:
None)topk_k (int | None) – The number of top elements to consider when applying sparsity regularization. Defaults to half of the final encoder dimension if sparsity_metric is set but topk_k is not provided. (default:
None)verbose (int) – Set the verbosity level for wrapping.
0 <= verbose <= 2(default:0)torch_compile (bool) – Apply torch’s inductor to compile the wrapped model. This should improve model performance, at the cost of initial overhead. (default:
False)symbolic_trace (bool) – Attempt to use symbolic shapes when tracing the module’s graph. Turning this on will help to avoid rewrapping when the model is fed different input shapes, however it is not well supported by all models. (default:
False)
Note
verboseandsymbolic_traceare keyword arguments only
- __call__(module_or_module_class)¶
Applies wrapper to either an instantiated
torch.nn.Moduleor a class that subclassestorch.nn.Moduleto create a new wrapped implementation.- Parameters:
module_or_module_class – The Module to wrap
- Return type:
_JointBaseModule | type[_JointBaseModule]
- Returns:
The wrapped module, with weights shared with module
Example Usage
Wrapping a Module¶from capsa_torch.sample import Wrapper # or capsa_torch.sculpt, capsa_torch.vote wrapper = Wrapper(n_samples=3, verbose=1) # Initialize a wrapper object with your config options wrapped_module = wrapper(module) # wrap your module y = wrapped_module(x) # Use the wrapped module as usual y, risk = wrapped_module(x, return_risk=True) # Use the wrapped module to obtain risk values
Decorator approach¶from capsa_torch.sample import Wrapper # or capsa_torch.sculpt, capsa_torch.vote @Wrapper(n_samples=3, verbose=1) # Initialize a wrapper object with your config options class MyModule(torch.nn.Module): # Note: MyModule must subclass torch.nn.Module def __init__(self, ...): ... def forward(self, ...): ... wrapped_module = MyModule(...) # Call MyModule's __init__ fn as usual to create a wrapped module y = wrapped_module(x) # Use the wrapped module as usual y, risk = wrapped_module(x, return_risk=True) # Use the wrapped module to obtain risk values
- class capsa_torch.neo.Wrapper¶
- __init__(integration_sites=2, layer_alpha=(2.0, 1.0), layer_out_dims=None, finetune=False, conv_params=None, add_batch_norm=False, node_name_filter=None, pixel_wise=False, *, verbose=0, torch_compile=False, symbolic_trace=False)¶
Initialize a Neo Wrapper with configs. :type integration_sites: int :param integration_sites: The number of sites to use when integrating neo wrapper into your model. More integration sites may produce more robust vacuitic uncertainty estimates but will increase computation costs. (default:
2) :type layer_alpha: tuple[float, …] :param layer_alpha: Controls the structure of neo integrations. Pair of floats. Larger values produce more robust vacuitic uncertainty estimates but with more compute and memory overhead. (default:(2.0, 1.0)) :type layer_out_dims: tuple[int, …] | None :param layer_out_dims: Optional sizes for the output dimensions of integration layers. If provided, must be a tuple matching the number of integration sites. (default:None) :type finetune: bool :param finetune: Freeze the existing module weights and only train the added parameters. (default:False) :type conv_params: dict | None :param conv_params: Dictionary of convolution parameters for integration layers. Supported keys:kernel_size: size of the convolution kernels (int)
padding: padding applied to the layers (int)
stride: stride of the convolution (int) (default:
None)
- Parameters:
add_batch_norm (bool) – Whether to insert BatchNorm layers before the integration convolutions. (default:
False)node_name_filter (list[str] | None) – List of node names to filter out from the graph. If provided, must be a list matching the number of integration sites, or a list of a single string. (default:
None)pixel_wise (bool) – If True, compute vacuity scores per pixel; otherwise, compute a global score. (default:
False)verbose (int) – Set the verbosity level for wrapping.
0 <= verbose <= 2(default:0)torch_compile (bool) – Apply torch’s inductor to compile the wrapped model. This should improve model performance, at the cost of initial overhead. (default:
False)symbolic_trace (bool) – Attempt to use symbolic shapes when tracing the module’s graph. Turning this on will help to avoid rewrapping when the model is fed different input shapes, however it is not well supported by all models. (default:
False)
Note
verboseandsymbolic_traceare keyword arguments only
- __call__(module_or_module_class)¶
Applies wrapper to either an instantiated
torch.nn.Moduleor a class that subclassestorch.nn.Moduleto create a new wrapped implementation.- Parameters:
module_or_module_class – The Module to wrap
- Return type:
_JointBaseModule | type[_JointBaseModule]
- Returns:
The wrapped module, with weights shared with module
Example Usage
Wrapping a Module¶from capsa_torch.sample import Wrapper # or capsa_torch.sculpt, capsa_torch.vote wrapper = Wrapper(n_samples=3, verbose=1) # Initialize a wrapper object with your config options wrapped_module = wrapper(module) # wrap your module y = wrapped_module(x) # Use the wrapped module as usual y, risk = wrapped_module(x, return_risk=True) # Use the wrapped module to obtain risk values
Decorator approach¶from capsa_torch.sample import Wrapper # or capsa_torch.sculpt, capsa_torch.vote @Wrapper(n_samples=3, verbose=1) # Initialize a wrapper object with your config options class MyModule(torch.nn.Module): # Note: MyModule must subclass torch.nn.Module def __init__(self, ...): ... def forward(self, ...): ... wrapped_module = MyModule(...) # Call MyModule's __init__ fn as usual to create a wrapped module y = wrapped_module(x) # Use the wrapped module as usual y, risk = wrapped_module(x, return_risk=True) # Use the wrapped module to obtain risk values