Neo Wrapper: Anomaly Detection in Autonomous Driving with SegFormer

In this tutorial, we demonstrate how to wrap a pretrained SegFormer model using the Neo wrapper, enabling vacuity-based anomaly detection. We’ll train the wrapped model on anomaly-free data from the MUAD dataset, and then apply it to detect out-of-distribution (OOD) anomalies.

Requirements

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
from torch.utils.data.dataloader import DataLoader
from torch.nn import functional as F
from capsa_torch import neo

1. Loading MUAD Data

We start by preparing the MUAD dataset for training and testing. MUAD contains both clean (anomaly-free) and out-of-distribution (OOD) data, which makes it ideal for anomaly detection tasks.

To process the images and semantic segmentation masks, we use the SegformerFeatureExtractor, which ensures the inputs are formatted correctly for the SegFormer model. We define a custom transform function that returns both the normalized image tensors and their corresponding label masks.

We then load:

  • The training split from the anomaly-free portion of MUAD

  • The test split from the OOD subset to evaluate anomaly detection performance

Below is the code that

  1. defines MUAD class

  2. performs the loading and preprocessing:

from pathlib import Path
from typing import Callable, Literal, NamedTuple
from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive
from torchvision import tv_tensors
from torchvision import tv_tensors
from PIL import Image
from collections import namedtuple

class MUADClass(NamedTuple):
    name: str
    id: int
    color: tuple[int, int, int]
    cityscapes_id: int

class MUAD(VisionDataset):
    base_url = "https://zenodo.org/records/10619959/files/"
    zip_md5 = {
        "train": "cea6a672225b10dda1add8b2974a5982",
        "val": "957af9c1c36f0a85c33279e06b6cf8d8",
    }

    classes = [
        MUADClass("road", 0, (128, 64, 128), 7),
        MUADClass("sidewalk", 1, (244, 35, 232), 8),
        MUADClass("building", 2, (70, 70, 70), 11),
        MUADClass("wall", 3, (102, 102, 156), 12),
        MUADClass("fence", 4, (190, 153, 153), 13),
        MUADClass("pole", 5, (153, 153, 153), 17),
        MUADClass("traffic_light", 6, (250, 170, 30), 19),
        MUADClass("traffic_sign", 7, (220, 220, 0), 20),
        MUADClass("vegetation", 8, (107, 142, 35), 21),
        MUADClass("terrain", 9, (152, 251, 152), 22),
        MUADClass("sky", 10, (70, 130, 180), 23),
        MUADClass("person", 11, (220, 20, 60), 24),
        MUADClass("rider", 12, (255, 0, 0), 25),
        MUADClass("car", 13, (0, 0, 142), 26),
        MUADClass("truck", 14, (0, 0, 70), 27),
        MUADClass("bus", 15, (0, 60, 100), 28),
        MUADClass("train", 16, (0, 80, 100), 31),
        MUADClass("motorcycle", 17, (0, 0, 230), 32),
        MUADClass("bicycle", 18, (119, 11, 32), 33),
        MUADClass("bear deer cow", 19, (255, 228, 196), 0),
        MUADClass("garbage_bag stand_food trash_can", 20, (128, 128, 0), 0),
        MUADClass("unlabeled", 21, (0, 0, 0), 0)
    ]

    def __init__(
        self,
        root: str | Path,
        split: Literal["train", "ood"],
        target_type: Literal["semantic"] = "semantic",
        transforms: Callable | None = None,
        download: bool = True
    ) -> None:
        if split not in ["train", "ood"]:
            raise ValueError("Only 'train' and 'ood' splits are supported.")
        if target_type != "semantic":
            raise ValueError("Only 'semantic' target_type is supported.")

        # Handle full dataset root logic like original class
        if split != "ood":
            dataset_root = Path(root) / "MUAD"
        else:
            dataset_root = Path(root)

        super().__init__(dataset_root, transforms=transforms)
        self.root = dataset_root
        self.split = split
        self.target_type = target_type

        if not self._check_exists():
            if not download:
                raise FileNotFoundError(f"MUAD {split} not found at {self.root}.")
            if split == "ood":
                raise FileNotFoundError("No download available for 'ood' split. Place it manually.")
            self._download()

        self.samples = sorted((self.root / split / "leftImg8bit").glob("**/*"))
        self.targets = sorted((self.root / split / "leftLabel").glob("**/*"))
        self.len = len(self.samples)

    def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]:
        image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB"))
        target = tv_tensors.Mask(Image.open(self.targets[index]))
        if self.transforms is not None:
            image, target = self.transforms(image, target)
        return image, target

    def __len__(self) -> int:
        return self.len

    def _check_exists(self) -> bool:
        img_dir = self.root / self.split / "leftImg8bit"
        lbl_dir = self.root / self.split / "leftLabel"
        return (
            img_dir.is_dir()
            and lbl_dir.is_dir()
            and len(list((img_dir).glob("**/*"))) > 1
            and len(list((lbl_dir).glob("**/*"))) > 1
        )

    def _download(self) -> None:
        filename = f"{self.split}.zip"
        url = self.base_url + filename
        md5 = self.zip_md5[self.split]
        print(f"[MUAD] Downloading {self.split} split from {url} ...")
        download_and_extract_archive(url, download_root=str(self.root.parent), md5=md5)

    @property
    def color_palette(self) -> list[list[int]]:
        return [list(c.color) for c in self.classes]
