Neo Wrapper: Image Classification with Out-of-Distribution Inputs

In this tutorial, we show how to wrap convolutional models with Neo wrapper, and visualize the vacuity loss.

The CIFAR-10 dataset is a widely-used benchmark dataset in machine learning, particularly for image classification tasks. It consists of 60,000 color images, each of size 32x32 pixels, spread across 10 different classes: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. The dataset is divided into 50,000 training images and 10,000 test images. CIFAR-10 is commonly used for evaluating the performance of classification algorithms due to its manageable size and diverse categories.

Step 1: Initial Setup and Wrapping

Import requirements

import argparse
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.models as models
import matplotlib.pyplot as plt
import time, os, copy, numpy as np
import os
import gc
import torch.nn.functional as FTh
import torch.backends.cudnn as cudnn
from datetime import datetime
import torchvision
import sys 
from PIL import Image
from torch.utils.data import DataLoader
from sklearn.metrics import roc_curve, auc
from sklearn.decomposition import PCA
from torch.fx import symbolic_trace, Graph, GraphModule
import argparse
import os
import matplotlib.pyplot as plt
import logging
import sys
import math
device = torch.device("cuda:0")
torch.set_default_device(device)

Define and initialize the model

We can load various convolutional models that is pretrained on Cifar10, and wrap them. The pretrained model is from the pytorch-cifar-models repository. Here we provide an example that is cifar10_repvgg_a2, you can also wrap other models by changing the model name, examples including cifar10_resnet20, cifar10_resnet56 and cifar10_vgg19_bn. The full list of models can be found on the github page.

from capsa_torch import neo
def create_model(model_name, device):
    net = torch.hub.load("chenyaofo/pytorch-cifar-models", model_name, pretrained=True)

    net = net.to(device)
 
    wrapper=neo.Wrapper(num_attachment_points=2,layers_spec=(2,1))

    wrapped_model=wrapper(net)
   

    wrapped_model = wrapped_model.to(device)

    logging.info(wrapped_model)
    return wrapped_model
model_name = "cifar10_repvgg_a2"
device = torch.device("cuda:0")
wrapped_model = create_model(model_name, device)

Step 2: Initialize Dataset

The model wrapped with the Neo Wrapper must be trained. In order to compare the vacuity scores, we divide the classes in the Cifar10 datasets as ID (in-distribution) and OOD (out-of-distribution), and we only train the model with ID data.

Augment CIFAR-10 Dataset with ID and OOD Class Labels

class FilteredCIFAR10(torchvision.datasets.CIFAR10):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, id_classes=None, ood_classes=None):
        super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)
        self.id_classes = id_classes
        self.ood_classes = ood_classes
        self.selected_indices = []
        self.id_ood_labels = []

        if self.id_classes is not None and self.ood_classes is not None:
            self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
            for i, target in enumerate(self.targets):
                class_name = self.classes[target]
                if class_name in self.id_classes:
                    self.selected_indices.append(i)
                    self.id_ood_labels.append(0)  # 0 for ID
                elif class_name in self.ood_classes:
                    self.selected_indices.append(i)
                    self.id_ood_labels.append(1)  # 1 for OOD

            self.targets = [self.targets[i] for i in self.selected_indices]
            self.data = self.data[self.selected_indices]

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        id_ood_label = self.id_ood_labels[index]

        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, id_ood_label

Download and Configure Dataset

# Define transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

g = torch.Generator(device=device)


transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


# Load the full CIFAR-10 training and test datasets
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

# Define filtered dataset for ID and OOD classes
ood_classes = ['airplane', 'automobile', 'bird', 'cat']
id_classes = ['deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Filtered dataset with only ID classes
ID_trainset = FilteredCIFAR10(root='./data', train=True, download=True, transform=transform_train, id_classes=id_classes, ood_classes=[])

# Use the full test set for evaluation
testset = FilteredCIFAR10(root='./data', train=False, download=True, transform=transform_test, id_classes=id_classes, ood_classes=ood_classes)

# DataLoader for filtered datasets
trainloader = DataLoader(ID_trainset, batch_size=100, shuffle=True, num_workers=0, generator=g)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=0, generator=g)

print(f"train dataset size: {len(trainset)}")

print(f"Filtered train dataset size (ID only): {len(ID_trainset)}")
print(f"Filtered test dataset size (ID + OOD): {len(testset)}")
logging.info(f"train dataset size: {len(trainset)}")

logging.info(f"Filtered train dataset size (ID only): {len(ID_trainset)}")
logging.info(f"Filtered test dataset size (ID + OOD): {len(testset)}")
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
train dataset size: 50000
Filtered train dataset size (ID only): 30000
Filtered test dataset size (ID + OOD): 10000

