05 - Image segmentation

Advanced Image Processing

Poznan University of Technology, Institute of Robotics and Machine Intelligence

Laboratory 5: Image segmentation

Introduction

In previous laboratories you focused on image classification: predicting a single label for an entire image using convolutional networks and Vision Transformers. In many real‑world applications, however, we need a much more detailed understanding of the scene. Instead of answering “what is in the image?” we want to know “which pixel belongs to which object?”.

This is the purpose of image segmentation. In this task, the model assigns a class label to every pixel in the image. For example, in a medical scan each pixel might be labeled as tumor or healthy tissue; in autonomous driving, pixels can represent road, car, pedestrian, sky, etc.

There are several related tasks:

In this lab, we will focus on semantic segmentation using a convolutional encoder–decoder architecture (U‑Net‑like model). You will learn how to prepare datasets with masks, implement a segmentation network, choose appropriate losses and metrics, and evaluate predictions both quantitatively and visually.


Classification vs detection vs segmentation (illustrative example, source: kaggle.com).

Segmentation is a key technology in many domains:

In this laboratory, you will move from global image‑level reasoning to dense prediction, where every pixel matters.


Goals

The objectives of this laboratory are to:


Prerequisites

Install dependencies

PyTorch
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126
Lightning
pip install lightning
Additional packages (optional but recommended)
pip install matplotlib pillow

Imports

Below is a typical set of imports you will need during this lab.

import os
from pathlib import Path

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from torchvision import transforms
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

Dataset

In this lab we will work with a binary foreground/background segmentation problem. Each input image has a corresponding mask of the same spatial size, where each pixel belongs to one of two classes:

Masks are stored as grayscale images: pixel values 0 and 255 (or another constant) are mapped to class indices 0 and 1. In practice, multi‑class segmentation simply extends this idea to more labels.

A typical folder structure for a segmentation dataset looks like this:

data/
    train/
        images/
            image_0001.png
            image_0002.png
            ...
        masks/
            image_0001.png
            image_0002.png
            ...
    val/
        images/
        masks/
    test/
        images/
        masks/

Each image file in images/ has a corresponding mask in masks/ with the same file name.


Example of an input image (left) and its corresponding segmentation mask (right).


Segmentation fundamentals

Pixel‑wise classification view

Semantic segmentation can be seen as performing classification for each pixel independently, while also using information from surrounding pixels via convolutions.

For an image of size H × W and C classes, the model outputs a tensor of shape:

where each spatial position (h, w) has a vector of C logits. After applying softmax over the channel dimension, we obtain a probability distribution over classes for every pixel.

The ground truth mask can be represented in two ways:

For most loss functions in PyTorch (e.g. CrossEntropyLoss), we use the label map representation.

Loss functions for segmentation

The simplest loss function is pixel‑wise cross‑entropy. For each pixel, we compute the cross‑entropy between the predicted class distribution and the true class, and then average over all pixels in the batch.

However, segmentation often suffers from class imbalance (few foreground pixels compared to background). In such cases, overlap‑based losses such as Dice loss are very popular.

The Dice coefficient for a single class is defined as:

\[ \text{Dice}(P, G) = \frac{2|P \cap G|}{|P| + |G|}, \]

where \(P\) is the set of predicted foreground pixels and \(G\) is the set of ground truth foreground pixels. The Dice loss is then:

\[ \mathcal{L}_{\text{Dice}} = 1 - \text{Dice}. \]

In practice, Dice can be computed in a differentiable way using probabilities or logits.

Metrics: IoU and Dice

To evaluate segmentation quality, we often report Intersection over Union (IoU) and Dice score.

For a given class, IoU is defined as:

\[ \text{IoU}(P, G) = \frac{|P \cap G|}{|P \cup G|}. \]

Dice and IoU are closely related; Dice tends to be slightly more forgiving for small objects.


Comparison between Dice and Intersection over Union (IoU) – also used in object detection.

For binary segmentation, we often report foreground Dice and foreground IoU. For multi‑class problems, we can compute these metrics per class and then average.


Segmentation model architectures

From classification CNNs to segmentation networks

In classification tasks, CNNs progressively downsample the feature maps using pooling or strided convolutions and finally apply a global pooling and a fully connected layer to produce a single prediction per image.

For segmentation, we need dense predictions with the same spatial resolution (or almost the same) as the input image. A common solution is to use an encoder–decoder architecture:

U‑Net and skip connections

