Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
SBI MEA Model
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Doorn, Nina (UT-TNW)
SBI MEA Model
Commits
db13e96d
Commit
db13e96d
authored
3 months ago
by
Doorn, Nina (UT-TNW)
Browse files
Options
Downloads
Patches
Plain Diff
Included all example observations
parent
9b561e37
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
FindPosteriors.py
+62
-57
62 additions, 57 deletions
FindPosteriors.py
with
62 additions
and
57 deletions
FindPosteriors.py
+
62
−
57
View file @
db13e96d
...
...
@@ -4,34 +4,38 @@ from brian2 import *
from
sbi
import
utils
as
utils
from
sbi
import
analysis
as
analysis
import
matplotlib.pyplot
as
plt
from
Simulator
import
MEAnet
S
imulate
,
C
ompute
F
eatures
from
MakeFigures
import
rasterplot
,
M
arginaldiffplot
from
Simulator
import
MEAnet
s
imulate
,
c
ompute
_f
eatures
from
MakeFigures
import
rasterplot
,
m
arginaldiffplot
from
scipy.stats
import
ks_2samp
numstats
=
15
# Number of summary statistics
numparams
=
10
# Number of free parameters of the model
example_dir
=
'
../example_observations/
'
# directory with the example observations
num_stats
=
15
# Number of summary statistics
num_params
=
10
# Number of free parameters of the model
prior_min
=
[
1.5
,
0.5
,
0.1
,
0.5
,
0.05
,
0.
,
0.1
,
150
,
0.005
,
0.
]
prior_max
=
[
7
,
2
,
10
,
10
,
1
,
1
,
0.6
,
1200
,
0.3
,
0.005
]
prior
=
utils
.
BoxUniform
(
low
=
torch
.
tensor
(
prior_min
),
high
=
torch
.
tensor
(
prior_max
))
priorlimits
=
[[
1.5
,
7
],
[
0.5
,
2
],
[
0.1
,
10
],
[
0.5
,
10
],
[
0.05
,
1
],
[
0.
,
1
],
[
0.1
,
0.6
],
[
150
,
1200
],
[
0.005
,
0.3
],
[
0.
,
0.005
]]
numparams
=
10
parlabels
=
[
'
noise
'
,
'
$g_{Na}$
'
,
'
$g_{K}$
'
,
'
$g_{AHP}$
'
,
'
$g_{AMPA}$
'
,
'
$g_{NMDA}$
'
,
'
Conn%
'
,
r
'
$\tau_{D}$
'
,
'
U (STD)
'
,
'
U asyn
'
]
SSlabels
=
[
'
MFR
'
,
'
NBR
'
,
'
NBD
'
,
'
PSIB
'
,
'
#FBs
'
,
'
CVIBI
'
,
'
mean CC
'
,
'
sd CC
'
,
'
mean ISI CC
'
,
'
sd ISI CC
'
,
'
ISI dist
'
,
'
mean ISI
'
,
'
sd ISI temp
'
,
'
sd isi elec
'
,
'
MAC
'
]
## LOAD YOUR OWN EXPERIMENTAL DATA TO OBTAIN POSTERIOR
prior_limits
=
[[
1.5
,
7
],
[
0.5
,
2
],
[
0.1
,
10
],
[
0.5
,
10
],
[
0.05
,
1
],
[
0.
,
1
],
[
0.1
,
0.6
],
[
150
,
1200
],
[
0.005
,
0.3
],
[
0.
,
0.005
]]
par_labels
=
[
'
noise
'
,
'
$g_{Na}$
'
,
'
$g_{K}$
'
,
'
$g_{AHP}$
'
,
'
$g_{AMPA}$
'
,
'
$g_{NMDA}$
'
,
'
Conn%
'
,
r
'
$\tau_{D}$
'
,
'
U (STD)
'
,
'
U asyn
'
]
SS_labels
=
[
'
MFR
'
,
'
NBR
'
,
'
NBD
'
,
'
PSIB
'
,
'
#FBs
'
,
'
CVIBI
'
,
'
mean CC
'
,
'
sd CC
'
,
'
mean ISI CC
'
,
'
sd ISI CC
'
,
'
ISI dist
'
,
'
mean ISI
'
,
'
sd ISI temp
'
,
'
sd isi elec
'
,
'
MAC
'
]
# LOAD YOUR OWN EXPERIMENTAL DATA TO INFER POSTERIOR
# Load your own experimental data as APs (first column electrode number, second column AP timestamps):
# location of your experimental file
s
exp_fileloc
=
'
/home/
yourlocation
'
# location of your experimental file
exp_fileloc
=
'
/home/
Nina/Documents/SBI_project/Output/Paper_Figures_ver1/APs_Fig_5_CACNClonesb3_sim0.csv
'
APs_obs
=
numpy
.
loadtxt
(
exp_fileloc
,
delimiter
=
"
,
"
,
dtype
=
'
int
'
)
recordtime
=
165
*
second
# how long the recording was
fs
=
10000
# sampling frequency used for the recording
recordtime
=
165
*
second
# how long the recording was
fs
=
10000
# sampling frequency used for the recording
#
p
lot the data
#
P
lot the data
rasterplot
(
APs_obs
,
"
observation
"
,
1
/
fs
,
0
*
second
,
recordtime
,
'
black
'
)
# Calculate MEA features
exp_MEAfeatures
=
C
ompute
F
eatures
(
APs_obs
,
recordtime
,
5
*
second
,
fs
)
exp_MEAfeatures
=
c
ompute
_f
eatures
(
APs_obs
,
recordtime
,
5
*
second
,
fs
)
# Load the trainedNDE for the posterior
with
open
(
'
TrainedNDE
'
,
'
rb
'
)
as
f
:
...
...
@@ -45,48 +49,49 @@ modeparams = posterior.map()
samples
=
posterior
.
sample
((
1000
,))
_
=
analysis
.
pairplot
(
samples
,
diag
=
'
kde
'
,
ticks
=
priorlimits
,
upper
=
'
kde
'
,
points
=
modeparams
,
points_colors
=
[
'
#EF6F6C
'
],
ticks
=
prior
_
limits
,
upper
=
'
kde
'
,
points
=
modeparams
,
points_colors
=
[
'
#EF6F6C
'
],
points_offdiag
=
{
'
markersize
'
:
8
},
limits
=
priorlimits
,
figsize
=
(
6
,
6
),
labels
=
parlabels
)
limits
=
prior
_
limits
,
figsize
=
(
6
,
6
),
labels
=
par
_
labels
)
plt
.
show
()
#
r
un simulations with the mode of the posterior
APs_sim
,
simtime
,
transient
,
fs
=
MEAnet
S
imulate
(
modeparams
)
#
R
un simulations with the mode of the posterior
APs_sim
,
simtime
,
transient
,
fs
=
MEAnet
s
imulate
(
modeparams
)
rasterplot
(
APs_sim
,
"
simulation
"
,
1
/
fs
,
transient
,
simtime
,
'
black
'
)
#
# COMPARE TWO POSTERIORS
#
c
alculate or define the MEA features of your two observations
observation1
=
torch
.
tensor
(
torch
.
load
(
'
SCN_WTC_2410.pt
'
))
posterior
.
set_default_x
(
observation1
)
# find the maxima of the posterior
# COMPARE TWO POSTERIORS
#
C
alculate or define the MEA features of your two observations
observation1
=
torch
.
tensor
(
torch
.
load
(
example_dir
+
'
SCN_WTC_2410.pt
'
))
posterior
.
set_default_x
(
observation1
)
obs1_samples
=
posterior
.
sample
((
1000
,))
observation2
=
torch
.
tensor
(
torch
.
load
(
'
SCN_GEFS_2410.pt
'
))
posterior
.
set_default_x
(
observation2
)
# find the maxima of the posterior
observation2
=
torch
.
tensor
(
torch
.
load
(
example_dir
+
'
SCN_GEFS_2410.pt
'
))
posterior
.
set_default_x
(
observation2
)
obs2_samples
=
posterior
.
sample
((
1000
,))
M
arginaldiffplot
(
obs1_samples
,
obs2_samples
,
numparams
,
priorlimits
,
parlabels
,
'
WTC_GEFS_diff
'
)
m
arginaldiffplot
(
obs1_samples
,
obs2_samples
,
num
_
params
,
prior
_
limits
,
par
_
labels
,
'
WTC_GEFS_diff
'
)
#Perform Kolmogorov-Smirnov test to test differences between marginals
observation1
=
torch
.
tensor
(
torch
.
load
(
'
SCN_WTC_2410.pt
'
))
posterior
.
set_default_x
(
observation1
)
# find the maxima of the posterior
obs1_samples
=
posterior
.
sample
((
50
,))
observation2
=
torch
.
tensor
(
torch
.
load
(
'
SCN_GEFS_2410.pt
'
))
posterior
.
set_default_x
(
observation2
)
# find the maxima of the posterior
obs2_samples
=
posterior
.
sample
((
50
,))
# Perform Kolmogorov-Smirnov test to test differences between marginals
num_samples
=
50
# the number of samples drawn from the posterior to perform KS test
observation1
=
torch
.
tensor
(
torch
.
load
(
example_dir
+
'
SCN_WTC_2410.pt
'
))
posterior
.
set_default_x
(
observation1
)
obs1_samples
=
posterior
.
sample
((
num_samples
,))
observation2
=
torch
.
tensor
(
torch
.
load
(
example_dir
+
'
SCN_GEFS_2410.pt
'
))
posterior
.
set_default_x
(
observation2
)
obs2_samples
=
posterior
.
sample
((
num_samples
,))
KSs
=
np
.
zeros
(
numparams
)
Pvals
=
np
.
zeros
(
numparams
)
for
i
in
range
(
numparams
):
KSs
=
np
.
zeros
(
num
_
params
)
Pvals
=
np
.
zeros
(
num
_
params
)
for
i
in
range
(
num
_
params
):
par
=
i
KSs
[
i
],
Pvals
[
i
]
=
ks_2samp
(
obs1_samples
[:,
par
],
obs2_samples
[:,
par
])
print
(
parlabels
[
i
])
KSs
[
i
],
Pvals
[
i
]
=
ks_2samp
(
obs1_samples
[:,
par
],
obs2_samples
[:,
par
])
print
(
par
_
labels
[
i
])
print
(
"
KS statistic:
"
,
KSs
[
i
])
print
(
"
P-value:
"
,
Pvals
[
i
])
#
#
FIND CONDITIONAL DISTRIBUTIONS AND PEARSON CORRELATIONS
#
s
how a conditional posterior distribution with one sample from the posterior
observation
=
torch
.
tensor
(
torch
.
load
(
'
SCN_DS_2410.pt
'
))
# FIND CONDITIONAL DISTRIBUTIONS AND PEARSON CORRELATIONS
#
S
how a conditional posterior distribution with one sample from the posterior
observation
=
torch
.
tensor
(
torch
.
load
(
example_dir
+
'
SCN_DS_2410.pt
'
))
posterior
.
set_default_x
(
observation
)
condition
=
posterior
.
sample
((
1
,))
...
...
@@ -95,31 +100,31 @@ _ = analysis.conditional_pairplot(
condition
=
condition
,
diag
=
[
'
kde
'
],
upper
=
[
'
kde
'
],
limits
=
priorlimits
,
figsize
=
(
6
,
6
),
labels
=
parlabels
)
limits
=
prior
_
limits
,
figsize
=
(
6
,
6
),
labels
=
par
_
labels
)
plt
.
show
()
# Compute the correlation coefficient of every pair of parameters for
every po
st
e
ri
or sample
numconds
=
50
corrcoefs
=
np
.
zeros
((
numconds
,
100
))
for
i
in
range
(
numconds
):
# Compute the correlation coefficient of every pair of parameters for
num_conds conditional di
stri
butions
num
_
conds
=
50
# of how many conditional distributions you want to compute the CCs
corrcoefs
=
np
.
zeros
((
num
_
conds
,
100
))
for
i
in
range
(
num
_
conds
):
condition
=
posterior
.
sample
((
1
,))
cond_coeff_mat
=
analysis
.
conditional_corrcoeff
(
density
=
posterior
,
condition
=
condition
,
limits
=
torch
.
tensor
(
priorlimits
),
limits
=
torch
.
tensor
(
prior
_
limits
),
)
corrcoefs
[
i
,:]
=
np
.
array
(
torch
.
flatten
(
cond_coeff_mat
))
corrcoefs
[
i
,
:]
=
np
.
array
(
torch
.
flatten
(
cond_coeff_mat
))
#
t
ake the average correlation coefficients
#
T
ake the average correlation coefficients
average_corrcoefs
=
torch
.
tensor
(
np
.
mean
(
corrcoefs
,
axis
=
0
))
average_corrcoefs_pl
=
torch
.
unflatten
(
average_corrcoefs
,
0
,
(
10
,
10
))
#Construct the correlation matrix
#
Construct the correlation matrix
of the average correlation coefficients
fig
,
ax
=
plt
.
subplots
(
1
,
1
,
figsize
=
(
3
,
3
))
im
=
plt
.
imshow
(
average_corrcoefs_pl
,
clim
=
[
-
0.6
,
0.6
],
cmap
=
"
RdBu
"
)
ax
.
set_xticks
(
range
(
0
,
10
))
ax
.
set_xticklabels
(
parlabels
,
rotation
=
90
)
ax
.
set_yticks
(
range
(
0
,
10
),
parlabels
)
ax
.
set_xticks
(
range
(
0
,
10
))
ax
.
set_xticklabels
(
par
_
labels
,
rotation
=
90
)
ax
.
set_yticks
(
range
(
0
,
10
),
par
_
labels
)
_
=
fig
.
colorbar
(
im
)
plt
.
show
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment