From fc0e1cbf2c69d0406d74282a41ba56687f8afcd0 Mon Sep 17 00:00:00 2001
From: "Doorn, Nina (UT-TNW)" <n.doorn-1@utwente.nl>
Date: Mon, 17 Mar 2025 15:36:53 +0100
Subject: [PATCH] Splitted Simulator.py into simulator and this
 FeatureExtraction.py that will also be called from TrainNDE script.

---
 FeatureExtraction.py | 207 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 207 insertions(+)
 create mode 100644 FeatureExtraction.py

diff --git a/FeatureExtraction.py b/FeatureExtraction.py
new file mode 100644
index 0000000..d1be436
--- /dev/null
+++ b/FeatureExtraction.py
@@ -0,0 +1,207 @@
+# Functions to compute the spikerate (input for embedding network) and the 15 MEA features
+# (input to regular NDE used in the manuscript).
+
+from brian2 import *
+from numpy import *
+from scipy.signal import find_peaks
+from scipy.fft import fft, fftfreq
+from scipy.stats import norm
+from itertools import combinations
+import numpy as np
+
+def compute_spikerate(APs, rec_time, transient, fs, time_bin):
+    """Compute spike rates for input to embedding network or to MEA feature calculation"""
+    APs_wot = APs[APs[:, 1] > transient * fs, :]
+    APs_wot[:, 1] = APs_wot[:, 1] - transient * fs
+    APs=APs_wot
+    electrodes = APs[:, 0]
+    timestamps = APs[:, 1]
+    electrode_ids = np.unique(electrodes)
+    bin_size = int(time_bin * fs)
+    min_time = transient * fs
+    max_time = rec_time * fs
+    num_bins = int((rec_time - transient) / time_bin) + 1
+    spike_counts = np.zeros((len(electrode_ids), num_bins), dtype=np.int16)
+    spike_rates = np.zeros((len(electrode_ids), num_bins), dtype=np.int16)
+
+    for electrode in electrode_ids:
+        electrode_spikes = timestamps[electrodes == electrode]
+        bin_edges = np.arange(0, max_time-min_time + bin_size, bin_size)
+        spike_bins = np.digitize(electrode_spikes, bin_edges) - 1
+        spike_count = np.bincount(spike_bins, minlength=num_bins)
+        spike_counts[electrode, :] = spike_count
+        spike_rates[electrodes, : ] = spike_count / time_bin
+
+    return spike_counts[:,:-1], spike_rates[:,:-1]
+
+def compute_features(APs, simtime, transient, fs):
+    """ Compute the 15 MEA features """
+    num_electrodes = 12  # number of electrodes
+    time_bin = 25 * ms  # timebin to compute network firing rate
+
+    # remove transients
+    APs_wot = APs[APs[:, 1] > transient/second * fs, :]
+    APs_wot[:, 1] = APs_wot[:, 1] - transient/second * fs
+
+    spike_counts, spike_rate = compute_spikerate(APs_wot, (simtime-transient)/second, 0, fs, time_bin/second)
+
+    spike_counts_tot = sum(spike_counts, axis=0)
+    spike_rate_tot = sum(spike_rate, axis=0)
+
+    # Smoothen the spikerate by convolution with gaussian kernel
+    width = 11
+    sigma = 3.0
+    x = np.arange(0, width, 1, float)
+    x = x - width // 2
+    kernel = norm.pdf(x, scale=sigma)
+    kernel /= np.sum(kernel)
+    spikeratesmooth = np.convolve(spike_rate_tot, kernel, mode='same')
+
+    # Detect fragmented bursts on smoothed spikerate
+    MB_th = (1 / 16) * max(spike_rate_tot)
+    peaks, ph = find_peaks(spikeratesmooth, height=MB_th, prominence=(1 / 20) * max(spike_rate_tot))
+
+    # Set parameters for burst detection
+    act_elec = sum(mean(spike_rate, axis=1) > 0.02)  # calculate the number of active electrodes
+
+    start_th = 0.25 * max(spike_rate_tot)  # spikerate threshold to start a burst
+    t_th = int((50 * ms) / time_bin)  # how long it has to surpass threshold for to start burst
+    e_th = 0.5 * act_elec  # how many electrodes need to be active in the burst
+    stop_th = (1 / 50) * max(spike_rate_tot)  # threshold to end a burst
+
+    # Initialize burst detection
+    i = 0
+    NB_count = 0
+    max_NBs = 1000  # maximum amount of to be detected bursts
+    NBs = zeros((max_NBs, 4))
+
+    # Detect NBs
+    while (i + t_th) < len(spike_rate_tot):
+        if (all(spike_rate_tot[i:i + t_th] > start_th)) \
+                & (sum(sum(spike_counts[:, i:i + t_th], axis=1) > t_th) > e_th):
+            NBs[NB_count, 2] = NBs[NB_count, 2] + sum(spike_counts_tot[i:i + t_th])
+            NBs[NB_count, 0] = i
+            i = i + t_th
+            while any(spike_rate_tot[i:i + 2 * t_th] > stop_th):
+                NBs[NB_count, 2] = NBs[NB_count, 2] + spike_counts_tot[i]
+                i = i + 1
+            NBs[NB_count, 3] = sum((peaks > NBs[NB_count, 0]) & (peaks < i))
+            NBs[NB_count, 1] = i
+            NB_count = NB_count + 1
+        else:
+            i = i + 1
+
+    NBs = NBs[0:NB_count, :]
+
+    MNBR = NB_count * 60 * second / (simtime - transient)
+    NBdurations = (array(NBs[:, 1]) - array(NBs[:, 0])) * time_bin / second
+    MNBD = mean(NBdurations)
+    PSIB = sum(NBs[:, 2] / len(APs_wot)) * 100
+    MFR = len(APs_wot) / ((simtime - transient) / second) / num_electrodes
+    IBI = (array(NBs[1:, 0]) - array(NBs[0:-1, 1])) * time_bin / second
+    CVIBI = np.std(IBI) / np.mean(IBI)
+    if NB_count == 0:
+        MNBD = 0.0
+        MNMBs = 0.0
+        NFBs = 0
+    else:
+        NFBs = sum(NBs[:, 3]) / NB_count
+
+    if NB_count < 2:
+        CVIBI = 0.0
+
+    # Calculate MAC metric as defined by Maheswaranathan
+    yf = fft(spike_rate_tot)
+    xf = fftfreq(len(spike_rate_tot), time_bin / second)[:len(spike_rate_tot) // 2]
+    MAC = max(np.abs(yf[1:len(spike_rate_tot)])) / np.abs(yf[0])
+
+    # Calculate cross-correlation between binarized spike trains
+    # Binarize the spike trains
+    all_combinations = list(combinations(list(arange(num_electrodes)), 2))
+    trans_timebin = 0.2 * second  # timebin to transform spiketrains to binary
+    bin_timeseries = list(range(0, int((simtime - transient) / ms), int(trans_timebin / ms)))
+    binary_signal = zeros((num_electrodes, len(bin_timeseries)))
+    for i in range(num_electrodes):
+        signal = spike_counts[i, :]
+        grosssignal = [sum(signal[x:x + int((trans_timebin / time_bin))]) for x in
+                       range(0, len(signal), int(trans_timebin / time_bin))]
+        binary_signal[i, :] = [1 if x > 0 else 0 for x in grosssignal]
+
+    # Calculate coefficients between every pair of electrodes
+    coefficients = []
+    N = len(binary_signal[0, :])
+    for i, j in all_combinations:
+        signal1 = binary_signal[i, :]
+        signal2 = binary_signal[j, :]
+        if (i != j) & (not list(signal1) == list(signal2)):
+            coefficients.append((N * sum(signal1 * signal2) - sum(signal1) * (sum(signal2)))
+                                * ((N * sum(signal1 ** 2) - sum(signal1) ** 2) ** (-0.5))
+                                * ((N * sum(signal2 ** 2) - sum(signal2) ** 2) ** (-0.5)))
+
+    mean_corr = mean(coefficients)
+    sd_corr = std(coefficients)
+
+    if not coefficients:
+        mean_corr = 1
+        sd_corr = 0
+
+    # Compute continuous ISI arrays
+    time_vector = np.arange(0, (simtime - transient) / second, 1 / fs)
+    isi_arrays = np.zeros((num_electrodes, len(time_vector)))
+
+    for electrode in range(num_electrodes):
+        # Extract spike times for the current electrode
+        electrode_spike_times = APs_wot[APs_wot[:, 0] == electrode, 1]
+
+        for i in range(len(electrode_spike_times) - 1):
+            spike1 = electrode_spike_times[i]
+            spike2 = electrode_spike_times[i + 1]
+            tisi = (spike2 - spike1) / fs
+
+            # Fill ISI values in the appropriate range
+            if i == 0:
+                isi_arrays[electrode, 0:spike1] = NaN
+            isi_arrays[electrode, spike1:spike2] = tisi
+            if (i + 1) == (len(electrode_spike_times) - 1):
+                isi_arrays[electrode, spike2:] = NaN
+
+    # Compute ISI measures
+    meanisi_array = np.nanmean(isi_arrays, axis=0)
+    mean_ISI = np.nanmean(meanisi_array)
+    sdmean_ISI = np.nanstd(meanisi_array)
+    sdtime_ISI = np.nanmean(np.nanstd(isi_arrays, axis=0))
+
+    # Calculate the ISI-distance and ISI correlations
+    ISI_distances = np.zeros(len(all_combinations))
+    isicoefficients = np.zeros(len(all_combinations))
+    N = len(isi_arrays[0, :])
+    j = 0
+
+    # Iterate through the electrode combinations
+    for electrode1_key, electrode2_key in all_combinations:
+        # Get the ISI arrays for the selected electrodes
+        isit1_wn = isi_arrays[electrode1_key, :]
+        isit2_wn = isi_arrays[electrode2_key, :]
+        isit1 = isit1_wn[~isnan(isit1_wn)]
+        isit2_wn = isit2_wn[~isnan(isit1_wn)]
+        isit2 = isit2_wn[~isnan(isit2_wn)]
+        isit1 = isit1[~isnan(isit2_wn)]
+
+        isi_diff = isit1 / isit2
+        ISI_distances[j] = np.mean(np.where(isi_diff <= 1, abs(isi_diff - 1), -1 * (1 / isi_diff - 1)))
+
+        if (i != j) & (not list(isit1) == list(isit2)):
+            isicoefficients[j] = ((N * sum(isit1 * isit2) - sum(isit1) * (sum(isit2)))
+                                  * ((N * sum(isit1 ** 2) - sum(isit1) ** 2) ** (-0.5))
+                                  * ((N * sum(isit2 ** 2) - sum(isit2) ** 2) ** (-0.5)))
+
+        j += 1
+
+    mean_ISIcorr = mean(isicoefficients)
+    sd_ISIcorr = std(isicoefficients)
+    ISI_distance = np.mean(ISI_distances)
+
+    return [MFR, MNBR, MNBD, PSIB, NFBs, CVIBI, mean_corr, sd_corr, mean_ISIcorr, sd_ISIcorr, ISI_distance, mean_ISI,
+            sdmean_ISI, sdtime_ISI, MAC]
+
+
-- 
GitLab