Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
SBI MEA Model
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Doorn, Nina (UT-TNW)
SBI MEA Model
Commits
45bcc21d
Commit
45bcc21d
authored
1 month ago
by
Doorn, Nina (UT-TNW)
Browse files
Options
Downloads
Patches
Plain Diff
Script to perform PPC and calculate PRE of different trained NDEs
parent
af50ab80
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
PosteriorPredictiveChecks.py
+90
-0
90 additions, 0 deletions
PosteriorPredictiveChecks.py
with
90 additions
and
0 deletions
PosteriorPredictiveChecks.py
0 → 100644
+
90
−
0
View file @
45bcc21d
# To perform and visualeze posterior predictive checks on a trained density estimator
import
torch
from
brian2
import
*
from
sbi
import
analysis
,
utils
import
numpy
as
np
from
FeatureExtraction
import
compute_features
,
compute_spikerate
from
Simulator
import
MEAnetsimulate
from
MakeFigures
import
rasterplot
import
pickle
num_feats
=
15
# Number of MEA features
num_params
=
10
# Number of free parameters of the model
prior_min
=
[
1.5
,
0.5
,
0.1
,
0.5
,
0.05
,
0.
,
0.1
,
150
,
0.005
,
0.
]
prior_max
=
[
7
,
2
,
10
,
10
,
1
,
1
,
0.6
,
1200
,
0.3
,
0.005
]
prior
=
utils
.
BoxUniform
(
low
=
torch
.
tensor
(
prior_min
),
high
=
torch
.
tensor
(
prior_max
))
prior_limits
=
[[
1.5
,
7
],
[
0.5
,
2
],
[
0.1
,
10
],
[
0.5
,
10
],
[
0.05
,
1
],
[
0.
,
1
],
[
0.1
,
0.6
],
[
150
,
1200
],
[
0.005
,
0.3
],
[
0.
,
0.005
]]
par_labels
=
[
'
noise
'
,
'
$g_{Na}$
'
,
'
$g_{K}$
'
,
'
$g_{AHP}$
'
,
'
$g_{AMPA}$
'
,
'
$g_{NMDA}$
'
,
'
Conn%
'
,
r
'
$\tau_{D}$
'
,
'
U (STD)
'
,
'
U asyn
'
]
feat_labels
=
[
'
MFR
'
,
'
NBR
'
,
'
NBD
'
,
'
PSIB
'
,
'
#FBs
'
,
'
CVIBI
'
,
'
mean CC
'
,
'
sd CC
'
,
'
mean ISI CC
'
,
'
sd ISI CC
'
,
'
ISI dist
'
,
'
mean ISI
'
,
'
sd ISI temp
'
,
'
sd isi elec
'
,
'
MAC
'
]
# load your trained density estimator
with
open
(
'
TrainedNDE
'
,
'
rb
'
)
as
f
:
# with open('Posterior_Features', 'rb') as f:
posterior
=
pickle
.
load
(
f
)
embedding_net
=
False
# true means NDE was trained with embedding net on spike rates per electrode
# false mean NDE was trained on 15 MEA features
# Visualize the performance on one set of ground-truth parameters
# set of ground-truth parameters (define yourself):
test_params
=
torch
.
as_tensor
([
4
,
1
,
0.4
,
5
,
0.2
,
0.4
,
0.15
,
200
,
0.1
,
0.0001
])
# or alternatively, draw a random sample from prior:
# test_params = prior.sample((1,))
# perform a simulations with the parameters
APs
,
simtime
,
transient
,
fs
=
MEAnetsimulate
(
test_params
)
rasterplot
(
APs
,
'
PPC1_Feat_pre
'
,
1
/
fs
,
transient
,
simtime
,
'
black
'
)
if
embedding_net
:
numelectrodes
=
12
time_bin
=
100e-3
spikerate
=
compute_spikerate
(
APs
,
simtime
/
second
,
transient
/
second
,
fs
,
time_bin
)
spikeratet
=
torch
.
as_tensor
(
spikerate
)
observation
=
spikeratet
.
reshape
(
1
,
-
1
)
else
:
observation
=
torch
.
as_tensor
(
compute_features
(
APs
,
simtime
,
transient
,
fs
))
posterior
.
set_default_x
(
observation
)
est_params
=
posterior
.
map
()
samples
=
posterior
.
sample
((
1000
,))
_
=
analysis
.
pairplot
(
samples
,
diag
=
'
kde
'
,
ticks
=
prior_limits
,
upper
=
'
kde
'
,
points
=
[
est_params
,
test_params
],
points_colors
=
[
'
#EF6F6C
'
,
'
#6B0504
'
],
points_offdiag
=
{
'
markersize
'
:
8
},
limits
=
prior_limits
,
figsize
=
(
6
,
6
),
labels
=
par_labels
)
plt
.
show
()
# Run a simulation with the MAP of the posterior to see if it matches original
APsres
,
simtime
,
transient
,
fs
=
MEAnetsimulate
(
est_params
)
rasterplot
(
APsres
,
'
PPC1_Feat_post
'
,
1
/
fs
,
transient
,
simtime
,
'
black
'
)
if
not
embedding_net
:
model_prediction
=
torch
.
as_tensor
(
compute_features
(
APsres
,
simtime
,
transient
,
fs
))
# calculate the PRE
def
normalize_parameters
(
params
,
prior_min
,
prior_max
):
if
not
(
params
.
shape
[
-
1
]
==
prior_min
.
shape
[
0
]
==
prior_max
.
shape
[
0
]):
raise
ValueError
(
"
Mismatch between number of parameters and prior range dimensions.
"
)
return
(
params
-
prior_min
)
/
(
prior_max
-
prior_min
)
def
compute_pre
(
posterior_samples
,
ground_truth
):
"""
Computes the Posterior Recovery Error (PRE) for given posterior samples and ground truth parameters.
"""
if
posterior_samples
.
shape
[
1
]
!=
ground_truth
.
shape
[
0
]:
raise
ValueError
(
"
Mismatch between number of parameters in samples and ground truth.
"
)
pre
=
torch
.
mean
((
posterior_samples
-
ground_truth
)
**
2
,
dim
=
0
)
pre
=
pre
.
numpy
()
return
pre
norm_samps
=
normalize_parameters
(
samples
,
np
.
array
(
prior_min
),
np
.
array
(
prior_max
))
norm_GT
=
normalize_parameters
(
test_params
,
np
.
array
(
prior_min
),
np
.
array
(
prior_max
))
pre
=
compute_pre
(
norm_samps
,
norm_GT
)
for
s
,
n
in
zip
(
par_labels
,
pre
):
print
(
f
"
PRE
{
s
}
:
{
n
}
"
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment