Skip to content
Snippets Groups Projects
Commit e86db1f6 authored by Bart Leenheer's avatar Bart Leenheer
Browse files

Added the ED and ES_prediction_example.png

parent 5a14eb63
No related branches found
No related tags found
No related merge requests found
ED_prediction_example.png

114 KiB

ES_prediction_example.png

113 KiB

%% Cell type:code id:df7898a3 tags: %% Cell type:code id:df7898a3 tags:
``` python ``` python
%pip install hausdorff %pip install hausdorff
%pip install numba %pip install numba
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import os import os
import glob import glob
import numpy as np import numpy as np
import nibabel as nib import nibabel as nib
from ACDCUNet import build_dict_images, build_dict_images_pred from ACDCUNet import build_dict_images, build_dict_images_pred
import statistics import statistics
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
from hausdorff import hausdorff_distance from hausdorff import hausdorff_distance
``` ```
%% Cell type:code id:d744133d tags: %% Cell type:code id:d744133d tags:
``` python ``` python
""" """
author: Clément Zotti (clement.zotti@usherbrooke.ca) author: Clément Zotti (clement.zotti@usherbrooke.ca)
date: April 2017 date: April 2017
DESCRIPTION : DESCRIPTION :
The script provide helpers functions to handle nifti image format: The script provide helpers functions to handle nifti image format:
- load_nii() - load_nii()
- save_nii() - save_nii()
to generate metrics for two images: to generate metrics for two images:
- metrics() - metrics()
And it is callable from the command line (see below). And it is callable from the command line (see below).
Each function provided in this script has comments to understand Each function provided in this script has comments to understand
how they works. how they works.
HOW-TO: HOW-TO:
This script was tested for python 3.4. This script was tested for python 3.4.
First, you need to install the required packages with First, you need to install the required packages with
pip install -r requirements.txt pip install -r requirements.txt
After the installation, you have two ways of running this script: After the installation, you have two ways of running this script:
1) python metrics.py ground_truth/patient001_ED.nii.gz prediction/patient001_ED.nii.gz 1) python metrics.py ground_truth/patient001_ED.nii.gz prediction/patient001_ED.nii.gz
2) python metrics.py ground_truth/ prediction/ 2) python metrics.py ground_truth/ prediction/
The first option will print in the console the dice and volume of each class for the given image. The first option will print in the console the dice and volume of each class for the given image.
The second option wiil ouput a csv file where each images will have the dice and volume of each class. The second option wiil ouput a csv file where each images will have the dice and volume of each class.
Link: http://acdc.creatis.insa-lyon.fr Link: http://acdc.creatis.insa-lyon.fr
""" """
HEADER = [ HEADER = [
"Name", "Name",
"Dice LV", "Dice LV",
"Volume LV pred", "Volume LV pred",
"Volume LV GT", "Volume LV GT",
"Err LV(ml)", "Err LV(ml)",
"Dice RV", "Dice RV",
"Volume RV pred", "Volume RV pred",
"Volume RV GT", "Volume RV GT",
"Err RV(ml)", "Err RV(ml)",
"Dice MYO", "Dice MYO",
"Volume MYO pred", "Volume MYO pred",
"Volume MYO GT", "Volume MYO GT",
"Err MYO(ml)", "Err MYO(ml)",
] ]
# #
# Utils functions used to sort strings into a natural order # Utils functions used to sort strings into a natural order
# #
def conv_int(i): def conv_int(i):
return int(i) if i.isdigit() else i return int(i) if i.isdigit() else i
def natural_order(sord): def natural_order(sord):
""" """
Sort a (list,tuple) of strings into natural order. Sort a (list,tuple) of strings into natural order.
Ex: Ex:
['1','10','2'] -> ['1','2','10'] ['1','10','2'] -> ['1','2','10']
['abc1def','ab10d','b2c','ab1d'] -> ['ab1d','ab10d', 'abc1def', 'b2c'] ['abc1def','ab10d','b2c','ab1d'] -> ['ab1d','ab10d', 'abc1def', 'b2c']
""" """
if isinstance(sord, tuple): if isinstance(sord, tuple):
sord = sord[0] sord = sord[0]
return [conv_int(c) for c in re.split(r"(\d+)", sord)] return [conv_int(c) for c in re.split(r"(\d+)", sord)]
# #
# Utils function to load and save nifti files with the nibabel package # Utils function to load and save nifti files with the nibabel package
# #
img_path = "ACDC\database" img_path = "ACDC\database"
def load_nii(img_path): def load_nii(img_path):
""" """
Function to load a 'nii' or 'nii.gz' file, The function returns Function to load a 'nii' or 'nii.gz' file, The function returns
everyting needed to save another 'nii' or 'nii.gz' everyting needed to save another 'nii' or 'nii.gz'
in the same dimensional space, i.e. the affine matrix and the header in the same dimensional space, i.e. the affine matrix and the header
Parameters Parameters
---------- ----------
img_path: string img_path: string
String with the path of the 'nii' or 'nii.gz' image file name. String with the path of the 'nii' or 'nii.gz' image file name.
Returns Returns
------- -------
Three element, the first is a numpy array of the image values, Three element, the first is a numpy array of the image values,
the second is the affine transformation of the image, and the the second is the affine transformation of the image, and the
last one is the header of the image. last one is the header of the image.
""" """
nimg = nib.load(img_path) nimg = nib.load(img_path)
return nimg.get_fdata(), nimg.affine, nimg.header return nimg.get_fdata(), nimg.affine, nimg.header
def save_nii(img_path, data, affine, header): def save_nii(img_path, data, affine, header):
""" """
Function to save a 'nii' or 'nii.gz' file. Function to save a 'nii' or 'nii.gz' file.
Parameters Parameters
---------- ----------
img_path: string img_path: string
Path to save the image should be ending with '.nii' or '.nii.gz'. Path to save the image should be ending with '.nii' or '.nii.gz'.
data: np.array data: np.array
Numpy array of the image data. Numpy array of the image data.
affine: list of list or np.array affine: list of list or np.array
The affine transformation to save with the image. The affine transformation to save with the image.
header: nib.Nifti1Header header: nib.Nifti1Header
The header that define everything about the data The header that define everything about the data
(pleasecheck nibabel documentation). (pleasecheck nibabel documentation).
""" """
nimg = nib.Nifti1Image(data, affine=affine, header=header) nimg = nib.Nifti1Image(data, affine=affine, header=header)
nimg.to_filename(img_path) nimg.to_filename(img_path)
# #
# Functions to process files, directories and metrics # Functions to process files, directories and metrics
# #
def metrics(img_gt, img_pred, voxel_size): def metrics(img_gt, img_pred, voxel_size):
""" """
Function to compute the metrics between two segmentation maps given as input. Function to compute the metrics between two segmentation maps given as input.
Parameters Parameters
---------- ----------
img_gt: np.array img_gt: np.array
Array of the ground truth segmentation map. Array of the ground truth segmentation map.
img_pred: np.array img_pred: np.array
Array of the predicted segmentation map. Array of the predicted segmentation map.
voxel_size: list, tuple or np.array voxel_size: list, tuple or np.array
The size of a voxel of the images used to compute the volumes. The size of a voxel of the images used to compute the volumes.
Return Return
------ ------
A list of metrics in this order, [Dice LV, Volume LV, Volume GT, Err LV(ml), A list of metrics in this order, [Dice LV, Volume LV, Volume GT, Err LV(ml),
Dice RV, Volume RV, Volume GT, Err RV(ml), Dice MYO, Volume MYO, Volume GT, Err MYO(ml)] Dice RV, Volume RV, Volume GT, Err RV(ml), Dice MYO, Volume MYO, Volume GT, Err MYO(ml)]
""" """
if img_gt.ndim != img_pred.ndim: if img_gt.ndim != img_pred.ndim:
raise ValueError( raise ValueError(
"The arrays 'img_gt' and 'img_pred' should have the " "The arrays 'img_gt' and 'img_pred' should have the "
"same dimension, {} against {}".format(img_gt.ndim, img_pred.ndim) "same dimension, {} against {}".format(img_gt.ndim, img_pred.ndim)
) )
res = [] res = []
# Loop on each classes of the input images # Loop on each classes of the input images
for c in [3, 1, 2]: for c in [3, 1, 2]:
# Copy the gt image to not alterate the input # Copy the gt image to not alterate the input
gt_c_i = np.copy(img_gt) gt_c_i = np.copy(img_gt)
gt_c_i[gt_c_i != c] = 0 gt_c_i[gt_c_i != c] = 0
# Copy the pred image to not alterate the input # Copy the pred image to not alterate the input
pred_c_i = np.copy(img_pred) pred_c_i = np.copy(img_pred)
pred_c_i[pred_c_i != c] = 0 pred_c_i[pred_c_i != c] = 0
# Clip the value to compute the volumes # Clip the value to compute the volumes
gt_c_i = np.clip(gt_c_i, 0, 1) gt_c_i = np.clip(gt_c_i, 0, 1)
pred_c_i = np.clip(pred_c_i, 0, 1) pred_c_i = np.clip(pred_c_i, 0, 1)
# Compute the Dice # Compute the Dice
# dice = dc(gt_c_i, pred_c_i) # dice = dc(gt_c_i, pred_c_i)
dice = 1 dice = 1
# Eventueel alternatief # Eventueel alternatief
gt_volume = np.sum(gt_c_i) gt_volume = np.sum(gt_c_i)
pred_volume = np.sum(pred_c_i) pred_volume = np.sum(pred_c_i)
intersect = np.sum(gt_c_i * pred_c_i) intersect = np.sum(gt_c_i * pred_c_i)
dice = (2 * intersect) / (gt_volume + pred_volume) dice = (2 * intersect) / (gt_volume + pred_volume)
# Compute volume # Compute volume
volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000.0 volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000.0
volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000.0 volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000.0
res += [dice, volpred, volgt, volpred - volgt] res += [dice, volpred, volgt, volpred - volgt]
return res return res
def compute_metrics_on_files(path_gt, path_pred): def compute_metrics_on_files(path_gt, path_pred):
""" """
Function to give the metrics for two files Function to give the metrics for two files
Parameters Parameters
---------- ----------
path_gt: string path_gt: string
Path of the ground truth image. Path of the ground truth image.
path_pred: string path_pred: string
Path of the predicted image. Path of the predicted image.
""" """
gt, _, header = load_nii(path_gt) gt, _, header = load_nii(path_gt)
pred, _, _ = load_nii(path_pred) pred, _, _ = load_nii(path_pred)
zooms = header.get_zooms() zooms = header.get_zooms()
name = os.path.basename(path_gt) name = os.path.basename(path_gt)
name = name.split(".")[0] name = name.split(".")[0]
res = metrics(gt, pred, zooms) res = metrics(gt, pred, zooms)
res = ["{:.3f}".format(r) for r in res] res = ["{:.3f}".format(r) for r in res]
formatting = "{:<20}" + "{:>12}" * len(res) formatting = "{:<20}" + "{:>12}" * len(res)
output = formatting.format(name, *res) output = formatting.format(name, *res)
print(formatting.format(*HEADER)) print(formatting.format(*HEADER))
print(output) print(output)
# formatting = "{:>14}, {:>7}, {:>9}, {:>10}, {:>7}, {:>9}, {:>10}, {:>8}, {:>10}, {:>11}" # formatting = "{:>14}, {:>7}, {:>9}, {:>10}, {:>7}, {:>9}, {:>10}, {:>8}, {:>10}, {:>11}"
# print(formatting.format(*HEADER)) # print(formatting.format(*HEADER))
# print(formatting.format(name, *res)) # print(formatting.format(name, *res))
# return [name, *res] # return [name, *res]
return res return res
def compute_metrics_on_directories(dir_gt, dir_pred): def compute_metrics_on_directories(dir_gt, dir_pred):
""" """
Function to generate a csv file for each images of two directories. Function to generate a csv file for each images of two directories.
Parameters Parameters
---------- ----------
path_gt: string path_gt: string
Directory of the ground truth segmentation maps. Directory of the ground truth segmentation maps.
path_pred: string path_pred: string
Directory of the predicted segmentation maps. Directory of the predicted segmentation maps.
""" """
lst_gt = sorted(glob(os.path.join(dir_gt, "*")), key=natural_order) lst_gt = sorted(glob(os.path.join(dir_gt, "*")), key=natural_order)
lst_pred = sorted(glob(os.path.join(dir_pred, "*")), key=natural_order) lst_pred = sorted(glob(os.path.join(dir_pred, "*")), key=natural_order)
res = [] res = []
for p_gt, p_pred in zip(lst_gt, lst_pred): for p_gt, p_pred in zip(lst_gt, lst_pred):
if os.path.basename(p_gt) != os.path.basename(p_pred): if os.path.basename(p_gt) != os.path.basename(p_pred):
raise ValueError( raise ValueError(
"The two files don't have the same name" "The two files don't have the same name"
" {}, {}.".format(os.path.basename(p_gt), os.path.basename(p_pred)) " {}, {}.".format(os.path.basename(p_gt), os.path.basename(p_pred))
) )
gt, _, header = load_nii(p_gt) gt, _, header = load_nii(p_gt)
pred, _, _ = load_nii(p_pred) pred, _, _ = load_nii(p_pred)
zooms = header.get_zooms() zooms = header.get_zooms()
res.append(metrics(gt, pred, zooms)) res.append(metrics(gt, pred, zooms))
lst_name_gt = [os.path.basename(gt).split(".")[0] for gt in lst_gt] lst_name_gt = [os.path.basename(gt).split(".")[0] for gt in lst_gt]
res = [ res = [
[ [
n, n,
] ]
+ r + r
for r, n in zip(res, lst_name_gt) for r, n in zip(res, lst_name_gt)
] ]
df = pd.DataFrame(res, columns=HEADER) df = pd.DataFrame(res, columns=HEADER)
df.to_csv("results_{}.csv".format(time.strftime("%Y%m%d_%H%M%S")), index=False) df.to_csv("results_{}.csv".format(time.strftime("%Y%m%d_%H%M%S")), index=False)
def main(path_gt, path_pred): def main(path_gt, path_pred):
""" """
Main function to select which method to apply on the input parameters. Main function to select which method to apply on the input parameters.
""" """
if os.path.isfile(path_gt) and os.path.isfile(path_pred): if os.path.isfile(path_gt) and os.path.isfile(path_pred):
compute_metrics_on_files(path_gt, path_pred) compute_metrics_on_files(path_gt, path_pred)
elif os.path.isdir(path_gt) and os.path.isdir(path_pred): elif os.path.isdir(path_gt) and os.path.isdir(path_pred):
compute_metrics_on_directories(path_gt, path_pred) compute_metrics_on_directories(path_gt, path_pred)
else: else:
raise ValueError("The paths given needs to be two directories or two files.") raise ValueError("The paths given needs to be two directories or two files.")
# if __name__ == "__main__": # if __name__ == "__main__":
# parser = argparse.ArgumentParser( # parser = argparse.ArgumentParser(
# description="Script to compute ACDC challenge metrics.") # description="Script to compute ACDC challenge metrics.")
# parser.add_argument("GT_IMG", type=str, help="Ground Truth image") # parser.add_argument("GT_IMG", type=str, help="Ground Truth image")
# parser.add_argument("PRED_IMG", type=str, help="Predicted image") # parser.add_argument("PRED_IMG", type=str, help="Predicted image")
# args = parser.parse_args() # args = parser.parse_args()
# main(args.GT_IMG, args.PRED_IMG) # main(args.GT_IMG, args.PRED_IMG)
``` ```
%% Cell type:code id:142caccb tags: %% Cell type:code id:142caccb tags:
``` python ``` python
# Testing # Testing
path_gt = os.path.join('ACDC','database','testing','patient101','patient101_frame01.nii.gz') path_gt = os.path.join('ACDC','database','testing','patient101','patient101_frame01.nii.gz')
dir_gt = os.path.join('ACDC','database','testing','patient101','patient101_frame01_gt.nii.gz') dir_gt = os.path.join('ACDC','database','testing','patient101','patient101_frame01_gt.nii.gz')
path_pred = os.path.join('ACDC','database','testing','patient101','patient101_frame01.nii.gz') path_pred = os.path.join('ACDC','database','testing','patient101','patient101_frame01.nii.gz')
dir_pred = os.path.join('ACDC','database','testing','patient101','patient101_frame01_ml_pred.nii.gz') dir_pred = os.path.join('ACDC','database','testing','patient101','patient101_frame01_ml_pred.nii.gz')
test_files = compute_metrics_on_files(dir_gt, dir_pred) test_files = compute_metrics_on_files(dir_gt, dir_pred)
``` ```
%% Cell type:code id:0451a847 tags: %% Cell type:code id:0451a847 tags:
``` python ``` python
data_path = "ACDC/database" data_path = "ACDC/database"
test_dict_gt = build_dict_images(data_path, mode='testing').ravel() test_dict_gt = build_dict_images(data_path, mode='testing').ravel()
test_dict_pred = build_dict_images_pred(data_path, mode='testing') test_dict_pred = build_dict_images_pred(data_path, mode='testing')
# Ground truth # Ground truth
# First file # First file
print(len(test_dict_gt)) print(len(test_dict_gt))
image_paths_gt_ED = [d['image'] for d in test_dict_gt if d["first_file"]] image_paths_gt_ED = [d['image'] for d in test_dict_gt if d["first_file"]]
label_paths_gt_ED = [d['label'] for d in test_dict_gt if d["first_file"]] label_paths_gt_ED = [d['label'] for d in test_dict_gt if d["first_file"]]
# Last file # Last file
image_paths_gt_ES = [d['image'] for d in test_dict_gt if not d["first_file"]] image_paths_gt_ES = [d['image'] for d in test_dict_gt if not d["first_file"]]
label_paths_gt_ES = [d['label'] for d in test_dict_gt if not d["first_file"]] label_paths_gt_ES = [d['label'] for d in test_dict_gt if not d["first_file"]]
print('Image path gt ED is ', image_paths_gt_ED) print('Image path gt ED is ', image_paths_gt_ED)
print('Label path gt ED is ', label_paths_gt_ED) print('Label path gt ED is ', label_paths_gt_ED)
print('Image path gt ES is ', image_paths_gt_ES) print('Image path gt ES is ', image_paths_gt_ES)
print('Label path gt ES is ', label_paths_gt_ES) print('Label path gt ES is ', label_paths_gt_ES)
# Prediction # Prediction
image_paths_pred_ED = [d['image'] for d in test_dict_pred if d["first_file"]] image_paths_pred_ED = [d['image'] for d in test_dict_pred if d["first_file"]]
label_paths_pred_ED = [d['label'] for d in test_dict_pred if d["first_file"]] label_paths_pred_ED = [d['label'] for d in test_dict_pred if d["first_file"]]
image_paths_pred_ES = [d['image'] for d in test_dict_pred if not d["first_file"]] image_paths_pred_ES = [d['image'] for d in test_dict_pred if not d["first_file"]]
label_paths_pred_ES = [d['label'] for d in test_dict_pred if not d["first_file"]] label_paths_pred_ES = [d['label'] for d in test_dict_pred if not d["first_file"]]
print('Image path prediction ED is ', image_paths_pred_ED) print('Image path prediction ED is ', image_paths_pred_ED)
print('Label path prediction ED is ', label_paths_pred_ED) print('Label path prediction ED is ', label_paths_pred_ED)
print('Image path prediction ES is ', image_paths_pred_ES) print('Image path prediction ES is ', image_paths_pred_ES)
print('Label path prediction ES is ', label_paths_pred_ES) print('Label path prediction ES is ', label_paths_pred_ES)
``` ```
%% Cell type:code id:6c7be742 tags: %% Cell type:code id:6c7be742 tags:
``` python ``` python
test_files_ED = [compute_metrics_on_files(label_paths_gt_ED[i], label_paths_pred_ED[i]) for i in range(len(label_paths_gt_ED))] test_files_ED = [compute_metrics_on_files(label_paths_gt_ED[i], label_paths_pred_ED[i]) for i in range(len(label_paths_gt_ED))]
test_files_ES = [compute_metrics_on_files(label_paths_gt_ES[i], label_paths_pred_ES[i]) for i in range(len(label_paths_gt_ES))] test_files_ES = [compute_metrics_on_files(label_paths_gt_ES[i], label_paths_pred_ES[i]) for i in range(len(label_paths_gt_ES))]
``` ```
%% Cell type:code id:049a0262 tags: %% Cell type:code id:049a0262 tags:
``` python ``` python
# ED # ED
LV_GT_ED = np.array([float(metric[2]) for metric in test_files_ED]) LV_GT_ED = np.array([float(metric[2]) for metric in test_files_ED])
LV_pred_ED = np.array([float(metric[1]) for metric in test_files_ED]) LV_pred_ED = np.array([float(metric[1]) for metric in test_files_ED])
LV_dice_ED = np.array([float(metric[0]) for metric in test_files_ED]) LV_dice_ED = np.array([float(metric[0]) for metric in test_files_ED])
LV_err_ED = np.array([float(metric[3]) for metric in test_files_ED]) LV_err_ED = np.array([float(metric[3]) for metric in test_files_ED])
RV_GT_ED = np.array([float(metric[6]) for metric in test_files_ED]) RV_GT_ED = np.array([float(metric[6]) for metric in test_files_ED])
RV_pred_ED = np.array([float(metric[5]) for metric in test_files_ED]) RV_pred_ED = np.array([float(metric[5]) for metric in test_files_ED])
RV_dice_ED = np.array([float(metric[4]) for metric in test_files_ED]) RV_dice_ED = np.array([float(metric[4]) for metric in test_files_ED])
RV_err_ED = np.array([float(metric[7]) for metric in test_files_ED]) RV_err_ED = np.array([float(metric[7]) for metric in test_files_ED])
MYO_GT_ED = np.array([float(metric[10]) for metric in test_files_ED]) MYO_GT_ED = np.array([float(metric[10]) for metric in test_files_ED])
MYO_pred_ED = np.array([float(metric[9]) for metric in test_files_ED]) MYO_pred_ED = np.array([float(metric[9]) for metric in test_files_ED])
MYO_dice_ED = np.array([float(metric[8]) for metric in test_files_ED]) MYO_dice_ED = np.array([float(metric[8]) for metric in test_files_ED])
MYO_err_ED = np.array([float(metric[11]) for metric in test_files_ED]) MYO_err_ED = np.array([float(metric[11]) for metric in test_files_ED])
# ES # ES
LV_GT_ES = np.array([float(metric[2]) for metric in test_files_ES]) LV_GT_ES = np.array([float(metric[2]) for metric in test_files_ES])
LV_pred_ES = np.array([float(metric[1]) for metric in test_files_ES]) LV_pred_ES = np.array([float(metric[1]) for metric in test_files_ES])
LV_dice_ES = np.array([float(metric[0]) for metric in test_files_ES]) LV_dice_ES = np.array([float(metric[0]) for metric in test_files_ES])
LV_err_ES = np.array([float(metric[3]) for metric in test_files_ES]) LV_err_ES = np.array([float(metric[3]) for metric in test_files_ES])
RV_GT_ES = np.array([float(metric[6]) for metric in test_files_ES]) RV_GT_ES = np.array([float(metric[6]) for metric in test_files_ES])
RV_pred_ES = np.array([float(metric[5]) for metric in test_files_ES]) RV_pred_ES = np.array([float(metric[5]) for metric in test_files_ES])
RV_dice_ES = np.array([float(metric[4]) for metric in test_files_ES]) RV_dice_ES = np.array([float(metric[4]) for metric in test_files_ES])
RV_err_ES = np.array([float(metric[7]) for metric in test_files_ES]) RV_err_ES = np.array([float(metric[7]) for metric in test_files_ES])
MYO_GT_ES = np.array([float(metric[10]) for metric in test_files_ES]) MYO_GT_ES = np.array([float(metric[10]) for metric in test_files_ES])
MYO_pred_ES = np.array([float(metric[9]) for metric in test_files_ES]) MYO_pred_ES = np.array([float(metric[9]) for metric in test_files_ES])
MYO_dice_ES = np.array([float(metric[8]) for metric in test_files_ES]) MYO_dice_ES = np.array([float(metric[8]) for metric in test_files_ES])
MYO_err_ES = np.array([float(metric[11]) for metric in test_files_ES]) MYO_err_ES = np.array([float(metric[11]) for metric in test_files_ES])
``` ```
%% Cell type:markdown id:7c36b2e0 tags: %% Cell type:markdown id:7c36b2e0 tags:
Other ways for visualization Other ways for visualization
> Comparison graphs > Comparison graphs
%% Cell type:code id:6fd62eb7 tags: %% Cell type:code id:6fd62eb7 tags:
``` python ``` python
def show_comparison_graph(gt, pred, dice, err, color_boundaries=20, part="LV", mode="ED"): def show_comparison_graph(gt, pred, dice, err, color_boundaries=20, part="LV", mode="ED"):
plt.clf() plt.clf()
vmin = min(err) vmin = min(err)
vmax = max(err) vmax = max(err)
colors = ['red', 'yellow', 'green', 'yellow', 'red'] colors = ['red', 'yellow', 'green', 'yellow', 'red']
color_positions = [vmin, -1 * color_boundaries, 0, color_boundaries, vmax] color_positions = [vmin, -1 * color_boundaries, 0, color_boundaries, vmax]
norm = plt.Normalize(vmin, vmax) norm = plt.Normalize(vmin, vmax)
print(norm(color_positions)) print(norm(color_positions))
colormap = list(zip(norm(color_positions), colors)) colormap = list(zip(norm(color_positions), colors))
cmap = LinearSegmentedColormap.from_list("custom_cmap", colormap) cmap = LinearSegmentedColormap.from_list("custom_cmap", colormap)
plt.scatter(gt, pred, c=err, cmap=cmap, vmin=vmin, vmax=vmax, s=dice * 100) plt.scatter(gt, pred, c=err, cmap=cmap, vmin=vmin, vmax=vmax, s=dice * 100)
plt.plot([min(gt), max(gt)], [min(gt), max(gt)], 'k--') plt.plot([min(gt), max(gt)], [min(gt), max(gt)], 'k--')
plt.xlabel(f'Ground Truth {part} Volume [ml]') plt.xlabel(f'Ground Truth {part} Volume [ml]')
plt.ylabel(f'Predicted {part} Volume [ml]') plt.ylabel(f'Predicted {part} Volume [ml]')
plt.title(f'Ground Truth vs. Predicted Volume of {part} {mode}') plt.title(f'Ground Truth vs. Predicted Volume of {part} {mode}')
cbar = plt.colorbar() cbar = plt.colorbar()
cbar.set_label('Error Value') cbar.set_label('Error Value')
plt.show() plt.show()
``` ```
%% Cell type:code id:1541570e tags: %% Cell type:code id:1541570e tags:
``` python ``` python
show_comparison_graph(LV_GT_ED, LV_pred_ED, LV_dice_ED, LV_err_ED) show_comparison_graph(LV_GT_ED, LV_pred_ED, LV_dice_ED, LV_err_ED)
``` ```
%% Cell type:code id:93c71590 tags: %% Cell type:code id:93c71590 tags:
``` python ``` python
show_comparison_graph(RV_GT_ED, RV_pred_ED, RV_dice_ED, RV_err_ED, color_boundaries=18, part="RV") show_comparison_graph(RV_GT_ED, RV_pred_ED, RV_dice_ED, RV_err_ED, color_boundaries=18, part="RV")
``` ```
%% Cell type:code id:fa1aa317 tags: %% Cell type:code id:fa1aa317 tags:
``` python ``` python
show_comparison_graph(MYO_GT_ED, MYO_pred_ED, MYO_dice_ED, MYO_err_ED, color_boundaries=18, part="MYO") show_comparison_graph(MYO_GT_ED, MYO_pred_ED, MYO_dice_ED, MYO_err_ED, color_boundaries=18, part="MYO")
``` ```
%% Cell type:markdown id:5b91e382 tags: %% Cell type:markdown id:5b91e382 tags:
> Correlations > Correlations
%% Cell type:code id:a21ffc52 tags: %% Cell type:code id:a21ffc52 tags:
``` python ``` python
# Correlation coefficients # Correlation coefficients
def calculate_correlation(gt, pred, err): def calculate_correlation(gt, pred, err):
gt_mean = statistics.mean(gt) gt_mean = statistics.mean(gt)
gt_std = statistics.stdev(gt) gt_std = statistics.stdev(gt)
pred_mean = statistics.mean(pred) pred_mean = statistics.mean(pred)
pred_std = statistics.stdev(pred) pred_std = statistics.stdev(pred)
covariance = sum((x - gt_mean) * (y - pred_mean) for x, y in zip(gt, pred)) / len(gt) covariance = sum((x - gt_mean) * (y - pred_mean) for x, y in zip(gt, pred)) / len(gt)
corr = covariance / (gt_std * pred_std) corr = covariance / (gt_std * pred_std)
bias = statistics.mean(err) bias = statistics.mean(err)
err_std = statistics.stdev(err) err_std = statistics.stdev(err)
loa = 1.96*err_std loa = 1.96*err_std
return corr, bias, loa return corr, bias, loa
``` ```
%% Cell type:code id:a4ecf214 tags: %% Cell type:code id:a4ecf214 tags:
``` python ``` python
LV_corr, LV_bias, LV_LOA = calculate_correlation(LV_GT_ED, LV_pred_ED, LV_err_ED) LV_corr, LV_bias, LV_LOA = calculate_correlation(LV_GT_ED, LV_pred_ED, LV_err_ED)
print('The correlation coefficient for left ventricle EDV is ', LV_corr) print('The correlation coefficient for left ventricle EDV is ', LV_corr)
print('The bias for the left ventricle EDV is ', LV_bias) print('The bias for the left ventricle EDV is ', LV_bias)
print('The LOA of the left ventricle EDV is ', LV_LOA) print('The LOA of the left ventricle EDV is ', LV_LOA)
``` ```
%% Cell type:code id:12c061d2 tags: %% Cell type:code id:12c061d2 tags:
``` python ``` python
RV_corr, RV_bias, RV_LOA = calculate_correlation(RV_GT_ED, RV_pred_ED, RV_err_ED) RV_corr, RV_bias, RV_LOA = calculate_correlation(RV_GT_ED, RV_pred_ED, RV_err_ED)
print('The correlation coefficient for right ventricle EDV is ', RV_corr) print('The correlation coefficient for right ventricle EDV is ', RV_corr)
print('The bias for the right ventricle EDV is ', RV_bias) print('The bias for the right ventricle EDV is ', RV_bias)
print('The LOA of the right ventricle EDV is ', RV_LOA) print('The LOA of the right ventricle EDV is ', RV_LOA)
``` ```
%% Cell type:code id:ccb53139 tags: %% Cell type:code id:ccb53139 tags:
``` python ``` python
MYO_corr, MYO_bias, MYO_LOA = calculate_correlation(MYO_GT_ED, MYO_pred_ED, MYO_err_ED) MYO_corr, MYO_bias, MYO_LOA = calculate_correlation(MYO_GT_ED, MYO_pred_ED, MYO_err_ED)
print('The correlation coefficient for myocardium EDV is ', MYO_corr) print('The correlation coefficient for myocardium EDV is ', MYO_corr)
print('The bias for the left myocardium EDV is ', MYO_bias) print('The bias for the left myocardium EDV is ', MYO_bias)
print('The LOA of the left myocardium EDV is ', MYO_LOA) print('The LOA of the left myocardium EDV is ', MYO_LOA)
``` ```
%% Cell type:code id:d85b01dd tags: %% Cell type:code id:d85b01dd tags:
``` python ``` python
# Average dice # Average dice
av_LV_dice = statistics.mean(LV_dice_ED) av_LV_dice = statistics.mean(LV_dice_ED)
av_RV_dice = statistics.mean(RV_dice_ED) av_RV_dice = statistics.mean(RV_dice_ED)
av_MYO_dice = statistics.mean(MYO_dice_ED) av_MYO_dice = statistics.mean(MYO_dice_ED)
print('Average dice left ventricle is ', av_LV_dice) print('Average dice left ventricle is ', av_LV_dice)
print('Average dice right ventricle is ', av_RV_dice) print('Average dice right ventricle is ', av_RV_dice)
print('Average dice myocardium is ', av_MYO_dice) print('Average dice myocardium is ', av_MYO_dice)
``` ```
%% Cell type:markdown id:ce64f7e5 tags: %% Cell type:markdown id:ce64f7e5 tags:
ES ES
%% Cell type:code id:437c2b32 tags: %% Cell type:code id:437c2b32 tags:
``` python ``` python
show_comparison_graph(LV_GT_ES, LV_pred_ES, LV_dice_ES, LV_err_ES, color_boundaries=15, part="LV", mode="ES") show_comparison_graph(LV_GT_ES, LV_pred_ES, LV_dice_ES, LV_err_ES, color_boundaries=15, part="LV", mode="ES")
``` ```
%% Cell type:code id:0ff66a64 tags: %% Cell type:code id:0ff66a64 tags:
``` python ``` python
show_comparison_graph(RV_GT_ES, RV_pred_ES, RV_dice_ES, RV_err_ES, color_boundaries=10, part="RV", mode="ES") show_comparison_graph(RV_GT_ES, RV_pred_ES, RV_dice_ES, RV_err_ES, color_boundaries=10, part="RV", mode="ES")
``` ```
%% Cell type:code id:931d0abb tags: %% Cell type:code id:931d0abb tags:
``` python ``` python
show_comparison_graph(MYO_GT_ES, MYO_pred_ES, MYO_dice_ES, MYO_err_ES, color_boundaries=30, part="MYO", mode="ES") show_comparison_graph(MYO_GT_ES, MYO_pred_ES, MYO_dice_ES, MYO_err_ES, color_boundaries=30, part="MYO", mode="ES")
``` ```
%% Cell type:code id:4d30933c tags: %% Cell type:code id:4d30933c tags:
``` python ``` python
LV_corr, LV_bias, LV_LOA = calculate_correlation(LV_GT_ES, LV_pred_ES, LV_err_ES) LV_corr, LV_bias, LV_LOA = calculate_correlation(LV_GT_ES, LV_pred_ES, LV_err_ES)
print('The correlation coefficient for left ventricle ESV is ', LV_corr) print('The correlation coefficient for left ventricle ESV is ', LV_corr)
print('The bias for the left ventricle ESV is ', LV_bias) print('The bias for the left ventricle ESV is ', LV_bias)
print('The LOA of the left ventricle ESV is ', LV_LOA) print('The LOA of the left ventricle ESV is ', LV_LOA)
``` ```
%% Cell type:code id:823112fe tags: %% Cell type:code id:823112fe tags:
``` python ``` python
RV_corr, RV_bias, RV_LOA = calculate_correlation(RV_GT_ES, RV_pred_ES, RV_err_ES) RV_corr, RV_bias, RV_LOA = calculate_correlation(RV_GT_ES, RV_pred_ES, RV_err_ES)
print('The correlation coefficient for right ventricle ESV is ', RV_corr) print('The correlation coefficient for right ventricle ESV is ', RV_corr)
print('The bias for the right ventricle ESV is ', RV_bias) print('The bias for the right ventricle ESV is ', RV_bias)
print('The LOA of the right ventricle ESV is ', RV_LOA) print('The LOA of the right ventricle ESV is ', RV_LOA)
``` ```
%% Cell type:code id:50408bf1 tags: %% Cell type:code id:50408bf1 tags:
``` python ``` python
MYO_corr, MYO_bias, MYO_LOA = calculate_correlation(MYO_GT_ES, MYO_pred_ES, MYO_err_ES) MYO_corr, MYO_bias, MYO_LOA = calculate_correlation(MYO_GT_ES, MYO_pred_ES, MYO_err_ES)
print('The correlation coefficient for myocardium ESV is ', MYO_corr) print('The correlation coefficient for myocardium ESV is ', MYO_corr)
print('The bias for the left myocardium ESV is ', MYO_bias) print('The bias for the left myocardium ESV is ', MYO_bias)
print('The LOA of the left myocardium ESV is ', MYO_LOA) print('The LOA of the left myocardium ESV is ', MYO_LOA)
``` ```
%% Cell type:code id:cf0f0d91 tags: %% Cell type:code id:cf0f0d91 tags:
``` python ``` python
# Average dice # Average dice
av_LV_dice = statistics.mean(LV_dice_ES) av_LV_dice = statistics.mean(LV_dice_ES)
av_RV_dice = statistics.mean(RV_dice_ES) av_RV_dice = statistics.mean(RV_dice_ES)
av_MYO_dice = statistics.mean(MYO_dice_ES) av_MYO_dice = statistics.mean(MYO_dice_ES)
print('Average dice left ventricle ESV is ', av_LV_dice) print('Average dice left ventricle ESV is ', av_LV_dice)
print('Average dice right ventricle ESV is ', av_RV_dice) print('Average dice right ventricle ESV is ', av_RV_dice)
print('Average dice myocardium ESV is ', av_MYO_dice) print('Average dice myocardium ESV is ', av_MYO_dice)
``` ```
%% Cell type:markdown id:ecf8346f tags: %% Cell type:markdown id:ecf8346f tags:
EJ EJ
%% Cell type:code id:c1671753 tags: %% Cell type:code id:c1671753 tags:
``` python ``` python
EJ_LV_GT = [((LV_GT_ED[i] - LV_GT_ES[i])/ LV_GT_ED[i]) * 100 for i in range(len(LV_GT_ED))] EJ_LV_GT = [((LV_GT_ED[i] - LV_GT_ES[i])/ LV_GT_ED[i]) * 100 for i in range(len(LV_GT_ED))]
EJ_RV_GT = [((RV_GT_ED[i] - RV_GT_ES[i])/ RV_GT_ED[i]) * 100 for i in range(len(RV_GT_ED))] EJ_RV_GT = [((RV_GT_ED[i] - RV_GT_ES[i])/ RV_GT_ED[i]) * 100 for i in range(len(RV_GT_ED))]
EJ_LV_pred = [((LV_pred_ED[i] - LV_pred_ES[i])/ (LV_pred_ED[i]+0.1)) * 100 for i in range(len(LV_pred_ED))] EJ_LV_pred = [((LV_pred_ED[i] - LV_pred_ES[i])/ (LV_pred_ED[i]+0.1)) * 100 for i in range(len(LV_pred_ED))]
EJ_RV_pred = [((RV_pred_ED[i] - RV_pred_ES[i])/ (RV_pred_ED[i]+0.1)) * 100 for i in range(len(RV_pred_ED))] EJ_RV_pred = [((RV_pred_ED[i] - RV_pred_ES[i])/ (RV_pred_ED[i]+0.1)) * 100 for i in range(len(RV_pred_ED))]
EJ_LV_err = [EJ_LV_GT[i] - EJ_LV_pred[i] for i in range(len(EJ_LV_GT))] EJ_LV_err = [EJ_LV_GT[i] - EJ_LV_pred[i] for i in range(len(EJ_LV_GT))]
EJ_RV_err = [EJ_RV_GT[i] - EJ_RV_pred[i] for i in range(len(EJ_RV_GT))] EJ_RV_err = [EJ_RV_GT[i] - EJ_RV_pred[i] for i in range(len(EJ_RV_GT))]
EJ_LV_corr, EJ_LV_bias, LV_LOA = calculate_correlation(EJ_LV_GT, EJ_LV_pred, EJ_LV_err) EJ_LV_corr, EJ_LV_bias, LV_LOA = calculate_correlation(EJ_LV_GT, EJ_LV_pred, EJ_LV_err)
print('The correlation of the ejection fraction for left ventricle is ', EJ_LV_corr) print('The correlation of the ejection fraction for left ventricle is ', EJ_LV_corr)
print('The mean bias of the ejection fraction of left ventricle is ', EJ_LV_bias) print('The mean bias of the ejection fraction of left ventricle is ', EJ_LV_bias)
print('The LOA of the ejection fraction of the left ventricle is ', LV_LOA) print('The LOA of the ejection fraction of the left ventricle is ', LV_LOA)
EJ_RV_corr, EJ_RV_bias, RV_LOA = calculate_correlation(EJ_RV_GT, EJ_RV_pred, EJ_RV_err) EJ_RV_corr, EJ_RV_bias, RV_LOA = calculate_correlation(EJ_RV_GT, EJ_RV_pred, EJ_RV_err)
print('The correlation of the ejection fraction for right ventricle is ', EJ_RV_corr) print('The correlation of the ejection fraction for right ventricle is ', EJ_RV_corr)
print('The mean bias of the ejection fraction of right ventricle is ', EJ_RV_bias) print('The mean bias of the ejection fraction of right ventricle is ', EJ_RV_bias)
print('The LOA of the ejection fraction of the right ventricle is ', RV_LOA) print('The LOA of the ejection fraction of the right ventricle is ', RV_LOA)
``` ```
%% Cell type:markdown id:de9c181c tags: %% Cell type:markdown id:de9c181c tags:
Best and worst results Best and worst results
%% Cell type:code id:c9a463de tags: %% Cell type:code id:c9a463de tags:
``` python ``` python
# ES # This determines the best and worst performing predictions based on the dice score of all the segments and both time frames
max_value = max(LV_dice_ES) # Uses the LV_dice_ES list to find back the correct patient, but could be any of the lists
min_value = sorted(LV_dice_ES)[1]
score_sum = {LV_dice_ES[x]: sum([LV_dice_ES[x], RV_dice_ES[x], MYO_dice_ES[x], LV_dice_ED[x], RV_dice_ED[x], MYO_dice_ED[x]]) for x in range(len(LV_dice_ES))}
max_value = max(score_sum, key=score_sum.get)
min_value = min(score_sum, key=score_sum.get)
print(min_value)
del score_sum[min_value]
min_value = min(score_sum, key=score_sum.get)
print(min_value)
max_index = np.where(LV_dice_ES == max_value)[0][0] max_index = np.where(LV_dice_ES == max_value)[0][0]
min_index = np.where(LV_dice_ES == min_value)[0][0] min_index = np.where(LV_dice_ES == min_value)[0][0]
print("Max value:", max_value) print("Max value:", max_value)
print("Max index:", max_index) print("Max index:", max_index)
print("Min value:", min_value) print("Min value:", min_value)
print("Min index:", min_index) print("Min index:", min_index)
# Show slices # Show slices
``` ```
%% Cell type:code id:0ef0ed36 tags: %% Cell type:code id:0ef0ed36 tags:
``` python ``` python
print('Best image path gt ES is ', image_paths_gt_ES[max_index]) print('Best image path gt ES is ', image_paths_gt_ES[max_index])
print('Best label path gt ES is ', label_paths_gt_ES[max_index]) print('Best label path gt ES is ', label_paths_gt_ES[max_index])
print('Best image path prediction ES is ', image_paths_pred_ES[max_index]) print('Best image path prediction ES is ', image_paths_pred_ES[max_index])
print('Best label path prediction ES is ', label_paths_pred_ES[max_index]) print('Best label path prediction ES is ', label_paths_pred_ES[max_index])
``` ```
%% Cell type:code id:35b72c7c tags: %% Cell type:code id:35b72c7c tags:
``` python ``` python
def get_slice(index, image_gt, label_gt, image_pred, label_pred, channel_index=0, frame='ES'): def get_slice(index, image_gt, label_gt, image_pred, label_pred, channel_index=0, frame='ES'):
image_path_gt = image_gt[index] image_path_gt = image_gt[index]
label_path_gt = label_gt[index] label_path_gt = label_gt[index]
image_path_pred = image_pred[index] image_path_pred = image_pred[index]
label_path_pred = label_pred[index] label_path_pred = label_pred[index]
# Load image and label data # Load image and label data
image_gt = nib.load(image_path_gt).get_fdata() image_gt = nib.load(image_path_gt).get_fdata()
label_gt_ES = nib.load(label_path_gt).get_fdata() label_gt_ES = nib.load(label_path_gt).get_fdata()
image_pred = nib.load(image_path_pred).get_fdata() image_pred = nib.load(image_path_pred).get_fdata()
label_pred = nib.load(label_path_pred).get_fdata() label_pred = nib.load(label_path_pred).get_fdata()
return ( return (
[ [
image_gt[..., channel_index], image_gt[..., channel_index],
label_gt_ES[..., channel_index], label_gt_ES[..., channel_index],
image_pred[..., channel_index], image_pred[..., channel_index],
label_pred[..., channel_index] label_pred[..., channel_index]
] ]
) )
def display_slices(min_index, max_index, image_gt, label_gt, image_pred, label_pred, channel_index=0, frame='ES'): def display_slices(min_index, max_index, image_gt, label_gt, image_pred, label_pred, channel_index=0, frame='ES'):
best_image_gt_ES_channel, best_label_gt_ES_channel, best_image_pred_ES_channel, best_label_pred_ES_channel = get_slice(max_index, image_gt, label_gt, image_pred, label_pred, channel_index) best_image_gt_ES_channel, best_label_gt_ES_channel, best_image_pred_ES_channel, best_label_pred_ES_channel = get_slice(max_index, image_gt, label_gt, image_pred, label_pred, channel_index)
worst_image_gt_ES_channel, worst_label_gt_ES_channel, worst_image_pred_ES_channel, worst_label_pred_ES_channel = get_slice(min_index, image_gt, label_gt, image_pred, label_pred, channel_index) worst_image_gt_ES_channel, worst_label_gt_ES_channel, worst_image_pred_ES_channel, worst_label_pred_ES_channel = get_slice(min_index, image_gt, label_gt, image_pred, label_pred, channel_index)
plt.subplot(2, 2, 1) plt.subplot(2, 2, 1)
plt.imshow(best_image_gt_ES_channel, cmap='gray') plt.imshow(best_image_gt_ES_channel, cmap='gray')
plt.title(f"Ground truth best prediction {frame}") plt.title(f"Ground truth best prediction {frame}")
plt.imshow(best_label_gt_ES_channel, alpha=0.5, cmap='Reds') plt.imshow(best_label_gt_ES_channel, alpha=0.5, cmap='Reds')
plt.subplot(2, 2, 2) plt.subplot(2, 2, 2)
plt.imshow(best_image_pred_ES_channel, cmap='gray') plt.imshow(best_image_pred_ES_channel, cmap='gray')
plt.title(f"Best prediction {frame}") plt.title(f"Best prediction {frame}")
plt.imshow(best_label_pred_ES_channel, cmap='Reds', alpha=0.5) plt.imshow(best_label_pred_ES_channel, cmap='Reds', alpha=0.5)
plt.subplot(2, 2, 3) plt.subplot(2, 2, 3)
plt.imshow(worst_image_gt_ES_channel, cmap='gray') plt.imshow(worst_image_gt_ES_channel, cmap='gray')
plt.title(f"Ground truth worst prediction {frame}") plt.title(f"Ground truth worst prediction {frame}")
plt.imshow(worst_label_gt_ES_channel, alpha=0.5, cmap='Reds') plt.imshow(worst_label_gt_ES_channel, alpha=0.5, cmap='Reds')
plt.subplot(2, 2, 4) plt.subplot(2, 2, 4)
plt.imshow(worst_image_pred_ES_channel, cmap='gray') plt.imshow(worst_image_pred_ES_channel, cmap='gray')
plt.title(f"Worst prediction {frame}") plt.title(f"Worst prediction {frame}")
plt.imshow(worst_label_pred_ES_channel, cmap='Reds', alpha=0.5) plt.imshow(worst_label_pred_ES_channel, cmap='Reds', alpha=0.5)
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()
``` ```
%% Cell type:code id:8074a3d4 tags: %% Cell type:code id:8074a3d4 tags:
``` python ``` python
display_slices(min_index, max_index, image_paths_gt_ES, label_paths_gt_ES, image_paths_pred_ES, label_paths_pred_ES, channel_index=0, frame='ES') display_slices(min_index, max_index, image_paths_gt_ES, label_paths_gt_ES, image_paths_pred_ES, label_paths_pred_ES, channel_index=4, frame='ES')
compute_metrics_on_files(label_paths_gt_ED[max_index], label_paths_pred_ED[max_index])
compute_metrics_on_files(label_paths_gt_ES[max_index], label_paths_pred_ES[max_index])
print()
``` ```
%% Cell type:code id:e2d5604a tags: %% Cell type:code id:e2d5604a tags:
``` python ``` python
display_slices(min_index, max_index, image_paths_gt_ED, label_paths_gt_ED, image_paths_pred_ED, label_paths_pred_ED, channel_index=1, frame='ED') display_slices(min_index, max_index, image_paths_gt_ED, label_paths_gt_ED, image_paths_pred_ED, label_paths_pred_ED, channel_index=4, frame='ED')
compute_metrics_on_files(label_paths_gt_ED[min_index], label_paths_pred_ED[min_index])
compute_metrics_on_files(label_paths_gt_ES[min_index], label_paths_pred_ES[min_index])
print()
``` ```
%% Cell type:markdown id:c4ee1d4a tags: %% Cell type:markdown id:c4ee1d4a tags:
Hausdorff distances Hausdorff distances
%% Cell type:code id:a633891f tags: %% Cell type:code id:a633891f tags:
``` python ``` python
# First file # First file
label_paths_gt_ED = [d['label'] for d in test_dict_gt if d["first_file"]] label_paths_gt_ED = [d['label'] for d in test_dict_gt if d["first_file"]]
print(len(label_paths_gt_ED)) print(len(label_paths_gt_ED))
# Last file # Last file
label_paths_gt_ES = [d['label'] for d in test_dict_gt if not d["first_file"]] label_paths_gt_ES = [d['label'] for d in test_dict_gt if not d["first_file"]]
# Prediction # Prediction
label_paths_pred_ED = [d['label'] for d in test_dict_pred if d["first_file"]] label_paths_pred_ED = [d['label'] for d in test_dict_pred if d["first_file"]]
label_paths_pred_ES = [d['label'] for d in test_dict_pred if not d["first_file"]] label_paths_pred_ES = [d['label'] for d in test_dict_pred if not d["first_file"]]
``` ```
%% Cell type:code id:f3a8a4e0 tags: %% Cell type:code id:f3a8a4e0 tags:
``` python ``` python
# GT and pred of ED # GT and pred of ED
dir_gt_ED = [] dir_gt_ED = []
dir_pred_ED = [] dir_pred_ED = []
for label_path_gt_ED, label_path_pred_ED in zip(label_paths_gt_ED, label_paths_pred_ED): for label_path_gt_ED, label_path_pred_ED in zip(label_paths_gt_ED, label_paths_pred_ED):
label_GT_ED = nib.load(label_path_gt_ED).get_fdata() label_GT_ED = nib.load(label_path_gt_ED).get_fdata()
label_pred_ED = nib.load(label_path_pred_ED).get_fdata() label_pred_ED = nib.load(label_path_pred_ED).get_fdata()
dir_gt_ED.append(label_GT_ED) dir_gt_ED.append(label_GT_ED)
dir_pred_ED.append(label_pred_ED) dir_pred_ED.append(label_pred_ED)
# GT and pred of ES # GT and pred of ES
dir_gt_ES = [] dir_gt_ES = []
dir_pred_ES = [] dir_pred_ES = []
for label_path_gt, label_path_pred in zip(label_paths_gt_ES, label_paths_pred_ES): for label_path_gt, label_path_pred in zip(label_paths_gt_ES, label_paths_pred_ES):
label_GT_ES = nib.load(label_path_gt).get_fdata() label_GT_ES = nib.load(label_path_gt).get_fdata()
label_pred = nib.load(label_path_pred).get_fdata() label_pred = nib.load(label_path_pred).get_fdata()
dir_gt_ES.append(label_GT_ES) dir_gt_ES.append(label_GT_ES)
dir_pred_ES.append(label_pred) dir_pred_ES.append(label_pred)
``` ```
%% Cell type:code id:106f8cb8 tags: %% Cell type:code id:106f8cb8 tags:
``` python ``` python
label_gt_ED = [] label_gt_ED = []
label_gt_ED_LV = [] label_gt_ED_LV = []
label_gt_ED_wall = [] label_gt_ED_wall = []
label_gt_ED_RV = [] label_gt_ED_RV = []
for i in range(50): for i in range(50):
label_gt_ED_single = np.array(dir_gt_ED[i], dtype=np.float32) label_gt_ED_single = np.array(dir_gt_ED[i], dtype=np.float32)
label_gt_ED.append(label_gt_ED_single) label_gt_ED.append(label_gt_ED_single)
label_gt_ED_LV.append(np.where(label_gt_ED_single == 1,1,0)) label_gt_ED_LV.append(np.where(label_gt_ED_single == 1,1,0))
label_gt_ED_wall.append(np.where(label_gt_ED_single == 2,1,0)) label_gt_ED_wall.append(np.where(label_gt_ED_single == 2,1,0))
label_gt_ED_RV.append(np.where(label_gt_ED_single == 3,1,0)) label_gt_ED_RV.append(np.where(label_gt_ED_single == 3,1,0))
label_pred_ED = [] label_pred_ED = []
label_pred_ED_LV = [] label_pred_ED_LV = []
label_pred_ED_wall = [] label_pred_ED_wall = []
label_pred_ED_RV = [] label_pred_ED_RV = []
for i in range(50): for i in range(50):
label_pred_ED_single = np.array(dir_pred_ED[i], dtype=np.float32) label_pred_ED_single = np.array(dir_pred_ED[i], dtype=np.float32)
label_pred_ED.append(label_pred_ED_single) label_pred_ED.append(label_pred_ED_single)
label_pred_ED_LV.append(np.where(label_pred_ED_single == 1,1,0)) label_pred_ED_LV.append(np.where(label_pred_ED_single == 1,1,0))
label_pred_ED_wall.append(np.where(label_pred_ED_single == 2,1,0)) label_pred_ED_wall.append(np.where(label_pred_ED_single == 2,1,0))
label_pred_ED_RV.append(np.where(label_pred_ED_single == 3,1,0)) label_pred_ED_RV.append(np.where(label_pred_ED_single == 3,1,0))
label_gt_ES = [] label_gt_ES = []
label_gt_ES_LV = [] label_gt_ES_LV = []
label_gt_ES_wall = [] label_gt_ES_wall = []
label_gt_ES_RV = [] label_gt_ES_RV = []
for i in range(50): for i in range(50):
label_gt_ES_single = np.array(dir_gt_ES[i], dtype=np.float32) label_gt_ES_single = np.array(dir_gt_ES[i], dtype=np.float32)
label_gt_ES.append(label_gt_ES_single) label_gt_ES.append(label_gt_ES_single)
label_gt_ES_LV.append(np.where(label_gt_ES_single == 1,1,0)) label_gt_ES_LV.append(np.where(label_gt_ES_single == 1,1,0))
label_gt_ES_wall.append(np.where(label_gt_ES_single == 2,1,0)) label_gt_ES_wall.append(np.where(label_gt_ES_single == 2,1,0))
label_gt_ES_RV.append(np.where(label_gt_ES_single == 3,1,0)) label_gt_ES_RV.append(np.where(label_gt_ES_single == 3,1,0))
label_pred = [] label_pred = []
label_pred_ES_LV = [] label_pred_ES_LV = []
label_pred_ES_wall = [] label_pred_ES_wall = []
label_pred_ES_RV = [] label_pred_ES_RV = []
for i in range(50): for i in range(50):
label_pred_ES_single = np.array(dir_pred_ES[i], dtype=np.float32) label_pred_ES_single = np.array(dir_pred_ES[i], dtype=np.float32)
label_pred.append(label_pred_ES_single) label_pred.append(label_pred_ES_single)
label_pred_ES_LV.append(np.where(label_pred_ES_single == 1,1,0)) label_pred_ES_LV.append(np.where(label_pred_ES_single == 1,1,0))
label_pred_ES_wall.append(np.where(label_pred_ES_single == 2,1,0)) label_pred_ES_wall.append(np.where(label_pred_ES_single == 2,1,0))
label_pred_ES_RV.append(np.where(label_pred_ES_single == 3,1,0)) label_pred_ES_RV.append(np.where(label_pred_ES_single == 3,1,0))
``` ```
%% Cell type:code id:385e0f03 tags: %% Cell type:code id:385e0f03 tags:
``` python ``` python
print(len(label_pred_ES_RV)) print(len(label_pred_ES_RV))
``` ```
%% Cell type:code id:9a682e84 tags: %% Cell type:code id:9a682e84 tags:
``` python ``` python
hd_LV_ED = [] hd_LV_ED = []
hd_wall_ED = [] hd_wall_ED = []
hd_RV_ED = [] hd_RV_ED = []
hd_LV_ES = [] hd_LV_ES = []
hd_wall_ES = [] hd_wall_ES = []
hd_RV_ES = [] hd_RV_ES = []
for i in range(50): for i in range(50):
# Calculate Hausdorff distance for ED # Calculate Hausdorff distance for ED
hd_lv_ed = max([hausdorff_distance(label_pred_ED_LV[i][:,:,y], label_gt_ED_LV[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_LV[i].shape[2])]) hd_lv_ed = max([hausdorff_distance(label_pred_ED_LV[i][:,:,y], label_gt_ED_LV[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_LV[i].shape[2])])
hd_wall_ed = max([hausdorff_distance(label_pred_ED_wall[i][:,:,y], label_gt_ED_wall[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_wall[i].shape[2])]) hd_wall_ed = max([hausdorff_distance(label_pred_ED_wall[i][:,:,y], label_gt_ED_wall[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_wall[i].shape[2])])
hd_rv_ed = max([hausdorff_distance(label_pred_ED_RV[i][:,:,y], label_gt_ED_RV[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_RV[i].shape[2])]) hd_rv_ed = max([hausdorff_distance(label_pred_ED_RV[i][:,:,y], label_gt_ED_RV[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_RV[i].shape[2])])
hd_LV_ED.append(hd_lv_ed) hd_LV_ED.append(hd_lv_ed)
hd_wall_ED.append(hd_wall_ed) hd_wall_ED.append(hd_wall_ed)
hd_RV_ED.append(hd_rv_ed) hd_RV_ED.append(hd_rv_ed)
# Calculate Hausdorff distance for ES # Calculate Hausdorff distance for ES
hd_lv_es = max([hausdorff_distance(label_pred_ES_LV[i][:,:,y], label_gt_ES_LV[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_LV[i].shape[2])]) hd_lv_es = max([hausdorff_distance(label_pred_ES_LV[i][:,:,y], label_gt_ES_LV[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_LV[i].shape[2])])
hd_wall_es = max([hausdorff_distance(label_pred_ES_wall[i][:,:,y], label_gt_ES_wall[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_wall[i].shape[2])]) hd_wall_es = max([hausdorff_distance(label_pred_ES_wall[i][:,:,y], label_gt_ES_wall[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_wall[i].shape[2])])
hd_rv_es = max([hausdorff_distance(label_pred_ES_RV[i][:,:,y], label_gt_ES_RV[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_RV[i].shape[2])]) hd_rv_es = max([hausdorff_distance(label_pred_ES_RV[i][:,:,y], label_gt_ES_RV[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_RV[i].shape[2])])
hd_LV_ES.append(hd_lv_es) hd_LV_ES.append(hd_lv_es)
hd_wall_ES.append(hd_wall_es) hd_wall_ES.append(hd_wall_es)
hd_RV_ES.append(hd_rv_es) hd_RV_ES.append(hd_rv_es)
``` ```
%% Cell type:code id:5c4c95a1 tags: %% Cell type:code id:5c4c95a1 tags:
``` python ``` python
print(len(hd_LV_ED)) print(len(hd_LV_ED))
print(hd_wall_ES[0]) print(hd_wall_ES[0])
``` ```
%% Cell type:code id:2b9b673e tags: %% Cell type:code id:2b9b673e tags:
``` python ``` python
# Calculations of average HD # Calculations of average HD
# ED # ED
HD_LV_ED_mean = statistics.mean(hd_LV_ED) HD_LV_ED_mean = statistics.mean(hd_LV_ED)
HD_RV_ED_mean = statistics.mean(hd_RV_ED) HD_RV_ED_mean = statistics.mean(hd_RV_ED)
HD_MYO_ED_mean = statistics.mean(hd_wall_ED) HD_MYO_ED_mean = statistics.mean(hd_wall_ED)
print('The average hd of the left ventricle EDV is ', HD_LV_ED_mean) print('The average hd of the left ventricle EDV is ', HD_LV_ED_mean)
print('The average hd of the right ventricle EDV is ', HD_RV_ED_mean) print('The average hd of the right ventricle EDV is ', HD_RV_ED_mean)
print('The average hd of the myocardium EDV is ', HD_MYO_ED_mean) print('The average hd of the myocardium EDV is ', HD_MYO_ED_mean)
HD_LV_ES_mean = statistics.mean(hd_LV_ES) HD_LV_ES_mean = statistics.mean(hd_LV_ES)
HD_RV_ES_mean = statistics.mean(hd_RV_ES) HD_RV_ES_mean = statistics.mean(hd_RV_ES)
HD_MYO_ES_mean = statistics.mean(hd_wall_ES) HD_MYO_ES_mean = statistics.mean(hd_wall_ES)
print('The average hd of the left ventricle ESV is ', HD_LV_ES_mean) print('The average hd of the left ventricle ESV is ', HD_LV_ES_mean)
print('The average hd of the right ventricle ESV is ', HD_RV_ES_mean) print('The average hd of the right ventricle ESV is ', HD_RV_ES_mean)
print('The average hd of the myocardium ESV is ', HD_MYO_ES_mean) print('The average hd of the myocardium ESV is ', HD_MYO_ES_mean)
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment