Skip to content
Snippets Groups Projects
Commit af50ab80 authored by Doorn, Nina (UT-TNW)'s avatar Doorn, Nina (UT-TNW)
Browse files

Script to test for model misspecification

parent 9f47aaa5
No related branches found
No related tags found
No related merge requests found
# Code to test for model misspecification as proposed by Schmitt et al. 2024,
# https://doi.org/10.48550/arXiv.2406.03154
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import tight_layout
from sklearn.metrics.pairwise import euclidean_distances
from scipy.stats import gaussian_kde
def whiten_summaries(dat_sim, dat_exp):
"""Transform summary statistics to a standard normal distribution."""
mu = np.mean(dat_sim, axis=0)
sigma = np.cov(dat_sim, rowvar=False)
# Compute whitening transformation
sigma_inv_sqrt = np.linalg.inv(np.linalg.cholesky(sigma)).T # Compute Σ^(-1/2)
# Apply transformation
sim_transformed = (dat_sim - mu) @ sigma_inv_sqrt
exp_transformed = (dat_exp - mu) @ sigma_inv_sqrt
return sim_transformed, exp_transformed
def compute_mmd(dat_sim, dat_exp, max_samples=10000, h=0.1, c=1, beta=1):
"""Compute MMD using a inverse multiquadratic (IMQ) kernel"""
# select only subset of data to dampen computational load.
if dat_sim.shape[0] > max_samples:
dat_sim = dat_sim[np.random.choice(dat_sim.shape[0], max_samples, replace=False)]
if dat_exp.shape[0] > max_samples:
dat_exp = dat_exp[np.random.choice(dat_exp.shape[0], max_samples, replace=False)]
K_XX = ((c + (euclidean_distances(dat_sim, dat_sim) ** 2) / h ** 2) ** -beta).mean()
K_YY = ((c + (euclidean_distances(dat_exp, dat_exp) ** 2) / h ** 2) ** -beta).mean()
K_XY = ((c + (euclidean_distances(dat_sim, dat_exp) ** 2) / h ** 2) ** -beta).mean()
mmd = K_XX + K_YY - 2 * K_XY
return mmd
def estimate_mmd_distributions(simulated_data, real_data_size, num_repeats=1000, max_samples=200):
""" Estimate MMD Distribution under H0 Using samples from simulations """
n = simulated_data.shape[0]
mmd_h0_values = []
for _ in range(num_repeats):
# Shuffle indices and split into large and small samples (size of experimental dataset) without overlap
indices = np.random.permutation(n)
sampleM = simulated_data[indices[:max_samples]]
sampleN = simulated_data[indices[max_samples:max_samples + real_data_size]]
mmd_h0_values.append(compute_mmd(sampleM, sampleN, max_samples))
return np.array(mmd_h0_values)
def detect_model_misspecification(dat_sim, dat_exp, alpha=0.05, max_samples=10000):
"""Detect model misspecification by comparing real data with simulated data."""
# Standardize and whiten both observation
dat_sim, dat_exp = whiten_summaries(dat_sim, dat_exp)
# Compute MMD between real and simulated data
mmd_real_vs_sim = compute_mmd(dat_sim, dat_exp, 10000)
# Estimate MMD distributions under H0 and M*
mmd_h0_distribution = estimate_mmd_distributions(dat_sim, dat_exp.shape[0], max_samples=500)
critical_value = np.percentile(mmd_h0_distribution, 100 * (1 - alpha))
misspecified = mmd_real_vs_sim > critical_value
# Visualization of MMD distributions under H0 and the real MMD value
kde_h0 = gaussian_kde(mmd_h0_distribution)
x_vals = np.linspace(mmd_h0_distribution.min(), mmd_h0_distribution.max(), 500)
plt.figure(figsize=(6, 3))
plt.fill_between(x_vals, kde_h0(x_vals), color='#431853', alpha=0.3)
plt.plot(x_vals, kde_h0(x_vals), color='#431853', linewidth=2, label='Training Model (H0)')
plt.axvline(critical_value, color='#5FA5CD', linewidth=2, label='Critical MMD')
plt.axvline(mmd_real_vs_sim, color='#EF6D6C', linewidth=2, label='MMD Real vs Sim')
plt.yticks([])
plt.legend()
plt.xlabel("MMD")
tight_layout()
plt.show()
return {
"MMD": mmd_real_vs_sim,
"Critical Value": critical_value,
"Misspecified": misspecified,
"p-value": (mmd_h0_distribution > mmd_real_vs_sim).mean()
}
# load summary features of simulations and all available experiments
dat_sim = np.loadtxt("../Simulations_MEAfeatures.csv", delimiter=",")
dat_sim = dat_sim[~np.isnan(dat_sim).any(axis=1), :] # remove simulations with NaNs
dat_exp = np.loadtxt("AllExperiments_MEAfeatures.csv", delimiter=",")
result = detect_model_misspecification(dat_sim, dat_exp)
print("Result:", result)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment