capsa_torch.interpret

capsa_torch.interpret.aleatoric_misclassification_prob_binary(y_pred, y_sigma, risk_threshold=0.0)

Compute the misclassification probability using the CDF interpretation from output, risk, and risk_threshold values for multi-class binary label tasks (multiple labels can be positive). The CDF risk interpretation is defined as the probability of sampling from N(output, risk) across the risk_threshold (misclassifying).

Parameters:
  • output – The output LOGITS from model prediction

  • risk – The risk values from model prediction

  • risk_threshold – Threshold for each class in LOGIT space (default: 0.0)

Return type:

Tensor

Returns:

The computed prob. of misclassification risk tensor

capsa_torch.interpret.aleatoric_misclassification_prob_categorical(y_pred, y_sigma, dim=-1)

Compute the aleatoric misclassification probability for categorical tasks (only one positive label per instance).

Parameters:
  • y_pred (Tensor) – The output LOGITS from model prediction

  • y_sigma (Tensor) – The risk values from model prediction

  • dim (int) – The dimension which corresponds to the logits over the classes (default: -1)

Return type:

Tensor

Returns:

The computed aleatoric probability of misclassification risk tensor

capsa_torch.interpret.epistemic_misclassification_prob_binary(y_pred, y_sigma, risk_threshold=0.0)

Compute the misclassification probability using the CDF interpretation from output, risk, and risk_threshold values for multi-class binary label tasks (multiple labels can be positive). The CDF risk interpretation is defined as the probability of sampling from N(output, risk) across the risk_threshold (misclassifying).

Parameters:
  • output – The output LOGITS from model prediction

  • risk – The risk values from model prediction

  • risk_threshold – Threshold for each class in LOGIT space (default: 0.0)

Return type:

Tensor

Returns:

The computed prob. of misclassification risk tensor

capsa_torch.interpret.epistemic_misclassification_prob_categorical(y_pred, y_sigma, dim=-1, num_points_integral=15)

Compute the misclassification probability using the CDF interpretation from output, risk, and risk_threshold values for categorical tasks (only one positive label per instance). The CDF risk interpretation is defined as the probability of that the largest element of y_pred (along dim) does NOT correspond to the true largest value of y ~ N(y_pred, y_sigma). It should go to 0 as the elements of y_sigma go to 0.

Parameters:
  • y_pred (Tensor) – The output LOGITS from model prediction

  • y_sigma (Tensor) – The risk values from model prediction

  • num_points_integral (int) – The number of points over which to numerically calculate the integral (default: 15)

  • dim (int) – The dimension which corresponds to the logits over the classes (default: -1)

Return type:

Tensor

Returns:

The computed epistemic probability of misclassification risk tensor

capsa_torch.interpret.misclassification_prob_binary(y_pred, y_sigma, risk_threshold=0.0, num_points_integral=15)

Compute the misclassification probability using the CDF interpretation from output, risk, and risk_threshold values for multi-class binary label tasks (multiple labels can be positive). The CDF risk interpretation is defined as the probability of sampling from N(output, risk) across the risk_threshold (misclassifying).

Parameters:
  • output – The output LOGITS from model prediction

  • risk – The risk values from model prediction

  • risk_threshold – Threshold for each class in LOGIT space (default: 0.0)

  • num_points_integral – The number of points over which to numerically calculate the integral (default: 15)

Return type:

Tensor

Returns:

The computed prob. of misclassification risk tensor

capsa_torch.interpret.misclassification_prob_categorical(y_pred, y_sigma, dim=-1, num_points_integral=15)

Compute the misclassification probability for categorical tasks (only one positive label per instance). The misclassification risk is defined as the complement of the maximum expected softmax value of y ~ N(y_pred, y_sigma). As y_sigma goes to 0, it converges to the softmax uncertainty (complement of the maximum expected softmax value of y_pred).

Parameters:
  • y_pred (Tensor) – The output LOGITS from model prediction

  • y_sigma (Tensor) – The risk values from model prediction

  • num_points_integral (int) – The number of points over which to numerically calculate the integral (default: 15)

  • dim (int) – The dimension which corresponds to the logits over the classes (default: -1)

Return type:

Tensor

Returns:

The computed probability of misclassification risk tensor

capsa_torch.interpret.top_percent_risk_cut_accuracy(outputs, risks, gt, risk_thresholds)

For multi-class classification problems. Returns the accuracy of the model on a subsets of the dataset, cutting off the top x% high risk inputs

Parameters:
  • outputs (Tensor) – A tensor of output labels; typically integers in the range [1, n_classes]. Must be of shape [len(dataset)].

  • risks (Tensor) – Risk values corresponding to outputs; higher risk outputs are cut first. These should be produced with misclassification_prob_categorical() Must be of shape [len(dataset)] with risks[i] corresponding to outputs[i].

  • gt (Tensor) – Ground truth target labels; typically integers in the range [1, n_classes], but can be anything that supports equality with the outputs. Must be of shape [len(dataset)] with gt[i] corresponding to outputs[i]

  • risk_thresholds – If float(s), must lie in [0,1) and represent the risk quantile(s) at which to cut outputs. The accuracy will be reported only for outputs with risk below this quantile. If int n, behaves the same as passing torch.linspace(1 / (n + 1), 1., n) as a list.

Return type:

Tensor

Returns:

Accuracy for each cut percentage, as a tensor.