feature_extractor = SegformerFeatureExtractor()
def train_transform_muad(image, target):
    data = feature_extractor(image, target)   
    return data["pixel_values"][0], data["labels"][0]

MUAD_TRAIN_ROOT = "muad/train"
train_dataset = MUAD(
    root=MUAD_TRAIN_ROOT,
    split="train",
    target_type="semantic",
    transforms=train_transform_muad
)
print(f"Train dataset size: {len(train_dataset)}")
train_dataloader = DataLoader(train_dataset, batch_size=16, num_workers=4, shuffle=True)

MUAD_OOD_ROOT = "muad/test_OOD"
test_dataset = MUAD(
    root=MUAD_OOD_ROOT,
    split="ood",
    target_type="semantic",
    transforms=train_transform_muad
)
print(f"Test dataset size: {len(test_dataset)}")
test_dataloader = DataLoader(test_dataset,
                             batch_size=1,
                             num_workers=4,
                             shuffle=False)
Train dataset size: 3420
Test dataset size: 102

2. Initializing and Wrapping the SegFormer Model

In this step, we load a pretrained SegFormer model (nvidia/mit-b3) and wrap it using the Neo wrapper to enable vacuity-based anomaly detection.

We load a previously fine-tuned checkpoint and freeze the model’s parameters, ensuring that only the Neo wrapper’s auxiliary heads will be trained, this is because SegFormer is large and slow to train.

We then wrap it with Neo, it takes several parameters that control how and where the anomaly detection heads are attached within the SegFormer model:

  • integration_sites: The number of uncertainty heads to insert. Using multiple heads helps improve the reliability of the detection signal.

  • layer_out_dims: Controls the size of internal features used by each head.

  • node_name_filter: Specifies where to attach the heads within the model, using internal layer names.

  • kernel_size, padding, stride: Basic configuration for the attached modules.

  • add_batch_norm: Enables optional normalization within the added components.

  • pixel_wise: When True, the model will produce pixel-level anomaly scores.

Finally, we run a dummy forward pass to initialize the graph and ensure the wrapped model is ready for training.

Below is the code that performs this setup:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Device used: {device}')

# Initialize Segformer model
MODEL_NAME = "nvidia/mit-b3"
id2label = {
    0: "road", 1: "sidewalk", 2: "building", 3: "wall", 4: "fence", 5: "pole",
    6: "traffic light", 7: "traffic sign", 8: "vegetation", 9: "terrain", 10: "sky",
    11: "person", 12: "rider", 13: "car", 14: "truck", 15: "bus", 16: "train",
    17: "motorcycle", 18: "bicycle", 19: "bear deer cow", 20: "garbage_bag stand_food trash_can"
}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

segformer = SegformerForSemanticSegmentation.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)
segformer.to(device)

import gdown

file_id = "1vwpfy-5SGH0vyDcV3AckAFqsu4TOkppY"
url = f"https://drive.google.com/uc?id={file_id}"

output = "segformer_pretrained_muad.pt"
gdown.download(url, output, quiet=False)

model_name = f'segformer_pretrained_muad.pt'
segformer.load_state_dict(torch.load(model_name))

for param in segformer.parameters():
    param.requires_grad = False

wrapper = neo.Wrapper(integration_sites=3, layer_out_dims=(128,64), node_name_filter=['conv2d_58'], kernel_size = 3, padding = 1, stride = 1, add_batch_norm=True, pixel_wise = True)
wrapped_segformer = wrapper(segformer).to(device)
input = torch.randn(1, 3, 224, 224).to(device)
output = wrapped_segformer(input)  # Run a dummy forward pass to initialize the wrapped model
Device used: cuda:0
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b3 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected Segformer — disabling symbolic tracing.

3. Training on Anomaly-Free Data

We train the Neo-wrapped SegFormer model on the anomaly-free portion of the MUAD dataset. During training, the model learns to minimize a vacuity loss, which reflects its uncertainty under normal conditions.

