capsa_tf.sample¶
- class capsa_tf.sample.Wrapper¶
- __init__(n_samples=5, distribution=Bernoulli(p=0.1), over=WrapOver.WEIGHTS, inline=False)¶
Initialize a Sample Wrapper with configs.
- Parameters:
n_samples (
int
) – Number of runs through the model. Higher sample numbers will give you a more accurate uncertainty value, but will also cost more memory and compute time. (default:5
)distribution (
Distribution
) – The distribution to use when introducing noise into the model. (default:Bernoulli(p=0.1)
)over (
WrapOver
) – The part of the model to apply stochastic noise to. One of[WrapOver.WEIGHTS, WrapOver.ACTIVATIONS]
. (default:<WrapOver.WEIGHTS: 'weights'>
)inline (
bool
) – Whether to inline inner functions. Defaults to inline=False which disallows inner functions. (default:False
)
- __call__(fn)¶
Applies wrapper to
tf.function
decorated function.- Parameters:
fn (
Callable
) – The function to wrap- Return type:
WrappedFunction
- Returns:
The wrapped function
Example Usage
from capsa_tf import sample # or sculpt, vote wrapper = sample.Wrapper(n_samples=3) # Initialize a wrapper object with your config options def forward_pass(x): ... return y wrapped_forward_pass = wrapper(forward_pass) # wrap your function y = wrapped_forward_pass(x) # Use the wrapped function as usual y, risk = wrapped_forward_pass(x, return_risk=True) # Use the wrapped function to obtain risk values
from capsa_tf import sample # or sculpt, vote @sample.Wrapper(n_samples=3) # Initialize a wrapper object with your config options def forward_pass(x): ... return y y = forward_pass(x) # Use the wrapped function as usual y = forward_pass(x,return_risk=True) # Use the wrapped function to obtain risk values
- class capsa_tf.sample.WrapOver¶
Enum representing which part of the model is wrapped to produce noise into the model outputs
- WEIGHTS = 'weights'¶
Inject noise into the model predictions through the model’s weights
- ACTIVATIONS = 'activations'¶
Inject noise into model predictions after every activation function call
- class capsa_tf.sample.Bernoulli¶
Config Class to represent a Bernoulli Distribution for the Sample Wrapper
- __init__(p=0.1)¶
Initialize a Bernoulli Distribution object
- Parameters:
p (
float
) – The drop probability of the Bernoulli distribution (default:0.1
)
- class capsa_tf.sample.Normal¶
Config Class to represent a Normal Distribution for the Sample Wrapper
- __init__(init_sigma=0.2)¶
Initialize a Normal Distribution object
- Parameters:
init_sigma – The starting value for sigma (default:
0.2
)
- class capsa_tf.sample.Distribution¶
Base Config Class to represent a Distribution for the Sample Wrapper