capsa_torch.sculpt¶
- class capsa_torch.sculpt.Wrapper¶
- __init__(distribution=Dist[Normal], n_layers=3, *, torch_compile=False, verbose=0, symbolic_trace=True)¶
Initialize a Sculpt Wrapper with configs.
- Parameters:
distribution (
type
[Distribution
]) – The type of distribution that the model’s outputs should be set to. (default:Dist[Normal]
)n_layers (
int
) – The number of model output layers that should be modified to produce standard deviation outputs. Note: “layers” here only weakly approximates the layers of a neural network. (default:3
)torch_compile (
bool
) – Apply torch’s inductor to compile the wrapped model. This should improve model perfomance, at the cost of initial overhead. (default:False
)verbose (
int
) – Set the verbosity level for wrapping.0 <= verbose <= 2
(default:0
)symbolic_trace (
bool
) – Attempt to use symbolic shapes when tracing the module’s graph. Turning this off may help if the module is failing to wrap, however the resulting graph is more likely to use fixed input dimensions and trigger rewraps when fed different input shapes. (default:True
)
Note
verbose
andsymbolic_trace
are keyword arguments only
- __call__(module_or_module_class)¶
Applies wrapper to either an instantiated
torch.nn.Module
or a class that subclassestorch.nn.Module
to create a new wrapped implementation.- Parameters:
module_or_module_class (
TypeVar
(T
,Module
, type[torch.nn.Module])) – The Module to wrap- Returns:
The wrapped module, with weights shared with module
Example Usage
from capsa_torch.sample import Wrapper # or capsa_torch.sculpt, capsa_torch.vote wrapper = Wrapper(n_samples=3, verbose=1) # Initialize a wrapper object with your config options wrapped_module = wrapper(module) # wrap your module y = wrapped_module(x) # Use the wrapped module as usual y, risk = wrapped_module(x, return_risk=True) # Use the wrapped module to obtain risk values
from capsa_torch.sample import Wrapper # or capsa_torch.sculpt, capsa_torch.vote @Wrapper(n_samples=3, verbose=1) # Initialize a wrapper object with your config options class MyModule(torch.nn.Module): # Note: MyModule must subclass torch.nn.Module def __init__(self, ...): ... def forward(self, ...): ... wrapped_module = MyModule(...) # Call MyModule's __init__ fn as usual to create a wrapped module y = wrapped_module(x) # Use the wrapped module as usual y, risk = wrapped_module(x, return_risk=True) # Use the wrapped module to obtain risk values
- class capsa_torch.sculpt.Normal¶
Normal Distribution for the Sculpt Wrapper
- static loss_function(target, pred, risk, reduction='mean')¶
Compute the distribution loss from target, pred, and risk values.
- Parameters:
target (
Tensor
) – The target value for the model outputpred (
Tensor
) – The predicted output distribution mean valuerisk (
Tensor
) – The predicted output distribution sigma valuereduction (
str
) – The reduction approach to use on the output. Options include: [“mean”, “sum”, “none”]. (default:'mean'
)
- Return type:
Tensor
- Returns:
The computed loss value
- static sample(out)¶
Samples from a normal distribution with mean
mu
and standard deviationsigma
- Parameters:
out (RiskOutput | tuple[PyTree, PyTree]) – Object containing both
mu
andsigma
outputs from the module- Return type:
PyTree
- Returns:
A sample from the distribution
- static to_risk(risk_out)¶
Convert predicted risk distribution parameters to a single risk value.
- Parameters:
risk_out (
Any
) – Risk output value from wrapped model. (i.e. out.risk)- Return type:
Any
- Returns:
The computed risk values
Example Usage
from capsa_torch import sculpt wrapper = sculpt.Wrapper(distribution=Normal) wrapped_module = wrapper(module) out = wrapped_module(x, return_risk=True) risk_values = Normal.to_risk(out.risk)
- class capsa_torch.sculpt.MultiVariateNormal¶
MultiVariateNormal Distribution for the Sculpt Wrapper
- static loss_function(target, pred, risk, reduction='mean')¶
Compute the distribution loss from target, pred, and risk values.
- Parameters:
target (
Tensor
) – The target value for the model outputpred (
Tensor
) – The predicted output distribution mean valuerisk (
Tensor
) – The predicted output distribution covariance scale_tril valuereduction (
str
) – The reduction approach to use on the output. Options include: [“mean”, “sum”, “none”]. (default:'mean'
)
- Return type:
Tensor
- Returns:
The computed loss value
- static sample(out)¶
Samples from a multivariatenormal distribution with the predicted
mu
andcov
values- Parameters:
out (RiskOutput | tuple[PyTree, PyTree]) – Either RiskOutput or tuple containing
mu
andcov
outputs of model- Return type:
PyTree
- Returns:
A sample from the distribution
- static to_risk(risk_out)¶
Convert predicted risk distribution parameters to a single risk value.
- Parameters:
risk_out (
Any
) – Risk output value from wrapped model. (i.e. out.risk)- Return type:
Any
- Returns:
The computed risk values
Example Usage
from capsa_torch import sculpt wrapper = sculpt.Wrapper(distribution=MultiVariateNormal) wrapped_module = wrapper(module) out = wrapped_module(x, return_risk=True) risk_values = MultiVariateNormal.to_risk(out.risk)
- class capsa_torch.sculpt.NormalInverseGamma¶
NormalInverseGamma Distribution for the Sculpt Wrapper
- static loss_function(target, pred, risk, coeff=1.0, reduction='mean')¶
Compute the distribution loss from target, pred, and risk values.
- Parameters:
target (Tensor) – The target value for the model output
pred (Tensor) – The predicted output distribution mean value
risk (NormalInverseGammaRisk | tuple[Tensor, Tensor, Tensor]) – The predicted output distribution covariance scale_tril value
coeff (float) – A multiplier for the regularization loss (default:
1.0
)reduction (str) – The reduction approach to use on the output. Options include: [“mean”, “sum”, “none”]. (default:
'mean'
)
- Return type:
Tensor
- Returns:
The computed loss value
- static sample(out)¶
Samples from the predicted output normal-inverse-gamma distribution
- Parameters:
out (RiskOutput | tuple[Tensor, NormalInverseGammaRisk]) – Either RiskOutput or tuple containing model predictions and risk values
- Return type:
PyTree
- Returns:
A sample from the distribution
- static to_risk(risk_out)¶
Convert predicted risk distribution parameters to a single risk value.
- Parameters:
risk_out (
Any
) – Risk output value from wrapped model. (i.e. out.risk)- Return type:
Any
- Returns:
The computed risk values
Example Usage
from capsa_torch import sculpt wrapper = sculpt.Wrapper(distribution=NormalInverseGamma) wrapped_module = wrapper(module) out = wrapped_module(x, return_risk=True) risk_values = NormalInverseGamma.to_risk(out.risk)
- class capsa_torch.sculpt.NormalInverseWishart¶
NormalInverseWishart Distribution for the Sculpt Wrapper
- static loss_function(target, pred, risk, reduction='mean')¶
Compute the distribution loss from target, pred, and risk values.
- Parameters:
target (Tensor) – The target value for the model output
pred (Tensor) – The predicted output distribution mean value
risk (NormalInverseWishartRisk | tuple[Tensor, Tensor]) – The predicted output distribution hyperparameters
reduction (str) – The reduction approach to use on the output. Options include: [“mean”, “sum”, “none”]. (default:
'mean'
)
- Return type:
Tensor
- Returns:
The computed loss value
- static sample(out)¶
Samples from the predicted output normal-inverse-wishart distribution
- Parameters:
out (RiskOutput | tuple[Tensor, NormalInverseWishartRisk]) – Either RiskOutput or tuple containing model predictions and risk values
- Return type:
PyTree
- Returns:
A sample from the distribution
- static to_risk(risk_out)¶
Convert predicted risk distribution parameters to a single risk value.
- Parameters:
risk_out (
Any
) – Risk output value from wrapped model. (i.e. out.risk)- Return type:
Any
- Returns:
The computed risk values
Example Usage
from capsa_torch import sculpt wrapper = sculpt.Wrapper(distribution=NormalInverseWishart) wrapped_module = wrapper(module) out = wrapped_module(x, return_risk=True) risk_values = NormalInverseWishart.to_risk(out.risk)
- class capsa_torch.sculpt.Distribution¶
Static Base Class to represent a Distribution for the Sculpt Wrapper
- abstract static loss_function(*args, **kwargs)¶
- Return type:
Tensor
- abstract static sample(out)¶
- Return type:
Any
- abstract static to_risk(risk_out)¶
- Return type:
Any
- class capsa_torch.sculpt.NormalInverseGammaRisk¶
Normal Inverse Gamma Risk NamedTuple
-
nu:
Tensor
¶ Alias for field number 0
-
alpha:
Tensor
¶ Alias for field number 1
-
beta:
Tensor
¶ Alias for field number 2
- count(value, /)¶
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)¶
Return first index of value.
Raises ValueError if the value is not present.
-
nu:
- class capsa_torch.sculpt.NormalInverseWishartRisk¶
Normal Inverse Wishart Risk NamedTuple
-
nu:
Tensor
¶ Alias for field number 0
-
lower_tril:
Tensor
¶ Alias for field number 1
- count(value, /)¶
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)¶
Return first index of value.
Raises ValueError if the value is not present.
-
nu: