{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "df7898a3", "metadata": {}, "outputs": [], "source": [ "%pip install hausdorff\n", "%pip install numba\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import os\n", "import glob\n", "import numpy as np\n", "import nibabel as nib\n", "from ACDCUNet import build_dict_images, build_dict_images_pred\n", "\n", "import statistics\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from matplotlib.colors import LinearSegmentedColormap\n", "from hausdorff import hausdorff_distance" ] }, { "cell_type": "code", "execution_count": null, "id": "d744133d", "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "author: Clément Zotti (clement.zotti@usherbrooke.ca)\n", "date: April 2017\n", "\n", "DESCRIPTION :\n", "The script provide helpers functions to handle nifti image format:\n", " - load_nii()\n", " - save_nii()\n", "\n", "to generate metrics for two images:\n", " - metrics()\n", "\n", "And it is callable from the command line (see below).\n", "Each function provided in this script has comments to understand\n", "how they works.\n", "\n", "HOW-TO:\n", "\n", "This script was tested for python 3.4.\n", "\n", "First, you need to install the required packages with\n", " pip install -r requirements.txt\n", "\n", "After the installation, you have two ways of running this script:\n", " 1) python metrics.py ground_truth/patient001_ED.nii.gz prediction/patient001_ED.nii.gz\n", " 2) python metrics.py ground_truth/ prediction/\n", "\n", "The first option will print in the console the dice and volume of each class for the given image.\n", "The second option wiil ouput a csv file where each images will have the dice and volume of each class.\n", "\n", "\n", "Link: http://acdc.creatis.insa-lyon.fr\n", "\n", "\"\"\"\n", "\n", "HEADER = [\n", " \"Name\",\n", " \"Dice LV\",\n", " \"Volume LV pred\",\n", " \"Volume LV GT\",\n", " \"Err LV(ml)\",\n", " \"Dice RV\",\n", " \"Volume RV pred\",\n", " \"Volume RV GT\",\n", " \"Err RV(ml)\",\n", " \"Dice MYO\",\n", " \"Volume MYO pred\",\n", " \"Volume MYO GT\",\n", " \"Err MYO(ml)\",\n", "]\n", "\n", "\n", "#\n", "# Utils functions used to sort strings into a natural order\n", "#\n", "def conv_int(i):\n", " return int(i) if i.isdigit() else i\n", "\n", "\n", "def natural_order(sord):\n", " \"\"\"\n", " Sort a (list,tuple) of strings into natural order.\n", "\n", " Ex:\n", "\n", " ['1','10','2'] -> ['1','2','10']\n", "\n", " ['abc1def','ab10d','b2c','ab1d'] -> ['ab1d','ab10d', 'abc1def', 'b2c']\n", "\n", " \"\"\"\n", " if isinstance(sord, tuple):\n", " sord = sord[0]\n", " return [conv_int(c) for c in re.split(r\"(\\d+)\", sord)]\n", "\n", "\n", "#\n", "# Utils function to load and save nifti files with the nibabel package\n", "#\n", "\n", "img_path = \"ACDC\\database\"\n", "\n", "\n", "def load_nii(img_path):\n", " \"\"\"\n", " Function to load a 'nii' or 'nii.gz' file, The function returns\n", " everyting needed to save another 'nii' or 'nii.gz'\n", " in the same dimensional space, i.e. the affine matrix and the header\n", "\n", " Parameters\n", " ----------\n", "\n", " img_path: string\n", " String with the path of the 'nii' or 'nii.gz' image file name.\n", "\n", " Returns\n", " -------\n", " Three element, the first is a numpy array of the image values,\n", " the second is the affine transformation of the image, and the\n", " last one is the header of the image.\n", " \"\"\"\n", " nimg = nib.load(img_path)\n", " return nimg.get_fdata(), nimg.affine, nimg.header\n", "\n", "\n", "def save_nii(img_path, data, affine, header):\n", " \"\"\"\n", " Function to save a 'nii' or 'nii.gz' file.\n", "\n", " Parameters\n", " ----------\n", "\n", " img_path: string\n", " Path to save the image should be ending with '.nii' or '.nii.gz'.\n", "\n", " data: np.array\n", " Numpy array of the image data.\n", "\n", " affine: list of list or np.array\n", " The affine transformation to save with the image.\n", "\n", " header: nib.Nifti1Header\n", " The header that define everything about the data\n", " (pleasecheck nibabel documentation).\n", " \"\"\"\n", " nimg = nib.Nifti1Image(data, affine=affine, header=header)\n", " nimg.to_filename(img_path)\n", "\n", "\n", "#\n", "# Functions to process files, directories and metrics\n", "#\n", "def metrics(img_gt, img_pred, voxel_size):\n", " \"\"\"\n", " Function to compute the metrics between two segmentation maps given as input.\n", "\n", " Parameters\n", " ----------\n", " img_gt: np.array\n", " Array of the ground truth segmentation map.\n", "\n", " img_pred: np.array\n", " Array of the predicted segmentation map.\n", "\n", " voxel_size: list, tuple or np.array\n", " The size of a voxel of the images used to compute the volumes.\n", "\n", " Return\n", " ------\n", " A list of metrics in this order, [Dice LV, Volume LV, Volume GT, Err LV(ml),\n", " Dice RV, Volume RV, Volume GT, Err RV(ml), Dice MYO, Volume MYO, Volume GT, Err MYO(ml)]\n", " \"\"\"\n", "\n", " if img_gt.ndim != img_pred.ndim:\n", " raise ValueError(\n", " \"The arrays 'img_gt' and 'img_pred' should have the \"\n", " \"same dimension, {} against {}\".format(img_gt.ndim, img_pred.ndim)\n", " )\n", "\n", " res = []\n", " # Loop on each classes of the input images\n", " for c in [3, 1, 2]:\n", " # Copy the gt image to not alterate the input\n", " gt_c_i = np.copy(img_gt)\n", " gt_c_i[gt_c_i != c] = 0\n", "\n", " # Copy the pred image to not alterate the input\n", " pred_c_i = np.copy(img_pred)\n", " pred_c_i[pred_c_i != c] = 0\n", "\n", " # Clip the value to compute the volumes\n", " gt_c_i = np.clip(gt_c_i, 0, 1)\n", " pred_c_i = np.clip(pred_c_i, 0, 1)\n", "\n", " # Compute the Dice\n", " # dice = dc(gt_c_i, pred_c_i)\n", " dice = 1\n", "\n", " # Eventueel alternatief\n", " gt_volume = np.sum(gt_c_i)\n", " pred_volume = np.sum(pred_c_i)\n", " intersect = np.sum(gt_c_i * pred_c_i)\n", " dice = (2 * intersect) / (gt_volume + pred_volume)\n", "\n", " # Compute volume\n", " volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000.0\n", " volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000.0\n", "\n", " res += [dice, volpred, volgt, volpred - volgt]\n", "\n", " return res\n", "\n", "\n", "def compute_metrics_on_files(path_gt, path_pred):\n", " \"\"\"\n", " Function to give the metrics for two files\n", "\n", " Parameters\n", " ----------\n", "\n", " path_gt: string\n", " Path of the ground truth image.\n", "\n", " path_pred: string\n", " Path of the predicted image.\n", " \"\"\"\n", " gt, _, header = load_nii(path_gt)\n", " pred, _, _ = load_nii(path_pred)\n", " zooms = header.get_zooms()\n", "\n", " name = os.path.basename(path_gt)\n", " name = name.split(\".\")[0]\n", " res = metrics(gt, pred, zooms)\n", " res = [\"{:.3f}\".format(r) for r in res]\n", "\n", " formatting = \"{:<20}\" + \"{:>12}\" * len(res)\n", " output = formatting.format(name, *res)\n", "\n", " print(formatting.format(*HEADER))\n", " print(output)\n", "\n", " # formatting = \"{:>14}, {:>7}, {:>9}, {:>10}, {:>7}, {:>9}, {:>10}, {:>8}, {:>10}, {:>11}\"\n", " # print(formatting.format(*HEADER))\n", " # print(formatting.format(name, *res))\n", "\n", " # return [name, *res]\n", " return res\n", "\n", "\n", "def compute_metrics_on_directories(dir_gt, dir_pred):\n", " \"\"\"\n", " Function to generate a csv file for each images of two directories.\n", "\n", " Parameters\n", " ----------\n", "\n", " path_gt: string\n", " Directory of the ground truth segmentation maps.\n", "\n", " path_pred: string\n", " Directory of the predicted segmentation maps.\n", " \"\"\"\n", " lst_gt = sorted(glob(os.path.join(dir_gt, \"*\")), key=natural_order)\n", " lst_pred = sorted(glob(os.path.join(dir_pred, \"*\")), key=natural_order)\n", "\n", " res = []\n", " for p_gt, p_pred in zip(lst_gt, lst_pred):\n", " if os.path.basename(p_gt) != os.path.basename(p_pred):\n", " raise ValueError(\n", " \"The two files don't have the same name\"\n", " \" {}, {}.\".format(os.path.basename(p_gt), os.path.basename(p_pred))\n", " )\n", "\n", " gt, _, header = load_nii(p_gt)\n", " pred, _, _ = load_nii(p_pred)\n", " zooms = header.get_zooms()\n", " res.append(metrics(gt, pred, zooms))\n", "\n", " lst_name_gt = [os.path.basename(gt).split(\".\")[0] for gt in lst_gt]\n", " res = [\n", " [\n", " n,\n", " ]\n", " + r\n", " for r, n in zip(res, lst_name_gt)\n", " ]\n", " df = pd.DataFrame(res, columns=HEADER)\n", " df.to_csv(\"results_{}.csv\".format(time.strftime(\"%Y%m%d_%H%M%S\")), index=False)\n", "\n", "\n", "def main(path_gt, path_pred):\n", " \"\"\"\n", " Main function to select which method to apply on the input parameters.\n", " \"\"\"\n", " if os.path.isfile(path_gt) and os.path.isfile(path_pred):\n", " compute_metrics_on_files(path_gt, path_pred)\n", " elif os.path.isdir(path_gt) and os.path.isdir(path_pred):\n", " compute_metrics_on_directories(path_gt, path_pred)\n", " else:\n", " raise ValueError(\"The paths given needs to be two directories or two files.\")\n", "\n", "\n", "# if __name__ == \"__main__\":\n", "# parser = argparse.ArgumentParser(\n", "# description=\"Script to compute ACDC challenge metrics.\")\n", "# parser.add_argument(\"GT_IMG\", type=str, help=\"Ground Truth image\")\n", "# parser.add_argument(\"PRED_IMG\", type=str, help=\"Predicted image\")\n", "# args = parser.parse_args()\n", "# main(args.GT_IMG, args.PRED_IMG)" ] }, { "cell_type": "code", "execution_count": null, "id": "142caccb", "metadata": {}, "outputs": [], "source": [ "# Testing\n", "path_gt = os.path.join('ACDC','database','testing','patient101','patient101_frame01.nii.gz')\n", "dir_gt = os.path.join('ACDC','database','testing','patient101','patient101_frame01_gt.nii.gz')\n", "\n", "path_pred = os.path.join('ACDC','database','testing','patient101','patient101_frame01.nii.gz')\n", "dir_pred = os.path.join('ACDC','database','testing','patient101','patient101_frame01_ml_pred.nii.gz')\n", "\n", "test_files = compute_metrics_on_files(dir_gt, dir_pred)" ] }, { "cell_type": "code", "execution_count": null, "id": "0451a847", "metadata": {}, "outputs": [], "source": [ "data_path = \"ACDC/database\"\n", "test_dict_gt = build_dict_images(data_path, mode='testing').ravel()\n", "test_dict_pred = build_dict_images_pred(data_path, mode='testing')\n", "\n", "# Ground truth\n", "# First file\n", "print(len(test_dict_gt))\n", "image_paths_gt_ED = [d['image'] for d in test_dict_gt if d[\"first_file\"]]\n", "label_paths_gt_ED = [d['label'] for d in test_dict_gt if d[\"first_file\"]]\n", "\n", "# Last file\n", "image_paths_gt_ES = [d['image'] for d in test_dict_gt if not d[\"first_file\"]]\n", "label_paths_gt_ES = [d['label'] for d in test_dict_gt if not d[\"first_file\"]]\n", "\n", "print('Image path gt ED is ', image_paths_gt_ED) \n", "print('Label path gt ED is ', label_paths_gt_ED)\n", "print('Image path gt ES is ', image_paths_gt_ES)\n", "print('Label path gt ES is ', label_paths_gt_ES)\n", "\n", "# Prediction\n", "image_paths_pred_ED = [d['image'] for d in test_dict_pred if d[\"first_file\"]]\n", "label_paths_pred_ED = [d['label'] for d in test_dict_pred if d[\"first_file\"]]\n", "\n", "image_paths_pred_ES = [d['image'] for d in test_dict_pred if not d[\"first_file\"]]\n", "label_paths_pred_ES = [d['label'] for d in test_dict_pred if not d[\"first_file\"]]\n", "\n", "print('Image path prediction ED is ', image_paths_pred_ED)\n", "print('Label path prediction ED is ', label_paths_pred_ED)\n", "print('Image path prediction ES is ', image_paths_pred_ES)\n", "print('Label path prediction ES is ', label_paths_pred_ES)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6c7be742", "metadata": {}, "outputs": [], "source": [ "test_files_ED = [compute_metrics_on_files(label_paths_gt_ED[i], label_paths_pred_ED[i]) for i in range(len(label_paths_gt_ED))]\n", "test_files_ES = [compute_metrics_on_files(label_paths_gt_ES[i], label_paths_pred_ES[i]) for i in range(len(label_paths_gt_ES))]" ] }, { "cell_type": "code", "execution_count": null, "id": "049a0262", "metadata": {}, "outputs": [], "source": [ "# ED\n", "LV_GT_ED = np.array([float(metric[2]) for metric in test_files_ED])\n", "LV_pred_ED = np.array([float(metric[1]) for metric in test_files_ED])\n", "LV_dice_ED = np.array([float(metric[0]) for metric in test_files_ED])\n", "LV_err_ED = np.array([float(metric[3]) for metric in test_files_ED])\n", "\n", "RV_GT_ED = np.array([float(metric[6]) for metric in test_files_ED])\n", "RV_pred_ED = np.array([float(metric[5]) for metric in test_files_ED])\n", "RV_dice_ED = np.array([float(metric[4]) for metric in test_files_ED])\n", "RV_err_ED = np.array([float(metric[7]) for metric in test_files_ED])\n", "\n", "MYO_GT_ED = np.array([float(metric[10]) for metric in test_files_ED])\n", "MYO_pred_ED = np.array([float(metric[9]) for metric in test_files_ED])\n", "MYO_dice_ED = np.array([float(metric[8]) for metric in test_files_ED])\n", "MYO_err_ED = np.array([float(metric[11]) for metric in test_files_ED])\n", "\n", "# ES\n", "LV_GT_ES = np.array([float(metric[2]) for metric in test_files_ES])\n", "LV_pred_ES = np.array([float(metric[1]) for metric in test_files_ES])\n", "LV_dice_ES = np.array([float(metric[0]) for metric in test_files_ES])\n", "LV_err_ES = np.array([float(metric[3]) for metric in test_files_ES])\n", "\n", "RV_GT_ES = np.array([float(metric[6]) for metric in test_files_ES])\n", "RV_pred_ES = np.array([float(metric[5]) for metric in test_files_ES])\n", "RV_dice_ES = np.array([float(metric[4]) for metric in test_files_ES])\n", "RV_err_ES = np.array([float(metric[7]) for metric in test_files_ES])\n", "\n", "MYO_GT_ES = np.array([float(metric[10]) for metric in test_files_ES])\n", "MYO_pred_ES = np.array([float(metric[9]) for metric in test_files_ES])\n", "MYO_dice_ES = np.array([float(metric[8]) for metric in test_files_ES])\n", "MYO_err_ES = np.array([float(metric[11]) for metric in test_files_ES])" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7c36b2e0", "metadata": {}, "source": [ "Other ways for visualization\n", "> Comparison graphs" ] }, { "cell_type": "code", "execution_count": null, "id": "6fd62eb7", "metadata": {}, "outputs": [], "source": [ "def show_comparison_graph(gt, pred, dice, err, color_boundaries=20, part=\"LV\", mode=\"ED\"):\n", " plt.clf()\n", " vmin = min(err)\n", " vmax = max(err)\n", "\n", " colors = ['red', 'yellow', 'green', 'yellow', 'red']\n", " color_positions = [vmin, -1 * color_boundaries, 0, color_boundaries, vmax]\n", " norm = plt.Normalize(vmin, vmax)\n", " print(norm(color_positions))\n", "\n", " colormap = list(zip(norm(color_positions), colors))\n", " cmap = LinearSegmentedColormap.from_list(\"custom_cmap\", colormap)\n", "\n", " plt.scatter(gt, pred, c=err, cmap=cmap, vmin=vmin, vmax=vmax, s=dice * 100)\n", " plt.plot([min(gt), max(gt)], [min(gt), max(gt)], 'k--')\n", " plt.xlabel(f'Ground Truth {part} Volume [ml]')\n", " plt.ylabel(f'Predicted {part} Volume [ml]')\n", " plt.title(f'Ground Truth vs. Predicted Volume of {part} {mode}')\n", "\n", " cbar = plt.colorbar()\n", " cbar.set_label('Error Value')\n", "\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "1541570e", "metadata": {}, "outputs": [], "source": [ "show_comparison_graph(LV_GT_ED, LV_pred_ED, LV_dice_ED, LV_err_ED)" ] }, { "cell_type": "code", "execution_count": null, "id": "93c71590", "metadata": {}, "outputs": [], "source": [ "show_comparison_graph(RV_GT_ED, RV_pred_ED, RV_dice_ED, RV_err_ED, color_boundaries=18, part=\"RV\")" ] }, { "cell_type": "code", "execution_count": null, "id": "fa1aa317", "metadata": {}, "outputs": [], "source": [ "show_comparison_graph(MYO_GT_ED, MYO_pred_ED, MYO_dice_ED, MYO_err_ED, color_boundaries=18, part=\"MYO\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5b91e382", "metadata": {}, "source": [ "> Correlations" ] }, { "cell_type": "code", "execution_count": null, "id": "a21ffc52", "metadata": {}, "outputs": [], "source": [ "# Correlation coefficients\n", "def calculate_correlation(gt, pred, err):\n", " gt_mean = statistics.mean(gt)\n", " gt_std = statistics.stdev(gt)\n", " pred_mean = statistics.mean(pred)\n", " pred_std = statistics.stdev(pred)\n", " covariance = sum((x - gt_mean) * (y - pred_mean) for x, y in zip(gt, pred)) / len(gt)\n", " corr = covariance / (gt_std * pred_std)\n", " bias = statistics.mean(err)\n", " err_std = statistics.stdev(err)\n", " loa = 1.96*err_std\n", " return corr, bias, loa" ] }, { "cell_type": "code", "execution_count": null, "id": "a4ecf214", "metadata": {}, "outputs": [], "source": [ "LV_corr, LV_bias, LV_LOA = calculate_correlation(LV_GT_ED, LV_pred_ED, LV_err_ED)\n", "print('The correlation coefficient for left ventricle EDV is ', LV_corr)\n", "print('The bias for the left ventricle EDV is ', LV_bias)\n", "print('The LOA of the left ventricle EDV is ', LV_LOA)" ] }, { "cell_type": "code", "execution_count": null, "id": "12c061d2", "metadata": {}, "outputs": [], "source": [ "RV_corr, RV_bias, RV_LOA = calculate_correlation(RV_GT_ED, RV_pred_ED, RV_err_ED)\n", "print('The correlation coefficient for right ventricle EDV is ', RV_corr)\n", "print('The bias for the right ventricle EDV is ', RV_bias)\n", "print('The LOA of the right ventricle EDV is ', RV_LOA)" ] }, { "cell_type": "code", "execution_count": null, "id": "ccb53139", "metadata": {}, "outputs": [], "source": [ "MYO_corr, MYO_bias, MYO_LOA = calculate_correlation(MYO_GT_ED, MYO_pred_ED, MYO_err_ED)\n", "print('The correlation coefficient for myocardium EDV is ', MYO_corr)\n", "print('The bias for the left myocardium EDV is ', MYO_bias)\n", "print('The LOA of the left myocardium EDV is ', MYO_LOA)" ] }, { "cell_type": "code", "execution_count": null, "id": "d85b01dd", "metadata": {}, "outputs": [], "source": [ "# Average dice\n", "\n", "av_LV_dice = statistics.mean(LV_dice_ED)\n", "av_RV_dice = statistics.mean(RV_dice_ED)\n", "av_MYO_dice = statistics.mean(MYO_dice_ED)\n", "print('Average dice left ventricle is ', av_LV_dice)\n", "print('Average dice right ventricle is ', av_RV_dice)\n", "print('Average dice myocardium is ', av_MYO_dice)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ce64f7e5", "metadata": {}, "source": [ "ES" ] }, { "cell_type": "code", "execution_count": null, "id": "437c2b32", "metadata": {}, "outputs": [], "source": [ "show_comparison_graph(LV_GT_ES, LV_pred_ES, LV_dice_ES, LV_err_ES, color_boundaries=15, part=\"LV\", mode=\"ES\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0ff66a64", "metadata": {}, "outputs": [], "source": [ "show_comparison_graph(RV_GT_ES, RV_pred_ES, RV_dice_ES, RV_err_ES, color_boundaries=10, part=\"RV\", mode=\"ES\")" ] }, { "cell_type": "code", "execution_count": null, "id": "931d0abb", "metadata": {}, "outputs": [], "source": [ "show_comparison_graph(MYO_GT_ES, MYO_pred_ES, MYO_dice_ES, MYO_err_ES, color_boundaries=30, part=\"MYO\", mode=\"ES\")" ] }, { "cell_type": "code", "execution_count": null, "id": "4d30933c", "metadata": {}, "outputs": [], "source": [ "LV_corr, LV_bias, LV_LOA = calculate_correlation(LV_GT_ES, LV_pred_ES, LV_err_ES)\n", "print('The correlation coefficient for left ventricle ESV is ', LV_corr)\n", "print('The bias for the left ventricle ESV is ', LV_bias)\n", "print('The LOA of the left ventricle ESV is ', LV_LOA)" ] }, { "cell_type": "code", "execution_count": null, "id": "823112fe", "metadata": {}, "outputs": [], "source": [ "RV_corr, RV_bias, RV_LOA = calculate_correlation(RV_GT_ES, RV_pred_ES, RV_err_ES)\n", "print('The correlation coefficient for right ventricle ESV is ', RV_corr)\n", "print('The bias for the right ventricle ESV is ', RV_bias)\n", "print('The LOA of the right ventricle ESV is ', RV_LOA)" ] }, { "cell_type": "code", "execution_count": null, "id": "50408bf1", "metadata": {}, "outputs": [], "source": [ "MYO_corr, MYO_bias, MYO_LOA = calculate_correlation(MYO_GT_ES, MYO_pred_ES, MYO_err_ES)\n", "print('The correlation coefficient for myocardium ESV is ', MYO_corr)\n", "print('The bias for the left myocardium ESV is ', MYO_bias)\n", "print('The LOA of the left myocardium ESV is ', MYO_LOA)" ] }, { "cell_type": "code", "execution_count": null, "id": "cf0f0d91", "metadata": {}, "outputs": [], "source": [ "# Average dice\n", "av_LV_dice = statistics.mean(LV_dice_ES)\n", "av_RV_dice = statistics.mean(RV_dice_ES)\n", "av_MYO_dice = statistics.mean(MYO_dice_ES)\n", "print('Average dice left ventricle ESV is ', av_LV_dice)\n", "print('Average dice right ventricle ESV is ', av_RV_dice)\n", "print('Average dice myocardium ESV is ', av_MYO_dice)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ecf8346f", "metadata": {}, "source": [ "EJ" ] }, { "cell_type": "code", "execution_count": null, "id": "c1671753", "metadata": {}, "outputs": [], "source": [ "EJ_LV_GT = [((LV_GT_ED[i] - LV_GT_ES[i])/ LV_GT_ED[i]) * 100 for i in range(len(LV_GT_ED))]\n", "EJ_RV_GT = [((RV_GT_ED[i] - RV_GT_ES[i])/ RV_GT_ED[i]) * 100 for i in range(len(RV_GT_ED))]\n", "\n", "EJ_LV_pred = [((LV_pred_ED[i] - LV_pred_ES[i])/ (LV_pred_ED[i]+0.1)) * 100 for i in range(len(LV_pred_ED))]\n", "EJ_RV_pred = [((RV_pred_ED[i] - RV_pred_ES[i])/ (RV_pred_ED[i]+0.1)) * 100 for i in range(len(RV_pred_ED))]\n", "\n", "EJ_LV_err = [EJ_LV_GT[i] - EJ_LV_pred[i] for i in range(len(EJ_LV_GT))]\n", "EJ_RV_err = [EJ_RV_GT[i] - EJ_RV_pred[i] for i in range(len(EJ_RV_GT))]\n", "\n", "EJ_LV_corr, EJ_LV_bias, LV_LOA = calculate_correlation(EJ_LV_GT, EJ_LV_pred, EJ_LV_err)\n", "print('The correlation of the ejection fraction for left ventricle is ', EJ_LV_corr)\n", "print('The mean bias of the ejection fraction of left ventricle is ', EJ_LV_bias)\n", "print('The LOA of the ejection fraction of the left ventricle is ', LV_LOA)\n", "\n", "EJ_RV_corr, EJ_RV_bias, RV_LOA = calculate_correlation(EJ_RV_GT, EJ_RV_pred, EJ_RV_err)\n", "print('The correlation of the ejection fraction for right ventricle is ', EJ_RV_corr)\n", "print('The mean bias of the ejection fraction of right ventricle is ', EJ_RV_bias)\n", "print('The LOA of the ejection fraction of the right ventricle is ', RV_LOA)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "de9c181c", "metadata": {}, "source": [ "Best and worst results" ] }, { "cell_type": "code", "execution_count": null, "id": "c9a463de", "metadata": {}, "outputs": [], "source": [ "# This determines the best and worst performing predictions based on the dice score of all the segments and both time frames\n", "# Uses the LV_dice_ES list to find back the correct patient, but could be any of the lists\n", "\n", "\n", "score_sum = {LV_dice_ES[x]: sum([LV_dice_ES[x], RV_dice_ES[x], MYO_dice_ES[x], LV_dice_ED[x], RV_dice_ED[x], MYO_dice_ED[x]]) for x in range(len(LV_dice_ES))}\n", "max_value = max(score_sum, key=score_sum.get)\n", "min_value = min(score_sum, key=score_sum.get)\n", "print(min_value)\n", "del score_sum[min_value]\n", "\n", "min_value = min(score_sum, key=score_sum.get)\n", "print(min_value)\n", "\n", "max_index = np.where(LV_dice_ES == max_value)[0][0]\n", "min_index = np.where(LV_dice_ES == min_value)[0][0]\n", "\n", "print(\"Max value:\", max_value)\n", "print(\"Max index:\", max_index)\n", "\n", "print(\"Min value:\", min_value)\n", "print(\"Min index:\", min_index)\n", "\n", "# Show slices" ] }, { "cell_type": "code", "execution_count": null, "id": "0ef0ed36", "metadata": {}, "outputs": [], "source": [ "print('Best image path gt ES is ', image_paths_gt_ES[max_index])\n", "print('Best label path gt ES is ', label_paths_gt_ES[max_index])\n", "\n", "print('Best image path prediction ES is ', image_paths_pred_ES[max_index])\n", "print('Best label path prediction ES is ', label_paths_pred_ES[max_index])\n" ] }, { "cell_type": "code", "execution_count": null, "id": "35b72c7c", "metadata": {}, "outputs": [], "source": [ "def get_slice(index, image_gt, label_gt, image_pred, label_pred, channel_index=0, frame='ES'):\n", " image_path_gt = image_gt[index]\n", " label_path_gt = label_gt[index]\n", " image_path_pred = image_pred[index]\n", " label_path_pred = label_pred[index]\n", "\n", " # Load image and label data\n", " image_gt = nib.load(image_path_gt).get_fdata()\n", " label_gt_ES = nib.load(label_path_gt).get_fdata()\n", " image_pred = nib.load(image_path_pred).get_fdata()\n", " label_pred = nib.load(label_path_pred).get_fdata()\n", " return (\n", " [\n", " image_gt[..., channel_index],\n", " label_gt_ES[..., channel_index],\n", " image_pred[..., channel_index],\n", " label_pred[..., channel_index]\n", " ]\n", " )\n", "def display_slices(min_index, max_index, image_gt, label_gt, image_pred, label_pred, channel_index=0, frame='ES'):\n", " best_image_gt_ES_channel, best_label_gt_ES_channel, best_image_pred_ES_channel, best_label_pred_ES_channel = get_slice(max_index, image_gt, label_gt, image_pred, label_pred, channel_index)\n", " worst_image_gt_ES_channel, worst_label_gt_ES_channel, worst_image_pred_ES_channel, worst_label_pred_ES_channel = get_slice(min_index, image_gt, label_gt, image_pred, label_pred, channel_index)\n", "\n", " plt.subplot(2, 2, 1)\n", " plt.imshow(best_image_gt_ES_channel, cmap='gray')\n", " plt.title(f\"Ground truth best prediction {frame}\")\n", " plt.imshow(best_label_gt_ES_channel, alpha=0.5, cmap='Reds')\n", "\n", "\n", " plt.subplot(2, 2, 2)\n", " plt.imshow(best_image_pred_ES_channel, cmap='gray')\n", " plt.title(f\"Best prediction {frame}\")\n", " plt.imshow(best_label_pred_ES_channel, cmap='Reds', alpha=0.5)\n", "\n", " plt.subplot(2, 2, 3)\n", " plt.imshow(worst_image_gt_ES_channel, cmap='gray')\n", " plt.title(f\"Ground truth worst prediction {frame}\")\n", " plt.imshow(worst_label_gt_ES_channel, alpha=0.5, cmap='Reds')\n", "\n", " plt.subplot(2, 2, 4)\n", " plt.imshow(worst_image_pred_ES_channel, cmap='gray')\n", " plt.title(f\"Worst prediction {frame}\")\n", " plt.imshow(worst_label_pred_ES_channel, cmap='Reds', alpha=0.5)\n", " plt.tight_layout()\n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8074a3d4", "metadata": {}, "outputs": [], "source": [ "display_slices(min_index, max_index, image_paths_gt_ES, label_paths_gt_ES, image_paths_pred_ES, label_paths_pred_ES, channel_index=4, frame='ES')\n", "compute_metrics_on_files(label_paths_gt_ED[max_index], label_paths_pred_ED[max_index])\n", "compute_metrics_on_files(label_paths_gt_ES[max_index], label_paths_pred_ES[max_index])\n", "print()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e2d5604a", "metadata": {}, "outputs": [], "source": [ "display_slices(min_index, max_index, image_paths_gt_ED, label_paths_gt_ED, image_paths_pred_ED, label_paths_pred_ED, channel_index=4, frame='ED')\n", "compute_metrics_on_files(label_paths_gt_ED[min_index], label_paths_pred_ED[min_index])\n", "compute_metrics_on_files(label_paths_gt_ES[min_index], label_paths_pred_ES[min_index])\n", "print()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c4ee1d4a", "metadata": {}, "source": [ "Hausdorff distances" ] }, { "cell_type": "code", "execution_count": null, "id": "a633891f", "metadata": {}, "outputs": [], "source": [ "# First file\n", "label_paths_gt_ED = [d['label'] for d in test_dict_gt if d[\"first_file\"]]\n", "print(len(label_paths_gt_ED))\n", "\n", "# Last file\n", "label_paths_gt_ES = [d['label'] for d in test_dict_gt if not d[\"first_file\"]]\n", "\n", "# Prediction\n", "label_paths_pred_ED = [d['label'] for d in test_dict_pred if d[\"first_file\"]]\n", "\n", "label_paths_pred_ES = [d['label'] for d in test_dict_pred if not d[\"first_file\"]]" ] }, { "cell_type": "code", "execution_count": null, "id": "f3a8a4e0", "metadata": {}, "outputs": [], "source": [ "# GT and pred of ED\n", "dir_gt_ED = []\n", "dir_pred_ED = []\n", "for label_path_gt_ED, label_path_pred_ED in zip(label_paths_gt_ED, label_paths_pred_ED):\n", " label_GT_ED = nib.load(label_path_gt_ED).get_fdata()\n", " label_pred_ED = nib.load(label_path_pred_ED).get_fdata()\n", " dir_gt_ED.append(label_GT_ED)\n", " dir_pred_ED.append(label_pred_ED)\n", "\n", "# GT and pred of ES\n", "dir_gt_ES = []\n", "dir_pred_ES = []\n", "for label_path_gt, label_path_pred in zip(label_paths_gt_ES, label_paths_pred_ES):\n", " label_GT_ES = nib.load(label_path_gt).get_fdata()\n", " label_pred = nib.load(label_path_pred).get_fdata()\n", " dir_gt_ES.append(label_GT_ES)\n", " dir_pred_ES.append(label_pred)" ] }, { "cell_type": "code", "execution_count": null, "id": "106f8cb8", "metadata": {}, "outputs": [], "source": [ "label_gt_ED = []\n", "label_gt_ED_LV = []\n", "label_gt_ED_wall = []\n", "label_gt_ED_RV = []\n", "\n", "for i in range(50):\n", " label_gt_ED_single = np.array(dir_gt_ED[i], dtype=np.float32)\n", " label_gt_ED.append(label_gt_ED_single)\n", " label_gt_ED_LV.append(np.where(label_gt_ED_single == 1,1,0))\n", " label_gt_ED_wall.append(np.where(label_gt_ED_single == 2,1,0))\n", " label_gt_ED_RV.append(np.where(label_gt_ED_single == 3,1,0))\n", "\n", "\n", "label_pred_ED = []\n", "label_pred_ED_LV = []\n", "label_pred_ED_wall = []\n", "label_pred_ED_RV = []\n", "\n", "for i in range(50):\n", " label_pred_ED_single = np.array(dir_pred_ED[i], dtype=np.float32)\n", " label_pred_ED.append(label_pred_ED_single)\n", " label_pred_ED_LV.append(np.where(label_pred_ED_single == 1,1,0))\n", " label_pred_ED_wall.append(np.where(label_pred_ED_single == 2,1,0))\n", " label_pred_ED_RV.append(np.where(label_pred_ED_single == 3,1,0))\n", "\n", "label_gt_ES = []\n", "label_gt_ES_LV = []\n", "label_gt_ES_wall = []\n", "label_gt_ES_RV = []\n", "\n", "for i in range(50):\n", " label_gt_ES_single = np.array(dir_gt_ES[i], dtype=np.float32)\n", " label_gt_ES.append(label_gt_ES_single)\n", " label_gt_ES_LV.append(np.where(label_gt_ES_single == 1,1,0))\n", " label_gt_ES_wall.append(np.where(label_gt_ES_single == 2,1,0))\n", " label_gt_ES_RV.append(np.where(label_gt_ES_single == 3,1,0))\n", "\n", "\n", "label_pred = []\n", "label_pred_ES_LV = []\n", "label_pred_ES_wall = []\n", "label_pred_ES_RV = []\n", "\n", "for i in range(50):\n", " label_pred_ES_single = np.array(dir_pred_ES[i], dtype=np.float32)\n", " label_pred.append(label_pred_ES_single)\n", " label_pred_ES_LV.append(np.where(label_pred_ES_single == 1,1,0))\n", " label_pred_ES_wall.append(np.where(label_pred_ES_single == 2,1,0))\n", " label_pred_ES_RV.append(np.where(label_pred_ES_single == 3,1,0))" ] }, { "cell_type": "code", "execution_count": null, "id": "385e0f03", "metadata": {}, "outputs": [], "source": [ "print(len(label_pred_ES_RV))" ] }, { "cell_type": "code", "execution_count": null, "id": "9a682e84", "metadata": {}, "outputs": [], "source": [ "hd_LV_ED = []\n", "hd_wall_ED = []\n", "hd_RV_ED = []\n", "hd_LV_ES = []\n", "hd_wall_ES = []\n", "hd_RV_ES = []\n", "\n", "for i in range(50):\n", " # Calculate Hausdorff distance for ED\n", " hd_lv_ed = max([hausdorff_distance(label_pred_ED_LV[i][:,:,y], label_gt_ED_LV[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_LV[i].shape[2])])\n", " hd_wall_ed = max([hausdorff_distance(label_pred_ED_wall[i][:,:,y], label_gt_ED_wall[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_wall[i].shape[2])])\n", " hd_rv_ed = max([hausdorff_distance(label_pred_ED_RV[i][:,:,y], label_gt_ED_RV[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_RV[i].shape[2])])\n", " hd_LV_ED.append(hd_lv_ed)\n", " hd_wall_ED.append(hd_wall_ed)\n", " hd_RV_ED.append(hd_rv_ed)\n", " \n", " # Calculate Hausdorff distance for ES\n", " hd_lv_es = max([hausdorff_distance(label_pred_ES_LV[i][:,:,y], label_gt_ES_LV[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_LV[i].shape[2])])\n", " hd_wall_es = max([hausdorff_distance(label_pred_ES_wall[i][:,:,y], label_gt_ES_wall[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_wall[i].shape[2])])\n", " hd_rv_es = max([hausdorff_distance(label_pred_ES_RV[i][:,:,y], label_gt_ES_RV[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_RV[i].shape[2])])\n", " hd_LV_ES.append(hd_lv_es)\n", " hd_wall_ES.append(hd_wall_es)\n", " hd_RV_ES.append(hd_rv_es)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "5c4c95a1", "metadata": {}, "outputs": [], "source": [ "print(len(hd_LV_ED))\n", "print(hd_wall_ES[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "2b9b673e", "metadata": {}, "outputs": [], "source": [ "# Calculations of average HD\n", "# ED\n", "HD_LV_ED_mean = statistics.mean(hd_LV_ED)\n", "HD_RV_ED_mean = statistics.mean(hd_RV_ED)\n", "HD_MYO_ED_mean = statistics.mean(hd_wall_ED)\n", "\n", "print('The average hd of the left ventricle EDV is ', HD_LV_ED_mean)\n", "print('The average hd of the right ventricle EDV is ', HD_RV_ED_mean)\n", "print('The average hd of the myocardium EDV is ', HD_MYO_ED_mean)\n", "\n", "HD_LV_ES_mean = statistics.mean(hd_LV_ES)\n", "HD_RV_ES_mean = statistics.mean(hd_RV_ES)\n", "HD_MYO_ES_mean = statistics.mean(hd_wall_ES)\n", "\n", "print('The average hd of the left ventricle ESV is ', HD_LV_ES_mean)\n", "print('The average hd of the right ventricle ESV is ', HD_RV_ES_mean)\n", "print('The average hd of the myocardium ESV is ', HD_MYO_ES_mean)" ] } ], "metadata": { "kernelspec": { "display_name": "Python", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }