Skip to content
Snippets Groups Projects
Commit 9f47aaa5 authored by Doorn, Nina (UT-TNW)'s avatar Doorn, Nina (UT-TNW)
Browse files

Script showing how NDE was trained in the manuscript with either 15 MEA...

Script showing how NDE was trained in the manuscript with either 15 MEA features or using an embedding network.
parent fc0e1cbf
No related branches found
No related tags found
No related merge requests found
# Script to train a NDE, either on the 15 MEA features, or using an embedding network
# you will need a newer version of the sbi toolit (v0.23.0) to use the embedding network
import torch
import torch.nn as nn
from brian2 import *
from sbi import utils as utils
from sbi.inference import NPE
from sbi.neural_nets.embedding_nets import CNNEmbedding
from sbi.neural_nets import posterior_nn
import os
import numpy as np
from FeatureExtraction import compute_features, compute_spikerate
import pickle
embedding_net = False # choose if you want to train with embedding network, if False then MEA features are used
calculate_feats = False # choose if you want to analyze your own raw spike trains instead of provided simulations
if calculate_feats:
# specify where you have your AP files/spike trains saves
results_dir = '/path/to/APs'
sim_time = 185 # length of your simulation in seconds
transient = 5 # transient that has to be discarded from analysis
fs = 10000 # sampling frequency
num_electrodes = 12 # number of electrodes
num_sims = 100000 # Number of simulations of which results are in the results_dir
num_feats = 15 # Number of MEA features
num_params = 10 # Number of free parameters of the model
prior_min = [1.5, 0.5, 0.1, 0.5, 0.05, 0., 0.1, 150, 0.005, 0.]
prior_max = [7, 2, 10, 10, 1, 1, 0.6, 1200, 0.3, 0.005]
prior = utils.BoxUniform(low=torch.tensor(prior_min), high=torch.tensor(prior_max))
if not embedding_net:
""" Train NDE based on 15 MEA features, as in the paper """
if calculate_feats: # Calculate the features of your own spike trains
results = torch.zeros(num_sims, num_feats)
false_sims = []
i = 0
for filename in os.listdir(results_dir):
f = os.path.join(results_dir, filename)
ind = re.findall(r'\d+', filename)
try:
APs = torch.load(f)
feats = compute_features(APs, sim_time*second, transient*second, fs)
results[i,:] = torch.as_tensor(feats)
except Exception as e:
# Print the file path that caused the error
print(f"Error loading file: {filename}")
print(f"Error message: {e}")
false_sims.append(i) # Save which simulations are not processed
i = i + 1
mask = np.ones(torch.Size([num_sims]), dtype=bool)
mask[false_sims] = False
x = results[mask, :] # only take sims that were procssed
theta = torch.load('Theta100000.pt')
theta = torch.as_tensor(np.float32(theta[mask, :]))
else:
# Use provided simulations used for manuscript
theta = torch.tensor(np.loadtxt('Simulations_modelparameters.csv', delimiter=',')).to(torch.float32)
x = torch.tensor(np.loadtxt('Simulations_MEAfeatures.csv', delimiter=',')).to(torch.float32)
inference = NPE(prior)
inference = inference.append_simulations(theta, x)
density_estimator = inference.train()
posterior = inference.build_posterior(density_estimator)
# Save the trained density estimator
with open('Posterior_Features', 'wb') as f:
pickle.dump(posterior, f)
else:
""" Train NDE with an embedding network trained on firing rates per electrode, as in the rebuttal """
if calculate_feats: # Calculate the firing rates of your own spike trains
time_bin = 100e-3
results = torch.zeros((num_sims, 1, int((sim_time-transient)/time_bin*num_electrodes)))
i = 0
false_sims = []
for filename in os.listdir(results_dir):
f = os.path.join(results_dir, filename)
ind = re.findall(r'\d+', filename)
try:
APs = torch.load(f)
spikerate, _ = compute_spikerate(APs, sim_time, transient, fs, time_bin)
spikeratet = torch.as_tensor(spikerate)
spikeratet = spikeratet.reshape(1,-1)
results[int(ind[0]),:] = spikeratet
except Exception as e:
# Print the file path that caused the error
print(f"Error loading file: {filename}")
print(f"Error message: {e}")
false_sims.append(i)
i = i + 1
mask = np.ones(torch.Size([num_sims]), dtype=bool)
mask[false_sims] = False
results = results[mask, :]
theta = torch.load('Theta100000.pt')
theta = torch.as_tensor(np.float32(theta[mask, :]))
else: # If you want to use the 100.000 simultions from the rebuttal
results = torch.load('ResultsSpikerate100000.pt')
theta = torch.tensor(np.transpose(torch.load('Theta100000mask.pt')))
# The embedding network used for the rebuttal
class CNNEmbedding(nn.Module):
def __init__(self, input_shape=(12, 1800), output_dim=20):
super(CNNEmbedding, self).__init__()
self.input_shape = input_shape
# Local feature extraction (spatial and short-term temporal)
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 5), padding=(1, 2))
self.relu1 = nn.ReLU()
self.pool1 = nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4)) # Reduce temporal resolution
# Temporal feature extraction per electrode (depthwise convolution)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(1, 5), padding=(0, 2), groups=16)
self.relu2 = nn.ReLU()
# Global electrode interaction (1x1 convolution)
self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(1, 1))
self.relu3 = nn.ReLU()
# Dilated convolutions for long-term dependencies
self.conv4 = nn.Conv2d(in_channels=32, out_channels=48, kernel_size=(1, 5), padding=(0, 4), dilation=(1, 2))
self.relu4 = nn.ReLU()
self.pool2 = nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2)) # Further reduce temporal resolution
self.conv5 = nn.Conv2d(in_channels=48, out_channels=64, kernel_size=(1, 5), padding=(0, 8), dilation=(1, 4))
self.relu5 = nn.ReLU()
# Fully connected layers to reduce to embedding dimension
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(self.compute_flattened_size(), 256)
self.relu_fc1 = nn.ReLU()
self.fc2 = nn.Linear(256, output_dim)
def compute_flattened_size(self):
# Compute the output size after convolutions (assuming input shape (12, 3600))
dummy_input = torch.zeros(1, 1, *self.input_shape)
x = self.conv1(dummy_input)
x = self.pool1(x) # Include pooling if used
x = self.conv2(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
return x.numel() // x.shape[0]
def forward(self, x):
x = x.view(x.shape[0], 1, *self.input_shape) # Ensure correct input shape
x = self.relu1(self.conv1(x))
x = self.pool1(x)
x = self.relu2(self.conv2(x))
x = self.relu3(self.conv3(x))
x = self.relu4(self.conv4(x))
x = self.pool2(x)
x = self.relu5(self.conv5(x))
x = self.flatten(x)
x = self.relu_fc1(self.fc1(x))
x = self.fc2(x)
return x
if calculate_feats:
embedding_net = CNNEmbedding(input_shape=(num_electrodes, int((sim_time-transient)/time_bin*num_electrodes)),
output_dim=20)
else:
embedding_net = CNNEmbedding(input_shape=(12, 1800), output_dim=20)
neural_posterior = posterior_nn(model='maf', embedding_net=embedding_net)
inferer = NPE(prior=prior, density_estimator=neural_posterior)
density_estimator = inferer.append_simulations(theta, results).train()
posteriorEMB = inferer.build_posterior(density_estimator)
# Save the trained density estimator
with open('Posterior_Embeddingnet', 'wb') as f:
pickle.dump(posteriorEMB, f)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment