capsa-torch

Submodules

Top-Level Attributes

class capsa_torch.RiskOutput
RiskOutput(pred, risk): Namedtuple representing the output of a Module with Risk. Returned by wrapped module’s when the

return_risk=True argument is passed. Can be treated as regular tuple.

# Automatically unpacks like a regular tuple
pred, risk = wrapped_module(..., return_risk=True)

out = wrapped_module(..., return_risk=True)
type(out) # RiskOutput
out.pred # (or out[0]) Access module prediction
out.risk # (or out[1]) Access risk associated with module prediction
static __new__(_cls, pred, risk)

Create new instance of RiskOutput(pred, risk)

count(value, /)

Return number of occurrences of value.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

pred

Alias for field number 0

risk

Alias for field number 1

capsa_torch.collate(risk_out)

Converts a RiskOutput object containing two pytrees (prediction and risk) into a pytree containing RiskOutput leaves.

Parameters:

risk_out (RiskOutput) – A RiskOutput object containing pytrees matching prediction/risk pytrees.

Returns:

A pytree with RiskOutput leaves in the same structure as the input pytrees.

out = model(...)
# out = {"key": (t1, [t2, t3]), "key2": t4}

risk_out = model(..., return_risk=True)
# risk_out = RiskOutput(
#     {"key": (t1_pred, [t2_pred, t3_pred]), "key2": t4_pred},
#     {"key": (t1_risk, [t2_risk, t3_risk]), "key2": t4_risk}
# )
# `risk_out.pred` is the same as `out`
# `risk_out.risk` is a pytree with the same nesting structure as `out`
# but contains risk values as leaves.

collated_risk_out = capsa_torch.collate(risk_out)
# collated_risk_out = \
#     {"key": (RiskOutput(t1_pred, t1_risk),
#         [RiskOutput(t2_pred, t2_risk),
#         RiskOutput(t3_pred, t3_risk)]),
#      "key2": RiskOutput(t4_pred, t4_risk)}
# `collated_risk_out` is a pytree with the same nesting structure as `out`
# but contains a `RiskOutput` namedtuples as leaves.