Skip to content
Snippets Groups Projects
Commit e86db1f6 authored by Bart Leenheer's avatar Bart Leenheer
Browse files

Added the ED and ES_prediction_example.png

parent 5a14eb63
Branches master
No related tags found
No related merge requests found
ED_prediction_example.png

114 KiB

ES_prediction_example.png

113 KiB

%% Cell type:code id:df7898a3 tags:
``` python
%pip install hausdorff
%pip install numba
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import numpy as np
import nibabel as nib
from ACDCUNet import build_dict_images, build_dict_images_pred
import statistics
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from hausdorff import hausdorff_distance
```
%% Cell type:code id:d744133d tags:
``` python
"""
author: Clément Zotti (clement.zotti@usherbrooke.ca)
date: April 2017
DESCRIPTION :
The script provide helpers functions to handle nifti image format:
- load_nii()
- save_nii()
to generate metrics for two images:
- metrics()
And it is callable from the command line (see below).
Each function provided in this script has comments to understand
how they works.
HOW-TO:
This script was tested for python 3.4.
First, you need to install the required packages with
pip install -r requirements.txt
After the installation, you have two ways of running this script:
1) python metrics.py ground_truth/patient001_ED.nii.gz prediction/patient001_ED.nii.gz
2) python metrics.py ground_truth/ prediction/
The first option will print in the console the dice and volume of each class for the given image.
The second option wiil ouput a csv file where each images will have the dice and volume of each class.
Link: http://acdc.creatis.insa-lyon.fr
"""
HEADER = [
"Name",
"Dice LV",
"Volume LV pred",
"Volume LV GT",
"Err LV(ml)",
"Dice RV",
"Volume RV pred",
"Volume RV GT",
"Err RV(ml)",
"Dice MYO",
"Volume MYO pred",
"Volume MYO GT",
"Err MYO(ml)",
]
#
# Utils functions used to sort strings into a natural order
#
def conv_int(i):
return int(i) if i.isdigit() else i
def natural_order(sord):
"""
Sort a (list,tuple) of strings into natural order.
Ex:
['1','10','2'] -> ['1','2','10']
['abc1def','ab10d','b2c','ab1d'] -> ['ab1d','ab10d', 'abc1def', 'b2c']
"""
if isinstance(sord, tuple):
sord = sord[0]
return [conv_int(c) for c in re.split(r"(\d+)", sord)]
#
# Utils function to load and save nifti files with the nibabel package
#
img_path = "ACDC\database"
def load_nii(img_path):
"""
Function to load a 'nii' or 'nii.gz' file, The function returns
everyting needed to save another 'nii' or 'nii.gz'
in the same dimensional space, i.e. the affine matrix and the header
Parameters
----------
img_path: string
String with the path of the 'nii' or 'nii.gz' image file name.
Returns
-------
Three element, the first is a numpy array of the image values,
the second is the affine transformation of the image, and the
last one is the header of the image.
"""
nimg = nib.load(img_path)
return nimg.get_fdata(), nimg.affine, nimg.header
def save_nii(img_path, data, affine, header):
"""
Function to save a 'nii' or 'nii.gz' file.
Parameters
----------
img_path: string
Path to save the image should be ending with '.nii' or '.nii.gz'.
data: np.array
Numpy array of the image data.
affine: list of list or np.array
The affine transformation to save with the image.
header: nib.Nifti1Header
The header that define everything about the data
(pleasecheck nibabel documentation).
"""
nimg = nib.Nifti1Image(data, affine=affine, header=header)
nimg.to_filename(img_path)
#
# Functions to process files, directories and metrics
#
def metrics(img_gt, img_pred, voxel_size):
"""
Function to compute the metrics between two segmentation maps given as input.
Parameters
----------
img_gt: np.array
Array of the ground truth segmentation map.
img_pred: np.array
Array of the predicted segmentation map.
voxel_size: list, tuple or np.array
The size of a voxel of the images used to compute the volumes.
Return
------
A list of metrics in this order, [Dice LV, Volume LV, Volume GT, Err LV(ml),
Dice RV, Volume RV, Volume GT, Err RV(ml), Dice MYO, Volume MYO, Volume GT, Err MYO(ml)]
"""
if img_gt.ndim != img_pred.ndim:
raise ValueError(
"The arrays 'img_gt' and 'img_pred' should have the "
"same dimension, {} against {}".format(img_gt.ndim, img_pred.ndim)
)
res = []
# Loop on each classes of the input images
for c in [3, 1, 2]:
# Copy the gt image to not alterate the input
gt_c_i = np.copy(img_gt)
gt_c_i[gt_c_i != c] = 0
# Copy the pred image to not alterate the input
pred_c_i = np.copy(img_pred)
pred_c_i[pred_c_i != c] = 0
# Clip the value to compute the volumes
gt_c_i = np.clip(gt_c_i, 0, 1)
pred_c_i = np.clip(pred_c_i, 0, 1)
# Compute the Dice
# dice = dc(gt_c_i, pred_c_i)
dice = 1
# Eventueel alternatief
gt_volume = np.sum(gt_c_i)
pred_volume = np.sum(pred_c_i)
intersect = np.sum(gt_c_i * pred_c_i)
dice = (2 * intersect) / (gt_volume + pred_volume)
# Compute volume
volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000.0
volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000.0
res += [dice, volpred, volgt, volpred - volgt]
return res
def compute_metrics_on_files(path_gt, path_pred):
"""
Function to give the metrics for two files
Parameters
----------
path_gt: string
Path of the ground truth image.
path_pred: string
Path of the predicted image.
"""
gt, _, header = load_nii(path_gt)
pred, _, _ = load_nii(path_pred)
zooms = header.get_zooms()
name = os.path.basename(path_gt)
name = name.split(".")[0]
res = metrics(gt, pred, zooms)
res = ["{:.3f}".format(r) for r in res]
formatting = "{:<20}" + "{:>12}" * len(res)
output = formatting.format(name, *res)
print(formatting.format(*HEADER))
print(output)
# formatting = "{:>14}, {:>7}, {:>9}, {:>10}, {:>7}, {:>9}, {:>10}, {:>8}, {:>10}, {:>11}"
# print(formatting.format(*HEADER))
# print(formatting.format(name, *res))
# return [name, *res]
return res
def compute_metrics_on_directories(dir_gt, dir_pred):
"""
Function to generate a csv file for each images of two directories.
Parameters
----------
path_gt: string
Directory of the ground truth segmentation maps.
path_pred: string
Directory of the predicted segmentation maps.
"""
lst_gt = sorted(glob(os.path.join(dir_gt, "*")), key=natural_order)
lst_pred = sorted(glob(os.path.join(dir_pred, "*")), key=natural_order)
res = []
for p_gt, p_pred in zip(lst_gt, lst_pred):
if os.path.basename(p_gt) != os.path.basename(p_pred):
raise ValueError(
"The two files don't have the same name"
" {}, {}.".format(os.path.basename(p_gt), os.path.basename(p_pred))
)
gt, _, header = load_nii(p_gt)
pred, _, _ = load_nii(p_pred)
zooms = header.get_zooms()
res.append(metrics(gt, pred, zooms))
lst_name_gt = [os.path.basename(gt).split(".")[0] for gt in lst_gt]
res = [
[
n,
]
+ r
for r, n in zip(res, lst_name_gt)
]
df = pd.DataFrame(res, columns=HEADER)
df.to_csv("results_{}.csv".format(time.strftime("%Y%m%d_%H%M%S")), index=False)
def main(path_gt, path_pred):
"""
Main function to select which method to apply on the input parameters.
"""
if os.path.isfile(path_gt) and os.path.isfile(path_pred):
compute_metrics_on_files(path_gt, path_pred)
elif os.path.isdir(path_gt) and os.path.isdir(path_pred):
compute_metrics_on_directories(path_gt, path_pred)
else:
raise ValueError("The paths given needs to be two directories or two files.")
# if __name__ == "__main__":
# parser = argparse.ArgumentParser(
# description="Script to compute ACDC challenge metrics.")
# parser.add_argument("GT_IMG", type=str, help="Ground Truth image")
# parser.add_argument("PRED_IMG", type=str, help="Predicted image")
# args = parser.parse_args()
# main(args.GT_IMG, args.PRED_IMG)
```
%% Cell type:code id:142caccb tags:
``` python
# Testing
path_gt = os.path.join('ACDC','database','testing','patient101','patient101_frame01.nii.gz')
dir_gt = os.path.join('ACDC','database','testing','patient101','patient101_frame01_gt.nii.gz')
path_pred = os.path.join('ACDC','database','testing','patient101','patient101_frame01.nii.gz')
dir_pred = os.path.join('ACDC','database','testing','patient101','patient101_frame01_ml_pred.nii.gz')
test_files = compute_metrics_on_files(dir_gt, dir_pred)
```
%% Cell type:code id:0451a847 tags:
``` python
data_path = "ACDC/database"
test_dict_gt = build_dict_images(data_path, mode='testing').ravel()
test_dict_pred = build_dict_images_pred(data_path, mode='testing')
# Ground truth
# First file
print(len(test_dict_gt))
image_paths_gt_ED = [d['image'] for d in test_dict_gt if d["first_file"]]
label_paths_gt_ED = [d['label'] for d in test_dict_gt if d["first_file"]]
# Last file
image_paths_gt_ES = [d['image'] for d in test_dict_gt if not d["first_file"]]
label_paths_gt_ES = [d['label'] for d in test_dict_gt if not d["first_file"]]
print('Image path gt ED is ', image_paths_gt_ED)
print('Label path gt ED is ', label_paths_gt_ED)
print('Image path gt ES is ', image_paths_gt_ES)
print('Label path gt ES is ', label_paths_gt_ES)
# Prediction
image_paths_pred_ED = [d['image'] for d in test_dict_pred if d["first_file"]]
label_paths_pred_ED = [d['label'] for d in test_dict_pred if d["first_file"]]
image_paths_pred_ES = [d['image'] for d in test_dict_pred if not d["first_file"]]
label_paths_pred_ES = [d['label'] for d in test_dict_pred if not d["first_file"]]
print('Image path prediction ED is ', image_paths_pred_ED)
print('Label path prediction ED is ', label_paths_pred_ED)
print('Image path prediction ES is ', image_paths_pred_ES)
print('Label path prediction ES is ', label_paths_pred_ES)
```
%% Cell type:code id:6c7be742 tags:
``` python
test_files_ED = [compute_metrics_on_files(label_paths_gt_ED[i], label_paths_pred_ED[i]) for i in range(len(label_paths_gt_ED))]
test_files_ES = [compute_metrics_on_files(label_paths_gt_ES[i], label_paths_pred_ES[i]) for i in range(len(label_paths_gt_ES))]
```
%% Cell type:code id:049a0262 tags:
``` python
# ED
LV_GT_ED = np.array([float(metric[2]) for metric in test_files_ED])
LV_pred_ED = np.array([float(metric[1]) for metric in test_files_ED])
LV_dice_ED = np.array([float(metric[0]) for metric in test_files_ED])
LV_err_ED = np.array([float(metric[3]) for metric in test_files_ED])
RV_GT_ED = np.array([float(metric[6]) for metric in test_files_ED])
RV_pred_ED = np.array([float(metric[5]) for metric in test_files_ED])
RV_dice_ED = np.array([float(metric[4]) for metric in test_files_ED])
RV_err_ED = np.array([float(metric[7]) for metric in test_files_ED])
MYO_GT_ED = np.array([float(metric[10]) for metric in test_files_ED])
MYO_pred_ED = np.array([float(metric[9]) for metric in test_files_ED])
MYO_dice_ED = np.array([float(metric[8]) for metric in test_files_ED])
MYO_err_ED = np.array([float(metric[11]) for metric in test_files_ED])
# ES
LV_GT_ES = np.array([float(metric[2]) for metric in test_files_ES])
LV_pred_ES = np.array([float(metric[1]) for metric in test_files_ES])
LV_dice_ES = np.array([float(metric[0]) for metric in test_files_ES])
LV_err_ES = np.array([float(metric[3]) for metric in test_files_ES])
RV_GT_ES = np.array([float(metric[6]) for metric in test_files_ES])
RV_pred_ES = np.array([float(metric[5]) for metric in test_files_ES])
RV_dice_ES = np.array([float(metric[4]) for metric in test_files_ES])
RV_err_ES = np.array([float(metric[7]) for metric in test_files_ES])
MYO_GT_ES = np.array([float(metric[10]) for metric in test_files_ES])
MYO_pred_ES = np.array([float(metric[9]) for metric in test_files_ES])
MYO_dice_ES = np.array([float(metric[8]) for metric in test_files_ES])
MYO_err_ES = np.array([float(metric[11]) for metric in test_files_ES])
```
%% Cell type:markdown id:7c36b2e0 tags:
Other ways for visualization
> Comparison graphs
%% Cell type:code id:6fd62eb7 tags:
``` python
def show_comparison_graph(gt, pred, dice, err, color_boundaries=20, part="LV", mode="ED"):
plt.clf()
vmin = min(err)
vmax = max(err)
colors = ['red', 'yellow', 'green', 'yellow', 'red']
color_positions = [vmin, -1 * color_boundaries, 0, color_boundaries, vmax]
norm = plt.Normalize(vmin, vmax)
print(norm(color_positions))
colormap = list(zip(norm(color_positions), colors))
cmap = LinearSegmentedColormap.from_list("custom_cmap", colormap)
plt.scatter(gt, pred, c=err, cmap=cmap, vmin=vmin, vmax=vmax, s=dice * 100)
plt.plot([min(gt), max(gt)], [min(gt), max(gt)], 'k--')
plt.xlabel(f'Ground Truth {part} Volume [ml]')
plt.ylabel(f'Predicted {part} Volume [ml]')
plt.title(f'Ground Truth vs. Predicted Volume of {part} {mode}')
cbar = plt.colorbar()
cbar.set_label('Error Value')
plt.show()
```
%% Cell type:code id:1541570e tags:
``` python
show_comparison_graph(LV_GT_ED, LV_pred_ED, LV_dice_ED, LV_err_ED)
```
%% Cell type:code id:93c71590 tags:
``` python
show_comparison_graph(RV_GT_ED, RV_pred_ED, RV_dice_ED, RV_err_ED, color_boundaries=18, part="RV")
```
%% Cell type:code id:fa1aa317 tags:
``` python
show_comparison_graph(MYO_GT_ED, MYO_pred_ED, MYO_dice_ED, MYO_err_ED, color_boundaries=18, part="MYO")
```
%% Cell type:markdown id:5b91e382 tags:
> Correlations
%% Cell type:code id:a21ffc52 tags:
``` python
# Correlation coefficients
def calculate_correlation(gt, pred, err):
gt_mean = statistics.mean(gt)
gt_std = statistics.stdev(gt)
pred_mean = statistics.mean(pred)
pred_std = statistics.stdev(pred)
covariance = sum((x - gt_mean) * (y - pred_mean) for x, y in zip(gt, pred)) / len(gt)
corr = covariance / (gt_std * pred_std)
bias = statistics.mean(err)
err_std = statistics.stdev(err)
loa = 1.96*err_std
return corr, bias, loa
```
%% Cell type:code id:a4ecf214 tags:
``` python
LV_corr, LV_bias, LV_LOA = calculate_correlation(LV_GT_ED, LV_pred_ED, LV_err_ED)
print('The correlation coefficient for left ventricle EDV is ', LV_corr)
print('The bias for the left ventricle EDV is ', LV_bias)
print('The LOA of the left ventricle EDV is ', LV_LOA)
```
%% Cell type:code id:12c061d2 tags:
``` python
RV_corr, RV_bias, RV_LOA = calculate_correlation(RV_GT_ED, RV_pred_ED, RV_err_ED)
print('The correlation coefficient for right ventricle EDV is ', RV_corr)
print('The bias for the right ventricle EDV is ', RV_bias)
print('The LOA of the right ventricle EDV is ', RV_LOA)
```
%% Cell type:code id:ccb53139 tags:
``` python
MYO_corr, MYO_bias, MYO_LOA = calculate_correlation(MYO_GT_ED, MYO_pred_ED, MYO_err_ED)
print('The correlation coefficient for myocardium EDV is ', MYO_corr)
print('The bias for the left myocardium EDV is ', MYO_bias)
print('The LOA of the left myocardium EDV is ', MYO_LOA)
```
%% Cell type:code id:d85b01dd tags:
``` python
# Average dice
av_LV_dice = statistics.mean(LV_dice_ED)
av_RV_dice = statistics.mean(RV_dice_ED)
av_MYO_dice = statistics.mean(MYO_dice_ED)
print('Average dice left ventricle is ', av_LV_dice)
print('Average dice right ventricle is ', av_RV_dice)
print('Average dice myocardium is ', av_MYO_dice)
```
%% Cell type:markdown id:ce64f7e5 tags:
ES
%% Cell type:code id:437c2b32 tags:
``` python
show_comparison_graph(LV_GT_ES, LV_pred_ES, LV_dice_ES, LV_err_ES, color_boundaries=15, part="LV", mode="ES")
```
%% Cell type:code id:0ff66a64 tags:
``` python
show_comparison_graph(RV_GT_ES, RV_pred_ES, RV_dice_ES, RV_err_ES, color_boundaries=10, part="RV", mode="ES")
```
%% Cell type:code id:931d0abb tags:
``` python
show_comparison_graph(MYO_GT_ES, MYO_pred_ES, MYO_dice_ES, MYO_err_ES, color_boundaries=30, part="MYO", mode="ES")
```
%% Cell type:code id:4d30933c tags:
``` python
LV_corr, LV_bias, LV_LOA = calculate_correlation(LV_GT_ES, LV_pred_ES, LV_err_ES)
print('The correlation coefficient for left ventricle ESV is ', LV_corr)
print('The bias for the left ventricle ESV is ', LV_bias)
print('The LOA of the left ventricle ESV is ', LV_LOA)
```
%% Cell type:code id:823112fe tags:
``` python
RV_corr, RV_bias, RV_LOA = calculate_correlation(RV_GT_ES, RV_pred_ES, RV_err_ES)
print('The correlation coefficient for right ventricle ESV is ', RV_corr)
print('The bias for the right ventricle ESV is ', RV_bias)
print('The LOA of the right ventricle ESV is ', RV_LOA)
```
%% Cell type:code id:50408bf1 tags:
``` python
MYO_corr, MYO_bias, MYO_LOA = calculate_correlation(MYO_GT_ES, MYO_pred_ES, MYO_err_ES)
print('The correlation coefficient for myocardium ESV is ', MYO_corr)
print('The bias for the left myocardium ESV is ', MYO_bias)
print('The LOA of the left myocardium ESV is ', MYO_LOA)
```
%% Cell type:code id:cf0f0d91 tags:
``` python
# Average dice
av_LV_dice = statistics.mean(LV_dice_ES)
av_RV_dice = statistics.mean(RV_dice_ES)
av_MYO_dice = statistics.mean(MYO_dice_ES)
print('Average dice left ventricle ESV is ', av_LV_dice)
print('Average dice right ventricle ESV is ', av_RV_dice)
print('Average dice myocardium ESV is ', av_MYO_dice)
```
%% Cell type:markdown id:ecf8346f tags:
EJ
%% Cell type:code id:c1671753 tags:
``` python
EJ_LV_GT = [((LV_GT_ED[i] - LV_GT_ES[i])/ LV_GT_ED[i]) * 100 for i in range(len(LV_GT_ED))]
EJ_RV_GT = [((RV_GT_ED[i] - RV_GT_ES[i])/ RV_GT_ED[i]) * 100 for i in range(len(RV_GT_ED))]
EJ_LV_pred = [((LV_pred_ED[i] - LV_pred_ES[i])/ (LV_pred_ED[i]+0.1)) * 100 for i in range(len(LV_pred_ED))]
EJ_RV_pred = [((RV_pred_ED[i] - RV_pred_ES[i])/ (RV_pred_ED[i]+0.1)) * 100 for i in range(len(RV_pred_ED))]
EJ_LV_err = [EJ_LV_GT[i] - EJ_LV_pred[i] for i in range(len(EJ_LV_GT))]
EJ_RV_err = [EJ_RV_GT[i] - EJ_RV_pred[i] for i in range(len(EJ_RV_GT))]
EJ_LV_corr, EJ_LV_bias, LV_LOA = calculate_correlation(EJ_LV_GT, EJ_LV_pred, EJ_LV_err)
print('The correlation of the ejection fraction for left ventricle is ', EJ_LV_corr)
print('The mean bias of the ejection fraction of left ventricle is ', EJ_LV_bias)
print('The LOA of the ejection fraction of the left ventricle is ', LV_LOA)
EJ_RV_corr, EJ_RV_bias, RV_LOA = calculate_correlation(EJ_RV_GT, EJ_RV_pred, EJ_RV_err)
print('The correlation of the ejection fraction for right ventricle is ', EJ_RV_corr)
print('The mean bias of the ejection fraction of right ventricle is ', EJ_RV_bias)
print('The LOA of the ejection fraction of the right ventricle is ', RV_LOA)
```
%% Cell type:markdown id:de9c181c tags:
Best and worst results
%% Cell type:code id:c9a463de tags:
``` python
# ES
max_value = max(LV_dice_ES)
min_value = sorted(LV_dice_ES)[1]
# This determines the best and worst performing predictions based on the dice score of all the segments and both time frames
# Uses the LV_dice_ES list to find back the correct patient, but could be any of the lists
score_sum = {LV_dice_ES[x]: sum([LV_dice_ES[x], RV_dice_ES[x], MYO_dice_ES[x], LV_dice_ED[x], RV_dice_ED[x], MYO_dice_ED[x]]) for x in range(len(LV_dice_ES))}
max_value = max(score_sum, key=score_sum.get)
min_value = min(score_sum, key=score_sum.get)
print(min_value)
del score_sum[min_value]
min_value = min(score_sum, key=score_sum.get)
print(min_value)
max_index = np.where(LV_dice_ES == max_value)[0][0]
min_index = np.where(LV_dice_ES == min_value)[0][0]
print("Max value:", max_value)
print("Max index:", max_index)
print("Min value:", min_value)
print("Min index:", min_index)
# Show slices
```
%% Cell type:code id:0ef0ed36 tags:
``` python
print('Best image path gt ES is ', image_paths_gt_ES[max_index])
print('Best label path gt ES is ', label_paths_gt_ES[max_index])
print('Best image path prediction ES is ', image_paths_pred_ES[max_index])
print('Best label path prediction ES is ', label_paths_pred_ES[max_index])
```
%% Cell type:code id:35b72c7c tags:
``` python
def get_slice(index, image_gt, label_gt, image_pred, label_pred, channel_index=0, frame='ES'):
image_path_gt = image_gt[index]
label_path_gt = label_gt[index]
image_path_pred = image_pred[index]
label_path_pred = label_pred[index]
# Load image and label data
image_gt = nib.load(image_path_gt).get_fdata()
label_gt_ES = nib.load(label_path_gt).get_fdata()
image_pred = nib.load(image_path_pred).get_fdata()
label_pred = nib.load(label_path_pred).get_fdata()
return (
[
image_gt[..., channel_index],
label_gt_ES[..., channel_index],
image_pred[..., channel_index],
label_pred[..., channel_index]
]
)
def display_slices(min_index, max_index, image_gt, label_gt, image_pred, label_pred, channel_index=0, frame='ES'):
best_image_gt_ES_channel, best_label_gt_ES_channel, best_image_pred_ES_channel, best_label_pred_ES_channel = get_slice(max_index, image_gt, label_gt, image_pred, label_pred, channel_index)
worst_image_gt_ES_channel, worst_label_gt_ES_channel, worst_image_pred_ES_channel, worst_label_pred_ES_channel = get_slice(min_index, image_gt, label_gt, image_pred, label_pred, channel_index)
plt.subplot(2, 2, 1)
plt.imshow(best_image_gt_ES_channel, cmap='gray')
plt.title(f"Ground truth best prediction {frame}")
plt.imshow(best_label_gt_ES_channel, alpha=0.5, cmap='Reds')
plt.subplot(2, 2, 2)
plt.imshow(best_image_pred_ES_channel, cmap='gray')
plt.title(f"Best prediction {frame}")
plt.imshow(best_label_pred_ES_channel, cmap='Reds', alpha=0.5)
plt.subplot(2, 2, 3)
plt.imshow(worst_image_gt_ES_channel, cmap='gray')
plt.title(f"Ground truth worst prediction {frame}")
plt.imshow(worst_label_gt_ES_channel, alpha=0.5, cmap='Reds')
plt.subplot(2, 2, 4)
plt.imshow(worst_image_pred_ES_channel, cmap='gray')
plt.title(f"Worst prediction {frame}")
plt.imshow(worst_label_pred_ES_channel, cmap='Reds', alpha=0.5)
plt.tight_layout()
plt.show()
```
%% Cell type:code id:8074a3d4 tags:
``` python
display_slices(min_index, max_index, image_paths_gt_ES, label_paths_gt_ES, image_paths_pred_ES, label_paths_pred_ES, channel_index=0, frame='ES')
display_slices(min_index, max_index, image_paths_gt_ES, label_paths_gt_ES, image_paths_pred_ES, label_paths_pred_ES, channel_index=4, frame='ES')
compute_metrics_on_files(label_paths_gt_ED[max_index], label_paths_pred_ED[max_index])
compute_metrics_on_files(label_paths_gt_ES[max_index], label_paths_pred_ES[max_index])
print()
```
%% Cell type:code id:e2d5604a tags:
``` python
display_slices(min_index, max_index, image_paths_gt_ED, label_paths_gt_ED, image_paths_pred_ED, label_paths_pred_ED, channel_index=1, frame='ED')
display_slices(min_index, max_index, image_paths_gt_ED, label_paths_gt_ED, image_paths_pred_ED, label_paths_pred_ED, channel_index=4, frame='ED')
compute_metrics_on_files(label_paths_gt_ED[min_index], label_paths_pred_ED[min_index])
compute_metrics_on_files(label_paths_gt_ES[min_index], label_paths_pred_ES[min_index])
print()
```
%% Cell type:markdown id:c4ee1d4a tags:
Hausdorff distances
%% Cell type:code id:a633891f tags:
``` python
# First file
label_paths_gt_ED = [d['label'] for d in test_dict_gt if d["first_file"]]
print(len(label_paths_gt_ED))
# Last file
label_paths_gt_ES = [d['label'] for d in test_dict_gt if not d["first_file"]]
# Prediction
label_paths_pred_ED = [d['label'] for d in test_dict_pred if d["first_file"]]
label_paths_pred_ES = [d['label'] for d in test_dict_pred if not d["first_file"]]
```
%% Cell type:code id:f3a8a4e0 tags:
``` python
# GT and pred of ED
dir_gt_ED = []
dir_pred_ED = []
for label_path_gt_ED, label_path_pred_ED in zip(label_paths_gt_ED, label_paths_pred_ED):
label_GT_ED = nib.load(label_path_gt_ED).get_fdata()
label_pred_ED = nib.load(label_path_pred_ED).get_fdata()
dir_gt_ED.append(label_GT_ED)
dir_pred_ED.append(label_pred_ED)
# GT and pred of ES
dir_gt_ES = []
dir_pred_ES = []
for label_path_gt, label_path_pred in zip(label_paths_gt_ES, label_paths_pred_ES):
label_GT_ES = nib.load(label_path_gt).get_fdata()
label_pred = nib.load(label_path_pred).get_fdata()
dir_gt_ES.append(label_GT_ES)
dir_pred_ES.append(label_pred)
```
%% Cell type:code id:106f8cb8 tags:
``` python
label_gt_ED = []
label_gt_ED_LV = []
label_gt_ED_wall = []
label_gt_ED_RV = []
for i in range(50):
label_gt_ED_single = np.array(dir_gt_ED[i], dtype=np.float32)
label_gt_ED.append(label_gt_ED_single)
label_gt_ED_LV.append(np.where(label_gt_ED_single == 1,1,0))
label_gt_ED_wall.append(np.where(label_gt_ED_single == 2,1,0))
label_gt_ED_RV.append(np.where(label_gt_ED_single == 3,1,0))
label_pred_ED = []
label_pred_ED_LV = []
label_pred_ED_wall = []
label_pred_ED_RV = []
for i in range(50):
label_pred_ED_single = np.array(dir_pred_ED[i], dtype=np.float32)
label_pred_ED.append(label_pred_ED_single)
label_pred_ED_LV.append(np.where(label_pred_ED_single == 1,1,0))
label_pred_ED_wall.append(np.where(label_pred_ED_single == 2,1,0))
label_pred_ED_RV.append(np.where(label_pred_ED_single == 3,1,0))
label_gt_ES = []
label_gt_ES_LV = []
label_gt_ES_wall = []
label_gt_ES_RV = []
for i in range(50):
label_gt_ES_single = np.array(dir_gt_ES[i], dtype=np.float32)
label_gt_ES.append(label_gt_ES_single)
label_gt_ES_LV.append(np.where(label_gt_ES_single == 1,1,0))
label_gt_ES_wall.append(np.where(label_gt_ES_single == 2,1,0))
label_gt_ES_RV.append(np.where(label_gt_ES_single == 3,1,0))
label_pred = []
label_pred_ES_LV = []
label_pred_ES_wall = []
label_pred_ES_RV = []
for i in range(50):
label_pred_ES_single = np.array(dir_pred_ES[i], dtype=np.float32)
label_pred.append(label_pred_ES_single)
label_pred_ES_LV.append(np.where(label_pred_ES_single == 1,1,0))
label_pred_ES_wall.append(np.where(label_pred_ES_single == 2,1,0))
label_pred_ES_RV.append(np.where(label_pred_ES_single == 3,1,0))
```
%% Cell type:code id:385e0f03 tags:
``` python
print(len(label_pred_ES_RV))
```
%% Cell type:code id:9a682e84 tags:
``` python
hd_LV_ED = []
hd_wall_ED = []
hd_RV_ED = []
hd_LV_ES = []
hd_wall_ES = []
hd_RV_ES = []
for i in range(50):
# Calculate Hausdorff distance for ED
hd_lv_ed = max([hausdorff_distance(label_pred_ED_LV[i][:,:,y], label_gt_ED_LV[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_LV[i].shape[2])])
hd_wall_ed = max([hausdorff_distance(label_pred_ED_wall[i][:,:,y], label_gt_ED_wall[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_wall[i].shape[2])])
hd_rv_ed = max([hausdorff_distance(label_pred_ED_RV[i][:,:,y], label_gt_ED_RV[i][:,:,y], distance='euclidean') for y in range(label_pred_ED_RV[i].shape[2])])
hd_LV_ED.append(hd_lv_ed)
hd_wall_ED.append(hd_wall_ed)
hd_RV_ED.append(hd_rv_ed)
# Calculate Hausdorff distance for ES
hd_lv_es = max([hausdorff_distance(label_pred_ES_LV[i][:,:,y], label_gt_ES_LV[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_LV[i].shape[2])])
hd_wall_es = max([hausdorff_distance(label_pred_ES_wall[i][:,:,y], label_gt_ES_wall[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_wall[i].shape[2])])
hd_rv_es = max([hausdorff_distance(label_pred_ES_RV[i][:,:,y], label_gt_ES_RV[i][:,:,y], distance='euclidean') for y in range(label_pred_ES_RV[i].shape[2])])
hd_LV_ES.append(hd_lv_es)
hd_wall_ES.append(hd_wall_es)
hd_RV_ES.append(hd_rv_es)
```
%% Cell type:code id:5c4c95a1 tags:
``` python
print(len(hd_LV_ED))
print(hd_wall_ES[0])
```
%% Cell type:code id:2b9b673e tags:
``` python
# Calculations of average HD
# ED
HD_LV_ED_mean = statistics.mean(hd_LV_ED)
HD_RV_ED_mean = statistics.mean(hd_RV_ED)
HD_MYO_ED_mean = statistics.mean(hd_wall_ED)
print('The average hd of the left ventricle EDV is ', HD_LV_ED_mean)
print('The average hd of the right ventricle EDV is ', HD_RV_ED_mean)
print('The average hd of the myocardium EDV is ', HD_MYO_ED_mean)
HD_LV_ES_mean = statistics.mean(hd_LV_ES)
HD_RV_ES_mean = statistics.mean(hd_RV_ES)
HD_MYO_ES_mean = statistics.mean(hd_wall_ES)
print('The average hd of the left ventricle ESV is ', HD_LV_ES_mean)
print('The average hd of the right ventricle ESV is ', HD_RV_ES_mean)
print('The average hd of the myocardium ESV is ', HD_MYO_ES_mean)
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment