capsa-torch¶
Submodules
Top-Level Attributes
- class capsa_torch.RiskOutput¶
- RiskOutput(pred, risk): Namedtuple representing the output of a Module with Risk. Returned by wrapped module’s when the
return_risk=True
argument is passed. Can be treated as regular tuple.# Automatically unpacks like a regular tuple pred, risk = wrapped_module(..., return_risk=True) out = wrapped_module(..., return_risk=True) type(out) # RiskOutput out.pred # (or out[0]) Access module prediction out.risk # (or out[1]) Access risk associated with module prediction
- static __new__(_cls, pred, risk)¶
Create new instance of RiskOutput(pred, risk)
- count(value, /)¶
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)¶
Return first index of value.
Raises ValueError if the value is not present.
- pred¶
Alias for field number 0
- risk¶
Alias for field number 1
- capsa_torch.collate(risk_out)¶
Converts a RiskOutput object containing two pytrees (prediction and risk) into a pytree containing RiskOutput leaves.
- Parameters:
risk_out (
RiskOutput
) – A RiskOutput object containing pytrees matching prediction/risk pytrees.- Returns:
A pytree with RiskOutput leaves in the same structure as the input pytrees.
out = model(...) # out = {"key": (t1, [t2, t3]), "key2": t4} risk_out = model(..., return_risk=True) # risk_out = RiskOutput( # {"key": (t1_pred, [t2_pred, t3_pred]), "key2": t4_pred}, # {"key": (t1_risk, [t2_risk, t3_risk]), "key2": t4_risk} # ) # `risk_out.pred` is the same as `out` # `risk_out.risk` is a pytree with the same nesting structure as `out` # but contains risk values as leaves. collated_risk_out = capsa_torch.collate(risk_out) # collated_risk_out = \ # {"key": (RiskOutput(t1_pred, t1_risk), # [RiskOutput(t2_pred, t2_risk), # RiskOutput(t3_pred, t3_risk)]), # "key2": RiskOutput(t4_pred, t4_risk)} # `collated_risk_out` is a pytree with the same nesting structure as `out` # but contains a `RiskOutput` namedtuples as leaves.