From db13e96dc33cd885a051076a88e9e8af879ebfe7 Mon Sep 17 00:00:00 2001 From: "Doorn, Nina (UT-TNW)" <n.doorn-1@utwente.nl> Date: Wed, 15 Jan 2025 11:31:37 +0100 Subject: [PATCH] Included all example observations --- FindPosteriors.py | 119 ++++++++++++++++++++++++---------------------- 1 file changed, 62 insertions(+), 57 deletions(-) diff --git a/FindPosteriors.py b/FindPosteriors.py index 7e45f44..29bc2bb 100644 --- a/FindPosteriors.py +++ b/FindPosteriors.py @@ -4,34 +4,38 @@ from brian2 import * from sbi import utils as utils from sbi import analysis as analysis import matplotlib.pyplot as plt -from Simulator import MEAnetSimulate, ComputeFeatures -from MakeFigures import rasterplot, Marginaldiffplot +from Simulator import MEAnetsimulate, compute_features +from MakeFigures import rasterplot, marginaldiffplot from scipy.stats import ks_2samp -numstats = 15 # Number of summary statistics -numparams = 10 # Number of free parameters of the model +example_dir = '../example_observations/' # directory with the example observations + +num_stats = 15 # Number of summary statistics +num_params = 10 # Number of free parameters of the model prior_min = [1.5, 0.5, 0.1, 0.5, 0.05, 0., 0.1, 150, 0.005, 0.] prior_max = [7, 2, 10, 10, 1, 1, 0.6, 1200, 0.3, 0.005] prior = utils.BoxUniform(low=torch.tensor(prior_min), high=torch.tensor(prior_max)) -priorlimits = [[1.5, 7], [0.5, 2], [0.1, 10], [0.5, 10], [0.05, 1], [0., 1], [0.1, 0.6], [150, 1200], [0.005, 0.3], [0., 0.005]] -numparams = 10 -parlabels = ['noise', '$g_{Na}$', '$g_{K}$', '$g_{AHP}$', '$g_{AMPA}$', '$g_{NMDA}$', 'Conn%', r'$\tau_{D}$', 'U (STD)', 'U asyn'] -SSlabels = ['MFR', 'NBR', 'NBD', 'PSIB', '#FBs', 'CVIBI', 'mean CC', 'sd CC', 'mean ISI CC', 'sd ISI CC', 'ISI dist', 'mean ISI', 'sd ISI temp', 'sd isi elec', 'MAC'] - -## LOAD YOUR OWN EXPERIMENTAL DATA TO OBTAIN POSTERIOR +prior_limits = [[1.5, 7], [0.5, 2], [0.1, 10], [0.5, 10], [0.05, 1], [0., 1], [0.1, 0.6], [150, 1200], [0.005, 0.3], + [0., 0.005]] +par_labels = ['noise', '$g_{Na}$', '$g_{K}$', '$g_{AHP}$', '$g_{AMPA}$', '$g_{NMDA}$', 'Conn%', r'$\tau_{D}$', + 'U (STD)', 'U asyn'] +SS_labels = ['MFR', 'NBR', 'NBD', 'PSIB', '#FBs', 'CVIBI', 'mean CC', 'sd CC', 'mean ISI CC', 'sd ISI CC', 'ISI dist', + 'mean ISI', 'sd ISI temp', 'sd isi elec', 'MAC'] + +# LOAD YOUR OWN EXPERIMENTAL DATA TO INFER POSTERIOR # Load your own experimental data as APs (first column electrode number, second column AP timestamps): -# location of your experimental files -exp_fileloc = '/home/yourlocation' +# location of your experimental file +exp_fileloc = '/home/Nina/Documents/SBI_project/Output/Paper_Figures_ver1/APs_Fig_5_CACNClonesb3_sim0.csv' APs_obs = numpy.loadtxt(exp_fileloc, delimiter=",", dtype='int') -recordtime = 165 * second # how long the recording was -fs = 10000 # sampling frequency used for the recording +recordtime = 165 * second # how long the recording was +fs = 10000 # sampling frequency used for the recording -# plot the data +# Plot the data rasterplot(APs_obs, "observation", 1/fs, 0*second, recordtime, 'black') # Calculate MEA features -exp_MEAfeatures = ComputeFeatures(APs_obs, recordtime, 5 * second, fs) +exp_MEAfeatures = compute_features(APs_obs, recordtime, 5 * second, fs) # Load the trainedNDE for the posterior with open('TrainedNDE', 'rb') as f: @@ -45,48 +49,49 @@ modeparams = posterior.map() samples = posterior.sample((1000,)) _ = analysis.pairplot(samples, diag='kde', - ticks = priorlimits, - upper='kde', points = modeparams, points_colors=['#EF6F6C'], + ticks=prior_limits, + upper='kde', points=modeparams, points_colors=['#EF6F6C'], points_offdiag={'markersize': 8}, - limits=priorlimits, - figsize=(6,6), labels=parlabels) + limits=prior_limits, + figsize=(6, 6), labels=par_labels) plt.show() -# run simulations with the mode of the posterior -APs_sim, simtime, transient, fs = MEAnetSimulate(modeparams) +# Run simulations with the mode of the posterior +APs_sim, simtime, transient, fs = MEAnetsimulate(modeparams) rasterplot(APs_sim, "simulation", 1/fs, transient, simtime, 'black') -## COMPARE TWO POSTERIORS -# calculate or define the MEA features of your two observations -observation1 = torch.tensor(torch.load('SCN_WTC_2410.pt')) -posterior.set_default_x(observation1) # find the maxima of the posterior +# COMPARE TWO POSTERIORS +# Calculate or define the MEA features of your two observations +observation1 = torch.tensor(torch.load(example_dir + 'SCN_WTC_2410.pt')) +posterior.set_default_x(observation1) obs1_samples = posterior.sample((1000,)) -observation2 = torch.tensor(torch.load('SCN_GEFS_2410.pt')) -posterior.set_default_x(observation2) # find the maxima of the posterior +observation2 = torch.tensor(torch.load(example_dir + 'SCN_GEFS_2410.pt')) +posterior.set_default_x(observation2) obs2_samples = posterior.sample((1000,)) -Marginaldiffplot(obs1_samples, obs2_samples, numparams, priorlimits, parlabels, 'WTC_GEFS_diff') +marginaldiffplot(obs1_samples, obs2_samples, num_params, prior_limits, par_labels, 'WTC_GEFS_diff') -#Perform Kolmogorov-Smirnov test to test differences between marginals -observation1 = torch.tensor(torch.load('SCN_WTC_2410.pt')) -posterior.set_default_x(observation1) # find the maxima of the posterior -obs1_samples = posterior.sample((50,)) -observation2 = torch.tensor(torch.load('SCN_GEFS_2410.pt')) -posterior.set_default_x(observation2) # find the maxima of the posterior -obs2_samples = posterior.sample((50,)) +# Perform Kolmogorov-Smirnov test to test differences between marginals +num_samples = 50 # the number of samples drawn from the posterior to perform KS test +observation1 = torch.tensor(torch.load(example_dir + 'SCN_WTC_2410.pt')) +posterior.set_default_x(observation1) +obs1_samples = posterior.sample((num_samples,)) +observation2 = torch.tensor(torch.load(example_dir + 'SCN_GEFS_2410.pt')) +posterior.set_default_x(observation2) +obs2_samples = posterior.sample((num_samples,)) -KSs = np.zeros(numparams) -Pvals = np.zeros(numparams) -for i in range(numparams): +KSs = np.zeros(num_params) +Pvals = np.zeros(num_params) +for i in range(num_params): par = i - KSs[i], Pvals[i] = ks_2samp(obs1_samples[:,par], obs2_samples[:,par]) - print(parlabels[i]) + KSs[i], Pvals[i] = ks_2samp(obs1_samples[:, par], obs2_samples[:, par]) + print(par_labels[i]) print("KS statistic:", KSs[i]) print("P-value:", Pvals[i]) -## FIND CONDITIONAL DISTRIBUTIONS AND PEARSON CORRELATIONS -# show a conditional posterior distribution with one sample from the posterior -observation = torch.tensor(torch.load('SCN_DS_2410.pt')) +# FIND CONDITIONAL DISTRIBUTIONS AND PEARSON CORRELATIONS +# Show a conditional posterior distribution with one sample from the posterior +observation = torch.tensor(torch.load(example_dir + 'SCN_DS_2410.pt')) posterior.set_default_x(observation) condition = posterior.sample((1,)) @@ -95,31 +100,31 @@ _ = analysis.conditional_pairplot( condition=condition, diag=['kde'], upper=['kde'], - limits=priorlimits, - figsize=(6,6), labels=parlabels) + limits=prior_limits, + figsize=(6, 6), labels=par_labels) plt.show() -# Compute the correlation coefficient of every pair of parameters for every posterior sample -numconds = 50 -corrcoefs = np.zeros((numconds, 100)) -for i in range (numconds): +# Compute the correlation coefficient of every pair of parameters for num_conds conditional distributions +num_conds = 50 # of how many conditional distributions you want to compute the CCs +corrcoefs = np.zeros((num_conds, 100)) +for i in range(num_conds): condition = posterior.sample((1,)) cond_coeff_mat = analysis.conditional_corrcoeff( density=posterior, condition=condition, - limits=torch.tensor(priorlimits), + limits=torch.tensor(prior_limits), ) - corrcoefs[i,:] = np.array(torch.flatten(cond_coeff_mat)) + corrcoefs[i, :] = np.array(torch.flatten(cond_coeff_mat)) -#take the average correlation coefficients +# Take the average correlation coefficients average_corrcoefs = torch.tensor(np.mean(corrcoefs, axis=0)) average_corrcoefs_pl = torch.unflatten(average_corrcoefs, 0, (10, 10)) -#Construct the correlation matrix +# Construct the correlation matrix of the average correlation coefficients fig, ax = plt.subplots(1, 1, figsize=(3, 3)) im = plt.imshow(average_corrcoefs_pl, clim=[-0.6, 0.6], cmap="RdBu") -ax.set_xticks(range(0,10)) -ax.set_xticklabels(parlabels, rotation = 90) -ax.set_yticks(range(0,10), parlabels) +ax.set_xticks(range(0, 10)) +ax.set_xticklabels(par_labels, rotation=90) +ax.set_yticks(range(0, 10), par_labels) _ = fig.colorbar(im) plt.show() -- GitLab