capsa_tf.vote

class capsa_tf.vote.Wrapper
__init__(n_voters=4, alpha=3, weight_noise=0.1, inline=False)

Initialize a Vote Wrapper with configs.

Parameters:
  • n_voters – More voters will give a more diverse set of opinions and quality of uncertainty, but will also cost more memory and compute time. (default: 4)

  • alpha – Approximate multiple votes with a shared internal representation. Smaller alpha (i.e., alpha=1) indicates more sharing between voters (faster runtime and less memory requirement). (default: 3)

  • weight_noise – How much noise to use when initializing new weights. Suggested range [0., 0.3] (default: 0.1)

  • inline – Whether to inline inner functions. Defaults to inline=False which disallows inner functions. (default: False)

__call__(fn)

Applies wrapper to tf.function decorated function.

Parameters:

fn (Callable) – The function to wrap

Return type:

WrappedFunction

Returns:

The wrapped function

Example Usage

Wrapping a tf.function
from capsa_tf import sample # or sculpt, vote
wrapper = sample.Wrapper(n_samples=3) # Initialize a wrapper object with your config options

def forward_pass(x):
    ...
    return y

wrapped_forward_pass = wrapper(forward_pass) # wrap your function

y = wrapped_forward_pass(x) # Use the wrapped function as usual
y, risk = wrapped_forward_pass(x, return_risk=True) # Use the wrapped function to obtain risk values
Decorator approach
from capsa_tf import sample # or sculpt, vote

@sample.Wrapper(n_samples=3) # Initialize a wrapper object with your config options
def forward_pass(x):
    ...
    return y

y = forward_pass(x) # Use the wrapped function as usual
y = forward_pass(x,return_risk=True) # Use the wrapped function to obtain risk values