Skip to content
Snippets Groups Projects
Commit 45bcc21d authored by Doorn, Nina (UT-TNW)'s avatar Doorn, Nina (UT-TNW)
Browse files

Script to perform PPC and calculate PRE of different trained NDEs

parent af50ab80
Branches
No related tags found
No related merge requests found
# To perform and visualeze posterior predictive checks on a trained density estimator
import torch
from brian2 import *
from sbi import analysis, utils
import numpy as np
from FeatureExtraction import compute_features, compute_spikerate
from Simulator import MEAnetsimulate
from MakeFigures import rasterplot
import pickle
num_feats = 15 # Number of MEA features
num_params = 10 # Number of free parameters of the model
prior_min = [1.5, 0.5, 0.1, 0.5, 0.05, 0., 0.1, 150, 0.005, 0.]
prior_max = [7, 2, 10, 10, 1, 1, 0.6, 1200, 0.3, 0.005]
prior = utils.BoxUniform(low=torch.tensor(prior_min), high=torch.tensor(prior_max))
prior_limits = [[1.5, 7], [0.5, 2], [0.1, 10], [0.5, 10], [0.05, 1], [0., 1], [0.1, 0.6],
[150, 1200], [0.005, 0.3], [0., 0.005]]
par_labels = ['noise', '$g_{Na}$', '$g_{K}$', '$g_{AHP}$', '$g_{AMPA}$', '$g_{NMDA}$', 'Conn%',
r'$\tau_{D}$', 'U (STD)', 'U asyn']
feat_labels = ['MFR', 'NBR', 'NBD', 'PSIB', '#FBs', 'CVIBI', 'mean CC', 'sd CC', 'mean ISI CC', 'sd ISI CC', 'ISI dist',
'mean ISI', 'sd ISI temp', 'sd isi elec', 'MAC']
# load your trained density estimator
with open('TrainedNDE', 'rb') as f: # with open('Posterior_Features', 'rb') as f:
posterior = pickle.load(f)
embedding_net = False # true means NDE was trained with embedding net on spike rates per electrode
# false mean NDE was trained on 15 MEA features
# Visualize the performance on one set of ground-truth parameters
# set of ground-truth parameters (define yourself):
test_params = torch.as_tensor([4, 1, 0.4, 5, 0.2, 0.4, 0.15, 200, 0.1, 0.0001])
# or alternatively, draw a random sample from prior:
# test_params = prior.sample((1,))
# perform a simulations with the parameters
APs, simtime, transient, fs = MEAnetsimulate(test_params)
rasterplot(APs, 'PPC1_Feat_pre', 1/fs, transient, simtime, 'black')
if embedding_net:
numelectrodes = 12
time_bin = 100e-3
spikerate = compute_spikerate(APs, simtime/second, transient/second, fs, time_bin)
spikeratet = torch.as_tensor(spikerate)
observation = spikeratet.reshape(1, -1)
else:
observation = torch.as_tensor(compute_features(APs, simtime, transient, fs))
posterior.set_default_x(observation)
est_params = posterior.map()
samples = posterior.sample((1000,))
_ = analysis.pairplot(samples,
diag='kde',
ticks = prior_limits,
upper='kde', points = [est_params, test_params], points_colors=['#EF6F6C', '#6B0504'],
points_offdiag={'markersize': 8},
limits=prior_limits,
figsize=(6,6), labels=par_labels)
plt.show()
# Run a simulation with the MAP of the posterior to see if it matches original
APsres, simtime, transient, fs = MEAnetsimulate(est_params)
rasterplot(APsres, 'PPC1_Feat_post', 1/fs, transient, simtime, 'black')
if not embedding_net:
model_prediction = torch.as_tensor(compute_features(APsres, simtime, transient, fs))
# calculate the PRE
def normalize_parameters(params, prior_min, prior_max):
if not (params.shape[-1] == prior_min.shape[0] == prior_max.shape[0]):
raise ValueError("Mismatch between number of parameters and prior range dimensions.")
return (params - prior_min) / (prior_max - prior_min)
def compute_pre(posterior_samples, ground_truth):
"""
Computes the Posterior Recovery Error (PRE) for given posterior samples and ground truth parameters.
"""
if posterior_samples.shape[1] != ground_truth.shape[0]:
raise ValueError("Mismatch between number of parameters in samples and ground truth.")
pre = torch.mean((posterior_samples - ground_truth) ** 2, dim=0)
pre = pre.numpy()
return pre
norm_samps = normalize_parameters(samples, np.array(prior_min), np.array(prior_max))
norm_GT = normalize_parameters(test_params, np.array(prior_min), np.array(prior_max))
pre = compute_pre(norm_samps, norm_GT)
for s,n in zip(par_labels, pre):
print(f"PRE {s}: {n}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment