{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df7898a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install hausdorff\n",
    "%pip install numba\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import glob\n",
    "import numpy as np\n",
    "import nibabel as nib\n",
    "from ACDCUNet import build_dict_images, build_dict_images_pred\n",
    "\n",
    "import statistics\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "from hausdorff import hausdorff_distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d744133d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "author: Clément Zotti (clement.zotti@usherbrooke.ca)\n",
    "date: April 2017\n",
    "\n",
    "DESCRIPTION :\n",
    "The script provide helpers functions to handle nifti image format:\n",
    "    - load_nii()\n",
    "    - save_nii()\n",
    "\n",
    "to generate metrics for two images:\n",
    "    - metrics()\n",
    "\n",
    "And it is callable from the command line (see below).\n",
    "Each function provided in this script has comments to understand\n",
    "how they works.\n",
    "\n",
    "HOW-TO:\n",
    "\n",
    "This script was tested for python 3.4.\n",
    "\n",
    "First, you need to install the required packages with\n",
    "    pip install -r requirements.txt\n",
    "\n",
    "After the installation, you have two ways of running this script:\n",
    "    1) python metrics.py ground_truth/patient001_ED.nii.gz prediction/patient001_ED.nii.gz\n",
    "    2) python metrics.py ground_truth/ prediction/\n",
    "\n",
    "The first option will print in the console the dice and volume of each class for the given image.\n",
    "The second option wiil ouput a csv file where each images will have the dice and volume of each class.\n",
    "\n",
    "\n",
    "Link: http://acdc.creatis.insa-lyon.fr\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "HEADER = [\n",
    "    \"Name\",\n",
    "    \"Dice LV\",\n",
    "    \"Volume LV pred\",\n",
    "    \"Volume LV GT\",\n",
    "    \"Err LV(ml)\",\n",
    "    \"Dice RV\",\n",
    "    \"Volume RV pred\",\n",
    "    \"Volume RV GT\",\n",
    "    \"Err RV(ml)\",\n",
    "    \"Dice MYO\",\n",
    "    \"Volume MYO pred\",\n",
    "    \"Volume MYO GT\",\n",
    "    \"Err MYO(ml)\",\n",
    "]\n",
    "\n",
    "\n",
    "#\n",
    "# Utils functions used to sort strings into a natural order\n",
    "#\n",
    "def conv_int(i):\n",
    "    return int(i) if i.isdigit() else i\n",
    "\n",
    "\n",
    "def natural_order(sord):\n",
    "    \"\"\"\n",
    "    Sort a (list,tuple) of strings into natural order.\n",
    "\n",
    "    Ex:\n",
    "\n",
    "    ['1','10','2'] -> ['1','2','10']\n",
    "\n",
    "    ['abc1def','ab10d','b2c','ab1d'] -> ['ab1d','ab10d', 'abc1def', 'b2c']\n",
    "\n",
    "    \"\"\"\n",
    "    if isinstance(sord, tuple):\n",
    "        sord = sord[0]\n",
    "    return [conv_int(c) for c in re.split(r\"(\\d+)\", sord)]\n",
    "\n",
    "\n",
    "#\n",
    "# Utils function to load and save nifti files with the nibabel package\n",
    "#\n",
    "\n",
    "img_path = \"ACDC\\database\"\n",
    "\n",
    "\n",
    "def load_nii(img_path):\n",
    "    \"\"\"\n",
    "    Function to load a 'nii' or 'nii.gz' file, The function returns\n",
    "    everyting needed to save another 'nii' or 'nii.gz'\n",
    "    in the same dimensional space, i.e. the affine matrix and the header\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "\n",
    "    img_path: string\n",
    "    String with the path of the 'nii' or 'nii.gz' image file name.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    Three element, the first is a numpy array of the image values,\n",
    "    the second is the affine transformation of the image, and the\n",
    "    last one is the header of the image.\n",
    "    \"\"\"\n",
    "    nimg = nib.load(img_path)\n",
    "    return nimg.get_fdata(), nimg.affine, nimg.header\n",
    "\n",
    "\n",
    "def save_nii(img_path, data, affine, header):\n",
    "    \"\"\"\n",
    "    Function to save a 'nii' or 'nii.gz' file.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "\n",
    "    img_path: string\n",
    "    Path to save the image should be ending with '.nii' or '.nii.gz'.\n",
    "\n",
    "    data: np.array\n",
    "    Numpy array of the image data.\n",
    "\n",
    "    affine: list of list or np.array\n",
    "    The affine transformation to save with the image.\n",
    "\n",
    "    header: nib.Nifti1Header\n",
    "    The header that define everything about the data\n",
    "    (pleasecheck nibabel documentation).\n",
    "    \"\"\"\n",
    "    nimg = nib.Nifti1Image(data, affine=affine, header=header)\n",
    "    nimg.to_filename(img_path)\n",
    "\n",
    "\n",
    "#\n",
    "# Functions to process files, directories and metrics\n",
    "#\n",
    "def metrics(img_gt, img_pred, voxel_size):\n",
    "    \"\"\"\n",
    "    Function to compute the metrics between two segmentation maps given as input.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    img_gt: np.array\n",
    "    Array of the ground truth segmentation map.\n",
    "\n",
    "    img_pred: np.array\n",
    "    Array of the predicted segmentation map.\n",
    "\n",
    "    voxel_size: list, tuple or np.array\n",
    "    The size of a voxel of the images used to compute the volumes.\n",
    "\n",
    "    Return\n",
    "    ------\n",
    "    A list of metrics in this order, [Dice LV, Volume LV, Volume GT, Err LV(ml),\n",
    "    Dice RV, Volume RV, Volume GT, Err RV(ml), Dice MYO, Volume MYO, Volume GT, Err MYO(ml)]\n",
    "    \"\"\"\n",
    "\n",
    "    if img_gt.ndim != img_pred.ndim:\n",
    "        raise ValueError(\n",
    "            \"The arrays 'img_gt' and 'img_pred' should have the \"\n",
    "            \"same dimension, {} against {}\".format(img_gt.ndim, img_pred.ndim)\n",
    "        )\n",
    "\n",
    "    res = []\n",
    "    # Loop on each classes of the input images\n",
    "    for c in [3, 1, 2]:\n",
    "        # Copy the gt image to not alterate the input\n",
    "        gt_c_i = np.copy(img_gt)\n",
    "        gt_c_i[gt_c_i != c] = 0\n",
    "\n",
    "        # Copy the pred image to not alterate the input\n",
    "        pred_c_i = np.copy(img_pred)\n",
    "        pred_c_i[pred_c_i != c] = 0\n",
    "\n",
    "        # Clip the value to compute the volumes\n",
    "        gt_c_i = np.clip(gt_c_i, 0, 1)\n",
    "        pred_c_i = np.clip(pred_c_i, 0, 1)\n",
    "\n",
    "        # Compute the Dice\n",
    "        # dice = dc(gt_c_i, pred_c_i)\n",
    "        dice = 1\n",
    "\n",
    "        # Eventueel alternatief\n",
    "        gt_volume = np.sum(gt_c_i)\n",
    "        pred_volume = np.sum(pred_c_i)\n",
    "        intersect = np.sum(gt_c_i * pred_c_i)\n",
    "        dice = (2 * intersect) / (gt_volume + pred_volume)\n",
    "\n",
    "        # Compute volume\n",
    "        volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000.0\n",
    "        volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000.0\n",
    "\n",
    "        res += [dice, volpred, volgt, volpred - volgt]\n",
    "\n",
    "    return res\n",
    "\n",
    "\n",
    "def compute_metrics_on_files(path_gt, path_pred):\n",
    "    \"\"\"\n",
    "    Function to give the metrics for two files\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "\n",
    "    path_gt: string\n",
    "    Path of the ground truth image.\n",
    "\n",
    "    path_pred: string\n",
    "    Path of the predicted image.\n",
    "    \"\"\"\n",
    "    gt, _, header = load_nii(path_gt)\n",
    "    pred, _, _ = load_nii(path_pred)\n",
    "    zooms = header.get_zooms()\n",
    "\n",
    "    name = os.path.basename(path_gt)\n",
    "    name = name.split(\".\")[0]\n",
    "    res = metrics(gt, pred, zooms)\n",
    "    res = [\"{:.3f}\".format(r) for r in res]\n",
    "\n",
    "    formatting = \"{:<20}\" + \"{:>12}\" * len(res)\n",
    "    output = formatting.format(name, *res)\n",
    "\n",
    "    print(formatting.format(*HEADER))\n",
    "    print(output)\n",
    "\n",
    "    # formatting = \"{:>14}, {:>7}, {:>9}, {:>10}, {:>7}, {:>9}, {:>10}, {:>8}, {:>10}, {:>11}\"\n",
    "    # print(formatting.format(*HEADER))\n",
    "    # print(formatting.format(name, *res))\n",
    "\n",
    "    # return [name, *res]\n",
    "    return res\n",
    "\n",
    "\n",
    "def compute_metrics_on_directories(dir_gt, dir_pred):\n",
    "    \"\"\"\n",
    "    Function to generate a csv file for each images of two directories.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "\n",
    "    path_gt: string\n",
    "    Directory of the ground truth segmentation maps.\n",
    "\n",
    "    path_pred: string\n",
    "    Directory of the predicted segmentation maps.\n",
    "    \"\"\"\n",
    "    lst_gt = sorted(glob(os.path.join(dir_gt, \"*\")), key=natural_order)\n",
    "    lst_pred = sorted(glob(os.path.join(dir_pred, \"*\")), key=natural_order)\n",
    "\n",
    "    res = []\n",
    "    for p_gt, p_pred in zip(lst_gt, lst_pred):\n",
    "        if os.path.basename(p_gt) != os.path.basename(p_pred):\n",
    "            raise ValueError(\n",
    "                \"The two files don't have the same name\"\n",
    "                \" {}, {}.\".format(os.path.basename(p_gt), os.path.basename(p_pred))\n",
    "            )\n",
    "\n",
    "        gt, _, header = load_nii(p_gt)\n",
    "        pred, _, _ = load_nii(p_pred)\n",
    "        zooms = header.get_zooms()\n",
    "        res.append(metrics(gt, pred, zooms))\n",
    "\n",
    "    lst_name_gt = [os.path.basename(gt).split(\".\")[0] for gt in lst_gt]\n",
    "    res = [\n",
    "        [\n",
    "            n,\n",
    "        ]\n",
    "        + r\n",
    "        for r, n in zip(res, lst_name_gt)\n",
    "    ]\n",
    "    df = pd.DataFrame(res, columns=HEADER)\n",
    "    df.to_csv(\"results_{}.csv\".format(time.strftime(\"%Y%m%d_%H%M%S\")), index=False)\n",
    "\n",
    "\n",
    "def main(path_gt, path_pred):\n",
    "    \"\"\"\n",
    "    Main function to select which method to apply on the input parameters.\n",
    "    \"\"\"\n",
    "    if os.path.isfile(path_gt) and os.path.isfile(path_pred):\n",
    "        compute_metrics_on_files(path_gt, path_pred)\n",
    "    elif os.path.isdir(path_gt) and os.path.isdir(path_pred):\n",
    "        compute_metrics_on_directories(path_gt, path_pred)\n",
    "    else:\n",
    "        raise ValueError(\"The paths given needs to be two directories or two files.\")\n",
    "\n",
    "\n",
    "# if __name__ == \"__main__\":\n",
    "#     parser = argparse.ArgumentParser(\n",
    "#         description=\"Script to compute ACDC challenge metrics.\")\n",
    "#     parser.add_argument(\"GT_IMG\", type=str, help=\"Ground Truth image\")\n",
    "#     parser.add_argument(\"PRED_IMG\", type=str, help=\"Predicted image\")\n",
    "#     args = parser.parse_args()\n",
    "#     main(args.GT_IMG, args.PRED_IMG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "142caccb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Testing\n",
    "path_gt = os.path.join('ACDC','database','testing','patient101','patient101_frame01.nii.gz')\n",
    "dir_gt = os.path.join('ACDC','database','testing','patient101','patient101_frame01_gt.nii.gz')\n",
    "\n",
    "path_pred = os.path.join('ACDC','database','testing','patient101','patient101_frame01.nii.gz')\n",
    "dir_pred = os.path.join('ACDC','database','testing','patient101','patient101_frame01_ml_pred.nii.gz')\n",
    "\n",
    "test_files = compute_metrics_on_files(dir_gt, dir_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0451a847",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = \"ACDC/database\"\n",
    "test_dict_gt = build_dict_images(data_path, mode='testing').ravel()\n",
    "test_dict_pred = build_dict_images_pred(data_path, mode='testing')\n",
    "\n",
    "# Ground truth\n",
    "# First file\n",
    "print(len(test_dict_gt))\n",
    "image_paths_gt_ED = [d['image'] for d in test_dict_gt if d[\"first_file\"]]\n",
    "label_paths_gt_ED = [d['label'] for d in test_dict_gt if d[\"first_file\"]]\n",
    "\n",
    "# Last file\n",
    "image_paths_gt_ES = [d['image'] for d in test_dict_gt if not d[\"first_file\"]]\n",
    "label_paths_gt_ES = [d['label'] for d in test_dict_gt if not d[\"first_file\"]]\n",
    "\n",
    "print('Image path gt ED is ', image_paths_gt_ED) \n",
    "print('Label path gt ED is ', label_paths_gt_ED)\n",
    "print('Image path gt ES is ', image_paths_gt_ES)\n",
    "print('Label path gt ES is ', label_paths_gt_ES)\n",
    "\n",
    "# Prediction\n",
    "image_paths_pred_ED = [d['image'] for d in test_dict_pred if d[\"first_file\"]]\n",
    "label_paths_pred_ED = [d['label'] for d in test_dict_pred if d[\"first_file\"]]\n",
    "\n",
    "image_paths_pred_ES = [d['image'] for d in test_dict_pred if not d[\"first_file\"]]\n",
    "label_paths_pred_ES = [d['label'] for d in test_dict_pred if not d[\"first_file\"]]\n",
    "\n",
    "print('Image path prediction ED is ', image_paths_pred_ED)\n",
    "print('Label path prediction ED is ', label_paths_pred_ED)\n",
    "print('Image path prediction ES is ', image_paths_pred_ES)\n",
    "print('Label path prediction ES is ', label_paths_pred_ES)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c7be742",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_files_ED = [compute_metrics_on_files(label_paths_gt_ED[i], label_paths_pred_ED[i]) for i in range(len(label_paths_gt_ED))]\n",
    "test_files_ES = [compute_metrics_on_files(label_paths_gt_ES[i], label_paths_pred_ES[i]) for i in range(len(label_paths_gt_ES))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "049a0262",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ED\n",
    "LV_GT_ED = np.array([float(metric[2]) for metric in test_files_ED])\n",
    "LV_pred_ED = np.array([float(metric[1]) for metric in test_files_ED])\n",
    "LV_dice_ED = np.array([float(metric[0]) for metric in test_files_ED])\n",
    "LV_err_ED = np.array([float(metric[3]) for metric in test_files_ED])\n",
    "\n",
    "RV_GT_ED = np.array([float(metric[6]) for metric in test_files_ED])\n",
    "RV_pred_ED = np.array([float(metric[5]) for metric in test_files_ED])\n",
    "RV_dice_ED = np.array([float(metric[4]) for metric in test_files_ED])\n",
    "RV_err_ED = np.array([float(metric[7]) for metric in test_files_ED])\n",
    "\n",
    "MYO_GT_ED = np.array([float(metric[10]) for metric in test_files_ED])\n",
    "MYO_pred_ED = np.array([float(metric[9]) for metric in test_files_ED])\n",
    "MYO_dice_ED = np.array([float(metric[8]) for metric in test_files_ED])\n",
    "MYO_err_ED = np.array([float(metric[11]) for metric in test_files_ED])\n",
    "\n",
    "# ES\n",
    "LV_GT_ES = np.array([float(metric[2]) for metric in test_files_ES])\n",
    "LV_pred_ES = np.array([float(metric[1]) for metric in test_files_ES])\n",
    "LV_dice_ES = np.array([float(metric[0]) for metric in test_files_ES])\n",
    "LV_err_ES = np.array([float(metric[3]) for metric in test_files_ES])\n",
    "\n",
    "RV_GT_ES = np.array([float(metric[6]) for metric in test_files_ES])\n",
    "RV_pred_ES = np.array([float(metric[5]) for metric in test_files_ES])\n",
    "RV_dice_ES = np.array([float(metric[4]) for metric in test_files_ES])\n",
    "RV_err_ES = np.array([float(metric[7]) for metric in test_files_ES])\n",
    "\n",
    "MYO_GT_ES = np.array([float(metric[10]) for metric in test_files_ES])\n",
    "MYO_pred_ES = np.array([float(metric[9]) for metric in test_files_ES])\n",
    "MYO_dice_ES = np.array([float(metric[8]) for metric in test_files_ES])\n",
    "MYO_err_ES = np.array([float(metric[11]) for metric in test_files_ES])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7c36b2e0",
   "metadata": {},
   "source": [
    "Other ways for visualization\n",
    "> Comparison graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fd62eb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_comparison_graph(gt, pred, dice, err, color_boundaries=20, part=\"LV\", mode=\"ED\"):\n",
    "    plt.clf()\n",
    "    vmin = min(err)\n",
    "    vmax = max(err)\n",
    "\n",
    "    colors = ['red', 'yellow', 'green', 'yellow', 'red']\n",
    "    color_positions = [vmin, -1 * color_boundaries, 0, color_boundaries, vmax]\n",
    "    norm = plt.Normalize(vmin, vmax)\n",
    "    print(norm(color_positions))\n",
    "\n",
    "    colormap = list(zip(norm(color_positions), colors))\n",
    "    cmap = LinearSegmentedColormap.from_list(\"custom_cmap\", colormap)\n",
    "\n",
    "    plt.scatter(gt, pred, c=err, cmap=cmap, vmin=vmin, vmax=vmax, s=dice * 100)\n",
    "    plt.plot([min(gt), max(gt)], [min(gt), max(gt)], 'k--')\n",
    "    plt.xlabel(f'Ground Truth {part} Volume [ml]')\n",
    "    plt.ylabel(f'Predicted {part} Volume [ml]')\n",
    "    plt.title(f'Ground Truth vs. Predicted Volume of {part} {mode}')\n",
    "\n",
    "    cbar = plt.colorbar()\n",
    "    cbar.set_label('Error Value')\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1541570e",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_comparison_graph(LV_GT_ED, LV_pred_ED, LV_dice_ED, LV_err_ED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93c71590",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_comparison_graph(RV_GT_ED, RV_pred_ED, RV_dice_ED, RV_err_ED, color_boundaries=18, part=\"RV\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa1aa317",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_comparison_graph(MYO_GT_ED, MYO_pred_ED, MYO_dice_ED, MYO_err_ED, color_boundaries=18, part=\"MYO\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5b91e382",
   "metadata": {},
   "source": [
    "> Correlations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a21ffc52",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Correlation coefficients\n",
    "def calculate_correlation(gt, pred, err):\n",
    "    gt_mean = statistics.mean(gt)\n",
    "    gt_std = statistics.stdev(gt)\n",
    "    pred_mean = statistics.mean(pred)\n",
    "    pred_std = statistics.stdev(pred)\n",
    "    covariance = sum((x - gt_mean) * (y - pred_mean) for x, y in zip(gt, pred)) / len(gt)\n",
    "    corr = covariance / (gt_std * pred_std)\n",
    "    bias = statistics.mean(err)\n",
    "    err_std = statistics.stdev(err)\n",
    "    loa = 1.96*err_std\n",
    "    return corr, bias, loa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4ecf214",
   "metadata": {},
   "outputs": [],
   "source": [
    "LV_corr, LV_bias, LV_LOA = calculate_correlation(LV_GT_ED, LV_pred_ED, LV_err_ED)\n",
    "print('The correlation coefficient for left ventricle EDV is ', LV_corr)\n",
    "print('The bias for the left ventricle EDV is ', LV_bias)\n",
    "print('The LOA of the left ventricle EDV is ', LV_LOA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12c061d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "RV_corr, RV_bias, RV_LOA = calculate_correlation(RV_GT_ED, RV_pred_ED, RV_err_ED)\n",
    "print('The correlation coefficient for right ventricle EDV is ', RV_corr)\n",
    "print('The bias for the right ventricle EDV is ', RV_bias)\n",
    "print('The LOA of the right ventricle EDV is ', RV_LOA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccb53139",
   "metadata": {},
   "outputs": [],
   "source": [
    "MYO_corr, MYO_bias, MYO_LOA = calculate_correlation(MYO_GT_ED, MYO_pred_ED, MYO_err_ED)\n",
    "print('The correlation coefficient for myocardium EDV is ', MYO_corr)\n",
    "print('The bias for the left myocardium EDV is ', MYO_bias)\n",
    "print('The LOA of the left myocardium EDV is ', MYO_LOA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d85b01dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Average dice\n",
    "\n",
    "av_LV_dice = statistics.mean(LV_dice_ED)\n",
    "av_RV_dice = statistics.mean(RV_dice_ED)\n",
    "av_MYO_dice = statistics.mean(MYO_dice_ED)\n",
    "print('Average dice left ventricle is ', av_LV_dice)\n",
    "print('Average dice right ventricle is ', av_RV_dice)\n",
    "print('Average dice myocardium is ', av_MYO_dice)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ce64f7e5",
   "metadata": {},
   "source": [
    "ES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "437c2b32",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_comparison_graph(LV_GT_ES, LV_pred_ES, LV_dice_ES, LV_err_ES, color_boundaries=15, part=\"LV\", mode=\"ES\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ff66a64",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_comparison_graph(RV_GT_ES, RV_pred_ES, RV_dice_ES, RV_err_ES, color_boundaries=10, part=\"RV\", mode=\"ES\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "931d0abb",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_comparison_graph(MYO_GT_ES, MYO_pred_ES, MYO_dice_ES, MYO_err_ES, color_boundaries=30, part=\"MYO\", mode=\"ES\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d30933c",
   "metadata": {},
   "outputs": [],
   "source": [
    "LV_corr, LV_bias, LV_LOA = calculate_correlation(LV_GT_ES, LV_pred_ES, LV_err_ES)\n",
    "print('The correlation coefficient for left ventricle ESV is ', LV_corr)\n",
    "print('The bias for the left ventricle ESV is ', LV_bias)\n",
    "print('The LOA of the left ventricle ESV is ', LV_LOA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "823112fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "RV_corr, RV_bias, RV_LOA = calculate_correlation(RV_GT_ES, RV_pred_ES, RV_err_ES)\n",
    "print('The correlation coefficient for right ventricle ESV is ', RV_corr)\n",
    "print('The bias for the right ventricle ESV is ', RV_bias)\n",
    "print('The LOA of the right ventricle ESV is ', RV_LOA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50408bf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "MYO_corr, MYO_bias, MYO_LOA = calculate_correlation(MYO_GT_ES, MYO_pred_ES, MYO_err_ES)\n",
    "print('The correlation coefficient for myocardium ESV is ', MYO_corr)\n",
    "print('The bias for the left myocardium ESV is ', MYO_bias)\n",
    "print('The LOA of the left myocardium ESV is ', MYO_LOA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf0f0d91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Average dice\n",
    "av_LV_dice = statistics.mean(LV_dice_ES)\n",
    "av_RV_dice = statistics.mean(RV_dice_ES)\n",
    "av_MYO_dice = statistics.mean(MYO_dice_ES)\n",
    "print('Average dice left ventricle ESV is ', av_LV_dice)\n",
    "print('Average dice right ventricle ESV is ', av_RV_dice)\n",
    "print('Average dice myocardium ESV is ', av_MYO_dice)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ecf8346f",
   "metadata": {},
   "source": [
    "EJ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1671753",
   "metadata": {},
   "outputs": [],
   "source": [
    "EJ_LV_GT = [((LV_GT_ED[i] - LV_GT_ES[i])/ LV_GT_ED[i]) * 100 for i in range(len(LV_GT_ED))]\n",
    "EJ_RV_GT = [((RV_GT_ED[i] - RV_GT_ES[i])/ RV_GT_ED[i]) * 100 for i in range(len(RV_GT_ED))]\n",
    "\n",
    "EJ_LV_pred = [((LV_pred_ED[i] - LV_pred_ES[i])/ (LV_pred_ED[i]+0.1)) * 100 for i in range(len(LV_pred_ED))]\n",
    "EJ_RV_pred = [((RV_pred_ED[i] - RV_pred_ES[i])/ (RV_pred_ED[i]+0.1)) * 100 for i in range(len(RV_pred_ED))]\n",
    "\n",
    "EJ_LV_err = [EJ_LV_GT[i] - EJ_LV_pred[i] for i in range(len(EJ_LV_GT))]\n",
    "EJ_RV_err = [EJ_RV_GT[i] - EJ_RV_pred[i] for i in range(len(EJ_RV_GT))]\n",
    "\n",
    "EJ_LV_corr, EJ_LV_bias, LV_LOA = calculate_correlation(EJ_LV_GT, EJ_LV_pred, EJ_LV_err)\n",
    "print('The correlation of the ejection fraction for left ventricle is ', EJ_LV_corr)\n",
    "print('The mean bias of the ejection fraction of left ventricle is ', EJ_LV_bias)\n",
    "print('The LOA of the ejection fraction of the left ventricle is ', LV_LOA)\n",
    "\n",
    "EJ_RV_corr, EJ_RV_bias, RV_LOA = calculate_correlation(EJ_RV_GT, EJ_RV_pred, EJ_RV_err)\n",
    "print('The correlation of the ejection fraction for right ventricle is ', EJ_RV_corr)\n",
    "print('The mean bias of the ejection fraction of right ventricle is ', EJ_RV_bias)\n",
    "print('The LOA of the ejection fraction of the right ventricle is ', RV_LOA)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "de9c181c",
   "metadata": {},
   "source": [
    "Best and worst results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9a463de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This determines the best and worst performing predictions based on the dice score of all the segments and both time frames\n",
    "# Uses the LV_dice_ES list to find back the correct patient, but could be any of the lists\n",
    "\n",
    "\n",
    "score_sum = {LV_dice_ES[x]: sum([LV_dice_ES[x], RV_dice_ES[x], MYO_dice_ES[x], LV_dice_ED[x], RV_dice_ED[x], MYO_dice_ED[x]]) for x in range(len(LV_dice_ES))}\n",
    "max_value = max(score_sum, key=score_sum.get)\n",
    "min_value = min(score_sum, key=score_sum.get)\n",
    "print(min_value)\n",
    "del score_sum[min_value]\n",
    "\n",
    "min_value = min(score_sum, key=score_sum.get)\n",
    "print(min_value)\n",
    "\n",
    "max_index = np.where(LV_dice_ES == max_value)[0][0]\n",
    "min_index = np.where(LV_dice_ES == min_value)[0][0]\n",
    "\n",
    "print(\"Max value:\", max_value)\n",
    "print(\"Max index:\", max_index)\n",
    "\n",
    "print(\"Min value:\", min_value)\n",
    "print(\"Min index:\", min_index)\n",
    "\n",
    "# Show slices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ef0ed36",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Best image path gt ES is ', image_paths_gt_ES[max_index])\n",
    "print('Best label path gt ES is ', label_paths_gt_ES[max_index])\n",
    "\n",
    "print('Best image path prediction ES is ', image_paths_pred_ES[max_index])\n",
    "print('Best label path prediction ES is ', label_paths_pred_ES[max_index])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35b72c7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_slice(index, image_gt, label_gt, image_pred, label_pred,  channel_index=0, frame='ES'):\n",
    "    image_path_gt = image_gt[index]\n",
    "    label_path_gt = label_gt[index]\n",
    "    image_path_pred = image_pred[index]\n",
    "    label_path_pred = label_pred[index]\n",
    "\n",
    "    # Load image and label data\n",
    "    image_gt = nib.load(image_path_gt).get_fdata()\n",
    "    label_gt_ES = nib.load(label_path_gt).get_fdata()\n",
    "    image_pred = nib.load(image_path_pred).get_fdata()\n",
    "    label_pred = nib.load(label_path_pred).get_fdata()\n",
    "    return (\n",
    "        [\n",
    "        image_gt[..., channel_index],\n",
    "        label_gt_ES[..., channel_index],\n",
    "        image_pred[..., channel_index],\n",
    "        label_pred[..., channel_index]\n",
    "        ]\n",
    "    )\n",
    "def display_slices(min_index, max_index,  image_gt, label_gt, image_pred, label_pred,  channel_index=0, frame='ES'):\n",
    "    best_image_gt_ES_channel, best_label_gt_ES_channel, best_image_pred_ES_channel, best_label_pred_ES_channel = get_slice(max_index, image_gt, label_gt, image_pred, label_pred, channel_index)\n",
    "    worst_image_gt_ES_channel, worst_label_gt_ES_channel, worst_image_pred_ES_channel, worst_label_pred_ES_channel = get_slice(min_index, image_gt, label_gt, image_pred, label_pred, channel_index)\n",
    "\n",
    "    plt.subplot(2, 2, 1)\n",
    "    plt.imshow(best_image_gt_ES_channel, cmap='gray')\n",
    "    plt.title(f\"Ground truth best prediction {frame}\")\n",
    "    plt.imshow(best_label_gt_ES_channel, alpha=0.5, cmap='Reds')\n",
    "\n",
    "\n",
    "    plt.subplot(2, 2, 2)\n",
    "    plt.imshow(best_image_pred_ES_channel, cmap='gray')\n",
    "    plt.title(f\"Best prediction {frame}\")\n",
    "    plt.imshow(best_label_pred_ES_channel, cmap='Reds', alpha=0.5)\n",
    "\n",
    "    plt.subplot(2, 2, 3)\n",
    "    plt.imshow(worst_image_gt_ES_channel, cmap='gray')\n",
    "    plt.title(f\"Ground truth worst prediction {frame}\")\n",
    "    plt.imshow(worst_label_gt_ES_channel, alpha=0.5, cmap='Reds')\n",
    "\n",
    "    plt.subplot(2, 2, 4)\n",
    "    plt.imshow(worst_image_pred_ES_channel, cmap='gray')\n",
    "    plt.title(f\"Worst prediction {frame}\")\n",
    "    plt.imshow(worst_label_pred_ES_channel, cmap='Reds', alpha=0.5)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8074a3d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "display_slices(min_index, max_index, image_paths_gt_ES, label_paths_gt_ES, image_paths_pred_ES, label_paths_pred_ES, channel_index=4, frame='ES')\n",
    "compute_metrics_on_files(label_paths_gt_ED[max_index], label_paths_pred_ED[max_index])\n",
    "compute_metrics_on_files(label_paths_gt_ES[max_index], label_paths_pred_ES[max_index])\n",
    "print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2d5604a",
   "metadata": {},
   "outputs": [],
   "source": [
    "display_slices(min_index, max_index, image_paths_gt_ED, label_paths_gt_ED, image_paths_pred_ED, label_paths_pred_ED, channel_index=4, frame='ED')\n",
    "compute_metrics_on_files(label_paths_gt_ED[min_index], label_paths_pred_ED[min_index])\n",
    "compute_metrics_on_files(label_paths_gt_ES[min_index], label_paths_pred_ES[min_index])\n",
    "print()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c4ee1d4a",
   "metadata": {},
   "source": [
    "Hausdorff distances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a633891f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# First file\n",
    "label_paths_gt_ED = [d['label'] for d in test_dict_gt if d[\"first_file\"]]\n",
    "print(len(label_paths_gt_ED))\n",
    "\n",
    "# Last file\n",
    "label_paths_gt_ES = [d['label'] for d in test_dict_gt if not d[\"first_file\"]]\n",
    "\n",
    "# Prediction\n",
    "label_paths_pred_ED = [d['label'] for d in test_dict_pred if d[\"first_file\"]]\n",
    "\n",
    "label_paths_pred_ES = [d['label'] for d in test_dict_pred if not d[\"first_file\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3a8a4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# GT and pred of ED\n",
    "dir_gt_ED = []\n",
    "dir_pred_ED = []\n",
    "for label_path_gt_ED, label_path_pred_ED in zip(label_paths_gt_ED, label_paths_pred_ED):\n",
    "    label_GT_ED = nib.load(label_path_gt_ED).get_fdata()\n",
    "    label_pred_ED = nib.load(label_path_pred_ED).get_fdata()\n",
    "    dir_gt_ED.append(label_GT_ED)\n",
    "    dir_pred_ED.append(label_pred_ED)\n",
    "\n",
    "# GT and pred of ES\n",
    "dir_gt_ES = []\n",
    "dir_pred_ES = []\n",
    "for label_path_gt, label_path_pred in zip(label_paths_gt_ES, label_paths_pred_ES):\n",
    "    label_GT_ES = nib.load(label_path_gt).get_fdata()\n",
    "    label_pred = nib.load(label_path_pred).get_fdata()\n",
    "    dir_gt_ES.append(label_GT_ES)\n",
    "    dir_pred_ES.append(label_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "106f8cb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_gt_ED = []\n",
    "label_gt_ED_LV = []\n",
    "label_gt_ED_wall = []\n",
    "label_gt_ED_RV = []\n",
    "\n",
    "for i in range(50):\n",
    "    label_gt_ED_single = np.array(dir_gt_ED[i], dtype=np.float32)\n",
    "    label_gt_ED.append(label_gt_ED_single)\n",
    "    label_gt_ED_LV.append(np.where(label_gt_ED_single == 1,1,0))\n",
    "    label_gt_ED_wall.append(np.where(label_gt_ED_single == 2,1,0))\n",
    "    label_gt_ED_RV.append(np.where(label_gt_ED_single == 3,1,0))\n",
    "\n",
    "\n",
    "label_pred_ED = []\n",
    "label_pred_ED_LV = []\n",
    "label_pred_ED_wall = []\n",
    "label_pred_ED_RV = []\n",
    "\n",
    "for i in range(50):\n",
    "    label_pred_ED_single = np.array(dir_pred_ED[i], dtype=np.float32)\n",
    "    label_pred_ED.append(label_pred_ED_single)\n",
    "    label_pred_ED_LV.append(np.where(label_pred_ED_single == 1,1,0))\n",
    "    label_pred_ED_wall.append(np.where(label_pred_ED_single == 2,1,0))\n",
    "    label_pred_ED_RV.append(np.where(label_pred_ED_single == 3,1,0))\n",
    "\n",
    "label_gt_ES = []\n",
    "label_gt_ES_LV = []\n",
    "label_gt_ES_wall = []\n",
    "label_gt_ES_RV = []\n",
    "\n",
    "for i in range(50):\n",
    "    label_gt_ES_single = np.array(dir_gt_ES[i], dtype=np.float32)\n",
    "    label_gt_ES.append(label_gt_ES_single)\n",
    "    label_gt_ES_LV.append(np.where(label_gt_ES_single == 1,1,0))\n",
    "    label_gt_ES_wall.append(np.where(label_gt_ES_single == 2,1,0))\n",
    "    label_gt_ES_RV.append(np.where(label_gt_ES_single == 3,1,0))\n",
    "\n",
    "\n",
    "label_pred = []\n",
    "label_pred_ES_LV = []\n",
    "label_pred_ES_wall = []\n",
    "label_pred_ES_RV = []\n",
    "\n",
    "for i in range(50):\n",
    "    label_pred_ES_single = np.array(dir_pred_ES[i], dtype=np.float32)\n",
    "    label_pred.append(label_pred_ES_single)\n",
    "    label_pred_ES_LV.append(np.where(label_pred_ES_single == 1,1,0))\n",
    "    label_pred_ES_wall.append(np.where(label_pred_ES_single == 2,1,0))\n",
    "    label_pred_ES_RV.append(np.where(label_pred_ES_single == 3,1,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "385e0f03",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(label_pred_ES_RV))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a682e84",
   "metadata": {},
   "outputs": [],
   "source": [
    "hd_LV_ED = []\n",
    "hd_wall_ED = []\n",
    "hd_RV_ED = []\n",
    "hd_LV_ES = []\n",
    "hd_wall_ES = []\n",
    "hd_RV_ES = []\n",
    "\n",
    "for i in range(50):\n",
    "    # Calculate Hausdorff distance for ED\n",
    "    hd_lv_ed = max([hausdorff_distance(label_pred_ED_LV[i][:,:,y], label_gt_ED_LV[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_LV[i].shape[2])])\n",
    "    hd_wall_ed = max([hausdorff_distance(label_pred_ED_wall[i][:,:,y], label_gt_ED_wall[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_wall[i].shape[2])])\n",
    "    hd_rv_ed = max([hausdorff_distance(label_pred_ED_RV[i][:,:,y], label_gt_ED_RV[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_RV[i].shape[2])])\n",
    "    hd_LV_ED.append(hd_lv_ed)\n",
    "    hd_wall_ED.append(hd_wall_ed)\n",
    "    hd_RV_ED.append(hd_rv_ed)\n",
    "    \n",
    "    # Calculate Hausdorff distance for ES\n",
    "    hd_lv_es = max([hausdorff_distance(label_pred_ES_LV[i][:,:,y], label_gt_ES_LV[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_LV[i].shape[2])])\n",
    "    hd_wall_es = max([hausdorff_distance(label_pred_ES_wall[i][:,:,y], label_gt_ES_wall[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_wall[i].shape[2])])\n",
    "    hd_rv_es = max([hausdorff_distance(label_pred_ES_RV[i][:,:,y], label_gt_ES_RV[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_RV[i].shape[2])])\n",
    "    hd_LV_ES.append(hd_lv_es)\n",
    "    hd_wall_ES.append(hd_wall_es)\n",
    "    hd_RV_ES.append(hd_rv_es)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c4c95a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(hd_LV_ED))\n",
    "print(hd_wall_ES[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b9b673e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculations of average HD\n",
    "# ED\n",
    "HD_LV_ED_mean = statistics.mean(hd_LV_ED)\n",
    "HD_RV_ED_mean = statistics.mean(hd_RV_ED)\n",
    "HD_MYO_ED_mean = statistics.mean(hd_wall_ED)\n",
    "\n",
    "print('The average hd of the left ventricle EDV is ', HD_LV_ED_mean)\n",
    "print('The average hd of the right ventricle EDV is ', HD_RV_ED_mean)\n",
    "print('The average hd of the myocardium EDV is ', HD_MYO_ED_mean)\n",
    "\n",
    "HD_LV_ES_mean = statistics.mean(hd_LV_ES)\n",
    "HD_RV_ES_mean = statistics.mean(hd_RV_ES)\n",
    "HD_MYO_ES_mean = statistics.mean(hd_wall_ES)\n",
    "\n",
    "print('The average hd of the left ventricle ESV is ', HD_LV_ES_mean)\n",
    "print('The average hd of the right ventricle ESV is ', HD_RV_ES_mean)\n",
    "print('The average hd of the myocardium ESV is ', HD_MYO_ES_mean)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}