From 488064d33da231b309c9bd14813877df6213878c Mon Sep 17 00:00:00 2001 From: Julius Steiglechner <julius.steiglechner@tuebingen.mpg.de> Date: Wed, 29 Jan 2025 15:44:12 +0100 Subject: [PATCH] Wrap most important segmentation metrics into one function. --- evaluate_segmentation_quality.py | 330 +++++++++++++++++++++++++++++++ 1 file changed, 330 insertions(+) create mode 100644 evaluate_segmentation_quality.py diff --git a/evaluate_segmentation_quality.py b/evaluate_segmentation_quality.py new file mode 100644 index 0000000..9b7ec00 --- /dev/null +++ b/evaluate_segmentation_quality.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +The aim of this module is to provide a method to evaluate a model on a dataset. + +Created on Wed Jan 29 10:10:38 2025 + +@author: jsteiglechner +""" + +from typing import Callable, Dict, List + +import numpy as np +import torch + +from segmentation_quality_measures.information_theoretic_based_metrics import ( + variation_of_information, +) +from segmentation_quality_measures.prepare_tensors import ( + encode_to_batch_times_class_times_spatial, +) +from segmentation_quality_measures.spatial_distance_based_metrics import ( + hausdorff_distance_metric, +) +from segmentation_quality_measures.spatial_overlap_based_metrics import ( + intersection_over_union, + boundary_intersection_over_union, +) + + +def calculate_batch_intersection_over_union( + output: torch.Tensor, + target: torch.Tensor, + num_types: int, + considered_types: torch.Tensor, +): + """Calculate the segmentation accuracy.""" + output_one_hot = encode_to_batch_times_class_times_spatial( + output, num_classes=num_types) + target_one_hot = encode_to_batch_times_class_times_spatial( + target, num_classes=num_types) + + scores = intersection_over_union( + output_one_hot, + target_one_hot, + labels=considered_types, + batch_iou=True, + reduction='none', + ) + + return torch.mean(scores, dim=1) + + +def evaluate_prediction_iou( + output: torch.Tensor, + target: torch.Tensor, + predicted_types: Dict[int, str], + considered_types: List[int] = None, +) -> Dict[str, float]: + """ + Calculate Intersection over Union for multiple labels. + + Parameters + ---------- + output : torch.Tensor + Prediction. + target : torch.Tensor + Reference. + predicted_types : Dict[int, str] + All values that can be in output and target with corresponding names. + considered_types : List[int], optional + All values that should be evaluated. The default is None. + + Returns + ------- + Dict[str, float] + Metrics string to value. + + """ + iou = intersection_over_union( + output=output, + target=target, + labels=considered_types, + non_presence_threshold=torch.prod( + torch.div( + torch.tensor(output.shape[2:]), + 100, + rounding_mode="floor", + ) + ), + reduction="none", + ) + mean_iou = torch.nanmean(iou) + + metrics = {} + metrics["iou_mean"] = mean_iou.cpu().tolist() + for i, label in enumerate(considered_types): + metrics["iou_" + predicted_types[label.item()]] = iou[i].cpu().tolist() + + return metrics + + +def evaluate_prediction_overlap_based( + output: torch.Tensor, + target: torch.Tensor, + predicted_types: Dict[int, str], + considered_types: List[int] = None, +) -> Dict[str, float]: + """ + Evaluate prediction with overlap based segmentation metrics. + + Notes + ----- + Metrics that where used: + + - Intersection over union + - boundary intersection over union + + Parameters + ---------- + output : torch.Tensor + tensor with labels one-hot encoded with shape 1 x C x SP. + target : torch.Tensor + tensor with labels one-hot encoded with shape 1 x C x SP. + predicted_types : Dict[int, str] + Types that where predicted with names. + considered_types : torch.Tensor, optional + Types that should be selected if metric allows for selection. Default + is None. + + Returns + ------- + metrics : Dict[str, float] + dictionary which map metrics and considered labels. + + """ + metrics = {} + + iou = intersection_over_union( + output=output, + target=target, + labels=considered_types, + non_presence_threshold=torch.prod( + torch.div( + torch.tensor(output.shape[2:]), + 100, + rounding_mode="floor", + ) + ), + reduction="none", + ) + mean_iou = torch.nanmean(iou) + + boundary_iou = boundary_intersection_over_union( + output=output, + target=target, + boundary_width=1, + kernel_type="box", + labels=considered_types, + non_presence_threshold=8, + reduction="none", + ) + mean_boundary_iou = torch.nanmean(boundary_iou) + + metrics["iou_mean"] = mean_iou.cpu().tolist() + metrics["iou_boundary_mean"] = mean_boundary_iou.cpu().tolist() + + for i, label in enumerate(considered_types): + metrics["iou_" + predicted_types[label.item()]] = iou[i].cpu().tolist() + metrics["iou_boundary_" + predicted_types[label.item()] + ] = boundary_iou[i].cpu().tolist() + + return metrics + + +def evaluate_prediction( + output: torch.Tensor, + target: torch.Tensor, + predicted_types: Dict[int, str], + considered_types: List[int] = None, +) -> Dict[str, float]: + """ + Evaluate prediction with relevant segmentation metrics. + + Notes + ----- + Metrics that where used: + + - Intersection over union + - 95th percentile of Hausdorff distance (surface distance) + - Variation of information + + Parameters + ---------- + output : torch.Tensor + tensor with labels one-hot encoded with shape 1 x C x SP. + target : torch.Tensor + tensor with labels one-hot encoded with shape 1 x C x SP. + predicted_types : Dict[int, str] + Types that where predicted with names. + considered_types : torch.Tensor, optional + Types that should be selected if metric allows for selection. Default + is None. + + Returns + ------- + metrics : Dict[str, float] + dictionary which map metrics and considered labels. + + """ + metrics = {} + + iou = intersection_over_union( + output=output, + target=target, + labels=considered_types, + non_presence_threshold=torch.prod( + torch.div( + torch.tensor(output.shape[2:]), + 100, + rounding_mode="floor", + ) + ), + reduction="none", + ) + mean_iou = torch.nanmean(iou) + + boundary_iou = boundary_intersection_over_union( + output=output, + target=target, + boundary_width=1, + kernel_type="box", + labels=considered_types, + non_presence_threshold=8, + reduction="none", + ) + mean_boundary_iou = torch.nanmean(boundary_iou) + + hd = hausdorff_distance_metric( + output=output, + target=target, + percentile=95, + labels=considered_types, + label_reduction="none", + reduction="none", + directed=False, + ).squeeze() + mean_hd = torch.nanmean(hd) + + voi = variation_of_information( + output=output, + target=target, + labels=None, # not implemented yet + normalization=False, + reduction="none", + ) + + metrics["iou_mean"] = mean_iou.cpu().tolist() + metrics["iou_boundary_mean"] = mean_boundary_iou.cpu().tolist() + metrics["hd_mean"] = mean_hd.cpu().tolist() + metrics["voi"] = voi.cpu().tolist() + + for i, label in enumerate(considered_types): + metrics["iou_" + predicted_types[label.item()]] = iou[i].cpu().tolist() + metrics["iou_boundary_" + predicted_types[label.item()] + ] = boundary_iou[i].cpu().tolist() + metrics["hd_" + predicted_types[label.item()] + ] = hd[i].cpu().tolist() + + return metrics + + +def evaluate_prediction_from_array( + y_pred: np.ndarray, + target: np.ndarray, + device: torch.device, + predicted_types: Dict[int, str], + considered_types: List[int] = None, + accuracy_fn: Callable = None, +) -> Dict[str, float]: + """ + Evaluate predicted result against reference. + + Parameters + ---------- + y_pred : np.ndarray + Array of prediction. + target : np.ndarray + Array of reference. + device : torch.device + Device which computes. + predicted_types : Dict[int, str] + Types that where predicted with names. + considered_types : List[int], optional + Types that should be selected if metric allows for selection. The + default is None. + accuracy_fn : Callable, optional + Performance metric. If None, there is a standard evaluation process. + The default is None. + + Returns + ------- + metrics : Dict[str, float] + Dictionary which map metrics and considered labels. + + """ + with torch.no_grad(): + y_pred = torch.from_numpy(y_pred).to(torch.long).to(device) + target = torch.from_numpy(target).to(torch.long).to(device) + + y_pred = encode_to_batch_times_class_times_spatial( + y_pred, + len(predicted_types), + ) + target = encode_to_batch_times_class_times_spatial( + target, + len(predicted_types), + ) + + if accuracy_fn is None: + metrics = evaluate_prediction( + y_pred, + target, + predicted_types, + considered_types, + ) + else: + metrics = accuracy_fn(output=y_pred, target=target) + + return metrics -- GitLab