optimizers = optim.Adam(wrapped_segformer.parameters(), lr=5e-5, weight_decay=1e-5)
criterion = nn.MSELoss(reduction='none')  # Use 'none' to get loss per sample
min_running_loss = np.inf
model_name = f'segformer_neo.pt'
print(f'Train neo wrapped segformer on anomaly-free dataset ...')
for epoch in range(8):
    running_loss = 0.0

    for i, batch in tqdm(enumerate(train_dataloader)):
        optimizers.zero_grad()

        inputs = batch[0].to(device)
        labels = batch[1].to(device)
        batch_outputs, vacuity_scores = wrapped_segformer(pixel_values=inputs, return_risk=True)

        loss = vacuity_scores.mean()

        loss.backward()
        optimizers.step()
        running_loss += loss.item()

        if i % 100 == 99:
            print(f"Epoch {epoch+1}, iter {i+1} \t loss: {running_loss}")

            if running_loss < min_running_loss:
                torch.save(wrapped_segformer.state_dict(), model_name)
                print(f"Loss decreased: {min_running_loss} -> {running_loss}.")
                print(f"Model saved to {model_name}.")

            min_running_loss = min(min_running_loss, running_loss)
            running_loss = 0.0
Train neo wrapped segformer on anomaly-free dataset ...
Epoch 8, iter 200 	 loss: 26.032298654317856
200it [01:51,  1.71it/s]
Loss decreased: 27.030073761940002 -> 26.032298654317856.
Model saved to segformer_neo.pt.

We can also load the weights to skip the training.

import gdown

file_id = "1fHxezRKXc-wajOw7Rxo8WQsC9v2Tn2pt"
url = f"https://drive.google.com/uc?id={file_id}"

output = "segformer_neo.pt"
gdown.download(url, output, quiet=False)

model_name = f'segformer_neo.pt'
wrapped_segformer.load_state_dict(torch.load(model_name))
Downloading...
From (original): https://drive.google.com/uc?id=1fHxezRKXc-wajOw7Rxo8WQsC9v2Tn2pt
From (redirected): https://drive.google.com/uc?id=1fHxezRKXc-wajOw7Rxo8WQsC9v2Tn2pt&confirm=t&uuid=f116d104-ff17-468c-8c9e-ad4b49fd143e
To: yourfolder/segformer_neo.pt
100%|██████████| 190M/190M [00:01<00:00, 103MB/s]  
<All keys matched successfully>

4. Visualizing Anomaly Detection Results

Before running inference, we define a utility function that displays five informative views: the input image, ground truth segmentation, model prediction, and two vacuity maps. This layout helps interpret where and why the model expresses uncertainty in its outputs.

from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import cv2

def visualize(img, vac_label, vac_region, labels=None, pred_labels=None):
    fig, axs = plt.subplots(1, 5, figsize=(20, 4))

    # 1. Original Image
    axs[0].imshow((img - img.min()) / (img.max() - img.min()))
    axs[0].set_title("Original Image")
    axs[0].axis('off')

    # 2. Ground Truth Labels
    if labels is not None:
        cmap = plt.get_cmap("tab20", 21)
        axs[1].imshow(cmap(labels))
        axs[1].set_title("Segmentation Label")
        axs[1].axis('off')
    else:
        axs[1].axis('off')

    # 3. Predicted Labels
    if pred_labels is not None:
        cmap = plt.get_cmap("tab20", 21)
        axs[2].imshow(cmap(pred_labels))
        axs[2].set_title("Segmentation Prediction")
        axs[2].axis('off')
    else:
        axs[2].axis('off')

    # 4. Vacuity by Ground Truth Label
    im3 = axs[3].imshow(vac_label, cmap='jet')
    axs[3].set_title("Vacuity by Ground Truth Label")
    axs[3].axis('off')
    divider3 = make_axes_locatable(axs[3])
    cax3 = divider3.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im3, cax=cax3)

    # 5. Vacuity by Predicted Region
    im4 = axs[4].imshow(vac_region, cmap='jet')
    axs[4].set_title("Vacuity by Predicted Region")
    axs[4].axis('off')
    divider4 = make_axes_locatable(axs[4])
    cax4 = divider4.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im4, cax=cax4)

    plt.tight_layout()
    plt.show()


def average_vacuity_per_label(vacuity_scores, labels):
    """
    Average vacuity scores per label and replace the vacuity scores on each pixel
    with the averaged scores per label at that pixel.
    """
    vacuity_scores = vacuity_scores.cpu()
    labels = labels.cpu()
    unique_labels = torch.unique(labels)
    averaged_vacuity = torch.zeros_like(vacuity_scores)

    for label in unique_labels:
        label_mask = (labels == label)
        if label_mask.sum() > 0:
            avg_score = vacuity_scores[label_mask].mean()
            averaged_vacuity[label_mask] = avg_score

    return averaged_vacuity