One of the most popular segmentation architectures is U‑Net. Its key idea is to combine coarse, high‑level features from deep layers with fine, high‑resolution features from shallow layers using skip connections.

The U‑Net consists of:

  1. Encoder (contracting path) – sequence of convolutional blocks and pooling operations.
  2. Decoder (expanding path) – upsampling layers that increase spatial resolution.
  3. Skip connections – feature maps from the encoder are concatenated with decoder features at matching resolutions, which helps recover fine details.


Simplified U‑Net style encoder–decoder architecture with skip connections.

Note: In this lab, you will implement a small U‑Net‑like model from scratch.


Dataset and dataloaders

Segmentation dataset implementation

We start by implementing a custom PyTorch Dataset that returns (image, mask) pairs for training.

class SegmentationDataset(Dataset):
    """Custom dataset for image segmentation.

    Each sample consists of an RGB image and a corresponding mask with
    integer labels (0 – background, 1 – foreground).
    """
    def __init__(self, images_dir: str, masks_dir: str, transform=None):
        self.images_dir = Path(images_dir)
        self.masks_dir = Path(masks_dir)
        self.transform = transform

        # Collect image paths and ensure masks exist
        self.image_paths = sorted(list(self.images_dir.glob("*.png")))
        if len(self.image_paths) == 0:
            raise RuntimeError(f"No images found in {self.images_dir}")

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        image_path = self.image_paths[idx]
        mask_path = self.masks_dir / image_path.name

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        image = np.array(image)
        mask = np.array(mask)

        # Map mask to {0, 1}
        mask = (mask > 0).astype(np.uint8)

        if self.transform is not None:
            # Basic transform pipeline using torchvision transforms
            # We apply the same geometric transforms to image and mask
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        # Convert to tensors
        image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        mask = torch.from_numpy(mask).long()

        return image, mask

In practice, you can use libraries like Albumentations to conveniently apply the same geometric transforms to both images and masks. For simplicity, the example above assumes a generic callable transform that returns a dictionary with "image" and "mask" keys.


💥 Task 1 💥

Implement your own SegmentationDataset class based on the template above. Make sure that:

Create training, validation and test datasets pointing to this dataset. You will need to split the train folder into training and validation sets (e.g. 80% train, 20% val).


Visualizing images and masks

Before training the model, it is very important to verify that images and masks are loaded correctly.


💥 Task 2 💥

Implement the visualization function to display a few random samples from your training dataset. Check that:

If something looks wrong, fix your dataset implementation before continuing.

def show_samples(dataset, num_samples: int = 4):
    indices = np.random.choice(len(dataset), size=num_samples, replace=False)

    fig, axes = plt.subplots(2, num_samples, figsize=(4 * num_samples, 6))
    ############# TODO: Student code #####################
    for i, idx in enumerate(indices):
        pass

    ######################################################

    plt.tight_layout()
    plt.show()

Baseline segmentation model

U‑Net‑like architecture

In this section you will implement a simple U‑Net‑like encoder–decoder architecture. Below is a minimal implementation sketch of a U‑Net‑style network for binary segmentation.

