From 488064d33da231b309c9bd14813877df6213878c Mon Sep 17 00:00:00 2001
From: Julius Steiglechner <julius.steiglechner@tuebingen.mpg.de>
Date: Wed, 29 Jan 2025 15:44:12 +0100
Subject: [PATCH] Wrap most important segmentation metrics into one function.

---
 evaluate_segmentation_quality.py | 330 +++++++++++++++++++++++++++++++
 1 file changed, 330 insertions(+)
 create mode 100644 evaluate_segmentation_quality.py

diff --git a/evaluate_segmentation_quality.py b/evaluate_segmentation_quality.py
new file mode 100644
index 0000000..9b7ec00
--- /dev/null
+++ b/evaluate_segmentation_quality.py
@@ -0,0 +1,330 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+The aim of this module is to provide a method to evaluate a model on a dataset.
+
+Created on Wed Jan 29 10:10:38 2025
+
+@author: jsteiglechner
+"""
+
+from typing import Callable, Dict, List
+
+import numpy as np
+import torch
+
+from segmentation_quality_measures.information_theoretic_based_metrics import (
+    variation_of_information,
+)
+from segmentation_quality_measures.prepare_tensors import (
+    encode_to_batch_times_class_times_spatial,
+)
+from segmentation_quality_measures.spatial_distance_based_metrics import (
+    hausdorff_distance_metric,
+)
+from segmentation_quality_measures.spatial_overlap_based_metrics import (
+    intersection_over_union,
+    boundary_intersection_over_union,
+)
+
+
+def calculate_batch_intersection_over_union(
+        output: torch.Tensor,
+        target: torch.Tensor,
+        num_types: int,
+        considered_types: torch.Tensor,
+):
+    """Calculate the segmentation accuracy."""
+    output_one_hot = encode_to_batch_times_class_times_spatial(
+        output, num_classes=num_types)
+    target_one_hot = encode_to_batch_times_class_times_spatial(
+        target, num_classes=num_types)
+
+    scores = intersection_over_union(
+        output_one_hot,
+        target_one_hot,
+        labels=considered_types,
+        batch_iou=True,
+        reduction='none',
+    )
+
+    return torch.mean(scores, dim=1)
+
+
+def evaluate_prediction_iou(
+        output: torch.Tensor,
+        target: torch.Tensor,
+        predicted_types: Dict[int, str],
+        considered_types: List[int] = None,
+) -> Dict[str, float]:
+    """
+    Calculate Intersection over Union for multiple labels.
+
+    Parameters
+    ----------
+    output : torch.Tensor
+        Prediction.
+    target : torch.Tensor
+        Reference.
+    predicted_types : Dict[int, str]
+        All values that can be in output and target with corresponding names.
+    considered_types : List[int], optional
+        All values that should be evaluated. The default is None.
+
+    Returns
+    -------
+    Dict[str, float]
+        Metrics string to value.
+
+    """
+    iou = intersection_over_union(
+        output=output,
+        target=target,
+        labels=considered_types,
+        non_presence_threshold=torch.prod(
+            torch.div(
+                torch.tensor(output.shape[2:]),
+                100,
+                rounding_mode="floor",
+            )
+        ),
+        reduction="none",
+    )
+    mean_iou = torch.nanmean(iou)
+
+    metrics = {}
+    metrics["iou_mean"] = mean_iou.cpu().tolist()
+    for i, label in enumerate(considered_types):
+        metrics["iou_" + predicted_types[label.item()]] = iou[i].cpu().tolist()
+
+    return metrics
+
+
+def evaluate_prediction_overlap_based(
+    output: torch.Tensor,
+    target: torch.Tensor,
+    predicted_types: Dict[int, str],
+    considered_types: List[int] = None,
+) -> Dict[str, float]:
+    """
+    Evaluate prediction with overlap based segmentation metrics.
+
+    Notes
+    -----
+    Metrics that where used:
+
+    - Intersection over union
+    - boundary intersection over union
+
+    Parameters
+    ----------
+    output : torch.Tensor
+        tensor with labels one-hot encoded with shape 1 x C x SP.
+    target : torch.Tensor
+        tensor with labels one-hot encoded with shape 1 x C x SP.
+    predicted_types : Dict[int, str]
+        Types that where predicted with names.
+    considered_types : torch.Tensor, optional
+        Types that should be selected if metric allows for selection. Default
+        is None.
+
+    Returns
+    -------
+    metrics : Dict[str, float]
+        dictionary which map metrics and considered labels.
+
+    """
+    metrics = {}
+
+    iou = intersection_over_union(
+        output=output,
+        target=target,
+        labels=considered_types,
+        non_presence_threshold=torch.prod(
+            torch.div(
+                torch.tensor(output.shape[2:]),
+                100,
+                rounding_mode="floor",
+            )
+        ),
+        reduction="none",
+    )
+    mean_iou = torch.nanmean(iou)
+
+    boundary_iou = boundary_intersection_over_union(
+        output=output,
+        target=target,
+        boundary_width=1,
+        kernel_type="box",
+        labels=considered_types,
+        non_presence_threshold=8,
+        reduction="none",
+    )
+    mean_boundary_iou = torch.nanmean(boundary_iou)
+
+    metrics["iou_mean"] = mean_iou.cpu().tolist()
+    metrics["iou_boundary_mean"] = mean_boundary_iou.cpu().tolist()
+
+    for i, label in enumerate(considered_types):
+        metrics["iou_" + predicted_types[label.item()]] = iou[i].cpu().tolist()
+        metrics["iou_boundary_" + predicted_types[label.item()]
+                ] = boundary_iou[i].cpu().tolist()
+
+    return metrics
+
+
+def evaluate_prediction(
+    output: torch.Tensor,
+    target: torch.Tensor,
+    predicted_types: Dict[int, str],
+    considered_types: List[int] = None,
+) -> Dict[str, float]:
+    """
+    Evaluate prediction with relevant segmentation metrics.
+
+    Notes
+    -----
+    Metrics that where used:
+
+    - Intersection over union
+    - 95th percentile of Hausdorff distance (surface distance)
+    - Variation of information
+
+    Parameters
+    ----------
+    output : torch.Tensor
+        tensor with labels one-hot encoded with shape 1 x C x SP.
+    target : torch.Tensor
+        tensor with labels one-hot encoded with shape 1 x C x SP.
+    predicted_types : Dict[int, str]
+        Types that where predicted with names.
+    considered_types : torch.Tensor, optional
+        Types that should be selected if metric allows for selection. Default
+        is None.
+
+    Returns
+    -------
+    metrics : Dict[str, float]
+        dictionary which map metrics and considered labels.
+
+    """
+    metrics = {}
+
+    iou = intersection_over_union(
+        output=output,
+        target=target,
+        labels=considered_types,
+        non_presence_threshold=torch.prod(
+            torch.div(
+                torch.tensor(output.shape[2:]),
+                100,
+                rounding_mode="floor",
+            )
+        ),
+        reduction="none",
+    )
+    mean_iou = torch.nanmean(iou)
+
+    boundary_iou = boundary_intersection_over_union(
+        output=output,
+        target=target,
+        boundary_width=1,
+        kernel_type="box",
+        labels=considered_types,
+        non_presence_threshold=8,
+        reduction="none",
+    )
+    mean_boundary_iou = torch.nanmean(boundary_iou)
+
+    hd = hausdorff_distance_metric(
+        output=output,
+        target=target,
+        percentile=95,
+        labels=considered_types,
+        label_reduction="none",
+        reduction="none",
+        directed=False,
+    ).squeeze()
+    mean_hd = torch.nanmean(hd)
+
+    voi = variation_of_information(
+        output=output,
+        target=target,
+        labels=None,  # not implemented yet
+        normalization=False,
+        reduction="none",
+    )
+
+    metrics["iou_mean"] = mean_iou.cpu().tolist()
+    metrics["iou_boundary_mean"] = mean_boundary_iou.cpu().tolist()
+    metrics["hd_mean"] = mean_hd.cpu().tolist()
+    metrics["voi"] = voi.cpu().tolist()
+
+    for i, label in enumerate(considered_types):
+        metrics["iou_" + predicted_types[label.item()]] = iou[i].cpu().tolist()
+        metrics["iou_boundary_" + predicted_types[label.item()]
+                ] = boundary_iou[i].cpu().tolist()
+        metrics["hd_" + predicted_types[label.item()]
+                ] = hd[i].cpu().tolist()
+
+    return metrics
+
+
+def evaluate_prediction_from_array(
+    y_pred: np.ndarray,
+    target: np.ndarray,
+    device: torch.device,
+    predicted_types: Dict[int, str],
+    considered_types: List[int] = None,
+    accuracy_fn: Callable = None,
+) -> Dict[str, float]:
+    """
+    Evaluate predicted result against reference.
+
+    Parameters
+    ----------
+    y_pred : np.ndarray
+        Array of prediction.
+    target : np.ndarray
+        Array of reference.
+    device : torch.device
+        Device which computes.
+    predicted_types : Dict[int, str]
+        Types that where predicted with names.
+    considered_types : List[int], optional
+        Types that should be selected if metric allows for selection. The
+        default is None.
+    accuracy_fn : Callable, optional
+        Performance metric. If None, there is a standard evaluation process.
+        The default is None.
+
+    Returns
+    -------
+    metrics : Dict[str, float]
+        Dictionary which map metrics and considered labels.
+
+    """
+    with torch.no_grad():
+        y_pred = torch.from_numpy(y_pred).to(torch.long).to(device)
+        target = torch.from_numpy(target).to(torch.long).to(device)
+
+        y_pred = encode_to_batch_times_class_times_spatial(
+            y_pred,
+            len(predicted_types),
+        )
+        target = encode_to_batch_times_class_times_spatial(
+            target,
+            len(predicted_types),
+        )
+
+        if accuracy_fn is None:
+            metrics = evaluate_prediction(
+                y_pred,
+                target,
+                predicted_types,
+                considered_types,
+            )
+        else:
+            metrics = accuracy_fn(output=y_pred, target=target)
+
+    return metrics
-- 
GitLab