Step 3: Train Model

Here we provide the training function to train the wrapped model.

def evaluate_model(model, dataloaders, device='cuda:0'):
    model.eval()
    with torch.no_grad():
        for inputs, labels, _ in dataloaders:
            inputs = inputs.to(device)
            labels = labels.to(device)
            output, risk=model(inputs, return_risk=True)

def train_wrapper(model_name, model, train_loader, val_loader, dataset_sizes, optimizer, scheduler, num_epochs=25, patience=5):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')

    train_losses = []
    val_losses = []

    patience_counter = 0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)
        logging.info('Epoch {}/{}'.format(epoch + 1, num_epochs))
        logging.info('-' * 10)
        model.train()
        running_loss = 0.0

        for i, (inputs, labels, id_ood) in enumerate(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output, risk =model(inputs, return_risk=True)
            loss = risk.mean()
            loss.backward()
            optimizer.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            print("\rTrain Iteration: {}/{}, Loss: {}.".format(i + 1, len(train_loader), loss.item() * inputs.size(0)), end="")
            sys.stdout.flush()

        epoch_loss = running_loss / dataset_sizes['train']
        train_losses.append(epoch_loss)

        scheduler.step()

        # Validation phase
        model.eval()   # Set model to evaluate mode
        running_loss = 0.0

        with torch.no_grad():
            for i, (inputs, labels, id_ood) in enumerate(val_loader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass
                output, risk=model(inputs, return_risk=True)
                loss = risk.mean()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                print("\rVal Iteration: {}/{}, Loss: {}.".format(i + 1, len(val_loader), loss.item() * inputs.size(0)), end="")
                sys.stdout.flush()

        epoch_val_loss = running_loss / dataset_sizes['val']
        val_losses.append(epoch_val_loss)

        # Deep copy the model
        if epoch_val_loss < best_loss:
            best_loss = epoch_val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            patience_counter = 0  # Reset the patience counter
        else:
            patience_counter += 1

        # Early stopping check
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            logging.info(f"Early stopping at epoch {epoch + 1}")
            break

        # Print the metrics
        print()
        print('Train Loss: {:.4f}'.format(train_losses[-1]))
        print('Val Loss: {:.4f}'.format(val_losses[-1]))
        print()
        logging.info('Train Loss: {:.4f}'.format(train_losses[-1]))
        logging.info('Val Loss: {:.4f}'.format(val_losses[-1]))

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Loss: {:4f}'.format(best_loss))
    logging.info('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    logging.info('Best val Loss: {:4f}'.format(best_loss))

    # Load best model weights
    model.load_state_dict(best_model_wts)

    # Create the directory if it doesn't exist
    directory = os.path.join('wrapped_models', model_name)
    if not os.path.exists(directory):
        os.makedirs(directory)
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    # Save the best model weights
    torch.save(model.state_dict(), os.path.join(directory, f"best_model_{timestamp}.pth"))
    
    # Plot the train and validation losses
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    lossplotname = os.path.join(directory, "training_loss.png")
    counter = 2
    while os.path.exists(lossplotname):
        lossplotname = os.path.join(directory, f"training_loss_{counter}.png")
        counter += 1
    plt.savefig(lossplotname)
    print(f"Figure saved to {lossplotname}")
    logging.info(f"Figure saved to {lossplotname}")
    plt.show()
    plt.close()

    return model, best_loss
# before setting up the parameter and training the model, we need to evaluate the model first to make show it wrapped
evaluate_model(wrapped_model, testloader)
# Observe that all parameters are being optimized
optimizer = optim.SGD(wrapped_model.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
wrapped_model, best_loss = train_wrapper(model_name, wrapped_model, trainloader, testloader, {'train':len(ID_trainset),'val':len(testset)}, optimizer, exp_lr_scheduler,
                       num_epochs=50)
Epoch 1/50
----------
Val Iteration: 100/100, Loss: 99.20313954353333.4..
Train Loss: 1.2113
Val Loss: 0.9822

Epoch 2/50
----------
Val Iteration: 100/100, Loss: 82.21189975738525....
Train Loss: 0.6886
Val Loss: 0.8138

Epoch 3/50
----------
Val Iteration: 100/100, Loss: 73.35951924324036.6..
Train Loss: 0.5840
Val Loss: 0.7236

Epoch 4/50
----------
Val Iteration: 100/100, Loss: 67.3624575138092..2..
Train Loss: 0.5269
Val Loss: 0.6669

Epoch 5/50
----------
Val Iteration: 100/100, Loss: 62.98609972000122.7..
Train Loss: 0.4878
Val Loss: 0.6233

Epoch 6/50
----------
Val Iteration: 100/100, Loss: 59.66009497642517.4..
Train Loss: 0.4585
Val Loss: 0.5900

Epoch 7/50
----------
Val Iteration: 100/100, Loss: 56.70555830001831.04.
Train Loss: 0.4364
Val Loss: 0.5610

Epoch 8/50
----------
Val Iteration: 100/100, Loss: 56.7879319190979..4..
Train Loss: 0.4258
Val Loss: 0.5614

Epoch 9/50
----------
Val Iteration: 100/100, Loss: 56.17443919181824.5..
Train Loss: 0.4248
Val Loss: 0.5568

Epoch 10/50
----------
Val Iteration: 100/100, Loss: 56.15739822387695.9..
Train Loss: 0.4227
Val Loss: 0.5559

Epoch 11/50
----------
Val Iteration: 100/100, Loss: 55.894654989242554...
Train Loss: 0.4212
Val Loss: 0.5526

Epoch 12/50
----------
Val Iteration: 100/100, Loss: 55.53848743438721.3..
Train Loss: 0.4192
Val Loss: 0.5497

Epoch 13/50
----------
Val Iteration: 100/100, Loss: 55.46305775642395.5..
Train Loss: 0.4175
Val Loss: 0.5497

Epoch 14/50
----------
Val Iteration: 100/100, Loss: 55.09125590324402.2..
Train Loss: 0.4157
Val Loss: 0.5460

Epoch 15/50
----------
Val Iteration: 100/100, Loss: 55.22012710571289....
Train Loss: 0.4145
Val Loss: 0.5457

Epoch 16/50
----------
Val Iteration: 100/100, Loss: 55.02273440361023.3..
Train Loss: 0.4143
Val Loss: 0.5444

Epoch 17/50
----------
Val Iteration: 100/100, Loss: 54.9222469329834..3..
Train Loss: 0.4151
Val Loss: 0.5443

Epoch 18/50
----------
Val Iteration: 100/100, Loss: 55.11597394943237.36.
Train Loss: 0.4150
Val Loss: 0.5457

Epoch 19/50
----------
Val Iteration: 100/100, Loss: 54.95551824569702.64.
Train Loss: 0.4147
Val Loss: 0.5430

Epoch 20/50
----------
Val Iteration: 100/100, Loss: 55.0877571105957..9..
Train Loss: 0.4140
Val Loss: 0.5441

Epoch 21/50
----------
Val Iteration: 100/100, Loss: 55.13790845870972.3..
Train Loss: 0.4135
Val Loss: 0.5454

Epoch 22/50
----------
Val Iteration: 100/100, Loss: 54.91640567779541.2..
Train Loss: 0.4139
Val Loss: 0.5439

Epoch 23/50
----------
Val Iteration: 100/100, Loss: 54.926955699920654...
Train Loss: 0.4146
Val Loss: 0.5438

Epoch 24/50
----------
Val Iteration: 100/100, Loss: 55.03297448158264.2..
Train Loss: 0.4137
Val Loss: 0.5428

Epoch 25/50
----------
Val Iteration: 100/100, Loss: 55.15057444572449.65.
Train Loss: 0.4137
Val Loss: 0.5458

Epoch 26/50
----------
Val Iteration: 100/100, Loss: 55.037474632263184...
Train Loss: 0.4144
Val Loss: 0.5440

Epoch 27/50
----------
Val Iteration: 100/100, Loss: 54.771679639816284.4.
Train Loss: 0.4141
Val Loss: 0.5417

Epoch 28/50
----------
Val Iteration: 100/100, Loss: 55.06795644760132.1..
Train Loss: 0.4147
Val Loss: 0.5436

Epoch 29/50
----------
Val Iteration: 100/100, Loss: 55.27934432029724.3..
Train Loss: 0.4139
Val Loss: 0.5458

Epoch 30/50
----------
Val Iteration: 100/100, Loss: 55.08577823638916.45.
Train Loss: 0.4135
Val Loss: 0.5446

Epoch 31/50
----------
Val Iteration: 100/100, Loss: 55.09892702102661.2..
Train Loss: 0.4135
Val Loss: 0.5450

Epoch 32/50
----------
Val Iteration: 100/100, Loss: 54.99833822250366.55.Early stopping at epoch 32
Training complete in 15m 31s
Best val Loss: 0.541698
Figure saved to wrapped_models/cifar10_repvgg_a2/training_loss_3.png
../../_images/ca851c43e9eab6118b914db34d20823ba0254d06aad129d26be1bdc3acb59df6.png

Step 3: Evaluate Vacuity-Based ID/OOD Classification

We added labels ID/OOD to the dataset, and as the vacuity score should reflect the type of uncertainty that arises from a lack of sufficient data or information, we expect higher vacuity scores for the OOD data. We can find a best threshold to divide the testing dataset into ID and OOD, and compare it with the true label. We plot the AUROC curve for this classification in order to evaluate its performance.

def evaluate_model(model, dataloaders, device='cuda:0'):
    model.eval()
    all_labels = []
    all_vacuity_scores = []
    all_id_ood_labels = []

    with torch.no_grad():
        for inputs, labels, id_ood_label in dataloaders:
            inputs = inputs.to(device)
            labels = labels.to(device)
                #id_ood_label=is_id_or_ood(labels, id_classes, ood_classes)
            #_,code1,code_pred1,_,code2,code_pred2 = model(inputs)
            output, risk=model(inputs, return_risk=True)
            loss = risk.mean()
            all_labels.extend(labels.cpu().numpy())
            all_vacuity_scores.extend(risk.cpu().numpy())
            #all_vacuity_scores.append(vacuity_scores.item())
            all_id_ood_labels.extend(id_ood_label.cpu().numpy())

    return np.array(all_labels), np.array(all_vacuity_scores), np.array(all_id_ood_labels)

def calculate_average_vacuity_scores(labels, vacuity_scores, id_ood_labels):
    # Convert inputs to numpy arrays if they are not already
    labels = np.array(labels)
    vacuity_scores = np.array(vacuity_scores)
    id_ood_labels = np.array(id_ood_labels)
    
    # Separate the vacuity scores based on id_ood_labels
    vacuity_scores_id_0 = vacuity_scores[id_ood_labels == 0]
    vacuity_scores_id_1 = vacuity_scores[id_ood_labels == 1]
    
    # Calculate the average vacuity scores
    avg_vacuity_score_id_0 = np.mean(vacuity_scores_id_0)
    avg_vacuity_score_id_1 = np.mean(vacuity_scores_id_1)
    
    return avg_vacuity_score_id_0, avg_vacuity_score_id_1

def plot_roc_curve(model_name, id_ood_labels, vacuity_scores):
    fpr, tpr, thresholds = roc_curve(id_ood_labels, vacuity_scores)
    roc_auc = auc(fpr, tpr)

    # Plot ROC curve
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc="lower right")
    directory = os.path.join('wrapped_models', model_name)
    if not os.path.exists(directory):
        os.makedirs(directory)
    plotname = os.path.join(directory, "roc_curve.png")
    base_plotname = plotname
    counter = 2
    while os.path.exists(plotname):
        plotname = os.path.join(directory, f"roc_curve_{counter}.png")
        counter += 1
    plt.savefig(plotname)
    plt.show()
    plt.close()

    return fpr, tpr, thresholds, roc_auc
labels, vacuity_scores, id_ood_labels=evaluate_model(wrapped_model,testloader)
avg_vacuity_score_id_0, avg_vacuity_score_id_1 = calculate_average_vacuity_scores(labels, vacuity_scores, id_ood_labels)
fpr, tpr, thresholds, roc_auc = plot_roc_curve(model_name, id_ood_labels, vacuity_scores)
../../_images/2764609925a710ed9e288cfac63e2521f84d8d3310e2ebdecfed5dd02dfb4645.png

Given the AUROC, the trained wrapped model evaluate a vacuity score that reflects the lack of sufficient data or information.

Additional Testing: Evaluating Other Models

In addition to the model repvgg_a2, we can also wrap other models; here we provide an example using a ResNet20 classifier for CIFAR-10. You can change model_name to other such models like "cifar10_vgg19_bn" if you like.

# initialize model
model_name = "cifar10_resnet20"
device = torch.device("cuda:0")
wrapped_model = create_model(model_name, device)

# train the model
evaluate_model(wrapped_model, testloader) ## pass through sample inputs to the model before the training to finalize wrapping
optimizer = optim.SGD(wrapped_model.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
wrapped_model, best_loss = train_wrapper(model_name, wrapped_model, trainloader, testloader, {'train':len(ID_trainset),'val':len(testset)}, optimizer, exp_lr_scheduler,
                       num_epochs=50)
Epoch 1/50
----------
Val Iteration: 100/100, Loss: 99.89910125732422.7..
Train Loss: 1.2265
Val Loss: 0.9956

Epoch 2/50
----------
Val Iteration: 100/100, Loss: 81.63678646087646.5.
Train Loss: 0.7741
Val Loss: 0.8146

Epoch 3/50
----------
Val Iteration: 100/100, Loss: 71.86244130134583.86.
Train Loss: 0.6563
Val Loss: 0.7190

Epoch 4/50
----------
Val Iteration: 100/100, Loss: 66.1555290222168.24..
Train Loss: 0.5932
Val Loss: 0.6635

Epoch 5/50
----------
Val Iteration: 100/100, Loss: 61.7536723613739..04.
Train Loss: 0.5521
Val Loss: 0.6206

Epoch 6/50
----------
Val Iteration: 100/100, Loss: 58.13770294189453.1..
Train Loss: 0.5225
Val Loss: 0.5858

Epoch 7/50
----------
Val Iteration: 100/100, Loss: 55.31335473060608.3..
Train Loss: 0.4992
Val Loss: 0.5578

Epoch 8/50
----------
Val Iteration: 100/100, Loss: 55.15536665916443.3..
Train Loss: 0.4873
Val Loss: 0.5556

Epoch 9/50
----------
Val Iteration: 100/100, Loss: 54.678040742874146.6.
Train Loss: 0.4858
Val Loss: 0.5513

Epoch 10/50
----------
Val Iteration: 100/100, Loss: 54.43543195724487.95.
Train Loss: 0.4842
Val Loss: 0.5493

Epoch 11/50
----------
Val Iteration: 100/100, Loss: 54.40511107444763.9..
Train Loss: 0.4815
Val Loss: 0.5489

Epoch 12/50
----------
Val Iteration: 100/100, Loss: 54.15217876434326.55.
Train Loss: 0.4801
Val Loss: 0.5462

Epoch 13/50
----------
Val Iteration: 100/100, Loss: 54.02219891548157..4.
Train Loss: 0.4783
Val Loss: 0.5448

Epoch 14/50
----------
Val Iteration: 100/100, Loss: 53.82266044616699.9..
Train Loss: 0.4771
Val Loss: 0.5429

Epoch 15/50
----------
Val Iteration: 100/100, Loss: 53.550535440444946.5.
Train Loss: 0.4756
Val Loss: 0.5404

Epoch 16/50
----------
Val Iteration: 100/100, Loss: 53.86887788772583.76.
Train Loss: 0.4754
Val Loss: 0.5432

Epoch 17/50
----------
Val Iteration: 100/100, Loss: 53.52661609649658.86.
Train Loss: 0.4756
Val Loss: 0.5403

Epoch 18/50
----------
Val Iteration: 100/100, Loss: 53.57630252838135.7..
Train Loss: 0.4749
Val Loss: 0.5411

Epoch 19/50
----------
Val Iteration: 100/100, Loss: 53.83332371711731.7..
Train Loss: 0.4750
Val Loss: 0.5433

Epoch 20/50
----------
Val Iteration: 100/100, Loss: 53.37501764297485.1..
Train Loss: 0.4749
Val Loss: 0.5391

Epoch 21/50
----------
Val Iteration: 100/100, Loss: 53.625160455703735...
Train Loss: 0.4746
Val Loss: 0.5413

Epoch 22/50
----------
Val Iteration: 100/100, Loss: 53.5938560962677..84.
Train Loss: 0.4745
Val Loss: 0.5407

Epoch 23/50
----------
Val Iteration: 100/100, Loss: 53.84281277656555.9..
Train Loss: 0.4745
Val Loss: 0.5427

Epoch 24/50
----------
Val Iteration: 100/100, Loss: 53.799569606781006...
Train Loss: 0.4746
Val Loss: 0.5428

Epoch 25/50
----------
Val Iteration: 100/100, Loss: 53.48985195159912.2..Early stopping at epoch 25
Training complete in 5m 29s
Best val Loss: 0.539133
Figure saved to wrapped_models/cifar10_resnet20/training_loss.png
../../_images/4c10119441d6c83738ff4bcdb0f5e373e296805258f6735638bb1ed863de0e36.png
labels, vacuity_scores, id_ood_labels=evaluate_model(wrapped_model,testloader)
avg_vacuity_score_id_0, avg_vacuity_score_id_1 = calculate_average_vacuity_scores(labels, vacuity_scores, id_ood_labels)
fpr, tpr, thresholds, roc_auc = plot_roc_curve(model_name, id_ood_labels, vacuity_scores)
../../_images/efe791d9062e064eddc8dc741adf8a348adb182f7e594c0d9c282c0fda887c28.png