def average_vacuity_per_region(vacuity_scores, labels):
    """
    Average vacuity scores per connected component of predicted label and replace the vacuity scores
    on each pixel with the averaged scores of the connected component at that pixel.
    """
    vacuity_scores = vacuity_scores.squeeze(0).cpu()  # Remove extra channel dimension
    labels = labels.cpu()
    averaged_vacuity = torch.zeros_like(vacuity_scores)

    for label in torch.unique(labels):
        # Include all labels, including label == 0
        label_mask = (labels == label).byte().numpy()  # Convert to binary mask
        label_mask = np.squeeze(label_mask).astype(np.uint8)  # Ensure single channel and dtype uint8
        num_components, components = cv2.connectedComponents(label_mask, connectivity=8)
        for component_id in range(1, num_components):  # Skip background of connected components
            component_mask = (components == component_id)
            component_mask = torch.tensor(component_mask, dtype=torch.bool, device=vacuity_scores.device)  # Convert to tensor
            component_vacuity_scores = vacuity_scores[component_mask]
            avg_score = component_vacuity_scores.mean()
            averaged_vacuity[component_mask] = avg_score

    return averaged_vacuity

Once inference is complete, we compute vacuity scores and use the visualization function to inspect a few test samples. The aggregated vacuity views — by ground truth label and by predicted region — offer insights into how the model reacts to out-of-distribution content in the MUAD test set.

from einops import rearrange

wrapped_segformer.eval()

# Inference and visualization
for i, batch in enumerate(test_dataloader):
    inputs = batch[0].to(device)
    labels = batch[1].to(device)
    
    with torch.no_grad():
        batch_outputs, vacuity_scores = wrapped_segformer(pixel_values=inputs, return_risk=True)

    # Prepare vacuity map
    vacuity_scores = vacuity_scores
    vacuity_scores = vacuity_scores.detach().cpu()
    vacuity_scores = nn.functional.interpolate(
        vacuity_scores.float(),
        size=labels.shape[-2:],
        mode="bilinear",
        align_corners=False,
    ).mean(dim=1, keepdim=False)

    # Get segmentation prediction
    pred_labels = nn.functional.interpolate(
        batch_outputs.logits.float(),
        size=labels.shape[-2:],
        mode="bilinear",
        align_corners=False,
    ).argmax(dim=1).squeeze(1)

    # Visualization (limit to 10 samples)
    if i < 10:
        img_in = rearrange(inputs.cpu(), 'b c h w -> b h w c')
        average_vacuity_scores = average_vacuity_per_label(vacuity_scores, labels)
        average_vacuity_scores_per_region = average_vacuity_per_region(vacuity_scores, pred_labels)

        for b in range(inputs.shape[0]):
            labels_squeezed = labels[b].cpu().numpy()
            pred_labels_squeezed = pred_labels[b].cpu().numpy()

            visualize(
                img=img_in[b],
                vac_label=average_vacuity_scores[b],
                vac_region=average_vacuity_scores_per_region,
                labels=labels_squeezed,
                pred_labels=pred_labels_squeezed
            )
../../_images/febd716e6a412188d057576a23f627f0111de302d2ffd48746c6e14ab67c9106.png ../../_images/6087e76c85fa935a4b72a267ef73e04ce3e824a8f2cbb6e3c9206160aa7c8f05.png ../../_images/9a5665dff76d8897ecb4489ddf729b35f07a628b11c76a1c35dccc40b0816342.png ../../_images/5358fb9cffc655f595e32083cf48a5df9111761f9b5ded078a1e2827587e7629.png ../../_images/521ef2dca5ff8cf590a1d6c08105d1ccdb433859f33475773f141cac259405e3.png ../../_images/c281ea196a8bc490a1626fa155029b30adf9c1fe11155339978e1ae85968f8e6.png ../../_images/bec738bf30c244d833d58d46e2f9c5a1983da0973ee1ba3bae4ba1844b5c36c6.png ../../_images/d2e90ab4439a46bf4a5c67aad6acc7b7105743966f26e0eee0e5e88cf9e0f089.png ../../_images/df215f54928a03e0c2a63481635665484547e4ab9d04cc17ad7937781e20911c.png ../../_images/e79baa5e2c2b342c7f1843c4d7ad5492f817894d9d97180f42998d255a19a71a.png

5. Summary and Observations

In this tutorial, we wrapped a pretrained SegFormer model with the Neo wrapper to enable vacuity-based anomaly detection. After training the model on anomaly-free MUAD data, we evaluated it on OOD samples and visualized its outputs.

The results show that the vacuity signal meaningfully highlights regions where the model is less confident, often corresponding to anomalous content in the test data, which are animals and trash cans/bags in this case. Both class-wise and region-wise vacuity maps provide interpretable uncertainty cues, even when the segmentation predictions appear plausible.