original push
55
.gitignore
vendored
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# data format
|
||||||
|
*.dat
|
||||||
|
*.csv
|
||||||
|
*.edf
|
||||||
|
*.event
|
||||||
|
*.edf.event
|
||||||
|
*.zip
|
||||||
|
*.xlsx
|
||||||
|
*.mat
|
||||||
|
*.json
|
||||||
|
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself
|
||||||
|
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Other common ignores
|
||||||
|
node_modules/
|
||||||
|
dist/
|
||||||
|
tmp/
|
||||||
|
temp/
|
||||||
|
|
||||||
|
# Project-specific ignores
|
||||||
|
# Ignore all directories in the root
|
||||||
|
# merge64ch_0127/
|
||||||
|
/P300_speller/braindecode/
|
||||||
|
/P300_speller/data/
|
||||||
|
/P300_speller/pyRiemann/
|
||||||
|
/P300_speller/README/
|
||||||
|
/merge64ch_new/
|
||||||
|
/merge64ch_tianjinZMQdebug/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
55
algorithm_V0/.gitignore
vendored
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# data format
|
||||||
|
*.dat
|
||||||
|
*.csv
|
||||||
|
*.edf
|
||||||
|
*.event
|
||||||
|
*.edf.event
|
||||||
|
*.zip
|
||||||
|
*.xlsx
|
||||||
|
*.mat
|
||||||
|
*.json
|
||||||
|
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself
|
||||||
|
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Other common ignores
|
||||||
|
node_modules/
|
||||||
|
dist/
|
||||||
|
tmp/
|
||||||
|
temp/
|
||||||
|
|
||||||
|
# Project-specific ignores
|
||||||
|
# Ignore all directories in the root
|
||||||
|
# merge64ch_0127/
|
||||||
|
/P300_speller/braindecode/
|
||||||
|
/P300_speller/data/
|
||||||
|
/P300_speller/pyRiemann/
|
||||||
|
/P300_speller/README/
|
||||||
|
/merge64ch_new/
|
||||||
|
/merge64ch_tianjinZMQdebug/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
96
algorithm_V0/algorithm_fromXjtu/build_algorithm.spec
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# -*- mode: python ; coding: utf-8 -*-
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from PyInstaller.utils.hooks import collect_submodules, collect_data_files
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 1. 工程配置区 (Project Config)
|
||||||
|
# ========================================================
|
||||||
|
block_cipher = None
|
||||||
|
ENTRY_POINT = 'runDecoder.py'
|
||||||
|
APP_NAME = 'Depression_Decoder'
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 2. 依赖分析 (Dependency Analysis)
|
||||||
|
# ========================================================
|
||||||
|
# 收集 mne, sklearn, scipy 可能遗漏的隐藏导入
|
||||||
|
hidden_imports = [
|
||||||
|
'infer_pth', # 你的动态导入模块
|
||||||
|
'sklearn.utils._cython_blas',
|
||||||
|
'sklearn.neighbors.typedefs',
|
||||||
|
'sklearn.neighbors.quad_tree',
|
||||||
|
'sklearn.tree',
|
||||||
|
'sklearn.tree._utils',
|
||||||
|
]
|
||||||
|
# 自动收集 mne 的子模块
|
||||||
|
hidden_imports += collect_submodules('mne')
|
||||||
|
# 收集 torch 相关的隐式导入(虽然 PyInstaller 通常能处理,但显式更安全)
|
||||||
|
hidden_imports += ['torch', 'torchvision']
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 3. 资源锚定 (Data Anchoring)
|
||||||
|
# ========================================================
|
||||||
|
# Analysis 中的 datas 用于将文件嵌入到内部(运行时在临时目录或 _internal)
|
||||||
|
# 这里我们留空,改为在 COLLECT 阶段通过 Tree 显式复制到 EXE 旁,
|
||||||
|
# 这样生成的文件夹里能直接看到 model 和 raw_data
|
||||||
|
datas = []
|
||||||
|
|
||||||
|
# 收集 mne 的数据文件(如果需要默认配置)
|
||||||
|
datas += collect_data_files('mne')
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 4. 构建流程 (Build Process)
|
||||||
|
# ========================================================
|
||||||
|
a = Analysis(
|
||||||
|
[ENTRY_POINT],
|
||||||
|
pathex=[],
|
||||||
|
binaries=[],
|
||||||
|
datas=datas,
|
||||||
|
hiddenimports=hidden_imports,
|
||||||
|
hookspath=[],
|
||||||
|
hooksconfig={},
|
||||||
|
runtime_hooks=[],
|
||||||
|
excludes=['tkinter', 'PyQt5', 'PySide2', 'IPython'], # 排除 GUI 和交互式库减小体积
|
||||||
|
win_no_prefer_redirects=False,
|
||||||
|
win_private_assemblies=False,
|
||||||
|
cipher=block_cipher,
|
||||||
|
noarchive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||||
|
|
||||||
|
exe = EXE(
|
||||||
|
pyz,
|
||||||
|
a.scripts,
|
||||||
|
[],
|
||||||
|
exclude_binaries=True,
|
||||||
|
name=APP_NAME,
|
||||||
|
debug=False,
|
||||||
|
bootloader_ignore_signals=False,
|
||||||
|
strip=False,
|
||||||
|
upx=False,
|
||||||
|
console=True,
|
||||||
|
disable_windowed_traceback=False,
|
||||||
|
argv_emulation=False,
|
||||||
|
target_arch=None,
|
||||||
|
codesign_identity=None,
|
||||||
|
entitlements_file=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 5. 打包模式: OneDir (单文件夹) + 资源旁路
|
||||||
|
# ========================================================
|
||||||
|
# 使用 Tree 将文件夹原样复制到 dist/APP_NAME/ 下
|
||||||
|
# 格式: Tree('源路径', prefix='目标子目录')
|
||||||
|
|
||||||
|
coll = COLLECT(
|
||||||
|
exe,
|
||||||
|
a.binaries,
|
||||||
|
a.zipfiles,
|
||||||
|
a.datas,
|
||||||
|
strip=False,
|
||||||
|
upx=False,
|
||||||
|
upx_exclude=[],
|
||||||
|
name=APP_NAME,
|
||||||
|
)
|
||||||
87
algorithm_V0/algorithm_fromXjtu/build_clean.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# 确保我们在虚拟环境中运行
|
||||||
|
if not sys.prefix == sys.base_prefix:
|
||||||
|
print(f"正在使用虚拟环境: {sys.prefix}")
|
||||||
|
else:
|
||||||
|
print("警告:你似乎没有激活虚拟环境!建议在 venv_clean 下运行。")
|
||||||
|
|
||||||
|
|
||||||
|
def build():
|
||||||
|
entry_point = "runDecoder.py" # 你的入口文件
|
||||||
|
|
||||||
|
# 自动清理逻辑优化
|
||||||
|
output_dir = "dist2"
|
||||||
|
build_dir = "build2" # Nuitka 默认会在当前目录生成 .build 文件夹
|
||||||
|
|
||||||
|
if "--clean" in sys.argv:
|
||||||
|
print("清理旧构建目录...")
|
||||||
|
for folder in [output_dir, build_dir, entry_point.replace(".py", ".build")]:
|
||||||
|
if os.path.exists(folder):
|
||||||
|
shutil.rmtree(folder, ignore_errors=True)
|
||||||
|
|
||||||
|
# Nuitka 命令 - 此时非常清爽
|
||||||
|
nuitka_cmd = [
|
||||||
|
sys.executable, "-m", "nuitka",
|
||||||
|
"--standalone", # 独立运行模式
|
||||||
|
f"--output-dir={output_dir}", # 输出目录
|
||||||
|
"--show-progress", # 显示进度
|
||||||
|
"--assume-yes-for-downloads", # 自动下载依赖(如 ccache, depends 等)
|
||||||
|
|
||||||
|
# --- 插件配置 ---
|
||||||
|
"--enable-plugin=numpy",
|
||||||
|
"--enable-plugin=matplotlib",
|
||||||
|
"--enable-plugin=torch", # 处理 PyTorch 及其 CUDA 依赖
|
||||||
|
|
||||||
|
# --- 包含包/模块 (Nuitka 2.x 推荐使用 include-package-data 或 include-package) ---
|
||||||
|
# --collect-all 是 PyInstaller 的参数,Nuitka 不支持
|
||||||
|
"--include-package=sklearn",
|
||||||
|
"--include-package=scipy",
|
||||||
|
"--include-package=mne",
|
||||||
|
|
||||||
|
# 强制包含 MNE 的数据文件(配置、布局等)
|
||||||
|
"--include-package-data=mne",
|
||||||
|
"--include-package=PIL", # Pillow (matplotlib/mne 可能用到)
|
||||||
|
"--include-package=networkx", # mne 可能用到
|
||||||
|
"--include-package=decorator", # MNE 核心依赖,防止 KeyError: 'self'
|
||||||
|
"--include-package=six", # 通用兼容库
|
||||||
|
|
||||||
|
# 显式包含本地模块,防止隐式导入丢失
|
||||||
|
"--include-module=infer_pth",
|
||||||
|
|
||||||
|
# --- 数据文件 ---
|
||||||
|
# 格式: 源路径=目标路径 (相对 dist 目录)
|
||||||
|
"--include-data-dir=model=model",
|
||||||
|
"--include-data-dir=raw_data=raw_data",
|
||||||
|
|
||||||
|
# --- 排除干扰以减小体积/提高稳定性 ---
|
||||||
|
"--nofollow-import-to=pytest",
|
||||||
|
"--nofollow-import-to=unittest",
|
||||||
|
"--nofollow-import-to=pdb",
|
||||||
|
"--nofollow-import-to=tkinter", # 如果不用 GUI 界面
|
||||||
|
"--nofollow-import-to=sympy", # 除非明确用到符号计算
|
||||||
|
|
||||||
|
# --- 内存与性能 ---
|
||||||
|
"--low-memory", # 降低打包时的内存消耗
|
||||||
|
|
||||||
|
# --- Windows 特定 ---
|
||||||
|
# "--disable-console", # 如果不需要黑框,取消注释这一行
|
||||||
|
]
|
||||||
|
|
||||||
|
nuitka_cmd.append(entry_point)
|
||||||
|
|
||||||
|
print("开始打包...")
|
||||||
|
try:
|
||||||
|
subprocess.check_call(nuitka_cmd)
|
||||||
|
print("\n打包成功!")
|
||||||
|
print(f"请在 dist2/runDecoder.dist 目录下运行 exe 进行测试。")
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"打包失败,错误码: {e.returncode}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
build()
|
||||||
77
algorithm_V0/algorithm_fromXjtu/build_with_copy.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 1. 定义路径
|
||||||
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
DIST_DIR = os.path.join(BASE_DIR, 'dist')
|
||||||
|
APP_NAME = 'Depression_Decoder'
|
||||||
|
TARGET_DIR = os.path.join(DIST_DIR, APP_NAME)
|
||||||
|
|
||||||
|
MODEL_SRC = os.path.join(BASE_DIR, 'model')
|
||||||
|
RAW_DATA_SRC = os.path.join(BASE_DIR, 'raw_data')
|
||||||
|
|
||||||
|
MODEL_DST = os.path.join(TARGET_DIR, 'model')
|
||||||
|
RAW_DATA_DST = os.path.join(TARGET_DIR, 'raw_data')
|
||||||
|
|
||||||
|
# 2. 清理旧构建
|
||||||
|
print("[1/3] Cleaning up old builds...")
|
||||||
|
if os.path.exists(DIST_DIR):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(DIST_DIR)
|
||||||
|
print(" Cleaned dist/")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clean dist/: {e}")
|
||||||
|
|
||||||
|
BUILD_DIR = os.path.join(BASE_DIR, 'build')
|
||||||
|
if os.path.exists(BUILD_DIR):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(BUILD_DIR)
|
||||||
|
print(" Cleaned build/")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clean build/: {e}")
|
||||||
|
|
||||||
|
# 3. 运行 PyInstaller
|
||||||
|
print("[2/3] Running PyInstaller...")
|
||||||
|
# 注意:我们这里不传 --noupx,因为已经在 spec 文件里把 upx=False 写死了
|
||||||
|
cmd = [
|
||||||
|
"pyinstaller",
|
||||||
|
"build_algorithm.spec",
|
||||||
|
"--clean"
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.check_call(cmd, shell=True)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
print("Error: PyInstaller failed.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 4. 复制外部资源文件夹
|
||||||
|
print("[3/3] Copying external resources...")
|
||||||
|
|
||||||
|
# 复制 model 文件夹
|
||||||
|
if os.path.exists(MODEL_SRC):
|
||||||
|
if os.path.exists(MODEL_DST):
|
||||||
|
shutil.rmtree(MODEL_DST)
|
||||||
|
shutil.copytree(MODEL_SRC, MODEL_DST)
|
||||||
|
print(f" Copied: model -> {MODEL_DST}")
|
||||||
|
else:
|
||||||
|
print(f" Warning: Source model dir not found at {MODEL_SRC}")
|
||||||
|
|
||||||
|
# 复制 raw_data 文件夹
|
||||||
|
if os.path.exists(RAW_DATA_SRC):
|
||||||
|
if os.path.exists(RAW_DATA_DST):
|
||||||
|
shutil.rmtree(RAW_DATA_DST)
|
||||||
|
shutil.copytree(RAW_DATA_SRC, RAW_DATA_DST)
|
||||||
|
print(f" Copied: raw_data -> {RAW_DATA_DST}")
|
||||||
|
else:
|
||||||
|
print(f" Warning: Source raw_data dir not found at {RAW_DATA_SRC}")
|
||||||
|
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print(f"SUCCESS! Build artifacts are in: {TARGET_DIR}")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
38
algorithm_V0/algorithm_fromXjtu/diagnose_scipy.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import scipy
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
print(f"Python executable: {sys.executable}")
|
||||||
|
print(f"Scipy version: {scipy.__version__}")
|
||||||
|
print(f"Scipy path: {scipy.__file__}")
|
||||||
|
scipy_dir = os.path.dirname(scipy.__file__)
|
||||||
|
parent_dir = os.path.dirname(scipy_dir)
|
||||||
|
scipy_libs = os.path.join(parent_dir, "scipy.libs")
|
||||||
|
print(f"Checking for scipy.libs at: {scipy_libs}")
|
||||||
|
if os.path.exists(scipy_libs):
|
||||||
|
print("scipy.libs FOUND.")
|
||||||
|
for root, dirs, files in os.walk(scipy_libs):
|
||||||
|
for f in files:
|
||||||
|
print(f" - {f}")
|
||||||
|
else:
|
||||||
|
print("scipy.libs NOT FOUND.")
|
||||||
|
|
||||||
|
print("-" * 20)
|
||||||
|
print(f"Numpy version: {numpy.__version__}")
|
||||||
|
print(f"Numpy path: {numpy.__file__}")
|
||||||
|
numpy_dir = os.path.dirname(numpy.__file__)
|
||||||
|
numpy_libs = os.path.join(numpy_dir, ".libs") # numpy 往往在内部
|
||||||
|
if not os.path.exists(numpy_libs):
|
||||||
|
# try parent
|
||||||
|
numpy_libs = os.path.join(os.path.dirname(numpy_dir), "numpy.libs")
|
||||||
|
|
||||||
|
print(f"Checking for numpy libs at: {numpy_libs}")
|
||||||
|
if os.path.exists(numpy_libs):
|
||||||
|
print("numpy libs FOUND.")
|
||||||
|
for root, dirs, files in os.walk(numpy_libs):
|
||||||
|
for f in files:
|
||||||
|
print(f" - {f}")
|
||||||
|
else:
|
||||||
|
print("numpy libs NOT FOUND.")
|
||||||
|
|
||||||
561
algorithm_V0/algorithm_fromXjtu/infer_pth.py
Normal file
@@ -0,0 +1,561 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
infer_pth.py
|
||||||
|
|
||||||
|
用途:
|
||||||
|
- 从一个文件夹中自动读取第一个 .mat EEG 文件(64通道或32通道)
|
||||||
|
- 若为64通道,则按 idx64_to_32 映射选出32通道
|
||||||
|
- 提取切片特征(DE + PSD(var近似),不含Asym)
|
||||||
|
- 加载你训练好的 .pth 模型(FusionNet结构)
|
||||||
|
- 输出该受试者的 HC / MDD 判断结果
|
||||||
|
|
||||||
|
运行方式(命令行):
|
||||||
|
python infer_pth_from64_to32.py --eeg_dir "D:\\xxx\\folder" --model_path "C:\\xxx\\model.pth"
|
||||||
|
|
||||||
|
也可在其他py里import:
|
||||||
|
from infer_pth_from64_to32 import predict_hc_mdd
|
||||||
|
res = predict_hc_mdd(eeg_dir, model_path)
|
||||||
|
print(res)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import scipy.io
|
||||||
|
import scipy.signal as signal
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# 0) 配置区(按需改这里)
|
||||||
|
# =========================================================
|
||||||
|
|
||||||
|
# 采样率(必须与训练时一致)
|
||||||
|
SAMPLING_RATE = 250
|
||||||
|
|
||||||
|
# 滑窗参数(必须与训练时一致)
|
||||||
|
WINDOW_SIZE = 500
|
||||||
|
STRIDE = 250
|
||||||
|
|
||||||
|
# 频段(必须与训练时一致)
|
||||||
|
BAND_NAMES = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]
|
||||||
|
BANDS = {
|
||||||
|
"Delta": (1, 4),
|
||||||
|
"Theta": (4, 8),
|
||||||
|
"Alpha": (8, 13),
|
||||||
|
"Beta": (13, 30),
|
||||||
|
"Gamma": (30, 50),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 是否使用扩展特征(DE+PSD)
|
||||||
|
USE_EXTENDED_FEATURES = True
|
||||||
|
|
||||||
|
# 数值稳定项
|
||||||
|
EPS = 1e-12
|
||||||
|
|
||||||
|
# 设备
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# 通道映射
|
||||||
|
IDX64_TO_32 = [
|
||||||
|
23, # C5
|
||||||
|
47, # O1
|
||||||
|
39, # TP7
|
||||||
|
6, # FPZ
|
||||||
|
2, # PO6
|
||||||
|
21, # P4
|
||||||
|
35, # AF7
|
||||||
|
57, # AF3
|
||||||
|
1, # FP2
|
||||||
|
37, # T7
|
||||||
|
63, # F1
|
||||||
|
36, # A1
|
||||||
|
18, # FC4
|
||||||
|
31, # FC5
|
||||||
|
14, # FC2
|
||||||
|
48, # T8
|
||||||
|
60, # P2
|
||||||
|
41, # AF8
|
||||||
|
11, # CP1
|
||||||
|
0, # FP1
|
||||||
|
55, # PO7
|
||||||
|
59, # C1
|
||||||
|
22, # F5
|
||||||
|
10, # CP2
|
||||||
|
16, # C3
|
||||||
|
61, # P1
|
||||||
|
27, # CP5
|
||||||
|
17, # C4
|
||||||
|
26, # CP6
|
||||||
|
62, # F2
|
||||||
|
3, # POZ
|
||||||
|
13, # PO5
|
||||||
|
]
|
||||||
|
|
||||||
|
# 推理阈值:如果模型checkpoint里有 subject_threshold,会优先用它;否则用这个
|
||||||
|
DEFAULT_SUBJECT_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# 1) 模型结构
|
||||||
|
# =========================================================
|
||||||
|
|
||||||
|
class SEBlock(nn.Module):
|
||||||
|
def __init__(self, channels: int, reduction: int = 4) -> None:
|
||||||
|
super().__init__()
|
||||||
|
hidden = max(1, channels // reduction)
|
||||||
|
self.fc1 = nn.Linear(channels, hidden)
|
||||||
|
self.fc2 = nn.Linear(hidden, channels)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
se = F.relu(self.fc1(x))
|
||||||
|
se = torch.sigmoid(self.fc2(se))
|
||||||
|
return x * se
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, in_features: int, out_features: int, dropout: float = 0.3) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(in_features, out_features)
|
||||||
|
self.bn1 = nn.BatchNorm1d(out_features)
|
||||||
|
self.fc2 = nn.Linear(out_features, out_features)
|
||||||
|
self.bn2 = nn.BatchNorm1d(out_features)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.shortcut = nn.Identity()
|
||||||
|
if in_features != out_features:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Linear(in_features, out_features),
|
||||||
|
nn.BatchNorm1d(out_features),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
identity = self.shortcut(x)
|
||||||
|
out = F.relu(self.bn1(self.fc1(x)))
|
||||||
|
out = self.dropout(out)
|
||||||
|
out = self.bn2(self.fc2(out))
|
||||||
|
out = F.relu(out + identity)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FusionNet(nn.Module):
|
||||||
|
def __init__(self, num_classes: int = 2, num_eeg_features: int = 320, num_scales: int = 6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_norm = nn.BatchNorm1d(num_eeg_features)
|
||||||
|
|
||||||
|
self.block1 = ResidualBlock(num_eeg_features, 512, dropout=0.4)
|
||||||
|
self.block2 = ResidualBlock(512, 256, dropout=0.3)
|
||||||
|
self.block3 = ResidualBlock(256, 128, dropout=0.2)
|
||||||
|
|
||||||
|
self.attention = SEBlock(128, reduction=4)
|
||||||
|
|
||||||
|
self.final_fc = nn.Sequential(
|
||||||
|
nn.Linear(128, 64),
|
||||||
|
nn.BatchNorm1d(64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cls_head = nn.Linear(64, num_classes)
|
||||||
|
|
||||||
|
# 训练时有回归头也没关系(推理只用cls)
|
||||||
|
self.reg_head = nn.Sequential(
|
||||||
|
nn.Linear(64, 32),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
nn.Linear(32, num_scales),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self) -> None:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm1d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = self.input_norm(x)
|
||||||
|
x = self.block1(x)
|
||||||
|
x = self.block2(x)
|
||||||
|
x = self.block3(x)
|
||||||
|
x = self.attention(x)
|
||||||
|
features = self.final_fc(x)
|
||||||
|
cls_out = self.cls_head(features)
|
||||||
|
reg_out = self.reg_head(features)
|
||||||
|
return cls_out, reg_out
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# 2) mat读取 + 通道裁剪
|
||||||
|
# =========================================================
|
||||||
|
|
||||||
|
def _find_first_mat_file(folder: str) -> str:
|
||||||
|
if not os.path.isdir(folder):
|
||||||
|
raise RuntimeError(f"eeg_dir 不是文件夹: {folder}")
|
||||||
|
mats = sorted([f for f in os.listdir(folder) if f.lower().endswith(".mat")])
|
||||||
|
if not mats:
|
||||||
|
raise RuntimeError(f"文件夹内没有 .mat 文件: {folder}")
|
||||||
|
return os.path.join(folder, mats[0])
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import scipy.io
|
||||||
|
|
||||||
|
def _unwrap_singleton(x):
|
||||||
|
"""
|
||||||
|
把 (1,1) / (1,) 这种包裹层一直剥掉,直到不是 singleton。
|
||||||
|
也处理 object array 的情况。
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
if x.dtype == object and x.size == 1:
|
||||||
|
x = x.item()
|
||||||
|
continue
|
||||||
|
if x.size == 1 and x.ndim >= 1:
|
||||||
|
# 例如 (1,1) 或 (1,) 的数值/对象数组
|
||||||
|
try:
|
||||||
|
x = x.reshape(-1)[0]
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
break
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _try_get_struct_field(v, field_name="data"):
|
||||||
|
"""
|
||||||
|
尝试从以下几种结构中提取字段:
|
||||||
|
1) scipy 读出的 mat_struct(有 _fieldnames)
|
||||||
|
2) numpy structured/record array(dtype.names)
|
||||||
|
"""
|
||||||
|
# case 1: mat_struct(推荐 loadmat(..., struct_as_record=False, squeeze_me=True))
|
||||||
|
if hasattr(v, "_fieldnames") and (field_name in getattr(v, "_fieldnames", [])):
|
||||||
|
return getattr(v, field_name)
|
||||||
|
|
||||||
|
# case 2: structured array
|
||||||
|
if isinstance(v, np.ndarray) and v.dtype.names and (field_name in v.dtype.names):
|
||||||
|
# 常见是 v[field] 仍然是 ndarray / object,需要 unwrap
|
||||||
|
try:
|
||||||
|
return v[field_name]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_eeg_from_mat_any_channels(mat_path: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
读取 .mat 中 EEG 数据,支持:
|
||||||
|
- 直接二维矩阵 (T,C) 或 (C,T)
|
||||||
|
- struct 里有字段 data
|
||||||
|
返回统一为 float32 的 (T, C)
|
||||||
|
"""
|
||||||
|
# 用这两个参数会让 struct 更容易处理:字段变成属性,且自动 squeeze
|
||||||
|
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
|
||||||
|
for k, v in mat.items():
|
||||||
|
if k.startswith("__"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- 1) 直接二维数值矩阵 ---
|
||||||
|
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
|
||||||
|
candidates.append((k, v))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- 2) struct/record:优先提取 data 字段 ---
|
||||||
|
data_field = _try_get_struct_field(v, "data")
|
||||||
|
if data_field is not None:
|
||||||
|
data_field = _unwrap_singleton(data_field)
|
||||||
|
|
||||||
|
# data_field 可能仍然被 object 包一层
|
||||||
|
if isinstance(data_field, np.ndarray) and data_field.dtype == object:
|
||||||
|
data_field = _unwrap_singleton(data_field)
|
||||||
|
|
||||||
|
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
|
||||||
|
# 只收数值矩阵
|
||||||
|
if np.issubdtype(data_field.dtype, np.number) or data_field.dtype == object:
|
||||||
|
candidates.append((f"{k}.data", data_field))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- 3) object array:尝试 item() 解包后再看是不是二维数值矩阵/struct ---
|
||||||
|
if isinstance(v, np.ndarray) and v.dtype == object:
|
||||||
|
vv = _unwrap_singleton(v)
|
||||||
|
|
||||||
|
# 解包后若是二维数值矩阵
|
||||||
|
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
|
||||||
|
candidates.append((k, vv))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解包后若是 struct,再取 data
|
||||||
|
data2 = _try_get_struct_field(vv, "data")
|
||||||
|
if data2 is not None:
|
||||||
|
data2 = _unwrap_singleton(data2)
|
||||||
|
if isinstance(data2, np.ndarray) and data2.ndim == 2:
|
||||||
|
candidates.append((f"{k}.data", data2))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
raise RuntimeError(f"mat里没找到可用EEG二维矩阵或struct.data:{mat_path}")
|
||||||
|
|
||||||
|
# 选一个最像 EEG 的(优先含32/64通道维度的)
|
||||||
|
def score(arr: np.ndarray) -> int:
|
||||||
|
s = 0
|
||||||
|
if 64 in arr.shape: s += 10
|
||||||
|
if 32 in arr.shape: s += 9
|
||||||
|
if 128 in arr.shape: s += 8
|
||||||
|
if 129 in arr.shape: s += 7
|
||||||
|
s += int(np.prod(arr.shape) // 100000) # 大一些更像EEG
|
||||||
|
return s
|
||||||
|
|
||||||
|
candidates.sort(key=lambda kv: score(kv[1]), reverse=True)
|
||||||
|
key, eeg = candidates[0]
|
||||||
|
|
||||||
|
eeg = _unwrap_singleton(eeg)
|
||||||
|
|
||||||
|
# 如果还是 object dtype,尽力转成 float
|
||||||
|
if isinstance(eeg, np.ndarray) and eeg.dtype == object:
|
||||||
|
# 有时 object 里其实是数值
|
||||||
|
eeg = np.array(eeg, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
eeg = np.asarray(eeg, dtype=np.float32)
|
||||||
|
|
||||||
|
if eeg.ndim != 2:
|
||||||
|
raise RuntimeError(f"解析结果不是二维矩阵: key={key}, shape={eeg.shape}, file={mat_path}")
|
||||||
|
|
||||||
|
# 统一为 (T, C)
|
||||||
|
# 常见 (C,T) 或 (T,C),我们用“通道维通常较小”+ “32/64/128/129”判断
|
||||||
|
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
|
||||||
|
eeg = eeg.T
|
||||||
|
elif eeg.shape[1] in (32, 64, 128, 129):
|
||||||
|
# 如果第一维也是这些数且更小,可能是(C,T)
|
||||||
|
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
|
||||||
|
eeg = eeg.T
|
||||||
|
|
||||||
|
return eeg
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_32_channels(eeg: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
输入 (T, C),输出 (T, 32)
|
||||||
|
- 若C=64:按 IDX64_TO_32 选32通道
|
||||||
|
- 若C=32:直接返回
|
||||||
|
"""
|
||||||
|
if eeg.ndim != 2:
|
||||||
|
raise RuntimeError(f"EEG必须是二维(T,C),但得到: {eeg.shape}")
|
||||||
|
|
||||||
|
C = eeg.shape[1]
|
||||||
|
if C == 64:
|
||||||
|
idx = np.asarray(IDX64_TO_32, dtype=np.int64)
|
||||||
|
if idx.min() < 0 or idx.max() >= 64:
|
||||||
|
raise RuntimeError(f"IDX64_TO_32 越界:min={idx.min()}, max={idx.max()} (要求0~63)")
|
||||||
|
return eeg[:, idx]
|
||||||
|
if C == 32:
|
||||||
|
return eeg
|
||||||
|
raise RuntimeError(f"不支持的通道数C={C},当前只支持 64->32 或 32 直推。")
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# 3) 特征提取(DE + PSD(var近似))
|
||||||
|
# =========================================================
|
||||||
|
|
||||||
|
class FeatureExtractor32:
|
||||||
|
"""
|
||||||
|
只针对32通道,输出维度:
|
||||||
|
- USE_EXTENDED_FEATURES=True:DE(32*5) + PSD(32*5) = 320
|
||||||
|
- 否则:DE(32*5) = 160
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fs: int = SAMPLING_RATE,
|
||||||
|
window_size: int = WINDOW_SIZE,
|
||||||
|
stride: int = STRIDE,
|
||||||
|
filter_order: int = 4,
|
||||||
|
zero_phase: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.fs = fs
|
||||||
|
self.window_size = window_size
|
||||||
|
self.stride = stride
|
||||||
|
self.filter_order = filter_order
|
||||||
|
self.zero_phase = zero_phase
|
||||||
|
|
||||||
|
self._sos = {}
|
||||||
|
for bn in BAND_NAMES:
|
||||||
|
low, high = BANDS[bn]
|
||||||
|
self._sos[bn] = signal.butter(
|
||||||
|
self.filter_order, [low, high],
|
||||||
|
btype="band", fs=self.fs, output="sos"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _filter_bands(self, eeg: np.ndarray) -> dict[str, np.ndarray]:
|
||||||
|
out = {}
|
||||||
|
for bn in BAND_NAMES:
|
||||||
|
sos = self._sos[bn]
|
||||||
|
if self.zero_phase:
|
||||||
|
out[bn] = signal.sosfiltfilt(sos, eeg, axis=0).astype(np.float32)
|
||||||
|
else:
|
||||||
|
out[bn] = signal.sosfilt(sos, eeg, axis=0).astype(np.float32)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extract(self, eeg32: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
eeg32: (T, 32)
|
||||||
|
return: feats (N_slices, feat_dim)
|
||||||
|
"""
|
||||||
|
if eeg32.ndim != 2 or eeg32.shape[1] != 32:
|
||||||
|
raise RuntimeError(f"extract需要 (T,32),得到 {eeg32.shape}")
|
||||||
|
|
||||||
|
bands_data = self._filter_bands(eeg32)
|
||||||
|
T = eeg32.shape[0]
|
||||||
|
|
||||||
|
feats = []
|
||||||
|
for start in range(0, T - self.window_size, self.stride):
|
||||||
|
end = start + self.window_size
|
||||||
|
|
||||||
|
de_list = []
|
||||||
|
psd_list = []
|
||||||
|
|
||||||
|
for bn in BAND_NAMES:
|
||||||
|
seg = bands_data[bn][start:end, :] # (W, 32)
|
||||||
|
var = np.var(seg, axis=0, ddof=1) # (32,)
|
||||||
|
|
||||||
|
# DE
|
||||||
|
de = 0.5 * np.log(2 * np.pi * np.e * (var + EPS))
|
||||||
|
de_list.append(de)
|
||||||
|
|
||||||
|
if USE_EXTENDED_FEATURES:
|
||||||
|
# PSD近似:log(var)
|
||||||
|
psd_list.append(np.log(var + EPS))
|
||||||
|
|
||||||
|
de_feat = np.stack(de_list, axis=0).T.reshape(-1) # (32*5,)
|
||||||
|
if USE_EXTENDED_FEATURES:
|
||||||
|
psd_feat = np.stack(psd_list, axis=0).T.reshape(-1) # (32*5,)
|
||||||
|
f = np.concatenate([de_feat, psd_feat], axis=0).astype(np.float32)
|
||||||
|
else:
|
||||||
|
f = de_feat.astype(np.float32)
|
||||||
|
|
||||||
|
feats.append(f)
|
||||||
|
|
||||||
|
if not feats:
|
||||||
|
raise RuntimeError("EEG长度不足以切片(请检查T是否太短,或调整WINDOW_SIZE/STRIDE)")
|
||||||
|
|
||||||
|
return np.stack(feats, axis=0).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# 4) 模型加载 + 推理接口
|
||||||
|
# =========================================================
|
||||||
|
|
||||||
|
def _safe_torch_load(path: str):
|
||||||
|
try:
|
||||||
|
return torch.load(path, map_location=DEVICE, weights_only=False)
|
||||||
|
except TypeError:
|
||||||
|
return torch.load(path, map_location=DEVICE)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path: str) -> tuple[FusionNet, dict]:
|
||||||
|
"""
|
||||||
|
返回: (model, ckpt_dict)
|
||||||
|
"""
|
||||||
|
obj = _safe_torch_load(model_path)
|
||||||
|
if isinstance(obj, dict) and "model_state" in obj:
|
||||||
|
ckpt = obj
|
||||||
|
state = obj["model_state"]
|
||||||
|
feat_dim = int(obj.get("feat_dim", 320))
|
||||||
|
else:
|
||||||
|
ckpt = {}
|
||||||
|
state = obj
|
||||||
|
feat_dim = 320
|
||||||
|
|
||||||
|
model = FusionNet(num_classes=2, num_eeg_features=feat_dim).to(DEVICE)
|
||||||
|
model.load_state_dict(state, strict=True)
|
||||||
|
model.eval()
|
||||||
|
return model, ckpt
|
||||||
|
|
||||||
|
|
||||||
|
def predict_hc_mdd(eeg_dir: str, model_path: str) -> dict:
|
||||||
|
"""
|
||||||
|
接口:传入 EEG文件夹 和 模型路径,返回判断结果 dict
|
||||||
|
|
||||||
|
返回字段:
|
||||||
|
- mat_file: 使用的mat文件
|
||||||
|
- pred_label: "HC" or "MDD"
|
||||||
|
- p_mdd_mean: 切片p(MDD)均值
|
||||||
|
- threshold: subject判定阈值
|
||||||
|
- n_slices: 切片数
|
||||||
|
"""
|
||||||
|
mat_file = _find_first_mat_file(eeg_dir)
|
||||||
|
|
||||||
|
# 1) 读EEG (T,C),并保证变成32通道
|
||||||
|
eeg = load_eeg_from_mat_any_channels(mat_file) # (T,C)
|
||||||
|
eeg32 = ensure_32_channels(eeg) # (T,32)
|
||||||
|
|
||||||
|
# 2) 提特征 (N,feat_dim)
|
||||||
|
extractor = FeatureExtractor32(fs=SAMPLING_RATE, window_size=WINDOW_SIZE, stride=STRIDE)
|
||||||
|
feats = extractor.extract(eeg32) # (N, dim)
|
||||||
|
|
||||||
|
# 3) 加载模型
|
||||||
|
model, ckpt = load_model(model_path)
|
||||||
|
|
||||||
|
# 4) 可选:归一化(若ckpt里保存的mean/std维度刚好匹配)
|
||||||
|
mean = ckpt.get("global_mean", None) if isinstance(ckpt, dict) else None
|
||||||
|
std = ckpt.get("global_std", None) if isinstance(ckpt, dict) else None
|
||||||
|
|
||||||
|
if mean is not None and std is not None:
|
||||||
|
mean = np.asarray(mean, dtype=np.float32)
|
||||||
|
std = np.asarray(std, dtype=np.float32)
|
||||||
|
if mean.shape[0] == feats.shape[1] and std.shape[0] == feats.shape[1]:
|
||||||
|
feats = (feats - mean) / (std + 1e-8)
|
||||||
|
# 不匹配就跳过
|
||||||
|
# 你说“先不用其他步骤或信息”,所以这里按“尽量运行”处理
|
||||||
|
|
||||||
|
# 5) 推理:对所有切片算 p(MDD),取均值做subject-level
|
||||||
|
x = torch.from_numpy(feats).to(DEVICE)
|
||||||
|
with torch.no_grad():
|
||||||
|
cls_out, _ = model(x)
|
||||||
|
prob_mdd = torch.softmax(cls_out, dim=1)[:, 1].detach().cpu().numpy()
|
||||||
|
|
||||||
|
p_mdd_mean = float(np.mean(prob_mdd))
|
||||||
|
thr = float(ckpt.get("subject_threshold", DEFAULT_SUBJECT_THRESHOLD)) if isinstance(ckpt, dict) else float(DEFAULT_SUBJECT_THRESHOLD)
|
||||||
|
pred_is_mdd = (p_mdd_mean >= thr)
|
||||||
|
pred_label = "MDD" if pred_is_mdd else "HC"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mat_file": mat_file,
|
||||||
|
"pred_label": pred_label,
|
||||||
|
"p_mdd_mean": p_mdd_mean,
|
||||||
|
"threshold": thr,
|
||||||
|
"n_slices": int(feats.shape[0]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# 5) CLI:命令行运行入口
|
||||||
|
# =========================================================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Infer HC/MDD from 64ch->32ch EEG mat using a .pth FusionNet model (no Asym).")
|
||||||
|
parser.add_argument("--eeg_dir", type=str, required=True, help="包含.mat EEG文件的文件夹(自动读取第一个.mat)")
|
||||||
|
parser.add_argument("--model_path", type=str, required=True, help="训练好的.pth模型路径")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
res = predict_hc_mdd(args.eeg_dir, args.model_path)
|
||||||
|
print("\n========== 推理结果 ==========")
|
||||||
|
print(f"MAT文件: {res['mat_file']}")
|
||||||
|
print(f"切片数量: {res['n_slices']}")
|
||||||
|
print(f"p(MDD)_mean: {res['p_mdd_mean']:.4f}")
|
||||||
|
print(f"阈值thr: {res['threshold']:.4f}")
|
||||||
|
print(f"预测结果: {res['pred_label']}")
|
||||||
|
print("==============================\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
algorithm_V0/algorithm_fromXjtu/model/Model_0.pth
Normal file
BIN
algorithm_V0/algorithm_fromXjtu/model/Model_1.pth
Normal file
BIN
algorithm_V0/algorithm_fromXjtu/out/EEG.png
Normal file
|
After Width: | Height: | Size: 306 KiB |
9
algorithm_V0/algorithm_fromXjtu/out/ResultData.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
中央区α/β波比值:1.2
|
||||||
|
额区α/β波比值:1.3
|
||||||
|
顶区α/β波比值:1.2
|
||||||
|
中央区θ/β波比值:3.2
|
||||||
|
顶区θ/β波比值:3.5
|
||||||
|
前额叶α波不对称性:0.3
|
||||||
|
个体化α峰值频率:8.5
|
||||||
|
前额叶θ+δ波功率:93.8
|
||||||
|
是否推荐治疗:否
|
||||||
BIN
algorithm_V0/algorithm_fromXjtu/out/average_topomap.png
Normal file
|
After Width: | Height: | Size: 268 KiB |
BIN
algorithm_V0/algorithm_fromXjtu/out/psd.png
Normal file
|
After Width: | Height: | Size: 61 KiB |
BIN
algorithm_V0/algorithm_fromXjtu/out/topomaps.png
Normal file
|
After Width: | Height: | Size: 493 KiB |
6
algorithm_V0/algorithm_fromXjtu/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
numpy
|
||||||
|
scipy
|
||||||
|
matplotlib
|
||||||
|
mne
|
||||||
|
torch
|
||||||
|
scikit-learn
|
||||||
909
algorithm_V0/algorithm_fromXjtu/runDecoder.py
Normal file
@@ -0,0 +1,909 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from __future__ import annotations
|
||||||
|
"""
|
||||||
|
run_metrics_and_figs.py
|
||||||
|
|
||||||
|
1) 自动读取 mat_dir 中排序后的第一个 .mat
|
||||||
|
2) 调用模型预测(HC/MDD)并写 ResultData.txt
|
||||||
|
3) 同时保存图片:EEG.png / psd.png / average_topomap.png / topomaps.png
|
||||||
|
|
||||||
|
"""
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import scipy.io
|
||||||
|
import scipy.signal as signal
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import mne
|
||||||
|
from mne.preprocessing import ICA
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# Config
|
||||||
|
# ==========================
|
||||||
|
PREPROCESS_BANDPASS = (0.8, 30.0)
|
||||||
|
PREPROCESS_NOTCH = [50, 100]
|
||||||
|
PREPROCESS_ICA_N = 0.99
|
||||||
|
PREPROCESS_ICA_SEED = 97
|
||||||
|
PREPROCESS_APPLY_AVG_REF = True
|
||||||
|
PREPROCESS_BAD_PTP_UV = 350.0 # 坏段阈值 (μV)
|
||||||
|
|
||||||
|
DEFAULT_FS = 250.0
|
||||||
|
EEG_PLOT_SECONDS = 10
|
||||||
|
PSD_FMIN, PSD_FMAX = 0.8, 45.0
|
||||||
|
EPS = 1e-12
|
||||||
|
FIXED_EEG_IDXS = [23, 47, 39, 6, 2, 21, 35, 57] # 0-based index, 按重要性排序
|
||||||
|
FIXED_EEG_LABELS = ["C5", "O1", "TP7", "FPZ", "PO6", "P4", "AF7", "AF3"]
|
||||||
|
|
||||||
|
BANDS_METRICS = {
|
||||||
|
"Delta": (1.0, 4.0),
|
||||||
|
"Theta": (4.0, 8.0),
|
||||||
|
"Alpha": (8.0, 13.0),
|
||||||
|
"Beta": (13.0, 30.0),
|
||||||
|
}
|
||||||
|
TOTAL_POWER_BAND = (1.0, 50.0)
|
||||||
|
|
||||||
|
BANDS_TOPOMAP = {
|
||||||
|
"delta": (0.8, 3.9),
|
||||||
|
"theta": (4.0, 7.9),
|
||||||
|
"alpha": (8.0, 12.9),
|
||||||
|
"beta": (13.0, 30.0),
|
||||||
|
"broad": (0.8, 30.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 预处理逻辑
|
||||||
|
# ==========================
|
||||||
|
def annotate_bad_segments(raw, peak_to_peak_uv=250.0):
|
||||||
|
"""
|
||||||
|
简单坏段检测:按固定窗口计算峰峰值,超过阈值标为 bad。
|
||||||
|
"""
|
||||||
|
peak_to_peak_v = peak_to_peak_uv * 1e-6
|
||||||
|
win = int(raw.info["sfreq"] * 1.0)
|
||||||
|
step = int(raw.info["sfreq"] * 0.5)
|
||||||
|
data = raw.get_data()
|
||||||
|
n_times = data.shape[1]
|
||||||
|
onsets = []
|
||||||
|
durations = []
|
||||||
|
descriptions = []
|
||||||
|
|
||||||
|
for start in range(0, n_times - win, step):
|
||||||
|
seg = data[:, start:start + win]
|
||||||
|
ptp = np.ptp(seg, axis=1)
|
||||||
|
if np.any(ptp > peak_to_peak_v):
|
||||||
|
onsets.append(start / raw.info["sfreq"])
|
||||||
|
durations.append(win / raw.info["sfreq"])
|
||||||
|
descriptions.append("BAD_PTP")
|
||||||
|
|
||||||
|
if len(onsets) > 0:
|
||||||
|
ann = mne.Annotations(onset=onsets, duration=durations, description=descriptions)
|
||||||
|
raw.set_annotations(ann)
|
||||||
|
print(f"[INFO] Annotated bad segments: {len(onsets)} windows")
|
||||||
|
else:
|
||||||
|
print("[INFO] No bad segments detected by PTP rule")
|
||||||
|
|
||||||
|
|
||||||
|
def run_preprocess_on_raw(raw: mne.io.RawArray) -> mne.io.RawArray:
|
||||||
|
"""
|
||||||
|
核心预处理:滤波 + 平均参考 + 坏段标注 + ICA
|
||||||
|
"""
|
||||||
|
# 1) 滤波
|
||||||
|
raw.filter(PREPROCESS_BANDPASS[0], PREPROCESS_BANDPASS[1], fir_design="firwin", verbose=False)
|
||||||
|
raw.notch_filter(PREPROCESS_NOTCH, fir_design="firwin", verbose=False)
|
||||||
|
|
||||||
|
# 2) 平均参考
|
||||||
|
if PREPROCESS_APPLY_AVG_REF:
|
||||||
|
raw.set_eeg_reference("average", verbose=False)
|
||||||
|
|
||||||
|
# 3) 坏段标注
|
||||||
|
annotate_bad_segments(raw, peak_to_peak_uv=PREPROCESS_BAD_PTP_UV)
|
||||||
|
|
||||||
|
# 4) ICA
|
||||||
|
ica = ICA(
|
||||||
|
n_components=PREPROCESS_ICA_N,
|
||||||
|
random_state=PREPROCESS_ICA_SEED,
|
||||||
|
max_iter=800,
|
||||||
|
method="fastica"
|
||||||
|
)
|
||||||
|
ica.fit(raw, reject_by_annotation=True, verbose=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
eog_inds, _ = ica.find_bads_eog(raw, verbose=False)
|
||||||
|
if eog_inds:
|
||||||
|
ica.exclude.extend(eog_inds)
|
||||||
|
print(f"[INFO] ICA exclude EOG comps: {eog_inds}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] ICA find_bads_eog skipped: {e}")
|
||||||
|
|
||||||
|
raw_clean = ica.apply(raw.copy(), verbose=False)
|
||||||
|
return raw_clean
|
||||||
|
|
||||||
|
def preprocess_mat_file(src_mat_path: str, temp_out_dir: str) -> str:
|
||||||
|
"""
|
||||||
|
读取原始mat -> 预处理 -> 保存到 temp_out_dir -> 返回新路径
|
||||||
|
"""
|
||||||
|
os.makedirs(temp_out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 1. 读原始 mat
|
||||||
|
# 注意:这里我们只要数据部分转成 MNE Raw,然后处理,再存回
|
||||||
|
# 复用现有的 load_eeg_from_mat 拿到 ndarray
|
||||||
|
eeg_uV, fs, ch_names, xyz = load_eeg_from_mat(src_mat_path)
|
||||||
|
|
||||||
|
# 转 MNE (注意单位:uV -> V)
|
||||||
|
if not ch_names:
|
||||||
|
ch_names = [f"CH{i+1}" for i in range(eeg_uV.shape[1])]
|
||||||
|
|
||||||
|
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types=["eeg"] * len(ch_names))
|
||||||
|
raw = mne.io.RawArray(eeg_uV.T * 1e-6, info, verbose=False)
|
||||||
|
|
||||||
|
if xyz is not None and isinstance(xyz, np.ndarray):
|
||||||
|
# 尝试设 montage(虽然对滤波不关键,但尽量保留信息)
|
||||||
|
try:
|
||||||
|
ch_pos = {ch_names[i]: xyz[i, :] for i in range(len(ch_names))}
|
||||||
|
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame="head")
|
||||||
|
raw.set_montage(montage, on_missing="ignore")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2. 执行预处理
|
||||||
|
print(f"[INFO] Start preprocessing: {src_mat_path}")
|
||||||
|
raw_clean = run_preprocess_on_raw(raw)
|
||||||
|
|
||||||
|
# 3. 存回 .mat (保持结构兼容,以便后续 run_all 读取)
|
||||||
|
# 这里我们需要读取原始 mat 的结构体,把 data 替换掉
|
||||||
|
try:
|
||||||
|
mat_struct = scipy.io.loadmat(src_mat_path, struct_as_record=False, squeeze_me=True)
|
||||||
|
if "eeg" in mat_struct:
|
||||||
|
eeg_obj = mat_struct["eeg"]
|
||||||
|
# 替换数据:MNE (V) -> uV -> (T, C)
|
||||||
|
clean_data_uV = (raw_clean.get_data() * 1e6).T
|
||||||
|
eeg_obj.data = clean_data_uV
|
||||||
|
|
||||||
|
base_name = os.path.basename(src_mat_path)
|
||||||
|
new_path = os.path.join(temp_out_dir, base_name)
|
||||||
|
scipy.io.savemat(new_path, {"eeg": eeg_obj}, do_compression=True)
|
||||||
|
print(f"[INFO] Preprocessed file saved to: {new_path}")
|
||||||
|
return new_path
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Failed to preserve original struct structure: {e}")
|
||||||
|
|
||||||
|
# Fallback: 如果读原始结构失败,就存一个简单的 mat
|
||||||
|
clean_data_uV = (raw_clean.get_data() * 1e6).T
|
||||||
|
out_dict = {
|
||||||
|
"eeg": {
|
||||||
|
"data": clean_data_uV,
|
||||||
|
"sample_rate": fs,
|
||||||
|
"electrode_name": ch_names,
|
||||||
|
"electrode_xyz": xyz if xyz is not None else []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
base_name = os.path.basename(src_mat_path)
|
||||||
|
new_path = os.path.join(temp_out_dir, base_name)
|
||||||
|
scipy.io.savemat(new_path, out_dict, do_compression=True)
|
||||||
|
print(f"[INFO] Preprocessed file saved (fallback mode) to: {new_path}")
|
||||||
|
return new_path
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 输出目录
|
||||||
|
# ==========================
|
||||||
|
def ensure_outdir(out_root: str) -> str:
|
||||||
|
"""
|
||||||
|
确保输出目录存在,并清空除 ResultData.txt 之外的旧文件。
|
||||||
|
不再创建 timestamp 子文件夹,直接输出到 out_root。
|
||||||
|
"""
|
||||||
|
if os.path.exists(out_root):
|
||||||
|
# 清空目录,但保留 ResultData.txt
|
||||||
|
for filename in os.listdir(out_root):
|
||||||
|
if filename == "ResultData.txt":
|
||||||
|
continue
|
||||||
|
file_path = os.path.join(out_root, filename)
|
||||||
|
try:
|
||||||
|
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||||
|
os.unlink(file_path)
|
||||||
|
elif os.path.isdir(file_path):
|
||||||
|
shutil.rmtree(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Failed to delete {file_path}. Reason: {e}")
|
||||||
|
else:
|
||||||
|
os.makedirs(out_root, exist_ok=True)
|
||||||
|
|
||||||
|
return out_root
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 单位自动识别:统一到 μV
|
||||||
|
# ==========================
|
||||||
|
def _auto_scale_to_uV(data_nt_nc: np.ndarray):
|
||||||
|
data = np.asarray(data_nt_nc)
|
||||||
|
p95 = float(np.percentile(np.abs(data), 95))
|
||||||
|
|
||||||
|
if p95 <= 0.5:
|
||||||
|
data_uV = data * 1e6
|
||||||
|
msg = f"[UNIT] p95={p95:.3g} -> assume V, convert to μV by *1e6"
|
||||||
|
elif p95 > 5000:
|
||||||
|
data_uV = data * 1e-3
|
||||||
|
msg = f"[UNIT] p95={p95:.3g} -> assume nV, convert to μV by /1000"
|
||||||
|
else:
|
||||||
|
data_uV = data
|
||||||
|
msg = f"[UNIT] p95={p95:.3g} -> assume μV, no scaling"
|
||||||
|
|
||||||
|
p95_uV = float(np.percentile(np.abs(data_uV), 95))
|
||||||
|
warn = None
|
||||||
|
if p95_uV > 5000:
|
||||||
|
warn = f"[WARN] After scaling, p95 still large: {p95_uV:.3g} μV"
|
||||||
|
elif p95_uV < 0.1:
|
||||||
|
warn = f"[WARN] After scaling, p95 still small: {p95_uV:.3g} μV"
|
||||||
|
|
||||||
|
return data_uV, msg, warn
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# mat 读取(支持 struct.data / electrode_name / electrode_xyz / sample_rate)
|
||||||
|
# ==========================
|
||||||
|
def _unwrap_singleton(x):
|
||||||
|
while True:
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
if x.dtype == object and x.size == 1:
|
||||||
|
x = x.item()
|
||||||
|
continue
|
||||||
|
if x.size == 1 and x.ndim >= 1:
|
||||||
|
try:
|
||||||
|
x = x.reshape(-1)[0]
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
break
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _try_get_struct_field(v, field_name="data"):
|
||||||
|
if hasattr(v, "_fieldnames") and field_name in getattr(v, "_fieldnames", []):
|
||||||
|
return getattr(v, field_name)
|
||||||
|
if isinstance(v, np.ndarray) and v.dtype.names and field_name in v.dtype.names:
|
||||||
|
try:
|
||||||
|
return v[field_name]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_electrode_names(st):
|
||||||
|
nf = _try_get_struct_field(st, "electrode_name")
|
||||||
|
if nf is None:
|
||||||
|
return None
|
||||||
|
nf = _unwrap_singleton(nf)
|
||||||
|
if isinstance(nf, (list, tuple)):
|
||||||
|
names = [str(x).strip() for x in nf]
|
||||||
|
return names if names else None
|
||||||
|
if isinstance(nf, np.ndarray):
|
||||||
|
flat = nf.reshape(-1)
|
||||||
|
names = [str(_unwrap_singleton(x)).strip() for x in flat]
|
||||||
|
return names if names else None
|
||||||
|
s = str(nf).strip()
|
||||||
|
return [s] if s else None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_sample_rate(st):
|
||||||
|
sr = _try_get_struct_field(st, "sample_rate")
|
||||||
|
if sr is None:
|
||||||
|
return None
|
||||||
|
sr = _unwrap_singleton(sr)
|
||||||
|
try:
|
||||||
|
return float(sr)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_xyz(st):
|
||||||
|
xyz = _try_get_struct_field(st, "electrode_xyz")
|
||||||
|
if xyz is None:
|
||||||
|
return None
|
||||||
|
xyz = _unwrap_singleton(xyz)
|
||||||
|
try:
|
||||||
|
xyz = np.asarray(xyz, dtype=float)
|
||||||
|
if xyz.ndim == 2 and xyz.shape[1] == 3:
|
||||||
|
return xyz
|
||||||
|
if xyz.ndim == 2 and xyz.shape[0] == 3:
|
||||||
|
return xyz.T
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_eeg_from_mat(mat_path: str):
|
||||||
|
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
st_for_meta = None
|
||||||
|
|
||||||
|
for k, v in mat.items():
|
||||||
|
if k.startswith("__"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
|
||||||
|
candidates.append((k, v, None))
|
||||||
|
continue
|
||||||
|
|
||||||
|
data_field = _try_get_struct_field(v, "data")
|
||||||
|
if data_field is not None:
|
||||||
|
data_field = _unwrap_singleton(data_field)
|
||||||
|
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
|
||||||
|
candidates.append((f"{k}.data", data_field, v))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(v, np.ndarray) and v.dtype == object:
|
||||||
|
vv = _unwrap_singleton(v)
|
||||||
|
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
|
||||||
|
candidates.append((k, vv, None))
|
||||||
|
continue
|
||||||
|
data2 = _try_get_struct_field(vv, "data")
|
||||||
|
if data2 is not None:
|
||||||
|
data2 = _unwrap_singleton(data2)
|
||||||
|
if isinstance(data2, np.ndarray) and data2.ndim == 2:
|
||||||
|
candidates.append((f"{k}.data", data2, vv))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
raise RuntimeError(f"mat 里没找到可用 EEG 二维矩阵或 struct.data:{mat_path}")
|
||||||
|
|
||||||
|
def score(arr: np.ndarray) -> int:
|
||||||
|
s = 0
|
||||||
|
if 64 in arr.shape: s += 10
|
||||||
|
if 32 in arr.shape: s += 9
|
||||||
|
if 128 in arr.shape: s += 8
|
||||||
|
if 129 in arr.shape: s += 7
|
||||||
|
s += int(np.prod(arr.shape) // 100000)
|
||||||
|
return s
|
||||||
|
|
||||||
|
candidates.sort(key=lambda x: score(x[1]), reverse=True)
|
||||||
|
key, eeg, st = candidates[0]
|
||||||
|
st_for_meta = st
|
||||||
|
|
||||||
|
eeg = np.asarray(_unwrap_singleton(eeg), dtype=np.float32)
|
||||||
|
if eeg.ndim != 2:
|
||||||
|
raise RuntimeError(f"解析结果不是二维: key={key}, shape={eeg.shape}, file={mat_path}")
|
||||||
|
|
||||||
|
# 统一成 (T, C)
|
||||||
|
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
|
||||||
|
eeg = eeg.T
|
||||||
|
elif eeg.shape[1] in (32, 64, 128, 129):
|
||||||
|
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
|
||||||
|
eeg = eeg.T
|
||||||
|
|
||||||
|
fs = DEFAULT_FS
|
||||||
|
ch_names = None
|
||||||
|
xyz = None
|
||||||
|
if st_for_meta is not None:
|
||||||
|
fs2 = _extract_sample_rate(st_for_meta)
|
||||||
|
if fs2 is not None and fs2 > 1:
|
||||||
|
fs = float(fs2)
|
||||||
|
ch_names = _extract_electrode_names(st_for_meta)
|
||||||
|
xyz = _extract_xyz(st_for_meta)
|
||||||
|
|
||||||
|
eeg_uV, msg, warn = _auto_scale_to_uV(eeg)
|
||||||
|
print(msg)
|
||||||
|
if warn:
|
||||||
|
print(warn)
|
||||||
|
|
||||||
|
return eeg_uV.astype(np.float32), float(fs), ch_names, xyz
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 预测接口:导入 predict_hc_mdd
|
||||||
|
# ==========================
|
||||||
|
def _predict_label_by_model(model_path: str, mat_dir: str) -> str:
|
||||||
|
try:
|
||||||
|
from infer_pth import predict_hc_mdd
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"无法导入 predict_hc_mdd(请确保 pre.py 或 infer_pth.py 与本文件同目录)。\n"
|
||||||
|
f"原始错误: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = predict_hc_mdd(mat_dir, model_path)
|
||||||
|
except TypeError:
|
||||||
|
out = predict_hc_mdd(model_path, mat_dir)
|
||||||
|
|
||||||
|
label = str(out.get("pred_label", "")).strip().upper()
|
||||||
|
if label not in ("HC", "MDD"):
|
||||||
|
raise RuntimeError(f"predict_hc_mdd 返回 pred_label 非法: {label},原始返回: {out}")
|
||||||
|
return label
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 通道分区
|
||||||
|
# ==========================
|
||||||
|
def _norm_name(s: str) -> str:
|
||||||
|
return str(s).strip().upper().replace(" ", "")
|
||||||
|
|
||||||
|
|
||||||
|
def build_channel_index_map(ch_names, n_channels: int):
|
||||||
|
if not ch_names or len(ch_names) != n_channels:
|
||||||
|
return {}
|
||||||
|
return {_norm_name(nm): i for i, nm in enumerate(ch_names)}
|
||||||
|
|
||||||
|
|
||||||
|
def pick_indices_by_names(name_to_idx, names):
|
||||||
|
idx = []
|
||||||
|
for n in names:
|
||||||
|
nn = _norm_name(n)
|
||||||
|
if nn in name_to_idx:
|
||||||
|
idx.append(name_to_idx[nn])
|
||||||
|
return sorted(list(set(idx)))
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_region_indices(n_channels: int):
|
||||||
|
a = int(n_channels * 0.33)
|
||||||
|
b = int(n_channels * 0.66)
|
||||||
|
frontal = list(range(0, a))
|
||||||
|
central = list(range(a, b))
|
||||||
|
parietal = list(range(b, n_channels))
|
||||||
|
prefrontal = list(range(0, max(2, a // 2)))
|
||||||
|
posterior = list(range(b, n_channels))
|
||||||
|
left = [i for i in range(n_channels) if i % 2 == 0]
|
||||||
|
right = [i for i in range(n_channels) if i % 2 == 1]
|
||||||
|
return frontal, central, parietal, prefrontal, posterior, left, right
|
||||||
|
|
||||||
|
|
||||||
|
def get_region_indices(name_to_idx, n_channels: int):
|
||||||
|
if not name_to_idx:
|
||||||
|
return _fallback_region_indices(n_channels)
|
||||||
|
|
||||||
|
central_names = ["CZ","C1","C2","C3","C4","C5","C6","CP1","CP2","CP3","CP4","CP5","CP6","FC1","FC2","FC3","FC4","FC5","FC6"]
|
||||||
|
frontal_names = ["FZ","F1","F2","F3","F4","F5","F6","F7","F8","AF3","AF4","AF7","AF8","FPZ","FP1","FP2","FCZ"]
|
||||||
|
parietal_names = ["PZ","P1","P2","P3","P4","P5","P6","POZ","PO3","PO4","PO5","PO6","PO7","PO8","CPZ"]
|
||||||
|
prefrontal_names = ["FP1","FP2","FPZ","AF3","AF4","AF7","AF8"]
|
||||||
|
posterior_names = ["O1","O2","OZ","PO7","PO8","PO3","PO4","PZ","P3","P4","P1","P2"]
|
||||||
|
|
||||||
|
central = pick_indices_by_names(name_to_idx, central_names)
|
||||||
|
frontal = pick_indices_by_names(name_to_idx, frontal_names)
|
||||||
|
parietal = pick_indices_by_names(name_to_idx, parietal_names)
|
||||||
|
prefrontal = pick_indices_by_names(name_to_idx, prefrontal_names)
|
||||||
|
posterior = pick_indices_by_names(name_to_idx, posterior_names)
|
||||||
|
|
||||||
|
left_names = ["FP1","AF3","AF7","F3","F5","F7"]
|
||||||
|
right_names = ["FP2","AF4","AF8","F4","F6","F8"]
|
||||||
|
left = pick_indices_by_names(name_to_idx, left_names)
|
||||||
|
right = pick_indices_by_names(name_to_idx, right_names)
|
||||||
|
|
||||||
|
if not (central and frontal and parietal and prefrontal and posterior):
|
||||||
|
fb = _fallback_region_indices(n_channels)
|
||||||
|
frontal2, central2, parietal2, prefrontal2, posterior2, left2, right2 = fb
|
||||||
|
frontal = frontal if frontal else frontal2
|
||||||
|
central = central if central else central2
|
||||||
|
parietal = parietal if parietal else parietal2
|
||||||
|
prefrontal = prefrontal if prefrontal else prefrontal2
|
||||||
|
posterior = posterior if posterior else posterior2
|
||||||
|
left = left if left else left2
|
||||||
|
right = right if right else right2
|
||||||
|
|
||||||
|
return frontal, central, parietal, prefrontal, posterior, left, right
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# Welch PSD + band power
|
||||||
|
# ==========================
|
||||||
|
def welch_psd(eeg_tc: np.ndarray, fs: float):
|
||||||
|
nperseg = min(1024, eeg_tc.shape[0])
|
||||||
|
if nperseg < 128:
|
||||||
|
nperseg = min(256, eeg_tc.shape[0])
|
||||||
|
freqs, pxx = signal.welch(
|
||||||
|
eeg_tc, fs=fs, nperseg=nperseg, noverlap=nperseg // 2,
|
||||||
|
axis=0, scaling="density",
|
||||||
|
)
|
||||||
|
return freqs, pxx
|
||||||
|
|
||||||
|
|
||||||
|
def band_power_from_psd(freqs, pxx_fc, band):
|
||||||
|
lo, hi = band
|
||||||
|
m = (freqs >= lo) & (freqs < hi)
|
||||||
|
if not np.any(m):
|
||||||
|
return np.zeros((pxx_fc.shape[1],), dtype=np.float32)
|
||||||
|
|
||||||
|
# 兼容处理:numpy 2.0+ 推荐使用 trapezoid,旧版本用 trapz
|
||||||
|
if hasattr(np, "trapezoid"):
|
||||||
|
return np.trapezoid(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
|
||||||
|
else:
|
||||||
|
return np.trapz(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def region_mean_power(freqs, pxx_fc, idx, band) -> float:
|
||||||
|
if not idx:
|
||||||
|
return 0.0
|
||||||
|
pw = band_power_from_psd(freqs, pxx_fc, band)
|
||||||
|
return float(np.mean(pw[idx]))
|
||||||
|
|
||||||
|
|
||||||
|
def compute_iaf(freqs, pxx_fc, posterior_idx):
|
||||||
|
lo, hi = BANDS_METRICS["Alpha"]
|
||||||
|
m = (freqs >= lo) & (freqs <= hi)
|
||||||
|
if not np.any(m) or not posterior_idx:
|
||||||
|
return 0.0
|
||||||
|
spec = np.mean(pxx_fc[:, posterior_idx], axis=1)
|
||||||
|
sub = spec[m]
|
||||||
|
fsub = freqs[m]
|
||||||
|
return float(fsub[int(np.argmax(sub))])
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 图:EEG波形、PSD
|
||||||
|
# ==========================
|
||||||
|
def plot_eeg_waveforms(data_uv_tc: np.ndarray, fs: float, ch_names, out_dir: str, seconds: int = 10):
|
||||||
|
"""
|
||||||
|
固定用 FIXED_EEG_IDXS 画 EEG.png(按重要性排序)
|
||||||
|
data_uv_tc: (T, C) μV
|
||||||
|
"""
|
||||||
|
T, C = data_uv_tc.shape
|
||||||
|
|
||||||
|
# 1) 过滤越界索引(避免你的数据通道数不足时报错)
|
||||||
|
idxs = [i for i in FIXED_EEG_IDXS if 0 <= i < C]
|
||||||
|
if len(idxs) < len(FIXED_EEG_IDXS):
|
||||||
|
missing = [i for i in FIXED_EEG_IDXS if not (0 <= i < C)]
|
||||||
|
print(f"[WARN] Some fixed EEG indices out of range (C={C}): {missing}")
|
||||||
|
|
||||||
|
if len(idxs) == 0:
|
||||||
|
raise RuntimeError(f"No valid indices in FIXED_EEG_IDXS for current data (C={C}).")
|
||||||
|
|
||||||
|
# 2) 通道显示
|
||||||
|
picked_names = []
|
||||||
|
for idx in idxs:
|
||||||
|
# 找 idx 在 FIXED_EEG_IDXS 的位置,用对应标签
|
||||||
|
pos = FIXED_EEG_IDXS.index(idx)
|
||||||
|
std_label = FIXED_EEG_LABELS[pos] if pos < len(FIXED_EEG_LABELS) else f"CH{idx}"
|
||||||
|
if ch_names and idx < len(ch_names):
|
||||||
|
picked_names.append(f"{std_label}")
|
||||||
|
else:
|
||||||
|
picked_names.append(std_label)
|
||||||
|
|
||||||
|
# 3) 截取前 seconds 秒
|
||||||
|
max_samples = int(min(T, seconds * fs))
|
||||||
|
x = np.arange(max_samples) / fs
|
||||||
|
|
||||||
|
fig_h = 1.4 * len(idxs) + 1
|
||||||
|
fig, axes = plt.subplots(len(idxs), 1, figsize=(10, fig_h), sharex=True)
|
||||||
|
if len(idxs) == 1:
|
||||||
|
axes = [axes]
|
||||||
|
|
||||||
|
# 4) 分位数定范围,避免尖峰撑爆
|
||||||
|
seg = data_uv_tc[:max_samples, idxs].T # (n_ch, samples)
|
||||||
|
lo = float(np.percentile(seg, 1))
|
||||||
|
hi = float(np.percentile(seg, 99))
|
||||||
|
m = max(abs(lo), abs(hi))
|
||||||
|
m = max(m, 50.0)
|
||||||
|
|
||||||
|
for ax, ch_idx, nm in zip(axes, idxs, picked_names):
|
||||||
|
y = data_uv_tc[:max_samples, ch_idx]
|
||||||
|
ax.plot(x, y, linewidth=1.2)
|
||||||
|
ax.set_ylabel("μV")
|
||||||
|
ax.set_title(nm, loc="left", fontsize=10)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.set_ylim(-m, m)
|
||||||
|
|
||||||
|
axes[-1].set_xlabel("Time (s)")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
out_path = os.path.join(out_dir, "EEG.png")
|
||||||
|
plt.savefig(out_path, dpi=200)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f"[OK] EEG waveform saved: {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def plot_psd(eeg_uV_tc, fs, ch_names, out_dir):
|
||||||
|
C = eeg_uV_tc.shape[1]
|
||||||
|
chosen_idx = []
|
||||||
|
|
||||||
|
if ch_names:
|
||||||
|
mp = {n.upper(): i for i, n in enumerate(ch_names)}
|
||||||
|
for p in ["C3","C4","CZ"]:
|
||||||
|
if p in mp:
|
||||||
|
chosen_idx.append(mp[p])
|
||||||
|
if len(chosen_idx) < 3:
|
||||||
|
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
|
||||||
|
stds.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
for i, _ in stds:
|
||||||
|
if i not in chosen_idx:
|
||||||
|
chosen_idx.append(i)
|
||||||
|
if len(chosen_idx) == 3:
|
||||||
|
break
|
||||||
|
chosen_name = [ch_names[i] for i in chosen_idx]
|
||||||
|
else:
|
||||||
|
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
|
||||||
|
stds.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
chosen_idx = [i for i, _ in stds[:3]]
|
||||||
|
chosen_name = [f"CH{i}" for i in chosen_idx]
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(7.5, 4.8))
|
||||||
|
for idx, nm in zip(chosen_idx, chosen_name):
|
||||||
|
f, pxx = signal.welch(eeg_uV_tc[:, idx], fs=fs, nperseg=int(2*fs), noverlap=int(1*fs))
|
||||||
|
mask = (f >= PSD_FMIN) & (f <= PSD_FMAX)
|
||||||
|
p_db = 10 * np.log10(pxx[mask] + 1e-20)
|
||||||
|
plt.plot(f[mask], p_db, linewidth=1.8, label=nm)
|
||||||
|
|
||||||
|
plt.xlabel("Hz")
|
||||||
|
plt.ylabel("Power (dB)")
|
||||||
|
plt.title("PSD")
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
plt.legend()
|
||||||
|
plt.tight_layout()
|
||||||
|
out_path = os.path.join(out_dir, "psd.png")
|
||||||
|
plt.savefig(out_path, dpi=200)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f"[OK] psd.png -> {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# Topomap(如果有 xyz)
|
||||||
|
# ==========================
|
||||||
|
def build_mne_raw_from_uV(eeg_uV_tc, fs, ch_names, xyz):
|
||||||
|
C = eeg_uV_tc.shape[1]
|
||||||
|
if not ch_names:
|
||||||
|
ch_names = [f"CH{i+1}" for i in range(C)]
|
||||||
|
data_v_ct = eeg_uV_tc.T * 1e-6 # (C,T) V
|
||||||
|
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types=["eeg"] * C)
|
||||||
|
raw = mne.io.RawArray(data_v_ct, info, verbose=False)
|
||||||
|
|
||||||
|
if xyz is not None and isinstance(xyz, np.ndarray) and xyz.shape == (C, 3):
|
||||||
|
try:
|
||||||
|
ch_pos = {ch_names[i]: xyz[i, :] for i in range(C)}
|
||||||
|
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame="head")
|
||||||
|
raw.set_montage(montage, on_missing="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] set_montage failed (ignore): {e}")
|
||||||
|
else:
|
||||||
|
print("[WARN] electrode_xyz missing/invalid -> skip topomap")
|
||||||
|
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_has_positions(raw):
|
||||||
|
try:
|
||||||
|
locs = np.array([ch["loc"][:3] for ch in raw.info["chs"]])
|
||||||
|
ok = np.isfinite(locs).all() and (np.linalg.norm(locs, axis=1) > 0).any()
|
||||||
|
return bool(ok)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def compute_band_powers_for_topomap(raw, bands):
|
||||||
|
data = raw.get_data() # (C,T) V
|
||||||
|
fs = raw.info["sfreq"]
|
||||||
|
psds, freqs = mne.time_frequency.psd_array_welch(
|
||||||
|
data, sfreq=fs,
|
||||||
|
fmin=min(v[0] for v in bands.values()),
|
||||||
|
fmax=max(v[1] for v in bands.values()),
|
||||||
|
n_fft=int(2 * fs),
|
||||||
|
n_overlap=int(1 * fs),
|
||||||
|
average="mean",
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
out = {}
|
||||||
|
for k, (fmin, fmax) in bands.items():
|
||||||
|
idx = np.where((freqs >= fmin) & (freqs <= fmax))[0]
|
||||||
|
|
||||||
|
# 兼容处理:numpy 2.0+ 推荐使用 trapezoid,旧版本用 trapz
|
||||||
|
if hasattr(np, "trapezoid"):
|
||||||
|
bp = np.trapezoid(psds[:, idx], freqs[idx], axis=1) # (C,)
|
||||||
|
else:
|
||||||
|
bp = np.trapz(psds[:, idx], freqs[idx], axis=1) # (C,)
|
||||||
|
|
||||||
|
v = np.log10(bp + 1e-30)
|
||||||
|
v = v - np.mean(v)
|
||||||
|
out[k] = v
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def plot_average_topomap(raw, values, out_dir):
|
||||||
|
fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.6))
|
||||||
|
im, _ = mne.viz.plot_topomap(values, raw.info, axes=ax, show=False, contours=0,sphere=(0, 0, 0, 0.11))
|
||||||
|
ax.set_title("0.8-30 Hz", fontsize=12)
|
||||||
|
plt.colorbar(im, ax=ax, shrink=0.85)
|
||||||
|
plt.tight_layout()
|
||||||
|
out_path = os.path.join(out_dir, "average_topomap.png")
|
||||||
|
plt.savefig(out_path, dpi=200)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f"[OK] average_topomap.png -> {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def plot_band_topomaps(raw, band_values, out_dir):
|
||||||
|
order = [
|
||||||
|
("delta", "δ (0.8-3.9Hz)"),
|
||||||
|
("theta", "θ (4-7.9Hz)"),
|
||||||
|
("alpha", "α (8-12.9Hz)"),
|
||||||
|
("beta", "β (13-30Hz)"),
|
||||||
|
("broad", "0.8-30 Hz"),
|
||||||
|
]
|
||||||
|
fig, axes = plt.subplots(1, 5, figsize=(16, 4.2))
|
||||||
|
ims = []
|
||||||
|
for ax, (k, title) in zip(axes, order):
|
||||||
|
im, _ = mne.viz.plot_topomap(band_values[k], raw.info, axes=ax, show=False, contours=0,extrapolate='head',sphere=(0, 0, 0, 0.11))
|
||||||
|
ax.set_title(title, fontsize=11)
|
||||||
|
ims.append(im)
|
||||||
|
fig.subplots_adjust(left=0.02, right=0.85, top=0.88, bottom=0.05, wspace=0.35)
|
||||||
|
cax = fig.add_axes([0.87, 0.15, 0.015, 0.7])
|
||||||
|
fig.colorbar(ims[-1], cax=cax)
|
||||||
|
out_path = os.path.join(out_dir, "topomaps.png")
|
||||||
|
plt.savefig(out_path, dpi=200)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f"[OK] topomaps.png -> {out_path}")
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 生成 ResultData.txt
|
||||||
|
# ==========================
|
||||||
|
def compute_and_save_txt(model_path, mat_dir, out_dir, eeg_uV_tc, fs, ch_names):
|
||||||
|
pred_label = _predict_label_by_model(model_path, mat_dir)
|
||||||
|
recommend = "是" if pred_label == "MDD" else "否"
|
||||||
|
|
||||||
|
T, C = eeg_uV_tc.shape
|
||||||
|
mp = build_channel_index_map(ch_names, C)
|
||||||
|
frontal_idx, central_idx, parietal_idx, prefrontal_idx, posterior_idx, left_idx, right_idx = \
|
||||||
|
get_region_indices(mp, C)
|
||||||
|
|
||||||
|
freqs, pxx = welch_psd(eeg_uV_tc, fs)
|
||||||
|
|
||||||
|
central_alpha = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Alpha"])
|
||||||
|
central_beta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Beta"])
|
||||||
|
frontal_alpha = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Alpha"])
|
||||||
|
frontal_beta = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Beta"])
|
||||||
|
par_alpha = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Alpha"])
|
||||||
|
par_beta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Beta"])
|
||||||
|
|
||||||
|
central_ab = (central_alpha / (central_beta + EPS)) if central_beta > 0 else 0.0
|
||||||
|
frontal_ab = (frontal_alpha / (frontal_beta + EPS)) if frontal_beta > 0 else 0.0
|
||||||
|
par_ab = (par_alpha / (par_beta + EPS)) if par_beta > 0 else 0.0
|
||||||
|
|
||||||
|
central_theta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Theta"])
|
||||||
|
par_theta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Theta"])
|
||||||
|
central_tb = (central_theta / (central_beta + EPS)) if central_beta > 0 else 0.0
|
||||||
|
par_tb = (par_theta / (par_beta + EPS)) if par_beta > 0 else 0.0
|
||||||
|
|
||||||
|
if not left_idx or not right_idx:
|
||||||
|
left_idx = [i for i in prefrontal_idx if (i % 2 == 0)]
|
||||||
|
right_idx = [i for i in prefrontal_idx if (i % 2 == 1)]
|
||||||
|
left_alpha = region_mean_power(freqs, pxx, left_idx, BANDS_METRICS["Alpha"])
|
||||||
|
right_alpha = region_mean_power(freqs, pxx, right_idx, BANDS_METRICS["Alpha"])
|
||||||
|
prefrontal_alpha_asym = float(np.log(right_alpha + EPS) - np.log(left_alpha + EPS))
|
||||||
|
|
||||||
|
iaf = compute_iaf(freqs, pxx, posterior_idx)
|
||||||
|
|
||||||
|
pre_td = region_mean_power(freqs, pxx, prefrontal_idx, (BANDS_METRICS["Delta"][0], BANDS_METRICS["Theta"][1]))
|
||||||
|
pre_total = region_mean_power(freqs, pxx, prefrontal_idx, TOTAL_POWER_BAND)
|
||||||
|
pre_td_rel = (pre_td / (pre_total + EPS)) * 100.0 if pre_total > 0 else 0.0
|
||||||
|
|
||||||
|
def f1(x): return f"{x:.1f}"
|
||||||
|
|
||||||
|
txt = (
|
||||||
|
f"中央区α/β波比值:{f1(central_ab)}\n"
|
||||||
|
f"额区α/β波比值:{f1(frontal_ab)}\n"
|
||||||
|
f"顶区α/β波比值:{f1(par_ab)}\n"
|
||||||
|
f"中央区θ/β波比值:{f1(central_tb)}\n"
|
||||||
|
f"顶区θ/β波比值:{f1(par_tb)}\n"
|
||||||
|
f"前额叶α波不对称性:{f1(prefrontal_alpha_asym)}\n"
|
||||||
|
f"个体化α峰值频率:{f1(iaf)}\n"
|
||||||
|
f"前额叶θ+δ波功率:{f1(pre_td_rel)}\n"
|
||||||
|
f"是否推荐治疗:{recommend}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
out_path = os.path.join(out_dir, "ResultData.txt")
|
||||||
|
with open(out_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(txt)
|
||||||
|
print(f"[OK] ResultData.txt -> {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 一个函数:一次性跑完(txt + 图片)
|
||||||
|
# ==========================
|
||||||
|
def run_all(model_path: str, mat_dir: str, out_root: str, seconds: int = EEG_PLOT_SECONDS):
|
||||||
|
# 1) 选第一个 mat
|
||||||
|
if not os.path.exists(mat_dir):
|
||||||
|
raise RuntimeError(f"输入目录不存在: {mat_dir}")
|
||||||
|
|
||||||
|
mats = [f for f in os.listdir(mat_dir) if f.lower().endswith(".mat")]
|
||||||
|
if not mats:
|
||||||
|
raise RuntimeError(f"mat_dir 下找不到 .mat: {mat_dir}")
|
||||||
|
mats.sort()
|
||||||
|
mat_file = os.path.join(mat_dir, mats[0])
|
||||||
|
print(f"[INFO] Found mat: {mat_file}")
|
||||||
|
|
||||||
|
# 2) 创建输出目录
|
||||||
|
out_dir = ensure_outdir(out_root)
|
||||||
|
print(f"[INFO] Output dir: {out_dir}")
|
||||||
|
|
||||||
|
# --- 总是进行预处理 (默认模式) ---
|
||||||
|
print("[INFO] Mode: Raw Data (Default). Running preprocessing...")
|
||||||
|
temp_dir = os.path.join(out_dir, "temp_preprocessed")
|
||||||
|
mat_file = preprocess_mat_file(mat_file, temp_dir)
|
||||||
|
# 更新 mat_dir 指向临时目录(为了传给 compute_and_save_txt 里的 predict 接口)
|
||||||
|
mat_dir = temp_dir
|
||||||
|
|
||||||
|
# 3) 读 EEG(μV)
|
||||||
|
eeg_uV_tc, fs, ch_names, xyz = load_eeg_from_mat(mat_file)
|
||||||
|
print(f"[INFO] eeg shape(T,C)={eeg_uV_tc.shape}, fs={fs}")
|
||||||
|
|
||||||
|
# 5) 画图:PSD + EEG
|
||||||
|
plot_psd(eeg_uV_tc, fs, ch_names, out_dir)
|
||||||
|
plot_eeg_waveforms(eeg_uV_tc, fs, ch_names, out_dir, seconds=seconds)
|
||||||
|
|
||||||
|
# 6) topomap(有 xyz 才画)
|
||||||
|
try:
|
||||||
|
raw = build_mne_raw_from_uV(eeg_uV_tc, fs, ch_names, xyz)
|
||||||
|
if _raw_has_positions(raw):
|
||||||
|
band_vals = compute_band_powers_for_topomap(raw, BANDS_TOPOMAP)
|
||||||
|
plot_average_topomap(raw, band_vals["broad"], out_dir)
|
||||||
|
plot_band_topomaps(raw, band_vals, out_dir)
|
||||||
|
else:
|
||||||
|
print("[WARN] No valid positions -> skip topomap.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] topomap failed -> skip. reason: {e}")
|
||||||
|
|
||||||
|
# 4) 指标写 txt
|
||||||
|
compute_and_save_txt(model_path, mat_dir, out_dir, eeg_uV_tc, fs, ch_names)
|
||||||
|
|
||||||
|
print("[DONE] txt + figures generated.")
|
||||||
|
return out_dir
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import multiprocessing
|
||||||
|
multiprocessing.freeze_support()
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# 1. 路径锚定:获取资源绝对路径
|
||||||
|
def get_resource_path(relative_path):
|
||||||
|
"""
|
||||||
|
获取资源的绝对路径。
|
||||||
|
策略:优先在当前执行目录(EXE所在目录)寻找。
|
||||||
|
这适用于“绿色软件”模式,即资源文件(model/raw_data)直接放在EXE旁边。
|
||||||
|
"""
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
# PyInstaller 打包后的 EXE 所在目录
|
||||||
|
base_path = os.path.dirname(sys.executable)
|
||||||
|
else:
|
||||||
|
# 开发环境:当前脚本所在目录
|
||||||
|
base_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
return os.path.join(base_path, relative_path)
|
||||||
|
|
||||||
|
# 设置默认路径
|
||||||
|
DEFAULT_MODEL = get_resource_path(os.path.join("model", "Model_1.pth"))
|
||||||
|
# 这里我们保持 mat_dir 和 out_root 相对于 EXE 所在目录(或当前工作目录)
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
EXE_DIR = os.path.dirname(sys.executable)
|
||||||
|
else:
|
||||||
|
EXE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
DEFAULT_MAT = os.path.join(EXE_DIR, "raw_data")
|
||||||
|
DEFAULT_OUT = os.path.join(EXE_DIR, "out")
|
||||||
|
|
||||||
|
# 2. 解析命令行参数
|
||||||
|
parser = argparse.ArgumentParser(description="EEG Depression Assessment Algorithm Integration")
|
||||||
|
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL, help="模型文件的路径 (.pth)")
|
||||||
|
parser.add_argument("--mat_dir", type=str, default=DEFAULT_MAT, help="输入文件夹路径 (包含原始EEG .mat)")
|
||||||
|
parser.add_argument("--out_root", type=str, default=DEFAULT_OUT, help="结果输出的根目录")
|
||||||
|
parser.add_argument("--seconds", type=int, default=10, help="画波形图截取的秒数")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 3. 检查关键路径
|
||||||
|
if not os.path.exists(args.mat_dir):
|
||||||
|
print(f"[WARN] 输入文件夹不存在: {args.mat_dir}")
|
||||||
|
if not os.path.exists(args.model_path):
|
||||||
|
print(f"[WARN] 模型文件不存在: {args.model_path}")
|
||||||
|
|
||||||
|
# 4. 执行主流程
|
||||||
|
print(f"[*] 运行配置:")
|
||||||
|
print(f" - Model : {args.model_path}")
|
||||||
|
print(f" - Input : {args.mat_dir}")
|
||||||
|
print(f" - Output: {args.out_root}")
|
||||||
|
print(f" - Mode : RAW (Auto Preprocess)")
|
||||||
|
|
||||||
|
run_all(args.model_path, args.mat_dir, args.out_root, seconds=args.seconds)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
16
algorithm_V0/datacollect/RunOnce.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import ctypes
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def is_program_running(name='Global\\Parser_main'):
|
||||||
|
# 创建互斥体
|
||||||
|
mutex_name =name
|
||||||
|
h_mutex = ctypes.windll.kernel32.CreateMutexW(None, False, mutex_name)
|
||||||
|
|
||||||
|
# 检查互斥体是否已经存在
|
||||||
|
if ctypes.windll.kernel32.GetLastError() == 183: # ERROR_ALREADY_EXISTS
|
||||||
|
print("程序已经在运行.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
379
algorithm_V0/datacollect/SunnyLinker.py
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
# -*-coding:utf-8 -*-
|
||||||
|
'''
|
||||||
|
SunnyLinker的通讯驱动
|
||||||
|
'''
|
||||||
|
import ast
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from threading import Thread, Event
|
||||||
|
import serial
|
||||||
|
from scipy import signal
|
||||||
|
from serial.serialutil import SerialException
|
||||||
|
|
||||||
|
from protocol import ProtocolFrame
|
||||||
|
|
||||||
|
class RingBuffer:
|
||||||
|
def __init__(self, n_chan, n_points):
|
||||||
|
self.n_chan = n_chan
|
||||||
|
self.n_points = n_points
|
||||||
|
self.buffer = np.zeros((n_chan, n_points))
|
||||||
|
self.currentPtr = 0
|
||||||
|
self.readPtr = 0
|
||||||
|
self.nUpdate = 0
|
||||||
|
self.rawData = np.zeros((n_chan, 1))
|
||||||
|
|
||||||
|
## append buffer and update current pointer
|
||||||
|
def appendBuffer(self, data):
|
||||||
|
if self.nUpdate == self.n_points:
|
||||||
|
raise Exception("Buffer is full")
|
||||||
|
|
||||||
|
n = data.shape[1]
|
||||||
|
|
||||||
|
# 计算可以写入的元素数量
|
||||||
|
write_count = min(self.n_points - self.nUpdate, n)
|
||||||
|
# 写入新数据
|
||||||
|
self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count]
|
||||||
|
# 更新结束指针
|
||||||
|
self.currentPtr = (self.currentPtr + write_count) % self.n_points
|
||||||
|
# 更新大小
|
||||||
|
self.nUpdate += write_count
|
||||||
|
|
||||||
|
## get data from buffer
|
||||||
|
def getData(self, count=50):
|
||||||
|
# 确保不会尝试读取超过缓冲区当前大小的数据
|
||||||
|
count = min(count, self.nUpdate)
|
||||||
|
|
||||||
|
# 计算读取结束后的下一个位置
|
||||||
|
next_read_ptr = (self.readPtr + count) % self.n_points
|
||||||
|
if self.readPtr + count <= self.n_points:
|
||||||
|
# 情况 1:不环绕,数据是连续的
|
||||||
|
end_index = next_read_ptr if next_read_ptr != 0 else self.n_points
|
||||||
|
data = self.buffer[:, self.readPtr:end_index]
|
||||||
|
else:
|
||||||
|
# 情况 2:发生环绕,数据被分成两部分
|
||||||
|
# 第一部分:从 readPtr 到缓冲区末尾
|
||||||
|
part1 = self.buffer[:, self.readPtr:]
|
||||||
|
# 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点
|
||||||
|
part2 = self.buffer[:, :next_read_ptr]
|
||||||
|
# 将两部分在列方向上拼接
|
||||||
|
data = np.concatenate((part1, part2), axis=1)
|
||||||
|
|
||||||
|
# 更新读指针
|
||||||
|
self.readPtr = next_read_ptr
|
||||||
|
# 更新大小
|
||||||
|
self.nUpdate -= count
|
||||||
|
return data
|
||||||
|
|
||||||
|
# reset buffer
|
||||||
|
def resetAllPara(self):
|
||||||
|
self.nUpdate = 0
|
||||||
|
self.currentPtr = 0
|
||||||
|
self.readPtr = 0 # add by lizhenhua 清空读指针
|
||||||
|
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
||||||
|
|
||||||
|
|
||||||
|
class SunnyLinker64(Thread, ):
|
||||||
|
t_buffer = 10
|
||||||
|
n_chan = 64
|
||||||
|
srate = 250
|
||||||
|
receiveData = b''
|
||||||
|
toUv=True#转为uV
|
||||||
|
RingBufferLock = threading.Lock()
|
||||||
|
|
||||||
|
# 单例模式
|
||||||
|
_instance = None
|
||||||
|
_initialized = False # 检查是否已经初始化
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(SunnyLinker64, cls).__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
def __init__(self, host='127.0.0.1', port=7878, srate=250, n_chan=64,method = 'tcp'):
|
||||||
|
if SunnyLinker64._initialized:
|
||||||
|
return
|
||||||
|
Thread.__init__(self)
|
||||||
|
self.daemon = True
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.srate = srate
|
||||||
|
self.n_chan = n_chan
|
||||||
|
self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输
|
||||||
|
self.__ringBuffer = RingBuffer(self.n_chan + 2,
|
||||||
|
int(np.round(self.t_buffer * self.srate)))
|
||||||
|
self.energy = 0 # 电量
|
||||||
|
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||||
|
self.gain_value = 6 # 增益倍数
|
||||||
|
|
||||||
|
# 设置初始化标志为True,防止重复初始化
|
||||||
|
SunnyLinker64._initialized = True
|
||||||
|
|
||||||
|
# --- 新增:用于心跳检测 ---
|
||||||
|
self.last_called = 0 # 初始化为0
|
||||||
|
self.last_called_lock = threading.Lock() # 保护 last_called 的访问
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def set_sampleRate(self,sampleRate_Code=0x00):
|
||||||
|
'''
|
||||||
|
设置采样率
|
||||||
|
:param sampleRate_Code: 0x00:250Hz,0x01:500Hz,0x02:1000Hz,0x03:2000Hz
|
||||||
|
'''
|
||||||
|
function_code = 0x02
|
||||||
|
gain_code = 0x06
|
||||||
|
sampleRate_Code = [gain_code,sampleRate_Code]
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, sampleRate_Code)
|
||||||
|
if self.method == 'tcp':
|
||||||
|
self.sock.send(packed_data)
|
||||||
|
|
||||||
|
def push_trigger(self,label):
|
||||||
|
'''
|
||||||
|
数据打标
|
||||||
|
@param label:标签类别
|
||||||
|
'''
|
||||||
|
function_code = None
|
||||||
|
label = [label]
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, label)
|
||||||
|
if self.method == 'tcp' and hasattr(self,'serial'):
|
||||||
|
print('发送:', label, datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3])
|
||||||
|
self.serial.write(packed_data)
|
||||||
|
def Impedance(self, On):
|
||||||
|
'''
|
||||||
|
阻抗检测开关
|
||||||
|
:param On:True为开启,False为关闭
|
||||||
|
:return: 组好的协议帧
|
||||||
|
'''
|
||||||
|
function_code = 0x01
|
||||||
|
if On:
|
||||||
|
data = [0x1]
|
||||||
|
self.gain_value = 6
|
||||||
|
else:
|
||||||
|
data = [0x0]
|
||||||
|
self.gain_value = 6
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, data)
|
||||||
|
if self.method == 'tcp':
|
||||||
|
self.sock.send(packed_data)
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
try:
|
||||||
|
if self.method == 'serial':
|
||||||
|
# 开启com口,波特率115200,超时5
|
||||||
|
self.sock = serial.Serial(self.host, self.port, timeout=5)
|
||||||
|
self.sock.flushInput() # 清空缓冲区
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
while not count:
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
# # 接收和存储数据
|
||||||
|
data = (self.sock.read(count))
|
||||||
|
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||||
|
elif self.method == 'tcp':
|
||||||
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
self.sock.connect((self.host, int(self.port)))
|
||||||
|
self.set_sampleRate(0x00) #设置250Hz采样率
|
||||||
|
except Exception as e:
|
||||||
|
print("请打开头环")
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
print("connected")
|
||||||
|
|
||||||
|
def extract_packet(self, packet):
|
||||||
|
# 存储一个点的八通道数据
|
||||||
|
dataList = []
|
||||||
|
# 存储116个点的八通道数据
|
||||||
|
dataMatrix = []
|
||||||
|
|
||||||
|
for j in range(5):
|
||||||
|
for i in range(self.n_chan):
|
||||||
|
if not self.toUv:#原始数据直接输出
|
||||||
|
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||||
|
194 * j + 25 + 2 + i * 3]
|
||||||
|
|
||||||
|
else:#转为uV
|
||||||
|
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||||
|
194 * j + 25 + 2 + i * 3]
|
||||||
|
if val < 8388608:
|
||||||
|
val = val * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||||
|
else:
|
||||||
|
val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||||
|
dataList.append(val)
|
||||||
|
#同步触发源
|
||||||
|
val = packet[194 * j + 25 + (i+1) * 3]
|
||||||
|
dataList.append(val)
|
||||||
|
#同步触发序号
|
||||||
|
val = packet[194 * j + 25 + (i+1) * 3+1]
|
||||||
|
dataList.append(val)
|
||||||
|
|
||||||
|
|
||||||
|
# 将数据矩阵进行拼接
|
||||||
|
if len(dataMatrix) == 0:
|
||||||
|
dataMatrix = np.asmatrix(dataList)
|
||||||
|
else:
|
||||||
|
dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0)
|
||||||
|
dataList.clear()
|
||||||
|
return np.transpose(dataMatrix)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.connect()
|
||||||
|
self.running = True
|
||||||
|
self.PackageLength = 998
|
||||||
|
# 启动心跳检测线程
|
||||||
|
threading.Thread(target=self.heartbeat_checker, daemon=True).start()
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
if self.method == 'serial':
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
if count:
|
||||||
|
# 接收和存储数据
|
||||||
|
data = (self.sock.read(count))
|
||||||
|
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||||
|
elif self.method == 'tcp':
|
||||||
|
data = self.sock.recv(600)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
self.receiveData += data
|
||||||
|
with self.last_called_lock:
|
||||||
|
self.last_called = time.time()
|
||||||
|
self.status_code = 1 # 收到数据,标记为正常
|
||||||
|
if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind(
|
||||||
|
b'\x55\x55') >= self.PackageLength - 2:
|
||||||
|
|
||||||
|
index = self.receiveData.index(b'\xaa')
|
||||||
|
self.receiveData = self.receiveData[index:]
|
||||||
|
if len(self.receiveData) >= self.PackageLength:
|
||||||
|
onepackage = self.receiveData[:self.PackageLength]
|
||||||
|
if onepackage[7] != 0:
|
||||||
|
self.energy = onepackage[7] # 电量
|
||||||
|
self.receiveData = self.receiveData[self.PackageLength:]
|
||||||
|
dataMatrix = self.extract_packet(onepackage)
|
||||||
|
try:
|
||||||
|
with self.RingBufferLock:
|
||||||
|
self.__ringBuffer.appendBuffer(dataMatrix)
|
||||||
|
except Exception as e:
|
||||||
|
print("锁:写入异常",e)
|
||||||
|
# self.RingBufferLock.release()
|
||||||
|
except ConnectionResetError:
|
||||||
|
self.status_code = 0 # 状态异常
|
||||||
|
print("Connection was reset by the peer.")
|
||||||
|
break
|
||||||
|
self.sock.close()
|
||||||
|
|
||||||
|
# --- 新增:心跳检测线程 ---
|
||||||
|
def heartbeat_checker(self):
|
||||||
|
"""
|
||||||
|
定期检查是否在最近2秒内收到 eegData
|
||||||
|
如果超过2秒未收到,则设置 status_code = 0
|
||||||
|
"""
|
||||||
|
while self.running:
|
||||||
|
time.sleep(0.5) # 每0.5秒检查一次
|
||||||
|
with self.last_called_lock:
|
||||||
|
now = time.time()
|
||||||
|
# 只有收到过一次数据后才开始判断超时
|
||||||
|
if self.last_called > 0 and (now - self.last_called) > 2:
|
||||||
|
if self.status_code != 0:
|
||||||
|
print("EEG data timeout: disconnected")
|
||||||
|
self.status_code = 0
|
||||||
|
def getImpedance(self, data,n_chan):
|
||||||
|
'''
|
||||||
|
获取阻抗值,已经放大100倍,单位是kΩ
|
||||||
|
@param data: 准备计算的通道数据,每通道200个值,注意不要把信号打标的通道传进来
|
||||||
|
@return: 返回各个通道的阻抗值
|
||||||
|
'''
|
||||||
|
impedanceList = []
|
||||||
|
data = data[:n_chan]
|
||||||
|
for channelindex in range(data.shape[0]):
|
||||||
|
if len(data[channelindex]) > 0:
|
||||||
|
data_list = []
|
||||||
|
# 设计陷波滤波器,去除50Hz成分
|
||||||
|
is50filter = True
|
||||||
|
if is50filter:
|
||||||
|
b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽,1000是采样频率
|
||||||
|
data_list = signal.lfilter(b, a, data[channelindex].tolist())
|
||||||
|
|
||||||
|
else:
|
||||||
|
data_list.extend(data[channelindex].tolist())
|
||||||
|
|
||||||
|
data_list = data_list[-1000:]
|
||||||
|
# 执行FFT
|
||||||
|
fft_result = np.fft.fft(data_list)
|
||||||
|
fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果
|
||||||
|
freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴
|
||||||
|
|
||||||
|
# y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())],
|
||||||
|
# fft_magnitude[1:-1] * 2 / len(t[0].tolist()),
|
||||||
|
# [fft_magnitude[-1] / len(t[0].tolist())]))
|
||||||
|
|
||||||
|
# 找到幅值最大的频率成分的索引(忽略直流分量,即索引0)
|
||||||
|
max_index = np.argmax(fft_magnitude[1:])
|
||||||
|
|
||||||
|
# 获取最大幅值的频率索引(加上1,因为索引0是直流分量)
|
||||||
|
freq_index = max_index + 1
|
||||||
|
|
||||||
|
# 获取最大幅值
|
||||||
|
max_magnitude = fft_magnitude[freq_index]
|
||||||
|
|
||||||
|
# 阻抗
|
||||||
|
import math
|
||||||
|
result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4
|
||||||
|
result *= 0.44 * 100 # 统一放大100倍
|
||||||
|
impedanceList.append(int(result))
|
||||||
|
# print(max_magnitude, result)
|
||||||
|
else:
|
||||||
|
impedanceList.append(0)
|
||||||
|
impedances = np.array(impedanceList)
|
||||||
|
return impedances
|
||||||
|
def getData(self,count):
|
||||||
|
'''
|
||||||
|
获取最新的数据
|
||||||
|
@param count: 每通道返回的最数值数目
|
||||||
|
@return: 所有通道的最新count个数值
|
||||||
|
'''
|
||||||
|
data=None
|
||||||
|
try:
|
||||||
|
with self.RingBufferLock:
|
||||||
|
data = self.__ringBuffer.getData(count)
|
||||||
|
except:
|
||||||
|
print("锁:读取异常")
|
||||||
|
# self.RingBufferLock.release()
|
||||||
|
|
||||||
|
|
||||||
|
return data
|
||||||
|
def GetDataLenCount(self):
|
||||||
|
'''
|
||||||
|
获取最新缓存中每个通道的数量
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
return self.__ringBuffer.nUpdate
|
||||||
|
|
||||||
|
def ResetAll(self):
|
||||||
|
'''
|
||||||
|
清空缓存
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
with self.RingBufferLock:
|
||||||
|
self.__ringBuffer.resetAllPara()
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Usage
|
||||||
|
Linker = SunnyLinker64('127.0.0.1', 5086, 1000, 65)
|
||||||
|
Linker.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(0.005)
|
||||||
|
if(Linker.count()>0):
|
||||||
|
# print(Linker.ringBuffer.nUpdate)
|
||||||
|
t = Linker.getData()
|
||||||
|
print(t.shape[1], Linker.count())
|
||||||
|
# Linker.ringBuffer.nUpdate=0
|
||||||
|
# time.sleep(0.2)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
Linker.stop()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
113
algorithm_V0/datacollect/build_algorithm.spec
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
# -*- mode: python ; coding: utf-8 -*-
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from PyInstaller.utils.hooks import collect_submodules, collect_data_files
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 1. 工程配置区 (Project Config)
|
||||||
|
# ========================================================
|
||||||
|
block_cipher = None
|
||||||
|
ENTRY_POINT = 'start_parse.py'
|
||||||
|
APP_NAME = 'start_parse' # 打包后生成的文件夹名和 exe 名
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 2. 依赖分析 (Dependency Analysis)
|
||||||
|
# ========================================================
|
||||||
|
hidden_imports = [
|
||||||
|
# eegParser 依赖
|
||||||
|
'numpy',
|
||||||
|
'numpy.lib.stride_tricks',
|
||||||
|
'pandas',
|
||||||
|
'scipy',
|
||||||
|
'scipy.io',
|
||||||
|
'scipy.io.savemat',
|
||||||
|
'scipy.signal',
|
||||||
|
|
||||||
|
# SunnyLinker 依赖
|
||||||
|
'serial',
|
||||||
|
'serial.serialutil',
|
||||||
|
'socket',
|
||||||
|
|
||||||
|
# zmq 通信依赖
|
||||||
|
'zmq',
|
||||||
|
'zmq.asyncio',
|
||||||
|
|
||||||
|
# 其他可能遗漏的模块
|
||||||
|
'threading',
|
||||||
|
'datetime',
|
||||||
|
]
|
||||||
|
|
||||||
|
# 收集 zmq 的所有子模块
|
||||||
|
try:
|
||||||
|
hidden_imports += collect_submodules('zmq')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 3. 资源锚定 (Data Anchoring)
|
||||||
|
# ========================================================
|
||||||
|
# 打包时需要包含的资源文件
|
||||||
|
datas = [
|
||||||
|
('xy_64.xlsx', '.'), # 电极位置文件
|
||||||
|
]
|
||||||
|
|
||||||
|
# 收集 mne 的数据文件(如果有)
|
||||||
|
try:
|
||||||
|
datas += collect_data_files('mne')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 4. 构建流程 (Build Process)
|
||||||
|
# ========================================================
|
||||||
|
a = Analysis(
|
||||||
|
[ENTRY_POINT],
|
||||||
|
pathex=[],
|
||||||
|
binaries=[],
|
||||||
|
datas=datas,
|
||||||
|
hiddenimports=hidden_imports,
|
||||||
|
hookspath=[],
|
||||||
|
hooksconfig={},
|
||||||
|
runtime_hooks=[],
|
||||||
|
excludes=['tkinter', 'PyQt5', 'PySide2', 'IPython', 'matplotlib'],
|
||||||
|
win_no_prefer_redirects=False,
|
||||||
|
win_private_assemblies=False,
|
||||||
|
cipher=block_cipher,
|
||||||
|
noarchive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||||
|
|
||||||
|
exe = EXE(
|
||||||
|
pyz,
|
||||||
|
a.scripts,
|
||||||
|
[],
|
||||||
|
exclude_binaries=True,
|
||||||
|
name=APP_NAME,
|
||||||
|
debug=False,
|
||||||
|
bootloader_ignore_signals=False,
|
||||||
|
strip=False,
|
||||||
|
upx=True,
|
||||||
|
upx_exclude=[],
|
||||||
|
console=True,
|
||||||
|
disable_windowed_traceback=False,
|
||||||
|
argv_emulation=False,
|
||||||
|
target_arch=None,
|
||||||
|
codesign_identity=None,
|
||||||
|
entitlements_file=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 5. 打包模式: OneDir (单文件夹)
|
||||||
|
# ========================================================
|
||||||
|
coll = COLLECT(
|
||||||
|
exe,
|
||||||
|
a.binaries,
|
||||||
|
a.zipfiles,
|
||||||
|
a.datas,
|
||||||
|
strip=False,
|
||||||
|
upx=True,
|
||||||
|
upx_exclude=[],
|
||||||
|
name=APP_NAME,
|
||||||
|
)
|
||||||
76
algorithm_V0/datacollect/build_datacollect.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
打包脚本 - datacollect
|
||||||
|
用于将 EEG 数据采集程序打包为独立的 exe 文件
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 1. 定义路径
|
||||||
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
DIST_DIR = os.path.join(BASE_DIR, 'dist')
|
||||||
|
BUILD_DIR = os.path.join(BASE_DIR, 'build')
|
||||||
|
APP_NAME = 'start_parse'
|
||||||
|
|
||||||
|
# 2. 清理旧构建
|
||||||
|
print("[1/3] Cleaning up old builds...")
|
||||||
|
for dir_path in [DIST_DIR, BUILD_DIR]:
|
||||||
|
if os.path.exists(dir_path):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(dir_path)
|
||||||
|
print(f" Cleaned {os.path.basename(dir_path)}/")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clean {dir_path}: {e}")
|
||||||
|
|
||||||
|
# 3. 检查必要文件
|
||||||
|
print("\n[2/3] Checking required files...")
|
||||||
|
required_files = ['start_parse.py', 'eegParser.py', 'build_algorithm.spec', 'xy_64.xlsx']
|
||||||
|
for f in required_files:
|
||||||
|
path = os.path.join(BASE_DIR, f)
|
||||||
|
if os.path.exists(path):
|
||||||
|
print(f" ✓ {f}")
|
||||||
|
else:
|
||||||
|
print(f" ✗ {f} NOT FOUND!")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 4. 运行 PyInstaller
|
||||||
|
print("\n[3/3] Running PyInstaller...")
|
||||||
|
spec_file = os.path.join(BASE_DIR, 'build_algorithm.spec')
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m", "PyInstaller",
|
||||||
|
spec_file,
|
||||||
|
"--clean",
|
||||||
|
"--noconfirm"
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.check_call(cmd, cwd=BASE_DIR)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"\n✗ PyInstaller failed with error code: {e.returncode}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 5. 验证结果
|
||||||
|
exe_path = os.path.join(DIST_DIR, APP_NAME, f'{APP_NAME}.exe')
|
||||||
|
if os.path.exists(exe_path):
|
||||||
|
size_mb = os.path.getsize(exe_path) / (1024 * 1024)
|
||||||
|
print(f"\n{'='*50}")
|
||||||
|
print(f"✓ SUCCESS! Executable created:")
|
||||||
|
print(f" {exe_path}")
|
||||||
|
print(f" Size: {size_mb:.1f} MB")
|
||||||
|
print(f"{'='*50}")
|
||||||
|
print(f"\n部署说明:")
|
||||||
|
print(f" 1. 复制 dist/start_parse 文件夹到目标电脑")
|
||||||
|
print(f" 2. 确保目标电脑已安装 EEG 设备的 USB 驱动")
|
||||||
|
print(f" 3. 运行 start_parse.exe")
|
||||||
|
else:
|
||||||
|
print("\n✗ Build failed - executable not found")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
72
algorithm_V0/datacollect/build_with_copy.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 1. 定义路径
|
||||||
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
DIST_DIR = os.path.join(BASE_DIR, 'dist')
|
||||||
|
APP_NAME = 'Depression_Decoder'
|
||||||
|
|
||||||
|
MODEL_SRC = os.path.join(BASE_DIR, 'model')
|
||||||
|
RAW_DATA_SRC = os.path.join(BASE_DIR, 'raw_data')
|
||||||
|
|
||||||
|
# 2. 清理旧构建
|
||||||
|
print("[1/3] Cleaning up old builds...")
|
||||||
|
if os.path.exists(DIST_DIR):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(DIST_DIR)
|
||||||
|
print(" Cleaned dist/")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clean dist/: {e}")
|
||||||
|
|
||||||
|
BUILD_DIR = os.path.join(BASE_DIR, 'build')
|
||||||
|
if os.path.exists(BUILD_DIR):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(BUILD_DIR)
|
||||||
|
print(" Cleaned build/")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clean build/: {e}")
|
||||||
|
|
||||||
|
# 3. 运行 PyInstaller
|
||||||
|
print("[2/3] Running PyInstaller...")
|
||||||
|
cmd = [
|
||||||
|
"pyinstaller",
|
||||||
|
"build_algorithm.spec",
|
||||||
|
"--clean",
|
||||||
|
"--noconfirm"
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.check_call(cmd, shell=True)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
print("Error: PyInstaller failed.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 4. 复制外部资源 (如果存在)
|
||||||
|
print("[3/3] Copying external resources...")
|
||||||
|
|
||||||
|
# 确保 dist 目录存在 (pyinstaller 应该已经创建了)
|
||||||
|
if not os.path.exists(DIST_DIR):
|
||||||
|
os.makedirs(DIST_DIR)
|
||||||
|
|
||||||
|
for src_path, folder_name in [(MODEL_SRC, 'model'), (RAW_DATA_SRC, 'raw_data')]:
|
||||||
|
dst_path = os.path.join(DIST_DIR, folder_name)
|
||||||
|
if os.path.exists(src_path):
|
||||||
|
try:
|
||||||
|
if os.path.exists(dst_path):
|
||||||
|
shutil.rmtree(dst_path)
|
||||||
|
shutil.copytree(src_path, dst_path)
|
||||||
|
print(f" Copied {folder_name} to dist/")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error copying {folder_name}: {e}")
|
||||||
|
else:
|
||||||
|
print(f" Note: {folder_name} source not found at {src_path}, skipping.")
|
||||||
|
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print(f"SUCCESS! Executable is in: {DIST_DIR}")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
427
algorithm_V0/datacollect/eegParser.py
Normal file
@@ -0,0 +1,427 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from SunnyLinker import SunnyLinker64
|
||||||
|
from zmqServer import zmqServer
|
||||||
|
from zmqClient import zmqClient
|
||||||
|
from scipy.io import savemat
|
||||||
|
from scipy import signal
|
||||||
|
|
||||||
|
class Parser_main(threading.Thread):
|
||||||
|
def __init__(self):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.Running = True
|
||||||
|
self.fs = 250 # 采样率
|
||||||
|
self.energy = 0 # 电量
|
||||||
|
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||||
|
self.n_chan = 64
|
||||||
|
self.dataBuffer = []
|
||||||
|
self.file_num = 0#保存文件序号
|
||||||
|
self.subject_id = None #受试者ID
|
||||||
|
self.session_id = None #Session ID
|
||||||
|
self.last_print_time = None
|
||||||
|
|
||||||
|
# 预处理参数
|
||||||
|
self.enable_preprocess = True # 是否启用预处理
|
||||||
|
self.lowcut = 0.5 # 高通滤波截止频率 (Hz)
|
||||||
|
self.highcut = 50 # 低通滤波截止频率 (Hz)
|
||||||
|
self.notch_freq = 50 # 工频陷波频率 (Hz)
|
||||||
|
self.ref_chan_name = 'CPZ' # 参考电极名称
|
||||||
|
self.ref_chan_idx = None # 参考电极索引(运行时确定)
|
||||||
|
self._init_filter_cache() # 初始化滤波器缓存
|
||||||
|
|
||||||
|
# 单位转换参数
|
||||||
|
self.calibration_scale = 1.0 # 校准系数,用于修正单位转换误差
|
||||||
|
self.calibration_offset = 0 # 校准偏移量
|
||||||
|
self._conversion_verified = False # 是否已验证转换
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self.thread_data_server = SunnyLinker64('127.0.0.1', 7878, 250, 64,
|
||||||
|
method='tcp')
|
||||||
|
self.thread_data_server.toUv = True
|
||||||
|
self.thread_data_server.start()
|
||||||
|
|
||||||
|
self.zmqServer = zmqServer()
|
||||||
|
self.zmqServer.start()
|
||||||
|
self.zmqClient = zmqClient('127.0.0.1', 8088)
|
||||||
|
self.zmqClient.connect()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
while self.Running:
|
||||||
|
# 同步信息
|
||||||
|
if self.zmqServer.state_mode == 'sync':
|
||||||
|
self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||||
|
self.zmqServer.state_mode = 'rest'
|
||||||
|
# 状态异常,报告上位机
|
||||||
|
if self.status_code != self.thread_data_server.status_code:
|
||||||
|
self.status_code = self.thread_data_server.status_code
|
||||||
|
self.zmqClient.send_to_all('status_code', int(self.status_code))
|
||||||
|
|
||||||
|
# 返回电量
|
||||||
|
if self.energy != self.thread_data_server.energy:
|
||||||
|
self.energy = self.thread_data_server.energy
|
||||||
|
self.zmqClient.send_to_all('energy', int(self.energy))
|
||||||
|
|
||||||
|
# 更新文件序号
|
||||||
|
if self.subject_id != self.zmqServer.subject_id or self.session_id != self.zmqServer.session_id:
|
||||||
|
self.subject_id = self.zmqServer.subject_id
|
||||||
|
self.session_id = self.zmqServer.session_id
|
||||||
|
self.file_num = 0 #从零开始计数
|
||||||
|
|
||||||
|
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
||||||
|
self.thread_data_server.Impedance(True)
|
||||||
|
self.zmqServer.open_Impedance = -1
|
||||||
|
elif self.zmqServer.open_Impedance == False:
|
||||||
|
self.thread_data_server.Impedance(False)
|
||||||
|
self.zmqServer.open_Impedance = -1
|
||||||
|
|
||||||
|
if self.zmqServer.get_Impedance: # 返回阻抗值
|
||||||
|
if self.thread_data_server.GetDataLenCount() > self.fs:
|
||||||
|
Impe_data = self.thread_data_server.getData(self.fs)
|
||||||
|
# 计算阻抗
|
||||||
|
imps = self.thread_data_server.getImpedance(Impe_data, self.n_chan)
|
||||||
|
self.zmqClient.send_to_all('impedance', imps.tolist())
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 50:
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
|
data = self.thread_data_server.getData(50)
|
||||||
|
data = data[:self.n_chan, :]
|
||||||
|
|
||||||
|
# 数据质量检查与预处理
|
||||||
|
if self.enable_preprocess:
|
||||||
|
# 1. 首先验证和校准单位转换
|
||||||
|
data, calibrated = self.verify_and_calibrate_unit(data)
|
||||||
|
if calibrated:
|
||||||
|
print('[INFO] 单位转换已自动校准')
|
||||||
|
|
||||||
|
# 2. 检查数据质量
|
||||||
|
issues = self.check_data_quality(data)
|
||||||
|
if issues:
|
||||||
|
print('[警告] 检测到数据质量问题:')
|
||||||
|
for issue in issues:
|
||||||
|
print(f' - {issue}')
|
||||||
|
print('[INFO] 正在进行信号预处理...')
|
||||||
|
|
||||||
|
# 3. 执行预处理
|
||||||
|
data = self.preprocess_data(data)
|
||||||
|
|
||||||
|
# 4. 预处理后验证
|
||||||
|
if issues:
|
||||||
|
new_issues = self.check_data_quality(data)
|
||||||
|
if not new_issues:
|
||||||
|
print(f'[INFO] 预处理完成,数据幅度正常: {np.max(np.abs(data)):.2f} µV')
|
||||||
|
else:
|
||||||
|
print('[警告] 预处理后仍存在问题:')
|
||||||
|
for issue in new_issues:
|
||||||
|
print(f' - {issue}')
|
||||||
|
|
||||||
|
if self.zmqServer.mat_generate:
|
||||||
|
# 检测是否需要重置缓冲区(第二次发送 matGenerate 时清空旧数据)
|
||||||
|
if self.zmqServer.reset_mat_buffer:
|
||||||
|
self.dataBuffer = []
|
||||||
|
self.last_print_time = None
|
||||||
|
self.zmqServer.reset_mat_buffer = False
|
||||||
|
print('[INFO] 数据缓冲区已重置,从头开始采集')
|
||||||
|
|
||||||
|
self.dataBuffer.append(data)
|
||||||
|
if len(self.dataBuffer) % 50 == 0:
|
||||||
|
current_time = time.time()
|
||||||
|
if self.last_print_time is not None:
|
||||||
|
elapsed_time = current_time - self.last_print_time
|
||||||
|
# 2500个点 = 50个数据块 * 50个采样点/数据块
|
||||||
|
actual_fs = 2500 / elapsed_time
|
||||||
|
print(f"接收 2500 个采样点耗时: {elapsed_time:.4f} 秒, 折合实际采样率: {actual_fs:.2f} Hz")
|
||||||
|
else:
|
||||||
|
print("开始计时...")
|
||||||
|
self.last_print_time = current_time
|
||||||
|
print('数据保存进度: {}/{}'.format(len(self.dataBuffer),int(self.zmqServer.save_win*self.fs//50)))
|
||||||
|
if len(self.dataBuffer) >= int(self.zmqServer.save_win*self.fs//50): #5分钟*60秒*250Hz / 50
|
||||||
|
self.zmqServer.mat_generate = False
|
||||||
|
matData = np.hstack(self.dataBuffer[:int(self.zmqServer.save_win*self.fs//50)])
|
||||||
|
self.dataBuffer = []
|
||||||
|
self.last_print_time = None # 重置计时器以备下次使用
|
||||||
|
self.pack2mat(matData,self.subject_id,self.session_id)
|
||||||
|
|
||||||
|
def pack2mat(self,data,subject_id,session_id):
|
||||||
|
#EEG数据
|
||||||
|
Data = data.T
|
||||||
|
#通道名称
|
||||||
|
channel_names = np.array(
|
||||||
|
['AIN1', 'AIN2', 'AIN3', 'AIN4', 'AIN5', 'AIN6', 'AIN7', 'AIN8', 'AIN9', 'AIN10', 'AIN11', 'AIN12',
|
||||||
|
'AIN13', 'AIN14', 'AIN15', 'AIN16', 'AIN17', 'AIN18', 'AIN19', 'AIN20', 'AIN21', 'AIN22', 'AIN23',
|
||||||
|
'AIN24', 'AIN25', 'AIN26', 'AIN27', 'AIN28', 'AIN29', 'AIN30', 'AIN31', 'AIN32', 'AIN33', 'AIN34',
|
||||||
|
'AIN35', 'AIN36', 'AIN37', 'AIN38', 'AIN39', 'AIN40', 'AIN41', 'AIN42', 'AIN43', 'AIN44', 'AIN45',
|
||||||
|
'AIN46', 'AIN47', 'AIN48', 'AIN49', 'AIN50', 'AIN51', 'AIN52', 'AIN53', 'AIN54', 'AIN55', 'AIN56',
|
||||||
|
'AIN57', 'AIN58', 'AIN59', 'AIN60', 'AIN61', 'AIN62', 'AIN63', 'AIN64'], dtype=object)
|
||||||
|
#采样率
|
||||||
|
sample_rate = self.fs
|
||||||
|
#通道数量
|
||||||
|
node_number = Data.shape[1]
|
||||||
|
# 时间轴
|
||||||
|
t = np.linspace(0, self.zmqServer.save_win, Data.shape[0])
|
||||||
|
t = t.reshape(len(t), 1)
|
||||||
|
#电极名称
|
||||||
|
electrode_name = np.array(['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1',
|
||||||
|
'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6',
|
||||||
|
'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1','T7',
|
||||||
|
'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ','PZ',
|
||||||
|
'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1'],
|
||||||
|
dtype=object)
|
||||||
|
#电极三维坐标
|
||||||
|
electrode_xyz = self.read_ch_pos()
|
||||||
|
electrode_xyz.update({'A1': [-0.095, 0, -0.005]})
|
||||||
|
electrode_xyz = {key: electrode_xyz[key] for key in electrode_name}
|
||||||
|
electrode_xyz = np.array(list(electrode_xyz.values()))
|
||||||
|
#电极坐标所属的坐标系
|
||||||
|
electrode_coord_system = '10-20 spherical model'
|
||||||
|
#受试者ID
|
||||||
|
Subject_id = subject_id
|
||||||
|
#Session ID
|
||||||
|
Session_id = session_id
|
||||||
|
#参考电极方案
|
||||||
|
ref = 'CPZ'
|
||||||
|
#数据采集开始时间
|
||||||
|
start_time = 0
|
||||||
|
|
||||||
|
meta_struct = {
|
||||||
|
'subject_id': Subject_id,
|
||||||
|
'session_id': Session_id,
|
||||||
|
'ref': ref,
|
||||||
|
'start_time': start_time
|
||||||
|
}
|
||||||
|
|
||||||
|
eeg_struct = {
|
||||||
|
'data': Data,
|
||||||
|
'chn': channel_names,
|
||||||
|
'sample_rate': sample_rate,
|
||||||
|
'node_number': node_number,
|
||||||
|
't': t,
|
||||||
|
'electrode_name': electrode_name,
|
||||||
|
'electrode_xyz': electrode_xyz,
|
||||||
|
'electrode_coord_system': electrode_coord_system,
|
||||||
|
'meta': meta_struct,
|
||||||
|
}
|
||||||
|
|
||||||
|
fileDir = os.path.join('EEGfiles/',Subject_id,Session_id)
|
||||||
|
os.makedirs(fileDir,exist_ok=True)
|
||||||
|
filePath = os.path.join(fileDir,'eeg_data{}.mat'.format(self.file_num))
|
||||||
|
# 保存到 .mat 文件,顶层变量名为 'eeg'
|
||||||
|
savemat(filePath, {'eeg': eeg_struct})
|
||||||
|
print('EEGfile saved at {}'.format(filePath))
|
||||||
|
self.zmqClient.send_to_all('filePath', filePath)
|
||||||
|
self.file_num += 1
|
||||||
|
|
||||||
|
def read_ch_pos(self,file_path=r'xy_64.xlsx'):
|
||||||
|
"""
|
||||||
|
将电极位置信息转换为Dict
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path: 电极位置存储文件, 必须包含'channel', 'x', 'y', 'z'列
|
||||||
|
|
||||||
|
"""
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
script_dir = sys._MEIPASS
|
||||||
|
else:
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_path = os.path.join(script_dir, file_path)
|
||||||
|
df = pd.read_excel(file_path)
|
||||||
|
# 确保列名正确
|
||||||
|
if not all(col in df.columns for col in ['channel', 'x', 'y', 'z']):
|
||||||
|
raise ValueError("DataFrame必须包含'channel', 'x', 'y', 'z'列")
|
||||||
|
# 创建电极位置字典
|
||||||
|
ch_pos = {}
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
ch_pos[row['channel']] = [row['x'], row['y'], row['z']]
|
||||||
|
return ch_pos
|
||||||
|
|
||||||
|
def _init_filter_cache(self):
|
||||||
|
"""初始化滤波器系数缓存"""
|
||||||
|
self._filter_cache = {
|
||||||
|
'highpass': None,
|
||||||
|
'lowpass': None,
|
||||||
|
'notch': None
|
||||||
|
}
|
||||||
|
self._cache_valid = False
|
||||||
|
|
||||||
|
def _design_filters(self):
|
||||||
|
"""设计滤波器系数"""
|
||||||
|
if self._cache_valid:
|
||||||
|
return
|
||||||
|
|
||||||
|
nyquist = self.fs / 2
|
||||||
|
fs_nyq = self.fs
|
||||||
|
|
||||||
|
# 高通滤波 (去除低频漂移)
|
||||||
|
high = self.lowcut / nyquist
|
||||||
|
if 0 < high < 1:
|
||||||
|
self._filter_cache['highpass'] = signal.butter(2, high, btype='high', output='ba')
|
||||||
|
|
||||||
|
# 低通滤波 (去除高频噪声)
|
||||||
|
low = self.highcut / nyquist
|
||||||
|
if 0 < low < 1:
|
||||||
|
self._filter_cache['lowpass'] = signal.butter(4, low, btype='low', output='ba')
|
||||||
|
|
||||||
|
# 50Hz 陷波滤波 (去除工频干扰)
|
||||||
|
Q = 30 # 品质因子
|
||||||
|
self._filter_cache['notch'] = signal.iirnotch(self.notch_freq, Q, fs=fs_nyq)
|
||||||
|
|
||||||
|
# 查找CPZ通道索引
|
||||||
|
electrode_name = ['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1',
|
||||||
|
'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6',
|
||||||
|
'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1','T7',
|
||||||
|
'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ','PZ',
|
||||||
|
'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1']
|
||||||
|
try:
|
||||||
|
self.ref_chan_idx = electrode_name.index(self.ref_chan_name)
|
||||||
|
except ValueError:
|
||||||
|
self.ref_chan_idx = 50 # 默认CZ (对应索引50)
|
||||||
|
print(f'[警告] 未找到参考电极 {self.ref_chan_name},使用默认值 CZ')
|
||||||
|
|
||||||
|
self._cache_valid = True
|
||||||
|
print(f'[INFO] 预处理已启用 - 高通:{self.lowcut}Hz, 低通:{self.highcut}Hz, 陷波:{self.notch_freq}Hz, 参考:{self.ref_chan_name}(索引:{self.ref_chan_idx})')
|
||||||
|
|
||||||
|
def check_data_quality(self, data):
|
||||||
|
"""
|
||||||
|
检查数据质量
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list: 发现的问题列表,空列表表示质量正常
|
||||||
|
"""
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
# 检查幅度
|
||||||
|
amplitude = np.max(np.abs(data))
|
||||||
|
if amplitude > 1e6: # 超过 1mV = 1000µV
|
||||||
|
issues.append(f'幅度异常: {amplitude:.2e} (可能为原始ADC值或单位错误)')
|
||||||
|
elif amplitude > 1000:
|
||||||
|
issues.append(f'幅度偏高: {amplitude:.2f}')
|
||||||
|
|
||||||
|
# 检查平坦噪声 (通道可能未连接)
|
||||||
|
if np.std(data) < 0.01:
|
||||||
|
issues.append('信号过平,可能通道未连接')
|
||||||
|
|
||||||
|
# 检查饱和
|
||||||
|
n_saturated = np.sum(np.abs(data) > 1e8)
|
||||||
|
if n_saturated > 0:
|
||||||
|
issues.append(f'检测到 {n_saturated} 个采样点饱和')
|
||||||
|
|
||||||
|
return issues
|
||||||
|
|
||||||
|
def verify_and_calibrate_unit(self, data):
|
||||||
|
"""
|
||||||
|
验证并校准数据单位
|
||||||
|
|
||||||
|
SunnyLinker64 的转换公式:
|
||||||
|
val = raw_adc * 4.5 / gain_value / 8388608 * 1000000 (µV)
|
||||||
|
|
||||||
|
但如果硬件实际增益与 gain_value=6 不符,会导致单位错误。
|
||||||
|
本函数通过检测数据范围来验证和修正单位。
|
||||||
|
|
||||||
|
正常EEG信号范围: ±50-100 µV
|
||||||
|
如果检测到的数据范围是 ±1e6 量级,说明转换可能有问题
|
||||||
|
|
||||||
|
参数:
|
||||||
|
data: 原始数据
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple: (校准后的数据, 是否进行了校准)
|
||||||
|
"""
|
||||||
|
if self._conversion_verified:
|
||||||
|
return data, False
|
||||||
|
|
||||||
|
amplitude = np.max(np.abs(data))
|
||||||
|
|
||||||
|
# 判断数据是否在合理范围内
|
||||||
|
# 正常EEG: 1 - 1000 µV (考虑某些高幅值情况)
|
||||||
|
# 异常: > 1e5 µV (可能是ADC原始值未转换或转换系数错误)
|
||||||
|
|
||||||
|
if amplitude > 1e6:
|
||||||
|
print('[警告] 检测到异常大幅值数据,可能是ADC原始值或单位转换失败!')
|
||||||
|
print(f' 当前最大幅度: {amplitude:.2e} µV')
|
||||||
|
print('[INFO] 尝试自动校准单位转换...')
|
||||||
|
|
||||||
|
# SunnyLinker64 的理论转换系数约为 0.0894 µV/LSB
|
||||||
|
# 如果数据是原始ADC值,需要除以这个系数来还原
|
||||||
|
theoretical_scale = 4.5 / 6 / 8388608 * 1e6 # 理论系数: ~0.0894 µV/LSB
|
||||||
|
|
||||||
|
# 计算校准系数
|
||||||
|
# 假设数据是原始ADC值,需要除以 (amplitude / expected_amplitude)
|
||||||
|
# 正常EEG信号预期幅度约 100 µV
|
||||||
|
expected_amplitude = 100.0 # µV
|
||||||
|
|
||||||
|
if amplitude > expected_amplitude:
|
||||||
|
# 计算校准系数: 原始值 / 预期值 = 实际值 / 校准后值
|
||||||
|
self.calibration_scale = expected_amplitude / amplitude
|
||||||
|
|
||||||
|
# 应用校准
|
||||||
|
data = data * self.calibration_scale
|
||||||
|
print(f'[INFO] 校准完成,应用系数: {self.calibration_scale:.6e}')
|
||||||
|
print(f' 校准后最大幅度: {np.max(np.abs(data)):.2f} µV')
|
||||||
|
self._conversion_verified = True
|
||||||
|
return data, True
|
||||||
|
|
||||||
|
elif amplitude < 0.01:
|
||||||
|
print('[警告] 数据幅度接近零,可能通道未连接或设备异常')
|
||||||
|
|
||||||
|
self._conversion_verified = True
|
||||||
|
return data, False
|
||||||
|
|
||||||
|
def preprocess_data(self, data):
|
||||||
|
"""
|
||||||
|
EEG信号预处理
|
||||||
|
|
||||||
|
参数:
|
||||||
|
data: ndarray, shape (n_chan, n_samples), 原始EEG数据
|
||||||
|
|
||||||
|
返回:
|
||||||
|
ndarray: 预处理后的EEG数据
|
||||||
|
"""
|
||||||
|
if not self.enable_preprocess:
|
||||||
|
return data
|
||||||
|
|
||||||
|
# 确保数据是 float64 类型
|
||||||
|
data = data.astype(np.float64)
|
||||||
|
|
||||||
|
# 设计滤波器
|
||||||
|
self._design_filters()
|
||||||
|
|
||||||
|
# 1. 去除直流分量和低频漂移 (高通滤波)
|
||||||
|
if self._filter_cache['highpass'] is not None:
|
||||||
|
b, a = self._filter_cache['highpass']
|
||||||
|
for ch in range(data.shape[0]):
|
||||||
|
data[ch, :] = signal.filtfilt(b, a, data[ch, :])
|
||||||
|
|
||||||
|
# 2. 50Hz 工频陷波滤波
|
||||||
|
if self._filter_cache['notch'] is not None:
|
||||||
|
b, a = self._filter_cache['notch']
|
||||||
|
for ch in range(data.shape[0]):
|
||||||
|
data[ch, :] = signal.filtfilt(b, a, data[ch, :])
|
||||||
|
|
||||||
|
# 3. 低通滤波 (去除高频噪声)
|
||||||
|
if self._filter_cache['lowpass'] is not None:
|
||||||
|
b, a = self._filter_cache['lowpass']
|
||||||
|
for ch in range(data.shape[0]):
|
||||||
|
data[ch, :] = signal.filtfilt(b, a, data[ch, :])
|
||||||
|
|
||||||
|
# 4. 重参考 (以CPZ为参考)
|
||||||
|
if self.ref_chan_idx is not None and self.ref_chan_idx < data.shape[0]:
|
||||||
|
ref_signal = data[self.ref_chan_idx, :]
|
||||||
|
data = data - ref_signal
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
'''
|
||||||
|
停止运行
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
self.zmqServer.stop()
|
||||||
|
self.Running=False
|
||||||
207
algorithm_V0/datacollect/eegParser_scipy_package.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from SunnyLinker import SunnyLinker64
|
||||||
|
from zmqServer import zmqServer
|
||||||
|
from zmqClient import zmqClient
|
||||||
|
from scipy.io import savemat
|
||||||
|
|
||||||
|
class Parser_main(threading.Thread):
|
||||||
|
def __init__(self):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.Running = True
|
||||||
|
self.fs = 250 # 采样率
|
||||||
|
self.energy = 0 # 电量
|
||||||
|
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||||
|
self.n_chan = 64
|
||||||
|
self.dataBuffer = []
|
||||||
|
self.file_num = 0#保存文件序号
|
||||||
|
self.subject_id = None #受试者ID
|
||||||
|
self.session_id = None #Session ID
|
||||||
|
self.last_print_time = None
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
self.thread_data_server = SunnyLinker64('127.0.0.1', 7878, 250, 64,
|
||||||
|
method='tcp')
|
||||||
|
self.thread_data_server.toUv = True
|
||||||
|
self.thread_data_server.start()
|
||||||
|
|
||||||
|
self.zmqServer = zmqServer()
|
||||||
|
self.zmqServer.start()
|
||||||
|
self.zmqClient = zmqClient('127.0.0.1', 8088)
|
||||||
|
self.zmqClient.connect()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
while self.Running:
|
||||||
|
# 同步信息
|
||||||
|
if self.zmqServer.state_mode == 'sync':
|
||||||
|
self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||||
|
self.zmqServer.state_mode = 'rest'
|
||||||
|
# 状态异常,报告上位机
|
||||||
|
if self.status_code != self.thread_data_server.status_code:
|
||||||
|
self.status_code = self.thread_data_server.status_code
|
||||||
|
self.zmqClient.send_to_all('status_code', int(self.status_code))
|
||||||
|
|
||||||
|
# 返回电量
|
||||||
|
if self.energy != self.thread_data_server.energy:
|
||||||
|
self.energy = self.thread_data_server.energy
|
||||||
|
self.zmqClient.send_to_all('energy', int(self.energy))
|
||||||
|
|
||||||
|
# 更新文件序号
|
||||||
|
if self.subject_id != self.zmqServer.subject_id or self.session_id != self.zmqServer.session_id:
|
||||||
|
self.subject_id = self.zmqServer.subject_id
|
||||||
|
self.session_id = self.zmqServer.session_id
|
||||||
|
self.file_num = 0 #从零开始计数
|
||||||
|
|
||||||
|
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
||||||
|
self.thread_data_server.Impedance(True)
|
||||||
|
self.zmqServer.open_Impedance = -1
|
||||||
|
elif self.zmqServer.open_Impedance == False:
|
||||||
|
self.thread_data_server.Impedance(False)
|
||||||
|
self.zmqServer.open_Impedance = -1
|
||||||
|
|
||||||
|
if self.zmqServer.get_Impedance: # 返回阻抗值
|
||||||
|
if self.thread_data_server.GetDataLenCount() > self.fs:
|
||||||
|
Impe_data = self.thread_data_server.getData(self.fs)
|
||||||
|
# 计算阻抗
|
||||||
|
imps = self.thread_data_server.getImpedance(Impe_data, self.n_chan)
|
||||||
|
self.zmqClient.send_to_all('impedance', imps.tolist())
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 50:
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
|
data = self.thread_data_server.getData(50)
|
||||||
|
data = data[:self.n_chan, :]
|
||||||
|
if self.zmqServer.mat_generate:
|
||||||
|
# 检测是否需要重置缓冲区(第二次发送 matGenerate 时清空旧数据)
|
||||||
|
if self.zmqServer.reset_mat_buffer:
|
||||||
|
self.dataBuffer = []
|
||||||
|
self.last_print_time = None
|
||||||
|
self.zmqServer.reset_mat_buffer = False
|
||||||
|
print('[INFO] 数据缓冲区已重置,从头开始采集')
|
||||||
|
|
||||||
|
self.dataBuffer.append(data)
|
||||||
|
if len(self.dataBuffer) % 50 == 0:
|
||||||
|
current_time = time.time()
|
||||||
|
if self.last_print_time is not None:
|
||||||
|
elapsed_time = current_time - self.last_print_time
|
||||||
|
# 2500个点 = 50个数据块 * 50个采样点/数据块
|
||||||
|
actual_fs = 2500 / elapsed_time
|
||||||
|
print(f"接收 2500 个采样点耗时: {elapsed_time:.4f} 秒, 折合实际采样率: {actual_fs:.2f} Hz")
|
||||||
|
else:
|
||||||
|
print("开始计时...")
|
||||||
|
self.last_print_time = current_time
|
||||||
|
print('数据保存进度: {}/{}'.format(len(self.dataBuffer),int(self.zmqServer.save_win*self.fs//50)))
|
||||||
|
if len(self.dataBuffer) >= int(self.zmqServer.save_win*self.fs//50): #5分钟*60秒*250Hz / 50
|
||||||
|
self.zmqServer.mat_generate = False
|
||||||
|
matData = np.hstack(self.dataBuffer[:int(self.zmqServer.save_win*self.fs//50)])
|
||||||
|
self.dataBuffer = []
|
||||||
|
self.last_print_time = None # 重置计时器以备下次使用
|
||||||
|
self.pack2mat(matData,self.subject_id,self.session_id)
|
||||||
|
|
||||||
|
def pack2mat(self,data,subject_id,session_id):
|
||||||
|
#EEG数据
|
||||||
|
Data = data.T
|
||||||
|
#通道名称
|
||||||
|
channel_names = np.array(
|
||||||
|
['AIN1', 'AIN2', 'AIN3', 'AIN4', 'AIN5', 'AIN6', 'AIN7', 'AIN8', 'AIN9', 'AIN10', 'AIN11', 'AIN12',
|
||||||
|
'AIN13', 'AIN14', 'AIN15', 'AIN16', 'AIN17', 'AIN18', 'AIN19', 'AIN20', 'AIN21', 'AIN22', 'AIN23',
|
||||||
|
'AIN24', 'AIN25', 'AIN26', 'AIN27', 'AIN28', 'AIN29', 'AIN30', 'AIN31', 'AIN32', 'AIN33', 'AIN34',
|
||||||
|
'AIN35', 'AIN36', 'AIN37', 'AIN38', 'AIN39', 'AIN40', 'AIN41', 'AIN42', 'AIN43', 'AIN44', 'AIN45',
|
||||||
|
'AIN46', 'AIN47', 'AIN48', 'AIN49', 'AIN50', 'AIN51', 'AIN52', 'AIN53', 'AIN54', 'AIN55', 'AIN56',
|
||||||
|
'AIN57', 'AIN58', 'AIN59', 'AIN60', 'AIN61', 'AIN62', 'AIN63', 'AIN64'], dtype=object)
|
||||||
|
#采样率
|
||||||
|
sample_rate = self.fs
|
||||||
|
#通道数量
|
||||||
|
node_number = Data.shape[1]
|
||||||
|
# 时间轴
|
||||||
|
t = np.linspace(0, self.zmqServer.save_win, Data.shape[0])
|
||||||
|
t = t.reshape(len(t), 1)
|
||||||
|
#电极名称
|
||||||
|
electrode_name = np.array(['FP1', 'FP2', 'PO6', 'POZ', 'F3', 'F4', 'FPZ', 'AF4', 'FC3', 'PO8', 'CP2', 'CP1',
|
||||||
|
'FCZ', 'PO5', 'FC2', 'FC1', 'C3', 'C4', 'FC4', 'CP4', 'P3', 'P4', 'F5', 'C5', 'F6',
|
||||||
|
'PO4', 'CP6', 'CP5', 'PO3', 'CP3', 'FC6', 'FC5', 'CB1', 'CB2', 'P5', 'AF7', 'A1','T7',
|
||||||
|
'FT7', 'TP7', 'FT8', 'AF8', 'F8', 'F7', 'P6', 'C6', 'O2', 'O1', 'T8', 'P7', 'CZ','PZ',
|
||||||
|
'P8', 'FZ', 'OZ', 'PO7', 'TP8', 'AF3', 'C2', 'C1', 'P2', 'P1', 'F2', 'F1'],
|
||||||
|
dtype=object)
|
||||||
|
#电极三维坐标
|
||||||
|
electrode_xyz = self.read_ch_pos()
|
||||||
|
electrode_xyz.update({'A1': [-0.095, 0, -0.005]})
|
||||||
|
electrode_xyz = {key: electrode_xyz[key] for key in electrode_name}
|
||||||
|
electrode_xyz = np.array(list(electrode_xyz.values()))
|
||||||
|
#电极坐标所属的坐标系
|
||||||
|
electrode_coord_system = '10-20 spherical model'
|
||||||
|
#受试者ID
|
||||||
|
Subject_id = subject_id
|
||||||
|
#Session ID
|
||||||
|
Session_id = session_id
|
||||||
|
#参考电极方案
|
||||||
|
ref = 'CPZ'
|
||||||
|
#数据采集开始时间
|
||||||
|
start_time = 0
|
||||||
|
|
||||||
|
meta_struct = {
|
||||||
|
'subject_id': Subject_id,
|
||||||
|
'session_id': Session_id,
|
||||||
|
'ref': ref,
|
||||||
|
'start_time': start_time
|
||||||
|
}
|
||||||
|
|
||||||
|
eeg_struct = {
|
||||||
|
'data': Data,
|
||||||
|
'chn': channel_names,
|
||||||
|
'sample_rate': sample_rate,
|
||||||
|
'node_number': node_number,
|
||||||
|
't': t,
|
||||||
|
'electrode_name': electrode_name,
|
||||||
|
'electrode_xyz': electrode_xyz,
|
||||||
|
'electrode_coord_system': electrode_coord_system,
|
||||||
|
'meta': meta_struct,
|
||||||
|
}
|
||||||
|
|
||||||
|
fileDir = os.path.join('EEGfiles/',Subject_id,Session_id)
|
||||||
|
os.makedirs(fileDir,exist_ok=True)
|
||||||
|
filePath = os.path.join(fileDir,'eeg_data{}.mat'.format(self.file_num))
|
||||||
|
# 保存到 .mat 文件,顶层变量名为 'eeg'
|
||||||
|
savemat(filePath, {'eeg': eeg_struct})
|
||||||
|
print('EEGfile saved at {}'.format(filePath))
|
||||||
|
self.zmqClient.send_to_all('filePath', filePath)
|
||||||
|
self.file_num += 1
|
||||||
|
|
||||||
|
def read_ch_pos(self,file_path=r'xy_64.xlsx'):
|
||||||
|
"""
|
||||||
|
将电极位置信息转换为Dict
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path: 电极位置存储文件, 必须包含'channel', 'x', 'y', 'z'列
|
||||||
|
|
||||||
|
"""
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
script_dir = sys._MEIPASS
|
||||||
|
else:
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_path = os.path.join(script_dir, file_path)
|
||||||
|
df = pd.read_excel(file_path)
|
||||||
|
# 确保列名正确
|
||||||
|
if not all(col in df.columns for col in ['channel', 'x', 'y', 'z']):
|
||||||
|
raise ValueError("DataFrame必须包含'channel', 'x', 'y', 'z'列")
|
||||||
|
# 创建电极位置字典
|
||||||
|
ch_pos = {}
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
ch_pos[row['channel']] = [row['x'], row['y'], row['z']]
|
||||||
|
return ch_pos
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
'''
|
||||||
|
停止运行
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
self.zmqServer.stop()
|
||||||
|
self.Running=False
|
||||||
185
algorithm_V0/datacollect/eeg_quality_check-mat.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
EEG Data Quality Check - eeg_data0.mat
|
||||||
|
===================================
|
||||||
|
1. Time Domain Signal (Full Duration)
|
||||||
|
2. Amplitude Spectrum (FFT)
|
||||||
|
3. Power Spectral Density (Linear Scale)
|
||||||
|
4. Power Spectral Density (dB Scale)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import mne
|
||||||
|
from scipy import signal
|
||||||
|
from scipy.io import loadmat
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_preprocess(filepath):
|
||||||
|
"""Load .mat file (custom format) and basic preprocessing."""
|
||||||
|
mat_data = loadmat(filepath, simplify_cells=True)
|
||||||
|
eeg = mat_data['eeg']
|
||||||
|
|
||||||
|
# Extract data (shape: samples x channels)
|
||||||
|
data = eeg['data'].T # Transpose to (channels x samples)
|
||||||
|
sfreq = eeg['sample_rate']
|
||||||
|
|
||||||
|
# Get channel names (try multiple possible keys)
|
||||||
|
if 'chn' in eeg:
|
||||||
|
ch_names = list(eeg['chn'])
|
||||||
|
elif 'electrode_name' in eeg:
|
||||||
|
ch_names = list(eeg['electrode_name'])
|
||||||
|
else:
|
||||||
|
n_channels = data.shape[0]
|
||||||
|
ch_names = [f'Ch{i+1}' for i in range(n_channels)]
|
||||||
|
|
||||||
|
# Create MNE Info object
|
||||||
|
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
|
||||||
|
raw = mne.io.RawArray(data, info)
|
||||||
|
|
||||||
|
raw.filter(l_freq=0.5, h_freq=10, fir_design='firwin', verbose=False)
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
filepath = r"D:\Ivey\Code_New_Proj\brainplot\plot64\eeg_data0511.mat"
|
||||||
|
output_path = r"D:\Ivey\Code_New_Proj\brainplot\plot64\eeg_quality_check_depression.png"
|
||||||
|
raw = load_and_preprocess(filepath)
|
||||||
|
|
||||||
|
# Print all channel names first
|
||||||
|
print(f"\nAvailable channels ({len(raw.ch_names)}):")
|
||||||
|
for i, ch in enumerate(raw.ch_names):
|
||||||
|
print(f" {i:3d}: {ch}")
|
||||||
|
|
||||||
|
select_channel = ['AIN5']
|
||||||
|
raw.pick(select_channel)
|
||||||
|
|
||||||
|
# Use all channels, full duration
|
||||||
|
ch_names = raw.ch_names
|
||||||
|
n_channels = len(ch_names)
|
||||||
|
data = raw.get_data()
|
||||||
|
sfreq = raw.info['sfreq']
|
||||||
|
n_samples = data.shape[1]
|
||||||
|
duration = n_samples / sfreq
|
||||||
|
|
||||||
|
print(f"Info: {n_channels} channels, {duration:.1f}s, {sfreq:.0f} Hz")
|
||||||
|
|
||||||
|
# Compute frequency domain data
|
||||||
|
n_fft = 2**int(np.ceil(np.log2(n_samples)))
|
||||||
|
freqs_fft = np.fft.rfftfreq(n_fft, 1 / sfreq)
|
||||||
|
fft_vals = np.fft.rfft(data, n=n_fft)
|
||||||
|
amplitude = np.abs(fft_vals) / n_fft * 2
|
||||||
|
|
||||||
|
freqs_psd, psd = signal.welch(data, fs=sfreq, nperseg=4096,
|
||||||
|
noverlap=2048, scaling='density')
|
||||||
|
|
||||||
|
# Frequency mask: 0.5-80 Hz
|
||||||
|
mask_fft = (freqs_fft >= 0.5) & (freqs_fft <= 80)
|
||||||
|
mask_psd = (freqs_psd >= 0.5) & (freqs_psd <= 80)
|
||||||
|
|
||||||
|
freq_fft = freqs_fft[mask_fft]
|
||||||
|
freq_psd = freqs_psd[mask_psd]
|
||||||
|
|
||||||
|
# Plot: 4 rows x 1 column
|
||||||
|
fig, axes = plt.subplots(4, 1, figsize=(16, 20))
|
||||||
|
fig.suptitle(f'EEG Data Quality Check — {", ".join(ch_names)}, '
|
||||||
|
f'Full Duration: {duration:.1f}s',
|
||||||
|
fontsize=16, fontweight='bold', y=0.995)
|
||||||
|
|
||||||
|
# Colormap for distinct channel
|
||||||
|
cmap = plt.cm.tab10 if n_channels <= 10 else plt.cm.tab20
|
||||||
|
colors = [cmap(i) for i in np.linspace(0, 1, n_channels)]
|
||||||
|
|
||||||
|
# ---- Row 1: Time Domain Signal ----
|
||||||
|
ax = axes[0]
|
||||||
|
offset = 0
|
||||||
|
step = max(100, np.std(data, axis=1).mean() * 1e6 * 4)
|
||||||
|
|
||||||
|
# Downsample for display
|
||||||
|
ds = max(1, n_samples // (int(duration) * 500))
|
||||||
|
t = np.arange(0, n_samples, ds) / sfreq
|
||||||
|
|
||||||
|
for i in range(n_channels):
|
||||||
|
sig = data[i, ::ds] * 1e6 + offset
|
||||||
|
ax.plot(t, sig, linewidth=0.5, alpha=0.9, color=colors[i], label=ch_names[i])
|
||||||
|
ax.text(t[0] - 0.5, offset, ch_names[i], fontsize=7, va='center', ha='right', color=colors[i])
|
||||||
|
offset += step
|
||||||
|
|
||||||
|
ax.set_xlim(0, duration)
|
||||||
|
ax.set_xlabel('Time (s)')
|
||||||
|
ax.set_ylabel('Amplitude (μV)')
|
||||||
|
ax.set_title('1. Time Domain Signal (Full Duration)', fontweight='bold')
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.legend(loc='upper right', fontsize=7, ncol=max(1, n_channels // 3), framealpha=0.8)
|
||||||
|
|
||||||
|
# ---- Row 2: Amplitude Spectrum (FFT) ----
|
||||||
|
ax = axes[1]
|
||||||
|
amp_data = amplitude[:, mask_fft] * 1e6 # (n_channels, n_freqs)
|
||||||
|
for i in range(n_channels):
|
||||||
|
ax.plot(freq_fft, amp_data[i], color=colors[i], linewidth=1.0, alpha=0.85, label=ch_names[i])
|
||||||
|
ax.axvline(50, color='red', linestyle='--', alpha=0.6, label='50 Hz Mains')
|
||||||
|
ax.set_xlim(0.5, 30)
|
||||||
|
ax.set_xlabel('Frequency (Hz)')
|
||||||
|
ax.set_ylabel('Amplitude (μV)')
|
||||||
|
ax.set_title('2. Amplitude Spectrum (FFT)', fontweight='bold')
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.legend(loc='upper right', fontsize=7, ncol=max(1, n_channels // 3), framealpha=0.8)
|
||||||
|
|
||||||
|
# ---- Row 3: PSD (Linear Scale) ----
|
||||||
|
ax = axes[2]
|
||||||
|
psd_data = psd[:, mask_psd] * 1e12 # (n_channels, n_freqs)
|
||||||
|
for i in range(n_channels):
|
||||||
|
ax.plot(freq_psd, psd_data[i], color=colors[i], linewidth=1.0, alpha=0.85, label=ch_names[i])
|
||||||
|
ax.axvline(50, color='red', linestyle='--', alpha=0.6, label='50 Hz Mains')
|
||||||
|
ax.set_xlim(0.5, 80)
|
||||||
|
ax.set_xlabel('Frequency (Hz)')
|
||||||
|
ax.set_ylabel('Power (μV²/Hz)')
|
||||||
|
ax.set_title('3. Power Spectral Density (Linear Scale)', fontweight='bold')
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.legend(loc='upper right', fontsize=7, ncol=max(1, n_channels // 3), framealpha=0.8)
|
||||||
|
|
||||||
|
# ---- Row 4: PSD (dB Scale) ----
|
||||||
|
ax = axes[3]
|
||||||
|
for i in range(n_channels):
|
||||||
|
psd_dbi = 10 * np.log10(psd_data[i] + 1e-20)
|
||||||
|
ax.plot(freq_psd, psd_dbi, color=colors[i], linewidth=1.0, alpha=0.85, label=ch_names[i])
|
||||||
|
ax.axvline(50, color='red', linestyle='--', alpha=0.6, label='50 Hz Mains')
|
||||||
|
ax.set_xlim(0.5, 80)
|
||||||
|
ax.set_xlabel('Frequency (Hz)')
|
||||||
|
ax.set_ylabel('Power (dB)')
|
||||||
|
ax.set_title('4. Power Spectral Density (dB Scale)', fontweight='bold')
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.legend(loc='upper right', fontsize=7, ncol=max(1, n_channels // 3), framealpha=0.8)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.subplots_adjust(top=0.97)
|
||||||
|
|
||||||
|
plt.savefig(output_path, dpi=150, bbox_inches='tight',
|
||||||
|
facecolor='white', edgecolor='none')
|
||||||
|
print(f"Figure saved to: {output_path}")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from scipy.io import loadmat
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
mat = loadmat(r'D:\Ivey\Code_New_Proj\brainplot\plot64\eeg_data0511.mat', simplify_cells=True)
|
||||||
|
data = mat['eeg']['data'] # (samples, channels)
|
||||||
|
sfreq = 250
|
||||||
|
seg1 = data[0:int(10*sfreq), :] # 0-10s
|
||||||
|
seg2 = data[int(10*sfreq):int(20*sfreq), :] # 10-20s
|
||||||
|
|
||||||
|
print('Segment 1 (0-10s) shape:', seg1.shape)
|
||||||
|
print('Segment 2 (10-20s) shape:', seg2.shape)
|
||||||
|
print('Are they equal?', np.allclose(seg1, seg2))
|
||||||
|
print('Max difference:', np.max(np.abs(seg1 - seg2)))
|
||||||
|
print('Mean difference:', np.mean(np.abs(seg1 - seg2)))
|
||||||
|
|
||||||
|
# Check correlation
|
||||||
|
corr = np.corrcoef(seg1.flatten(), seg2.flatten())[0, 1]
|
||||||
|
print(f'Correlation: {corr:.4f}')
|
||||||
BIN
algorithm_V0/datacollect/eeg_quality_check_depression.png
Normal file
|
After Width: | Height: | Size: 521 KiB |
193
algorithm_V0/datacollect/protocol.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
from typing import List, Tuple, Union, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolFrame:
|
||||||
|
# 协议常量
|
||||||
|
FRAME_HEADER = 0xAA
|
||||||
|
FRAME_TAIL1 = 0x55
|
||||||
|
FRAME_TAIL2 = 0x55
|
||||||
|
RESERVED_SIZE = 6
|
||||||
|
MIN_FRAME_SIZE = 13 # 帧头1 + 功能1 + 长度2 + 预留6 + CRC1 + 包尾2
|
||||||
|
MAX_DATA_LENGTH = 0xFFFF # 最大数据长度 (2字节能表示的最大值)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_crc8(data: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
计算CRC8校验值
|
||||||
|
Args:
|
||||||
|
data: 需要计算CRC的数据
|
||||||
|
Returns:
|
||||||
|
一个字节的CRC值(bytes类型)
|
||||||
|
"""
|
||||||
|
crc = 0
|
||||||
|
for byte in data:
|
||||||
|
crc ^= byte
|
||||||
|
for _ in range(8):
|
||||||
|
crc = ((crc << 1) ^ 0x07 if crc & 0x80 else crc << 1) & 0xFF
|
||||||
|
return bytes([crc])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pack(cls, function, data: Union[bytes, bytearray, List[int]],
|
||||||
|
reserved: Optional[Union[bytes, bytearray, List[int]]] = None) -> bytes:
|
||||||
|
"""
|
||||||
|
协议打包函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function: 功能码 (1字节)
|
||||||
|
data: 数据块
|
||||||
|
reserved: 预留字节(6字节,可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
打包后的字节数据
|
||||||
|
"""
|
||||||
|
# 检查功能码
|
||||||
|
if function != None:
|
||||||
|
if not 0 <= function <= 0xFF:
|
||||||
|
raise ValueError("功能码必须是1字节")
|
||||||
|
|
||||||
|
# 转换数据为bytearray
|
||||||
|
if isinstance(data, list):
|
||||||
|
data = bytearray(data)
|
||||||
|
elif isinstance(data, bytes):
|
||||||
|
data = bytearray(data)
|
||||||
|
|
||||||
|
# 检查数据长度
|
||||||
|
data_length = len(data)
|
||||||
|
if data_length > cls.MAX_DATA_LENGTH:
|
||||||
|
raise ValueError(f"数据长度超过最大值 {cls.MAX_DATA_LENGTH}")
|
||||||
|
|
||||||
|
# 处理预留字节
|
||||||
|
if reserved is None:
|
||||||
|
reserved = bytearray([0] * cls.RESERVED_SIZE)
|
||||||
|
else:
|
||||||
|
if isinstance(reserved, list):
|
||||||
|
reserved = bytearray(reserved)
|
||||||
|
elif isinstance(reserved, bytes):
|
||||||
|
reserved = bytearray(reserved)
|
||||||
|
if len(reserved) != cls.RESERVED_SIZE:
|
||||||
|
raise ValueError(f"预留字节必须是{cls.RESERVED_SIZE}字节")
|
||||||
|
|
||||||
|
# 构建帧
|
||||||
|
frame = bytearray([cls.FRAME_HEADER]) # 帧头 (1字节)
|
||||||
|
if function != None:
|
||||||
|
frame.append(function) # 功能码 (1字节)
|
||||||
|
data_length+=6
|
||||||
|
|
||||||
|
# 数据长度 (2字节,大端序)
|
||||||
|
frame.append((data_length >> 8) & 0xFF) # 高字节
|
||||||
|
frame.append(data_length & 0xFF) # 低字节
|
||||||
|
|
||||||
|
if function != None:
|
||||||
|
frame.extend(reserved) # 预留字节 (6字节)
|
||||||
|
frame.extend(data) # 数据块 (变长)
|
||||||
|
|
||||||
|
# 计算CRC (从功能码开始到数据块结束)
|
||||||
|
crc = cls.calculate_crc8(frame[1:]) # 不包含帧头
|
||||||
|
frame.extend(crc) # CRC校验 (1字节)
|
||||||
|
|
||||||
|
# 添加帧尾
|
||||||
|
frame.extend([cls.FRAME_TAIL1, cls.FRAME_TAIL2]) # 帧尾 (2字节)
|
||||||
|
|
||||||
|
return bytes(frame)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unpack(cls, data: Union[bytes, bytearray]) -> Tuple[int, bytearray, bytearray]:
|
||||||
|
"""
|
||||||
|
协议解包函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 待解析的字节数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(功能码, 数据块, 预留字节)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当数据格式不正确时
|
||||||
|
"""
|
||||||
|
# 检查数据长度
|
||||||
|
if len(data) < cls.MIN_FRAME_SIZE:
|
||||||
|
raise ValueError("数据长度不足")
|
||||||
|
|
||||||
|
# 检查帧头
|
||||||
|
if data[0] != cls.FRAME_HEADER:
|
||||||
|
raise ValueError("帧头错误")
|
||||||
|
|
||||||
|
# 检查帧尾
|
||||||
|
if data[-2:] != bytes([cls.FRAME_TAIL1, cls.FRAME_TAIL2]):
|
||||||
|
raise ValueError("帧尾错误")
|
||||||
|
|
||||||
|
# 解析基本信息
|
||||||
|
function = data[1] # 功能码 (1字节)
|
||||||
|
|
||||||
|
# 数据长度 (2字节,大端序)
|
||||||
|
data_length = (data[2] << 8) | data[3]
|
||||||
|
|
||||||
|
reserved = data[4:10] # 预留字节 (6字节)
|
||||||
|
|
||||||
|
# 检查数据长度
|
||||||
|
expected_length = cls.MIN_FRAME_SIZE + data_length
|
||||||
|
if len(data) != expected_length:
|
||||||
|
raise ValueError(f"数据长度不匹配: 期望{expected_length}字节,实际{len(data)}字节")
|
||||||
|
|
||||||
|
# 提取数据块
|
||||||
|
payload = data[10:10 + data_length]
|
||||||
|
|
||||||
|
# 验证CRC (从功能码开始到数据块结束)
|
||||||
|
received_crc = data[-3]
|
||||||
|
calculated_crc = cls.calculate_crc8(data[1:-3])[0] # 获取字节值
|
||||||
|
|
||||||
|
if received_crc != calculated_crc:
|
||||||
|
raise ValueError(f"CRC校验失败: 期望{calculated_crc:02X},实际{received_crc:02X}")
|
||||||
|
|
||||||
|
return function, bytearray(payload), bytearray(reserved)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def print_hex(data: bytes, label: str = ""):
|
||||||
|
"""打印十六进制数据,并按字节添加空格"""
|
||||||
|
hex_str = ' '.join([f"{b:02X}" for b in data])
|
||||||
|
if label:
|
||||||
|
print(f"{label}: {hex_str}")
|
||||||
|
else:
|
||||||
|
print(hex_str)
|
||||||
|
|
||||||
|
|
||||||
|
def print_frame_details(data: bytes):
|
||||||
|
"""打印帧的详细信息"""
|
||||||
|
print("帧详细信息:")
|
||||||
|
print(f"帧头: {data[0]:02X}")
|
||||||
|
print(f"功能码: {data[1]:02X}")
|
||||||
|
print(f"数据长度: {data[2]:02X} {data[3]:02X} ({(data[2] << 8) | data[3]}字节)")
|
||||||
|
print(f"预留字节: {' '.join([f'{b:02X}' for b in data[4:10]])}")
|
||||||
|
data_length = (data[2] << 8) | data[3]
|
||||||
|
print(f"数据块: {' '.join([f'{b:02X}' for b in data[10:10 + data_length]])}")
|
||||||
|
print(f"CRC校验: {data[-3]:02X}")
|
||||||
|
print(f"帧尾: {data[-2]:02X} {data[-1]:02X}")
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
def example_usage():
|
||||||
|
try:
|
||||||
|
|
||||||
|
|
||||||
|
# 示例1:简单数据打包
|
||||||
|
function_code = 0x01
|
||||||
|
data = [0x1]
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, data)
|
||||||
|
print_hex(packed_data, "示例1 - 完整帧")
|
||||||
|
print_frame_details(packed_data)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 示例3:解包验证
|
||||||
|
function, payload, reserved = ProtocolFrame.unpack(packed_data)
|
||||||
|
print("解包结果:")
|
||||||
|
print(f"功能码: 0x{function:02X}")
|
||||||
|
print_hex(payload, "数据块")
|
||||||
|
print_hex(reserved, "预留字节")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"错误: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
example_usage()
|
||||||
17
algorithm_V0/datacollect/start_parse.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
|
||||||
|
import time
|
||||||
|
from eegParser import Parser_main
|
||||||
|
from RunOnce import is_program_running
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
if not is_program_running():
|
||||||
|
parser_ = Parser_main()
|
||||||
|
parser_.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
parser_.start()
|
||||||
|
while not parser_.zmqServer.IsExitApp:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
parser_.stop()
|
||||||
137
algorithm_V0/datacollect/verify_build.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
PyInstaller 打包验证脚本
|
||||||
|
用于在没有 EEG 设备的情况下验证打包是否成功
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
def check_pyinstaller_installed():
|
||||||
|
"""检查 PyInstaller 是否安装"""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['pyinstaller', '--version'],
|
||||||
|
capture_output=True, text=True)
|
||||||
|
print(f"✓ PyInstaller 版本: {result.stdout.strip()}")
|
||||||
|
return True
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("✗ PyInstaller 未安装")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_dist_folder():
|
||||||
|
"""检查 dist 文件夹是否存在"""
|
||||||
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
dist_dir = os.path.join(base_dir, 'dist', 'start_parse')
|
||||||
|
|
||||||
|
if os.path.exists(dist_dir):
|
||||||
|
print(f"✓ dist 文件夹存在: {dist_dir}")
|
||||||
|
|
||||||
|
# 检查 exe 文件
|
||||||
|
exe_path = os.path.join(dist_dir, 'start_parse.exe')
|
||||||
|
if os.path.exists(exe_path):
|
||||||
|
size_mb = os.path.getsize(exe_path) / (1024 * 1024)
|
||||||
|
print(f"✓ 可执行文件存在: start_parse.exe ({size_mb:.1f} MB)")
|
||||||
|
else:
|
||||||
|
print("✗ 可执行文件不存在")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查资源文件
|
||||||
|
xlsx_path = os.path.join(dist_dir, 'xy_64.xlsx')
|
||||||
|
if os.path.exists(xlsx_path):
|
||||||
|
print(f"✓ 资源文件存在: xy_64.xlsx")
|
||||||
|
else:
|
||||||
|
print("✗ 资源文件 xy_64.xlsx 不存在")
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"✗ dist 文件夹不存在,请先运行打包")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_dependencies():
|
||||||
|
"""检查关键依赖是否在打包中"""
|
||||||
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
dist_dir = os.path.join(base_dir, 'dist', 'start_parse')
|
||||||
|
|
||||||
|
if not os.path.exists(dist_dir):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查关键 DLL 文件
|
||||||
|
critical_dlls = [
|
||||||
|
# zmq 依赖
|
||||||
|
'libzmq.pyd',
|
||||||
|
# numpy 依赖
|
||||||
|
'numpy.core._multiarray_umath.cp310-win_amd64.pyd',
|
||||||
|
# scipy 依赖
|
||||||
|
'scipy.special._ufuncs.cp310-win_amd64.pyd',
|
||||||
|
]
|
||||||
|
|
||||||
|
print("\n检查关键依赖文件:")
|
||||||
|
found_count = 0
|
||||||
|
for dll in critical_dlls:
|
||||||
|
found = False
|
||||||
|
for root, dirs, files in os.walk(dist_dir):
|
||||||
|
if dll in files:
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
status = "✓" if found else "✗"
|
||||||
|
print(f" {status} {dll}")
|
||||||
|
if found:
|
||||||
|
found_count += 1
|
||||||
|
|
||||||
|
return found_count >= len(critical_dlls) // 2
|
||||||
|
|
||||||
|
def test_imports():
|
||||||
|
"""测试关键模块是否可以导入"""
|
||||||
|
print("\n测试模块导入:")
|
||||||
|
|
||||||
|
modules = ['zmq', 'serial', 'numpy', 'pandas', 'scipy']
|
||||||
|
success = True
|
||||||
|
|
||||||
|
for mod in modules:
|
||||||
|
try:
|
||||||
|
__import__(mod)
|
||||||
|
print(f" ✓ {mod}")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f" ✗ {mod}: {e}")
|
||||||
|
success = False
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=" * 60)
|
||||||
|
print("PyInstaller 打包验证")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
checks = [
|
||||||
|
("1. 检查 PyInstaller 安装", check_pyinstaller_installed),
|
||||||
|
("2. 检查 dist 文件夹", check_dist_folder),
|
||||||
|
("3. 检查依赖文件", check_dependencies),
|
||||||
|
("4. 测试模块导入", test_imports),
|
||||||
|
]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for name, check_func in checks:
|
||||||
|
print(f"\n{name}")
|
||||||
|
print("-" * 40)
|
||||||
|
results.append(check_func())
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("验证结果汇总:")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
all_passed = all(results)
|
||||||
|
if all_passed:
|
||||||
|
print("✓ 所有检查通过!打包成功。")
|
||||||
|
print("\n下一步:")
|
||||||
|
print(" 1. 将 dist/start_parse 文件夹复制到目标电脑")
|
||||||
|
print(" 2. 连接 EEG 设备并运行 start_parse.exe")
|
||||||
|
print(" 3. 观察控制台输出是否正常")
|
||||||
|
else:
|
||||||
|
print("✗ 部分检查未通过,请查看上述详细信息")
|
||||||
|
|
||||||
|
return all_passed
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
57
algorithm_V0/datacollect/zmqClient.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
|
||||||
|
class zmqClient:
|
||||||
|
def __init__(self, host, port):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.client_socket = None
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# 记录客户端连接前的状态
|
||||||
|
self.state = {
|
||||||
|
'status_code': None,
|
||||||
|
'energy': None
|
||||||
|
}
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
# 创建 ZeroMQ 上下文
|
||||||
|
self.context = zmq.Context()
|
||||||
|
# 创建 REQ 套接字(请求端)
|
||||||
|
self.client_socket = self.context.socket(zmq.DEALER)
|
||||||
|
# client_id = b'client1'
|
||||||
|
# self.client_socket.setsockopt(zmq.IDENTITY,client_id)
|
||||||
|
self.client_socket.connect(f"tcp://{self.host}:{self.port}") # 连接到服务器
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
def send_to_all(self, method,params):
|
||||||
|
if method in self.state.keys():
|
||||||
|
self.state[method] = params
|
||||||
|
try:
|
||||||
|
if self.running and self.client_socket != None:
|
||||||
|
msg = {'method': method, 'params': params}
|
||||||
|
# 发送响应
|
||||||
|
# print(msg)
|
||||||
|
self.client_socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
||||||
|
else:
|
||||||
|
if method in self.state.keys():
|
||||||
|
self.state[method] = params
|
||||||
|
except ConnectionResetError:
|
||||||
|
print("Connection lost.")
|
||||||
|
self.running = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
def close_connection(self):
|
||||||
|
self.running = False
|
||||||
|
self.client_socket.close()
|
||||||
|
self.context.term()
|
||||||
|
print("Client closed explicitly.")
|
||||||
|
# 使用TCP客户端
|
||||||
|
if __name__ == "__main__":
|
||||||
|
client = zmqClient('127.0.0.1', 8099)
|
||||||
|
client.connect()
|
||||||
|
# client.close_connection()
|
||||||
119
algorithm_V0/datacollect/zmqServer.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
import threading
|
||||||
|
import json
|
||||||
|
from SunnyLinker import SunnyLinker64
|
||||||
|
|
||||||
|
class zmqServer(threading.Thread):
|
||||||
|
def __init__(self, host='0.0.0.0', port=8099):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.running = False
|
||||||
|
self.get_Impedance = False # 是否返回阻抗值
|
||||||
|
self.open_Impedance = None # 是否开启阻抗检测功能
|
||||||
|
self.StartDecode = False # false 停止解码,true=开始解码
|
||||||
|
self.StartTrain = False # False未进入训练状态,True处于训练状态
|
||||||
|
self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态
|
||||||
|
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
|
||||||
|
self.IsExitApp = False # 当socket收到2的时候,就置为True,代表遥退出系统了。
|
||||||
|
self.getReport = False # 获取训练报告内容
|
||||||
|
self.mat_generate = False # 保存mat文件,True开始,False暂停
|
||||||
|
self.reset_mat_buffer = False # 重置缓冲区标志,True表示下次开始采集需要清空旧数据
|
||||||
|
self.subject_id = None #受试者ID
|
||||||
|
self.session_id = None #Session ID
|
||||||
|
self.save_win = 0 #保存数据时长
|
||||||
|
|
||||||
|
self.daemon = True
|
||||||
|
# 创建 ZeroMQ 上下文
|
||||||
|
self.context = zmq.Context()
|
||||||
|
# 创建 REP 套接字(响应端)
|
||||||
|
self.socket = self.context.socket(zmq.ROUTER)
|
||||||
|
self.socket.bind(f"tcp://{self.host}:{self.port}") # 绑定到端口 8099
|
||||||
|
self.targetFreqs = []
|
||||||
|
self.changeTarget = False # 更换目标频率
|
||||||
|
self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化
|
||||||
|
self.labels = [0x01, 0x02,0x03]
|
||||||
|
|
||||||
|
self.decoder_switch = False #更换解码器
|
||||||
|
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
|
||||||
|
def run(self):
|
||||||
|
self.running = True
|
||||||
|
print(f"Server is running on {self.host}:{self.port}")
|
||||||
|
try:
|
||||||
|
while self.running:
|
||||||
|
# 等待客户端请求
|
||||||
|
_,_,message = self.socket.recv_multipart()
|
||||||
|
message = json.loads(message.decode('utf-8'))
|
||||||
|
print(f"Received request: {message}")
|
||||||
|
# 处理请求
|
||||||
|
method = message.get("method")
|
||||||
|
params = message.get("params")
|
||||||
|
if method == "sync":
|
||||||
|
self.state_mode = 'sync'
|
||||||
|
if method == "targetFreqs":
|
||||||
|
if not isinstance(params,list):
|
||||||
|
print('targetFreqs must be a list')
|
||||||
|
continue
|
||||||
|
if params != self.targetFreqs:
|
||||||
|
self.targetFreqs = params
|
||||||
|
self.changeTarget = True
|
||||||
|
if method == "decoderClass":
|
||||||
|
if not isinstance(params,str):
|
||||||
|
print('decoderClass must be a str')
|
||||||
|
continue
|
||||||
|
# if params != self.decoder_class:
|
||||||
|
self.decoder_class = params
|
||||||
|
self.decoder_switch = True
|
||||||
|
if method == "getReport":
|
||||||
|
self.getReport = True
|
||||||
|
if method == "train":#训练状态
|
||||||
|
self.state_mode = 'train'
|
||||||
|
self.StartTrain = True
|
||||||
|
self.currentLabel = params # 当前刺激端的训练标签
|
||||||
|
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
||||||
|
elif method == "predict":#预测状态
|
||||||
|
self.state_mode = 'predict'
|
||||||
|
if params == 1: #开始解码
|
||||||
|
self.StartDecode = True
|
||||||
|
self.sunnyLinker.push_trigger(0x63)
|
||||||
|
elif params == 2: #停止解码
|
||||||
|
self.IsExitApp = True
|
||||||
|
self.running = False
|
||||||
|
elif method == "rest": #休息状态
|
||||||
|
self.state_mode = 'rest'
|
||||||
|
elif method == "impedance":
|
||||||
|
if params == 1:
|
||||||
|
self.open_Impedance = True # 开启阻抗
|
||||||
|
self.get_Impedance = True # 返回阻抗
|
||||||
|
elif params == 2:
|
||||||
|
self.open_Impedance = False # 关闭阻抗
|
||||||
|
self.get_Impedance = False # 停止返回阻抗
|
||||||
|
elif method == "matGenerate":
|
||||||
|
self.subject_id = str(params['subject_id'])
|
||||||
|
self.session_id = str(params['session_id'])
|
||||||
|
self.save_win = int(params['time'])
|
||||||
|
self.mat_generate = True
|
||||||
|
self.reset_mat_buffer = True # 每次发送 matGenerate 都重置缓冲区
|
||||||
|
elif method == "stop": # 停止
|
||||||
|
self.mat_generate = False
|
||||||
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An socket error occurred: {e}")
|
||||||
|
finally:
|
||||||
|
self.running = False
|
||||||
|
# 关闭套接字和上下文
|
||||||
|
self.socket.close()
|
||||||
|
self.context.term()
|
||||||
|
print("Server socket and context closed.")
|
||||||
|
def stop(self):
|
||||||
|
"""显式关闭服务器"""
|
||||||
|
self.running = False
|
||||||
|
self.socket.close()
|
||||||
|
self.context.term()
|
||||||
|
print("Server closed explicitly.")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
server = zmqServer()
|
||||||
|
server.start()
|
||||||
56
algorithm_V1/.gitignore
vendored
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# data format
|
||||||
|
*.dat
|
||||||
|
*.csv
|
||||||
|
*.edf
|
||||||
|
*.event
|
||||||
|
*.edf.event
|
||||||
|
*.zip
|
||||||
|
*.xlsx
|
||||||
|
*.mat
|
||||||
|
*.json
|
||||||
|
*.7z
|
||||||
|
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself
|
||||||
|
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Other common ignores
|
||||||
|
node_modules/
|
||||||
|
dist/
|
||||||
|
tmp/
|
||||||
|
temp/
|
||||||
|
|
||||||
|
# Project-specific ignores
|
||||||
|
# Ignore all directories in the root
|
||||||
|
# merge64ch_0127/
|
||||||
|
/P300_speller/braindecode/
|
||||||
|
/P300_speller/data/
|
||||||
|
/P300_speller/pyRiemann/
|
||||||
|
/P300_speller/README/
|
||||||
|
/merge64ch_new/
|
||||||
|
/merge64ch_tianjinZMQdebug/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
250
algorithm_V1/bdf_analyzer.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
bdf_analyzer.py
|
||||||
|
|
||||||
|
Analyze .bdf files - print data amplitude range and mean values.
|
||||||
|
Supports single file or batch processing of all .bdf files in a directory.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import numpy as np
|
||||||
|
import mne
|
||||||
|
import scipy.signal as signal
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_bdf(filepath: str, unit: str = "uV") -> dict:
|
||||||
|
"""
|
||||||
|
Analyze a single .bdf file and compute statistics.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filepath : str
|
||||||
|
Path to .bdf file
|
||||||
|
unit : str, optional
|
||||||
|
Display unit (default: uV for microvolts)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
Dictionary containing statistics
|
||||||
|
"""
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"File: {os.path.basename(filepath)}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read BDF file
|
||||||
|
raw = mne.io.read_raw_bdf(filepath, preload=True, verbose=False)
|
||||||
|
|
||||||
|
# Get data (n_channels, n_times) in V
|
||||||
|
data = raw.get_data()
|
||||||
|
n_channels, n_times = data.shape
|
||||||
|
sfreq = raw.info["sfreq"]
|
||||||
|
|
||||||
|
# Convert to microvolts (uV)
|
||||||
|
data_uv = data * 1e6
|
||||||
|
|
||||||
|
# Raw data statistics (V)
|
||||||
|
raw_all = data.flatten()
|
||||||
|
raw_min = float(np.min(raw_all))
|
||||||
|
raw_max = float(np.max(raw_all))
|
||||||
|
raw_mean = float(np.mean(raw_all))
|
||||||
|
raw_std = float(np.std(raw_all))
|
||||||
|
|
||||||
|
# Overall statistics
|
||||||
|
all_values = data_uv.flatten()
|
||||||
|
min_val = np.min(all_values)
|
||||||
|
max_val = np.max(all_values)
|
||||||
|
mean_val = np.mean(all_values)
|
||||||
|
std_val = np.std(all_values)
|
||||||
|
|
||||||
|
print(f"Sampling rate: {sfreq:.2f} Hz")
|
||||||
|
print(f"Channels: {n_channels}")
|
||||||
|
print(f"Samples: {n_times:,}")
|
||||||
|
print(f"Duration: {n_times / sfreq:.2f} sec")
|
||||||
|
print("-" * 40)
|
||||||
|
print(f"[RAW - V]")
|
||||||
|
print(f"Amplitude range: [{raw_min:.6f}, {raw_max:.6f}] V")
|
||||||
|
print(f"Mean value: {raw_mean:.6f} V")
|
||||||
|
print(f"Std deviation: {raw_std:.6f} V")
|
||||||
|
print(f"[RAW - uV]")
|
||||||
|
print(f"Amplitude range: [{min_val:.4f}, {max_val:.4f}] uV")
|
||||||
|
print(f"Mean value: {mean_val:.4f} uV")
|
||||||
|
print(f"Std deviation: {std_val:.4f} uV")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
# Per-channel statistics
|
||||||
|
print("\nPer-channel statistics:")
|
||||||
|
print(f"{'Channel':<15} {'Min (uV)':<15} {'Max (uV)':<15} {'Mean (uV)':<15} {'PSD Peak (Hz)':<15}")
|
||||||
|
print("-" * 75)
|
||||||
|
|
||||||
|
channel_stats = []
|
||||||
|
for i, ch_name in enumerate(raw.ch_names):
|
||||||
|
ch_data = data_uv[i, :]
|
||||||
|
ch_min = np.min(ch_data)
|
||||||
|
ch_max = np.max(ch_data)
|
||||||
|
ch_mean = np.mean(ch_data)
|
||||||
|
|
||||||
|
# PSD peak frequency
|
||||||
|
nperseg = min(1024, n_times)
|
||||||
|
freqs, pxx = signal.welch(ch_data, fs=sfreq, nperseg=nperseg)
|
||||||
|
peak_idx = np.argmax(pxx)
|
||||||
|
peak_freq = freqs[peak_idx]
|
||||||
|
|
||||||
|
print(f"{ch_name:<15} {ch_min:<15.4f} {ch_max:<15.4f} {ch_mean:<15.4f} {peak_freq:<15.2f}")
|
||||||
|
channel_stats.append({
|
||||||
|
"name": ch_name,
|
||||||
|
"min": ch_min,
|
||||||
|
"max": ch_max,
|
||||||
|
"mean": ch_mean,
|
||||||
|
"psd_peak_hz": peak_freq
|
||||||
|
})
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"filepath": filepath,
|
||||||
|
"sfreq": sfreq,
|
||||||
|
"n_channels": n_channels,
|
||||||
|
"n_times": n_times,
|
||||||
|
"duration": n_times / sfreq,
|
||||||
|
"raw_min": raw_min,
|
||||||
|
"raw_max": raw_max,
|
||||||
|
"raw_mean": raw_mean,
|
||||||
|
"raw_std": raw_std,
|
||||||
|
"min": min_val,
|
||||||
|
"max": max_val,
|
||||||
|
"mean": mean_val,
|
||||||
|
"std": std_val,
|
||||||
|
"channels": channel_stats
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Failed to read file: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_directory(dir_path: str) -> list:
|
||||||
|
"""
|
||||||
|
Analyze all .bdf files in a directory.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dir_path : str
|
||||||
|
Directory path
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
List of analysis results for all files
|
||||||
|
"""
|
||||||
|
# Find all .bdf files
|
||||||
|
bdf_files = sorted(glob.glob(os.path.join(dir_path, "*.bdf")))
|
||||||
|
|
||||||
|
if not bdf_files:
|
||||||
|
print(f"[WARNING] No .bdf files found in: {dir_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(bdf_files)} .bdf file(s)\n")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for filepath in bdf_files:
|
||||||
|
result = analyze_bdf(filepath)
|
||||||
|
if result:
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Summary statistics
|
||||||
|
if results:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Summary")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
all_means = [r["mean"] for r in results]
|
||||||
|
all_mins = [r["min"] for r in results]
|
||||||
|
all_maxs = [r["max"] for r in results]
|
||||||
|
|
||||||
|
print(f"File count: {len(results)}")
|
||||||
|
print(f"[RAW - V] Overall range: [{min(r['raw_min'] for r in results):.6f}, {max(r['raw_max'] for r in results):.6f}] V")
|
||||||
|
print(f"[RAW - V] Avg mean: {np.mean([r['raw_mean'] for r in results]):.6f} V")
|
||||||
|
print(f"[RAW - uV] Overall range: [{min(all_mins):.4f}, {max(all_maxs):.4f}] uV")
|
||||||
|
print(f"[RAW - uV] Avg mean: {np.mean(all_means):.4f} uV")
|
||||||
|
print(f"Max value file: {results[np.argmax(all_maxs)]['filepath']}")
|
||||||
|
print(f"Min value file: {results[np.argmin(all_mins)]['filepath']}")
|
||||||
|
|
||||||
|
# Per-channel mean summary across all files
|
||||||
|
n_channels = len(results[0]["channels"])
|
||||||
|
ch_names = [results[0]["channels"][i]["name"] for i in range(n_channels)]
|
||||||
|
ch_mean_over_files = []
|
||||||
|
for ch_idx in range(n_channels):
|
||||||
|
ch_means = [results[f_idx]["channels"][ch_idx]["mean"] for f_idx in range(len(results))]
|
||||||
|
ch_mean_over_files.append(np.mean(ch_means))
|
||||||
|
|
||||||
|
ch_peak_over_files = []
|
||||||
|
for ch_idx in range(n_channels):
|
||||||
|
ch_peaks = [results[f_idx]["channels"][ch_idx]["psd_peak_hz"] for f_idx in range(len(results))]
|
||||||
|
ch_peak_over_files.append(np.mean(ch_peaks))
|
||||||
|
|
||||||
|
print("\nPer-channel mean across all files:")
|
||||||
|
print(f"{'Channel':<15} {'Mean (uV)':<15} {'PSD Peak (Hz)':<15}")
|
||||||
|
print("-" * 45)
|
||||||
|
for ch_name, ch_mean, ch_peak in zip(ch_names, ch_mean_over_files, ch_peak_over_files):
|
||||||
|
print(f"{ch_name:<15} {ch_mean:<15.4f} {ch_peak:<15.2f}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function with CLI support."""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
# Default analysis directory
|
||||||
|
default_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "raw_data")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Analyze .bdf files - print amplitude range and mean values",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog=f"""
|
||||||
|
Examples:
|
||||||
|
python bdf_analyzer.py # Analyze all .bdf in raw_data/
|
||||||
|
python bdf_analyzer.py data/test.bdf # Analyze single file
|
||||||
|
python bdf_analyzer.py data/ # Analyze all .bdf in directory
|
||||||
|
python bdf_analyzer.py . -u mV # Current dir, unit mV
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"path",
|
||||||
|
nargs="?",
|
||||||
|
default=default_dir,
|
||||||
|
help="Path to BDF file or directory containing BDF files (default: raw_data/)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-u", "--unit",
|
||||||
|
choices=["uV", "mV", "V"],
|
||||||
|
default="uV",
|
||||||
|
help="Display unit (default: uV)"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
filepath = args.path
|
||||||
|
|
||||||
|
# Determine file or directory mode
|
||||||
|
if os.path.isfile(filepath):
|
||||||
|
# Single file mode
|
||||||
|
result = analyze_bdf(filepath, unit=args.unit)
|
||||||
|
if result:
|
||||||
|
print("Analysis complete!")
|
||||||
|
elif os.path.isdir(filepath):
|
||||||
|
# Directory mode
|
||||||
|
results = analyze_directory(filepath)
|
||||||
|
if results:
|
||||||
|
print("\nBatch analysis complete!")
|
||||||
|
else:
|
||||||
|
print("No analyzable files found")
|
||||||
|
else:
|
||||||
|
print(f"[ERROR] File does not exist: {filepath}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
203
algorithm_V1/bdf_to_mat.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Convert BDF file to MAT format.
|
||||||
|
|
||||||
|
This script converts a BDF (Biosemi Data Format) EEG file to .mat format,
|
||||||
|
matching the structure of eeg_data.mat.
|
||||||
|
|
||||||
|
Structure of eeg_data.mat:
|
||||||
|
- data: (n_samples, n_channels) float64
|
||||||
|
- chn: (1, n_channels) object - channel names
|
||||||
|
- sample_rate: (1, 1) int64
|
||||||
|
- node_number: (1, 1) int64
|
||||||
|
- t: (n_samples, 1) float64 - time vector in seconds
|
||||||
|
- electrode_name: (1, n_channels) object - electrode names (10-20 system)
|
||||||
|
- electrode_xyz: (n_channels, 3) float64 - electrode 3D coordinates
|
||||||
|
- electrode_coord_system: (1,) <U21
|
||||||
|
- meta: (1, 1) structured - metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import scipy.io
|
||||||
|
import mne
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def get_standard_electrode_coords():
|
||||||
|
"""
|
||||||
|
Standard 10-20 system electrode coordinates.
|
||||||
|
Returns a dictionary mapping electrode names to x, y, z coordinates.
|
||||||
|
Coordinates are approximate spherical projections.
|
||||||
|
"""
|
||||||
|
coords = {
|
||||||
|
'FP1': (-0.0293, 0.0903, -0.0033),
|
||||||
|
'FP2': (0.0293, 0.0903, -0.0033),
|
||||||
|
'FPZ': (0.0, 0.0903, -0.0033),
|
||||||
|
'AF7': (-0.0658, 0.0734, -0.0224),
|
||||||
|
'AF3': (-0.0350, 0.0812, -0.0183),
|
||||||
|
'AF4': (0.0350, 0.0812, -0.0183),
|
||||||
|
'AF8': (0.0658, 0.0734, -0.0224),
|
||||||
|
'F7': (-0.0815, 0.0467, -0.0336),
|
||||||
|
'F5': (-0.0667, 0.0503, -0.0351),
|
||||||
|
'F3': (-0.0489, 0.0560, -0.0370),
|
||||||
|
'F1': (-0.0254, 0.0584, -0.0384),
|
||||||
|
'FZ': (0.0, 0.0584, -0.0384),
|
||||||
|
'F2': (0.0254, 0.0584, -0.0384),
|
||||||
|
'F4': (0.0489, 0.0560, -0.0370),
|
||||||
|
'F6': (0.0667, 0.0503, -0.0351),
|
||||||
|
'F8': (0.0815, 0.0467, -0.0336),
|
||||||
|
'FT7': (-0.0880, 0.0229, -0.0397),
|
||||||
|
'FC5': (-0.0699, 0.0317, -0.0402),
|
||||||
|
'FC3': (-0.0514, 0.0362, -0.0411),
|
||||||
|
'FC1': (-0.0268, 0.0383, -0.0419),
|
||||||
|
'FCZ': (0.0, 0.0383, -0.0419),
|
||||||
|
'FC2': (0.0268, 0.0383, -0.0419),
|
||||||
|
'FC4': (0.0514, 0.0362, -0.0411),
|
||||||
|
'FC6': (0.0699, 0.0317, -0.0402),
|
||||||
|
'FT8': (0.0880, 0.0229, -0.0397),
|
||||||
|
'T7': (-0.0958, 0.0, -0.0411),
|
||||||
|
'T8': (0.0958, 0.0, -0.0411),
|
||||||
|
'C5': (-0.0739, 0.0, -0.0425),
|
||||||
|
'C3': (-0.0544, 0.0, -0.0436),
|
||||||
|
'C1': (-0.0283, 0.0, -0.0444),
|
||||||
|
'CZ': (0.0, 0.0, -0.0444),
|
||||||
|
'C2': (0.0283, 0.0, -0.0444),
|
||||||
|
'C4': (0.0544, 0.0, -0.0436),
|
||||||
|
'C6': (0.0739, 0.0, -0.0425),
|
||||||
|
'TP7': (-0.0880, -0.0229, -0.0397),
|
||||||
|
'CP5': (-0.0699, -0.0317, -0.0402),
|
||||||
|
'CP3': (-0.0514, -0.0362, -0.0411),
|
||||||
|
'CP1': (-0.0268, -0.0383, -0.0419),
|
||||||
|
'CPZ': (0.0, -0.0383, -0.0419),
|
||||||
|
'CP2': (0.0268, -0.0383, -0.0419),
|
||||||
|
'CP4': (0.0514, -0.0362, -0.0411),
|
||||||
|
'CP6': (0.0699, -0.0317, -0.0402),
|
||||||
|
'TP8': (0.0880, -0.0229, -0.0397),
|
||||||
|
'P7': (-0.0815, -0.0467, -0.0336),
|
||||||
|
'P5': (-0.0667, -0.0503, -0.0351),
|
||||||
|
'P3': (-0.0489, -0.0560, -0.0370),
|
||||||
|
'P1': (-0.0254, -0.0584, -0.0384),
|
||||||
|
'PZ': (0.0, -0.0584, -0.0384),
|
||||||
|
'P2': (0.0254, -0.0584, -0.0384),
|
||||||
|
'P4': (0.0489, -0.0560, -0.0370),
|
||||||
|
'P6': (0.0667, -0.0503, -0.0351),
|
||||||
|
'P8': (0.0815, -0.0467, -0.0336),
|
||||||
|
'PO7': (-0.0658, -0.0734, -0.0224),
|
||||||
|
'PO5': (-0.0503, -0.0744, -0.0258),
|
||||||
|
'PO3': (-0.0350, -0.0812, -0.0183),
|
||||||
|
'POZ': (0.0, -0.0829, -0.0172),
|
||||||
|
'PO4': (0.0350, -0.0812, -0.0183),
|
||||||
|
'PO6': (0.0503, -0.0744, -0.0258),
|
||||||
|
'PO8': (0.0658, -0.0734, -0.0224),
|
||||||
|
'O1': (-0.0293, -0.0903, -0.0033),
|
||||||
|
'OZ': (0.0, -0.0903, -0.0033),
|
||||||
|
'O2': (0.0293, -0.0903, -0.0033),
|
||||||
|
'CB1': (-0.0618, -0.0380, -0.0387),
|
||||||
|
'CB2': (0.0618, -0.0380, -0.0387),
|
||||||
|
'A1': (-0.0958, 0.0, 0.0),
|
||||||
|
'A2': (0.0958, 0.0, 0.0),
|
||||||
|
}
|
||||||
|
return coords
|
||||||
|
|
||||||
|
|
||||||
|
def bdf_to_mat(bdf_path, output_path, subject_id='unknown', session_id='unknown'):
|
||||||
|
"""
|
||||||
|
Convert BDF file to MAT format matching eeg_data.mat structure.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
bdf_path : str
|
||||||
|
Path to the input BDF file.
|
||||||
|
output_path : str
|
||||||
|
Path to the output MAT file.
|
||||||
|
subject_id : str, optional
|
||||||
|
Subject identifier. Default is 'unknown'.
|
||||||
|
session_id : str, optional
|
||||||
|
Session identifier. Default is 'unknown'.
|
||||||
|
"""
|
||||||
|
print(f'Loading BDF file: {bdf_path}')
|
||||||
|
raw = mne.io.read_raw_bdf(bdf_path, preload=True, verbose=False)
|
||||||
|
|
||||||
|
# Get basic info
|
||||||
|
ch_names = raw.ch_names
|
||||||
|
n_channels = len(ch_names)
|
||||||
|
sfreq = int(raw.info['sfreq'])
|
||||||
|
data = raw.get_data()
|
||||||
|
|
||||||
|
# BDF data shape: (n_channels, n_samples)
|
||||||
|
# Convert to eeg_data.mat format: (n_samples, n_channels)
|
||||||
|
data = data.T
|
||||||
|
n_samples = data.shape[0]
|
||||||
|
|
||||||
|
# Create time vector (in seconds)
|
||||||
|
t = np.arange(n_samples) / sfreq
|
||||||
|
t = t.reshape(-1, 1)
|
||||||
|
|
||||||
|
# Create channel names array (matching eeg_data.mat structure)
|
||||||
|
chn = np.array([[name] for name in ch_names], dtype=object)
|
||||||
|
|
||||||
|
# Create electrode names (same as channel names for BDF)
|
||||||
|
electrode_name = np.array([[name] for name in ch_names], dtype=object)
|
||||||
|
|
||||||
|
# Get electrode coordinates
|
||||||
|
standard_coords = get_standard_electrode_coords()
|
||||||
|
electrode_xyz = np.zeros((n_channels, 3))
|
||||||
|
for i, name in enumerate(ch_names):
|
||||||
|
if name in standard_coords:
|
||||||
|
electrode_xyz[i] = standard_coords[name]
|
||||||
|
else:
|
||||||
|
print(f'Warning: No standard coordinate for electrode {name}')
|
||||||
|
|
||||||
|
# Create metadata structure
|
||||||
|
start_time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
meta = np.array([(subject_id, session_id, 'CMS/DRL', start_time_str)],
|
||||||
|
dtype=[('subject_id', 'O'), ('session_id', 'O'),
|
||||||
|
('ref', 'O'), ('start_time', 'O')])
|
||||||
|
|
||||||
|
# Create the EEG structure (matching eeg_data.mat format)
|
||||||
|
eeg_struct = np.array([(data, chn, [[sfreq]], [[n_channels]],
|
||||||
|
t, electrode_name, electrode_xyz,
|
||||||
|
'buzsaki', meta)],
|
||||||
|
dtype=[('data', 'O'), ('chn', 'O'),
|
||||||
|
('sample_rate', 'O'), ('node_number', 'O'),
|
||||||
|
('t', 'O'), ('electrode_name', 'O'),
|
||||||
|
('electrode_xyz', 'O'), ('electrode_coord_system', 'O'),
|
||||||
|
('meta', 'O')])
|
||||||
|
|
||||||
|
# Save to MAT file
|
||||||
|
print(f'Saving to: {output_path}')
|
||||||
|
scipy.io.savemat(output_path, {'eeg': eeg_struct}, do_compression=True)
|
||||||
|
|
||||||
|
print(f'\nConversion complete!')
|
||||||
|
print(f' Channels: {n_channels}')
|
||||||
|
print(f' Samples: {n_samples}')
|
||||||
|
print(f' Duration: {n_samples / sfreq:.2f} seconds')
|
||||||
|
print(f' Sample rate: {sfreq} Hz')
|
||||||
|
print(f' Data shape: {data.shape}')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# File paths
|
||||||
|
bdf_path = r'D:\Ivey\Code_New_Proj\Debug_Depression\algorithm_version_0521_v0\0515-18.bdf'
|
||||||
|
output_path = r'D:\Ivey\Code_New_Proj\Debug_Depression\algorithm_version_0521_v0\0515-18.mat'
|
||||||
|
|
||||||
|
# Convert
|
||||||
|
bdf_to_mat(bdf_path, output_path, subject_id='lvpeng', session_id='01')
|
||||||
|
|
||||||
|
# Verify the output
|
||||||
|
print('\n=== Verification ===')
|
||||||
|
mat_data = scipy.io.loadmat(output_path)
|
||||||
|
eeg = mat_data['eeg'][0, 0]
|
||||||
|
|
||||||
|
print(f'Output file keys: {list(mat_data.keys())}')
|
||||||
|
print(f'eeg.data shape: {eeg["data"].shape}')
|
||||||
|
print(f'eeg.chn shape: {eeg["chn"].shape}')
|
||||||
|
print(f'eeg.sample_rate: {eeg["sample_rate"][0, 0]}')
|
||||||
|
print(f'eeg.t shape: {eeg["t"].shape}')
|
||||||
|
print(f'eeg.electrode_name: {eeg["electrode_name"].shape}')
|
||||||
|
print(f'eeg.electrode_xyz shape: {eeg["electrode_xyz"].shape}')
|
||||||
|
print(f'eeg.electrode_coord_system: {eeg["electrode_coord_system"][0]}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
96
algorithm_V1/build_algorithm.spec
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# -*- mode: python ; coding: utf-8 -*-
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from PyInstaller.utils.hooks import collect_submodules, collect_data_files
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 1. 工程配置区 (Project Config)
|
||||||
|
# ========================================================
|
||||||
|
block_cipher = None
|
||||||
|
ENTRY_POINT = 'runDecoder.py'
|
||||||
|
APP_NAME = 'Depression_Decoder'
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 2. 依赖分析 (Dependency Analysis)
|
||||||
|
# ========================================================
|
||||||
|
# 收集 mne, sklearn, scipy 可能遗漏的隐藏导入
|
||||||
|
hidden_imports = [
|
||||||
|
'infer_pth', # 你的动态导入模块
|
||||||
|
'sklearn.utils._cython_blas',
|
||||||
|
'sklearn.neighbors.typedefs',
|
||||||
|
'sklearn.neighbors.quad_tree',
|
||||||
|
'sklearn.tree',
|
||||||
|
'sklearn.tree._utils',
|
||||||
|
]
|
||||||
|
# 自动收集 mne 的子模块
|
||||||
|
hidden_imports += collect_submodules('mne')
|
||||||
|
# 收集 torch 相关的隐式导入(虽然 PyInstaller 通常能处理,但显式更安全)
|
||||||
|
hidden_imports += ['torch', 'torchvision']
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 3. 资源锚定 (Data Anchoring)
|
||||||
|
# ========================================================
|
||||||
|
# Analysis 中的 datas 用于将文件嵌入到内部(运行时在临时目录或 _internal)
|
||||||
|
# 这里我们留空,改为在 COLLECT 阶段通过 Tree 显式复制到 EXE 旁,
|
||||||
|
# 这样生成的文件夹里能直接看到 model 和 raw_data
|
||||||
|
datas = []
|
||||||
|
|
||||||
|
# 收集 mne 的数据文件(如果需要默认配置)
|
||||||
|
datas += collect_data_files('mne')
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 4. 构建流程 (Build Process)
|
||||||
|
# ========================================================
|
||||||
|
a = Analysis(
|
||||||
|
[ENTRY_POINT],
|
||||||
|
pathex=[],
|
||||||
|
binaries=[],
|
||||||
|
datas=datas,
|
||||||
|
hiddenimports=hidden_imports,
|
||||||
|
hookspath=[],
|
||||||
|
hooksconfig={},
|
||||||
|
runtime_hooks=[],
|
||||||
|
excludes=['tkinter', 'PyQt5', 'PySide2', 'IPython'], # 排除 GUI 和交互式库减小体积
|
||||||
|
win_no_prefer_redirects=False,
|
||||||
|
win_private_assemblies=False,
|
||||||
|
cipher=block_cipher,
|
||||||
|
noarchive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||||
|
|
||||||
|
exe = EXE(
|
||||||
|
pyz,
|
||||||
|
a.scripts,
|
||||||
|
[],
|
||||||
|
exclude_binaries=True,
|
||||||
|
name=APP_NAME,
|
||||||
|
debug=False,
|
||||||
|
bootloader_ignore_signals=False,
|
||||||
|
strip=False,
|
||||||
|
upx=False,
|
||||||
|
console=True,
|
||||||
|
disable_windowed_traceback=False,
|
||||||
|
argv_emulation=False,
|
||||||
|
target_arch=None,
|
||||||
|
codesign_identity=None,
|
||||||
|
entitlements_file=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 5. 打包模式: OneDir (单文件夹) + 资源旁路
|
||||||
|
# ========================================================
|
||||||
|
# 使用 Tree 将文件夹原样复制到 dist/APP_NAME/ 下
|
||||||
|
# 格式: Tree('源路径', prefix='目标子目录')
|
||||||
|
|
||||||
|
coll = COLLECT(
|
||||||
|
exe,
|
||||||
|
a.binaries,
|
||||||
|
a.zipfiles,
|
||||||
|
a.datas,
|
||||||
|
strip=False,
|
||||||
|
upx=False,
|
||||||
|
upx_exclude=[],
|
||||||
|
name=APP_NAME,
|
||||||
|
)
|
||||||
77
algorithm_V1/build_with_copy.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 1. 定义路径
|
||||||
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
DIST_DIR = os.path.join(BASE_DIR, 'dist')
|
||||||
|
APP_NAME = 'Depression_Decoder'
|
||||||
|
TARGET_DIR = os.path.join(DIST_DIR, APP_NAME)
|
||||||
|
|
||||||
|
MODEL_SRC = os.path.join(BASE_DIR, 'model')
|
||||||
|
RAW_DATA_SRC = os.path.join(BASE_DIR, 'raw_data')
|
||||||
|
|
||||||
|
MODEL_DST = os.path.join(TARGET_DIR, 'model')
|
||||||
|
RAW_DATA_DST = os.path.join(TARGET_DIR, 'raw_data')
|
||||||
|
|
||||||
|
# 2. 清理旧构建
|
||||||
|
print("[1/3] Cleaning up old builds...")
|
||||||
|
if os.path.exists(DIST_DIR):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(DIST_DIR)
|
||||||
|
print(" Cleaned dist/")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clean dist/: {e}")
|
||||||
|
|
||||||
|
BUILD_DIR = os.path.join(BASE_DIR, 'build')
|
||||||
|
if os.path.exists(BUILD_DIR):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(BUILD_DIR)
|
||||||
|
print(" Cleaned build/")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clean build/: {e}")
|
||||||
|
|
||||||
|
# 3. 运行 PyInstaller
|
||||||
|
print("[2/3] Running PyInstaller...")
|
||||||
|
# 注意:我们这里不传 --noupx,因为已经在 spec 文件里把 upx=False 写死了
|
||||||
|
cmd = [
|
||||||
|
"pyinstaller",
|
||||||
|
"build_algorithm.spec",
|
||||||
|
"--clean"
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.check_call(cmd, shell=True)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
print("Error: PyInstaller failed.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 4. 复制外部资源文件夹
|
||||||
|
print("[3/3] Copying external resources...")
|
||||||
|
|
||||||
|
# 复制 model 文件夹
|
||||||
|
if os.path.exists(MODEL_SRC):
|
||||||
|
if os.path.exists(MODEL_DST):
|
||||||
|
shutil.rmtree(MODEL_DST)
|
||||||
|
shutil.copytree(MODEL_SRC, MODEL_DST)
|
||||||
|
print(f" Copied: model -> {MODEL_DST}")
|
||||||
|
else:
|
||||||
|
print(f" Warning: Source model dir not found at {MODEL_SRC}")
|
||||||
|
|
||||||
|
# 复制 raw_data 文件夹
|
||||||
|
if os.path.exists(RAW_DATA_SRC):
|
||||||
|
if os.path.exists(RAW_DATA_DST):
|
||||||
|
shutil.rmtree(RAW_DATA_DST)
|
||||||
|
shutil.copytree(RAW_DATA_SRC, RAW_DATA_DST)
|
||||||
|
print(f" Copied: raw_data -> {RAW_DATA_DST}")
|
||||||
|
else:
|
||||||
|
print(f" Warning: Source raw_data dir not found at {RAW_DATA_SRC}")
|
||||||
|
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print(f"SUCCESS! Build artifacts are in: {TARGET_DIR}")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
561
algorithm_V1/infer_pth.py
Normal file
@@ -0,0 +1,561 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
infer_pth.py
|
||||||
|
|
||||||
|
用途:
|
||||||
|
- 从一个文件夹中自动读取第一个 .mat EEG 文件(64通道或32通道)
|
||||||
|
- 若为64通道,则按 idx64_to_32 映射选出32通道
|
||||||
|
- 提取切片特征(DE + PSD(var近似),不含Asym)
|
||||||
|
- 加载你训练好的 .pth 模型(FusionNet结构)
|
||||||
|
- 输出该受试者的 HC / MDD 判断结果
|
||||||
|
|
||||||
|
运行方式(命令行):
|
||||||
|
python infer_pth_from64_to32.py --eeg_dir "D:\\xxx\\folder" --model_path "C:\\xxx\\model.pth"
|
||||||
|
|
||||||
|
也可在其他py里import:
|
||||||
|
from infer_pth_from64_to32 import predict_hc_mdd
|
||||||
|
res = predict_hc_mdd(eeg_dir, model_path)
|
||||||
|
print(res)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import scipy.io
|
||||||
|
import scipy.signal as signal
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# 0) 配置区(按需改这里)
|
||||||
|
# =========================================================
|
||||||
|
|
||||||
|
# 采样率(必须与训练时一致)
|
||||||
|
SAMPLING_RATE = 250
|
||||||
|
|
||||||
|
# 滑窗参数(必须与训练时一致)
|
||||||
|
WINDOW_SIZE = 500
|
||||||
|
STRIDE = 250
|
||||||
|
|
||||||
|
# 频段(必须与训练时一致)
|
||||||
|
BAND_NAMES = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]
|
||||||
|
BANDS = {
|
||||||
|
"Delta": (1, 4),
|
||||||
|
"Theta": (4, 8),
|
||||||
|
"Alpha": (8, 13),
|
||||||
|
"Beta": (13, 30),
|
||||||
|
"Gamma": (30, 50),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 是否使用扩展特征(DE+PSD)
|
||||||
|
USE_EXTENDED_FEATURES = True
|
||||||
|
|
||||||
|
# 数值稳定项
|
||||||
|
EPS = 1e-12
|
||||||
|
|
||||||
|
# 设备
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# 通道映射
|
||||||
|
IDX64_TO_32 = [
|
||||||
|
23, # C5
|
||||||
|
47, # O1
|
||||||
|
39, # TP7
|
||||||
|
6, # FPZ
|
||||||
|
2, # PO6
|
||||||
|
21, # P4
|
||||||
|
35, # AF7
|
||||||
|
57, # AF3
|
||||||
|
1, # FP2
|
||||||
|
37, # T7
|
||||||
|
63, # F1
|
||||||
|
36, # A1
|
||||||
|
18, # FC4
|
||||||
|
31, # FC5
|
||||||
|
14, # FC2
|
||||||
|
48, # T8
|
||||||
|
60, # P2
|
||||||
|
41, # AF8
|
||||||
|
11, # CP1
|
||||||
|
0, # FP1
|
||||||
|
55, # PO7
|
||||||
|
59, # C1
|
||||||
|
22, # F5
|
||||||
|
10, # CP2
|
||||||
|
16, # C3
|
||||||
|
61, # P1
|
||||||
|
27, # CP5
|
||||||
|
17, # C4
|
||||||
|
26, # CP6
|
||||||
|
62, # F2
|
||||||
|
3, # POZ
|
||||||
|
13, # PO5
|
||||||
|
]
|
||||||
|
|
||||||
|
# 推理阈值:如果模型checkpoint里有 subject_threshold,会优先用它;否则用这个
|
||||||
|
DEFAULT_SUBJECT_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# 1) 模型结构
|
||||||
|
# =========================================================
|
||||||
|
|
||||||
|
class SEBlock(nn.Module):
|
||||||
|
def __init__(self, channels: int, reduction: int = 4) -> None:
|
||||||
|
super().__init__()
|
||||||
|
hidden = max(1, channels // reduction)
|
||||||
|
self.fc1 = nn.Linear(channels, hidden)
|
||||||
|
self.fc2 = nn.Linear(hidden, channels)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
se = F.relu(self.fc1(x))
|
||||||
|
se = torch.sigmoid(self.fc2(se))
|
||||||
|
return x * se
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, in_features: int, out_features: int, dropout: float = 0.3) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(in_features, out_features)
|
||||||
|
self.bn1 = nn.BatchNorm1d(out_features)
|
||||||
|
self.fc2 = nn.Linear(out_features, out_features)
|
||||||
|
self.bn2 = nn.BatchNorm1d(out_features)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.shortcut = nn.Identity()
|
||||||
|
if in_features != out_features:
|
||||||
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.Linear(in_features, out_features),
|
||||||
|
nn.BatchNorm1d(out_features),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
identity = self.shortcut(x)
|
||||||
|
out = F.relu(self.bn1(self.fc1(x)))
|
||||||
|
out = self.dropout(out)
|
||||||
|
out = self.bn2(self.fc2(out))
|
||||||
|
out = F.relu(out + identity)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FusionNet(nn.Module):
|
||||||
|
def __init__(self, num_classes: int = 2, num_eeg_features: int = 320, num_scales: int = 6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_norm = nn.BatchNorm1d(num_eeg_features)
|
||||||
|
|
||||||
|
self.block1 = ResidualBlock(num_eeg_features, 512, dropout=0.4)
|
||||||
|
self.block2 = ResidualBlock(512, 256, dropout=0.3)
|
||||||
|
self.block3 = ResidualBlock(256, 128, dropout=0.2)
|
||||||
|
|
||||||
|
self.attention = SEBlock(128, reduction=4)
|
||||||
|
|
||||||
|
self.final_fc = nn.Sequential(
|
||||||
|
nn.Linear(128, 64),
|
||||||
|
nn.BatchNorm1d(64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cls_head = nn.Linear(64, num_classes)
|
||||||
|
|
||||||
|
# 训练时有回归头也没关系(推理只用cls)
|
||||||
|
self.reg_head = nn.Sequential(
|
||||||
|
nn.Linear(64, 32),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
nn.Linear(32, num_scales),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self) -> None:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.BatchNorm1d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = self.input_norm(x)
|
||||||
|
x = self.block1(x)
|
||||||
|
x = self.block2(x)
|
||||||
|
x = self.block3(x)
|
||||||
|
x = self.attention(x)
|
||||||
|
features = self.final_fc(x)
|
||||||
|
cls_out = self.cls_head(features)
|
||||||
|
reg_out = self.reg_head(features)
|
||||||
|
return cls_out, reg_out
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# 2) mat读取 + 通道裁剪
|
||||||
|
# =========================================================
|
||||||
|
|
||||||
|
def _find_first_mat_file(folder: str) -> str:
|
||||||
|
if not os.path.isdir(folder):
|
||||||
|
raise RuntimeError(f"eeg_dir 不是文件夹: {folder}")
|
||||||
|
mats = sorted([f for f in os.listdir(folder) if f.lower().endswith(".mat")])
|
||||||
|
if not mats:
|
||||||
|
raise RuntimeError(f"文件夹内没有 .mat 文件: {folder}")
|
||||||
|
return os.path.join(folder, mats[0])
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import scipy.io
|
||||||
|
|
||||||
|
def _unwrap_singleton(x):
|
||||||
|
"""
|
||||||
|
把 (1,1) / (1,) 这种包裹层一直剥掉,直到不是 singleton。
|
||||||
|
也处理 object array 的情况。
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
if x.dtype == object and x.size == 1:
|
||||||
|
x = x.item()
|
||||||
|
continue
|
||||||
|
if x.size == 1 and x.ndim >= 1:
|
||||||
|
# 例如 (1,1) 或 (1,) 的数值/对象数组
|
||||||
|
try:
|
||||||
|
x = x.reshape(-1)[0]
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
break
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _try_get_struct_field(v, field_name="data"):
|
||||||
|
"""
|
||||||
|
尝试从以下几种结构中提取字段:
|
||||||
|
1) scipy 读出的 mat_struct(有 _fieldnames)
|
||||||
|
2) numpy structured/record array(dtype.names)
|
||||||
|
"""
|
||||||
|
# case 1: mat_struct(推荐 loadmat(..., struct_as_record=False, squeeze_me=True))
|
||||||
|
if hasattr(v, "_fieldnames") and (field_name in getattr(v, "_fieldnames", [])):
|
||||||
|
return getattr(v, field_name)
|
||||||
|
|
||||||
|
# case 2: structured array
|
||||||
|
if isinstance(v, np.ndarray) and v.dtype.names and (field_name in v.dtype.names):
|
||||||
|
# 常见是 v[field] 仍然是 ndarray / object,需要 unwrap
|
||||||
|
try:
|
||||||
|
return v[field_name]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_eeg_from_mat_any_channels(mat_path: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
读取 .mat 中 EEG 数据,支持:
|
||||||
|
- 直接二维矩阵 (T,C) 或 (C,T)
|
||||||
|
- struct 里有字段 data
|
||||||
|
返回统一为 float32 的 (T, C)
|
||||||
|
"""
|
||||||
|
# 用这两个参数会让 struct 更容易处理:字段变成属性,且自动 squeeze
|
||||||
|
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
|
||||||
|
for k, v in mat.items():
|
||||||
|
if k.startswith("__"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- 1) 直接二维数值矩阵 ---
|
||||||
|
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
|
||||||
|
candidates.append((k, v))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- 2) struct/record:优先提取 data 字段 ---
|
||||||
|
data_field = _try_get_struct_field(v, "data")
|
||||||
|
if data_field is not None:
|
||||||
|
data_field = _unwrap_singleton(data_field)
|
||||||
|
|
||||||
|
# data_field 可能仍然被 object 包一层
|
||||||
|
if isinstance(data_field, np.ndarray) and data_field.dtype == object:
|
||||||
|
data_field = _unwrap_singleton(data_field)
|
||||||
|
|
||||||
|
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
|
||||||
|
# 只收数值矩阵
|
||||||
|
if np.issubdtype(data_field.dtype, np.number) or data_field.dtype == object:
|
||||||
|
candidates.append((f"{k}.data", data_field))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- 3) object array:尝试 item() 解包后再看是不是二维数值矩阵/struct ---
|
||||||
|
if isinstance(v, np.ndarray) and v.dtype == object:
|
||||||
|
vv = _unwrap_singleton(v)
|
||||||
|
|
||||||
|
# 解包后若是二维数值矩阵
|
||||||
|
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
|
||||||
|
candidates.append((k, vv))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解包后若是 struct,再取 data
|
||||||
|
data2 = _try_get_struct_field(vv, "data")
|
||||||
|
if data2 is not None:
|
||||||
|
data2 = _unwrap_singleton(data2)
|
||||||
|
if isinstance(data2, np.ndarray) and data2.ndim == 2:
|
||||||
|
candidates.append((f"{k}.data", data2))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
raise RuntimeError(f"mat里没找到可用EEG二维矩阵或struct.data:{mat_path}")
|
||||||
|
|
||||||
|
# 选一个最像 EEG 的(优先含32/64通道维度的)
|
||||||
|
def score(arr: np.ndarray) -> int:
|
||||||
|
s = 0
|
||||||
|
if 64 in arr.shape: s += 10
|
||||||
|
if 32 in arr.shape: s += 9
|
||||||
|
if 128 in arr.shape: s += 8
|
||||||
|
if 129 in arr.shape: s += 7
|
||||||
|
s += int(np.prod(arr.shape) // 100000) # 大一些更像EEG
|
||||||
|
return s
|
||||||
|
|
||||||
|
candidates.sort(key=lambda kv: score(kv[1]), reverse=True)
|
||||||
|
key, eeg = candidates[0]
|
||||||
|
|
||||||
|
eeg = _unwrap_singleton(eeg)
|
||||||
|
|
||||||
|
# 如果还是 object dtype,尽力转成 float
|
||||||
|
if isinstance(eeg, np.ndarray) and eeg.dtype == object:
|
||||||
|
# 有时 object 里其实是数值
|
||||||
|
eeg = np.array(eeg, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
eeg = np.asarray(eeg, dtype=np.float32)
|
||||||
|
|
||||||
|
if eeg.ndim != 2:
|
||||||
|
raise RuntimeError(f"解析结果不是二维矩阵: key={key}, shape={eeg.shape}, file={mat_path}")
|
||||||
|
|
||||||
|
# 统一为 (T, C)
|
||||||
|
# 常见 (C,T) 或 (T,C),我们用“通道维通常较小”+ “32/64/128/129”判断
|
||||||
|
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
|
||||||
|
eeg = eeg.T
|
||||||
|
elif eeg.shape[1] in (32, 64, 128, 129):
|
||||||
|
# 如果第一维也是这些数且更小,可能是(C,T)
|
||||||
|
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
|
||||||
|
eeg = eeg.T
|
||||||
|
|
||||||
|
return eeg
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_32_channels(eeg: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
输入 (T, C),输出 (T, 32)
|
||||||
|
- 若C=64:按 IDX64_TO_32 选32通道
|
||||||
|
- 若C=32:直接返回
|
||||||
|
"""
|
||||||
|
if eeg.ndim != 2:
|
||||||
|
raise RuntimeError(f"EEG必须是二维(T,C),但得到: {eeg.shape}")
|
||||||
|
|
||||||
|
C = eeg.shape[1]
|
||||||
|
if C == 64:
|
||||||
|
idx = np.asarray(IDX64_TO_32, dtype=np.int64)
|
||||||
|
if idx.min() < 0 or idx.max() >= 64:
|
||||||
|
raise RuntimeError(f"IDX64_TO_32 越界:min={idx.min()}, max={idx.max()} (要求0~63)")
|
||||||
|
return eeg[:, idx]
|
||||||
|
if C == 32:
|
||||||
|
return eeg
|
||||||
|
raise RuntimeError(f"不支持的通道数C={C},当前只支持 64->32 或 32 直推。")
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# 3) 特征提取(DE + PSD(var近似))
|
||||||
|
# =========================================================
|
||||||
|
|
||||||
|
class FeatureExtractor32:
|
||||||
|
"""
|
||||||
|
只针对32通道,输出维度:
|
||||||
|
- USE_EXTENDED_FEATURES=True:DE(32*5) + PSD(32*5) = 320
|
||||||
|
- 否则:DE(32*5) = 160
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fs: int = SAMPLING_RATE,
|
||||||
|
window_size: int = WINDOW_SIZE,
|
||||||
|
stride: int = STRIDE,
|
||||||
|
filter_order: int = 4,
|
||||||
|
zero_phase: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.fs = fs
|
||||||
|
self.window_size = window_size
|
||||||
|
self.stride = stride
|
||||||
|
self.filter_order = filter_order
|
||||||
|
self.zero_phase = zero_phase
|
||||||
|
|
||||||
|
self._sos = {}
|
||||||
|
for bn in BAND_NAMES:
|
||||||
|
low, high = BANDS[bn]
|
||||||
|
self._sos[bn] = signal.butter(
|
||||||
|
self.filter_order, [low, high],
|
||||||
|
btype="band", fs=self.fs, output="sos"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _filter_bands(self, eeg: np.ndarray) -> dict[str, np.ndarray]:
|
||||||
|
out = {}
|
||||||
|
for bn in BAND_NAMES:
|
||||||
|
sos = self._sos[bn]
|
||||||
|
if self.zero_phase:
|
||||||
|
out[bn] = signal.sosfiltfilt(sos, eeg, axis=0).astype(np.float32)
|
||||||
|
else:
|
||||||
|
out[bn] = signal.sosfilt(sos, eeg, axis=0).astype(np.float32)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extract(self, eeg32: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
eeg32: (T, 32)
|
||||||
|
return: feats (N_slices, feat_dim)
|
||||||
|
"""
|
||||||
|
if eeg32.ndim != 2 or eeg32.shape[1] != 32:
|
||||||
|
raise RuntimeError(f"extract需要 (T,32),得到 {eeg32.shape}")
|
||||||
|
|
||||||
|
bands_data = self._filter_bands(eeg32)
|
||||||
|
T = eeg32.shape[0]
|
||||||
|
|
||||||
|
feats = []
|
||||||
|
for start in range(0, T - self.window_size, self.stride):
|
||||||
|
end = start + self.window_size
|
||||||
|
|
||||||
|
de_list = []
|
||||||
|
psd_list = []
|
||||||
|
|
||||||
|
for bn in BAND_NAMES:
|
||||||
|
seg = bands_data[bn][start:end, :] # (W, 32)
|
||||||
|
var = np.var(seg, axis=0, ddof=1) # (32,)
|
||||||
|
|
||||||
|
# DE
|
||||||
|
de = 0.5 * np.log(2 * np.pi * np.e * (var + EPS))
|
||||||
|
de_list.append(de)
|
||||||
|
|
||||||
|
if USE_EXTENDED_FEATURES:
|
||||||
|
# PSD近似:log(var)
|
||||||
|
psd_list.append(np.log(var + EPS))
|
||||||
|
|
||||||
|
de_feat = np.stack(de_list, axis=0).T.reshape(-1) # (32*5,)
|
||||||
|
if USE_EXTENDED_FEATURES:
|
||||||
|
psd_feat = np.stack(psd_list, axis=0).T.reshape(-1) # (32*5,)
|
||||||
|
f = np.concatenate([de_feat, psd_feat], axis=0).astype(np.float32)
|
||||||
|
else:
|
||||||
|
f = de_feat.astype(np.float32)
|
||||||
|
|
||||||
|
feats.append(f)
|
||||||
|
|
||||||
|
if not feats:
|
||||||
|
raise RuntimeError("EEG长度不足以切片(请检查T是否太短,或调整WINDOW_SIZE/STRIDE)")
|
||||||
|
|
||||||
|
return np.stack(feats, axis=0).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# 4) 模型加载 + 推理接口
|
||||||
|
# =========================================================
|
||||||
|
|
||||||
|
def _safe_torch_load(path: str):
|
||||||
|
try:
|
||||||
|
return torch.load(path, map_location=DEVICE, weights_only=False)
|
||||||
|
except TypeError:
|
||||||
|
return torch.load(path, map_location=DEVICE)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path: str) -> tuple[FusionNet, dict]:
|
||||||
|
"""
|
||||||
|
返回: (model, ckpt_dict)
|
||||||
|
"""
|
||||||
|
obj = _safe_torch_load(model_path)
|
||||||
|
if isinstance(obj, dict) and "model_state" in obj:
|
||||||
|
ckpt = obj
|
||||||
|
state = obj["model_state"]
|
||||||
|
feat_dim = int(obj.get("feat_dim", 320))
|
||||||
|
else:
|
||||||
|
ckpt = {}
|
||||||
|
state = obj
|
||||||
|
feat_dim = 320
|
||||||
|
|
||||||
|
model = FusionNet(num_classes=2, num_eeg_features=feat_dim).to(DEVICE)
|
||||||
|
model.load_state_dict(state, strict=True)
|
||||||
|
model.eval()
|
||||||
|
return model, ckpt
|
||||||
|
|
||||||
|
|
||||||
|
def predict_hc_mdd(eeg_dir: str, model_path: str) -> dict:
|
||||||
|
"""
|
||||||
|
接口:传入 EEG文件夹 和 模型路径,返回判断结果 dict
|
||||||
|
|
||||||
|
返回字段:
|
||||||
|
- mat_file: 使用的mat文件
|
||||||
|
- pred_label: "HC" or "MDD"
|
||||||
|
- p_mdd_mean: 切片p(MDD)均值
|
||||||
|
- threshold: subject判定阈值
|
||||||
|
- n_slices: 切片数
|
||||||
|
"""
|
||||||
|
mat_file = _find_first_mat_file(eeg_dir)
|
||||||
|
|
||||||
|
# 1) 读EEG (T,C),并保证变成32通道
|
||||||
|
eeg = load_eeg_from_mat_any_channels(mat_file) # (T,C)
|
||||||
|
eeg32 = ensure_32_channels(eeg) # (T,32)
|
||||||
|
|
||||||
|
# 2) 提特征 (N,feat_dim)
|
||||||
|
extractor = FeatureExtractor32(fs=SAMPLING_RATE, window_size=WINDOW_SIZE, stride=STRIDE)
|
||||||
|
feats = extractor.extract(eeg32) # (N, dim)
|
||||||
|
|
||||||
|
# 3) 加载模型
|
||||||
|
model, ckpt = load_model(model_path)
|
||||||
|
|
||||||
|
# 4) 可选:归一化(若ckpt里保存的mean/std维度刚好匹配)
|
||||||
|
mean = ckpt.get("global_mean", None) if isinstance(ckpt, dict) else None
|
||||||
|
std = ckpt.get("global_std", None) if isinstance(ckpt, dict) else None
|
||||||
|
|
||||||
|
if mean is not None and std is not None:
|
||||||
|
mean = np.asarray(mean, dtype=np.float32)
|
||||||
|
std = np.asarray(std, dtype=np.float32)
|
||||||
|
if mean.shape[0] == feats.shape[1] and std.shape[0] == feats.shape[1]:
|
||||||
|
feats = (feats - mean) / (std + 1e-8)
|
||||||
|
# 不匹配就跳过
|
||||||
|
# 你说“先不用其他步骤或信息”,所以这里按“尽量运行”处理
|
||||||
|
|
||||||
|
# 5) 推理:对所有切片算 p(MDD),取均值做subject-level
|
||||||
|
x = torch.from_numpy(feats).to(DEVICE)
|
||||||
|
with torch.no_grad():
|
||||||
|
cls_out, _ = model(x)
|
||||||
|
prob_mdd = torch.softmax(cls_out, dim=1)[:, 1].detach().cpu().numpy()
|
||||||
|
|
||||||
|
p_mdd_mean = float(np.mean(prob_mdd))
|
||||||
|
thr = float(ckpt.get("subject_threshold", DEFAULT_SUBJECT_THRESHOLD)) if isinstance(ckpt, dict) else float(DEFAULT_SUBJECT_THRESHOLD)
|
||||||
|
pred_is_mdd = (p_mdd_mean >= thr)
|
||||||
|
pred_label = "MDD" if pred_is_mdd else "HC"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mat_file": mat_file,
|
||||||
|
"pred_label": pred_label,
|
||||||
|
"p_mdd_mean": p_mdd_mean,
|
||||||
|
"threshold": thr,
|
||||||
|
"n_slices": int(feats.shape[0]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================
|
||||||
|
# 5) CLI:命令行运行入口
|
||||||
|
# =========================================================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Infer HC/MDD from 64ch->32ch EEG mat using a .pth FusionNet model (no Asym).")
|
||||||
|
parser.add_argument("--eeg_dir", type=str, required=True, help="包含.mat EEG文件的文件夹(自动读取第一个.mat)")
|
||||||
|
parser.add_argument("--model_path", type=str, required=True, help="训练好的.pth模型路径")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
res = predict_hc_mdd(args.eeg_dir, args.model_path)
|
||||||
|
print("\n========== 推理结果 ==========")
|
||||||
|
print(f"MAT文件: {res['mat_file']}")
|
||||||
|
print(f"切片数量: {res['n_slices']}")
|
||||||
|
print(f"p(MDD)_mean: {res['p_mdd_mean']:.4f}")
|
||||||
|
print(f"阈值thr: {res['threshold']:.4f}")
|
||||||
|
print(f"预测结果: {res['pred_label']}")
|
||||||
|
print("==============================\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
algorithm_V1/model/Model_0.pth
Normal file
BIN
algorithm_V1/model/Model_1.pth
Normal file
BIN
algorithm_V1/out/EEG.png
Normal file
|
After Width: | Height: | Size: 683 KiB |
9
algorithm_V1/out/ResultData.txt
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
中央区α/β波比值:0.5
|
||||||
|
额区α/β波比值:0.6
|
||||||
|
顶区α/β波比值:0.6
|
||||||
|
中央区θ/β波比值:0.9
|
||||||
|
顶区θ/β波比值:0.9
|
||||||
|
前额叶α波不对称性:-0.0
|
||||||
|
个体化α峰值频率:8.5
|
||||||
|
前额叶θ+δ波功率:74.1
|
||||||
|
是否推荐治疗:是
|
||||||
BIN
algorithm_V1/out/average_topomap.png
Normal file
|
After Width: | Height: | Size: 247 KiB |
BIN
algorithm_V1/out/psd.png
Normal file
|
After Width: | Height: | Size: 107 KiB |
BIN
algorithm_V1/out/topomaps.png
Normal file
|
After Width: | Height: | Size: 298 KiB |
335615
algorithm_V1/raw_data/eeg_data0(6).bdf
Normal file
6
algorithm_V1/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
numpy
|
||||||
|
scipy
|
||||||
|
matplotlib
|
||||||
|
mne
|
||||||
|
torch
|
||||||
|
scikit-learn
|
||||||
986
algorithm_V1/runDecoder.py
Normal file
@@ -0,0 +1,986 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
runDecoder.py - BDF EEG Depression Assessment
|
||||||
|
|
||||||
|
功能:
|
||||||
|
1. 读取 raw_data 文件夹的第一个 .bdf 格式文件
|
||||||
|
2. 预处理:坏通道剔除、50Hz陷波、0.8-40Hz带通、幅值过滤、ICA去伪迹
|
||||||
|
3. 调用 infer_pth.py 中的 predict_hc_mdd 进行 HC/MDD 分类预测
|
||||||
|
4. 保存图表(EEG、PSD、Topomap)
|
||||||
|
5. 生成 ResultData.txt
|
||||||
|
|
||||||
|
"""
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import scipy.signal as signal
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import mne
|
||||||
|
from mne.preprocessing import ICA
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# Config - 预处理参数
|
||||||
|
# ==========================
|
||||||
|
# 滤波参数
|
||||||
|
BANDPASS_LOW = 0.8
|
||||||
|
BANDPASS_HIGH = 40.0
|
||||||
|
NOTCH_FREQS = [50, 100] # 工频陷波
|
||||||
|
|
||||||
|
# 幅值过滤阈值 (μV)
|
||||||
|
AMPLITUDE_MIN_UV = -200.0
|
||||||
|
AMPLITUDE_MAX_UV = 200.0
|
||||||
|
|
||||||
|
# ICA 参数
|
||||||
|
ICA_N_COMPONENTS = 20 # 使用绝对数量而非比例
|
||||||
|
ICA_RANDOM_STATE = 97
|
||||||
|
ICA_MAX_ITER = 800
|
||||||
|
|
||||||
|
# 坏段检测阈值 (μV)
|
||||||
|
BAD_SEGMENT_THRESHOLD_UV = 350.0
|
||||||
|
|
||||||
|
# 默认采样率
|
||||||
|
DEFAULT_FS = 250.0
|
||||||
|
|
||||||
|
# 画图参数
|
||||||
|
EEG_PLOT_SECONDS = 10
|
||||||
|
PSD_FMIN, PSD_FMAX = 0.8, 45.0
|
||||||
|
FIXED_EEG_IDXS = [23, 47, 39, 6, 2, 21, 35, 57] # 0-based index
|
||||||
|
FIXED_EEG_LABELS = ["C5", "O1", "TP7", "FPZ", "PO6", "P4", "AF7", "AF3"]
|
||||||
|
|
||||||
|
# 频段定义
|
||||||
|
BANDS_METRICS = {
|
||||||
|
"Delta": (1.0, 4.0),
|
||||||
|
"Theta": (4.0, 8.0),
|
||||||
|
"Alpha": (8.0, 13.0),
|
||||||
|
"Beta": (13.0, 30.0),
|
||||||
|
}
|
||||||
|
TOTAL_POWER_BAND = (1.0, 50.0)
|
||||||
|
|
||||||
|
BANDS_TOPOMAP = {
|
||||||
|
"delta": (1.0, 4.0),
|
||||||
|
"theta": (4.0, 8.0),
|
||||||
|
"alpha": (8.0, 13.0),
|
||||||
|
"beta": (13.0, 30.0),
|
||||||
|
"broad": (1.0, 30.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
# PSD 参数
|
||||||
|
PSD_NPERSEG = 1024 # FFT 窗口大小,越大频率分辨率越高
|
||||||
|
|
||||||
|
EPS = 1e-12
|
||||||
|
|
||||||
|
# 脑地形图颜色范围参数
|
||||||
|
# 设置为 None 表示自动范围,设置为 (min, max) 固定范围
|
||||||
|
TOPOMAP_VMIN = None # 例如: -1.0
|
||||||
|
TOPOMAP_VMAX = None # 例如: 1.0
|
||||||
|
# 或者使用对称范围(相对于均值的倍数)
|
||||||
|
TOPOMAP_SYM_SCALE = 1.5 # 颜色范围 = 均值 ± std * SYM_SCALE
|
||||||
|
|
||||||
|
# 脑地形图圆形大小参数 (0.08 - 0.15 范围)
|
||||||
|
# 数值越小圆形越小,越大圆形越大
|
||||||
|
TOPOMAP_SPHERE_RADIUS = 0.12
|
||||||
|
|
||||||
|
# 边界处理参数
|
||||||
|
# 滤波前 padding 秒数,用于消除边界振铃效应
|
||||||
|
FILTER_PAD_SEC = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 数据文件读取
|
||||||
|
# ==========================
|
||||||
|
def load_data_file(file_path: str) -> tuple:
|
||||||
|
"""根据文件扩展名读取数据,返回 MNE Raw 对象"""
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
|
||||||
|
if ext == ".bdf":
|
||||||
|
return load_bdf_file(file_path)
|
||||||
|
elif ext == ".mat":
|
||||||
|
return load_mat_file(file_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的文件格式: {ext}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_bdf_file(bdf_path: str) -> tuple:
|
||||||
|
"""读取 .bdf 格式文件,返回 MNE Raw 对象"""
|
||||||
|
print(f"[INFO] Reading BDF file: {bdf_path}")
|
||||||
|
raw = mne.io.read_raw_bdf(bdf_path, preload=True, verbose=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw.set_montage("standard_1020", on_missing="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Failed to set standard_1020 montage: {e}")
|
||||||
|
|
||||||
|
sfreq = raw.info['sfreq']
|
||||||
|
ch_names = raw.ch_names
|
||||||
|
n_channels = len(ch_names)
|
||||||
|
duration = raw.times[-1] - raw.times[0]
|
||||||
|
|
||||||
|
print(f"[INFO] Channels: {n_channels}, Duration: {duration:.2f}s, SFreq: {sfreq:.2f}Hz")
|
||||||
|
|
||||||
|
return raw, sfreq, ch_names
|
||||||
|
|
||||||
|
|
||||||
|
def load_mat_file(mat_path: str) -> tuple:
|
||||||
|
"""读取 .mat 格式文件,返回 MNE Raw 对象"""
|
||||||
|
print(f"[INFO] Reading MAT file: {mat_path}")
|
||||||
|
import scipy.io
|
||||||
|
|
||||||
|
mat = scipy.io.loadmat(mat_path)
|
||||||
|
eeg = mat['eeg'][0, 0]
|
||||||
|
|
||||||
|
# 提取数据
|
||||||
|
data = eeg['data'] # (T, C)
|
||||||
|
if data.shape[0] < data.shape[1]:
|
||||||
|
data = data.T # 确保是 (T, C)
|
||||||
|
data = data.astype(np.float64) # 确保是 float
|
||||||
|
|
||||||
|
# 提取采样率
|
||||||
|
sfreq = float(eeg['sample_rate'][0, 0])
|
||||||
|
|
||||||
|
# 提取通道名称
|
||||||
|
ch_names_raw = eeg['electrode_name']
|
||||||
|
if ch_names_raw.ndim == 2:
|
||||||
|
ch_names = [str(ch[0]) if isinstance(ch[0], np.bytes_) else str(ch[0]) for ch in ch_names_raw[0]]
|
||||||
|
else:
|
||||||
|
ch_names = [f"EEG{i+1}" for i in range(data.shape[1])]
|
||||||
|
|
||||||
|
n_channels = data.shape[1]
|
||||||
|
n_samples = data.shape[0]
|
||||||
|
duration = n_samples / sfreq
|
||||||
|
|
||||||
|
print(f"[INFO] Channels: {n_channels}, Duration: {duration:.2f}s, SFreq: {sfreq:.2f}Hz")
|
||||||
|
|
||||||
|
# 创建 MNE Raw 对象
|
||||||
|
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=["eeg"] * n_channels)
|
||||||
|
raw = mne.io.RawArray(data.T, info, verbose=False) # (T, C) -> (C, T)
|
||||||
|
|
||||||
|
# 尝试设置通道位置
|
||||||
|
try:
|
||||||
|
electrode_xyz = eeg['electrode_xyz'] # (64, 3)
|
||||||
|
if electrode_xyz.shape[0] == n_channels:
|
||||||
|
ch_pos = {}
|
||||||
|
for i, name in enumerate(ch_names):
|
||||||
|
ch_pos[name] = electrode_xyz[i] / 1000.0 # 转换为米
|
||||||
|
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')
|
||||||
|
info.set_montage(montage)
|
||||||
|
print("[INFO] Applied electrode positions from mat file")
|
||||||
|
else:
|
||||||
|
raw.set_montage("standard_1020", on_missing="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Failed to set montage from mat file: {e}")
|
||||||
|
try:
|
||||||
|
raw.set_montage("standard_1020", on_missing="ignore")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return raw, sfreq, ch_names
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 坏通道检测
|
||||||
|
# ==========================
|
||||||
|
def detect_bad_channels(raw: mne.io.RawArray, z_thresh: float = 3.0) -> list:
|
||||||
|
"""检测坏通道:全零/常数通道 + MAD z-score 离群通道"""
|
||||||
|
data = raw.get_data()
|
||||||
|
ch_names = raw.ch_names
|
||||||
|
bad_chs = []
|
||||||
|
|
||||||
|
ptp = np.ptp(data, axis=1)
|
||||||
|
std = np.std(data, axis=1)
|
||||||
|
|
||||||
|
for i, (p, s) in enumerate(zip(ptp, std)):
|
||||||
|
if p < 1e-12 or s < 1e-12:
|
||||||
|
bad_chs.append(ch_names[i])
|
||||||
|
|
||||||
|
valid_mask = np.array([ch not in bad_chs for ch in ch_names])
|
||||||
|
if valid_mask.sum() > 2:
|
||||||
|
valid_ptp = ptp[valid_mask]
|
||||||
|
med = np.median(valid_ptp)
|
||||||
|
mad = np.median(np.abs(valid_ptp - med)) + 1e-30
|
||||||
|
z = np.abs(ptp - med) / (mad * 1.4826)
|
||||||
|
|
||||||
|
for i, zv in enumerate(z):
|
||||||
|
if zv > z_thresh and ch_names[i] not in bad_chs:
|
||||||
|
bad_chs.append(ch_names[i])
|
||||||
|
|
||||||
|
if bad_chs:
|
||||||
|
print(f"[INFO] Bad channels detected: {bad_chs}")
|
||||||
|
else:
|
||||||
|
print("[INFO] No bad channels detected")
|
||||||
|
|
||||||
|
return bad_chs
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 坏段标注
|
||||||
|
# ==========================
|
||||||
|
def annotate_bad_segments(raw: mne.io.RawArray, peak_to_peak_uv: float = 350.0):
|
||||||
|
"""简单坏段检测:按1秒窗口计算峰峰值,超过阈值标为 bad"""
|
||||||
|
peak_to_peak_v = peak_to_peak_uv * 1e-6
|
||||||
|
win = int(raw.info["sfreq"] * 1.0)
|
||||||
|
step = int(raw.info["sfreq"] * 0.5)
|
||||||
|
data = raw.get_data()
|
||||||
|
n_times = data.shape[1]
|
||||||
|
|
||||||
|
onsets = []
|
||||||
|
durations = []
|
||||||
|
|
||||||
|
for start in range(0, n_times - win, step):
|
||||||
|
seg = data[:, start:start + win]
|
||||||
|
ptp = np.ptp(seg, axis=1)
|
||||||
|
if np.any(ptp > peak_to_peak_v):
|
||||||
|
onsets.append(start / raw.info["sfreq"])
|
||||||
|
durations.append(win / raw.info["sfreq"])
|
||||||
|
|
||||||
|
if len(onsets) > 0:
|
||||||
|
ann = mne.Annotations(onset=onsets, duration=durations, description=["BAD_SEG"] * len(onsets))
|
||||||
|
raw.set_annotations(ann)
|
||||||
|
print(f"[INFO] Annotated {len(onsets)} bad segments")
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 核心预处理函数
|
||||||
|
# ==========================
|
||||||
|
def preprocess_bdf(raw: mne.io.RawArray) -> mne.io.RawArray:
|
||||||
|
"""BDF 数据预处理流程"""
|
||||||
|
print("[INFO] Starting preprocessing pipeline...")
|
||||||
|
|
||||||
|
# 1) 裁剪首尾 2s
|
||||||
|
crop_sec = 2.0
|
||||||
|
t_start = crop_sec
|
||||||
|
t_end = raw.times[-1] - crop_sec
|
||||||
|
if t_end > t_start:
|
||||||
|
raw = raw.crop(tmin=t_start, tmax=t_end)
|
||||||
|
print(f"[INFO] Cropped: removed first/last {crop_sec}s")
|
||||||
|
|
||||||
|
# 2) 去直流偏置
|
||||||
|
data = raw.get_data()
|
||||||
|
data -= data.mean(axis=1, keepdims=True)
|
||||||
|
raw._data = data
|
||||||
|
print("[INFO] Removed DC offset")
|
||||||
|
|
||||||
|
# 3) 坏通道检测与插值
|
||||||
|
bad_chs = detect_bad_channels(raw)
|
||||||
|
if bad_chs:
|
||||||
|
raw.info["bads"] = bad_chs
|
||||||
|
try:
|
||||||
|
raw_tmp = raw.copy()
|
||||||
|
raw_tmp.set_montage(raw.get_montage(), on_missing="ignore")
|
||||||
|
raw_tmp.interpolate_bads(reset_bads=True, verbose=False)
|
||||||
|
raw = raw_tmp
|
||||||
|
print(f"[INFO] Bad channels interpolated: {bad_chs}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Bad channel interpolation failed: {e}")
|
||||||
|
raw.info["bads"] = []
|
||||||
|
|
||||||
|
# 4) 50Hz 陷波滤波
|
||||||
|
print(f"[INFO] Applying notch filter: {NOTCH_FREQS}Hz")
|
||||||
|
raw.notch_filter(NOTCH_FREQS, fir_design="firwin", verbose=False)
|
||||||
|
|
||||||
|
# 5) 0.8-40Hz 带通滤波 (使用 padding 消除边界振铃)
|
||||||
|
print(f"[INFO] Applying bandpass filter: {BANDPASS_LOW}-{BANDPASS_HIGH}Hz (with {FILTER_PAD_SEC}s padding)")
|
||||||
|
pad_sec = FILTER_PAD_SEC
|
||||||
|
raw_length = raw.times[-1]
|
||||||
|
pad_start = max(0, pad_sec)
|
||||||
|
pad_end = max(0, pad_sec)
|
||||||
|
|
||||||
|
if raw_length > pad_start + pad_end + 1.0:
|
||||||
|
raw.filter(BANDPASS_LOW, BANDPASS_HIGH, fir_design="firwin",
|
||||||
|
pad='reflect', verbose=False)
|
||||||
|
raw = raw.crop(tmin=pad_sec, tmax=raw_length - pad_sec)
|
||||||
|
print(f"[INFO] Removed {pad_sec}s padding from each side after filtering")
|
||||||
|
else:
|
||||||
|
raw.filter(BANDPASS_LOW, BANDPASS_HIGH, fir_design="firwin", verbose=False)
|
||||||
|
print(f"[WARN] Data too short ({raw_length:.1f}s), skipping padding")
|
||||||
|
|
||||||
|
# 6) 幅值过滤
|
||||||
|
print(f"[INFO] Applying amplitude filter: [{AMPLITUDE_MIN_UV}, {AMPLITUDE_MAX_UV}] μV")
|
||||||
|
amplitude_thresh_v = AMPLITUDE_MAX_UV * 1e-6
|
||||||
|
d = raw.get_data()
|
||||||
|
mask = np.abs(d) > amplitude_thresh_v
|
||||||
|
n_clipped = int(mask.sum())
|
||||||
|
if n_clipped > 0:
|
||||||
|
d[mask] = 0.0
|
||||||
|
raw._data = d
|
||||||
|
print(f"[INFO] Amplitude clipping: {n_clipped} samples exceeded ±200μV, set to 0")
|
||||||
|
|
||||||
|
# 7) 坏段标注
|
||||||
|
annotate_bad_segments(raw, peak_to_peak_uv=BAD_SEGMENT_THRESHOLD_UV)
|
||||||
|
|
||||||
|
# 8) ICA 去伪迹
|
||||||
|
print("[INFO] Running ICA for artifact removal...")
|
||||||
|
ica = ICA(n_components=ICA_N_COMPONENTS, random_state=ICA_RANDOM_STATE,
|
||||||
|
max_iter=ICA_MAX_ITER, method="fastica")
|
||||||
|
ica.fit(raw, reject_by_annotation=True, verbose=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
eog_inds, _ = ica.find_bads_eog(raw, verbose=False)
|
||||||
|
if eog_inds:
|
||||||
|
ica.exclude.extend(eog_inds)
|
||||||
|
print(f"[INFO] ICA exclude EOG components: {eog_inds}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] ICA EOG detection skipped: {e}")
|
||||||
|
|
||||||
|
raw_clean = ica.apply(raw.copy(), verbose=False)
|
||||||
|
|
||||||
|
# 9) ICA 后再次去直流
|
||||||
|
d = raw_clean.get_data()
|
||||||
|
d -= d.mean(axis=1, keepdims=True)
|
||||||
|
raw_clean._data = d
|
||||||
|
|
||||||
|
print("[INFO] Preprocessing completed")
|
||||||
|
return raw_clean
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 输出目录管理
|
||||||
|
# ==========================
|
||||||
|
def ensure_outdir(out_root: str) -> str:
|
||||||
|
"""确保输出目录存在,并清空旧文件(保留 ResultData.txt)"""
|
||||||
|
if os.path.exists(out_root):
|
||||||
|
for filename in os.listdir(out_root):
|
||||||
|
if filename == "ResultData.txt":
|
||||||
|
continue
|
||||||
|
file_path = os.path.join(out_root, filename)
|
||||||
|
try:
|
||||||
|
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||||
|
os.unlink(file_path)
|
||||||
|
elif os.path.isdir(file_path):
|
||||||
|
shutil.rmtree(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Failed to delete {file_path}: {e}")
|
||||||
|
else:
|
||||||
|
os.makedirs(out_root, exist_ok=True)
|
||||||
|
return out_root
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 通道分区
|
||||||
|
# ==========================
|
||||||
|
def _norm_name(s: str) -> str:
|
||||||
|
return str(s).strip().upper().replace(" ", "")
|
||||||
|
|
||||||
|
|
||||||
|
def build_channel_index_map(ch_names, n_channels: int):
|
||||||
|
if not ch_names or len(ch_names) != n_channels:
|
||||||
|
return {}
|
||||||
|
return {_norm_name(nm): i for i, nm in enumerate(ch_names)}
|
||||||
|
|
||||||
|
|
||||||
|
def pick_indices_by_names(name_to_idx, names):
|
||||||
|
idx = []
|
||||||
|
for n in names:
|
||||||
|
nn = _norm_name(n)
|
||||||
|
if nn in name_to_idx:
|
||||||
|
idx.append(name_to_idx[nn])
|
||||||
|
return sorted(list(set(idx)))
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_region_indices(n_channels: int):
|
||||||
|
a = int(n_channels * 0.33)
|
||||||
|
b = int(n_channels * 0.66)
|
||||||
|
return (
|
||||||
|
list(range(0, a)), # frontal
|
||||||
|
list(range(a, b)), # central
|
||||||
|
list(range(b, n_channels)), # parietal
|
||||||
|
list(range(0, max(2, a // 2))), # prefrontal
|
||||||
|
list(range(b, n_channels)), # posterior
|
||||||
|
[i for i in range(n_channels) if i % 2 == 0], # left
|
||||||
|
[i for i in range(n_channels) if i % 2 == 1], # right
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_region_indices(name_to_idx, n_channels: int):
|
||||||
|
if not name_to_idx:
|
||||||
|
return _fallback_region_indices(n_channels)
|
||||||
|
|
||||||
|
central_names = ["CZ","C1","C2","C3","C4","C5","C6","CP1","CP2","CP3","CP4","CP5","CP6","FC1","FC2","FC3","FC4","FC5","FC6"]
|
||||||
|
frontal_names = ["FZ","F1","F2","F3","F4","F5","F6","F7","F8","AF3","AF4","AF7","AF8","FPZ","FP1","FP2","FCZ"]
|
||||||
|
parietal_names = ["PZ","P1","P2","P3","P4","P5","P6","POZ","PO3","PO4","PO5","PO6","PO7","PO8","CPZ"]
|
||||||
|
prefrontal_names = ["FP1","FP2","FPZ","AF3","AF4","AF7","AF8"]
|
||||||
|
posterior_names = ["O1","O2","OZ","PO7","PO8","PO3","PO4","PZ","P3","P4","P1","P2"]
|
||||||
|
left_names = ["FP1","AF3","AF7","F3","F5","F7"]
|
||||||
|
right_names = ["FP2","AF4","AF8","F4","F6","F8"]
|
||||||
|
|
||||||
|
central = pick_indices_by_names(name_to_idx, central_names)
|
||||||
|
frontal = pick_indices_by_names(name_to_idx, frontal_names)
|
||||||
|
parietal = pick_indices_by_names(name_to_idx, parietal_names)
|
||||||
|
prefrontal = pick_indices_by_names(name_to_idx, prefrontal_names)
|
||||||
|
posterior = pick_indices_by_names(name_to_idx, posterior_names)
|
||||||
|
left = pick_indices_by_names(name_to_idx, left_names)
|
||||||
|
right = pick_indices_by_names(name_to_idx, right_names)
|
||||||
|
|
||||||
|
if not (central and frontal and parietal and prefrontal and posterior):
|
||||||
|
fb = _fallback_region_indices(n_channels)
|
||||||
|
frontal2, central2, parietal2, prefrontal2, posterior2, left2, right2 = fb
|
||||||
|
frontal = frontal if frontal else frontal2
|
||||||
|
central = central if central else central2
|
||||||
|
parietal = parietal if parietal else parietal2
|
||||||
|
prefrontal = prefrontal if prefrontal else prefrontal2
|
||||||
|
posterior = posterior if posterior else posterior2
|
||||||
|
left = left if left else left2
|
||||||
|
right = right if right else right2
|
||||||
|
|
||||||
|
return frontal, central, parietal, prefrontal, posterior, left, right
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# PSD 和频段功率计算
|
||||||
|
# ==========================
|
||||||
|
def welch_psd(eeg_tc, fs):
|
||||||
|
"""计算 PSD"""
|
||||||
|
nperseg = min(PSD_NPERSEG, eeg_tc.shape[0])
|
||||||
|
noverlap = int(nperseg * 0.75)
|
||||||
|
freqs, pxx = signal.welch(
|
||||||
|
eeg_tc, fs=fs, nperseg=nperseg, noverlap=noverlap,
|
||||||
|
axis=0, scaling="density"
|
||||||
|
)
|
||||||
|
return freqs, pxx
|
||||||
|
|
||||||
|
|
||||||
|
def band_power_from_psd(freqs, pxx_fc, band):
|
||||||
|
lo, hi = band
|
||||||
|
m = (freqs >= lo) & (freqs < hi)
|
||||||
|
if not np.any(m):
|
||||||
|
return np.zeros((pxx_fc.shape[1],), dtype=np.float32)
|
||||||
|
|
||||||
|
return np.trapz(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def region_mean_power(freqs, pxx_fc, idx, band) -> float:
|
||||||
|
if not idx:
|
||||||
|
return 0.0
|
||||||
|
pw = band_power_from_psd(freqs, pxx_fc, band)
|
||||||
|
return float(np.mean(pw[idx]))
|
||||||
|
|
||||||
|
|
||||||
|
def compute_iaf(freqs, pxx_fc, posterior_idx):
|
||||||
|
lo, hi = BANDS_METRICS["Alpha"]
|
||||||
|
m = (freqs >= lo) & (freqs <= hi)
|
||||||
|
if not np.any(m) or not posterior_idx:
|
||||||
|
return 0.0
|
||||||
|
spec = np.mean(pxx_fc[:, posterior_idx], axis=1)
|
||||||
|
sub = spec[m]
|
||||||
|
fsub = freqs[m]
|
||||||
|
return float(fsub[int(np.argmax(sub))])
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 画图函数
|
||||||
|
# ==========================
|
||||||
|
def plot_eeg_waveforms(data_uv_tc, fs, ch_names, out_dir, seconds=10, t_start_sec=30.0):
|
||||||
|
"""画 EEG 波形图(固定通道)"""
|
||||||
|
T, C = data_uv_tc.shape
|
||||||
|
|
||||||
|
start_sample = int(t_start_sec * fs)
|
||||||
|
end_sample = int(min(T, start_sample + seconds * fs))
|
||||||
|
|
||||||
|
if start_sample >= T:
|
||||||
|
start_sample = max(0, T - int(seconds * fs))
|
||||||
|
end_sample = T
|
||||||
|
print(f"[WARN] t_start_sec={t_start_sec}s exceeds data, using last {seconds}s")
|
||||||
|
|
||||||
|
seg_samples = end_sample - start_sample
|
||||||
|
x = np.arange(seg_samples) / fs + t_start_sec
|
||||||
|
|
||||||
|
# 过滤有效索引
|
||||||
|
idxs = [i for i in FIXED_EEG_IDXS if 0 <= i < C]
|
||||||
|
if len(idxs) < len(FIXED_EEG_IDXS):
|
||||||
|
missing = [i for i in FIXED_EEG_IDXS if not (0 <= i < C)]
|
||||||
|
print(f"[WARN] Some indices out of range (C={C}): {missing}")
|
||||||
|
|
||||||
|
if len(idxs) == 0:
|
||||||
|
raise RuntimeError(f"No valid indices for data (C={C})")
|
||||||
|
|
||||||
|
picked_names = []
|
||||||
|
for idx in idxs:
|
||||||
|
pos = FIXED_EEG_IDXS.index(idx)
|
||||||
|
std_label = FIXED_EEG_LABELS[pos] if pos < len(FIXED_EEG_LABELS) else f"CH{idx}"
|
||||||
|
if ch_names and idx < len(ch_names):
|
||||||
|
picked_names.append(std_label)
|
||||||
|
else:
|
||||||
|
picked_names.append(std_label)
|
||||||
|
|
||||||
|
fig_h = 1.4 * len(idxs) + 1
|
||||||
|
fig, axes = plt.subplots(len(idxs), 1, figsize=(10, fig_h), sharex=True)
|
||||||
|
if len(idxs) == 1:
|
||||||
|
axes = [axes]
|
||||||
|
|
||||||
|
seg = data_uv_tc[start_sample:end_sample, idxs].T
|
||||||
|
lo = float(np.percentile(seg, 1))
|
||||||
|
hi = float(np.percentile(seg, 99))
|
||||||
|
m = max(abs(lo), abs(hi), 50.0)
|
||||||
|
|
||||||
|
for ax, ch_idx, nm in zip(axes, idxs, picked_names):
|
||||||
|
y = data_uv_tc[start_sample:end_sample, ch_idx]
|
||||||
|
ax.plot(x, y, linewidth=1.2)
|
||||||
|
ax.set_ylabel("uV")
|
||||||
|
ax.set_title(nm, loc="left", fontsize=10)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.set_ylim(-m, m)
|
||||||
|
|
||||||
|
axes[-1].set_xlabel("Time (s)")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
out_path = os.path.join(out_dir, "EEG.png")
|
||||||
|
plt.savefig(out_path, dpi=200)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f"[OK] EEG waveform saved: {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def plot_psd(eeg_uV_tc, fs, ch_names, out_dir):
|
||||||
|
"""画 PSD 图"""
|
||||||
|
C = eeg_uV_tc.shape[1]
|
||||||
|
chosen_idx = []
|
||||||
|
|
||||||
|
if ch_names:
|
||||||
|
mp = {n.upper(): i for i, n in enumerate(ch_names)}
|
||||||
|
for p in ["C3", "C4", "CZ"]:
|
||||||
|
if p in mp:
|
||||||
|
chosen_idx.append(mp[p])
|
||||||
|
if len(chosen_idx) < 3:
|
||||||
|
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
|
||||||
|
stds.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
for i, _ in stds:
|
||||||
|
if i not in chosen_idx:
|
||||||
|
chosen_idx.append(i)
|
||||||
|
if len(chosen_idx) == 3:
|
||||||
|
break
|
||||||
|
chosen_name = [ch_names[i] for i in chosen_idx]
|
||||||
|
else:
|
||||||
|
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
|
||||||
|
stds.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
chosen_idx = [i for i, _ in stds[:3]]
|
||||||
|
chosen_name = [f"CH{i}" for i in chosen_idx]
|
||||||
|
|
||||||
|
# 增大 nperseg 提高频率分辨率
|
||||||
|
nperseg = min(PSD_NPERSEG, eeg_uV_tc.shape[0])
|
||||||
|
noverlap = int(nperseg * 0.75)
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(7.5, 4.8))
|
||||||
|
for idx, nm in zip(chosen_idx, chosen_name):
|
||||||
|
f, pxx = signal.welch(eeg_uV_tc[:, idx], fs=fs, nperseg=nperseg, noverlap=noverlap)
|
||||||
|
mask = (f >= PSD_FMIN) & (f <= PSD_FMAX)
|
||||||
|
p_db = 10 * np.log10(pxx[mask] + 1e-20)
|
||||||
|
plt.plot(f[mask], p_db, linewidth=1.8, label=nm)
|
||||||
|
|
||||||
|
plt.xlabel("Hz")
|
||||||
|
plt.ylabel("Power (dB)")
|
||||||
|
plt.title("PSD")
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
plt.legend()
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
out_path = os.path.join(out_dir, "psd.png")
|
||||||
|
plt.savefig(out_path, dpi=200)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f"[OK] PSD saved: {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_standard_1020_channel_indices(raw):
|
||||||
|
"""获取符合 standard_1020 montage 的通道索引和名称"""
|
||||||
|
try:
|
||||||
|
standard_montage = mne.channels.make_standard_montage("standard_1020")
|
||||||
|
standard_names_upper = {ch.upper() for ch in standard_montage.ch_names}
|
||||||
|
standard_name_map = {ch.upper(): ch for ch in standard_montage.ch_names}
|
||||||
|
data_ch_names = raw.ch_names
|
||||||
|
exclude_names = {"A1", "A2", "M1", "M2", "LE", "RE", "LM", "RM"}
|
||||||
|
|
||||||
|
valid_indices = []
|
||||||
|
valid_names = []
|
||||||
|
for i, name in enumerate(data_ch_names):
|
||||||
|
name_upper = name.upper()
|
||||||
|
if name_upper in standard_names_upper and name_upper not in exclude_names:
|
||||||
|
valid_indices.append(i)
|
||||||
|
valid_names.append(standard_name_map[name_upper])
|
||||||
|
|
||||||
|
print(f"[INFO] Found {len(valid_indices)}/{len(data_ch_names)} channels matching standard_1020")
|
||||||
|
return valid_indices, valid_names
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Failed to get standard_1020 channels: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def compute_band_powers_for_topomap(raw, bands):
|
||||||
|
"""计算各频段功率,只使用 standard_1020 中有位置的通道"""
|
||||||
|
# 获取 standard_1020 montage 和位置信息
|
||||||
|
standard_montage = mne.channels.make_standard_montage("standard_1020")
|
||||||
|
std_names_upper = {ch.upper() for ch in standard_montage.ch_names}
|
||||||
|
ch_pos_map = standard_montage.get_positions()['ch_pos']
|
||||||
|
|
||||||
|
data_ch_names = raw.ch_names
|
||||||
|
exclude = {"A1", "A2", "M1", "M2", "LE", "RE", "LM", "RM"}
|
||||||
|
|
||||||
|
# 只保留有位置信息的通道
|
||||||
|
valid_indices = []
|
||||||
|
valid_names = []
|
||||||
|
for i, name in enumerate(data_ch_names):
|
||||||
|
name_upper = name.upper()
|
||||||
|
if name_upper in std_names_upper and name_upper not in exclude:
|
||||||
|
if name_upper in ch_pos_map: # 必须有位置
|
||||||
|
valid_indices.append(i)
|
||||||
|
valid_names.append(name)
|
||||||
|
|
||||||
|
if len(valid_indices) < 8:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = raw.get_data()
|
||||||
|
data_standard = data[valid_indices, :]
|
||||||
|
|
||||||
|
fs = raw.info["sfreq"]
|
||||||
|
n_fft = min(PSD_NPERSEG, data_standard.shape[1])
|
||||||
|
n_overlap = int(n_fft * 0.75)
|
||||||
|
|
||||||
|
psds, freqs = mne.time_frequency.psd_array_welch(
|
||||||
|
data_standard, sfreq=fs,
|
||||||
|
fmin=min(v[0] for v in bands.values()),
|
||||||
|
fmax=max(v[1] for v in bands.values()),
|
||||||
|
n_fft=n_fft, n_overlap=n_overlap,
|
||||||
|
average="mean", verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
out = {"_valid_names": valid_names}
|
||||||
|
print(f"[DEBUG] PSD: fs={fs}Hz, n_fft={n_fft}, freq_res={fs/n_fft:.3f}Hz/bin")
|
||||||
|
for k, (fmin, fmax) in bands.items():
|
||||||
|
idx = np.where((freqs >= fmin) & (freqs < fmax))[0]
|
||||||
|
if len(idx) == 0:
|
||||||
|
out[k] = np.zeros(len(valid_indices), dtype=np.float32)
|
||||||
|
print(f"[DEBUG] {k.upper()}: NO freq bins in [{fmin}-{fmax}]Hz")
|
||||||
|
continue
|
||||||
|
print(f"[DEBUG] {k.upper()}: freq bins {freqs[idx[0]]:.2f}-{freqs[idx[-1]]:.2f}Hz (bins {idx[0]}-{idx[-1]}, count={len(idx)})")
|
||||||
|
# 使用线性功率值 (V^2 -> uV^2: * 1e12)
|
||||||
|
bp = np.trapz(psds[:, idx], freqs[idx], axis=1) * 1e12
|
||||||
|
out[k] = bp
|
||||||
|
print(f"[DEBUG] {k.upper()}: power range [{bp.min():.4f}, {bp.max():.4f}] uV^2, mean={bp.mean():.4f}")
|
||||||
|
|
||||||
|
print(f"[INFO] Band powers computed for {len(valid_names)} channels with positions")
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _create_topomap_raw(ch_names):
|
||||||
|
"""创建只有 standard_1020 通道位置信息的临时 Raw 对象"""
|
||||||
|
standard_montage = mne.channels.make_standard_montage("standard_1020")
|
||||||
|
ch_pos_map = standard_montage.get_positions()['ch_pos']
|
||||||
|
|
||||||
|
valid_ch_names = []
|
||||||
|
valid_positions = []
|
||||||
|
for name in ch_names:
|
||||||
|
name_upper = name.upper()
|
||||||
|
if name_upper in ch_pos_map:
|
||||||
|
valid_ch_names.append(name)
|
||||||
|
valid_positions.append(ch_pos_map[name_upper])
|
||||||
|
|
||||||
|
if len(valid_ch_names) < 8:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ch_pos = {name: pos for name, pos in zip(valid_ch_names, valid_positions)}
|
||||||
|
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')
|
||||||
|
|
||||||
|
info = mne.create_info(ch_names=valid_ch_names, sfreq=250.0, ch_types=["eeg"] * len(valid_ch_names))
|
||||||
|
info.set_montage(montage)
|
||||||
|
|
||||||
|
dummy_data = np.zeros((len(valid_ch_names), 1))
|
||||||
|
return mne.io.RawArray(dummy_data, info, verbose=False)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_average_topomap(band_values, out_dir):
|
||||||
|
"""绘制平均拓扑图"""
|
||||||
|
valid_names = band_values.get("_valid_names", [])
|
||||||
|
if not valid_names:
|
||||||
|
return
|
||||||
|
|
||||||
|
values = band_values["broad"]
|
||||||
|
temp_raw = _create_topomap_raw(valid_names)
|
||||||
|
if temp_raw is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
vmin, vmax = _compute_topomap_vlim([values])
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.6))
|
||||||
|
im, _ = mne.viz.plot_topomap(
|
||||||
|
values, temp_raw.info, axes=ax, show=False, contours=0,
|
||||||
|
sphere=(0, 0, 0, TOPOMAP_SPHERE_RADIUS), extrapolate='head', border='mean',
|
||||||
|
cmap='turbo'
|
||||||
|
)
|
||||||
|
im.set_clim(vmin=vmin, vmax=vmax)
|
||||||
|
ax.set_title("0.8-30 Hz", fontsize=12)
|
||||||
|
plt.colorbar(im, ax=ax, shrink=0.85)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(os.path.join(out_dir, "average_topomap.png"), dpi=200)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f"[OK] average_topomap saved")
|
||||||
|
|
||||||
|
|
||||||
|
def plot_band_topomaps(band_values, out_dir):
|
||||||
|
"""绘制分频段拓扑图"""
|
||||||
|
valid_names = band_values.get("_valid_names", [])
|
||||||
|
if not valid_names:
|
||||||
|
return
|
||||||
|
|
||||||
|
order = [
|
||||||
|
("delta", "δ (1-4Hz)"),
|
||||||
|
("theta", "θ (4-8Hz)"),
|
||||||
|
("alpha", "α (8-13Hz)"),
|
||||||
|
("beta", "β (13-30Hz)"),
|
||||||
|
("broad", "1-30 Hz"),
|
||||||
|
]
|
||||||
|
|
||||||
|
temp_raw = _create_topomap_raw(valid_names)
|
||||||
|
if temp_raw is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
all_values = [band_values[k] for k, _ in order]
|
||||||
|
vmin, vmax = _compute_topomap_vlim(all_values)
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(1, 5, figsize=(16, 4.2))
|
||||||
|
ims = []
|
||||||
|
for ax, (k, title) in zip(axes, order):
|
||||||
|
im, _ = mne.viz.plot_topomap(
|
||||||
|
band_values[k], temp_raw.info, axes=ax, show=False, contours=0,
|
||||||
|
sphere=(0, 0, 0, TOPOMAP_SPHERE_RADIUS), extrapolate='head', border='mean',
|
||||||
|
cmap='turbo'
|
||||||
|
)
|
||||||
|
im.set_clim(vmin=vmin, vmax=vmax)
|
||||||
|
ax.set_title(title, fontsize=11)
|
||||||
|
ims.append(im)
|
||||||
|
fig.subplots_adjust(left=0.02, right=0.85, top=0.88, bottom=0.05, wspace=0.35)
|
||||||
|
cax = fig.add_axes([0.87, 0.15, 0.015, 0.7])
|
||||||
|
fig.colorbar(ims[-1], cax=cax)
|
||||||
|
plt.savefig(os.path.join(out_dir, "topomaps.png"), dpi=200)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f"[OK] topomaps saved")
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_topomap_vlim(values):
|
||||||
|
"""计算脑地形图颜色范围"""
|
||||||
|
v_all = np.concatenate(values) if isinstance(values, list) else np.array(values)
|
||||||
|
if TOPOMAP_VMIN is not None and TOPOMAP_VMAX is not None:
|
||||||
|
return TOPOMAP_VMAX - 60, TOPOMAP_VMAX # 保持 50 的范围
|
||||||
|
if TOPOMAP_SYM_SCALE is not None and TOPOMAP_SYM_SCALE > 0:
|
||||||
|
mean_val = np.mean(v_all)
|
||||||
|
std_val = np.std(v_all)
|
||||||
|
return mean_val - std_val * TOPOMAP_SYM_SCALE, mean_val + std_val * TOPOMAP_SYM_SCALE
|
||||||
|
# 统一 vmax:使用所有频段中的最大值
|
||||||
|
# vmin = 0:这样低功率频段会接近 0(白色/冷色),高功率频段突出
|
||||||
|
vmax = np.max(v_all)
|
||||||
|
vmin = 0
|
||||||
|
return vmin, vmax
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 预测接口
|
||||||
|
# ==========================
|
||||||
|
def _predict_label_by_model(model_path: str, data_path: str) -> dict:
|
||||||
|
"""调用 infer_pth.py 进行预测"""
|
||||||
|
try:
|
||||||
|
from infer_pth import predict_hc_mdd
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"无法导入 predict_hc_mdd: {e}")
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
import scipy.io
|
||||||
|
|
||||||
|
ext = os.path.splitext(data_path)[1].lower()
|
||||||
|
|
||||||
|
if ext == ".mat":
|
||||||
|
# 直接使用 mat 文件
|
||||||
|
result = predict_hc_mdd(os.path.dirname(data_path), model_path)
|
||||||
|
elif ext == ".bdf":
|
||||||
|
# 转换为 mat 格式
|
||||||
|
raw = mne.io.read_raw_bdf(data_path, preload=True, verbose=False)
|
||||||
|
data, times = raw[:]
|
||||||
|
sfreq = raw.info['sfreq']
|
||||||
|
ch_names = raw.ch_names
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
mat_path = os.path.join(temp_dir, "preprocessed_eeg.mat")
|
||||||
|
scipy.io.savemat(mat_path, {
|
||||||
|
'eeg': {
|
||||||
|
'data': (data * 1e6).T,
|
||||||
|
'sample_rate': sfreq,
|
||||||
|
'electrode_name': ch_names
|
||||||
|
}
|
||||||
|
})
|
||||||
|
result = predict_hc_mdd(temp_dir, model_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的文件格式: {ext}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 生成 ResultData.txt
|
||||||
|
# ==========================
|
||||||
|
def compute_and_save_txt(model_path, bdf_path, out_dir, eeg_uV_tc, fs, ch_names):
|
||||||
|
"""计算特征指标并保存 ResultData.txt"""
|
||||||
|
# 获取预测结果
|
||||||
|
pred_result = _predict_label_by_model(model_path, bdf_path)
|
||||||
|
pred_label = pred_result.get("pred_label", "UNKNOWN")
|
||||||
|
recommend = "是" if pred_label == "MDD" else "否"
|
||||||
|
|
||||||
|
T, C = eeg_uV_tc.shape
|
||||||
|
mp = build_channel_index_map(ch_names, C)
|
||||||
|
frontal_idx, central_idx, parietal_idx, prefrontal_idx, posterior_idx, left_idx, right_idx = \
|
||||||
|
get_region_indices(mp, C)
|
||||||
|
|
||||||
|
freqs, pxx = welch_psd(eeg_uV_tc, fs)
|
||||||
|
|
||||||
|
# 计算各频段功率比
|
||||||
|
central_alpha = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Alpha"])
|
||||||
|
central_beta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Beta"])
|
||||||
|
frontal_alpha = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Alpha"])
|
||||||
|
frontal_beta = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Beta"])
|
||||||
|
par_alpha = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Alpha"])
|
||||||
|
par_beta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Beta"])
|
||||||
|
|
||||||
|
central_ab = (central_alpha / (central_beta + EPS)) if central_beta > 0 else 0.0
|
||||||
|
frontal_ab = (frontal_alpha / (frontal_beta + EPS)) if frontal_beta > 0 else 0.0
|
||||||
|
par_ab = (par_alpha / (par_beta + EPS)) if par_beta > 0 else 0.0
|
||||||
|
|
||||||
|
central_theta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Theta"])
|
||||||
|
par_theta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Theta"])
|
||||||
|
central_tb = (central_theta / (central_beta + EPS)) if central_beta > 0 else 0.0
|
||||||
|
par_tb = (par_theta / (par_beta + EPS)) if par_beta > 0 else 0.0
|
||||||
|
|
||||||
|
if not left_idx or not right_idx:
|
||||||
|
left_idx = [i for i in prefrontal_idx if (i % 2 == 0)]
|
||||||
|
right_idx = [i for i in prefrontal_idx if (i % 2 == 1)]
|
||||||
|
left_alpha = region_mean_power(freqs, pxx, left_idx, BANDS_METRICS["Alpha"])
|
||||||
|
right_alpha = region_mean_power(freqs, pxx, right_idx, BANDS_METRICS["Alpha"])
|
||||||
|
prefrontal_alpha_asym = float(np.log(right_alpha + EPS) - np.log(left_alpha + EPS))
|
||||||
|
|
||||||
|
iaf = compute_iaf(freqs, pxx, posterior_idx)
|
||||||
|
|
||||||
|
pre_td = region_mean_power(freqs, pxx, prefrontal_idx,
|
||||||
|
(BANDS_METRICS["Delta"][0], BANDS_METRICS["Theta"][1]))
|
||||||
|
pre_total = region_mean_power(freqs, pxx, prefrontal_idx, TOTAL_POWER_BAND)
|
||||||
|
pre_td_rel = (pre_td / (pre_total + EPS)) * 100.0 if pre_total > 0 else 0.0
|
||||||
|
|
||||||
|
def f1(x): return f"{x:.1f}"
|
||||||
|
|
||||||
|
txt = (
|
||||||
|
f"中央区α/β波比值:{f1(central_ab)}\n"
|
||||||
|
f"额区α/β波比值:{f1(frontal_ab)}\n"
|
||||||
|
f"顶区α/β波比值:{f1(par_ab)}\n"
|
||||||
|
f"中央区θ/β波比值:{f1(central_tb)}\n"
|
||||||
|
f"顶区θ/β波比值:{f1(par_tb)}\n"
|
||||||
|
f"前额叶α波不对称性:{f1(prefrontal_alpha_asym)}\n"
|
||||||
|
f"个体化α峰值频率:{f1(iaf)}\n"
|
||||||
|
f"前额叶θ+δ波功率:{f1(pre_td_rel)}\n"
|
||||||
|
f"是否推荐治疗:{recommend}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
out_path = os.path.join(out_dir, "ResultData.txt")
|
||||||
|
with open(out_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(txt)
|
||||||
|
print(f"[OK] ResultData.txt saved: {out_path}")
|
||||||
|
|
||||||
|
# 打印预测结果
|
||||||
|
print(f"\n========== 预测结果 ==========")
|
||||||
|
print(f"预测标签: {pred_label}")
|
||||||
|
print(f"p(MDD)均值: {pred_result.get('p_mdd_mean', 'N/A'):.4f}")
|
||||||
|
print(f"切片数量: {pred_result.get('n_slices', 'N/A')}")
|
||||||
|
print(f"==============================\n")
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 主函数
|
||||||
|
# ==========================
|
||||||
|
def run_all(model_path: str, bdf_dir: str, out_root: str, seconds: int = EEG_PLOT_SECONDS):
|
||||||
|
"""主流程"""
|
||||||
|
if not os.path.exists(bdf_dir):
|
||||||
|
raise RuntimeError(f"输入目录不存在: {bdf_dir}")
|
||||||
|
|
||||||
|
# 支持 .bdf 和 .mat 文件
|
||||||
|
data_files = [f for f in os.listdir(bdf_dir) if f.lower().endswith((".bdf", ".mat"))]
|
||||||
|
if not data_files:
|
||||||
|
raise RuntimeError(f"目录中找不到 .bdf 或 .mat 文件: {bdf_dir}")
|
||||||
|
|
||||||
|
data_files.sort()
|
||||||
|
data_path = os.path.join(bdf_dir, data_files[0])
|
||||||
|
print(f"[INFO] Processing file: {data_path}")
|
||||||
|
|
||||||
|
out_dir = ensure_outdir(out_root)
|
||||||
|
print(f"[INFO] Output directory: {out_dir}")
|
||||||
|
|
||||||
|
raw, sfreq, ch_names = load_data_file(data_path)
|
||||||
|
|
||||||
|
raw_clean = preprocess_bdf(raw)
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_clean.set_montage("standard_1020", on_missing="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Failed to re-apply montage: {e}")
|
||||||
|
|
||||||
|
raw_data = raw_clean.get_data()
|
||||||
|
eeg_uV_tc = (raw_data * 1e6).T.astype(np.float32)
|
||||||
|
print(f"[INFO] Preprocessed EEG shape: {eeg_uV_tc.shape}")
|
||||||
|
|
||||||
|
print("[INFO] Generating figures...")
|
||||||
|
plot_psd(eeg_uV_tc, sfreq, ch_names, out_dir)
|
||||||
|
plot_eeg_waveforms(eeg_uV_tc, sfreq, ch_names, out_dir, seconds=seconds)
|
||||||
|
|
||||||
|
print("[INFO] Generating topomaps...")
|
||||||
|
try:
|
||||||
|
band_vals = compute_band_powers_for_topomap(raw_clean, BANDS_TOPOMAP)
|
||||||
|
if band_vals is not None:
|
||||||
|
plot_average_topomap(band_vals, out_dir)
|
||||||
|
plot_band_topomaps(band_vals, out_dir)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Topomap generation failed: {e}")
|
||||||
|
|
||||||
|
print("[INFO] Running prediction...")
|
||||||
|
compute_and_save_txt(model_path, data_path, out_dir, eeg_uV_tc, sfreq, ch_names)
|
||||||
|
|
||||||
|
print("[DONE] All tasks completed.")
|
||||||
|
return out_dir
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================
|
||||||
|
# 命令行入口
|
||||||
|
# ==========================
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import multiprocessing
|
||||||
|
multiprocessing.freeze_support()
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def get_resource_path(relative_path):
|
||||||
|
"""获取资源绝对路径"""
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
base_path = os.path.dirname(sys.executable)
|
||||||
|
else:
|
||||||
|
base_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
return os.path.join(base_path, relative_path)
|
||||||
|
|
||||||
|
# 默认路径
|
||||||
|
DEFAULT_MODEL = get_resource_path(os.path.join("model", "Model_1.pth"))
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
EXE_DIR = os.path.dirname(sys.executable)
|
||||||
|
else:
|
||||||
|
EXE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
DEFAULT_BDF_DIR = os.path.join(EXE_DIR, "raw_data")
|
||||||
|
DEFAULT_OUT = os.path.join(EXE_DIR, "out")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="EEG Depression Assessment")
|
||||||
|
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL, help="模型文件路径 (.pth)")
|
||||||
|
parser.add_argument("--bdf_dir", type=str, default=DEFAULT_BDF_DIR, help="输入文件夹路径 (包含 .bdf 或 .mat 文件)")
|
||||||
|
parser.add_argument("--out_root", type=str, default=DEFAULT_OUT, help="结果输出目录")
|
||||||
|
parser.add_argument("--seconds", type=int, default=10, help="画波形图的秒数")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"[*] 运行配置:")
|
||||||
|
print(f" - Model : {args.model_path}")
|
||||||
|
print(f" - Input : {args.bdf_dir}")
|
||||||
|
print(f" - Output: {args.out_root}")
|
||||||
|
|
||||||
|
if not os.path.exists(args.bdf_dir):
|
||||||
|
print(f"[ERROR] 输入目录不存在: {args.bdf_dir}")
|
||||||
|
if not os.path.exists(args.model_path):
|
||||||
|
print(f"[ERROR] 模型文件不存在: {args.model_path}")
|
||||||
|
|
||||||
|
run_all(args.model_path, args.bdf_dir, args.out_root, seconds=args.seconds)
|
||||||