987 lines
34 KiB
Python
987 lines
34 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
runDecoder.py - BDF EEG Depression Assessment
|
||
|
||
功能:
|
||
1. 读取 raw_data 文件夹的第一个 .bdf 格式文件
|
||
2. 预处理:坏通道剔除、50Hz陷波、0.8-40Hz带通、幅值过滤、ICA去伪迹
|
||
3. 调用 infer_pth.py 中的 predict_hc_mdd 进行 HC/MDD 分类预测
|
||
4. 保存图表(EEG、PSD、Topomap)
|
||
5. 生成 ResultData.txt
|
||
|
||
"""
|
||
import matplotlib
|
||
matplotlib.use('Agg')
|
||
import numpy as np
|
||
import os
|
||
import shutil
|
||
import scipy.signal as signal
|
||
import matplotlib.pyplot as plt
|
||
import mne
|
||
from mne.preprocessing import ICA
|
||
|
||
# ==========================
|
||
# Config - 预处理参数
|
||
# ==========================
|
||
# 滤波参数
|
||
BANDPASS_LOW = 0.8
|
||
BANDPASS_HIGH = 40.0
|
||
NOTCH_FREQS = [50, 100] # 工频陷波
|
||
|
||
# 幅值过滤阈值 (μV)
|
||
AMPLITUDE_MIN_UV = -200.0
|
||
AMPLITUDE_MAX_UV = 200.0
|
||
|
||
# ICA 参数
|
||
ICA_N_COMPONENTS = 20 # 使用绝对数量而非比例
|
||
ICA_RANDOM_STATE = 97
|
||
ICA_MAX_ITER = 800
|
||
|
||
# 坏段检测阈值 (μV)
|
||
BAD_SEGMENT_THRESHOLD_UV = 350.0
|
||
|
||
# 默认采样率
|
||
DEFAULT_FS = 250.0
|
||
|
||
# 画图参数
|
||
EEG_PLOT_SECONDS = 10
|
||
PSD_FMIN, PSD_FMAX = 0.8, 45.0
|
||
FIXED_EEG_IDXS = [23, 47, 39, 6, 2, 21, 35, 57] # 0-based index
|
||
FIXED_EEG_LABELS = ["C5", "O1", "TP7", "FPZ", "PO6", "P4", "AF7", "AF3"]
|
||
|
||
# 频段定义
|
||
BANDS_METRICS = {
|
||
"Delta": (1.0, 4.0),
|
||
"Theta": (4.0, 8.0),
|
||
"Alpha": (8.0, 13.0),
|
||
"Beta": (13.0, 30.0),
|
||
}
|
||
TOTAL_POWER_BAND = (1.0, 50.0)
|
||
|
||
BANDS_TOPOMAP = {
|
||
"delta": (1.0, 4.0),
|
||
"theta": (4.0, 8.0),
|
||
"alpha": (8.0, 13.0),
|
||
"beta": (13.0, 30.0),
|
||
"broad": (1.0, 30.0),
|
||
}
|
||
|
||
# PSD 参数
|
||
PSD_NPERSEG = 1024 # FFT 窗口大小,越大频率分辨率越高
|
||
|
||
EPS = 1e-12
|
||
|
||
# 脑地形图颜色范围参数
|
||
# 设置为 None 表示自动范围,设置为 (min, max) 固定范围
|
||
TOPOMAP_VMIN = None # 例如: -1.0
|
||
TOPOMAP_VMAX = None # 例如: 1.0
|
||
# 或者使用对称范围(相对于均值的倍数)
|
||
TOPOMAP_SYM_SCALE = 1.5 # 颜色范围 = 均值 ± std * SYM_SCALE
|
||
|
||
# 脑地形图圆形大小参数 (0.08 - 0.15 范围)
|
||
# 数值越小圆形越小,越大圆形越大
|
||
TOPOMAP_SPHERE_RADIUS = 0.12
|
||
|
||
# 边界处理参数
|
||
# 滤波前 padding 秒数,用于消除边界振铃效应
|
||
FILTER_PAD_SEC = 1.0
|
||
|
||
|
||
# ==========================
|
||
# 数据文件读取
|
||
# ==========================
|
||
def load_data_file(file_path: str) -> tuple:
|
||
"""根据文件扩展名读取数据,返回 MNE Raw 对象"""
|
||
ext = os.path.splitext(file_path)[1].lower()
|
||
|
||
if ext == ".bdf":
|
||
return load_bdf_file(file_path)
|
||
elif ext == ".mat":
|
||
return load_mat_file(file_path)
|
||
else:
|
||
raise ValueError(f"不支持的文件格式: {ext}")
|
||
|
||
|
||
def load_bdf_file(bdf_path: str) -> tuple:
|
||
"""读取 .bdf 格式文件,返回 MNE Raw 对象"""
|
||
print(f"[INFO] Reading BDF file: {bdf_path}")
|
||
raw = mne.io.read_raw_bdf(bdf_path, preload=True, verbose=False)
|
||
|
||
try:
|
||
raw.set_montage("standard_1020", on_missing="ignore")
|
||
except Exception as e:
|
||
print(f"[WARN] Failed to set standard_1020 montage: {e}")
|
||
|
||
sfreq = raw.info['sfreq']
|
||
ch_names = raw.ch_names
|
||
n_channels = len(ch_names)
|
||
duration = raw.times[-1] - raw.times[0]
|
||
|
||
print(f"[INFO] Channels: {n_channels}, Duration: {duration:.2f}s, SFreq: {sfreq:.2f}Hz")
|
||
|
||
return raw, sfreq, ch_names
|
||
|
||
|
||
def load_mat_file(mat_path: str) -> tuple:
|
||
"""读取 .mat 格式文件,返回 MNE Raw 对象"""
|
||
print(f"[INFO] Reading MAT file: {mat_path}")
|
||
import scipy.io
|
||
|
||
mat = scipy.io.loadmat(mat_path)
|
||
eeg = mat['eeg'][0, 0]
|
||
|
||
# 提取数据
|
||
data = eeg['data'] # (T, C)
|
||
if data.shape[0] < data.shape[1]:
|
||
data = data.T # 确保是 (T, C)
|
||
data = data.astype(np.float64) # 确保是 float
|
||
|
||
# 提取采样率
|
||
sfreq = float(eeg['sample_rate'][0, 0])
|
||
|
||
# 提取通道名称
|
||
ch_names_raw = eeg['electrode_name']
|
||
if ch_names_raw.ndim == 2:
|
||
ch_names = [str(ch[0]) if isinstance(ch[0], np.bytes_) else str(ch[0]) for ch in ch_names_raw[0]]
|
||
else:
|
||
ch_names = [f"EEG{i+1}" for i in range(data.shape[1])]
|
||
|
||
n_channels = data.shape[1]
|
||
n_samples = data.shape[0]
|
||
duration = n_samples / sfreq
|
||
|
||
print(f"[INFO] Channels: {n_channels}, Duration: {duration:.2f}s, SFreq: {sfreq:.2f}Hz")
|
||
|
||
# 创建 MNE Raw 对象
|
||
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=["eeg"] * n_channels)
|
||
raw = mne.io.RawArray(data.T, info, verbose=False) # (T, C) -> (C, T)
|
||
|
||
# 尝试设置通道位置
|
||
try:
|
||
electrode_xyz = eeg['electrode_xyz'] # (64, 3)
|
||
if electrode_xyz.shape[0] == n_channels:
|
||
ch_pos = {}
|
||
for i, name in enumerate(ch_names):
|
||
ch_pos[name] = electrode_xyz[i] / 1000.0 # 转换为米
|
||
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')
|
||
info.set_montage(montage)
|
||
print("[INFO] Applied electrode positions from mat file")
|
||
else:
|
||
raw.set_montage("standard_1020", on_missing="ignore")
|
||
except Exception as e:
|
||
print(f"[WARN] Failed to set montage from mat file: {e}")
|
||
try:
|
||
raw.set_montage("standard_1020", on_missing="ignore")
|
||
except:
|
||
pass
|
||
|
||
return raw, sfreq, ch_names
|
||
|
||
|
||
# ==========================
|
||
# 坏通道检测
|
||
# ==========================
|
||
def detect_bad_channels(raw: mne.io.RawArray, z_thresh: float = 3.0) -> list:
|
||
"""检测坏通道:全零/常数通道 + MAD z-score 离群通道"""
|
||
data = raw.get_data()
|
||
ch_names = raw.ch_names
|
||
bad_chs = []
|
||
|
||
ptp = np.ptp(data, axis=1)
|
||
std = np.std(data, axis=1)
|
||
|
||
for i, (p, s) in enumerate(zip(ptp, std)):
|
||
if p < 1e-12 or s < 1e-12:
|
||
bad_chs.append(ch_names[i])
|
||
|
||
valid_mask = np.array([ch not in bad_chs for ch in ch_names])
|
||
if valid_mask.sum() > 2:
|
||
valid_ptp = ptp[valid_mask]
|
||
med = np.median(valid_ptp)
|
||
mad = np.median(np.abs(valid_ptp - med)) + 1e-30
|
||
z = np.abs(ptp - med) / (mad * 1.4826)
|
||
|
||
for i, zv in enumerate(z):
|
||
if zv > z_thresh and ch_names[i] not in bad_chs:
|
||
bad_chs.append(ch_names[i])
|
||
|
||
if bad_chs:
|
||
print(f"[INFO] Bad channels detected: {bad_chs}")
|
||
else:
|
||
print("[INFO] No bad channels detected")
|
||
|
||
return bad_chs
|
||
|
||
|
||
# ==========================
|
||
# 坏段标注
|
||
# ==========================
|
||
def annotate_bad_segments(raw: mne.io.RawArray, peak_to_peak_uv: float = 350.0):
|
||
"""简单坏段检测:按1秒窗口计算峰峰值,超过阈值标为 bad"""
|
||
peak_to_peak_v = peak_to_peak_uv * 1e-6
|
||
win = int(raw.info["sfreq"] * 1.0)
|
||
step = int(raw.info["sfreq"] * 0.5)
|
||
data = raw.get_data()
|
||
n_times = data.shape[1]
|
||
|
||
onsets = []
|
||
durations = []
|
||
|
||
for start in range(0, n_times - win, step):
|
||
seg = data[:, start:start + win]
|
||
ptp = np.ptp(seg, axis=1)
|
||
if np.any(ptp > peak_to_peak_v):
|
||
onsets.append(start / raw.info["sfreq"])
|
||
durations.append(win / raw.info["sfreq"])
|
||
|
||
if len(onsets) > 0:
|
||
ann = mne.Annotations(onset=onsets, duration=durations, description=["BAD_SEG"] * len(onsets))
|
||
raw.set_annotations(ann)
|
||
print(f"[INFO] Annotated {len(onsets)} bad segments")
|
||
|
||
|
||
# ==========================
|
||
# 核心预处理函数
|
||
# ==========================
|
||
def preprocess_bdf(raw: mne.io.RawArray) -> mne.io.RawArray:
|
||
"""BDF 数据预处理流程"""
|
||
print("[INFO] Starting preprocessing pipeline...")
|
||
|
||
# 1) 裁剪首尾 2s
|
||
crop_sec = 2.0
|
||
t_start = crop_sec
|
||
t_end = raw.times[-1] - crop_sec
|
||
if t_end > t_start:
|
||
raw = raw.crop(tmin=t_start, tmax=t_end)
|
||
print(f"[INFO] Cropped: removed first/last {crop_sec}s")
|
||
|
||
# 2) 去直流偏置
|
||
data = raw.get_data()
|
||
data -= data.mean(axis=1, keepdims=True)
|
||
raw._data = data
|
||
print("[INFO] Removed DC offset")
|
||
|
||
# 3) 坏通道检测与插值
|
||
bad_chs = detect_bad_channels(raw)
|
||
if bad_chs:
|
||
raw.info["bads"] = bad_chs
|
||
try:
|
||
raw_tmp = raw.copy()
|
||
raw_tmp.set_montage(raw.get_montage(), on_missing="ignore")
|
||
raw_tmp.interpolate_bads(reset_bads=True, verbose=False)
|
||
raw = raw_tmp
|
||
print(f"[INFO] Bad channels interpolated: {bad_chs}")
|
||
except Exception as e:
|
||
print(f"[WARN] Bad channel interpolation failed: {e}")
|
||
raw.info["bads"] = []
|
||
|
||
# 4) 50Hz 陷波滤波
|
||
print(f"[INFO] Applying notch filter: {NOTCH_FREQS}Hz")
|
||
raw.notch_filter(NOTCH_FREQS, fir_design="firwin", verbose=False)
|
||
|
||
# 5) 0.8-40Hz 带通滤波 (使用 padding 消除边界振铃)
|
||
print(f"[INFO] Applying bandpass filter: {BANDPASS_LOW}-{BANDPASS_HIGH}Hz (with {FILTER_PAD_SEC}s padding)")
|
||
pad_sec = FILTER_PAD_SEC
|
||
raw_length = raw.times[-1]
|
||
pad_start = max(0, pad_sec)
|
||
pad_end = max(0, pad_sec)
|
||
|
||
if raw_length > pad_start + pad_end + 1.0:
|
||
raw.filter(BANDPASS_LOW, BANDPASS_HIGH, fir_design="firwin",
|
||
pad='reflect', verbose=False)
|
||
raw = raw.crop(tmin=pad_sec, tmax=raw_length - pad_sec)
|
||
print(f"[INFO] Removed {pad_sec}s padding from each side after filtering")
|
||
else:
|
||
raw.filter(BANDPASS_LOW, BANDPASS_HIGH, fir_design="firwin", verbose=False)
|
||
print(f"[WARN] Data too short ({raw_length:.1f}s), skipping padding")
|
||
|
||
# 6) 幅值过滤
|
||
print(f"[INFO] Applying amplitude filter: [{AMPLITUDE_MIN_UV}, {AMPLITUDE_MAX_UV}] μV")
|
||
amplitude_thresh_v = AMPLITUDE_MAX_UV * 1e-6
|
||
d = raw.get_data()
|
||
mask = np.abs(d) > amplitude_thresh_v
|
||
n_clipped = int(mask.sum())
|
||
if n_clipped > 0:
|
||
d[mask] = 0.0
|
||
raw._data = d
|
||
print(f"[INFO] Amplitude clipping: {n_clipped} samples exceeded ±200μV, set to 0")
|
||
|
||
# 7) 坏段标注
|
||
annotate_bad_segments(raw, peak_to_peak_uv=BAD_SEGMENT_THRESHOLD_UV)
|
||
|
||
# 8) ICA 去伪迹
|
||
print("[INFO] Running ICA for artifact removal...")
|
||
ica = ICA(n_components=ICA_N_COMPONENTS, random_state=ICA_RANDOM_STATE,
|
||
max_iter=ICA_MAX_ITER, method="fastica")
|
||
ica.fit(raw, reject_by_annotation=True, verbose=False)
|
||
|
||
try:
|
||
eog_inds, _ = ica.find_bads_eog(raw, verbose=False)
|
||
if eog_inds:
|
||
ica.exclude.extend(eog_inds)
|
||
print(f"[INFO] ICA exclude EOG components: {eog_inds}")
|
||
except Exception as e:
|
||
print(f"[WARN] ICA EOG detection skipped: {e}")
|
||
|
||
raw_clean = ica.apply(raw.copy(), verbose=False)
|
||
|
||
# 9) ICA 后再次去直流
|
||
d = raw_clean.get_data()
|
||
d -= d.mean(axis=1, keepdims=True)
|
||
raw_clean._data = d
|
||
|
||
print("[INFO] Preprocessing completed")
|
||
return raw_clean
|
||
|
||
|
||
# ==========================
|
||
# 输出目录管理
|
||
# ==========================
|
||
def ensure_outdir(out_root: str) -> str:
|
||
"""确保输出目录存在,并清空旧文件(保留 ResultData.txt)"""
|
||
if os.path.exists(out_root):
|
||
for filename in os.listdir(out_root):
|
||
if filename == "ResultData.txt":
|
||
continue
|
||
file_path = os.path.join(out_root, filename)
|
||
try:
|
||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||
os.unlink(file_path)
|
||
elif os.path.isdir(file_path):
|
||
shutil.rmtree(file_path)
|
||
except Exception as e:
|
||
print(f"[WARN] Failed to delete {file_path}: {e}")
|
||
else:
|
||
os.makedirs(out_root, exist_ok=True)
|
||
return out_root
|
||
|
||
|
||
# ==========================
|
||
# 通道分区
|
||
# ==========================
|
||
def _norm_name(s: str) -> str:
|
||
return str(s).strip().upper().replace(" ", "")
|
||
|
||
|
||
def build_channel_index_map(ch_names, n_channels: int):
|
||
if not ch_names or len(ch_names) != n_channels:
|
||
return {}
|
||
return {_norm_name(nm): i for i, nm in enumerate(ch_names)}
|
||
|
||
|
||
def pick_indices_by_names(name_to_idx, names):
|
||
idx = []
|
||
for n in names:
|
||
nn = _norm_name(n)
|
||
if nn in name_to_idx:
|
||
idx.append(name_to_idx[nn])
|
||
return sorted(list(set(idx)))
|
||
|
||
|
||
def _fallback_region_indices(n_channels: int):
|
||
a = int(n_channels * 0.33)
|
||
b = int(n_channels * 0.66)
|
||
return (
|
||
list(range(0, a)), # frontal
|
||
list(range(a, b)), # central
|
||
list(range(b, n_channels)), # parietal
|
||
list(range(0, max(2, a // 2))), # prefrontal
|
||
list(range(b, n_channels)), # posterior
|
||
[i for i in range(n_channels) if i % 2 == 0], # left
|
||
[i for i in range(n_channels) if i % 2 == 1], # right
|
||
)
|
||
|
||
|
||
def get_region_indices(name_to_idx, n_channels: int):
|
||
if not name_to_idx:
|
||
return _fallback_region_indices(n_channels)
|
||
|
||
central_names = ["CZ","C1","C2","C3","C4","C5","C6","CP1","CP2","CP3","CP4","CP5","CP6","FC1","FC2","FC3","FC4","FC5","FC6"]
|
||
frontal_names = ["FZ","F1","F2","F3","F4","F5","F6","F7","F8","AF3","AF4","AF7","AF8","FPZ","FP1","FP2","FCZ"]
|
||
parietal_names = ["PZ","P1","P2","P3","P4","P5","P6","POZ","PO3","PO4","PO5","PO6","PO7","PO8","CPZ"]
|
||
prefrontal_names = ["FP1","FP2","FPZ","AF3","AF4","AF7","AF8"]
|
||
posterior_names = ["O1","O2","OZ","PO7","PO8","PO3","PO4","PZ","P3","P4","P1","P2"]
|
||
left_names = ["FP1","AF3","AF7","F3","F5","F7"]
|
||
right_names = ["FP2","AF4","AF8","F4","F6","F8"]
|
||
|
||
central = pick_indices_by_names(name_to_idx, central_names)
|
||
frontal = pick_indices_by_names(name_to_idx, frontal_names)
|
||
parietal = pick_indices_by_names(name_to_idx, parietal_names)
|
||
prefrontal = pick_indices_by_names(name_to_idx, prefrontal_names)
|
||
posterior = pick_indices_by_names(name_to_idx, posterior_names)
|
||
left = pick_indices_by_names(name_to_idx, left_names)
|
||
right = pick_indices_by_names(name_to_idx, right_names)
|
||
|
||
if not (central and frontal and parietal and prefrontal and posterior):
|
||
fb = _fallback_region_indices(n_channels)
|
||
frontal2, central2, parietal2, prefrontal2, posterior2, left2, right2 = fb
|
||
frontal = frontal if frontal else frontal2
|
||
central = central if central else central2
|
||
parietal = parietal if parietal else parietal2
|
||
prefrontal = prefrontal if prefrontal else prefrontal2
|
||
posterior = posterior if posterior else posterior2
|
||
left = left if left else left2
|
||
right = right if right else right2
|
||
|
||
return frontal, central, parietal, prefrontal, posterior, left, right
|
||
|
||
|
||
# ==========================
|
||
# PSD 和频段功率计算
|
||
# ==========================
|
||
def welch_psd(eeg_tc, fs):
|
||
"""计算 PSD"""
|
||
nperseg = min(PSD_NPERSEG, eeg_tc.shape[0])
|
||
noverlap = int(nperseg * 0.75)
|
||
freqs, pxx = signal.welch(
|
||
eeg_tc, fs=fs, nperseg=nperseg, noverlap=noverlap,
|
||
axis=0, scaling="density"
|
||
)
|
||
return freqs, pxx
|
||
|
||
|
||
def band_power_from_psd(freqs, pxx_fc, band):
|
||
lo, hi = band
|
||
m = (freqs >= lo) & (freqs < hi)
|
||
if not np.any(m):
|
||
return np.zeros((pxx_fc.shape[1],), dtype=np.float32)
|
||
|
||
return np.trapz(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
|
||
|
||
|
||
def region_mean_power(freqs, pxx_fc, idx, band) -> float:
|
||
if not idx:
|
||
return 0.0
|
||
pw = band_power_from_psd(freqs, pxx_fc, band)
|
||
return float(np.mean(pw[idx]))
|
||
|
||
|
||
def compute_iaf(freqs, pxx_fc, posterior_idx):
|
||
lo, hi = BANDS_METRICS["Alpha"]
|
||
m = (freqs >= lo) & (freqs <= hi)
|
||
if not np.any(m) or not posterior_idx:
|
||
return 0.0
|
||
spec = np.mean(pxx_fc[:, posterior_idx], axis=1)
|
||
sub = spec[m]
|
||
fsub = freqs[m]
|
||
return float(fsub[int(np.argmax(sub))])
|
||
|
||
|
||
# ==========================
|
||
# 画图函数
|
||
# ==========================
|
||
def plot_eeg_waveforms(data_uv_tc, fs, ch_names, out_dir, seconds=10, t_start_sec=30.0):
|
||
"""画 EEG 波形图(固定通道)"""
|
||
T, C = data_uv_tc.shape
|
||
|
||
start_sample = int(t_start_sec * fs)
|
||
end_sample = int(min(T, start_sample + seconds * fs))
|
||
|
||
if start_sample >= T:
|
||
start_sample = max(0, T - int(seconds * fs))
|
||
end_sample = T
|
||
print(f"[WARN] t_start_sec={t_start_sec}s exceeds data, using last {seconds}s")
|
||
|
||
seg_samples = end_sample - start_sample
|
||
x = np.arange(seg_samples) / fs + t_start_sec
|
||
|
||
# 过滤有效索引
|
||
idxs = [i for i in FIXED_EEG_IDXS if 0 <= i < C]
|
||
if len(idxs) < len(FIXED_EEG_IDXS):
|
||
missing = [i for i in FIXED_EEG_IDXS if not (0 <= i < C)]
|
||
print(f"[WARN] Some indices out of range (C={C}): {missing}")
|
||
|
||
if len(idxs) == 0:
|
||
raise RuntimeError(f"No valid indices for data (C={C})")
|
||
|
||
picked_names = []
|
||
for idx in idxs:
|
||
pos = FIXED_EEG_IDXS.index(idx)
|
||
std_label = FIXED_EEG_LABELS[pos] if pos < len(FIXED_EEG_LABELS) else f"CH{idx}"
|
||
if ch_names and idx < len(ch_names):
|
||
picked_names.append(std_label)
|
||
else:
|
||
picked_names.append(std_label)
|
||
|
||
fig_h = 1.4 * len(idxs) + 1
|
||
fig, axes = plt.subplots(len(idxs), 1, figsize=(10, fig_h), sharex=True)
|
||
if len(idxs) == 1:
|
||
axes = [axes]
|
||
|
||
seg = data_uv_tc[start_sample:end_sample, idxs].T
|
||
lo = float(np.percentile(seg, 1))
|
||
hi = float(np.percentile(seg, 99))
|
||
m = max(abs(lo), abs(hi), 50.0)
|
||
|
||
for ax, ch_idx, nm in zip(axes, idxs, picked_names):
|
||
y = data_uv_tc[start_sample:end_sample, ch_idx]
|
||
ax.plot(x, y, linewidth=1.2)
|
||
ax.set_ylabel("uV")
|
||
ax.set_title(nm, loc="left", fontsize=10)
|
||
ax.grid(True, alpha=0.3)
|
||
ax.set_ylim(-m, m)
|
||
|
||
axes[-1].set_xlabel("Time (s)")
|
||
plt.tight_layout()
|
||
|
||
out_path = os.path.join(out_dir, "EEG.png")
|
||
plt.savefig(out_path, dpi=200)
|
||
plt.close(fig)
|
||
print(f"[OK] EEG waveform saved: {out_path}")
|
||
|
||
|
||
def plot_psd(eeg_uV_tc, fs, ch_names, out_dir):
|
||
"""画 PSD 图"""
|
||
C = eeg_uV_tc.shape[1]
|
||
chosen_idx = []
|
||
|
||
if ch_names:
|
||
mp = {n.upper(): i for i, n in enumerate(ch_names)}
|
||
for p in ["C3", "C4", "CZ"]:
|
||
if p in mp:
|
||
chosen_idx.append(mp[p])
|
||
if len(chosen_idx) < 3:
|
||
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
|
||
stds.sort(key=lambda x: x[1], reverse=True)
|
||
for i, _ in stds:
|
||
if i not in chosen_idx:
|
||
chosen_idx.append(i)
|
||
if len(chosen_idx) == 3:
|
||
break
|
||
chosen_name = [ch_names[i] for i in chosen_idx]
|
||
else:
|
||
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
|
||
stds.sort(key=lambda x: x[1], reverse=True)
|
||
chosen_idx = [i for i, _ in stds[:3]]
|
||
chosen_name = [f"CH{i}" for i in chosen_idx]
|
||
|
||
# 增大 nperseg 提高频率分辨率
|
||
nperseg = min(PSD_NPERSEG, eeg_uV_tc.shape[0])
|
||
noverlap = int(nperseg * 0.75)
|
||
|
||
fig = plt.figure(figsize=(7.5, 4.8))
|
||
for idx, nm in zip(chosen_idx, chosen_name):
|
||
f, pxx = signal.welch(eeg_uV_tc[:, idx], fs=fs, nperseg=nperseg, noverlap=noverlap)
|
||
mask = (f >= PSD_FMIN) & (f <= PSD_FMAX)
|
||
p_db = 10 * np.log10(pxx[mask] + 1e-20)
|
||
plt.plot(f[mask], p_db, linewidth=1.8, label=nm)
|
||
|
||
plt.xlabel("Hz")
|
||
plt.ylabel("Power (dB)")
|
||
plt.title("PSD")
|
||
plt.grid(True, alpha=0.3)
|
||
plt.legend()
|
||
plt.tight_layout()
|
||
|
||
out_path = os.path.join(out_dir, "psd.png")
|
||
plt.savefig(out_path, dpi=200)
|
||
plt.close(fig)
|
||
print(f"[OK] PSD saved: {out_path}")
|
||
|
||
|
||
def _get_standard_1020_channel_indices(raw):
|
||
"""获取符合 standard_1020 montage 的通道索引和名称"""
|
||
try:
|
||
standard_montage = mne.channels.make_standard_montage("standard_1020")
|
||
standard_names_upper = {ch.upper() for ch in standard_montage.ch_names}
|
||
standard_name_map = {ch.upper(): ch for ch in standard_montage.ch_names}
|
||
data_ch_names = raw.ch_names
|
||
exclude_names = {"A1", "A2", "M1", "M2", "LE", "RE", "LM", "RM"}
|
||
|
||
valid_indices = []
|
||
valid_names = []
|
||
for i, name in enumerate(data_ch_names):
|
||
name_upper = name.upper()
|
||
if name_upper in standard_names_upper and name_upper not in exclude_names:
|
||
valid_indices.append(i)
|
||
valid_names.append(standard_name_map[name_upper])
|
||
|
||
print(f"[INFO] Found {len(valid_indices)}/{len(data_ch_names)} channels matching standard_1020")
|
||
return valid_indices, valid_names
|
||
except Exception as e:
|
||
print(f"[WARN] Failed to get standard_1020 channels: {e}")
|
||
return None, None
|
||
|
||
|
||
def compute_band_powers_for_topomap(raw, bands):
|
||
"""计算各频段功率,只使用 standard_1020 中有位置的通道"""
|
||
# 获取 standard_1020 montage 和位置信息
|
||
standard_montage = mne.channels.make_standard_montage("standard_1020")
|
||
std_names_upper = {ch.upper() for ch in standard_montage.ch_names}
|
||
ch_pos_map = standard_montage.get_positions()['ch_pos']
|
||
|
||
data_ch_names = raw.ch_names
|
||
exclude = {"A1", "A2", "M1", "M2", "LE", "RE", "LM", "RM"}
|
||
|
||
# 只保留有位置信息的通道
|
||
valid_indices = []
|
||
valid_names = []
|
||
for i, name in enumerate(data_ch_names):
|
||
name_upper = name.upper()
|
||
if name_upper in std_names_upper and name_upper not in exclude:
|
||
if name_upper in ch_pos_map: # 必须有位置
|
||
valid_indices.append(i)
|
||
valid_names.append(name)
|
||
|
||
if len(valid_indices) < 8:
|
||
return None
|
||
|
||
data = raw.get_data()
|
||
data_standard = data[valid_indices, :]
|
||
|
||
fs = raw.info["sfreq"]
|
||
n_fft = min(PSD_NPERSEG, data_standard.shape[1])
|
||
n_overlap = int(n_fft * 0.75)
|
||
|
||
psds, freqs = mne.time_frequency.psd_array_welch(
|
||
data_standard, sfreq=fs,
|
||
fmin=min(v[0] for v in bands.values()),
|
||
fmax=max(v[1] for v in bands.values()),
|
||
n_fft=n_fft, n_overlap=n_overlap,
|
||
average="mean", verbose=False
|
||
)
|
||
|
||
out = {"_valid_names": valid_names}
|
||
print(f"[DEBUG] PSD: fs={fs}Hz, n_fft={n_fft}, freq_res={fs/n_fft:.3f}Hz/bin")
|
||
for k, (fmin, fmax) in bands.items():
|
||
idx = np.where((freqs >= fmin) & (freqs < fmax))[0]
|
||
if len(idx) == 0:
|
||
out[k] = np.zeros(len(valid_indices), dtype=np.float32)
|
||
print(f"[DEBUG] {k.upper()}: NO freq bins in [{fmin}-{fmax}]Hz")
|
||
continue
|
||
print(f"[DEBUG] {k.upper()}: freq bins {freqs[idx[0]]:.2f}-{freqs[idx[-1]]:.2f}Hz (bins {idx[0]}-{idx[-1]}, count={len(idx)})")
|
||
# 使用线性功率值 (V^2 -> uV^2: * 1e12)
|
||
bp = np.trapz(psds[:, idx], freqs[idx], axis=1) * 1e12
|
||
out[k] = bp
|
||
print(f"[DEBUG] {k.upper()}: power range [{bp.min():.4f}, {bp.max():.4f}] uV^2, mean={bp.mean():.4f}")
|
||
|
||
print(f"[INFO] Band powers computed for {len(valid_names)} channels with positions")
|
||
return out
|
||
|
||
|
||
def _create_topomap_raw(ch_names):
|
||
"""创建只有 standard_1020 通道位置信息的临时 Raw 对象"""
|
||
standard_montage = mne.channels.make_standard_montage("standard_1020")
|
||
ch_pos_map = standard_montage.get_positions()['ch_pos']
|
||
|
||
valid_ch_names = []
|
||
valid_positions = []
|
||
for name in ch_names:
|
||
name_upper = name.upper()
|
||
if name_upper in ch_pos_map:
|
||
valid_ch_names.append(name)
|
||
valid_positions.append(ch_pos_map[name_upper])
|
||
|
||
if len(valid_ch_names) < 8:
|
||
return None
|
||
|
||
ch_pos = {name: pos for name, pos in zip(valid_ch_names, valid_positions)}
|
||
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')
|
||
|
||
info = mne.create_info(ch_names=valid_ch_names, sfreq=250.0, ch_types=["eeg"] * len(valid_ch_names))
|
||
info.set_montage(montage)
|
||
|
||
dummy_data = np.zeros((len(valid_ch_names), 1))
|
||
return mne.io.RawArray(dummy_data, info, verbose=False)
|
||
|
||
|
||
def plot_average_topomap(band_values, out_dir):
|
||
"""绘制平均拓扑图"""
|
||
valid_names = band_values.get("_valid_names", [])
|
||
if not valid_names:
|
||
return
|
||
|
||
values = band_values["broad"]
|
||
temp_raw = _create_topomap_raw(valid_names)
|
||
if temp_raw is None:
|
||
return
|
||
|
||
vmin, vmax = _compute_topomap_vlim([values])
|
||
|
||
fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.6))
|
||
im, _ = mne.viz.plot_topomap(
|
||
values, temp_raw.info, axes=ax, show=False, contours=0,
|
||
sphere=(0, 0, 0, TOPOMAP_SPHERE_RADIUS), extrapolate='head', border='mean',
|
||
cmap='turbo'
|
||
)
|
||
im.set_clim(vmin=vmin, vmax=vmax)
|
||
ax.set_title("0.8-30 Hz", fontsize=12)
|
||
plt.colorbar(im, ax=ax, shrink=0.85)
|
||
plt.tight_layout()
|
||
plt.savefig(os.path.join(out_dir, "average_topomap.png"), dpi=200)
|
||
plt.close(fig)
|
||
print(f"[OK] average_topomap saved")
|
||
|
||
|
||
def plot_band_topomaps(band_values, out_dir):
|
||
"""绘制分频段拓扑图"""
|
||
valid_names = band_values.get("_valid_names", [])
|
||
if not valid_names:
|
||
return
|
||
|
||
order = [
|
||
("delta", "δ (1-4Hz)"),
|
||
("theta", "θ (4-8Hz)"),
|
||
("alpha", "α (8-13Hz)"),
|
||
("beta", "β (13-30Hz)"),
|
||
("broad", "1-30 Hz"),
|
||
]
|
||
|
||
temp_raw = _create_topomap_raw(valid_names)
|
||
if temp_raw is None:
|
||
return
|
||
|
||
all_values = [band_values[k] for k, _ in order]
|
||
vmin, vmax = _compute_topomap_vlim(all_values)
|
||
|
||
fig, axes = plt.subplots(1, 5, figsize=(16, 4.2))
|
||
ims = []
|
||
for ax, (k, title) in zip(axes, order):
|
||
im, _ = mne.viz.plot_topomap(
|
||
band_values[k], temp_raw.info, axes=ax, show=False, contours=0,
|
||
sphere=(0, 0, 0, TOPOMAP_SPHERE_RADIUS), extrapolate='head', border='mean',
|
||
cmap='turbo'
|
||
)
|
||
im.set_clim(vmin=vmin, vmax=vmax)
|
||
ax.set_title(title, fontsize=11)
|
||
ims.append(im)
|
||
fig.subplots_adjust(left=0.02, right=0.85, top=0.88, bottom=0.05, wspace=0.35)
|
||
cax = fig.add_axes([0.87, 0.15, 0.015, 0.7])
|
||
fig.colorbar(ims[-1], cax=cax)
|
||
plt.savefig(os.path.join(out_dir, "topomaps.png"), dpi=200)
|
||
plt.close(fig)
|
||
print(f"[OK] topomaps saved")
|
||
|
||
|
||
def _compute_topomap_vlim(values):
|
||
"""计算脑地形图颜色范围"""
|
||
v_all = np.concatenate(values) if isinstance(values, list) else np.array(values)
|
||
if TOPOMAP_VMIN is not None and TOPOMAP_VMAX is not None:
|
||
return TOPOMAP_VMAX - 60, TOPOMAP_VMAX # 保持 50 的范围
|
||
if TOPOMAP_SYM_SCALE is not None and TOPOMAP_SYM_SCALE > 0:
|
||
mean_val = np.mean(v_all)
|
||
std_val = np.std(v_all)
|
||
return mean_val - std_val * TOPOMAP_SYM_SCALE, mean_val + std_val * TOPOMAP_SYM_SCALE
|
||
# 统一 vmax:使用所有频段中的最大值
|
||
# vmin = 0:这样低功率频段会接近 0(白色/冷色),高功率频段突出
|
||
vmax = np.max(v_all)
|
||
vmin = 0
|
||
return vmin, vmax
|
||
|
||
|
||
# ==========================
|
||
# 预测接口
|
||
# ==========================
|
||
def _predict_label_by_model(model_path: str, data_path: str) -> dict:
|
||
"""调用 infer_pth.py 进行预测"""
|
||
try:
|
||
from infer_pth import predict_hc_mdd
|
||
except Exception as e:
|
||
raise RuntimeError(f"无法导入 predict_hc_mdd: {e}")
|
||
|
||
import tempfile
|
||
import scipy.io
|
||
|
||
ext = os.path.splitext(data_path)[1].lower()
|
||
|
||
if ext == ".mat":
|
||
# 直接使用 mat 文件
|
||
result = predict_hc_mdd(os.path.dirname(data_path), model_path)
|
||
elif ext == ".bdf":
|
||
# 转换为 mat 格式
|
||
raw = mne.io.read_raw_bdf(data_path, preload=True, verbose=False)
|
||
data, times = raw[:]
|
||
sfreq = raw.info['sfreq']
|
||
ch_names = raw.ch_names
|
||
|
||
with tempfile.TemporaryDirectory() as temp_dir:
|
||
mat_path = os.path.join(temp_dir, "preprocessed_eeg.mat")
|
||
scipy.io.savemat(mat_path, {
|
||
'eeg': {
|
||
'data': (data * 1e6).T,
|
||
'sample_rate': sfreq,
|
||
'electrode_name': ch_names
|
||
}
|
||
})
|
||
result = predict_hc_mdd(temp_dir, model_path)
|
||
else:
|
||
raise ValueError(f"不支持的文件格式: {ext}")
|
||
|
||
return result
|
||
|
||
|
||
# ==========================
|
||
# 生成 ResultData.txt
|
||
# ==========================
|
||
def compute_and_save_txt(model_path, bdf_path, out_dir, eeg_uV_tc, fs, ch_names):
|
||
"""计算特征指标并保存 ResultData.txt"""
|
||
# 获取预测结果
|
||
pred_result = _predict_label_by_model(model_path, bdf_path)
|
||
pred_label = pred_result.get("pred_label", "UNKNOWN")
|
||
recommend = "是" if pred_label == "MDD" else "否"
|
||
|
||
T, C = eeg_uV_tc.shape
|
||
mp = build_channel_index_map(ch_names, C)
|
||
frontal_idx, central_idx, parietal_idx, prefrontal_idx, posterior_idx, left_idx, right_idx = \
|
||
get_region_indices(mp, C)
|
||
|
||
freqs, pxx = welch_psd(eeg_uV_tc, fs)
|
||
|
||
# 计算各频段功率比
|
||
central_alpha = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Alpha"])
|
||
central_beta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Beta"])
|
||
frontal_alpha = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Alpha"])
|
||
frontal_beta = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Beta"])
|
||
par_alpha = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Alpha"])
|
||
par_beta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Beta"])
|
||
|
||
central_ab = (central_alpha / (central_beta + EPS)) if central_beta > 0 else 0.0
|
||
frontal_ab = (frontal_alpha / (frontal_beta + EPS)) if frontal_beta > 0 else 0.0
|
||
par_ab = (par_alpha / (par_beta + EPS)) if par_beta > 0 else 0.0
|
||
|
||
central_theta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Theta"])
|
||
par_theta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Theta"])
|
||
central_tb = (central_theta / (central_beta + EPS)) if central_beta > 0 else 0.0
|
||
par_tb = (par_theta / (par_beta + EPS)) if par_beta > 0 else 0.0
|
||
|
||
if not left_idx or not right_idx:
|
||
left_idx = [i for i in prefrontal_idx if (i % 2 == 0)]
|
||
right_idx = [i for i in prefrontal_idx if (i % 2 == 1)]
|
||
left_alpha = region_mean_power(freqs, pxx, left_idx, BANDS_METRICS["Alpha"])
|
||
right_alpha = region_mean_power(freqs, pxx, right_idx, BANDS_METRICS["Alpha"])
|
||
prefrontal_alpha_asym = float(np.log(right_alpha + EPS) - np.log(left_alpha + EPS))
|
||
|
||
iaf = compute_iaf(freqs, pxx, posterior_idx)
|
||
|
||
pre_td = region_mean_power(freqs, pxx, prefrontal_idx,
|
||
(BANDS_METRICS["Delta"][0], BANDS_METRICS["Theta"][1]))
|
||
pre_total = region_mean_power(freqs, pxx, prefrontal_idx, TOTAL_POWER_BAND)
|
||
pre_td_rel = (pre_td / (pre_total + EPS)) * 100.0 if pre_total > 0 else 0.0
|
||
|
||
def f1(x): return f"{x:.1f}"
|
||
|
||
txt = (
|
||
f"中央区α/β波比值:{f1(central_ab)}\n"
|
||
f"额区α/β波比值:{f1(frontal_ab)}\n"
|
||
f"顶区α/β波比值:{f1(par_ab)}\n"
|
||
f"中央区θ/β波比值:{f1(central_tb)}\n"
|
||
f"顶区θ/β波比值:{f1(par_tb)}\n"
|
||
f"前额叶α波不对称性:{f1(prefrontal_alpha_asym)}\n"
|
||
f"个体化α峰值频率:{f1(iaf)}\n"
|
||
f"前额叶θ+δ波功率:{f1(pre_td_rel)}\n"
|
||
f"是否推荐治疗:{recommend}\n"
|
||
)
|
||
|
||
out_path = os.path.join(out_dir, "ResultData.txt")
|
||
with open(out_path, "w", encoding="utf-8") as f:
|
||
f.write(txt)
|
||
print(f"[OK] ResultData.txt saved: {out_path}")
|
||
|
||
# 打印预测结果
|
||
print(f"\n========== 预测结果 ==========")
|
||
print(f"预测标签: {pred_label}")
|
||
print(f"p(MDD)均值: {pred_result.get('p_mdd_mean', 'N/A'):.4f}")
|
||
print(f"切片数量: {pred_result.get('n_slices', 'N/A')}")
|
||
print(f"==============================\n")
|
||
|
||
|
||
# ==========================
|
||
# 主函数
|
||
# ==========================
|
||
def run_all(model_path: str, bdf_dir: str, out_root: str, seconds: int = EEG_PLOT_SECONDS):
|
||
"""主流程"""
|
||
if not os.path.exists(bdf_dir):
|
||
raise RuntimeError(f"输入目录不存在: {bdf_dir}")
|
||
|
||
# 支持 .bdf 和 .mat 文件
|
||
data_files = [f for f in os.listdir(bdf_dir) if f.lower().endswith((".bdf", ".mat"))]
|
||
if not data_files:
|
||
raise RuntimeError(f"目录中找不到 .bdf 或 .mat 文件: {bdf_dir}")
|
||
|
||
data_files.sort()
|
||
data_path = os.path.join(bdf_dir, data_files[0])
|
||
print(f"[INFO] Processing file: {data_path}")
|
||
|
||
out_dir = ensure_outdir(out_root)
|
||
print(f"[INFO] Output directory: {out_dir}")
|
||
|
||
raw, sfreq, ch_names = load_data_file(data_path)
|
||
|
||
raw_clean = preprocess_bdf(raw)
|
||
|
||
try:
|
||
raw_clean.set_montage("standard_1020", on_missing="ignore")
|
||
except Exception as e:
|
||
print(f"[WARN] Failed to re-apply montage: {e}")
|
||
|
||
raw_data = raw_clean.get_data()
|
||
eeg_uV_tc = (raw_data * 1e6).T.astype(np.float32)
|
||
print(f"[INFO] Preprocessed EEG shape: {eeg_uV_tc.shape}")
|
||
|
||
print("[INFO] Generating figures...")
|
||
plot_psd(eeg_uV_tc, sfreq, ch_names, out_dir)
|
||
plot_eeg_waveforms(eeg_uV_tc, sfreq, ch_names, out_dir, seconds=seconds)
|
||
|
||
print("[INFO] Generating topomaps...")
|
||
try:
|
||
band_vals = compute_band_powers_for_topomap(raw_clean, BANDS_TOPOMAP)
|
||
if band_vals is not None:
|
||
plot_average_topomap(band_vals, out_dir)
|
||
plot_band_topomaps(band_vals, out_dir)
|
||
except Exception as e:
|
||
print(f"[WARN] Topomap generation failed: {e}")
|
||
|
||
print("[INFO] Running prediction...")
|
||
compute_and_save_txt(model_path, data_path, out_dir, eeg_uV_tc, sfreq, ch_names)
|
||
|
||
print("[DONE] All tasks completed.")
|
||
return out_dir
|
||
|
||
|
||
# ==========================
|
||
# 命令行入口
|
||
# ==========================
|
||
if __name__ == "__main__":
|
||
import multiprocessing
|
||
multiprocessing.freeze_support()
|
||
import argparse
|
||
import sys
|
||
|
||
def get_resource_path(relative_path):
|
||
"""获取资源绝对路径"""
|
||
if getattr(sys, 'frozen', False):
|
||
base_path = os.path.dirname(sys.executable)
|
||
else:
|
||
base_path = os.path.dirname(os.path.abspath(__file__))
|
||
return os.path.join(base_path, relative_path)
|
||
|
||
# 默认路径
|
||
DEFAULT_MODEL = get_resource_path(os.path.join("model", "Model_1.pth"))
|
||
if getattr(sys, 'frozen', False):
|
||
EXE_DIR = os.path.dirname(sys.executable)
|
||
else:
|
||
EXE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
DEFAULT_BDF_DIR = os.path.join(EXE_DIR, "raw_data")
|
||
DEFAULT_OUT = os.path.join(EXE_DIR, "out")
|
||
|
||
parser = argparse.ArgumentParser(description="EEG Depression Assessment")
|
||
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL, help="模型文件路径 (.pth)")
|
||
parser.add_argument("--bdf_dir", type=str, default=DEFAULT_BDF_DIR, help="输入文件夹路径 (包含 .bdf 或 .mat 文件)")
|
||
parser.add_argument("--out_root", type=str, default=DEFAULT_OUT, help="结果输出目录")
|
||
parser.add_argument("--seconds", type=int, default=10, help="画波形图的秒数")
|
||
|
||
args = parser.parse_args()
|
||
|
||
print(f"[*] 运行配置:")
|
||
print(f" - Model : {args.model_path}")
|
||
print(f" - Input : {args.bdf_dir}")
|
||
print(f" - Output: {args.out_root}")
|
||
|
||
if not os.path.exists(args.bdf_dir):
|
||
print(f"[ERROR] 输入目录不存在: {args.bdf_dir}")
|
||
if not os.path.exists(args.model_path):
|
||
print(f"[ERROR] 模型文件不存在: {args.model_path}")
|
||
|
||
run_all(args.model_path, args.bdf_dir, args.out_root, seconds=args.seconds)
|