diff --git a/recognition/unet_s4741911/README.md b/recognition/unet_s4741911/README.md new file mode 100644 index 000000000..735aa7e32 --- /dev/null +++ b/recognition/unet_s4741911/README.md @@ -0,0 +1,120 @@ +# Prostate Segmentation on HipMRI Using 2D U-Net + +By Shannon Dela Pena, 47419111 + +## Description of the U-Net Algorithm + +This project focuses on prostate segmentation from medical MRI scans using a 2D U-Net deep learning architecture. + +The U-Net algorithm is a convolutional neural network (CNN) is designed specifically for biomedical image segmentation. [^1] It consists of a contracting path that captures contextual and semantic information by downsampling and an expanding path that restores spatial context by upsampling and skip connections. This structure allows for pixel-level classification of each region. By doing this, U-Net is highly effective in medical image segmentation, as it can adapt to variations of organ size, shape, and position across different patients. + +In this project, the same core architecture is adapted and optimised for 2D prostate MRI slices from the HipMRI dataset. + +## Objective of Task + +The goal of this project is to automatically segment the prostate gland from MRI scans, targeting a Dice Similarity Coefficient (DSC) of >= 0.75, indicating a high degree of overlap between predicted and ground-truth segmentations. + +## Methodology + +The MRI slice data consisted of images with a resolution of 256 x 128 pixels (H x W). Each corresponding segmentation mask contained six distinct labels, with labels: + +- 0 = Background +- 1 = Body +- 2 = Bones +- 3 = Bladder +- 4 = Rectum +- 5 = Prostate + +**Data Preparation and Augmentation** + +To improve generalisation and prevent overfitting, spatial augmentations such as random rotations and horizontal flip were applied to the MRI slices during training. A study on _Data Augmentations for Prostate Cancer Detection_, [^2] found that small rotational perturbations reflected real world positional variability while keeping anatomical realism, helping the model to generalise better. Horizontal and vertical flips further enhanced robustness by teaching the model that mirrored structures can represent the same anatomical class. + +**The U-Net Architecture** + +The encoder path uses repeated `DoubleConv` blocks (two 3x3 convolutions, batch normalisation, and ReLU) followed by max-pooling for downsampling. The original U-Net used feature sizes of 64, 128, 256, and 512, but these were reduced to 32, 64, 128, and 256 to better accommodate noisy MRI data. + +A `DoubleConv` block in the bottleneck expands the feature maps from 256 to 512 channels, allowing the model to capture broader information such as relative positioning and overall structure of organs. + +The decoder mirrors the downsampling process with transposed convolutions for upsampling and skip connections from the encoder. This reconstruction restores spatial information lost during pooling, producing a multi-class segmentation map, which integrates with the combined Focal and Dice Loss for multi-class segmentation training. + +**Loss Functions** + +A combination of Focal and Dice Loss was used to address class imbalance and improve segmentation of smaller structures. Focal Loss, implemented as an extension of Cross-Entropy Loss, down-weights well-classified pixels while emphasising harder, misclassified ones. This is particularly relevant for the HipMRI dataset, which is highly imbalanced, as most slices lack a visible prostate region. Dice Loss complements this by focusing on region-level overlap and improving boundary precision between different organs. + +To compensate for the underrepresentation of smaller regions such as the rectum and prostate, per-class weights were applied to both the focal and dice components. This assigned higher importance to smaller structures while reducing the influence of dominant classes like the background and body. + +**Metrics** + +The Dice coefficient was computer per class after each epoch to evaluate overlap accuracy. + +## Results + +### Train + +![image](visualisation_train/dice_per_class_plot.png) + +Over 100 epochs, the model achieved consistent convergence across all classes, with Dice coefficients above 0.9 for major regions and stabilising around 0.75 for smaller regions. The slower convergence of classes 4 and 5 highlights the importance of loss weighting and the combined Focal-Dice loss to address class imbalance. The plateauing of Dice scores across all classes indicates that training reached stability and consistent performance. + +![image](visualisation/loss_plot_combined_FL_DL.png) + +The plot shows a steep decline in both training and validation losses during the early epochs, indicating that the model quickly learned to distinguish between major tissue regions. After this phase, both curves gradually flatten between epochs 15 and 50. The training loss continues to decrease steadily to approximately 0.45, while validation stabilises around 0.65. The consistent gap in between suggests mild overfitting, likely due to the model's ability to memorise larger structures more effectively than less frequent, smaller regions. + +Example predictions at different epochs can be viewed in the `visualisation_train` folder. + +### Test + +The trained U-Net model was evaluated on a test set of 540 MRI slices. The model achieved a mean Dice coefficient of 0.9137, indicating strong overall segmentation across all regions. Below are class-wise Dice scores: + +![image](visualisation_test/dice_per_class_test.png) + +``` +Dice per class: {'Background': 0.969, 'Body': 0.968, 'Bones': 0.854, 'Bladder': 0.91, 'Rectum': 0.913, 'Prostate': 0.868} +Mean Dice coefficient: 0.9137 +Dice loss (1 - mean Dice): 0.0863 +``` + +Notably, the prostate region, despite being the smallest and least frequent, achieved a Dice score of 0.868. + +Visualisations of test predictions are provided in the `visualisation_test` folder. + +## Dependencies + +The project was conducted on the Rangpur A100 GPU cluster, using fixed random seed of 7 for reproducibility. The environment was built with the following dependencies: + +``` +torch==2.7.1+cu118 +torchvision==0.22.1+cu118 +torchaudio==2.7.1+cu118 +numpy==2.1.2 +matplotlib==3.10.6 +nibabel==5.3.2 +scikit-image==0.25.2 +tqdm==4.67.1 +pillow==11.0.0 +scipy==1.16.2 +``` + +## For Future Implementations + +Limitations + +- The model was trained using only 2D slices, losing some anatomical context + +Future work + +- Extend the approach to 3D volumetric segmentation +- Incorporate pre-trained encoders for better convergence +- Apply advanced data augmentation techniques, such as elastic deformations as demonstrated in the original U-Net paper to improve generalization +- In training, the plateau after approximately 60 epochs suggests early stopping to shorten training without loss of performance. + +## References + +1. Aladdin Persson. PyTorch Image Segmentation Tutorial with U-NET [Video]. YouTube. https://www.youtube.com/watch?v=IHq1t7NxS8k + - Provided a practical walkthrough of U-Net architecture implementation and training in PyTorch +2. DigitalSreeni. 208 - Multiclass semantic segmentation using U-Net [Video].YouTube. https://www.youtube.com/watch?v=XyX5HNuv-xE + - Provided a practical walkthrough of U-Net architecture implementation with multi-class segmentation in PyTorch + +### Footnotes + +[^1]: :[^1] Ronneberger, O., Fischer, P., & Brox, T. (2015). _U-Net: Convolutional Networks for Biomedical Image Segmentation_. [https://arxiv.org/abs/1505.04597](https://arxiv.org/abs/1505.04597) +[^2]: Hao, R., Namdar, K., Liu, L., Haider, M. A., & Khalvati, F. (2021). A comprehensive study of data augmentation strategies for prostate cancer detection in diffusion-weighted MRI using convolutional neural networks. Journal of Digital Imaging, 34(4), 862–876. https://doi.org/10.1007/s10278-021-00478-7 diff --git a/recognition/unet_s4741911/dataset.py b/recognition/unet_s4741911/dataset.py new file mode 100644 index 000000000..7dfd18a59 --- /dev/null +++ b/recognition/unet_s4741911/dataset.py @@ -0,0 +1,95 @@ +""" +dataset.py + +Defines dataset class for loading and preprocessing 2D Hip MRI prostate Nifti slices. +It supports optional transformations, normalization, categorical conversion, and resizing. +""" + +import torch +from torch.utils.data import Dataset +from utils import load_data_2D +import os +from torchvision.tv_tensors import Image, Mask + + +class HipMRIProstateDataset(Dataset): + """ + Dataset for Hip MRI Prostate 2D Nifti slices. + + Handles loading, normalisation, categorical conversion, and resizing. + """ + + def __init__( + self, + data_dir="/home/groups/comp3710/HipMRI_Study_open/keras_slices_data", + split="train", + transform=None, + subset_size=None, + normImage=True, + categorical=True, + resize_to=(256, 128), + ): + self.data_dir = data_dir + self.split = split + self.transform = transform + self.normImage = normImage + self.categorical = categorical + self.resize_to = resize_to + + # Image and segmentation directories + img_dir = os.path.join(data_dir, f"keras_slices_{split}") + seg_dir = os.path.join(data_dir, f"keras_slices_seg_{split}") + + # File names for images and segments + self.image_files = sorted( + [os.path.join(img_dir, f) for f in os.listdir(img_dir)] + ) + self.seg_files = sorted([os.path.join(seg_dir, f) for f in os.listdir(seg_dir)]) + + total_size = len(self.image_files) + + # Handle subset size before loading + if subset_size is not None and subset_size < total_size: + self.subset_size = subset_size + self.image_files = self.image_files[:subset_size] + self.seg_files = self.seg_files[:subset_size] + print( + f"Using subset of {subset_size} samples from {split} split (out of {total_size})." + ) + else: + self.subset_size = total_size + print(f"Using all {total_size} samples from {split} split.") + + # Load resized, normalised arrays + print(f"Loading {self.subset_size} samples from {split} split...") + self.images = load_data_2D( + self.image_files, + normImage=self.normImage, + categorical=False, + ) + self.seg = load_data_2D( + self.seg_files, + normImage=False, + categorical=self.categorical, + ) + + print( + f"Finished loading {len(self.images)} images and {len(self.seg)} segmentations." + ) + + def __len__(self): + return self.subset_size + + def __getitem__(self, idx): + image = self.images[idx] + seg = self.seg[idx] + + if self.transform: + image = Image(torch.tensor(image, dtype=torch.float32)) + seg = Mask(torch.tensor(seg, dtype=torch.int64)) + image, seg = self.transform(image, seg) + + if self.categorical and seg.ndim == 3: + seg = seg.permute(2, 0, 1) # Change seg shape to (C, H, W) + + return image, seg diff --git a/recognition/unet_s4741911/modules.py b/recognition/unet_s4741911/modules.py new file mode 100644 index 000000000..074de4953 --- /dev/null +++ b/recognition/unet_s4741911/modules.py @@ -0,0 +1,93 @@ +""" +modules.py + +Contains core model architecture for UNET-based segmentation of 2D Hip MRI prostate images. +Uses Double Convolution blocks, Encoder-Decoder structure with skip connections. +""" + +import torch +import torch.nn as nn +import torchvision.transforms.functional as TF + + +# Double Convolution Block +class DoubleConv(nn.Module): + def __init__(self, in_channels, out_channels, dropout_prob=0.2): + super(DoubleConv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, # Bias is not needed with BatchNorm, as it will be cancelled out + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Dropout2d(p=dropout_prob), + ) + + def forward(self, x): + return self.conv(x) + + +class UNET(nn.Module): + def __init__( + self, in_channels=1, out_channels=6, features=[32, 64, 128, 256] + ): # Reduced features compared to original UNET + super(UNET, self).__init__() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + # Encoder part of UNET + for feature in features: + self.downs.append(DoubleConv(in_channels, feature)) + in_channels = feature + + # Decoder part of UNET + for feature in reversed(features): + self.ups.append( + nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2) + ) + self.ups.append(DoubleConv(feature * 2, feature)) + + self.bottleneck = DoubleConv(features[-1], features[-1] * 2) + self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) + + def forward(self, x): + skip_connections = [] + + for down in self.downs: + x = down(x) + skip_connections.append(x) + x = self.pool(x) + + x = self.bottleneck(x) + skip_connections = skip_connections[::-1] + + for idx in range(0, len(self.ups), 2): + x = self.ups[idx](x) + skip_connection = skip_connections[idx // 2] + + if x.shape != skip_connection.shape: + # Resize x to match skip_connection size + x = TF.resize(x, size=skip_connection.shape[2:]) + + # Concatenate along channel dimension + concat_skip = torch.cat((skip_connection, x), dim=1) + x = self.ups[idx + 1](concat_skip) + + x = self.final_conv(x) + return x # Raw logits output diff --git a/recognition/unet_s4741911/predict.py b/recognition/unet_s4741911/predict.py new file mode 100644 index 000000000..c363e452d --- /dev/null +++ b/recognition/unet_s4741911/predict.py @@ -0,0 +1,130 @@ +import os +import random +import torch +import numpy as np +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from dataset import HipMRIProstateDataset +from modules import UNET +from utils import load_model_checkpoint +from train import DiceLoss +import torchvision.transforms.v2 as v2 + +# Configuration +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +BATCH_SIZE = 16 +NUM_CLASSES = 6 +IMAGE_HEIGHT = 256 +IMAGE_WIDTH = 128 +CHECKPOINT_PATH = "checkpoint.pth.tar" + +# Class names for visualisation +CLASS_NAMES = ["Background", "Body", "Bones", "Bladder", "Rectum", "Prostate"] + + +# Set random seeds for reproducibility +torch.manual_seed(55) +np.random.seed(55) +random.seed(55) +if torch.cuda.is_available(): + torch.cuda.manual_seed(55) + + +def evaluate(model, dataloader, device): + """Evaluate model on test set using Dice coefficient.""" + model.eval() + dice_fn = DiceLoss(num_classes=NUM_CLASSES) + dice_scores = [] + + with torch.no_grad(): + for images, masks in dataloader: + images, masks = images.to(device), masks.to(device) + outputs = model(images) + dice_fn(outputs, masks) + dice_per_class = dice_fn.get_last_dice_coeff() + dice_scores.append(dice_per_class) + + dice_scores = np.mean(np.vstack(dice_scores), axis=0) + mean_dice = np.mean(dice_scores) + return dice_scores, mean_dice + + +def visualize_predictions(model, dataset, save_dir="predictions", num_samples=5): + """Save example predictions as image triplets: input, ground truth, prediction.""" + os.makedirs(save_dir, exist_ok=True) + model.eval() + + indices = random.sample(range(len(dataset)), num_samples) + + with torch.no_grad(): + for idx in indices: + image, mask = dataset[idx] + input_tensor = image.unsqueeze(0).to(DEVICE) + pred = torch.argmax(model(input_tensor), dim=1).squeeze(0).cpu().numpy() + + fig, axes = plt.subplots(1, 3, figsize=(10, 4)) + axes[0].imshow(image.squeeze(), cmap="gray") + axes[0].set_title("Input Image") + axes[1].imshow(mask.cpu().numpy(), cmap="jet") + axes[1].set_title("Ground Truth") + axes[2].imshow(pred, cmap="jet") + axes[2].set_title("Prediction") + + for ax in axes: + ax.axis("off") + + plt.tight_layout() + plt.savefig(os.path.join(save_dir, f"sample_{idx}.png")) + plt.close() + + +def main(): + # Define transform for test data + test_transform = v2.Compose( + [ + v2.ToDtype(torch.float32, scale=True), + ] + ) + + # Load test dataset + test_dataset = HipMRIProstateDataset( + data_dir="/home/groups/comp3710/HipMRI_Study_open/keras_slices_data", + split="test", + transform=test_transform, + categorical=False, + normImage=True, + resize_to=(IMAGE_HEIGHT, IMAGE_WIDTH), + ) + + test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) + + # Load model and checkpoint + model = UNET(in_channels=1, out_channels=NUM_CLASSES).to(DEVICE) + load_model_checkpoint(torch.load(CHECKPOINT_PATH), model) + print("Loaded model checkpoint successfully.") + + # Evaluate model + dice_scores, mean_dice = evaluate(model, test_loader, DEVICE) + print(f"\nDice per class: {dict(zip(CLASS_NAMES, dice_scores.round(3)))}") + print(f"Mean Dice coefficient: {mean_dice:.4f}") + print(f"Dice loss (1 - mean Dice): {1 - mean_dice:.4f}") + + # Plot Dice per class + plt.figure(figsize=(10, 6)) + plt.bar(CLASS_NAMES, dice_scores, color="steelblue") + plt.ylabel("Dice Coefficient") + plt.title("Dice Coefficient per Class (Test Set)") + plt.ylim(0, 1) + plt.tight_layout() + plt.savefig("dice_per_class_test.png") + plt.close() + + # Save qualitative visualisations + visualize_predictions( + model, test_dataset, save_dir="visualisation_test", num_samples=6 + ) + print("Saved prediction visualisations to 'visualisation_test/'.") + + +if __name__ == "__main__": + main() diff --git a/recognition/unet_s4741911/train.py b/recognition/unet_s4741911/train.py new file mode 100644 index 000000000..c90f71fff --- /dev/null +++ b/recognition/unet_s4741911/train.py @@ -0,0 +1,374 @@ +""" +train.py +Training script for UNET-based segmentation of 2D Hip MRI prostate images. +Includes data loading, model training loop, validation, checkpointing, and visualization. + +Key Features: +- Uses Focal + Dice Loss for handling class imbalance. +- train() function encapsulates the training process. +- main() function sets up datasets, model, optimizer, and starts training. +""" + +import random +import os +import torch +from torch.utils.data import DataLoader +from dataset import HipMRIProstateDataset +import torchvision.transforms.v2 as v2 +from utils import ( + plot_loss, + show_epoch_predictions, + save_model_checkpoint, + load_model_checkpoint, + plot_dice_per_class, +) +import torch.optim as optim +from modules import UNET +import numpy as np + + +# Hyperparameters +LEARNING_RATE = 1e-4 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +SUBSET_SIZE_TRAIN = None # Use None to load full dataset +SUBSET_SIZE_VAL = None # Use None to load full dataset +BATCH_SIZE_TRAIN = 32 +BATCH_SIZE_VAL = 64 +NUM_EPOCHS = 20 # 100 +NUM_WORKERS = 1 +IMAGE_HEIGHT = 256 +IMAGE_WIDTH = 128 +NUM_CLASSES = 6 + +# Set random seeds for reproducibility +torch.manual_seed(55) +np.random.seed(55) +random.seed(55) +if torch.cuda.is_available(): + torch.cuda.manual_seed(55) + + +class DiceLoss(torch.nn.Module): + """ + Dice loss implementation for multi-class segmentation. + Computes Dice coefficient per class and averages with optional class weights. + + Args: + smooth (float): Smoothing factor to avoid division by zero. + num_classes (int): Number of segmentation classes. + class_weights (torch.Tensor, optional): Weights for each class to handle imbalance. + + Returns: + torch.Tensor: scalarDice loss value. + """ + + def __init__(self, smooth=1e-5, num_classes=6, class_weights=None): + super(DiceLoss, self).__init__() + self.smooth = smooth + self.num_classes = num_classes + self.last_dice_coeff = None + self.class_weights = class_weights + + def forward(self, inputs, targets): + # Apply softmax for probabilities + inputs = torch.softmax(inputs, dim=1) + + # One-hot encode targets + targets_one_hot = ( + torch.nn.functional.one_hot(targets, num_classes=self.num_classes) + .permute(0, 3, 1, 2) + .float() + ) + + # Flatten for spatial dimensions + inputs = inputs.view(inputs.size(0), inputs.size(1), -1) + targets_one_hot = targets_one_hot.view( + targets_one_hot.size(0), self.num_classes, -1 + ) + + # Compute intersection and union for Dice coefficient + intersection = (inputs * targets_one_hot).sum(-1) + total = inputs.sum(-1) + targets_one_hot.sum(-1) + dice_per_class = (2.0 * intersection + self.smooth) / (total + self.smooth) + + # Mask out classes not present in ground truth + mask = (targets_one_hot.sum(-1) > 0).float() + dice_per_class = (dice_per_class * mask).sum(0) / mask.sum(0).clamp(min=1.0) + + self.last_dice_coeff = dice_per_class.detach().cpu().numpy() + + # Apply class weights if provided + if self.class_weights is not None: + dice_loss = (1 - dice_per_class) * self.class_weights + dice_loss = dice_loss.sum() / self.class_weights.sum() + else: + dice_loss = 1 - dice_per_class.mean() + + return dice_loss + + def get_last_dice_coeff(self): + return self.last_dice_coeff + + +class FocalLoss(torch.nn.Module): + """ + Focal Loss implementation for multi-class segmentation. + Focuses training on hard-to-classify examples. + + Args: + weight (torch.Tensor, optional): Weights for each class. + gamma (float): Focusing parameter. + + Returns: + torch.Tensor: scalar Focal loss value. + """ + + def __init__(self, weight=None, gamma=2.0): + super(FocalLoss, self).__init__() + self.weight = weight + self.gamma = gamma + self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction="none") + + def forward(self, inputs, targets): + # Compute per-pixel cross entropy loss + ce_loss = self.ce(inputs, targets) # [B, H, W] + pt = torch.exp(-ce_loss) # model confidence for the true class + focal_loss = ((1 - pt) ** self.gamma) * ce_loss + return focal_loss.mean() + + +class FocalDiceLoss(torch.nn.Module): + """ + Combined Focal and Dice Loss for multi-class segmentation. + + Args: + weight (torch.Tensor, optional): Weights for each class. + dice_factor (float): Weighting factor for Dice loss component. + focal_factor (float): Weighting factor for Focal loss component. + num_classes (int): Number of segmentation classes. + + Returns: + torch.Tensor: scalar combined loss value. + """ + + def __init__(self, weight=None, dice_factor=3.0, focal_factor=1.0, num_classes=6): + super(FocalDiceLoss, self).__init__() + self.focal = FocalLoss(weight=weight) + self.dice = DiceLoss(num_classes=num_classes, class_weights=weight) + self.dice_factor = dice_factor + self.focal_factor = focal_factor + self.last_dice_coeff = None + + def forward(self, inputs, targets): + # Compute focal and dice losses + focal = self.focal(inputs, targets) + dice = self.dice(inputs, targets) + self.last_dice_coeff = self.dice.get_last_dice_coeff() + # Returned weighted sum of losses + return self.focal_factor * focal + self.dice_factor * dice + + def get_last_dice_coeff(self): + return self.last_dice_coeff + + +def train( + model, + train_loader, + validation_loader, + visualize_every=20, + criterion=None, + optimizer=None, +): + """ + Main training loop for UNET model. + """ + train_losses = [] + val_losses = [] + dice_scores_all_epochs = [] + + scaler = torch.amp.GradScaler() if DEVICE == "cuda" else None + + # Fixed validation samples for visualisation across epochs + val_dataset = validation_loader.dataset + _vis_n = 3 # number of samples to visualise + _vis_indices = random.sample(range(len(val_dataset)), _vis_n) + + print( + "Starting training with Batch Norm, ReLU, Dropout, using Focal and Dice Loss..." + ) + + # Load from checkpoint if available + checkpoint_path = "checkpoint.pth.tar" + if os.path.exists(checkpoint_path): + try: + load_model_checkpoint(torch.load(checkpoint_path), model) + except Exception as e: + print(f"Could not load checkpoint ({e}). Starting from scratch.") + else: + print("No checkpoint found. Starting from scratch.") + + for epoch in range(NUM_EPOCHS): + model.train() + epoch_loss = 0.0 + for images, masks in train_loader: + images, masks = images.to(DEVICE), masks.to(DEVICE) + + optimizer.zero_grad() + + with torch.amp.autocast(enabled=(scaler is not None), device_type=DEVICE): + outputs = model(images) # [B, C, H, W] + loss = criterion(outputs, masks) + + # Backward pass and optimization + if scaler is not None: + scaler.scale(loss).backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + + epoch_loss += loss.item() + + avg_train_loss = epoch_loss / len(train_loader) + train_losses.append(avg_train_loss) + + # Validate + model.eval() + with torch.no_grad(): + val_loss = 0.0 + dice_scores = [] + + for idx, (val_images, val_masks) in enumerate(validation_loader): + val_images, val_masks = val_images.to(DEVICE), val_masks.to(DEVICE) + + val_outputs = model(val_images) + loss = criterion(val_outputs, val_masks) + val_loss += loss.item() + + # Collect per-class Dice scores + if hasattr(criterion, "get_last_dice_coeff"): + dice_per_class = criterion.get_last_dice_coeff() + if dice_per_class is not None: + dice_scores.append(dice_per_class) + + # Compute average Dice per class across validation batches + if dice_scores: + dice_scores = np.mean(np.vstack(dice_scores), axis=0) + dice_scores_all_epochs.append(dice_scores) + print(f"Dice per class at epoch: {dice_scores}") + + avg_val_loss = val_loss / len(validation_loader) + val_losses.append(avg_val_loss) + + print( + f"Epoch [{epoch + 1}/{NUM_EPOCHS}], " + f"Train Loss: {avg_train_loss:.4f}, " + f"Validation Loss: {avg_val_loss:.4f}" + ) + + # Save checkpoint if best validation loss improves + if epoch == 0 or avg_val_loss < min(val_losses[:-1]): + checkpoint = { + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + save_model_checkpoint(checkpoint) + print( + f"Saved new best model checkpoint at epoch {epoch + 1} (val loss: {avg_val_loss:.4f})" + ) + + # Visualize predictions + if (epoch + 1) % visualize_every == 0: + show_epoch_predictions(model, val_dataset, epoch + 1, indices=_vis_indices) + + print("Training complete with U-NET architecture.") + + plot_dice_per_class(dice_scores_all_epochs) + + return train_losses, val_losses + + +def main(): + + train_transform = v2.Compose( + [ + v2.RandomHorizontalFlip(p=0.5), + v2.RandomVerticalFlip(p=0.5), + v2.RandomRotation(degrees=(-10, 10)), + v2.ToDtype(torch.float32, scale=True), + ] + ) + + val_transform = v2.Compose( + [ + v2.ToDtype(torch.float32, scale=True), + ] + ) + + train_dataset = HipMRIProstateDataset( + data_dir="/home/groups/comp3710/HipMRI_Study_open/keras_slices_data", + split="train", + transform=train_transform, + subset_size=SUBSET_SIZE_TRAIN, + categorical=False, # (256, 128) masks + normImage=True, + resize_to=(IMAGE_HEIGHT, IMAGE_WIDTH), + ) + # Shape checking + sample_image, sample_mask = train_dataset[0] + print(f"Sample image shape: {sample_image.shape}") # ([1, 256, 128]) + print(f"Sample mask shape: {sample_mask.shape}") # ([256, 128]) + + train_loader = DataLoader( + train_dataset, + batch_size=BATCH_SIZE_TRAIN, + shuffle=True, + num_workers=NUM_WORKERS, + pin_memory=True, + ) + + validate_dataset = HipMRIProstateDataset( + data_dir="/home/groups/comp3710/HipMRI_Study_open/keras_slices_data", + split="validate", + transform=val_transform, + subset_size=SUBSET_SIZE_VAL, + normImage=True, + categorical=False, + resize_to=(IMAGE_HEIGHT, IMAGE_WIDTH), + ) + validate_loader = DataLoader( + validate_dataset, + batch_size=BATCH_SIZE_VAL, + shuffle=False, + num_workers=NUM_WORKERS, + pin_memory=True, + ) + + # Compute class weights for CrossEntropyLoss + class_weights = torch.tensor([1, 1, 1, 3, 9, 10], dtype=torch.float).to(DEVICE) + + print(f"Class weights: {class_weights}") + + model = UNET(in_channels=1, out_channels=NUM_CLASSES).to(DEVICE) + criterion = FocalDiceLoss( + weight=class_weights, dice_factor=3.0, focal_factor=1.0, num_classes=NUM_CLASSES + ) + + optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) + + train_losses, val_losses = train( + model, + train_loader, + validate_loader, + visualize_every=10, + criterion=criterion, + optimizer=optimizer, + ) + plot_loss(train_losses, val_losses, loss_type="combined") + + +if __name__ == "__main__": + main() diff --git a/recognition/unet_s4741911/utils.py b/recognition/unet_s4741911/utils.py new file mode 100644 index 000000000..9d9297cef --- /dev/null +++ b/recognition/unet_s4741911/utils.py @@ -0,0 +1,253 @@ +""" +utils.py +Utility functions for data loading, preprocessing, model checkpointing, +and visualization for UNET-based segmentation of 2D Hip MRI prostate images. +""" + +import numpy as np +import nibabel as nib +from tqdm import tqdm +from skimage.transform import resize +import random +import matplotlib.pyplot as plt +import torch + + +def to_channels(arr: np.ndarray, dtype=np.uint8, num_classes: int = 6) -> np.ndarray: + """Convert a label array to one-hot encoded channels.""" + res = np.zeros(arr.shape + (num_classes,), dtype=dtype) + + for c in range(num_classes): + res[..., c : c + 1][arr == c] = 1 + + return res + + +# Load medical image functions +def load_data_2D( + imageNames, + normImage=False, + categorical=False, + dtype=np.float32, + getAffines=False, + early_stop=False, + resize_to=(256, 128), +): + """ + Load medical image data from names, cases list provided into a list for each. + + This function pre-allocates 4D arrays for conv2d to avoid excessive memory usage. + + normImage : bool ( normalise the image 0.0 -1.0) + + early_stop : Stop loading pre-maturely, leaves arrays mostly empty, for quick + loading and testing scripts. + """ + affines = [] + + # Get fixed size + num = len(imageNames) + + first_case = nib.load(imageNames[0]).get_fdata(caching="unchanged") + + if len(first_case.shape) == 3: + first_case = first_case[:, :, 0] # Sometimes extra dimension, take first slice + + # # Resize image + # first_case = resize(first_case, resize_to, mode="constant", preserve_range=True) + + if categorical: + first_case = to_channels(first_case, dtype=dtype, num_classes=6) + rows, cols, channels = first_case.shape + images = np.zeros((num, rows, cols, channels), dtype=dtype) + else: + rows, cols = first_case.shape + images = np.zeros((num, rows, cols), dtype=dtype) + + for i, inName in enumerate( + tqdm(imageNames, desc="Loading images", ncols=100, mininterval=1) + ): + niftiImage = nib.load(inName) + inImage = niftiImage.get_fdata(caching="unchanged") # Read disk only + affine = niftiImage.affine + if len(inImage.shape) == 3: + # Sometimes extra dimension in HipMRI, take first slice + inImage = inImage[:, :, 0] + + inImage = resize( + inImage, + resize_to, + mode="constant", + preserve_range=True, + ) + inImage = inImage.astype(dtype) + + if normImage: + # ~ inImage = inImage / np.linalg.norm(inImage) + # ~ inImage = 255. * inImage / inImage.max() + inImage = (inImage - inImage.mean()) / inImage.std() + + if categorical: + inImage = to_channels(inImage, dtype=dtype, num_classes=6) + images[i, :, :, :] = inImage + else: + images[i, :, :] = inImage + + affines.append(affine) + + if i > 20 and early_stop: + break + + if getAffines: + return images, affines + else: + return images + + +def show_epoch_predictions(model, dataset, epoch, n=3, device="cuda", indices=None): + """ + Show model predictions of validation set after specified epoch. + Args: + model: Trained UNET model. + dataset: Dataset to visualize predictions on. + epoch: Current epoch number for title. + n: Number of samples to display. + device: Device to run model on. + indices: Specific dataset indices to visualize. If None, random samples are chosen. + """ + model.eval() + fig, axes = plt.subplots(3, n, figsize=(12, 6)) + fig.suptitle( + f"Model Predictions after Epoch {epoch}", fontsize=16, fontweight="bold" + ) + + if indices is None: + indices = random.sample(range(len(dataset)), n) + + with torch.no_grad(): + for i, idx in enumerate(indices): + image, true_mask = dataset[idx] + + # Output of model is logits for multiple classes + pred = model(image.unsqueeze(0).to(device)) # Model expects batch dimension + # Predicted class per pixel + pred_mask = torch.argmax(pred, dim=1)[0].cpu().numpy() + + # Denormalize image for visualisation + img_show = (image - image.min()) / (image.max() - image.min() + 1e-8) + + # Transpose from CHW to HWC for plotting + img_np = img_show.permute(1, 2, 0).squeeze().cpu().numpy() # H x W x C + + # Show the original image + axes[0, i].imshow(img_np, cmap="gray") + axes[0, i].set_title(f"Image {idx}", fontweight="bold") + axes[0, i].axis("off") + + # Show the ground truth mask + axes[1, i].imshow(true_mask.squeeze(0).cpu(), cmap="gray", vmin=0, vmax=5) + axes[1, i].set_title(f"Ground Truth Mask {idx}", fontweight="bold") + axes[1, i].axis("off") + + # Show the predicted mask + axes[2, i].imshow(pred_mask, cmap="gray", vmin=0, vmax=5) + + true_np = true_mask.squeeze(0).cpu().numpy() + pred_np = pred_mask + + # Binary masks for prostate class (class 5) + true_np = (true_np == 5).astype(np.uint8) + prostate_present = true_np.any() # Check if prostate is present + pred_mask = (pred_np == 5).astype(np.uint8) + + intersection = np.logical_and(pred_mask, true_np).sum() + dice_coeff = (2.0 * intersection) / (pred_mask.sum() + true_np.sum() + 1e-6) + + presence_text = "Present" if prostate_present else "Absent" + + axes[2, i].set_title( + f"Predicted Mask {idx}\nClass 5 Dice Coeff: {dice_coeff:.3f}\nProstate Presence: {presence_text}", + fontweight="bold", + ) + axes[2, i].axis("off") + plt.tight_layout() + plt.savefig(f"epoch_{epoch}_predictions.png") + plt.close() + + model.train() # Switch back to train mode + return + + +def plot_loss(train_loss, val_loss, loss_type="dice"): + """ + Plot training and validation loss curves across epochs. + + Args: + train_loss (list): List of training loss values per epoch. + val_loss (list): List of validation loss values per epoch. + loss_type (str): Type of loss for title and filename. + """ + plt.figure(figsize=(8, 5)) + + # Plot both loss curves + plt.plot(train_loss, "bo-", label="Training Loss", linewidth=2, markersize=6) + plt.plot(val_loss, "ro-", label="Validation Loss", linewidth=2, markersize=6) + + title_map = { + "ce": "Training vs Validation Loss (CE)", + "dice": "Training vs Validation Loss (Dice)", + "combined": "Training vs Validation Loss (Combined CE + Dice)", + } + + plt.title( + title_map.get(loss_type, "Training vs Validation Loss"), + fontsize=14, + fontweight="bold", + ) + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.grid(True, alpha=0.3) + plt.legend() + plt.savefig(f"loss_plot_{loss_type}.png") + plt.close() + return + + +def save_model_checkpoint(state, filename="checkpoint.pth.tar"): + """Save model checkpoint.""" + torch.save(state, filename) + print(f"=> Saved model checkpoint to {filename}") + return + + +def load_model_checkpoint(checkpoint, model): + """Load model checkpoint.""" + model.load_state_dict(checkpoint["state_dict"]) + print(f"=> Loaded model checkpoint") + return + + +def plot_dice_per_class(dice_scores_all_epochs): + """ + Plot Dice coefficient per class across epochs. + Args: + dice_scores_all_epochs (list of lists): Dice scores per class for each epoch. + """ + dice_scores_all_epochs = np.array(dice_scores_all_epochs) + epochs = dice_scores_all_epochs.shape[0] + classes = dice_scores_all_epochs.shape[1] + + plt.figure(figsize=(10, 5)) + for class_idx in range(classes): + plt.plot( + range(1, epochs + 1), + dice_scores_all_epochs[:, class_idx], + label=f"Class {class_idx}", + ) + plt.xlabel("Epoch") + plt.ylabel("Dice Coefficient") + plt.title("Dice Coefficient per Class Across Epochs") + plt.legend() + plt.savefig("dice_per_class_plot.png") + plt.close() + return diff --git a/recognition/unet_s4741911/visualisation_test/dice_per_class_test.png b/recognition/unet_s4741911/visualisation_test/dice_per_class_test.png new file mode 100644 index 000000000..2b5c9dd14 Binary files /dev/null and b/recognition/unet_s4741911/visualisation_test/dice_per_class_test.png differ diff --git a/recognition/unet_s4741911/visualisation_test/sample_153.png b/recognition/unet_s4741911/visualisation_test/sample_153.png new file mode 100644 index 000000000..46f7277f1 Binary files /dev/null and b/recognition/unet_s4741911/visualisation_test/sample_153.png differ diff --git a/recognition/unet_s4741911/visualisation_test/sample_188.png b/recognition/unet_s4741911/visualisation_test/sample_188.png new file mode 100644 index 000000000..43a6eb492 Binary files /dev/null and b/recognition/unet_s4741911/visualisation_test/sample_188.png differ diff --git a/recognition/unet_s4741911/visualisation_test/sample_200.png b/recognition/unet_s4741911/visualisation_test/sample_200.png new file mode 100644 index 000000000..58b5addfa Binary files /dev/null and b/recognition/unet_s4741911/visualisation_test/sample_200.png differ diff --git a/recognition/unet_s4741911/visualisation_test/sample_309.png b/recognition/unet_s4741911/visualisation_test/sample_309.png new file mode 100644 index 000000000..02331d486 Binary files /dev/null and b/recognition/unet_s4741911/visualisation_test/sample_309.png differ diff --git a/recognition/unet_s4741911/visualisation_test/sample_81.png b/recognition/unet_s4741911/visualisation_test/sample_81.png new file mode 100644 index 000000000..cb2b47bc5 Binary files /dev/null and b/recognition/unet_s4741911/visualisation_test/sample_81.png differ diff --git a/recognition/unet_s4741911/visualisation_test/sample_92.png b/recognition/unet_s4741911/visualisation_test/sample_92.png new file mode 100644 index 000000000..d7253e5d7 Binary files /dev/null and b/recognition/unet_s4741911/visualisation_test/sample_92.png differ diff --git a/recognition/unet_s4741911/visualisation_train/dice_per_class_plot.png b/recognition/unet_s4741911/visualisation_train/dice_per_class_plot.png new file mode 100644 index 000000000..c6511baa6 Binary files /dev/null and b/recognition/unet_s4741911/visualisation_train/dice_per_class_plot.png differ diff --git a/recognition/unet_s4741911/visualisation_train/epoch_100_predictions.png b/recognition/unet_s4741911/visualisation_train/epoch_100_predictions.png new file mode 100644 index 000000000..0b9a3739e Binary files /dev/null and b/recognition/unet_s4741911/visualisation_train/epoch_100_predictions.png differ diff --git a/recognition/unet_s4741911/visualisation_train/epoch_10_predictions.png b/recognition/unet_s4741911/visualisation_train/epoch_10_predictions.png new file mode 100644 index 000000000..ae974963d Binary files /dev/null and b/recognition/unet_s4741911/visualisation_train/epoch_10_predictions.png differ diff --git a/recognition/unet_s4741911/visualisation_train/epoch_20_predictions.png b/recognition/unet_s4741911/visualisation_train/epoch_20_predictions.png new file mode 100644 index 000000000..b5d6dfb5a Binary files /dev/null and b/recognition/unet_s4741911/visualisation_train/epoch_20_predictions.png differ diff --git a/recognition/unet_s4741911/visualisation_train/epoch_30_predictions.png b/recognition/unet_s4741911/visualisation_train/epoch_30_predictions.png new file mode 100644 index 000000000..4056b672c Binary files /dev/null and b/recognition/unet_s4741911/visualisation_train/epoch_30_predictions.png differ diff --git a/recognition/unet_s4741911/visualisation_train/epoch_40_predictions.png b/recognition/unet_s4741911/visualisation_train/epoch_40_predictions.png new file mode 100644 index 000000000..ac37cae2a Binary files /dev/null and b/recognition/unet_s4741911/visualisation_train/epoch_40_predictions.png differ diff --git a/recognition/unet_s4741911/visualisation_train/epoch_50_predictions.png b/recognition/unet_s4741911/visualisation_train/epoch_50_predictions.png new file mode 100644 index 000000000..9df4498d2 Binary files /dev/null and b/recognition/unet_s4741911/visualisation_train/epoch_50_predictions.png differ diff --git a/recognition/unet_s4741911/visualisation_train/epoch_60_predictions.png b/recognition/unet_s4741911/visualisation_train/epoch_60_predictions.png new file mode 100644 index 000000000..de0fa2562 Binary files /dev/null and b/recognition/unet_s4741911/visualisation_train/epoch_60_predictions.png differ diff --git a/recognition/unet_s4741911/visualisation_train/epoch_70_predictions.png b/recognition/unet_s4741911/visualisation_train/epoch_70_predictions.png new file mode 100644 index 000000000..125de8c39 Binary files /dev/null and b/recognition/unet_s4741911/visualisation_train/epoch_70_predictions.png differ diff --git a/recognition/unet_s4741911/visualisation_train/epoch_80_predictions.png b/recognition/unet_s4741911/visualisation_train/epoch_80_predictions.png new file mode 100644 index 000000000..d2bf9051c Binary files /dev/null and b/recognition/unet_s4741911/visualisation_train/epoch_80_predictions.png differ diff --git a/recognition/unet_s4741911/visualisation_train/epoch_90_predictions.png b/recognition/unet_s4741911/visualisation_train/epoch_90_predictions.png new file mode 100644 index 000000000..cab1456dc Binary files /dev/null and b/recognition/unet_s4741911/visualisation_train/epoch_90_predictions.png differ diff --git a/recognition/unet_s4741911/visualisation_train/loss_plot_combined_FL_DL.png b/recognition/unet_s4741911/visualisation_train/loss_plot_combined_FL_DL.png new file mode 100644 index 000000000..0f842ea77 Binary files /dev/null and b/recognition/unet_s4741911/visualisation_train/loss_plot_combined_FL_DL.png differ