capsa_tf.sample

class capsa_tf.sample.Wrapper
__init__(n_samples=5, distribution=Bernoulli(p=0.1), over=WrapOver.WEIGHTS, inline=False)

Initialize a Sample Wrapper with configs.

Parameters:
  • n_samples (int) – Number of runs through the model. Higher sample numbers will give you a more accurate uncertainty value, but will also cost more memory and compute time. (default: 5)

  • distribution (Distribution) – The distribution to use when introducing noise into the model. (default: Bernoulli(p=0.1))

  • over (WrapOver) – The part of the model to apply stochastic noise to. One of [WrapOver.WEIGHTS, WrapOver.ACTIVATIONS]. (default: <WrapOver.WEIGHTS: 'weights'>)

  • inline (bool) – Whether to inline inner functions. Defaults to inline=False which disallows inner functions. (default: False)

__call__(fn)

Applies wrapper to tf.function decorated function.

Parameters:

fn (Callable) – The function to wrap

Return type:

WrappedFunction

Returns:

The wrapped function

Example Usage

Wrapping a tf.function
from capsa_tf import sample # or sculpt, vote
wrapper = sample.Wrapper(n_samples=3) # Initialize a wrapper object with your config options

def forward_pass(x):
    ...
    return y

wrapped_forward_pass = wrapper(forward_pass) # wrap your function

y = wrapped_forward_pass(x) # Use the wrapped function as usual
y, risk = wrapped_forward_pass(x, return_risk=True) # Use the wrapped function to obtain risk values
Decorator approach
from capsa_tf import sample # or sculpt, vote

@sample.Wrapper(n_samples=3) # Initialize a wrapper object with your config options
def forward_pass(x):
    ...
    return y

y = forward_pass(x) # Use the wrapped function as usual
y = forward_pass(x,return_risk=True) # Use the wrapped function to obtain risk values
class capsa_tf.sample.WrapOver

Enum representing which part of the model is wrapped to produce noise into the model outputs

WEIGHTS = 'weights'

Inject noise into the model predictions through the model’s weights

ACTIVATIONS = 'activations'

Inject noise into model predictions after every activation function call

class capsa_tf.sample.Bernoulli

Config Class to represent a Bernoulli Distribution for the Sample Wrapper

__init__(p=0.1)

Initialize a Bernoulli Distribution object

Parameters:

p (float) – The drop probability of the Bernoulli distribution (default: 0.1)

class capsa_tf.sample.Normal

Config Class to represent a Normal Distribution for the Sample Wrapper

__init__(init_sigma=0.2)

Initialize a Normal Distribution object

Parameters:

init_sigma – The starting value for sigma (default: 0.2)

class capsa_tf.sample.Distribution

Base Config Class to represent a Distribution for the Sample Wrapper