From fc0e1cbf2c69d0406d74282a41ba56687f8afcd0 Mon Sep 17 00:00:00 2001 From: "Doorn, Nina (UT-TNW)" <n.doorn-1@utwente.nl> Date: Mon, 17 Mar 2025 15:36:53 +0100 Subject: [PATCH] Splitted Simulator.py into simulator and this FeatureExtraction.py that will also be called from TrainNDE script. --- FeatureExtraction.py | 207 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 FeatureExtraction.py diff --git a/FeatureExtraction.py b/FeatureExtraction.py new file mode 100644 index 0000000..d1be436 --- /dev/null +++ b/FeatureExtraction.py @@ -0,0 +1,207 @@ +# Functions to compute the spikerate (input for embedding network) and the 15 MEA features +# (input to regular NDE used in the manuscript). + +from brian2 import * +from numpy import * +from scipy.signal import find_peaks +from scipy.fft import fft, fftfreq +from scipy.stats import norm +from itertools import combinations +import numpy as np + +def compute_spikerate(APs, rec_time, transient, fs, time_bin): + """Compute spike rates for input to embedding network or to MEA feature calculation""" + APs_wot = APs[APs[:, 1] > transient * fs, :] + APs_wot[:, 1] = APs_wot[:, 1] - transient * fs + APs=APs_wot + electrodes = APs[:, 0] + timestamps = APs[:, 1] + electrode_ids = np.unique(electrodes) + bin_size = int(time_bin * fs) + min_time = transient * fs + max_time = rec_time * fs + num_bins = int((rec_time - transient) / time_bin) + 1 + spike_counts = np.zeros((len(electrode_ids), num_bins), dtype=np.int16) + spike_rates = np.zeros((len(electrode_ids), num_bins), dtype=np.int16) + + for electrode in electrode_ids: + electrode_spikes = timestamps[electrodes == electrode] + bin_edges = np.arange(0, max_time-min_time + bin_size, bin_size) + spike_bins = np.digitize(electrode_spikes, bin_edges) - 1 + spike_count = np.bincount(spike_bins, minlength=num_bins) + spike_counts[electrode, :] = spike_count + spike_rates[electrodes, : ] = spike_count / time_bin + + return spike_counts[:,:-1], spike_rates[:,:-1] + +def compute_features(APs, simtime, transient, fs): + """ Compute the 15 MEA features """ + num_electrodes = 12 # number of electrodes + time_bin = 25 * ms # timebin to compute network firing rate + + # remove transients + APs_wot = APs[APs[:, 1] > transient/second * fs, :] + APs_wot[:, 1] = APs_wot[:, 1] - transient/second * fs + + spike_counts, spike_rate = compute_spikerate(APs_wot, (simtime-transient)/second, 0, fs, time_bin/second) + + spike_counts_tot = sum(spike_counts, axis=0) + spike_rate_tot = sum(spike_rate, axis=0) + + # Smoothen the spikerate by convolution with gaussian kernel + width = 11 + sigma = 3.0 + x = np.arange(0, width, 1, float) + x = x - width // 2 + kernel = norm.pdf(x, scale=sigma) + kernel /= np.sum(kernel) + spikeratesmooth = np.convolve(spike_rate_tot, kernel, mode='same') + + # Detect fragmented bursts on smoothed spikerate + MB_th = (1 / 16) * max(spike_rate_tot) + peaks, ph = find_peaks(spikeratesmooth, height=MB_th, prominence=(1 / 20) * max(spike_rate_tot)) + + # Set parameters for burst detection + act_elec = sum(mean(spike_rate, axis=1) > 0.02) # calculate the number of active electrodes + + start_th = 0.25 * max(spike_rate_tot) # spikerate threshold to start a burst + t_th = int((50 * ms) / time_bin) # how long it has to surpass threshold for to start burst + e_th = 0.5 * act_elec # how many electrodes need to be active in the burst + stop_th = (1 / 50) * max(spike_rate_tot) # threshold to end a burst + + # Initialize burst detection + i = 0 + NB_count = 0 + max_NBs = 1000 # maximum amount of to be detected bursts + NBs = zeros((max_NBs, 4)) + + # Detect NBs + while (i + t_th) < len(spike_rate_tot): + if (all(spike_rate_tot[i:i + t_th] > start_th)) \ + & (sum(sum(spike_counts[:, i:i + t_th], axis=1) > t_th) > e_th): + NBs[NB_count, 2] = NBs[NB_count, 2] + sum(spike_counts_tot[i:i + t_th]) + NBs[NB_count, 0] = i + i = i + t_th + while any(spike_rate_tot[i:i + 2 * t_th] > stop_th): + NBs[NB_count, 2] = NBs[NB_count, 2] + spike_counts_tot[i] + i = i + 1 + NBs[NB_count, 3] = sum((peaks > NBs[NB_count, 0]) & (peaks < i)) + NBs[NB_count, 1] = i + NB_count = NB_count + 1 + else: + i = i + 1 + + NBs = NBs[0:NB_count, :] + + MNBR = NB_count * 60 * second / (simtime - transient) + NBdurations = (array(NBs[:, 1]) - array(NBs[:, 0])) * time_bin / second + MNBD = mean(NBdurations) + PSIB = sum(NBs[:, 2] / len(APs_wot)) * 100 + MFR = len(APs_wot) / ((simtime - transient) / second) / num_electrodes + IBI = (array(NBs[1:, 0]) - array(NBs[0:-1, 1])) * time_bin / second + CVIBI = np.std(IBI) / np.mean(IBI) + if NB_count == 0: + MNBD = 0.0 + MNMBs = 0.0 + NFBs = 0 + else: + NFBs = sum(NBs[:, 3]) / NB_count + + if NB_count < 2: + CVIBI = 0.0 + + # Calculate MAC metric as defined by Maheswaranathan + yf = fft(spike_rate_tot) + xf = fftfreq(len(spike_rate_tot), time_bin / second)[:len(spike_rate_tot) // 2] + MAC = max(np.abs(yf[1:len(spike_rate_tot)])) / np.abs(yf[0]) + + # Calculate cross-correlation between binarized spike trains + # Binarize the spike trains + all_combinations = list(combinations(list(arange(num_electrodes)), 2)) + trans_timebin = 0.2 * second # timebin to transform spiketrains to binary + bin_timeseries = list(range(0, int((simtime - transient) / ms), int(trans_timebin / ms))) + binary_signal = zeros((num_electrodes, len(bin_timeseries))) + for i in range(num_electrodes): + signal = spike_counts[i, :] + grosssignal = [sum(signal[x:x + int((trans_timebin / time_bin))]) for x in + range(0, len(signal), int(trans_timebin / time_bin))] + binary_signal[i, :] = [1 if x > 0 else 0 for x in grosssignal] + + # Calculate coefficients between every pair of electrodes + coefficients = [] + N = len(binary_signal[0, :]) + for i, j in all_combinations: + signal1 = binary_signal[i, :] + signal2 = binary_signal[j, :] + if (i != j) & (not list(signal1) == list(signal2)): + coefficients.append((N * sum(signal1 * signal2) - sum(signal1) * (sum(signal2))) + * ((N * sum(signal1 ** 2) - sum(signal1) ** 2) ** (-0.5)) + * ((N * sum(signal2 ** 2) - sum(signal2) ** 2) ** (-0.5))) + + mean_corr = mean(coefficients) + sd_corr = std(coefficients) + + if not coefficients: + mean_corr = 1 + sd_corr = 0 + + # Compute continuous ISI arrays + time_vector = np.arange(0, (simtime - transient) / second, 1 / fs) + isi_arrays = np.zeros((num_electrodes, len(time_vector))) + + for electrode in range(num_electrodes): + # Extract spike times for the current electrode + electrode_spike_times = APs_wot[APs_wot[:, 0] == electrode, 1] + + for i in range(len(electrode_spike_times) - 1): + spike1 = electrode_spike_times[i] + spike2 = electrode_spike_times[i + 1] + tisi = (spike2 - spike1) / fs + + # Fill ISI values in the appropriate range + if i == 0: + isi_arrays[electrode, 0:spike1] = NaN + isi_arrays[electrode, spike1:spike2] = tisi + if (i + 1) == (len(electrode_spike_times) - 1): + isi_arrays[electrode, spike2:] = NaN + + # Compute ISI measures + meanisi_array = np.nanmean(isi_arrays, axis=0) + mean_ISI = np.nanmean(meanisi_array) + sdmean_ISI = np.nanstd(meanisi_array) + sdtime_ISI = np.nanmean(np.nanstd(isi_arrays, axis=0)) + + # Calculate the ISI-distance and ISI correlations + ISI_distances = np.zeros(len(all_combinations)) + isicoefficients = np.zeros(len(all_combinations)) + N = len(isi_arrays[0, :]) + j = 0 + + # Iterate through the electrode combinations + for electrode1_key, electrode2_key in all_combinations: + # Get the ISI arrays for the selected electrodes + isit1_wn = isi_arrays[electrode1_key, :] + isit2_wn = isi_arrays[electrode2_key, :] + isit1 = isit1_wn[~isnan(isit1_wn)] + isit2_wn = isit2_wn[~isnan(isit1_wn)] + isit2 = isit2_wn[~isnan(isit2_wn)] + isit1 = isit1[~isnan(isit2_wn)] + + isi_diff = isit1 / isit2 + ISI_distances[j] = np.mean(np.where(isi_diff <= 1, abs(isi_diff - 1), -1 * (1 / isi_diff - 1))) + + if (i != j) & (not list(isit1) == list(isit2)): + isicoefficients[j] = ((N * sum(isit1 * isit2) - sum(isit1) * (sum(isit2))) + * ((N * sum(isit1 ** 2) - sum(isit1) ** 2) ** (-0.5)) + * ((N * sum(isit2 ** 2) - sum(isit2) ** 2) ** (-0.5))) + + j += 1 + + mean_ISIcorr = mean(isicoefficients) + sd_ISIcorr = std(isicoefficients) + ISI_distance = np.mean(ISI_distances) + + return [MFR, MNBR, MNBD, PSIB, NFBs, CVIBI, mean_corr, sd_corr, mean_ISIcorr, sd_ISIcorr, ISI_distance, mean_ISI, + sdmean_ISI, sdtime_ISI, MAC] + + -- GitLab