capsa_torch.sculpt

class capsa_torch.sculpt.Wrapper
__init__(distribution=Dist[Normal], n_layers=3, *, torch_compile=False, verbose=0, symbolic_trace=True)

Initialize a Sculpt Wrapper with configs.

Parameters:
  • distribution (type[Distribution]) – The type of distribution that the model’s outputs should be set to. (default: Dist[Normal])

  • n_layers (int) – The number of model output layers that should be modified to produce standard deviation outputs. Note: “layers” here only weakly approximates the layers of a neural network. (default: 3)

  • torch_compile (bool) – Apply torch’s inductor to compile the wrapped model. This should improve model perfomance, at the cost of initial overhead. (default: False)

  • verbose (int) – Set the verbosity level for wrapping. 0 <= verbose <= 2 (default: 0)

  • symbolic_trace (bool) – Attempt to use symbolic shapes when tracing the module’s graph. Turning this off may help if the module is failing to wrap, however the resulting graph is more likely to use fixed input dimensions and trigger rewraps when fed different input shapes. (default: True)

Note

verbose and symbolic_trace are keyword arguments only

__call__(module_or_module_class)

Applies wrapper to either an instantiated torch.nn.Module or a class that subclasses torch.nn.Module to create a new wrapped implementation.

Parameters:

module_or_module_class (TypeVar(T, Module, type[torch.nn.Module])) – The Module to wrap

Returns:

The wrapped module, with weights shared with module

Example Usage

Wrapping a Module
from capsa_torch.sample import Wrapper # or capsa_torch.sculpt, capsa_torch.vote
wrapper = Wrapper(n_samples=3, verbose=1) # Initialize a wrapper object with your config options

wrapped_module = wrapper(module) # wrap your module

y = wrapped_module(x) # Use the wrapped module as usual
y, risk = wrapped_module(x, return_risk=True) # Use the wrapped module to obtain risk values
Decorator approach
from capsa_torch.sample import Wrapper # or capsa_torch.sculpt, capsa_torch.vote

@Wrapper(n_samples=3, verbose=1) # Initialize a wrapper object with your config options
class MyModule(torch.nn.Module): # Note: MyModule must subclass torch.nn.Module
    def __init__(self, ...):
        ...

    def forward(self, ...):
        ...

wrapped_module = MyModule(...) # Call MyModule's __init__ fn as usual to create a wrapped module

y = wrapped_module(x) # Use the wrapped module as usual
y, risk = wrapped_module(x, return_risk=True) # Use the wrapped module to obtain risk values
class capsa_torch.sculpt.Normal

Normal Distribution for the Sculpt Wrapper

static loss_function(target, pred, risk, reduction='mean')

Compute the distribution loss from target, pred, and risk values.

Parameters:
  • target (Tensor) – The target value for the model output

  • pred (Tensor) – The predicted output distribution mean value

  • risk (Tensor) – The predicted output distribution sigma value

  • reduction (str) – The reduction approach to use on the output. Options include: [“mean”, “sum”, “none”]. (default: 'mean')

Return type:

Tensor

Returns:

The computed loss value

static sample(out)

Samples from a normal distribution with mean mu and standard deviation sigma

Parameters:

out (RiskOutput | tuple[PyTree, PyTree]) – Object containing both mu and sigma outputs from the module

Return type:

PyTree

Returns:

A sample from the distribution

static to_risk(risk_out)

Convert predicted risk distribution parameters to a single risk value.

Parameters:

risk_out (Any) – Risk output value from wrapped model. (i.e. out.risk)

Return type:

Any

Returns:

The computed risk values

Example Usage

from capsa_torch import sculpt
wrapper = sculpt.Wrapper(distribution=Normal)
wrapped_module = wrapper(module)
out = wrapped_module(x, return_risk=True)

risk_values = Normal.to_risk(out.risk)
class capsa_torch.sculpt.MultiVariateNormal

MultiVariateNormal Distribution for the Sculpt Wrapper

static loss_function(target, pred, risk, reduction='mean')

Compute the distribution loss from target, pred, and risk values.

Parameters:
  • target (Tensor) – The target value for the model output

  • pred (Tensor) – The predicted output distribution mean value

  • risk (Tensor) – The predicted output distribution covariance scale_tril value

  • reduction (str) – The reduction approach to use on the output. Options include: [“mean”, “sum”, “none”]. (default: 'mean')

Return type:

Tensor

Returns:

The computed loss value

static sample(out)

Samples from a multivariatenormal distribution with the predicted mu and cov values

