Skip to content
Snippets Groups Projects
Commit db13e96d authored by Doorn, Nina (UT-TNW)'s avatar Doorn, Nina (UT-TNW)
Browse files

Included all example observations

parent 9b561e37
Branches
No related tags found
No related merge requests found
......@@ -4,34 +4,38 @@ from brian2 import *
from sbi import utils as utils
from sbi import analysis as analysis
import matplotlib.pyplot as plt
from Simulator import MEAnetSimulate, ComputeFeatures
from MakeFigures import rasterplot, Marginaldiffplot
from Simulator import MEAnetsimulate, compute_features
from MakeFigures import rasterplot, marginaldiffplot
from scipy.stats import ks_2samp
numstats = 15 # Number of summary statistics
numparams = 10 # Number of free parameters of the model
example_dir = '../example_observations/' # directory with the example observations
num_stats = 15 # Number of summary statistics
num_params = 10 # Number of free parameters of the model
prior_min = [1.5, 0.5, 0.1, 0.5, 0.05, 0., 0.1, 150, 0.005, 0.]
prior_max = [7, 2, 10, 10, 1, 1, 0.6, 1200, 0.3, 0.005]
prior = utils.BoxUniform(low=torch.tensor(prior_min), high=torch.tensor(prior_max))
priorlimits = [[1.5, 7], [0.5, 2], [0.1, 10], [0.5, 10], [0.05, 1], [0., 1], [0.1, 0.6], [150, 1200], [0.005, 0.3], [0., 0.005]]
numparams = 10
parlabels = ['noise', '$g_{Na}$', '$g_{K}$', '$g_{AHP}$', '$g_{AMPA}$', '$g_{NMDA}$', 'Conn%', r'$\tau_{D}$', 'U (STD)', 'U asyn']
SSlabels = ['MFR', 'NBR', 'NBD', 'PSIB', '#FBs', 'CVIBI', 'mean CC', 'sd CC', 'mean ISI CC', 'sd ISI CC', 'ISI dist', 'mean ISI', 'sd ISI temp', 'sd isi elec', 'MAC']
## LOAD YOUR OWN EXPERIMENTAL DATA TO OBTAIN POSTERIOR
prior_limits = [[1.5, 7], [0.5, 2], [0.1, 10], [0.5, 10], [0.05, 1], [0., 1], [0.1, 0.6], [150, 1200], [0.005, 0.3],
[0., 0.005]]
par_labels = ['noise', '$g_{Na}$', '$g_{K}$', '$g_{AHP}$', '$g_{AMPA}$', '$g_{NMDA}$', 'Conn%', r'$\tau_{D}$',
'U (STD)', 'U asyn']
SS_labels = ['MFR', 'NBR', 'NBD', 'PSIB', '#FBs', 'CVIBI', 'mean CC', 'sd CC', 'mean ISI CC', 'sd ISI CC', 'ISI dist',
'mean ISI', 'sd ISI temp', 'sd isi elec', 'MAC']
# LOAD YOUR OWN EXPERIMENTAL DATA TO INFER POSTERIOR
# Load your own experimental data as APs (first column electrode number, second column AP timestamps):
# location of your experimental files
exp_fileloc = '/home/yourlocation'
# location of your experimental file
exp_fileloc = '/home/Nina/Documents/SBI_project/Output/Paper_Figures_ver1/APs_Fig_5_CACNClonesb3_sim0.csv'
APs_obs = numpy.loadtxt(exp_fileloc, delimiter=",", dtype='int')
recordtime = 165 * second # how long the recording was
fs = 10000 # sampling frequency used for the recording
recordtime = 165 * second # how long the recording was
fs = 10000 # sampling frequency used for the recording
# plot the data
# Plot the data
rasterplot(APs_obs, "observation", 1/fs, 0*second, recordtime, 'black')
# Calculate MEA features
exp_MEAfeatures = ComputeFeatures(APs_obs, recordtime, 5 * second, fs)
exp_MEAfeatures = compute_features(APs_obs, recordtime, 5 * second, fs)
# Load the trainedNDE for the posterior
with open('TrainedNDE', 'rb') as f:
......@@ -45,48 +49,49 @@ modeparams = posterior.map()
samples = posterior.sample((1000,))
_ = analysis.pairplot(samples,
diag='kde',
ticks = priorlimits,
upper='kde', points = modeparams, points_colors=['#EF6F6C'],
ticks=prior_limits,
upper='kde', points=modeparams, points_colors=['#EF6F6C'],
points_offdiag={'markersize': 8},
limits=priorlimits,
figsize=(6,6), labels=parlabels)
limits=prior_limits,
figsize=(6, 6), labels=par_labels)
plt.show()
# run simulations with the mode of the posterior
APs_sim, simtime, transient, fs = MEAnetSimulate(modeparams)
# Run simulations with the mode of the posterior
APs_sim, simtime, transient, fs = MEAnetsimulate(modeparams)
rasterplot(APs_sim, "simulation", 1/fs, transient, simtime, 'black')
## COMPARE TWO POSTERIORS
# calculate or define the MEA features of your two observations
observation1 = torch.tensor(torch.load('SCN_WTC_2410.pt'))
posterior.set_default_x(observation1) # find the maxima of the posterior
# COMPARE TWO POSTERIORS
# Calculate or define the MEA features of your two observations
observation1 = torch.tensor(torch.load(example_dir + 'SCN_WTC_2410.pt'))
posterior.set_default_x(observation1)
obs1_samples = posterior.sample((1000,))
observation2 = torch.tensor(torch.load('SCN_GEFS_2410.pt'))
posterior.set_default_x(observation2) # find the maxima of the posterior
observation2 = torch.tensor(torch.load(example_dir + 'SCN_GEFS_2410.pt'))
posterior.set_default_x(observation2)
obs2_samples = posterior.sample((1000,))
Marginaldiffplot(obs1_samples, obs2_samples, numparams, priorlimits, parlabels, 'WTC_GEFS_diff')
marginaldiffplot(obs1_samples, obs2_samples, num_params, prior_limits, par_labels, 'WTC_GEFS_diff')
#Perform Kolmogorov-Smirnov test to test differences between marginals
observation1 = torch.tensor(torch.load('SCN_WTC_2410.pt'))
posterior.set_default_x(observation1) # find the maxima of the posterior
obs1_samples = posterior.sample((50,))
observation2 = torch.tensor(torch.load('SCN_GEFS_2410.pt'))
posterior.set_default_x(observation2) # find the maxima of the posterior
obs2_samples = posterior.sample((50,))
# Perform Kolmogorov-Smirnov test to test differences between marginals
num_samples = 50 # the number of samples drawn from the posterior to perform KS test
observation1 = torch.tensor(torch.load(example_dir + 'SCN_WTC_2410.pt'))
posterior.set_default_x(observation1)
obs1_samples = posterior.sample((num_samples,))
observation2 = torch.tensor(torch.load(example_dir + 'SCN_GEFS_2410.pt'))
posterior.set_default_x(observation2)
obs2_samples = posterior.sample((num_samples,))
KSs = np.zeros(numparams)
Pvals = np.zeros(numparams)
for i in range(numparams):
KSs = np.zeros(num_params)
Pvals = np.zeros(num_params)
for i in range(num_params):
par = i
KSs[i], Pvals[i] = ks_2samp(obs1_samples[:,par], obs2_samples[:,par])
print(parlabels[i])
KSs[i], Pvals[i] = ks_2samp(obs1_samples[:, par], obs2_samples[:, par])
print(par_labels[i])
print("KS statistic:", KSs[i])
print("P-value:", Pvals[i])
## FIND CONDITIONAL DISTRIBUTIONS AND PEARSON CORRELATIONS
# show a conditional posterior distribution with one sample from the posterior
observation = torch.tensor(torch.load('SCN_DS_2410.pt'))
# FIND CONDITIONAL DISTRIBUTIONS AND PEARSON CORRELATIONS
# Show a conditional posterior distribution with one sample from the posterior
observation = torch.tensor(torch.load(example_dir + 'SCN_DS_2410.pt'))
posterior.set_default_x(observation)
condition = posterior.sample((1,))
......@@ -95,31 +100,31 @@ _ = analysis.conditional_pairplot(
condition=condition,
diag=['kde'],
upper=['kde'],
limits=priorlimits,
figsize=(6,6), labels=parlabels)
limits=prior_limits,
figsize=(6, 6), labels=par_labels)
plt.show()
# Compute the correlation coefficient of every pair of parameters for every posterior sample
numconds = 50
corrcoefs = np.zeros((numconds, 100))
for i in range (numconds):
# Compute the correlation coefficient of every pair of parameters for num_conds conditional distributions
num_conds = 50 # of how many conditional distributions you want to compute the CCs
corrcoefs = np.zeros((num_conds, 100))
for i in range(num_conds):
condition = posterior.sample((1,))
cond_coeff_mat = analysis.conditional_corrcoeff(
density=posterior,
condition=condition,
limits=torch.tensor(priorlimits),
limits=torch.tensor(prior_limits),
)
corrcoefs[i,:] = np.array(torch.flatten(cond_coeff_mat))
corrcoefs[i, :] = np.array(torch.flatten(cond_coeff_mat))
#take the average correlation coefficients
# Take the average correlation coefficients
average_corrcoefs = torch.tensor(np.mean(corrcoefs, axis=0))
average_corrcoefs_pl = torch.unflatten(average_corrcoefs, 0, (10, 10))
#Construct the correlation matrix
# Construct the correlation matrix of the average correlation coefficients
fig, ax = plt.subplots(1, 1, figsize=(3, 3))
im = plt.imshow(average_corrcoefs_pl, clim=[-0.6, 0.6], cmap="RdBu")
ax.set_xticks(range(0,10))
ax.set_xticklabels(parlabels, rotation = 90)
ax.set_yticks(range(0,10), parlabels)
ax.set_xticks(range(0, 10))
ax.set_xticklabels(par_labels, rotation=90)
ax.set_yticks(range(0, 10), par_labels)
_ = fig.colorbar(im)
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment