original push

This commit is contained in:
Ivey Song
2026-06-01 13:18:36 +08:00
commit 8426770db6
46 changed files with 341750 additions and 0 deletions

View File

@@ -0,0 +1,96 @@
# -*- mode: python ; coding: utf-8 -*-
import sys
import os
from PyInstaller.utils.hooks import collect_submodules, collect_data_files
# ========================================================
# 1. 工程配置区 (Project Config)
# ========================================================
block_cipher = None
ENTRY_POINT = 'runDecoder.py'
APP_NAME = 'Depression_Decoder'
# ========================================================
# 2. 依赖分析 (Dependency Analysis)
# ========================================================
# 收集 mne, sklearn, scipy 可能遗漏的隐藏导入
hidden_imports = [
'infer_pth', # 你的动态导入模块
'sklearn.utils._cython_blas',
'sklearn.neighbors.typedefs',
'sklearn.neighbors.quad_tree',
'sklearn.tree',
'sklearn.tree._utils',
]
# 自动收集 mne 的子模块
hidden_imports += collect_submodules('mne')
# 收集 torch 相关的隐式导入(虽然 PyInstaller 通常能处理,但显式更安全)
hidden_imports += ['torch', 'torchvision']
# ========================================================
# 3. 资源锚定 (Data Anchoring)
# ========================================================
# Analysis 中的 datas 用于将文件嵌入到内部(运行时在临时目录或 _internal
# 这里我们留空,改为在 COLLECT 阶段通过 Tree 显式复制到 EXE 旁,
# 这样生成的文件夹里能直接看到 model 和 raw_data
datas = []
# 收集 mne 的数据文件(如果需要默认配置)
datas += collect_data_files('mne')
# ========================================================
# 4. 构建流程 (Build Process)
# ========================================================
a = Analysis(
[ENTRY_POINT],
pathex=[],
binaries=[],
datas=datas,
hiddenimports=hidden_imports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=['tkinter', 'PyQt5', 'PySide2', 'IPython'], # 排除 GUI 和交互式库减小体积
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name=APP_NAME,
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=False,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)
# ========================================================
# 5. 打包模式: OneDir (单文件夹) + 资源旁路
# ========================================================
# 使用 Tree 将文件夹原样复制到 dist/APP_NAME/ 下
# 格式: Tree('源路径', prefix='目标子目录')
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
strip=False,
upx=False,
upx_exclude=[],
name=APP_NAME,
)

View File

@@ -0,0 +1,87 @@
import os
import subprocess
import sys
import shutil
# 确保我们在虚拟环境中运行
if not sys.prefix == sys.base_prefix:
print(f"正在使用虚拟环境: {sys.prefix}")
else:
print("警告:你似乎没有激活虚拟环境!建议在 venv_clean 下运行。")
def build():
entry_point = "runDecoder.py" # 你的入口文件
# 自动清理逻辑优化
output_dir = "dist2"
build_dir = "build2" # Nuitka 默认会在当前目录生成 .build 文件夹
if "--clean" in sys.argv:
print("清理旧构建目录...")
for folder in [output_dir, build_dir, entry_point.replace(".py", ".build")]:
if os.path.exists(folder):
shutil.rmtree(folder, ignore_errors=True)
# Nuitka 命令 - 此时非常清爽
nuitka_cmd = [
sys.executable, "-m", "nuitka",
"--standalone", # 独立运行模式
f"--output-dir={output_dir}", # 输出目录
"--show-progress", # 显示进度
"--assume-yes-for-downloads", # 自动下载依赖(如 ccache, depends 等)
# --- 插件配置 ---
"--enable-plugin=numpy",
"--enable-plugin=matplotlib",
"--enable-plugin=torch", # 处理 PyTorch 及其 CUDA 依赖
# --- 包含包/模块 (Nuitka 2.x 推荐使用 include-package-data 或 include-package) ---
# --collect-all 是 PyInstaller 的参数Nuitka 不支持
"--include-package=sklearn",
"--include-package=scipy",
"--include-package=mne",
# 强制包含 MNE 的数据文件(配置、布局等)
"--include-package-data=mne",
"--include-package=PIL", # Pillow (matplotlib/mne 可能用到)
"--include-package=networkx", # mne 可能用到
"--include-package=decorator", # MNE 核心依赖,防止 KeyError: 'self'
"--include-package=six", # 通用兼容库
# 显式包含本地模块,防止隐式导入丢失
"--include-module=infer_pth",
# --- 数据文件 ---
# 格式: 源路径=目标路径 (相对 dist 目录)
"--include-data-dir=model=model",
"--include-data-dir=raw_data=raw_data",
# --- 排除干扰以减小体积/提高稳定性 ---
"--nofollow-import-to=pytest",
"--nofollow-import-to=unittest",
"--nofollow-import-to=pdb",
"--nofollow-import-to=tkinter", # 如果不用 GUI 界面
"--nofollow-import-to=sympy", # 除非明确用到符号计算
# --- 内存与性能 ---
"--low-memory", # 降低打包时的内存消耗
# --- Windows 特定 ---
# "--disable-console", # 如果不需要黑框,取消注释这一行
]
nuitka_cmd.append(entry_point)
print("开始打包...")
try:
subprocess.check_call(nuitka_cmd)
print("\n打包成功!")
print(f"请在 dist2/runDecoder.dist 目录下运行 exe 进行测试。")
except subprocess.CalledProcessError as e:
print(f"打包失败,错误码: {e.returncode}")
if __name__ == "__main__":
build()

View File

@@ -0,0 +1,77 @@
import os
import shutil
import subprocess
import sys
def main():
# 1. 定义路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DIST_DIR = os.path.join(BASE_DIR, 'dist')
APP_NAME = 'Depression_Decoder'
TARGET_DIR = os.path.join(DIST_DIR, APP_NAME)
MODEL_SRC = os.path.join(BASE_DIR, 'model')
RAW_DATA_SRC = os.path.join(BASE_DIR, 'raw_data')
MODEL_DST = os.path.join(TARGET_DIR, 'model')
RAW_DATA_DST = os.path.join(TARGET_DIR, 'raw_data')
# 2. 清理旧构建
print("[1/3] Cleaning up old builds...")
if os.path.exists(DIST_DIR):
try:
shutil.rmtree(DIST_DIR)
print(" Cleaned dist/")
except Exception as e:
print(f" Warning: Could not clean dist/: {e}")
BUILD_DIR = os.path.join(BASE_DIR, 'build')
if os.path.exists(BUILD_DIR):
try:
shutil.rmtree(BUILD_DIR)
print(" Cleaned build/")
except Exception as e:
print(f" Warning: Could not clean build/: {e}")
# 3. 运行 PyInstaller
print("[2/3] Running PyInstaller...")
# 注意:我们这里不传 --noupx因为已经在 spec 文件里把 upx=False 写死了
cmd = [
"pyinstaller",
"build_algorithm.spec",
"--clean"
]
try:
subprocess.check_call(cmd, shell=True)
except subprocess.CalledProcessError:
print("Error: PyInstaller failed.")
sys.exit(1)
# 4. 复制外部资源文件夹
print("[3/3] Copying external resources...")
# 复制 model 文件夹
if os.path.exists(MODEL_SRC):
if os.path.exists(MODEL_DST):
shutil.rmtree(MODEL_DST)
shutil.copytree(MODEL_SRC, MODEL_DST)
print(f" Copied: model -> {MODEL_DST}")
else:
print(f" Warning: Source model dir not found at {MODEL_SRC}")
# 复制 raw_data 文件夹
if os.path.exists(RAW_DATA_SRC):
if os.path.exists(RAW_DATA_DST):
shutil.rmtree(RAW_DATA_DST)
shutil.copytree(RAW_DATA_SRC, RAW_DATA_DST)
print(f" Copied: raw_data -> {RAW_DATA_DST}")
else:
print(f" Warning: Source raw_data dir not found at {RAW_DATA_SRC}")
print("\n" + "="*50)
print(f"SUCCESS! Build artifacts are in: {TARGET_DIR}")
print("="*50)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,38 @@
import os
import sys
import scipy
import numpy
print(f"Python executable: {sys.executable}")
print(f"Scipy version: {scipy.__version__}")
print(f"Scipy path: {scipy.__file__}")
scipy_dir = os.path.dirname(scipy.__file__)
parent_dir = os.path.dirname(scipy_dir)
scipy_libs = os.path.join(parent_dir, "scipy.libs")
print(f"Checking for scipy.libs at: {scipy_libs}")
if os.path.exists(scipy_libs):
print("scipy.libs FOUND.")
for root, dirs, files in os.walk(scipy_libs):
for f in files:
print(f" - {f}")
else:
print("scipy.libs NOT FOUND.")
print("-" * 20)
print(f"Numpy version: {numpy.__version__}")
print(f"Numpy path: {numpy.__file__}")
numpy_dir = os.path.dirname(numpy.__file__)
numpy_libs = os.path.join(numpy_dir, ".libs") # numpy 往往在内部
if not os.path.exists(numpy_libs):
# try parent
numpy_libs = os.path.join(os.path.dirname(numpy_dir), "numpy.libs")
print(f"Checking for numpy libs at: {numpy_libs}")
if os.path.exists(numpy_libs):
print("numpy libs FOUND.")
for root, dirs, files in os.walk(numpy_libs):
for f in files:
print(f" - {f}")
else:
print("numpy libs NOT FOUND.")

View File

@@ -0,0 +1,561 @@
# -*- coding: utf-8 -*-
"""
infer_pth.py
用途:
- 从一个文件夹中自动读取第一个 .mat EEG 文件64通道或32通道
- 若为64通道则按 idx64_to_32 映射选出32通道
- 提取切片特征DE + PSD(var近似)不含Asym
- 加载你训练好的 .pth 模型FusionNet结构
- 输出该受试者的 HC / MDD 判断结果
运行方式(命令行):
python infer_pth_from64_to32.py --eeg_dir "D:\\xxx\\folder" --model_path "C:\\xxx\\model.pth"
也可在其他py里import
from infer_pth_from64_to32 import predict_hc_mdd
res = predict_hc_mdd(eeg_dir, model_path)
print(res)
"""
from __future__ import annotations
import os
import argparse
import numpy as np
import scipy.io
import scipy.signal as signal
import torch
import torch.nn as nn
import torch.nn.functional as F
# =========================================================
# 0) 配置区(按需改这里)
# =========================================================
# 采样率(必须与训练时一致)
SAMPLING_RATE = 250
# 滑窗参数(必须与训练时一致)
WINDOW_SIZE = 500
STRIDE = 250
# 频段(必须与训练时一致)
BAND_NAMES = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]
BANDS = {
"Delta": (1, 4),
"Theta": (4, 8),
"Alpha": (8, 13),
"Beta": (13, 30),
"Gamma": (30, 50),
}
# 是否使用扩展特征DE+PSD
USE_EXTENDED_FEATURES = True
# 数值稳定项
EPS = 1e-12
# 设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 通道映射
IDX64_TO_32 = [
23, # C5
47, # O1
39, # TP7
6, # FPZ
2, # PO6
21, # P4
35, # AF7
57, # AF3
1, # FP2
37, # T7
63, # F1
36, # A1
18, # FC4
31, # FC5
14, # FC2
48, # T8
60, # P2
41, # AF8
11, # CP1
0, # FP1
55, # PO7
59, # C1
22, # F5
10, # CP2
16, # C3
61, # P1
27, # CP5
17, # C4
26, # CP6
62, # F2
3, # POZ
13, # PO5
]
# 推理阈值如果模型checkpoint里有 subject_threshold会优先用它否则用这个
DEFAULT_SUBJECT_THRESHOLD = 0.5
# =========================================================
# 1) 模型结构
# =========================================================
class SEBlock(nn.Module):
def __init__(self, channels: int, reduction: int = 4) -> None:
super().__init__()
hidden = max(1, channels // reduction)
self.fc1 = nn.Linear(channels, hidden)
self.fc2 = nn.Linear(hidden, channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
se = F.relu(self.fc1(x))
se = torch.sigmoid(self.fc2(se))
return x * se
class ResidualBlock(nn.Module):
def __init__(self, in_features: int, out_features: int, dropout: float = 0.3) -> None:
super().__init__()
self.fc1 = nn.Linear(in_features, out_features)
self.bn1 = nn.BatchNorm1d(out_features)
self.fc2 = nn.Linear(out_features, out_features)
self.bn2 = nn.BatchNorm1d(out_features)
self.dropout = nn.Dropout(dropout)
self.shortcut = nn.Identity()
if in_features != out_features:
self.shortcut = nn.Sequential(
nn.Linear(in_features, out_features),
nn.BatchNorm1d(out_features),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.shortcut(x)
out = F.relu(self.bn1(self.fc1(x)))
out = self.dropout(out)
out = self.bn2(self.fc2(out))
out = F.relu(out + identity)
return out
class FusionNet(nn.Module):
def __init__(self, num_classes: int = 2, num_eeg_features: int = 320, num_scales: int = 6) -> None:
super().__init__()
self.input_norm = nn.BatchNorm1d(num_eeg_features)
self.block1 = ResidualBlock(num_eeg_features, 512, dropout=0.4)
self.block2 = ResidualBlock(512, 256, dropout=0.3)
self.block3 = ResidualBlock(256, 128, dropout=0.2)
self.attention = SEBlock(128, reduction=4)
self.final_fc = nn.Sequential(
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(0.2),
)
self.cls_head = nn.Linear(64, num_classes)
# 训练时有回归头也没关系推理只用cls
self.reg_head = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(32, num_scales),
)
self._init_weights()
def _init_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor):
x = self.input_norm(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.attention(x)
features = self.final_fc(x)
cls_out = self.cls_head(features)
reg_out = self.reg_head(features)
return cls_out, reg_out
# =========================================================
# 2) mat读取 + 通道裁剪
# =========================================================
def _find_first_mat_file(folder: str) -> str:
if not os.path.isdir(folder):
raise RuntimeError(f"eeg_dir 不是文件夹: {folder}")
mats = sorted([f for f in os.listdir(folder) if f.lower().endswith(".mat")])
if not mats:
raise RuntimeError(f"文件夹内没有 .mat 文件: {folder}")
return os.path.join(folder, mats[0])
import numpy as np
import scipy.io
def _unwrap_singleton(x):
"""
把 (1,1) / (1,) 这种包裹层一直剥掉,直到不是 singleton。
也处理 object array 的情况。
"""
while True:
if isinstance(x, np.ndarray):
if x.dtype == object and x.size == 1:
x = x.item()
continue
if x.size == 1 and x.ndim >= 1:
# 例如 (1,1) 或 (1,) 的数值/对象数组
try:
x = x.reshape(-1)[0]
continue
except Exception:
pass
break
return x
def _try_get_struct_field(v, field_name="data"):
"""
尝试从以下几种结构中提取字段:
1) scipy 读出的 mat_struct有 _fieldnames
2) numpy structured/record arraydtype.names
"""
# case 1: mat_struct推荐 loadmat(..., struct_as_record=False, squeeze_me=True)
if hasattr(v, "_fieldnames") and (field_name in getattr(v, "_fieldnames", [])):
return getattr(v, field_name)
# case 2: structured array
if isinstance(v, np.ndarray) and v.dtype.names and (field_name in v.dtype.names):
# 常见是 v[field] 仍然是 ndarray / object需要 unwrap
try:
return v[field_name]
except Exception:
return None
return None
def load_eeg_from_mat_any_channels(mat_path: str) -> np.ndarray:
"""
读取 .mat 中 EEG 数据,支持:
- 直接二维矩阵 (T,C) 或 (C,T)
- struct 里有字段 data
返回统一为 float32 的 (T, C)
"""
# 用这两个参数会让 struct 更容易处理:字段变成属性,且自动 squeeze
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
candidates = []
for k, v in mat.items():
if k.startswith("__"):
continue
# --- 1) 直接二维数值矩阵 ---
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
candidates.append((k, v))
continue
# --- 2) struct/record优先提取 data 字段 ---
data_field = _try_get_struct_field(v, "data")
if data_field is not None:
data_field = _unwrap_singleton(data_field)
# data_field 可能仍然被 object 包一层
if isinstance(data_field, np.ndarray) and data_field.dtype == object:
data_field = _unwrap_singleton(data_field)
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
# 只收数值矩阵
if np.issubdtype(data_field.dtype, np.number) or data_field.dtype == object:
candidates.append((f"{k}.data", data_field))
continue
# --- 3) object array尝试 item() 解包后再看是不是二维数值矩阵/struct ---
if isinstance(v, np.ndarray) and v.dtype == object:
vv = _unwrap_singleton(v)
# 解包后若是二维数值矩阵
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
candidates.append((k, vv))
continue
# 解包后若是 struct再取 data
data2 = _try_get_struct_field(vv, "data")
if data2 is not None:
data2 = _unwrap_singleton(data2)
if isinstance(data2, np.ndarray) and data2.ndim == 2:
candidates.append((f"{k}.data", data2))
continue
if not candidates:
raise RuntimeError(f"mat里没找到可用EEG二维矩阵或struct.data{mat_path}")
# 选一个最像 EEG 的优先含32/64通道维度的
def score(arr: np.ndarray) -> int:
s = 0
if 64 in arr.shape: s += 10
if 32 in arr.shape: s += 9
if 128 in arr.shape: s += 8
if 129 in arr.shape: s += 7
s += int(np.prod(arr.shape) // 100000) # 大一些更像EEG
return s
candidates.sort(key=lambda kv: score(kv[1]), reverse=True)
key, eeg = candidates[0]
eeg = _unwrap_singleton(eeg)
# 如果还是 object dtype尽力转成 float
if isinstance(eeg, np.ndarray) and eeg.dtype == object:
# 有时 object 里其实是数值
eeg = np.array(eeg, dtype=np.float32)
else:
eeg = np.asarray(eeg, dtype=np.float32)
if eeg.ndim != 2:
raise RuntimeError(f"解析结果不是二维矩阵: key={key}, shape={eeg.shape}, file={mat_path}")
# 统一为 (T, C)
# 常见 (C,T) 或 (T,C),我们用“通道维通常较小”+ “32/64/128/129”判断
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
eeg = eeg.T
elif eeg.shape[1] in (32, 64, 128, 129):
# 如果第一维也是这些数且更小,可能是(C,T)
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
eeg = eeg.T
return eeg
def ensure_32_channels(eeg: np.ndarray) -> np.ndarray:
"""
输入 (T, C),输出 (T, 32)
- 若C=64按 IDX64_TO_32 选32通道
- 若C=32直接返回
"""
if eeg.ndim != 2:
raise RuntimeError(f"EEG必须是二维(T,C),但得到: {eeg.shape}")
C = eeg.shape[1]
if C == 64:
idx = np.asarray(IDX64_TO_32, dtype=np.int64)
if idx.min() < 0 or idx.max() >= 64:
raise RuntimeError(f"IDX64_TO_32 越界min={idx.min()}, max={idx.max()} (要求0~63)")
return eeg[:, idx]
if C == 32:
return eeg
raise RuntimeError(f"不支持的通道数C={C},当前只支持 64->32 或 32 直推。")
# =========================================================
# 3) 特征提取DE + PSD(var近似)
# =========================================================
class FeatureExtractor32:
"""
只针对32通道输出维度
- USE_EXTENDED_FEATURES=TrueDE(32*5) + PSD(32*5) = 320
- 否则DE(32*5) = 160
"""
def __init__(
self,
fs: int = SAMPLING_RATE,
window_size: int = WINDOW_SIZE,
stride: int = STRIDE,
filter_order: int = 4,
zero_phase: bool = False,
) -> None:
self.fs = fs
self.window_size = window_size
self.stride = stride
self.filter_order = filter_order
self.zero_phase = zero_phase
self._sos = {}
for bn in BAND_NAMES:
low, high = BANDS[bn]
self._sos[bn] = signal.butter(
self.filter_order, [low, high],
btype="band", fs=self.fs, output="sos"
)
def _filter_bands(self, eeg: np.ndarray) -> dict[str, np.ndarray]:
out = {}
for bn in BAND_NAMES:
sos = self._sos[bn]
if self.zero_phase:
out[bn] = signal.sosfiltfilt(sos, eeg, axis=0).astype(np.float32)
else:
out[bn] = signal.sosfilt(sos, eeg, axis=0).astype(np.float32)
return out
def extract(self, eeg32: np.ndarray) -> np.ndarray:
"""
eeg32: (T, 32)
return: feats (N_slices, feat_dim)
"""
if eeg32.ndim != 2 or eeg32.shape[1] != 32:
raise RuntimeError(f"extract需要 (T,32),得到 {eeg32.shape}")
bands_data = self._filter_bands(eeg32)
T = eeg32.shape[0]
feats = []
for start in range(0, T - self.window_size, self.stride):
end = start + self.window_size
de_list = []
psd_list = []
for bn in BAND_NAMES:
seg = bands_data[bn][start:end, :] # (W, 32)
var = np.var(seg, axis=0, ddof=1) # (32,)
# DE
de = 0.5 * np.log(2 * np.pi * np.e * (var + EPS))
de_list.append(de)
if USE_EXTENDED_FEATURES:
# PSD近似log(var)
psd_list.append(np.log(var + EPS))
de_feat = np.stack(de_list, axis=0).T.reshape(-1) # (32*5,)
if USE_EXTENDED_FEATURES:
psd_feat = np.stack(psd_list, axis=0).T.reshape(-1) # (32*5,)
f = np.concatenate([de_feat, psd_feat], axis=0).astype(np.float32)
else:
f = de_feat.astype(np.float32)
feats.append(f)
if not feats:
raise RuntimeError("EEG长度不足以切片请检查T是否太短或调整WINDOW_SIZE/STRIDE")
return np.stack(feats, axis=0).astype(np.float32)
# =========================================================
# 4) 模型加载 + 推理接口
# =========================================================
def _safe_torch_load(path: str):
try:
return torch.load(path, map_location=DEVICE, weights_only=False)
except TypeError:
return torch.load(path, map_location=DEVICE)
def load_model(model_path: str) -> tuple[FusionNet, dict]:
"""
返回: (model, ckpt_dict)
"""
obj = _safe_torch_load(model_path)
if isinstance(obj, dict) and "model_state" in obj:
ckpt = obj
state = obj["model_state"]
feat_dim = int(obj.get("feat_dim", 320))
else:
ckpt = {}
state = obj
feat_dim = 320
model = FusionNet(num_classes=2, num_eeg_features=feat_dim).to(DEVICE)
model.load_state_dict(state, strict=True)
model.eval()
return model, ckpt
def predict_hc_mdd(eeg_dir: str, model_path: str) -> dict:
"""
接口:传入 EEG文件夹 和 模型路径,返回判断结果 dict
返回字段:
- mat_file: 使用的mat文件
- pred_label: "HC" or "MDD"
- p_mdd_mean: 切片p(MDD)均值
- threshold: subject判定阈值
- n_slices: 切片数
"""
mat_file = _find_first_mat_file(eeg_dir)
# 1) 读EEG (T,C)并保证变成32通道
eeg = load_eeg_from_mat_any_channels(mat_file) # (T,C)
eeg32 = ensure_32_channels(eeg) # (T,32)
# 2) 提特征 (N,feat_dim)
extractor = FeatureExtractor32(fs=SAMPLING_RATE, window_size=WINDOW_SIZE, stride=STRIDE)
feats = extractor.extract(eeg32) # (N, dim)
# 3) 加载模型
model, ckpt = load_model(model_path)
# 4) 可选归一化若ckpt里保存的mean/std维度刚好匹配
mean = ckpt.get("global_mean", None) if isinstance(ckpt, dict) else None
std = ckpt.get("global_std", None) if isinstance(ckpt, dict) else None
if mean is not None and std is not None:
mean = np.asarray(mean, dtype=np.float32)
std = np.asarray(std, dtype=np.float32)
if mean.shape[0] == feats.shape[1] and std.shape[0] == feats.shape[1]:
feats = (feats - mean) / (std + 1e-8)
# 不匹配就跳过
# 你说“先不用其他步骤或信息”,所以这里按“尽量运行”处理
# 5) 推理:对所有切片算 p(MDD)取均值做subject-level
x = torch.from_numpy(feats).to(DEVICE)
with torch.no_grad():
cls_out, _ = model(x)
prob_mdd = torch.softmax(cls_out, dim=1)[:, 1].detach().cpu().numpy()
p_mdd_mean = float(np.mean(prob_mdd))
thr = float(ckpt.get("subject_threshold", DEFAULT_SUBJECT_THRESHOLD)) if isinstance(ckpt, dict) else float(DEFAULT_SUBJECT_THRESHOLD)
pred_is_mdd = (p_mdd_mean >= thr)
pred_label = "MDD" if pred_is_mdd else "HC"
return {
"mat_file": mat_file,
"pred_label": pred_label,
"p_mdd_mean": p_mdd_mean,
"threshold": thr,
"n_slices": int(feats.shape[0]),
}
# =========================================================
# 5) CLI命令行运行入口
# =========================================================
def main():
parser = argparse.ArgumentParser(description="Infer HC/MDD from 64ch->32ch EEG mat using a .pth FusionNet model (no Asym).")
parser.add_argument("--eeg_dir", type=str, required=True, help="包含.mat EEG文件的文件夹自动读取第一个.mat")
parser.add_argument("--model_path", type=str, required=True, help="训练好的.pth模型路径")
args = parser.parse_args()
res = predict_hc_mdd(args.eeg_dir, args.model_path)
print("\n========== 推理结果 ==========")
print(f"MAT文件: {res['mat_file']}")
print(f"切片数量: {res['n_slices']}")
print(f"p(MDD)_mean: {res['p_mdd_mean']:.4f}")
print(f"阈值thr: {res['threshold']:.4f}")
print(f"预测结果: {res['pred_label']}")
print("==============================\n")
if __name__ == "__main__":
main()

Binary file not shown.

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 306 KiB

View File

@@ -0,0 +1,9 @@
中央区α/β波比值:1.2
额区α/β波比值:1.3
顶区α/β波比值:1.2
中央区θ/β波比值:3.2
顶区θ/β波比值:3.5
前额叶α波不对称性:0.3
个体化α峰值频率:8.5
前额叶θ+δ波功率:93.8
是否推荐治疗:否

Binary file not shown.

After

Width:  |  Height:  |  Size: 268 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 493 KiB

View File

@@ -0,0 +1,6 @@
numpy
scipy
matplotlib
mne
torch
scikit-learn

View File

@@ -0,0 +1,909 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
"""
run_metrics_and_figs.py
1) 自动读取 mat_dir 中排序后的第一个 .mat
2) 调用模型预测HC/MDD并写 ResultData.txt
3) 同时保存图片EEG.png / psd.png / average_topomap.png / topomaps.png
"""
import matplotlib
matplotlib.use('Agg')
import numpy as np
import os
import shutil
import scipy.io
import scipy.signal as signal
import matplotlib.pyplot as plt
import mne
from mne.preprocessing import ICA
# ==========================
# Config
# ==========================
PREPROCESS_BANDPASS = (0.8, 30.0)
PREPROCESS_NOTCH = [50, 100]
PREPROCESS_ICA_N = 0.99
PREPROCESS_ICA_SEED = 97
PREPROCESS_APPLY_AVG_REF = True
PREPROCESS_BAD_PTP_UV = 350.0 # 坏段阈值 (μV)
DEFAULT_FS = 250.0
EEG_PLOT_SECONDS = 10
PSD_FMIN, PSD_FMAX = 0.8, 45.0
EPS = 1e-12
FIXED_EEG_IDXS = [23, 47, 39, 6, 2, 21, 35, 57] # 0-based index, 按重要性排序
FIXED_EEG_LABELS = ["C5", "O1", "TP7", "FPZ", "PO6", "P4", "AF7", "AF3"]
BANDS_METRICS = {
"Delta": (1.0, 4.0),
"Theta": (4.0, 8.0),
"Alpha": (8.0, 13.0),
"Beta": (13.0, 30.0),
}
TOTAL_POWER_BAND = (1.0, 50.0)
BANDS_TOPOMAP = {
"delta": (0.8, 3.9),
"theta": (4.0, 7.9),
"alpha": (8.0, 12.9),
"beta": (13.0, 30.0),
"broad": (0.8, 30.0),
}
# ==========================
# 预处理逻辑
# ==========================
def annotate_bad_segments(raw, peak_to_peak_uv=250.0):
"""
简单坏段检测:按固定窗口计算峰峰值,超过阈值标为 bad。
"""
peak_to_peak_v = peak_to_peak_uv * 1e-6
win = int(raw.info["sfreq"] * 1.0)
step = int(raw.info["sfreq"] * 0.5)
data = raw.get_data()
n_times = data.shape[1]
onsets = []
durations = []
descriptions = []
for start in range(0, n_times - win, step):
seg = data[:, start:start + win]
ptp = np.ptp(seg, axis=1)
if np.any(ptp > peak_to_peak_v):
onsets.append(start / raw.info["sfreq"])
durations.append(win / raw.info["sfreq"])
descriptions.append("BAD_PTP")
if len(onsets) > 0:
ann = mne.Annotations(onset=onsets, duration=durations, description=descriptions)
raw.set_annotations(ann)
print(f"[INFO] Annotated bad segments: {len(onsets)} windows")
else:
print("[INFO] No bad segments detected by PTP rule")
def run_preprocess_on_raw(raw: mne.io.RawArray) -> mne.io.RawArray:
"""
核心预处理:滤波 + 平均参考 + 坏段标注 + ICA
"""
# 1) 滤波
raw.filter(PREPROCESS_BANDPASS[0], PREPROCESS_BANDPASS[1], fir_design="firwin", verbose=False)
raw.notch_filter(PREPROCESS_NOTCH, fir_design="firwin", verbose=False)
# 2) 平均参考
if PREPROCESS_APPLY_AVG_REF:
raw.set_eeg_reference("average", verbose=False)
# 3) 坏段标注
annotate_bad_segments(raw, peak_to_peak_uv=PREPROCESS_BAD_PTP_UV)
# 4) ICA
ica = ICA(
n_components=PREPROCESS_ICA_N,
random_state=PREPROCESS_ICA_SEED,
max_iter=800,
method="fastica"
)
ica.fit(raw, reject_by_annotation=True, verbose=False)
try:
eog_inds, _ = ica.find_bads_eog(raw, verbose=False)
if eog_inds:
ica.exclude.extend(eog_inds)
print(f"[INFO] ICA exclude EOG comps: {eog_inds}")
except Exception as e:
print(f"[WARN] ICA find_bads_eog skipped: {e}")
raw_clean = ica.apply(raw.copy(), verbose=False)
return raw_clean
def preprocess_mat_file(src_mat_path: str, temp_out_dir: str) -> str:
"""
读取原始mat -> 预处理 -> 保存到 temp_out_dir -> 返回新路径
"""
os.makedirs(temp_out_dir, exist_ok=True)
# 1. 读原始 mat
# 注意:这里我们只要数据部分转成 MNE Raw然后处理再存回
# 复用现有的 load_eeg_from_mat 拿到 ndarray
eeg_uV, fs, ch_names, xyz = load_eeg_from_mat(src_mat_path)
# 转 MNE (注意单位uV -> V)
if not ch_names:
ch_names = [f"CH{i+1}" for i in range(eeg_uV.shape[1])]
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types=["eeg"] * len(ch_names))
raw = mne.io.RawArray(eeg_uV.T * 1e-6, info, verbose=False)
if xyz is not None and isinstance(xyz, np.ndarray):
# 尝试设 montage虽然对滤波不关键但尽量保留信息
try:
ch_pos = {ch_names[i]: xyz[i, :] for i in range(len(ch_names))}
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame="head")
raw.set_montage(montage, on_missing="ignore")
except Exception:
pass
# 2. 执行预处理
print(f"[INFO] Start preprocessing: {src_mat_path}")
raw_clean = run_preprocess_on_raw(raw)
# 3. 存回 .mat (保持结构兼容,以便后续 run_all 读取)
# 这里我们需要读取原始 mat 的结构体,把 data 替换掉
try:
mat_struct = scipy.io.loadmat(src_mat_path, struct_as_record=False, squeeze_me=True)
if "eeg" in mat_struct:
eeg_obj = mat_struct["eeg"]
# 替换数据MNE (V) -> uV -> (T, C)
clean_data_uV = (raw_clean.get_data() * 1e6).T
eeg_obj.data = clean_data_uV
base_name = os.path.basename(src_mat_path)
new_path = os.path.join(temp_out_dir, base_name)
scipy.io.savemat(new_path, {"eeg": eeg_obj}, do_compression=True)
print(f"[INFO] Preprocessed file saved to: {new_path}")
return new_path
except Exception as e:
print(f"[WARN] Failed to preserve original struct structure: {e}")
# Fallback: 如果读原始结构失败,就存一个简单的 mat
clean_data_uV = (raw_clean.get_data() * 1e6).T
out_dict = {
"eeg": {
"data": clean_data_uV,
"sample_rate": fs,
"electrode_name": ch_names,
"electrode_xyz": xyz if xyz is not None else []
}
}
base_name = os.path.basename(src_mat_path)
new_path = os.path.join(temp_out_dir, base_name)
scipy.io.savemat(new_path, out_dict, do_compression=True)
print(f"[INFO] Preprocessed file saved (fallback mode) to: {new_path}")
return new_path
# ==========================
# 输出目录
# ==========================
def ensure_outdir(out_root: str) -> str:
"""
确保输出目录存在,并清空除 ResultData.txt 之外的旧文件。
不再创建 timestamp 子文件夹,直接输出到 out_root。
"""
if os.path.exists(out_root):
# 清空目录,但保留 ResultData.txt
for filename in os.listdir(out_root):
if filename == "ResultData.txt":
continue
file_path = os.path.join(out_root, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f"[WARN] Failed to delete {file_path}. Reason: {e}")
else:
os.makedirs(out_root, exist_ok=True)
return out_root
# ==========================
# 单位自动识别:统一到 μV
# ==========================
def _auto_scale_to_uV(data_nt_nc: np.ndarray):
data = np.asarray(data_nt_nc)
p95 = float(np.percentile(np.abs(data), 95))
if p95 <= 0.5:
data_uV = data * 1e6
msg = f"[UNIT] p95={p95:.3g} -> assume V, convert to μV by *1e6"
elif p95 > 5000:
data_uV = data * 1e-3
msg = f"[UNIT] p95={p95:.3g} -> assume nV, convert to μV by /1000"
else:
data_uV = data
msg = f"[UNIT] p95={p95:.3g} -> assume μV, no scaling"
p95_uV = float(np.percentile(np.abs(data_uV), 95))
warn = None
if p95_uV > 5000:
warn = f"[WARN] After scaling, p95 still large: {p95_uV:.3g} μV"
elif p95_uV < 0.1:
warn = f"[WARN] After scaling, p95 still small: {p95_uV:.3g} μV"
return data_uV, msg, warn
# ==========================
# mat 读取(支持 struct.data / electrode_name / electrode_xyz / sample_rate
# ==========================
def _unwrap_singleton(x):
while True:
if isinstance(x, np.ndarray):
if x.dtype == object and x.size == 1:
x = x.item()
continue
if x.size == 1 and x.ndim >= 1:
try:
x = x.reshape(-1)[0]
continue
except Exception:
pass
break
return x
def _try_get_struct_field(v, field_name="data"):
if hasattr(v, "_fieldnames") and field_name in getattr(v, "_fieldnames", []):
return getattr(v, field_name)
if isinstance(v, np.ndarray) and v.dtype.names and field_name in v.dtype.names:
try:
return v[field_name]
except Exception:
return None
return None
def _extract_electrode_names(st):
nf = _try_get_struct_field(st, "electrode_name")
if nf is None:
return None
nf = _unwrap_singleton(nf)
if isinstance(nf, (list, tuple)):
names = [str(x).strip() for x in nf]
return names if names else None
if isinstance(nf, np.ndarray):
flat = nf.reshape(-1)
names = [str(_unwrap_singleton(x)).strip() for x in flat]
return names if names else None
s = str(nf).strip()
return [s] if s else None
def _extract_sample_rate(st):
sr = _try_get_struct_field(st, "sample_rate")
if sr is None:
return None
sr = _unwrap_singleton(sr)
try:
return float(sr)
except Exception:
return None
def _extract_xyz(st):
xyz = _try_get_struct_field(st, "electrode_xyz")
if xyz is None:
return None
xyz = _unwrap_singleton(xyz)
try:
xyz = np.asarray(xyz, dtype=float)
if xyz.ndim == 2 and xyz.shape[1] == 3:
return xyz
if xyz.ndim == 2 and xyz.shape[0] == 3:
return xyz.T
return None
except Exception:
return None
def load_eeg_from_mat(mat_path: str):
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
candidates = []
st_for_meta = None
for k, v in mat.items():
if k.startswith("__"):
continue
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
candidates.append((k, v, None))
continue
data_field = _try_get_struct_field(v, "data")
if data_field is not None:
data_field = _unwrap_singleton(data_field)
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
candidates.append((f"{k}.data", data_field, v))
continue
if isinstance(v, np.ndarray) and v.dtype == object:
vv = _unwrap_singleton(v)
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
candidates.append((k, vv, None))
continue
data2 = _try_get_struct_field(vv, "data")
if data2 is not None:
data2 = _unwrap_singleton(data2)
if isinstance(data2, np.ndarray) and data2.ndim == 2:
candidates.append((f"{k}.data", data2, vv))
continue
if not candidates:
raise RuntimeError(f"mat 里没找到可用 EEG 二维矩阵或 struct.data{mat_path}")
def score(arr: np.ndarray) -> int:
s = 0
if 64 in arr.shape: s += 10
if 32 in arr.shape: s += 9
if 128 in arr.shape: s += 8
if 129 in arr.shape: s += 7
s += int(np.prod(arr.shape) // 100000)
return s
candidates.sort(key=lambda x: score(x[1]), reverse=True)
key, eeg, st = candidates[0]
st_for_meta = st
eeg = np.asarray(_unwrap_singleton(eeg), dtype=np.float32)
if eeg.ndim != 2:
raise RuntimeError(f"解析结果不是二维: key={key}, shape={eeg.shape}, file={mat_path}")
# 统一成 (T, C)
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
eeg = eeg.T
elif eeg.shape[1] in (32, 64, 128, 129):
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
eeg = eeg.T
fs = DEFAULT_FS
ch_names = None
xyz = None
if st_for_meta is not None:
fs2 = _extract_sample_rate(st_for_meta)
if fs2 is not None and fs2 > 1:
fs = float(fs2)
ch_names = _extract_electrode_names(st_for_meta)
xyz = _extract_xyz(st_for_meta)
eeg_uV, msg, warn = _auto_scale_to_uV(eeg)
print(msg)
if warn:
print(warn)
return eeg_uV.astype(np.float32), float(fs), ch_names, xyz
# ==========================
# 预测接口:导入 predict_hc_mdd
# ==========================
def _predict_label_by_model(model_path: str, mat_dir: str) -> str:
try:
from infer_pth import predict_hc_mdd
except Exception as e:
raise RuntimeError(
"无法导入 predict_hc_mdd请确保 pre.py 或 infer_pth.py 与本文件同目录)。\n"
f"原始错误: {e}"
)
try:
out = predict_hc_mdd(mat_dir, model_path)
except TypeError:
out = predict_hc_mdd(model_path, mat_dir)
label = str(out.get("pred_label", "")).strip().upper()
if label not in ("HC", "MDD"):
raise RuntimeError(f"predict_hc_mdd 返回 pred_label 非法: {label},原始返回: {out}")
return label
# ==========================
# 通道分区
# ==========================
def _norm_name(s: str) -> str:
return str(s).strip().upper().replace(" ", "")
def build_channel_index_map(ch_names, n_channels: int):
if not ch_names or len(ch_names) != n_channels:
return {}
return {_norm_name(nm): i for i, nm in enumerate(ch_names)}
def pick_indices_by_names(name_to_idx, names):
idx = []
for n in names:
nn = _norm_name(n)
if nn in name_to_idx:
idx.append(name_to_idx[nn])
return sorted(list(set(idx)))
def _fallback_region_indices(n_channels: int):
a = int(n_channels * 0.33)
b = int(n_channels * 0.66)
frontal = list(range(0, a))
central = list(range(a, b))
parietal = list(range(b, n_channels))
prefrontal = list(range(0, max(2, a // 2)))
posterior = list(range(b, n_channels))
left = [i for i in range(n_channels) if i % 2 == 0]
right = [i for i in range(n_channels) if i % 2 == 1]
return frontal, central, parietal, prefrontal, posterior, left, right
def get_region_indices(name_to_idx, n_channels: int):
if not name_to_idx:
return _fallback_region_indices(n_channels)
central_names = ["CZ","C1","C2","C3","C4","C5","C6","CP1","CP2","CP3","CP4","CP5","CP6","FC1","FC2","FC3","FC4","FC5","FC6"]
frontal_names = ["FZ","F1","F2","F3","F4","F5","F6","F7","F8","AF3","AF4","AF7","AF8","FPZ","FP1","FP2","FCZ"]
parietal_names = ["PZ","P1","P2","P3","P4","P5","P6","POZ","PO3","PO4","PO5","PO6","PO7","PO8","CPZ"]
prefrontal_names = ["FP1","FP2","FPZ","AF3","AF4","AF7","AF8"]
posterior_names = ["O1","O2","OZ","PO7","PO8","PO3","PO4","PZ","P3","P4","P1","P2"]
central = pick_indices_by_names(name_to_idx, central_names)
frontal = pick_indices_by_names(name_to_idx, frontal_names)
parietal = pick_indices_by_names(name_to_idx, parietal_names)
prefrontal = pick_indices_by_names(name_to_idx, prefrontal_names)
posterior = pick_indices_by_names(name_to_idx, posterior_names)
left_names = ["FP1","AF3","AF7","F3","F5","F7"]
right_names = ["FP2","AF4","AF8","F4","F6","F8"]
left = pick_indices_by_names(name_to_idx, left_names)
right = pick_indices_by_names(name_to_idx, right_names)
if not (central and frontal and parietal and prefrontal and posterior):
fb = _fallback_region_indices(n_channels)
frontal2, central2, parietal2, prefrontal2, posterior2, left2, right2 = fb
frontal = frontal if frontal else frontal2
central = central if central else central2
parietal = parietal if parietal else parietal2
prefrontal = prefrontal if prefrontal else prefrontal2
posterior = posterior if posterior else posterior2
left = left if left else left2
right = right if right else right2
return frontal, central, parietal, prefrontal, posterior, left, right
# ==========================
# Welch PSD + band power
# ==========================
def welch_psd(eeg_tc: np.ndarray, fs: float):
nperseg = min(1024, eeg_tc.shape[0])
if nperseg < 128:
nperseg = min(256, eeg_tc.shape[0])
freqs, pxx = signal.welch(
eeg_tc, fs=fs, nperseg=nperseg, noverlap=nperseg // 2,
axis=0, scaling="density",
)
return freqs, pxx
def band_power_from_psd(freqs, pxx_fc, band):
lo, hi = band
m = (freqs >= lo) & (freqs < hi)
if not np.any(m):
return np.zeros((pxx_fc.shape[1],), dtype=np.float32)
# 兼容处理numpy 2.0+ 推荐使用 trapezoid旧版本用 trapz
if hasattr(np, "trapezoid"):
return np.trapezoid(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
else:
return np.trapz(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
def region_mean_power(freqs, pxx_fc, idx, band) -> float:
if not idx:
return 0.0
pw = band_power_from_psd(freqs, pxx_fc, band)
return float(np.mean(pw[idx]))
def compute_iaf(freqs, pxx_fc, posterior_idx):
lo, hi = BANDS_METRICS["Alpha"]
m = (freqs >= lo) & (freqs <= hi)
if not np.any(m) or not posterior_idx:
return 0.0
spec = np.mean(pxx_fc[:, posterior_idx], axis=1)
sub = spec[m]
fsub = freqs[m]
return float(fsub[int(np.argmax(sub))])
# ==========================
# 图EEG波形、PSD
# ==========================
def plot_eeg_waveforms(data_uv_tc: np.ndarray, fs: float, ch_names, out_dir: str, seconds: int = 10):
"""
固定用 FIXED_EEG_IDXS 画 EEG.png按重要性排序
data_uv_tc: (T, C) μV
"""
T, C = data_uv_tc.shape
# 1) 过滤越界索引(避免你的数据通道数不足时报错)
idxs = [i for i in FIXED_EEG_IDXS if 0 <= i < C]
if len(idxs) < len(FIXED_EEG_IDXS):
missing = [i for i in FIXED_EEG_IDXS if not (0 <= i < C)]
print(f"[WARN] Some fixed EEG indices out of range (C={C}): {missing}")
if len(idxs) == 0:
raise RuntimeError(f"No valid indices in FIXED_EEG_IDXS for current data (C={C}).")
# 2) 通道显示
picked_names = []
for idx in idxs:
# 找 idx 在 FIXED_EEG_IDXS 的位置,用对应标签
pos = FIXED_EEG_IDXS.index(idx)
std_label = FIXED_EEG_LABELS[pos] if pos < len(FIXED_EEG_LABELS) else f"CH{idx}"
if ch_names and idx < len(ch_names):
picked_names.append(f"{std_label}")
else:
picked_names.append(std_label)
# 3) 截取前 seconds 秒
max_samples = int(min(T, seconds * fs))
x = np.arange(max_samples) / fs
fig_h = 1.4 * len(idxs) + 1
fig, axes = plt.subplots(len(idxs), 1, figsize=(10, fig_h), sharex=True)
if len(idxs) == 1:
axes = [axes]
# 4) 分位数定范围,避免尖峰撑爆
seg = data_uv_tc[:max_samples, idxs].T # (n_ch, samples)
lo = float(np.percentile(seg, 1))
hi = float(np.percentile(seg, 99))
m = max(abs(lo), abs(hi))
m = max(m, 50.0)
for ax, ch_idx, nm in zip(axes, idxs, picked_names):
y = data_uv_tc[:max_samples, ch_idx]
ax.plot(x, y, linewidth=1.2)
ax.set_ylabel("μV")
ax.set_title(nm, loc="left", fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_ylim(-m, m)
axes[-1].set_xlabel("Time (s)")
plt.tight_layout()
out_path = os.path.join(out_dir, "EEG.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] EEG waveform saved: {out_path}")
def plot_psd(eeg_uV_tc, fs, ch_names, out_dir):
C = eeg_uV_tc.shape[1]
chosen_idx = []
if ch_names:
mp = {n.upper(): i for i, n in enumerate(ch_names)}
for p in ["C3","C4","CZ"]:
if p in mp:
chosen_idx.append(mp[p])
if len(chosen_idx) < 3:
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
stds.sort(key=lambda x: x[1], reverse=True)
for i, _ in stds:
if i not in chosen_idx:
chosen_idx.append(i)
if len(chosen_idx) == 3:
break
chosen_name = [ch_names[i] for i in chosen_idx]
else:
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
stds.sort(key=lambda x: x[1], reverse=True)
chosen_idx = [i for i, _ in stds[:3]]
chosen_name = [f"CH{i}" for i in chosen_idx]
fig = plt.figure(figsize=(7.5, 4.8))
for idx, nm in zip(chosen_idx, chosen_name):
f, pxx = signal.welch(eeg_uV_tc[:, idx], fs=fs, nperseg=int(2*fs), noverlap=int(1*fs))
mask = (f >= PSD_FMIN) & (f <= PSD_FMAX)
p_db = 10 * np.log10(pxx[mask] + 1e-20)
plt.plot(f[mask], p_db, linewidth=1.8, label=nm)
plt.xlabel("Hz")
plt.ylabel("Power (dB)")
plt.title("PSD")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
out_path = os.path.join(out_dir, "psd.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] psd.png -> {out_path}")
# ==========================
# Topomap如果有 xyz
# ==========================
def build_mne_raw_from_uV(eeg_uV_tc, fs, ch_names, xyz):
C = eeg_uV_tc.shape[1]
if not ch_names:
ch_names = [f"CH{i+1}" for i in range(C)]
data_v_ct = eeg_uV_tc.T * 1e-6 # (C,T) V
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types=["eeg"] * C)
raw = mne.io.RawArray(data_v_ct, info, verbose=False)
if xyz is not None and isinstance(xyz, np.ndarray) and xyz.shape == (C, 3):
try:
ch_pos = {ch_names[i]: xyz[i, :] for i in range(C)}
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame="head")
raw.set_montage(montage, on_missing="ignore")
except Exception as e:
print(f"[WARN] set_montage failed (ignore): {e}")
else:
print("[WARN] electrode_xyz missing/invalid -> skip topomap")
return raw
def _raw_has_positions(raw):
try:
locs = np.array([ch["loc"][:3] for ch in raw.info["chs"]])
ok = np.isfinite(locs).all() and (np.linalg.norm(locs, axis=1) > 0).any()
return bool(ok)
except Exception:
return False
def compute_band_powers_for_topomap(raw, bands):
data = raw.get_data() # (C,T) V
fs = raw.info["sfreq"]
psds, freqs = mne.time_frequency.psd_array_welch(
data, sfreq=fs,
fmin=min(v[0] for v in bands.values()),
fmax=max(v[1] for v in bands.values()),
n_fft=int(2 * fs),
n_overlap=int(1 * fs),
average="mean",
verbose=False
)
out = {}
for k, (fmin, fmax) in bands.items():
idx = np.where((freqs >= fmin) & (freqs <= fmax))[0]
# 兼容处理numpy 2.0+ 推荐使用 trapezoid旧版本用 trapz
if hasattr(np, "trapezoid"):
bp = np.trapezoid(psds[:, idx], freqs[idx], axis=1) # (C,)
else:
bp = np.trapz(psds[:, idx], freqs[idx], axis=1) # (C,)
v = np.log10(bp + 1e-30)
v = v - np.mean(v)
out[k] = v
return out
def plot_average_topomap(raw, values, out_dir):
fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.6))
im, _ = mne.viz.plot_topomap(values, raw.info, axes=ax, show=False, contours=0,sphere=(0, 0, 0, 0.11))
ax.set_title("0.8-30 Hz", fontsize=12)
plt.colorbar(im, ax=ax, shrink=0.85)
plt.tight_layout()
out_path = os.path.join(out_dir, "average_topomap.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] average_topomap.png -> {out_path}")
def plot_band_topomaps(raw, band_values, out_dir):
order = [
("delta", "δ (0.8-3.9Hz)"),
("theta", "θ (4-7.9Hz)"),
("alpha", "α (8-12.9Hz)"),
("beta", "β (13-30Hz)"),
("broad", "0.8-30 Hz"),
]
fig, axes = plt.subplots(1, 5, figsize=(16, 4.2))
ims = []
for ax, (k, title) in zip(axes, order):
im, _ = mne.viz.plot_topomap(band_values[k], raw.info, axes=ax, show=False, contours=0,extrapolate='head',sphere=(0, 0, 0, 0.11))
ax.set_title(title, fontsize=11)
ims.append(im)
fig.subplots_adjust(left=0.02, right=0.85, top=0.88, bottom=0.05, wspace=0.35)
cax = fig.add_axes([0.87, 0.15, 0.015, 0.7])
fig.colorbar(ims[-1], cax=cax)
out_path = os.path.join(out_dir, "topomaps.png")
plt.savefig(out_path, dpi=200)
plt.close(fig)
print(f"[OK] topomaps.png -> {out_path}")
# ==========================
# 生成 ResultData.txt
# ==========================
def compute_and_save_txt(model_path, mat_dir, out_dir, eeg_uV_tc, fs, ch_names):
pred_label = _predict_label_by_model(model_path, mat_dir)
recommend = "" if pred_label == "MDD" else ""
T, C = eeg_uV_tc.shape
mp = build_channel_index_map(ch_names, C)
frontal_idx, central_idx, parietal_idx, prefrontal_idx, posterior_idx, left_idx, right_idx = \
get_region_indices(mp, C)
freqs, pxx = welch_psd(eeg_uV_tc, fs)
central_alpha = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Alpha"])
central_beta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Beta"])
frontal_alpha = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Alpha"])
frontal_beta = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Beta"])
par_alpha = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Alpha"])
par_beta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Beta"])
central_ab = (central_alpha / (central_beta + EPS)) if central_beta > 0 else 0.0
frontal_ab = (frontal_alpha / (frontal_beta + EPS)) if frontal_beta > 0 else 0.0
par_ab = (par_alpha / (par_beta + EPS)) if par_beta > 0 else 0.0
central_theta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Theta"])
par_theta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Theta"])
central_tb = (central_theta / (central_beta + EPS)) if central_beta > 0 else 0.0
par_tb = (par_theta / (par_beta + EPS)) if par_beta > 0 else 0.0
if not left_idx or not right_idx:
left_idx = [i for i in prefrontal_idx if (i % 2 == 0)]
right_idx = [i for i in prefrontal_idx if (i % 2 == 1)]
left_alpha = region_mean_power(freqs, pxx, left_idx, BANDS_METRICS["Alpha"])
right_alpha = region_mean_power(freqs, pxx, right_idx, BANDS_METRICS["Alpha"])
prefrontal_alpha_asym = float(np.log(right_alpha + EPS) - np.log(left_alpha + EPS))
iaf = compute_iaf(freqs, pxx, posterior_idx)
pre_td = region_mean_power(freqs, pxx, prefrontal_idx, (BANDS_METRICS["Delta"][0], BANDS_METRICS["Theta"][1]))
pre_total = region_mean_power(freqs, pxx, prefrontal_idx, TOTAL_POWER_BAND)
pre_td_rel = (pre_td / (pre_total + EPS)) * 100.0 if pre_total > 0 else 0.0
def f1(x): return f"{x:.1f}"
txt = (
f"中央区α/β波比值:{f1(central_ab)}\n"
f"额区α/β波比值:{f1(frontal_ab)}\n"
f"顶区α/β波比值:{f1(par_ab)}\n"
f"中央区θ/β波比值:{f1(central_tb)}\n"
f"顶区θ/β波比值:{f1(par_tb)}\n"
f"前额叶α波不对称性:{f1(prefrontal_alpha_asym)}\n"
f"个体化α峰值频率:{f1(iaf)}\n"
f"前额叶θ+δ波功率:{f1(pre_td_rel)}\n"
f"是否推荐治疗:{recommend}\n"
)
out_path = os.path.join(out_dir, "ResultData.txt")
with open(out_path, "w", encoding="utf-8") as f:
f.write(txt)
print(f"[OK] ResultData.txt -> {out_path}")
# ==========================
# 一个函数一次性跑完txt + 图片)
# ==========================
def run_all(model_path: str, mat_dir: str, out_root: str, seconds: int = EEG_PLOT_SECONDS):
# 1) 选第一个 mat
if not os.path.exists(mat_dir):
raise RuntimeError(f"输入目录不存在: {mat_dir}")
mats = [f for f in os.listdir(mat_dir) if f.lower().endswith(".mat")]
if not mats:
raise RuntimeError(f"mat_dir 下找不到 .mat: {mat_dir}")
mats.sort()
mat_file = os.path.join(mat_dir, mats[0])
print(f"[INFO] Found mat: {mat_file}")
# 2) 创建输出目录
out_dir = ensure_outdir(out_root)
print(f"[INFO] Output dir: {out_dir}")
# --- 总是进行预处理 (默认模式) ---
print("[INFO] Mode: Raw Data (Default). Running preprocessing...")
temp_dir = os.path.join(out_dir, "temp_preprocessed")
mat_file = preprocess_mat_file(mat_file, temp_dir)
# 更新 mat_dir 指向临时目录(为了传给 compute_and_save_txt 里的 predict 接口)
mat_dir = temp_dir
# 3) 读 EEGμV
eeg_uV_tc, fs, ch_names, xyz = load_eeg_from_mat(mat_file)
print(f"[INFO] eeg shape(T,C)={eeg_uV_tc.shape}, fs={fs}")
# 5) 画图PSD + EEG
plot_psd(eeg_uV_tc, fs, ch_names, out_dir)
plot_eeg_waveforms(eeg_uV_tc, fs, ch_names, out_dir, seconds=seconds)
# 6) topomap有 xyz 才画)
try:
raw = build_mne_raw_from_uV(eeg_uV_tc, fs, ch_names, xyz)
if _raw_has_positions(raw):
band_vals = compute_band_powers_for_topomap(raw, BANDS_TOPOMAP)
plot_average_topomap(raw, band_vals["broad"], out_dir)
plot_band_topomaps(raw, band_vals, out_dir)
else:
print("[WARN] No valid positions -> skip topomap.")
except Exception as e:
print(f"[WARN] topomap failed -> skip. reason: {e}")
# 4) 指标写 txt
compute_and_save_txt(model_path, mat_dir, out_dir, eeg_uV_tc, fs, ch_names)
print("[DONE] txt + figures generated.")
return out_dir
if __name__ == "__main__":
import multiprocessing
multiprocessing.freeze_support()
import argparse
import sys
# 1. 路径锚定:获取资源绝对路径
def get_resource_path(relative_path):
"""
获取资源的绝对路径。
策略优先在当前执行目录EXE所在目录寻找。
这适用于“绿色软件”模式即资源文件model/raw_data直接放在EXE旁边。
"""
if getattr(sys, 'frozen', False):
# PyInstaller 打包后的 EXE 所在目录
base_path = os.path.dirname(sys.executable)
else:
# 开发环境:当前脚本所在目录
base_path = os.path.dirname(os.path.abspath(__file__))
return os.path.join(base_path, relative_path)
# 设置默认路径
DEFAULT_MODEL = get_resource_path(os.path.join("model", "Model_1.pth"))
# 这里我们保持 mat_dir 和 out_root 相对于 EXE 所在目录(或当前工作目录)
if getattr(sys, 'frozen', False):
EXE_DIR = os.path.dirname(sys.executable)
else:
EXE_DIR = os.path.dirname(os.path.abspath(__file__))
DEFAULT_MAT = os.path.join(EXE_DIR, "raw_data")
DEFAULT_OUT = os.path.join(EXE_DIR, "out")
# 2. 解析命令行参数
parser = argparse.ArgumentParser(description="EEG Depression Assessment Algorithm Integration")
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL, help="模型文件的路径 (.pth)")
parser.add_argument("--mat_dir", type=str, default=DEFAULT_MAT, help="输入文件夹路径 (包含原始EEG .mat)")
parser.add_argument("--out_root", type=str, default=DEFAULT_OUT, help="结果输出的根目录")
parser.add_argument("--seconds", type=int, default=10, help="画波形图截取的秒数")
args = parser.parse_args()
# 3. 检查关键路径
if not os.path.exists(args.mat_dir):
print(f"[WARN] 输入文件夹不存在: {args.mat_dir}")
if not os.path.exists(args.model_path):
print(f"[WARN] 模型文件不存在: {args.model_path}")
# 4. 执行主流程
print(f"[*] 运行配置:")
print(f" - Model : {args.model_path}")
print(f" - Input : {args.mat_dir}")
print(f" - Output: {args.out_root}")
print(f" - Mode : RAW (Auto Preprocess)")
run_all(args.model_path, args.mat_dir, args.out_root, seconds=args.seconds)