capsa_tf.vote¶
- class capsa_tf.vote.Wrapper¶
- __init__(n_voters=4, alpha=3, weight_noise=0.1, inline=False)¶
Initialize a Vote Wrapper with configs.
- Parameters:
n_voters – More voters will give a more diverse set of opinions and quality of uncertainty, but will also cost more memory and compute time. (default:
4
)alpha – Approximate multiple votes with a shared internal representation. Smaller alpha (i.e., alpha=1) indicates more sharing between voters (faster runtime and less memory requirement). (default:
3
)weight_noise – How much noise to use when initializing new weights. Suggested range [0., 0.3] (default:
0.1
)inline – Whether to inline inner functions. Defaults to inline=False which disallows inner functions. (default:
False
)
- __call__(fn)¶
Applies wrapper to
tf.function
decorated function.- Parameters:
fn (
Callable
) – The function to wrap- Return type:
WrappedFunction
- Returns:
The wrapped function
Example Usage
from capsa_tf import sample # or sculpt, vote wrapper = sample.Wrapper(n_samples=3) # Initialize a wrapper object with your config options def forward_pass(x): ... return y wrapped_forward_pass = wrapper(forward_pass) # wrap your function y = wrapped_forward_pass(x) # Use the wrapped function as usual y, risk = wrapped_forward_pass(x, return_risk=True) # Use the wrapped function to obtain risk values
from capsa_tf import sample # or sculpt, vote @sample.Wrapper(n_samples=3) # Initialize a wrapper object with your config options def forward_pass(x): ... return y y = forward_pass(x) # Use the wrapped function as usual y = forward_pass(x,return_risk=True) # Use the wrapped function to obtain risk values