diff --git a/PosteriorPredictiveChecks.py b/PosteriorPredictiveChecks.py new file mode 100644 index 0000000000000000000000000000000000000000..e8835dde879f4d73e0eea4dddf486b88de0ebbdb --- /dev/null +++ b/PosteriorPredictiveChecks.py @@ -0,0 +1,90 @@ +# To perform and visualeze posterior predictive checks on a trained density estimator +import torch +from brian2 import * +from sbi import analysis, utils +import numpy as np +from FeatureExtraction import compute_features, compute_spikerate +from Simulator import MEAnetsimulate +from MakeFigures import rasterplot +import pickle + +num_feats = 15 # Number of MEA features +num_params = 10 # Number of free parameters of the model +prior_min = [1.5, 0.5, 0.1, 0.5, 0.05, 0., 0.1, 150, 0.005, 0.] +prior_max = [7, 2, 10, 10, 1, 1, 0.6, 1200, 0.3, 0.005] +prior = utils.BoxUniform(low=torch.tensor(prior_min), high=torch.tensor(prior_max)) +prior_limits = [[1.5, 7], [0.5, 2], [0.1, 10], [0.5, 10], [0.05, 1], [0., 1], [0.1, 0.6], + [150, 1200], [0.005, 0.3], [0., 0.005]] +par_labels = ['noise', '$g_{Na}$', '$g_{K}$', '$g_{AHP}$', '$g_{AMPA}$', '$g_{NMDA}$', 'Conn%', + r'$\tau_{D}$', 'U (STD)', 'U asyn'] +feat_labels = ['MFR', 'NBR', 'NBD', 'PSIB', '#FBs', 'CVIBI', 'mean CC', 'sd CC', 'mean ISI CC', 'sd ISI CC', 'ISI dist', + 'mean ISI', 'sd ISI temp', 'sd isi elec', 'MAC'] + +# load your trained density estimator +with open('TrainedNDE', 'rb') as f: # with open('Posterior_Features', 'rb') as f: + posterior = pickle.load(f) +embedding_net = False # true means NDE was trained with embedding net on spike rates per electrode + # false mean NDE was trained on 15 MEA features + +# Visualize the performance on one set of ground-truth parameters +# set of ground-truth parameters (define yourself): +test_params = torch.as_tensor([4, 1, 0.4, 5, 0.2, 0.4, 0.15, 200, 0.1, 0.0001]) +# or alternatively, draw a random sample from prior: +# test_params = prior.sample((1,)) + +# perform a simulations with the parameters +APs, simtime, transient, fs = MEAnetsimulate(test_params) + +rasterplot(APs, 'PPC1_Feat_pre', 1/fs, transient, simtime, 'black') + +if embedding_net: + numelectrodes = 12 + time_bin = 100e-3 + spikerate = compute_spikerate(APs, simtime/second, transient/second, fs, time_bin) + spikeratet = torch.as_tensor(spikerate) + observation = spikeratet.reshape(1, -1) +else: + observation = torch.as_tensor(compute_features(APs, simtime, transient, fs)) + +posterior.set_default_x(observation) +est_params = posterior.map() +samples = posterior.sample((1000,)) +_ = analysis.pairplot(samples, + diag='kde', + ticks = prior_limits, + upper='kde', points = [est_params, test_params], points_colors=['#EF6F6C', '#6B0504'], + points_offdiag={'markersize': 8}, + limits=prior_limits, + figsize=(6,6), labels=par_labels) +plt.show() + +# Run a simulation with the MAP of the posterior to see if it matches original +APsres, simtime, transient, fs = MEAnetsimulate(est_params) +rasterplot(APsres, 'PPC1_Feat_post', 1/fs, transient, simtime, 'black') +if not embedding_net: + model_prediction = torch.as_tensor(compute_features(APsres, simtime, transient, fs)) + +# calculate the PRE +def normalize_parameters(params, prior_min, prior_max): + if not (params.shape[-1] == prior_min.shape[0] == prior_max.shape[0]): + raise ValueError("Mismatch between number of parameters and prior range dimensions.") + + return (params - prior_min) / (prior_max - prior_min) + + +def compute_pre(posterior_samples, ground_truth): + """ + Computes the Posterior Recovery Error (PRE) for given posterior samples and ground truth parameters. + """ + if posterior_samples.shape[1] != ground_truth.shape[0]: + raise ValueError("Mismatch between number of parameters in samples and ground truth.") + + pre = torch.mean((posterior_samples - ground_truth) ** 2, dim=0) + pre = pre.numpy() + return pre + +norm_samps = normalize_parameters(samples, np.array(prior_min), np.array(prior_max)) +norm_GT = normalize_parameters(test_params, np.array(prior_min), np.array(prior_max)) +pre = compute_pre(norm_samps, norm_GT) +for s,n in zip(par_labels, pre): + print(f"PRE {s}: {n}") \ No newline at end of file