Files
Depression_TMS/algorithm_V1/runDecoder.py
2026-06-01 13:18:36 +08:00

987 lines
34 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
runDecoder.py - BDF EEG Depression Assessment
功能:
1. 读取 raw_data 文件夹的第一个 .bdf 格式文件
2. 预处理坏通道剔除、50Hz陷波、0.8-40Hz带通、幅值过滤、ICA去伪迹
3. 调用 infer_pth.py 中的 predict_hc_mdd 进行 HC/MDD 分类预测
4. 保存图表EEG、PSD、Topomap
5. 生成 ResultData.txt
"""
import matplotlib
matplotlib.use('Agg')
import numpy as np
import os
import shutil
import scipy.signal as signal
import matplotlib.pyplot as plt
import mne
from mne.preprocessing import ICA
# ==========================
# Config - 预处理参数
# ==========================
# 滤波参数
BANDPASS_LOW = 0.8
BANDPASS_HIGH = 40.0
NOTCH_FREQS = [50, 100] # 工频陷波
# 幅值过滤阈值 (μV)
AMPLITUDE_MIN_UV = -200.0
AMPLITUDE_MAX_UV = 200.0
# ICA 参数
ICA_N_COMPONENTS = 20 # 使用绝对数量而非比例
ICA_RANDOM_STATE = 97
ICA_MAX_ITER = 800
# 坏段检测阈值 (μV)
BAD_SEGMENT_THRESHOLD_UV = 350.0
# 默认采样率
DEFAULT_FS = 250.0
# 画图参数
EEG_PLOT_SECONDS = 10
PSD_FMIN, PSD_FMAX = 0.8, 45.0
FIXED_EEG_IDXS = [23, 47, 39, 6, 2, 21, 35, 57] # 0-based index
FIXED_EEG_LABELS = ["C5", "O1", "TP7", "FPZ", "PO6", "P4", "AF7", "AF3"]
# 频段定义
BANDS_METRICS = {
"Delta": (1.0, 4.0),
"Theta": (4.0, 8.0),
"Alpha": (8.0, 13.0),
"Beta": (13.0, 30.0),
}
TOTAL_POWER_BAND = (1.0, 50.0)
BANDS_TOPOMAP = {
"delta": (1.0, 4.0),
"theta": (4.0, 8.0),
"alpha": (8.0, 13.0),
"beta": (13.0, 30.0),
"broad": (1.0, 30.0),
}
# PSD 参数
PSD_NPERSEG = 1024 # FFT 窗口大小,越大频率分辨率越高
EPS = 1e-12
# 脑地形图颜色范围参数
# 设置为 None 表示自动范围,设置为 (min, max) 固定范围
TOPOMAP_VMIN = None # 例如: -1.0
TOPOMAP_VMAX = None # 例如: 1.0
# 或者使用对称范围(相对于均值的倍数)
TOPOMAP_SYM_SCALE = 1.5 # 颜色范围 = 均值 ± std * SYM_SCALE
# 脑地形图圆形大小参数 (0.08 - 0.15 范围)
# 数值越小圆形越小,越大圆形越大
TOPOMAP_SPHERE_RADIUS = 0.12
# 边界处理参数
# 滤波前 padding 秒数,用于消除边界振铃效应
FILTER_PAD_SEC = 1.0
# ==========================
# 数据文件读取
# ==========================
def load_data_file(file_path: str) -> tuple:
"""根据文件扩展名读取数据,返回 MNE Raw 对象"""
ext = os.path.splitext(file_path)[1].lower()
if ext == ".bdf":
return load_bdf_file(file_path)
elif ext == ".mat":
return load_mat_file(file_path)
else:
raise ValueError(f"不支持的文件格式: {ext}")
def load_bdf_file(bdf_path: str) -> tuple:
"""读取 .bdf 格式文件,返回 MNE Raw 对象"""
print(f"[INFO] Reading BDF file: {bdf_path}")
raw = mne.io.read_raw_bdf(bdf_path, preload=True, verbose=False)
try:
raw.set_montage("standard_1020", on_missing="ignore")
except Exception as e:
print(f"[WARN] Failed to set standard_1020 montage: {e}")
sfreq = raw.info['sfreq']
ch_names = raw.ch_names
n_channels = len(ch_names)
duration = raw.times[-1] - raw.times[0]
print(f"[INFO] Channels: {n_channels}, Duration: {duration:.2f}s, SFreq: {sfreq:.2f}Hz")
return raw, sfreq, ch_names
def load_mat_file(mat_path: str) -> tuple:
"""读取 .mat 格式文件,返回 MNE Raw 对象"""
print(f"[INFO] Reading MAT file: {mat_path}")
import scipy.io
mat = scipy.io.loadmat(mat_path)
eeg = mat['eeg'][0, 0]
# 提取数据
data = eeg['data'] # (T, C)
if data.shape[0] < data.shape[1]:
data = data.T # 确保是 (T, C)
data = data.astype(np.float64) # 确保是 float
# 提取采样率
sfreq = float(eeg['sample_rate'][0, 0])
# 提取通道名称
ch_names_raw = eeg['electrode_name']
if ch_names_raw.ndim == 2:
ch_names = [str(ch[0]) if isinstance(ch[0], np.bytes_) else str(ch[0]) for ch in ch_names_raw[0]]
else:
ch_names = [f"EEG{i+1}" for i in range(data.shape[1])]
n_channels = data.shape[1]
n_samples = data.shape[0]
duration = n_samples / sfreq
print(f"[INFO] Channels: {n_channels}, Duration: {duration:.2f}s, SFreq: {sfreq:.2f}Hz")
# 创建 MNE Raw 对象
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=["eeg"] * n_channels)
raw = mne.io.RawArray(data.T, info, verbose=False) # (T, C) -> (C, T)
# 尝试设置通道位置
try:
electrode_xyz = eeg['electrode_xyz'] # (64, 3)
if electrode_xyz.shape[0] == n_channels:
ch_pos = {}
for i, name in enumerate(ch_names):
ch_pos[name] = electrode_xyz[i] / 1000.0 # 转换为米
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')
info.set_montage(montage)
print("[INFO] Applied electrode positions from mat file")
else:
raw.set_montage("standard_1020", on_missing="ignore")
except Exception as e:
print(f"[WARN] Failed to set montage from mat file: {e}")
try:
raw.set_montage("standard_1020", on_missing="ignore")
except:
pass
return raw, sfreq, ch_names
# ==========================
# 坏通道检测
# ==========================
def detect_bad_channels(raw: mne.io.RawArray, z_thresh: float = 3.0) -> list:
"""检测坏通道:全零/常数通道 + MAD z-score 离群通道"""
data = raw.get_data()
ch_names = raw.ch_names
bad_chs = []
ptp = np.ptp(data, axis=1)
std = np.std(data, axis=1)
for i, (p, s) in enumerate(zip(ptp, std)):
if p < 1e-12 or s < 1e-12:
bad_chs.append(ch_names[i])
valid_mask = np.array([ch not in bad_chs for ch in ch_names])
if valid_mask.sum() > 2:
valid_ptp = ptp[valid_mask]
med = np.median(valid_ptp)
mad = np.median(np.abs(valid_ptp - med)) + 1e-30
z = np.abs(ptp - med) / (mad * 1.4826)
for i, zv in enumerate(z):
if zv > z_thresh and ch_names[i] not in bad_chs:
bad_chs.append(ch_names[i])
if bad_chs:
print(f"[INFO] Bad channels detected: {bad_chs}")
else:
print("[INFO] No bad channels detected")
return bad_chs
# ==========================
# 坏段标注
# ==========================
def annotate_bad_segments(raw: mne.io.RawArray, peak_to_peak_uv: float = 350.0):
"""简单坏段检测按1秒窗口计算峰峰值超过阈值标为 bad"""
peak_to_peak_v = peak_to_peak_uv * 1e-6
win = int(raw.info["sfreq"] * 1.0)
step = int(raw.info["sfreq"] * 0.5)
data = raw.get_data()
n_times = data.shape[1]
onsets = []
durations = []
for start in range(0, n_times - win, step):
seg = data[:, start:start + win]
ptp = np.ptp(seg, axis=1)
if np.any(ptp > peak_to_peak_v):
onsets.append(start / raw.info["sfreq"])
durations.append(win / raw.info["sfreq"])
if len(onsets) > 0:
ann = mne.Annotations(onset=onsets, duration=durations, description=["BAD_SEG"] * len(onsets))
raw.set_annotations(ann)
print(f"[INFO] Annotated {len(onsets)} bad segments")
# ==========================
# 核心预处理函数
# ==========================
def preprocess_bdf(raw: mne.io.RawArray) -> mne.io.RawArray:
"""BDF 数据预处理流程"""
print("[INFO] Starting preprocessing pipeline...")
# 1) 裁剪首尾 2s
crop_sec = 2.0
t_start = crop_sec
t_end = raw.times[-1] - crop_sec
if t_end > t_start:
raw = raw.crop(tmin=t_start, tmax=t_end)
print(f"[INFO] Cropped: removed first/last {crop_sec}s")
# 2) 去直流偏置
data = raw.get_data()
data -= data.mean(axis=1, keepdims=True)
raw._data = data
print("[INFO] Removed DC offset")
# 3) 坏通道检测与插值
bad_chs = detect_bad_channels(raw)
if bad_chs:
raw.info["bads"] = bad_chs
try:
raw_tmp = raw.copy()
raw_tmp.set_montage(raw.get_montage(), on_missing="ignore")
raw_tmp.interpolate_bads(reset_bads=True, verbose=False)
raw = raw_tmp
print(f"[INFO] Bad channels interpolated: {bad_chs}")
except Exception as e:
print(f"[WARN] Bad channel interpolation failed: {e}")
raw.info["bads"] = []
# 4) 50Hz 陷波滤波
print(f"[INFO] Applying notch filter: {NOTCH_FREQS}Hz")
raw.notch_filter(NOTCH_FREQS, fir_design="firwin", verbose=False)
# 5) 0.8-40Hz 带通滤波 (使用 padding 消除边界振铃)
print(f"[INFO] Applying bandpass filter: {BANDPASS_LOW}-{BANDPASS_HIGH}Hz (with {FILTER_PAD_SEC}s padding)")
pad_sec = FILTER_PAD_SEC
raw_length = raw.times[-1]
pad_start = max(0, pad_sec)
pad_end = max(0, pad_sec)
if raw_length > pad_start + pad_end + 1.0:
raw.filter(BANDPASS_LOW, BANDPASS_HIGH, fir_design="firwin",
pad='reflect', verbose=False)
raw = raw.crop(tmin=pad_sec, tmax=raw_length - pad_sec)
print(f"[INFO] Removed {pad_sec}s padding from each side after filtering")
else:
raw.filter(BANDPASS_LOW, BANDPASS_HIGH, fir_design="firwin", verbose=False)
print(f"[WARN] Data too short ({raw_length:.1f}s), skipping padding")
# 6) 幅值过滤
print(f"[INFO] Applying amplitude filter: [{AMPLITUDE_MIN_UV}, {AMPLITUDE_MAX_UV}] μV")
amplitude_thresh_v = AMPLITUDE_MAX_UV * 1e-6
d = raw.get_data()
mask = np.abs(d) > amplitude_thresh_v
n_clipped = int(mask.sum())
if n_clipped > 0:
d[mask] = 0.0
raw._data = d
print(f"[INFO] Amplitude clipping: {n_clipped} samples exceeded ±200μV, set to 0")
# 7) 坏段标注
annotate_bad_segments(raw, peak_to_peak_uv=BAD_SEGMENT_THRESHOLD_UV)
# 8) ICA 去伪迹
print("[INFO] Running ICA for artifact removal...")
ica = ICA(n_components=ICA_N_COMPONENTS, random_state=ICA_RANDOM_STATE,
max_iter=ICA_MAX_ITER, method="fastica")
ica.fit(raw, reject_by_annotation=True, verbose=False)
try:
eog_inds, _ = ica.find_bads_eog(raw, verbose=False)
if eog_inds:
ica.exclude.extend(eog_inds)
print(f"[INFO] ICA exclude EOG components: {eog_inds}")
except Exception as e:
print(f"[WARN] ICA EOG detection skipped: {e}")
raw_clean = ica.apply(raw.copy(), verbose=False)
# 9) ICA 后再次去直流
d = raw_clean.get_data()
d -= d.mean(axis=1, keepdims=True)
raw_clean._data = d
print("[INFO] Preprocessing completed")
return raw_clean
# ==========================
# 输出目录管理
# ==========================
def ensure_outdir(out_root: str) -> str:
"""确保输出目录存在,并清空旧文件(保留 ResultData.txt"""
if os.path.exists(out_root):
for filename in os.listdir(out_root):
if filename == "ResultData.txt":
continue
file_path = os.path.join(out_root, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f"[WARN] Failed to delete {file_path}: {e}")
else:
os.makedirs(out_root, exist_ok=True)
return out_root
# ==========================
# 通道分区
# ==========================
def _norm_name(s: str) -> str:
return str(s).strip().upper().replace(" ", "")
def build_channel_index_map(ch_names, n_channels: int):
if not ch_names or len(ch_names) != n_channels:
return {}
return {_norm_name(nm): i for i, nm in enumerate(ch_names)}
def pick_indices_by_names(name_to_idx, names):
idx = []
for n in names:
nn = _norm_name(n)
if nn in name_to_idx:
idx.append(name_to_idx[nn])
return sorted(list(set(idx)))
def _fallback_region_indices(n_channels: int):
a = int(n_channels * 0.33)
b = int(n_channels * 0.66)
return (
list(range(0, a)), # frontal
list(range(a, b)), # central
list(range(b, n_channels)), # parietal
list(range(0, max(2, a // 2))), # prefrontal
list(range(b, n_channels)), # posterior
[i for i in range(n_channels) if i % 2 == 0], # left
[i for i in range(n_channels) if i % 2 == 1], # right
)
def get_region_indices(name_to_idx, n_channels: int):
if not name_to_idx:
return _fallback_region_indices(n_channels)
central_names = ["CZ","C1","C2","C3","C4","C5","C6","CP1","CP2","CP3","CP4","CP5","CP6","FC1","FC2","FC3","FC4","FC5","FC6"]
frontal_names = ["FZ","F1","F2","F3","F4","F5","F6","F7","F8","AF3","AF4","AF7","AF8","FPZ","FP1","FP2","FCZ"]
parietal_names = ["PZ","P1","P2","P3","P4","P5","P6","POZ","PO3","PO4","PO5","PO6","PO7","PO8","CPZ"]
prefrontal_names = ["FP1","FP2","FPZ","AF3","AF4","AF7","AF8"]
posterior_names = ["O1","O2","OZ","PO7","PO8","PO3","PO4","PZ","P3","P4","P1","P2"]
left_names = ["FP1","AF3","AF7","F3","F5","F7"]
right_names = ["FP2","AF4","AF8","F4","F6","F8"]
central = pick_indices_by_names(name_to_idx, central_names)
frontal = pick_indices_by_names(name_to_idx, frontal_names)
parietal = pick_indices_by_names(name_to_idx, parietal_names)
prefrontal = pick_indices_by_names(name_to_idx, prefrontal_names)
posterior = pick_indices_by_names(name_to_idx, posterior_names)
left = pick_indices_by_names(name_to_idx, left_names)
right = pick_indices_by_names(name_to_idx, right_names)
if not (central and frontal and parietal and prefrontal and posterior):
fb = _fallback_region_indices(n_channels)
frontal2, central2, parietal2, prefrontal2, posterior2, left2, right2 = fb
frontal = frontal if frontal else frontal2
central = central if central else central2
parietal = parietal if parietal else parietal2
prefrontal = prefrontal if prefrontal else prefrontal2
posterior = posterior if posterior else posterior2
left = left if left else left2
right = right if right else right2
return frontal, central, parietal, prefrontal, posterior, left, right
# ==========================
# PSD 和频段功率计算
# ==========================
def welch_psd(eeg_tc, fs):
"""计算 PSD"""
nperseg = min(PSD_NPERSEG, eeg_tc.shape[0])
noverlap = int(nperseg * 0.75)
freqs, pxx = signal.welch(
eeg_tc, fs=fs, nperseg=nperseg, noverlap=noverlap,
axis=0, scaling="density"
)
return freqs, pxx
def band_power_from_psd(freqs, pxx_fc, band):
lo, hi = band
m = (freqs >= lo) & (freqs < hi)
if not np.any(m):
return np.zeros((pxx_fc.shape[1],), dtype=np.float32)
return np.trapz(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
def region_mean_power(freqs, pxx_fc, idx, band) -> float:
if not idx:
return 0.0
pw = band_power_from_psd(freqs, pxx_fc, band)
return float(np.mean(pw[idx]))
def compute_iaf(freqs, pxx_fc, posterior_idx):
lo, hi = BANDS_METRICS["Alpha"]
m = (freqs >= lo) & (freqs <= hi)
if not np.any(m) or not posterior_idx:
return 0.0
spec = np.mean(pxx_fc[:, posterior_idx], axis=1)
sub = spec[m]
fsub = freqs[m]
return float(fsub[int(np.argmax(sub))])
# ==========================
# 画图函数
# ==========================
def plot_eeg_waveforms(data_uv_tc, fs, ch_names, out_dir, seconds=10, t_start_sec=30.0):
"""画 EEG 波形图(固定通道)"""
T, C = data_uv_tc.shape
start_sample = int(t_start_sec * fs)
end_sample = int(min(T, start_sample + seconds * fs))
if start_sample >= T:
start_sample = max(0, T - int(seconds * fs))
end_sample = T
print(f"[WARN] t_start_sec={t_start_sec}s exceeds data, using last {seconds}s")
seg_samples = end_sample - start_sample
x = np.arange(seg_samples) / fs + t_start_sec
# 过滤有效索引
idxs = [i for i in FIXED_EEG_IDXS if 0 <= i < C]
if len(idxs) < len(FIXED_EEG_IDXS):
missing = [i for i in FIXED_EEG_IDXS if not (0 <= i < C)]
print(f"[WARN] Some indices out of range (C={C}): {missing}")
if len(idxs) == 0:
raise RuntimeError(f"No valid indices for data (C={C})")
picked_names = []
for idx in idxs:
pos = FIXED_EEG_IDXS.index(idx)
std_label = FIXED_EEG_LABELS[pos] if pos < len(FIXED_EEG_LABELS) else f"CH{idx}"
if ch_names and idx < len(ch_names):
picked_names.append(std_label)
else:
picked_names.append(std_label)
fig_h = 1.4 * len(idxs) + 1
fig, axes = plt.subplots(len(idxs), 1, figsize=(10, fig_h), sharex=True)
if len(idxs) == 1:
axes = [axes]
seg = data_uv_tc[start_sample:end_sample, idxs].T
lo = float(np.percentile(seg, 1))
hi = float(np.percentile(seg, 99))
m = max(abs(lo), abs(hi), 50.0)
for ax, ch_idx, nm in zip(axes, idxs, picked_names):
y = data_uv_tc[start_sample:end_sample, ch_idx]
ax.plot(x, y, linewidth=1.2)
ax.set_ylabel("uV")
ax.set_title(nm, loc="left", fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_ylim(-m, m)
axes[-1].set_xlabel("Time (s)")
plt.tight_layout()
out_path = os.path.join(out_dir, "EEG.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] EEG waveform saved: {out_path}")
def plot_psd(eeg_uV_tc, fs, ch_names, out_dir):
"""画 PSD 图"""
C = eeg_uV_tc.shape[1]
chosen_idx = []
if ch_names:
mp = {n.upper(): i for i, n in enumerate(ch_names)}
for p in ["C3", "C4", "CZ"]:
if p in mp:
chosen_idx.append(mp[p])
if len(chosen_idx) < 3:
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
stds.sort(key=lambda x: x[1], reverse=True)
for i, _ in stds:
if i not in chosen_idx:
chosen_idx.append(i)
if len(chosen_idx) == 3:
break
chosen_name = [ch_names[i] for i in chosen_idx]
else:
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
stds.sort(key=lambda x: x[1], reverse=True)
chosen_idx = [i for i, _ in stds[:3]]
chosen_name = [f"CH{i}" for i in chosen_idx]
# 增大 nperseg 提高频率分辨率
nperseg = min(PSD_NPERSEG, eeg_uV_tc.shape[0])
noverlap = int(nperseg * 0.75)
fig = plt.figure(figsize=(7.5, 4.8))
for idx, nm in zip(chosen_idx, chosen_name):
f, pxx = signal.welch(eeg_uV_tc[:, idx], fs=fs, nperseg=nperseg, noverlap=noverlap)
mask = (f >= PSD_FMIN) & (f <= PSD_FMAX)
p_db = 10 * np.log10(pxx[mask] + 1e-20)
plt.plot(f[mask], p_db, linewidth=1.8, label=nm)
plt.xlabel("Hz")
plt.ylabel("Power (dB)")
plt.title("PSD")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
out_path = os.path.join(out_dir, "psd.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] PSD saved: {out_path}")
def _get_standard_1020_channel_indices(raw):
"""获取符合 standard_1020 montage 的通道索引和名称"""
try:
standard_montage = mne.channels.make_standard_montage("standard_1020")
standard_names_upper = {ch.upper() for ch in standard_montage.ch_names}
standard_name_map = {ch.upper(): ch for ch in standard_montage.ch_names}
data_ch_names = raw.ch_names
exclude_names = {"A1", "A2", "M1", "M2", "LE", "RE", "LM", "RM"}
valid_indices = []
valid_names = []
for i, name in enumerate(data_ch_names):
name_upper = name.upper()
if name_upper in standard_names_upper and name_upper not in exclude_names:
valid_indices.append(i)
valid_names.append(standard_name_map[name_upper])
print(f"[INFO] Found {len(valid_indices)}/{len(data_ch_names)} channels matching standard_1020")
return valid_indices, valid_names
except Exception as e:
print(f"[WARN] Failed to get standard_1020 channels: {e}")
return None, None
def compute_band_powers_for_topomap(raw, bands):
"""计算各频段功率,只使用 standard_1020 中有位置的通道"""
# 获取 standard_1020 montage 和位置信息
standard_montage = mne.channels.make_standard_montage("standard_1020")
std_names_upper = {ch.upper() for ch in standard_montage.ch_names}
ch_pos_map = standard_montage.get_positions()['ch_pos']
data_ch_names = raw.ch_names
exclude = {"A1", "A2", "M1", "M2", "LE", "RE", "LM", "RM"}
# 只保留有位置信息的通道
valid_indices = []
valid_names = []
for i, name in enumerate(data_ch_names):
name_upper = name.upper()
if name_upper in std_names_upper and name_upper not in exclude:
if name_upper in ch_pos_map: # 必须有位置
valid_indices.append(i)
valid_names.append(name)
if len(valid_indices) < 8:
return None
data = raw.get_data()
data_standard = data[valid_indices, :]
fs = raw.info["sfreq"]
n_fft = min(PSD_NPERSEG, data_standard.shape[1])
n_overlap = int(n_fft * 0.75)
psds, freqs = mne.time_frequency.psd_array_welch(
data_standard, sfreq=fs,
fmin=min(v[0] for v in bands.values()),
fmax=max(v[1] for v in bands.values()),
n_fft=n_fft, n_overlap=n_overlap,
average="mean", verbose=False
)
out = {"_valid_names": valid_names}
print(f"[DEBUG] PSD: fs={fs}Hz, n_fft={n_fft}, freq_res={fs/n_fft:.3f}Hz/bin")
for k, (fmin, fmax) in bands.items():
idx = np.where((freqs >= fmin) & (freqs < fmax))[0]
if len(idx) == 0:
out[k] = np.zeros(len(valid_indices), dtype=np.float32)
print(f"[DEBUG] {k.upper()}: NO freq bins in [{fmin}-{fmax}]Hz")
continue
print(f"[DEBUG] {k.upper()}: freq bins {freqs[idx[0]]:.2f}-{freqs[idx[-1]]:.2f}Hz (bins {idx[0]}-{idx[-1]}, count={len(idx)})")
# 使用线性功率值 (V^2 -> uV^2: * 1e12)
bp = np.trapz(psds[:, idx], freqs[idx], axis=1) * 1e12
out[k] = bp
print(f"[DEBUG] {k.upper()}: power range [{bp.min():.4f}, {bp.max():.4f}] uV^2, mean={bp.mean():.4f}")
print(f"[INFO] Band powers computed for {len(valid_names)} channels with positions")
return out
def _create_topomap_raw(ch_names):
"""创建只有 standard_1020 通道位置信息的临时 Raw 对象"""
standard_montage = mne.channels.make_standard_montage("standard_1020")
ch_pos_map = standard_montage.get_positions()['ch_pos']
valid_ch_names = []
valid_positions = []
for name in ch_names:
name_upper = name.upper()
if name_upper in ch_pos_map:
valid_ch_names.append(name)
valid_positions.append(ch_pos_map[name_upper])
if len(valid_ch_names) < 8:
return None
ch_pos = {name: pos for name, pos in zip(valid_ch_names, valid_positions)}
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')
info = mne.create_info(ch_names=valid_ch_names, sfreq=250.0, ch_types=["eeg"] * len(valid_ch_names))
info.set_montage(montage)
dummy_data = np.zeros((len(valid_ch_names), 1))
return mne.io.RawArray(dummy_data, info, verbose=False)
def plot_average_topomap(band_values, out_dir):
"""绘制平均拓扑图"""
valid_names = band_values.get("_valid_names", [])
if not valid_names:
return
values = band_values["broad"]
temp_raw = _create_topomap_raw(valid_names)
if temp_raw is None:
return
vmin, vmax = _compute_topomap_vlim([values])
fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.6))
im, _ = mne.viz.plot_topomap(
values, temp_raw.info, axes=ax, show=False, contours=0,
sphere=(0, 0, 0, TOPOMAP_SPHERE_RADIUS), extrapolate='head', border='mean',
cmap='turbo'
)
im.set_clim(vmin=vmin, vmax=vmax)
ax.set_title("0.8-30 Hz", fontsize=12)
plt.colorbar(im, ax=ax, shrink=0.85)
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "average_topomap.png"), dpi=200)
plt.close(fig)
print(f"[OK] average_topomap saved")
def plot_band_topomaps(band_values, out_dir):
"""绘制分频段拓扑图"""
valid_names = band_values.get("_valid_names", [])
if not valid_names:
return
order = [
("delta", "δ (1-4Hz)"),
("theta", "θ (4-8Hz)"),
("alpha", "α (8-13Hz)"),
("beta", "β (13-30Hz)"),
("broad", "1-30 Hz"),
]
temp_raw = _create_topomap_raw(valid_names)
if temp_raw is None:
return
all_values = [band_values[k] for k, _ in order]
vmin, vmax = _compute_topomap_vlim(all_values)
fig, axes = plt.subplots(1, 5, figsize=(16, 4.2))
ims = []
for ax, (k, title) in zip(axes, order):
im, _ = mne.viz.plot_topomap(
band_values[k], temp_raw.info, axes=ax, show=False, contours=0,
sphere=(0, 0, 0, TOPOMAP_SPHERE_RADIUS), extrapolate='head', border='mean',
cmap='turbo'
)
im.set_clim(vmin=vmin, vmax=vmax)
ax.set_title(title, fontsize=11)
ims.append(im)
fig.subplots_adjust(left=0.02, right=0.85, top=0.88, bottom=0.05, wspace=0.35)
cax = fig.add_axes([0.87, 0.15, 0.015, 0.7])
fig.colorbar(ims[-1], cax=cax)
plt.savefig(os.path.join(out_dir, "topomaps.png"), dpi=200)
plt.close(fig)
print(f"[OK] topomaps saved")
def _compute_topomap_vlim(values):
"""计算脑地形图颜色范围"""
v_all = np.concatenate(values) if isinstance(values, list) else np.array(values)
if TOPOMAP_VMIN is not None and TOPOMAP_VMAX is not None:
return TOPOMAP_VMAX - 60, TOPOMAP_VMAX # 保持 50 的范围
if TOPOMAP_SYM_SCALE is not None and TOPOMAP_SYM_SCALE > 0:
mean_val = np.mean(v_all)
std_val = np.std(v_all)
return mean_val - std_val * TOPOMAP_SYM_SCALE, mean_val + std_val * TOPOMAP_SYM_SCALE
# 统一 vmax使用所有频段中的最大值
# vmin = 0这样低功率频段会接近 0白色/冷色),高功率频段突出
vmax = np.max(v_all)
vmin = 0
return vmin, vmax
# ==========================
# 预测接口
# ==========================
def _predict_label_by_model(model_path: str, data_path: str) -> dict:
"""调用 infer_pth.py 进行预测"""
try:
from infer_pth import predict_hc_mdd
except Exception as e:
raise RuntimeError(f"无法导入 predict_hc_mdd: {e}")
import tempfile
import scipy.io
ext = os.path.splitext(data_path)[1].lower()
if ext == ".mat":
# 直接使用 mat 文件
result = predict_hc_mdd(os.path.dirname(data_path), model_path)
elif ext == ".bdf":
# 转换为 mat 格式
raw = mne.io.read_raw_bdf(data_path, preload=True, verbose=False)
data, times = raw[:]
sfreq = raw.info['sfreq']
ch_names = raw.ch_names
with tempfile.TemporaryDirectory() as temp_dir:
mat_path = os.path.join(temp_dir, "preprocessed_eeg.mat")
scipy.io.savemat(mat_path, {
'eeg': {
'data': (data * 1e6).T,
'sample_rate': sfreq,
'electrode_name': ch_names
}
})
result = predict_hc_mdd(temp_dir, model_path)
else:
raise ValueError(f"不支持的文件格式: {ext}")
return result
# ==========================
# 生成 ResultData.txt
# ==========================
def compute_and_save_txt(model_path, bdf_path, out_dir, eeg_uV_tc, fs, ch_names):
"""计算特征指标并保存 ResultData.txt"""
# 获取预测结果
pred_result = _predict_label_by_model(model_path, bdf_path)
pred_label = pred_result.get("pred_label", "UNKNOWN")
recommend = "" if pred_label == "MDD" else ""
T, C = eeg_uV_tc.shape
mp = build_channel_index_map(ch_names, C)
frontal_idx, central_idx, parietal_idx, prefrontal_idx, posterior_idx, left_idx, right_idx = \
get_region_indices(mp, C)
freqs, pxx = welch_psd(eeg_uV_tc, fs)
# 计算各频段功率比
central_alpha = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Alpha"])
central_beta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Beta"])
frontal_alpha = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Alpha"])
frontal_beta = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Beta"])
par_alpha = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Alpha"])
par_beta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Beta"])
central_ab = (central_alpha / (central_beta + EPS)) if central_beta > 0 else 0.0
frontal_ab = (frontal_alpha / (frontal_beta + EPS)) if frontal_beta > 0 else 0.0
par_ab = (par_alpha / (par_beta + EPS)) if par_beta > 0 else 0.0
central_theta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Theta"])
par_theta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Theta"])
central_tb = (central_theta / (central_beta + EPS)) if central_beta > 0 else 0.0
par_tb = (par_theta / (par_beta + EPS)) if par_beta > 0 else 0.0
if not left_idx or not right_idx:
left_idx = [i for i in prefrontal_idx if (i % 2 == 0)]
right_idx = [i for i in prefrontal_idx if (i % 2 == 1)]
left_alpha = region_mean_power(freqs, pxx, left_idx, BANDS_METRICS["Alpha"])
right_alpha = region_mean_power(freqs, pxx, right_idx, BANDS_METRICS["Alpha"])
prefrontal_alpha_asym = float(np.log(right_alpha + EPS) - np.log(left_alpha + EPS))
iaf = compute_iaf(freqs, pxx, posterior_idx)
pre_td = region_mean_power(freqs, pxx, prefrontal_idx,
(BANDS_METRICS["Delta"][0], BANDS_METRICS["Theta"][1]))
pre_total = region_mean_power(freqs, pxx, prefrontal_idx, TOTAL_POWER_BAND)
pre_td_rel = (pre_td / (pre_total + EPS)) * 100.0 if pre_total > 0 else 0.0
def f1(x): return f"{x:.1f}"
txt = (
f"中央区α/β波比值:{f1(central_ab)}\n"
f"额区α/β波比值:{f1(frontal_ab)}\n"
f"顶区α/β波比值:{f1(par_ab)}\n"
f"中央区θ/β波比值:{f1(central_tb)}\n"
f"顶区θ/β波比值:{f1(par_tb)}\n"
f"前额叶α波不对称性:{f1(prefrontal_alpha_asym)}\n"
f"个体化α峰值频率:{f1(iaf)}\n"
f"前额叶θ+δ波功率:{f1(pre_td_rel)}\n"
f"是否推荐治疗:{recommend}\n"
)
out_path = os.path.join(out_dir, "ResultData.txt")
with open(out_path, "w", encoding="utf-8") as f:
f.write(txt)
print(f"[OK] ResultData.txt saved: {out_path}")
# 打印预测结果
print(f"\n========== 预测结果 ==========")
print(f"预测标签: {pred_label}")
print(f"p(MDD)均值: {pred_result.get('p_mdd_mean', 'N/A'):.4f}")
print(f"切片数量: {pred_result.get('n_slices', 'N/A')}")
print(f"==============================\n")
# ==========================
# 主函数
# ==========================
def run_all(model_path: str, bdf_dir: str, out_root: str, seconds: int = EEG_PLOT_SECONDS):
"""主流程"""
if not os.path.exists(bdf_dir):
raise RuntimeError(f"输入目录不存在: {bdf_dir}")
# 支持 .bdf 和 .mat 文件
data_files = [f for f in os.listdir(bdf_dir) if f.lower().endswith((".bdf", ".mat"))]
if not data_files:
raise RuntimeError(f"目录中找不到 .bdf 或 .mat 文件: {bdf_dir}")
data_files.sort()
data_path = os.path.join(bdf_dir, data_files[0])
print(f"[INFO] Processing file: {data_path}")
out_dir = ensure_outdir(out_root)
print(f"[INFO] Output directory: {out_dir}")
raw, sfreq, ch_names = load_data_file(data_path)
raw_clean = preprocess_bdf(raw)
try:
raw_clean.set_montage("standard_1020", on_missing="ignore")
except Exception as e:
print(f"[WARN] Failed to re-apply montage: {e}")
raw_data = raw_clean.get_data()
eeg_uV_tc = (raw_data * 1e6).T.astype(np.float32)
print(f"[INFO] Preprocessed EEG shape: {eeg_uV_tc.shape}")
print("[INFO] Generating figures...")
plot_psd(eeg_uV_tc, sfreq, ch_names, out_dir)
plot_eeg_waveforms(eeg_uV_tc, sfreq, ch_names, out_dir, seconds=seconds)
print("[INFO] Generating topomaps...")
try:
band_vals = compute_band_powers_for_topomap(raw_clean, BANDS_TOPOMAP)
if band_vals is not None:
plot_average_topomap(band_vals, out_dir)
plot_band_topomaps(band_vals, out_dir)
except Exception as e:
print(f"[WARN] Topomap generation failed: {e}")
print("[INFO] Running prediction...")
compute_and_save_txt(model_path, data_path, out_dir, eeg_uV_tc, sfreq, ch_names)
print("[DONE] All tasks completed.")
return out_dir
# ==========================
# 命令行入口
# ==========================
if __name__ == "__main__":
import multiprocessing
multiprocessing.freeze_support()
import argparse
import sys
def get_resource_path(relative_path):
"""获取资源绝对路径"""
if getattr(sys, 'frozen', False):
base_path = os.path.dirname(sys.executable)
else:
base_path = os.path.dirname(os.path.abspath(__file__))
return os.path.join(base_path, relative_path)
# 默认路径
DEFAULT_MODEL = get_resource_path(os.path.join("model", "Model_1.pth"))
if getattr(sys, 'frozen', False):
EXE_DIR = os.path.dirname(sys.executable)
else:
EXE_DIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_BDF_DIR = os.path.join(EXE_DIR, "raw_data")
DEFAULT_OUT = os.path.join(EXE_DIR, "out")
parser = argparse.ArgumentParser(description="EEG Depression Assessment")
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL, help="模型文件路径 (.pth)")
parser.add_argument("--bdf_dir", type=str, default=DEFAULT_BDF_DIR, help="输入文件夹路径 (包含 .bdf 或 .mat 文件)")
parser.add_argument("--out_root", type=str, default=DEFAULT_OUT, help="结果输出目录")
parser.add_argument("--seconds", type=int, default=10, help="画波形图的秒数")
args = parser.parse_args()
print(f"[*] 运行配置:")
print(f" - Model : {args.model_path}")
print(f" - Input : {args.bdf_dir}")
print(f" - Output: {args.out_root}")
if not os.path.exists(args.bdf_dir):
print(f"[ERROR] 输入目录不存在: {args.bdf_dir}")
if not os.path.exists(args.model_path):
print(f"[ERROR] 模型文件不存在: {args.model_path}")
run_all(args.model_path, args.bdf_dir, args.out_root, seconds=args.seconds)