Parameters:

out (RiskOutput | tuple[PyTree, PyTree]) – Either RiskOutput or tuple containing mu and cov outputs of model

Return type:

PyTree

Returns:

A sample from the distribution

static to_risk(risk_out)

Convert predicted risk distribution parameters to a single risk value.

Parameters:

risk_out (Any) – Risk output value from wrapped model. (i.e. out.risk)

Return type:

Any

Returns:

The computed risk values

Example Usage

from capsa_torch import sculpt
wrapper = sculpt.Wrapper(distribution=MultiVariateNormal)
wrapped_module = wrapper(module)
out = wrapped_module(x, return_risk=True)

risk_values = MultiVariateNormal.to_risk(out.risk)
class capsa_torch.sculpt.NormalInverseGamma

NormalInverseGamma Distribution for the Sculpt Wrapper

static loss_function(target, pred, risk, coeff=1.0, reduction='mean')

Compute the distribution loss from target, pred, and risk values.

Parameters:
  • target (Tensor) – The target value for the model output

  • pred (Tensor) – The predicted output distribution mean value

  • risk (NormalInverseGammaRisk | tuple[Tensor, Tensor, Tensor]) – The predicted output distribution covariance scale_tril value

  • coeff (float) – A multiplier for the regularization loss (default: 1.0)

  • reduction (str) – The reduction approach to use on the output. Options include: [“mean”, “sum”, “none”]. (default: 'mean')

Return type:

Tensor

Returns:

The computed loss value

static sample(out)

Samples from the predicted output normal-inverse-gamma distribution

Parameters:

out (RiskOutput | tuple[Tensor, NormalInverseGammaRisk]) – Either RiskOutput or tuple containing model predictions and risk values

Return type:

PyTree

Returns:

A sample from the distribution

static to_risk(risk_out)

Convert predicted risk distribution parameters to a single risk value.

Parameters:

risk_out (Any) – Risk output value from wrapped model. (i.e. out.risk)

Return type:

Any

Returns:

The computed risk values

Example Usage

from capsa_torch import sculpt
wrapper = sculpt.Wrapper(distribution=NormalInverseGamma)
wrapped_module = wrapper(module)
out = wrapped_module(x, return_risk=True)

risk_values = NormalInverseGamma.to_risk(out.risk)
class capsa_torch.sculpt.NormalInverseWishart

NormalInverseWishart Distribution for the Sculpt Wrapper

static loss_function(target, pred, risk, reduction='mean')

Compute the distribution loss from target, pred, and risk values.

Parameters:
  • target (Tensor) – The target value for the model output

  • pred (Tensor) – The predicted output distribution mean value

  • risk (NormalInverseWishartRisk | tuple[Tensor, Tensor]) – The predicted output distribution hyperparameters

  • reduction (str) – The reduction approach to use on the output. Options include: [“mean”, “sum”, “none”]. (default: 'mean')

Return type:

Tensor

Returns:

The computed loss value

static sample(out)

Samples from the predicted output normal-inverse-wishart distribution

Parameters:

out (RiskOutput | tuple[Tensor, NormalInverseWishartRisk]) – Either RiskOutput or tuple containing model predictions and risk values

Return type:

PyTree

Returns:

A sample from the distribution

static to_risk(risk_out)

Convert predicted risk distribution parameters to a single risk value.

Parameters:

risk_out (Any) – Risk output value from wrapped model. (i.e. out.risk)

Return type:

Any

Returns:

The computed risk values

Example Usage

from capsa_torch import sculpt
wrapper = sculpt.Wrapper(distribution=NormalInverseWishart)
wrapped_module = wrapper(module)
out = wrapped_module(x, return_risk=True)

risk_values = NormalInverseWishart.to_risk(out.risk)
class capsa_torch.sculpt.Distribution

Static Base Class to represent a Distribution for the Sculpt Wrapper

abstract static loss_function(*args, **kwargs)
Return type:

Tensor

abstract static sample(out)
Return type:

Any

abstract static to_risk(risk_out)
Return type:

Any

class capsa_torch.sculpt.NormalInverseGammaRisk

Normal Inverse Gamma Risk NamedTuple

nu: Tensor

Alias for field number 0

alpha: Tensor

Alias for field number 1

beta: Tensor

Alias for field number 2

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

class capsa_torch.sculpt.NormalInverseWishartRisk

Normal Inverse Wishart Risk NamedTuple

nu: Tensor

Alias for field number 0

lower_tril: Tensor

Alias for field number 1

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.