Neo Wrapper: Anomaly Detection in Autonomous Driving with SegFormer¶
In this tutorial, we demonstrate how to wrap a pretrained SegFormer model using the Neo wrapper, enabling vacuity-based anomaly detection. We’ll train the wrapped model on anomaly-free data from the MUAD dataset, and then apply it to detect out-of-distribution (OOD) anomalies.
Requirements¶
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
from torch.utils.data.dataloader import DataLoader
from torch.nn import functional as F
from capsa_torch import neo
1. Loading MUAD Data¶
We start by preparing the MUAD dataset for training and testing. MUAD contains both clean (anomaly-free) and out-of-distribution (OOD) data, which makes it ideal for anomaly detection tasks.
To process the images and semantic segmentation masks, we use the SegformerFeatureExtractor
, which ensures the inputs are formatted correctly for the SegFormer model. We define a custom transform function that returns both the normalized image tensors and their corresponding label masks.
We then load:
The training split from the anomaly-free portion of MUAD
The test split from the OOD subset to evaluate anomaly detection performance
Below is the code that
defines MUAD class
performs the loading and preprocessing:
from pathlib import Path
from typing import Callable, Literal, NamedTuple
from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive
from torchvision import tv_tensors
from torchvision import tv_tensors
from PIL import Image
from collections import namedtuple
class MUADClass(NamedTuple):
name: str
id: int
color: tuple[int, int, int]
cityscapes_id: int
class MUAD(VisionDataset):
base_url = "https://zenodo.org/records/10619959/files/"
zip_md5 = {
"train": "cea6a672225b10dda1add8b2974a5982",
"val": "957af9c1c36f0a85c33279e06b6cf8d8",
}
classes = [
MUADClass("road", 0, (128, 64, 128), 7),
MUADClass("sidewalk", 1, (244, 35, 232), 8),
MUADClass("building", 2, (70, 70, 70), 11),
MUADClass("wall", 3, (102, 102, 156), 12),
MUADClass("fence", 4, (190, 153, 153), 13),
MUADClass("pole", 5, (153, 153, 153), 17),
MUADClass("traffic_light", 6, (250, 170, 30), 19),
MUADClass("traffic_sign", 7, (220, 220, 0), 20),
MUADClass("vegetation", 8, (107, 142, 35), 21),
MUADClass("terrain", 9, (152, 251, 152), 22),
MUADClass("sky", 10, (70, 130, 180), 23),
MUADClass("person", 11, (220, 20, 60), 24),
MUADClass("rider", 12, (255, 0, 0), 25),
MUADClass("car", 13, (0, 0, 142), 26),
MUADClass("truck", 14, (0, 0, 70), 27),
MUADClass("bus", 15, (0, 60, 100), 28),
MUADClass("train", 16, (0, 80, 100), 31),
MUADClass("motorcycle", 17, (0, 0, 230), 32),
MUADClass("bicycle", 18, (119, 11, 32), 33),
MUADClass("bear deer cow", 19, (255, 228, 196), 0),
MUADClass("garbage_bag stand_food trash_can", 20, (128, 128, 0), 0),
MUADClass("unlabeled", 21, (0, 0, 0), 0)
]
def __init__(
self,
root: str | Path,
split: Literal["train", "ood"],
target_type: Literal["semantic"] = "semantic",
transforms: Callable | None = None,
download: bool = True
) -> None:
if split not in ["train", "ood"]:
raise ValueError("Only 'train' and 'ood' splits are supported.")
if target_type != "semantic":
raise ValueError("Only 'semantic' target_type is supported.")
# Handle full dataset root logic like original class
if split != "ood":
dataset_root = Path(root) / "MUAD"
else:
dataset_root = Path(root)
super().__init__(dataset_root, transforms=transforms)
self.root = dataset_root
self.split = split
self.target_type = target_type
if not self._check_exists():
if not download:
raise FileNotFoundError(f"MUAD {split} not found at {self.root}.")
if split == "ood":
raise FileNotFoundError("No download available for 'ood' split. Place it manually.")
self._download()
self.samples = sorted((self.root / split / "leftImg8bit").glob("**/*"))
self.targets = sorted((self.root / split / "leftLabel").glob("**/*"))
self.len = len(self.samples)
def __getitem__(self, index: int) -> tuple[tv_tensors.Image, tv_tensors.Mask]:
image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB"))
target = tv_tensors.Mask(Image.open(self.targets[index]))
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self) -> int:
return self.len
def _check_exists(self) -> bool:
img_dir = self.root / self.split / "leftImg8bit"
lbl_dir = self.root / self.split / "leftLabel"
return (
img_dir.is_dir()
and lbl_dir.is_dir()
and len(list((img_dir).glob("**/*"))) > 1
and len(list((lbl_dir).glob("**/*"))) > 1
)
def _download(self) -> None:
filename = f"{self.split}.zip"
url = self.base_url + filename
md5 = self.zip_md5[self.split]
print(f"[MUAD] Downloading {self.split} split from {url} ...")
download_and_extract_archive(url, download_root=str(self.root.parent), md5=md5)
@property
def color_palette(self) -> list[list[int]]:
return [list(c.color) for c in self.classes]
feature_extractor = SegformerFeatureExtractor()
def train_transform_muad(image, target):
data = feature_extractor(image, target)
return data["pixel_values"][0], data["labels"][0]
MUAD_TRAIN_ROOT = "muad/train"
train_dataset = MUAD(
root=MUAD_TRAIN_ROOT,
split="train",
target_type="semantic",
transforms=train_transform_muad
)
print(f"Train dataset size: {len(train_dataset)}")
train_dataloader = DataLoader(train_dataset, batch_size=16, num_workers=4, shuffle=True)
MUAD_OOD_ROOT = "muad/test_OOD"
test_dataset = MUAD(
root=MUAD_OOD_ROOT,
split="ood",
target_type="semantic",
transforms=train_transform_muad
)
print(f"Test dataset size: {len(test_dataset)}")
test_dataloader = DataLoader(test_dataset,
batch_size=1,
num_workers=4,
shuffle=False)
Train dataset size: 3420
Test dataset size: 102
2. Initializing and Wrapping the SegFormer Model¶
In this step, we load a pretrained SegFormer model (nvidia/mit-b3
) and wrap it using the Neo wrapper to enable vacuity-based anomaly detection.
We load a previously fine-tuned checkpoint and freeze the model’s parameters, ensuring that only the Neo wrapper’s auxiliary heads will be trained, this is because SegFormer is large and slow to train.
We then wrap it with Neo, it takes several parameters that control how and where the anomaly detection heads are attached within the SegFormer model:
integration_sites
: The number of uncertainty heads to insert. Using multiple heads helps improve the reliability of the detection signal.layer_out_dims
: Controls the size of internal features used by each head.node_name_filter
: Specifies where to attach the heads within the model, using internal layer names.kernel_size
,padding
,stride
: Basic configuration for the attached modules.add_batch_norm
: Enables optional normalization within the added components.pixel_wise
: WhenTrue
, the model will produce pixel-level anomaly scores.
Finally, we run a dummy forward pass to initialize the graph and ensure the wrapped model is ready for training.
Below is the code that performs this setup:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Device used: {device}')
# Initialize Segformer model
MODEL_NAME = "nvidia/mit-b3"
id2label = {
0: "road", 1: "sidewalk", 2: "building", 3: "wall", 4: "fence", 5: "pole",
6: "traffic light", 7: "traffic sign", 8: "vegetation", 9: "terrain", 10: "sky",
11: "person", 12: "rider", 13: "car", 14: "truck", 15: "bus", 16: "train",
17: "motorcycle", 18: "bicycle", 19: "bear deer cow", 20: "garbage_bag stand_food trash_can"
}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)
segformer = SegformerForSemanticSegmentation.from_pretrained(
MODEL_NAME,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
segformer.to(device)
import gdown
file_id = "1vwpfy-5SGH0vyDcV3AckAFqsu4TOkppY"
url = f"https://drive.google.com/uc?id={file_id}"
output = "segformer_pretrained_muad.pt"
gdown.download(url, output, quiet=False)
model_name = f'segformer_pretrained_muad.pt'
segformer.load_state_dict(torch.load(model_name))
for param in segformer.parameters():
param.requires_grad = False
wrapper = neo.Wrapper(integration_sites=3, layer_out_dims=(128,64), node_name_filter=['conv2d_58'], kernel_size = 3, padding = 1, stride = 1, add_batch_norm=True, pixel_wise = True)
wrapped_segformer = wrapper(segformer).to(device)
input = torch.randn(1, 3, 224, 224).to(device)
output = wrapped_segformer(input) # Run a dummy forward pass to initialize the wrapped model
Device used: cuda:0
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b3 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected Segformer — disabling symbolic tracing.
3. Training on Anomaly-Free Data¶
We train the Neo-wrapped SegFormer model on the anomaly-free portion of the MUAD dataset. During training, the model learns to minimize a vacuity loss, which reflects its uncertainty under normal conditions.
optimizers = optim.Adam(wrapped_segformer.parameters(), lr=5e-5, weight_decay=1e-5)
criterion = nn.MSELoss(reduction='none') # Use 'none' to get loss per sample
min_running_loss = np.inf
model_name = f'segformer_neo.pt'
print(f'Train neo wrapped segformer on anomaly-free dataset ...')
for epoch in range(8):
running_loss = 0.0
for i, batch in tqdm(enumerate(train_dataloader)):
optimizers.zero_grad()
inputs = batch[0].to(device)
labels = batch[1].to(device)
batch_outputs, vacuity_scores = wrapped_segformer(pixel_values=inputs, return_risk=True)
loss = vacuity_scores.mean()
loss.backward()
optimizers.step()
running_loss += loss.item()
if i % 100 == 99:
print(f"Epoch {epoch+1}, iter {i+1} \t loss: {running_loss}")
if running_loss < min_running_loss:
torch.save(wrapped_segformer.state_dict(), model_name)
print(f"Loss decreased: {min_running_loss} -> {running_loss}.")
print(f"Model saved to {model_name}.")
min_running_loss = min(min_running_loss, running_loss)
running_loss = 0.0
Train neo wrapped segformer on anomaly-free dataset ...
Epoch 8, iter 200 loss: 26.032298654317856
200it [01:51, 1.71it/s]
Loss decreased: 27.030073761940002 -> 26.032298654317856.
Model saved to segformer_neo.pt.
We can also load the weights to skip the training.
import gdown
file_id = "1fHxezRKXc-wajOw7Rxo8WQsC9v2Tn2pt"
url = f"https://drive.google.com/uc?id={file_id}"
output = "segformer_neo.pt"
gdown.download(url, output, quiet=False)
model_name = f'segformer_neo.pt'
wrapped_segformer.load_state_dict(torch.load(model_name))
Downloading...
From (original): https://drive.google.com/uc?id=1fHxezRKXc-wajOw7Rxo8WQsC9v2Tn2pt
From (redirected): https://drive.google.com/uc?id=1fHxezRKXc-wajOw7Rxo8WQsC9v2Tn2pt&confirm=t&uuid=f116d104-ff17-468c-8c9e-ad4b49fd143e
To: yourfolder/segformer_neo.pt
100%|██████████| 190M/190M [00:01<00:00, 103MB/s]
<All keys matched successfully>
4. Visualizing Anomaly Detection Results¶
Before running inference, we define a utility function that displays five informative views: the input image, ground truth segmentation, model prediction, and two vacuity maps. This layout helps interpret where and why the model expresses uncertainty in its outputs.
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import cv2
def visualize(img, vac_label, vac_region, labels=None, pred_labels=None):
fig, axs = plt.subplots(1, 5, figsize=(20, 4))
# 1. Original Image
axs[0].imshow((img - img.min()) / (img.max() - img.min()))
axs[0].set_title("Original Image")
axs[0].axis('off')
# 2. Ground Truth Labels
if labels is not None:
cmap = plt.get_cmap("tab20", 21)
axs[1].imshow(cmap(labels))
axs[1].set_title("Segmentation Label")
axs[1].axis('off')
else:
axs[1].axis('off')
# 3. Predicted Labels
if pred_labels is not None:
cmap = plt.get_cmap("tab20", 21)
axs[2].imshow(cmap(pred_labels))
axs[2].set_title("Segmentation Prediction")
axs[2].axis('off')
else:
axs[2].axis('off')
# 4. Vacuity by Ground Truth Label
im3 = axs[3].imshow(vac_label, cmap='jet')
axs[3].set_title("Vacuity by Ground Truth Label")
axs[3].axis('off')
divider3 = make_axes_locatable(axs[3])
cax3 = divider3.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im3, cax=cax3)
# 5. Vacuity by Predicted Region
im4 = axs[4].imshow(vac_region, cmap='jet')
axs[4].set_title("Vacuity by Predicted Region")
axs[4].axis('off')
divider4 = make_axes_locatable(axs[4])
cax4 = divider4.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im4, cax=cax4)
plt.tight_layout()
plt.show()
def average_vacuity_per_label(vacuity_scores, labels):
"""
Average vacuity scores per label and replace the vacuity scores on each pixel
with the averaged scores per label at that pixel.
"""
vacuity_scores = vacuity_scores.cpu()
labels = labels.cpu()
unique_labels = torch.unique(labels)
averaged_vacuity = torch.zeros_like(vacuity_scores)
for label in unique_labels:
label_mask = (labels == label)
if label_mask.sum() > 0:
avg_score = vacuity_scores[label_mask].mean()
averaged_vacuity[label_mask] = avg_score
return averaged_vacuity
def average_vacuity_per_region(vacuity_scores, labels):
"""
Average vacuity scores per connected component of predicted label and replace the vacuity scores
on each pixel with the averaged scores of the connected component at that pixel.
"""
vacuity_scores = vacuity_scores.squeeze(0).cpu() # Remove extra channel dimension
labels = labels.cpu()
averaged_vacuity = torch.zeros_like(vacuity_scores)
for label in torch.unique(labels):
# Include all labels, including label == 0
label_mask = (labels == label).byte().numpy() # Convert to binary mask
label_mask = np.squeeze(label_mask).astype(np.uint8) # Ensure single channel and dtype uint8
num_components, components = cv2.connectedComponents(label_mask, connectivity=8)
for component_id in range(1, num_components): # Skip background of connected components
component_mask = (components == component_id)
component_mask = torch.tensor(component_mask, dtype=torch.bool, device=vacuity_scores.device) # Convert to tensor
component_vacuity_scores = vacuity_scores[component_mask]
avg_score = component_vacuity_scores.mean()
averaged_vacuity[component_mask] = avg_score
return averaged_vacuity
Once inference is complete, we compute vacuity scores and use the visualization function to inspect a few test samples. The aggregated vacuity views — by ground truth label and by predicted region — offer insights into how the model reacts to out-of-distribution content in the MUAD test set.
from einops import rearrange
wrapped_segformer.eval()
# Inference and visualization
for i, batch in enumerate(test_dataloader):
inputs = batch[0].to(device)
labels = batch[1].to(device)
with torch.no_grad():
batch_outputs, vacuity_scores = wrapped_segformer(pixel_values=inputs, return_risk=True)
# Prepare vacuity map
vacuity_scores = vacuity_scores
vacuity_scores = vacuity_scores.detach().cpu()
vacuity_scores = nn.functional.interpolate(
vacuity_scores.float(),
size=labels.shape[-2:],
mode="bilinear",
align_corners=False,
).mean(dim=1, keepdim=False)
# Get segmentation prediction
pred_labels = nn.functional.interpolate(
batch_outputs.logits.float(),
size=labels.shape[-2:],
mode="bilinear",
align_corners=False,
).argmax(dim=1).squeeze(1)
# Visualization (limit to 10 samples)
if i < 10:
img_in = rearrange(inputs.cpu(), 'b c h w -> b h w c')
average_vacuity_scores = average_vacuity_per_label(vacuity_scores, labels)
average_vacuity_scores_per_region = average_vacuity_per_region(vacuity_scores, pred_labels)
for b in range(inputs.shape[0]):
labels_squeezed = labels[b].cpu().numpy()
pred_labels_squeezed = pred_labels[b].cpu().numpy()
visualize(
img=img_in[b],
vac_label=average_vacuity_scores[b],
vac_region=average_vacuity_scores_per_region,
labels=labels_squeezed,
pred_labels=pred_labels_squeezed
)










5. Summary and Observations¶
In this tutorial, we wrapped a pretrained SegFormer model with the Neo wrapper to enable vacuity-based anomaly detection. After training the model on anomaly-free MUAD data, we evaluated it on OOD samples and visualized its outputs.
The results show that the vacuity signal meaningfully highlights regions where the model is less confident, often corresponding to anomalous content in the test data, which are animals and trash cans/bags in this case. Both class-wise and region-wise vacuity maps provide interpretable uncertainty cues, even when the segmentation predictions appear plausible.