class DoubleConv(nn.Module):
    """(Conv2d => ReLU) * 2 block."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class UNetSmall(nn.Module):
    def __init__(self, in_channels: int = 3, num_classes: int = 2):
        super().__init__()
        self.down1 = DoubleConv(in_channels, 64)
        self.down2 = DoubleConv(64, 128)
        self.down3 = DoubleConv(128, 256)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv3 = DoubleConv(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(128, 64)

        self.head = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        ############# TODO: Student code #####################
        # Input -> Encoder -> Bottleneck -> Decoder with skip connections -> Final classification layer -> Output

        ######################################################

💥 Task 3 💥

Implement your own small U‑Net‑like architecture using the above sketch. Ensure that:

Test your model with a random input tensor and verify that shapes match.


Lightning module and training loop

To keep the training code clean and consistent with previous labs, we will wrap the segmentation model in a LightningModule.


💥 Task 4 💥

Create a training script (e.g. lab05.py) that:

  1. Creates dataloaders.
  2. Instantiates LitSegmentationModel.
  3. Creates a Trainer and runs training and validation: trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader).
  4. Evaluates the model on the test set using trainer.test.

Monitor training and validation loss. Verify that the model is able to overfit a very small subset of the training data (e.g. 10 images) as a sanity check.


Dice and IoU metrics

Dice score

Below is a simple implementation of the Dice score for binary segmentation.

def dice_score(preds: torch.Tensor, targets: torch.Tensor, eps: float = 1e-6) -> float:
    """Computes Dice score for binary segmentation.

    Parameters
    ----------
    preds : torch.Tensor
        Logits or probabilities with shape (N, 2, H, W) or predicted labels with shape (N, H, W).
    targets : torch.Tensor
        Ground truth labels with shape (N, H, W) where values are 0 or 1.
    """
    if preds.dim() == 4:
        preds = torch.argmax(preds, dim=1)

    preds = preds.float()
    targets = targets.float()

    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum()
    dice = (2.0 * intersection + eps) / (union + eps)
    return dice.item()

IoU

def iou_score(preds: torch.Tensor, targets: torch.Tensor, eps: float = 1e-6) -> float:
    if preds.dim() == 4:
        preds = torch.argmax(preds, dim=1)

    preds = preds.float()
    targets = targets.float()

    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum() - intersection
    iou = (intersection + eps) / (union + eps)
    return iou.item()

💥 Task 5 💥

Integrate Dice score and IoU metrics into your LitSegmentationModel:

Compare the behavior of Dice and IoU during training. Which one is more stable on your dataset?


Visualizing predictions

Numerical metrics are important, but for segmentation it is critical to look at the predicted masks.

@torch.no_grad()
def visualize_predictions(model, dataset, num_samples: int = 4, device: str = "cuda"):
    model.eval().to(device)

    indices = np.random.choice(len(dataset), size=num_samples, replace=False)
    fig, axes = plt.subplots(3, num_samples, figsize=(4 * num_samples, 8))

    for i, idx in enumerate(indices):
        image, mask = dataset[idx]
        image = image.to(device).unsqueeze(0)

        logits = model(image)
        preds = torch.argmax(logits, dim=1).squeeze(0).cpu()

        image_np = image.squeeze(0).cpu().permute(1, 2, 0).numpy()
        mask_np = mask.numpy()
        pred_np = preds.numpy()

        axes[0, i].imshow(image_np)
        axes[0, i].set_title("Image")
        axes[0, i].axis("off")

        axes[1, i].imshow(mask_np, cmap="gray")
        axes[1, i].set_title("Ground truth")
        axes[1, i].axis("off")

        axes[2, i].imshow(pred_np, cmap="gray")
        axes[2, i].set_title("Prediction")
        axes[2, i].axis("off")

    plt.tight_layout()
    plt.show()

💥 Task 6 💥

Use the visualization function above to inspect predictions on the validation or test set. Answer the following questions:


Experiment tracking

As in previous labs, you can use an experiment tracking tool (e.g. MLflow) to log:


💥 Task 7 💥

Integrate experiment tracking into your training script and to facilitate comparison of different runs (with/without augmentations, different architectures, etc.).


Data augmentations for segmentation

In Lab 2 you worked with various image augmentations for classification (random crops, flips, color jitter, etc.). For segmentation, we must be more careful:

Libraries like Albumentations provide convenient primitives for such joint transformations:

import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    A.VerticalFlip(p=0.5),
    A.RandomResizeResizedCrop(height=256, width=256),
    A.ColorJitter(p=0.5),
    A.GaussianBlur(p=0.2),
    A.Normalize(),
    ToTensorV2(),
])

You can then pass this transform to SegmentationDataset and use image / mask produced by Albumentations directly.


💥 Task 8 💥

Add geometric augmentations to your training dataset. At minimum:

Train two models:

  1. without augmentations.
  2. with augmentations.

Compare Dice and IoU on the validation set and discuss the impact of augmentations on generalization.


Advanced usage

Transfer learning for segmentation

Instead of training the encoder from scratch, you can reuse a pretrained classification backbone (e.g. ResNet‑18) and attach a decoder on top.


💥 Task 9 💥


Ready-to-use segmentation models

You can also experiment with established segmentation architectures from libraries like Segmentation Models PyTorch (SMP).

pip3 install segmentation-models-pytorch

💥 Task 10 💥

Use SMP to create a segmentation model (e.g. U‑Net with ResNet‑34 backbone) and train it on your dataset. Compare results with your custom implementation.


Multi‑class segmentation

If your dataset contains more than two classes, you can extend the model and metrics to multi‑class segmentation. To do this you should:


💥 Task 11 💥

Find and download a segmentation dataset, for example from Kaggle Datasets, which has more than 2 classes and implement multi‑class segmentation.