562 lines
18 KiB
Python
562 lines
18 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
infer_pth.py
|
|||
|
|
|
|||
|
|
用途:
|
|||
|
|
- 从一个文件夹中自动读取第一个 .mat EEG 文件(64通道或32通道)
|
|||
|
|
- 若为64通道,则按 idx64_to_32 映射选出32通道
|
|||
|
|
- 提取切片特征(DE + PSD(var近似),不含Asym)
|
|||
|
|
- 加载你训练好的 .pth 模型(FusionNet结构)
|
|||
|
|
- 输出该受试者的 HC / MDD 判断结果
|
|||
|
|
|
|||
|
|
运行方式(命令行):
|
|||
|
|
python infer_pth_from64_to32.py --eeg_dir "D:\\xxx\\folder" --model_path "C:\\xxx\\model.pth"
|
|||
|
|
|
|||
|
|
也可在其他py里import:
|
|||
|
|
from infer_pth_from64_to32 import predict_hc_mdd
|
|||
|
|
res = predict_hc_mdd(eeg_dir, model_path)
|
|||
|
|
print(res)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import argparse
|
|||
|
|
import numpy as np
|
|||
|
|
import scipy.io
|
|||
|
|
import scipy.signal as signal
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =========================================================
|
|||
|
|
# 0) 配置区(按需改这里)
|
|||
|
|
# =========================================================
|
|||
|
|
|
|||
|
|
# 采样率(必须与训练时一致)
|
|||
|
|
SAMPLING_RATE = 250
|
|||
|
|
|
|||
|
|
# 滑窗参数(必须与训练时一致)
|
|||
|
|
WINDOW_SIZE = 500
|
|||
|
|
STRIDE = 250
|
|||
|
|
|
|||
|
|
# 频段(必须与训练时一致)
|
|||
|
|
BAND_NAMES = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]
|
|||
|
|
BANDS = {
|
|||
|
|
"Delta": (1, 4),
|
|||
|
|
"Theta": (4, 8),
|
|||
|
|
"Alpha": (8, 13),
|
|||
|
|
"Beta": (13, 30),
|
|||
|
|
"Gamma": (30, 50),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 是否使用扩展特征(DE+PSD)
|
|||
|
|
USE_EXTENDED_FEATURES = True
|
|||
|
|
|
|||
|
|
# 数值稳定项
|
|||
|
|
EPS = 1e-12
|
|||
|
|
|
|||
|
|
# 设备
|
|||
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||
|
|
|
|||
|
|
# 通道映射
|
|||
|
|
IDX64_TO_32 = [
|
|||
|
|
23, # C5
|
|||
|
|
47, # O1
|
|||
|
|
39, # TP7
|
|||
|
|
6, # FPZ
|
|||
|
|
2, # PO6
|
|||
|
|
21, # P4
|
|||
|
|
35, # AF7
|
|||
|
|
57, # AF3
|
|||
|
|
1, # FP2
|
|||
|
|
37, # T7
|
|||
|
|
63, # F1
|
|||
|
|
36, # A1
|
|||
|
|
18, # FC4
|
|||
|
|
31, # FC5
|
|||
|
|
14, # FC2
|
|||
|
|
48, # T8
|
|||
|
|
60, # P2
|
|||
|
|
41, # AF8
|
|||
|
|
11, # CP1
|
|||
|
|
0, # FP1
|
|||
|
|
55, # PO7
|
|||
|
|
59, # C1
|
|||
|
|
22, # F5
|
|||
|
|
10, # CP2
|
|||
|
|
16, # C3
|
|||
|
|
61, # P1
|
|||
|
|
27, # CP5
|
|||
|
|
17, # C4
|
|||
|
|
26, # CP6
|
|||
|
|
62, # F2
|
|||
|
|
3, # POZ
|
|||
|
|
13, # PO5
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 推理阈值:如果模型checkpoint里有 subject_threshold,会优先用它;否则用这个
|
|||
|
|
DEFAULT_SUBJECT_THRESHOLD = 0.5
|
|||
|
|
|
|||
|
|
# =========================================================
|
|||
|
|
# 1) 模型结构
|
|||
|
|
# =========================================================
|
|||
|
|
|
|||
|
|
class SEBlock(nn.Module):
|
|||
|
|
def __init__(self, channels: int, reduction: int = 4) -> None:
|
|||
|
|
super().__init__()
|
|||
|
|
hidden = max(1, channels // reduction)
|
|||
|
|
self.fc1 = nn.Linear(channels, hidden)
|
|||
|
|
self.fc2 = nn.Linear(hidden, channels)
|
|||
|
|
|
|||
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|||
|
|
se = F.relu(self.fc1(x))
|
|||
|
|
se = torch.sigmoid(self.fc2(se))
|
|||
|
|
return x * se
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ResidualBlock(nn.Module):
|
|||
|
|
def __init__(self, in_features: int, out_features: int, dropout: float = 0.3) -> None:
|
|||
|
|
super().__init__()
|
|||
|
|
self.fc1 = nn.Linear(in_features, out_features)
|
|||
|
|
self.bn1 = nn.BatchNorm1d(out_features)
|
|||
|
|
self.fc2 = nn.Linear(out_features, out_features)
|
|||
|
|
self.bn2 = nn.BatchNorm1d(out_features)
|
|||
|
|
self.dropout = nn.Dropout(dropout)
|
|||
|
|
|
|||
|
|
self.shortcut = nn.Identity()
|
|||
|
|
if in_features != out_features:
|
|||
|
|
self.shortcut = nn.Sequential(
|
|||
|
|
nn.Linear(in_features, out_features),
|
|||
|
|
nn.BatchNorm1d(out_features),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|||
|
|
identity = self.shortcut(x)
|
|||
|
|
out = F.relu(self.bn1(self.fc1(x)))
|
|||
|
|
out = self.dropout(out)
|
|||
|
|
out = self.bn2(self.fc2(out))
|
|||
|
|
out = F.relu(out + identity)
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FusionNet(nn.Module):
|
|||
|
|
def __init__(self, num_classes: int = 2, num_eeg_features: int = 320, num_scales: int = 6) -> None:
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
self.input_norm = nn.BatchNorm1d(num_eeg_features)
|
|||
|
|
|
|||
|
|
self.block1 = ResidualBlock(num_eeg_features, 512, dropout=0.4)
|
|||
|
|
self.block2 = ResidualBlock(512, 256, dropout=0.3)
|
|||
|
|
self.block3 = ResidualBlock(256, 128, dropout=0.2)
|
|||
|
|
|
|||
|
|
self.attention = SEBlock(128, reduction=4)
|
|||
|
|
|
|||
|
|
self.final_fc = nn.Sequential(
|
|||
|
|
nn.Linear(128, 64),
|
|||
|
|
nn.BatchNorm1d(64),
|
|||
|
|
nn.ReLU(),
|
|||
|
|
nn.Dropout(0.2),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self.cls_head = nn.Linear(64, num_classes)
|
|||
|
|
|
|||
|
|
# 训练时有回归头也没关系(推理只用cls)
|
|||
|
|
self.reg_head = nn.Sequential(
|
|||
|
|
nn.Linear(64, 32),
|
|||
|
|
nn.ReLU(),
|
|||
|
|
nn.Dropout(0.1),
|
|||
|
|
nn.Linear(32, num_scales),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self._init_weights()
|
|||
|
|
|
|||
|
|
def _init_weights(self) -> None:
|
|||
|
|
for m in self.modules():
|
|||
|
|
if isinstance(m, nn.Linear):
|
|||
|
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
|||
|
|
if m.bias is not None:
|
|||
|
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
elif isinstance(m, nn.BatchNorm1d):
|
|||
|
|
nn.init.constant_(m.weight, 1)
|
|||
|
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
|
|||
|
|
def forward(self, x: torch.Tensor):
|
|||
|
|
x = self.input_norm(x)
|
|||
|
|
x = self.block1(x)
|
|||
|
|
x = self.block2(x)
|
|||
|
|
x = self.block3(x)
|
|||
|
|
x = self.attention(x)
|
|||
|
|
features = self.final_fc(x)
|
|||
|
|
cls_out = self.cls_head(features)
|
|||
|
|
reg_out = self.reg_head(features)
|
|||
|
|
return cls_out, reg_out
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =========================================================
|
|||
|
|
# 2) mat读取 + 通道裁剪
|
|||
|
|
# =========================================================
|
|||
|
|
|
|||
|
|
def _find_first_mat_file(folder: str) -> str:
|
|||
|
|
if not os.path.isdir(folder):
|
|||
|
|
raise RuntimeError(f"eeg_dir 不是文件夹: {folder}")
|
|||
|
|
mats = sorted([f for f in os.listdir(folder) if f.lower().endswith(".mat")])
|
|||
|
|
if not mats:
|
|||
|
|
raise RuntimeError(f"文件夹内没有 .mat 文件: {folder}")
|
|||
|
|
return os.path.join(folder, mats[0])
|
|||
|
|
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
import scipy.io
|
|||
|
|
|
|||
|
|
def _unwrap_singleton(x):
|
|||
|
|
"""
|
|||
|
|
把 (1,1) / (1,) 这种包裹层一直剥掉,直到不是 singleton。
|
|||
|
|
也处理 object array 的情况。
|
|||
|
|
"""
|
|||
|
|
while True:
|
|||
|
|
if isinstance(x, np.ndarray):
|
|||
|
|
if x.dtype == object and x.size == 1:
|
|||
|
|
x = x.item()
|
|||
|
|
continue
|
|||
|
|
if x.size == 1 and x.ndim >= 1:
|
|||
|
|
# 例如 (1,1) 或 (1,) 的数值/对象数组
|
|||
|
|
try:
|
|||
|
|
x = x.reshape(-1)[0]
|
|||
|
|
continue
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
break
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
def _try_get_struct_field(v, field_name="data"):
|
|||
|
|
"""
|
|||
|
|
尝试从以下几种结构中提取字段:
|
|||
|
|
1) scipy 读出的 mat_struct(有 _fieldnames)
|
|||
|
|
2) numpy structured/record array(dtype.names)
|
|||
|
|
"""
|
|||
|
|
# case 1: mat_struct(推荐 loadmat(..., struct_as_record=False, squeeze_me=True))
|
|||
|
|
if hasattr(v, "_fieldnames") and (field_name in getattr(v, "_fieldnames", [])):
|
|||
|
|
return getattr(v, field_name)
|
|||
|
|
|
|||
|
|
# case 2: structured array
|
|||
|
|
if isinstance(v, np.ndarray) and v.dtype.names and (field_name in v.dtype.names):
|
|||
|
|
# 常见是 v[field] 仍然是 ndarray / object,需要 unwrap
|
|||
|
|
try:
|
|||
|
|
return v[field_name]
|
|||
|
|
except Exception:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def load_eeg_from_mat_any_channels(mat_path: str) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
读取 .mat 中 EEG 数据,支持:
|
|||
|
|
- 直接二维矩阵 (T,C) 或 (C,T)
|
|||
|
|
- struct 里有字段 data
|
|||
|
|
返回统一为 float32 的 (T, C)
|
|||
|
|
"""
|
|||
|
|
# 用这两个参数会让 struct 更容易处理:字段变成属性,且自动 squeeze
|
|||
|
|
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
|
|||
|
|
|
|||
|
|
candidates = []
|
|||
|
|
|
|||
|
|
for k, v in mat.items():
|
|||
|
|
if k.startswith("__"):
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# --- 1) 直接二维数值矩阵 ---
|
|||
|
|
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
|
|||
|
|
candidates.append((k, v))
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# --- 2) struct/record:优先提取 data 字段 ---
|
|||
|
|
data_field = _try_get_struct_field(v, "data")
|
|||
|
|
if data_field is not None:
|
|||
|
|
data_field = _unwrap_singleton(data_field)
|
|||
|
|
|
|||
|
|
# data_field 可能仍然被 object 包一层
|
|||
|
|
if isinstance(data_field, np.ndarray) and data_field.dtype == object:
|
|||
|
|
data_field = _unwrap_singleton(data_field)
|
|||
|
|
|
|||
|
|
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
|
|||
|
|
# 只收数值矩阵
|
|||
|
|
if np.issubdtype(data_field.dtype, np.number) or data_field.dtype == object:
|
|||
|
|
candidates.append((f"{k}.data", data_field))
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# --- 3) object array:尝试 item() 解包后再看是不是二维数值矩阵/struct ---
|
|||
|
|
if isinstance(v, np.ndarray) and v.dtype == object:
|
|||
|
|
vv = _unwrap_singleton(v)
|
|||
|
|
|
|||
|
|
# 解包后若是二维数值矩阵
|
|||
|
|
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
|
|||
|
|
candidates.append((k, vv))
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 解包后若是 struct,再取 data
|
|||
|
|
data2 = _try_get_struct_field(vv, "data")
|
|||
|
|
if data2 is not None:
|
|||
|
|
data2 = _unwrap_singleton(data2)
|
|||
|
|
if isinstance(data2, np.ndarray) and data2.ndim == 2:
|
|||
|
|
candidates.append((f"{k}.data", data2))
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if not candidates:
|
|||
|
|
raise RuntimeError(f"mat里没找到可用EEG二维矩阵或struct.data:{mat_path}")
|
|||
|
|
|
|||
|
|
# 选一个最像 EEG 的(优先含32/64通道维度的)
|
|||
|
|
def score(arr: np.ndarray) -> int:
|
|||
|
|
s = 0
|
|||
|
|
if 64 in arr.shape: s += 10
|
|||
|
|
if 32 in arr.shape: s += 9
|
|||
|
|
if 128 in arr.shape: s += 8
|
|||
|
|
if 129 in arr.shape: s += 7
|
|||
|
|
s += int(np.prod(arr.shape) // 100000) # 大一些更像EEG
|
|||
|
|
return s
|
|||
|
|
|
|||
|
|
candidates.sort(key=lambda kv: score(kv[1]), reverse=True)
|
|||
|
|
key, eeg = candidates[0]
|
|||
|
|
|
|||
|
|
eeg = _unwrap_singleton(eeg)
|
|||
|
|
|
|||
|
|
# 如果还是 object dtype,尽力转成 float
|
|||
|
|
if isinstance(eeg, np.ndarray) and eeg.dtype == object:
|
|||
|
|
# 有时 object 里其实是数值
|
|||
|
|
eeg = np.array(eeg, dtype=np.float32)
|
|||
|
|
else:
|
|||
|
|
eeg = np.asarray(eeg, dtype=np.float32)
|
|||
|
|
|
|||
|
|
if eeg.ndim != 2:
|
|||
|
|
raise RuntimeError(f"解析结果不是二维矩阵: key={key}, shape={eeg.shape}, file={mat_path}")
|
|||
|
|
|
|||
|
|
# 统一为 (T, C)
|
|||
|
|
# 常见 (C,T) 或 (T,C),我们用“通道维通常较小”+ “32/64/128/129”判断
|
|||
|
|
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
|
|||
|
|
eeg = eeg.T
|
|||
|
|
elif eeg.shape[1] in (32, 64, 128, 129):
|
|||
|
|
# 如果第一维也是这些数且更小,可能是(C,T)
|
|||
|
|
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
|
|||
|
|
eeg = eeg.T
|
|||
|
|
|
|||
|
|
return eeg
|
|||
|
|
|
|||
|
|
|
|||
|
|
def ensure_32_channels(eeg: np.ndarray) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
输入 (T, C),输出 (T, 32)
|
|||
|
|
- 若C=64:按 IDX64_TO_32 选32通道
|
|||
|
|
- 若C=32:直接返回
|
|||
|
|
"""
|
|||
|
|
if eeg.ndim != 2:
|
|||
|
|
raise RuntimeError(f"EEG必须是二维(T,C),但得到: {eeg.shape}")
|
|||
|
|
|
|||
|
|
C = eeg.shape[1]
|
|||
|
|
if C == 64:
|
|||
|
|
idx = np.asarray(IDX64_TO_32, dtype=np.int64)
|
|||
|
|
if idx.min() < 0 or idx.max() >= 64:
|
|||
|
|
raise RuntimeError(f"IDX64_TO_32 越界:min={idx.min()}, max={idx.max()} (要求0~63)")
|
|||
|
|
return eeg[:, idx]
|
|||
|
|
if C == 32:
|
|||
|
|
return eeg
|
|||
|
|
raise RuntimeError(f"不支持的通道数C={C},当前只支持 64->32 或 32 直推。")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =========================================================
|
|||
|
|
# 3) 特征提取(DE + PSD(var近似))
|
|||
|
|
# =========================================================
|
|||
|
|
|
|||
|
|
class FeatureExtractor32:
|
|||
|
|
"""
|
|||
|
|
只针对32通道,输出维度:
|
|||
|
|
- USE_EXTENDED_FEATURES=True:DE(32*5) + PSD(32*5) = 320
|
|||
|
|
- 否则:DE(32*5) = 160
|
|||
|
|
"""
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
fs: int = SAMPLING_RATE,
|
|||
|
|
window_size: int = WINDOW_SIZE,
|
|||
|
|
stride: int = STRIDE,
|
|||
|
|
filter_order: int = 4,
|
|||
|
|
zero_phase: bool = False,
|
|||
|
|
) -> None:
|
|||
|
|
self.fs = fs
|
|||
|
|
self.window_size = window_size
|
|||
|
|
self.stride = stride
|
|||
|
|
self.filter_order = filter_order
|
|||
|
|
self.zero_phase = zero_phase
|
|||
|
|
|
|||
|
|
self._sos = {}
|
|||
|
|
for bn in BAND_NAMES:
|
|||
|
|
low, high = BANDS[bn]
|
|||
|
|
self._sos[bn] = signal.butter(
|
|||
|
|
self.filter_order, [low, high],
|
|||
|
|
btype="band", fs=self.fs, output="sos"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _filter_bands(self, eeg: np.ndarray) -> dict[str, np.ndarray]:
|
|||
|
|
out = {}
|
|||
|
|
for bn in BAND_NAMES:
|
|||
|
|
sos = self._sos[bn]
|
|||
|
|
if self.zero_phase:
|
|||
|
|
out[bn] = signal.sosfiltfilt(sos, eeg, axis=0).astype(np.float32)
|
|||
|
|
else:
|
|||
|
|
out[bn] = signal.sosfilt(sos, eeg, axis=0).astype(np.float32)
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
def extract(self, eeg32: np.ndarray) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
eeg32: (T, 32)
|
|||
|
|
return: feats (N_slices, feat_dim)
|
|||
|
|
"""
|
|||
|
|
if eeg32.ndim != 2 or eeg32.shape[1] != 32:
|
|||
|
|
raise RuntimeError(f"extract需要 (T,32),得到 {eeg32.shape}")
|
|||
|
|
|
|||
|
|
bands_data = self._filter_bands(eeg32)
|
|||
|
|
T = eeg32.shape[0]
|
|||
|
|
|
|||
|
|
feats = []
|
|||
|
|
for start in range(0, T - self.window_size, self.stride):
|
|||
|
|
end = start + self.window_size
|
|||
|
|
|
|||
|
|
de_list = []
|
|||
|
|
psd_list = []
|
|||
|
|
|
|||
|
|
for bn in BAND_NAMES:
|
|||
|
|
seg = bands_data[bn][start:end, :] # (W, 32)
|
|||
|
|
var = np.var(seg, axis=0, ddof=1) # (32,)
|
|||
|
|
|
|||
|
|
# DE
|
|||
|
|
de = 0.5 * np.log(2 * np.pi * np.e * (var + EPS))
|
|||
|
|
de_list.append(de)
|
|||
|
|
|
|||
|
|
if USE_EXTENDED_FEATURES:
|
|||
|
|
# PSD近似:log(var)
|
|||
|
|
psd_list.append(np.log(var + EPS))
|
|||
|
|
|
|||
|
|
de_feat = np.stack(de_list, axis=0).T.reshape(-1) # (32*5,)
|
|||
|
|
if USE_EXTENDED_FEATURES:
|
|||
|
|
psd_feat = np.stack(psd_list, axis=0).T.reshape(-1) # (32*5,)
|
|||
|
|
f = np.concatenate([de_feat, psd_feat], axis=0).astype(np.float32)
|
|||
|
|
else:
|
|||
|
|
f = de_feat.astype(np.float32)
|
|||
|
|
|
|||
|
|
feats.append(f)
|
|||
|
|
|
|||
|
|
if not feats:
|
|||
|
|
raise RuntimeError("EEG长度不足以切片(请检查T是否太短,或调整WINDOW_SIZE/STRIDE)")
|
|||
|
|
|
|||
|
|
return np.stack(feats, axis=0).astype(np.float32)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =========================================================
|
|||
|
|
# 4) 模型加载 + 推理接口
|
|||
|
|
# =========================================================
|
|||
|
|
|
|||
|
|
def _safe_torch_load(path: str):
|
|||
|
|
try:
|
|||
|
|
return torch.load(path, map_location=DEVICE, weights_only=False)
|
|||
|
|
except TypeError:
|
|||
|
|
return torch.load(path, map_location=DEVICE)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_model(model_path: str) -> tuple[FusionNet, dict]:
|
|||
|
|
"""
|
|||
|
|
返回: (model, ckpt_dict)
|
|||
|
|
"""
|
|||
|
|
obj = _safe_torch_load(model_path)
|
|||
|
|
if isinstance(obj, dict) and "model_state" in obj:
|
|||
|
|
ckpt = obj
|
|||
|
|
state = obj["model_state"]
|
|||
|
|
feat_dim = int(obj.get("feat_dim", 320))
|
|||
|
|
else:
|
|||
|
|
ckpt = {}
|
|||
|
|
state = obj
|
|||
|
|
feat_dim = 320
|
|||
|
|
|
|||
|
|
model = FusionNet(num_classes=2, num_eeg_features=feat_dim).to(DEVICE)
|
|||
|
|
model.load_state_dict(state, strict=True)
|
|||
|
|
model.eval()
|
|||
|
|
return model, ckpt
|
|||
|
|
|
|||
|
|
|
|||
|
|
def predict_hc_mdd(eeg_dir: str, model_path: str) -> dict:
|
|||
|
|
"""
|
|||
|
|
接口:传入 EEG文件夹 和 模型路径,返回判断结果 dict
|
|||
|
|
|
|||
|
|
返回字段:
|
|||
|
|
- mat_file: 使用的mat文件
|
|||
|
|
- pred_label: "HC" or "MDD"
|
|||
|
|
- p_mdd_mean: 切片p(MDD)均值
|
|||
|
|
- threshold: subject判定阈值
|
|||
|
|
- n_slices: 切片数
|
|||
|
|
"""
|
|||
|
|
mat_file = _find_first_mat_file(eeg_dir)
|
|||
|
|
|
|||
|
|
# 1) 读EEG (T,C),并保证变成32通道
|
|||
|
|
eeg = load_eeg_from_mat_any_channels(mat_file) # (T,C)
|
|||
|
|
eeg32 = ensure_32_channels(eeg) # (T,32)
|
|||
|
|
|
|||
|
|
# 2) 提特征 (N,feat_dim)
|
|||
|
|
extractor = FeatureExtractor32(fs=SAMPLING_RATE, window_size=WINDOW_SIZE, stride=STRIDE)
|
|||
|
|
feats = extractor.extract(eeg32) # (N, dim)
|
|||
|
|
|
|||
|
|
# 3) 加载模型
|
|||
|
|
model, ckpt = load_model(model_path)
|
|||
|
|
|
|||
|
|
# 4) 可选:归一化(若ckpt里保存的mean/std维度刚好匹配)
|
|||
|
|
mean = ckpt.get("global_mean", None) if isinstance(ckpt, dict) else None
|
|||
|
|
std = ckpt.get("global_std", None) if isinstance(ckpt, dict) else None
|
|||
|
|
|
|||
|
|
if mean is not None and std is not None:
|
|||
|
|
mean = np.asarray(mean, dtype=np.float32)
|
|||
|
|
std = np.asarray(std, dtype=np.float32)
|
|||
|
|
if mean.shape[0] == feats.shape[1] and std.shape[0] == feats.shape[1]:
|
|||
|
|
feats = (feats - mean) / (std + 1e-8)
|
|||
|
|
# 不匹配就跳过
|
|||
|
|
# 你说“先不用其他步骤或信息”,所以这里按“尽量运行”处理
|
|||
|
|
|
|||
|
|
# 5) 推理:对所有切片算 p(MDD),取均值做subject-level
|
|||
|
|
x = torch.from_numpy(feats).to(DEVICE)
|
|||
|
|
with torch.no_grad():
|
|||
|
|
cls_out, _ = model(x)
|
|||
|
|
prob_mdd = torch.softmax(cls_out, dim=1)[:, 1].detach().cpu().numpy()
|
|||
|
|
|
|||
|
|
p_mdd_mean = float(np.mean(prob_mdd))
|
|||
|
|
thr = float(ckpt.get("subject_threshold", DEFAULT_SUBJECT_THRESHOLD)) if isinstance(ckpt, dict) else float(DEFAULT_SUBJECT_THRESHOLD)
|
|||
|
|
pred_is_mdd = (p_mdd_mean >= thr)
|
|||
|
|
pred_label = "MDD" if pred_is_mdd else "HC"
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"mat_file": mat_file,
|
|||
|
|
"pred_label": pred_label,
|
|||
|
|
"p_mdd_mean": p_mdd_mean,
|
|||
|
|
"threshold": thr,
|
|||
|
|
"n_slices": int(feats.shape[0]),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =========================================================
|
|||
|
|
# 5) CLI:命令行运行入口
|
|||
|
|
# =========================================================
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser(description="Infer HC/MDD from 64ch->32ch EEG mat using a .pth FusionNet model (no Asym).")
|
|||
|
|
parser.add_argument("--eeg_dir", type=str, required=True, help="包含.mat EEG文件的文件夹(自动读取第一个.mat)")
|
|||
|
|
parser.add_argument("--model_path", type=str, required=True, help="训练好的.pth模型路径")
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
res = predict_hc_mdd(args.eeg_dir, args.model_path)
|
|||
|
|
print("\n========== 推理结果 ==========")
|
|||
|
|
print(f"MAT文件: {res['mat_file']}")
|
|||
|
|
print(f"切片数量: {res['n_slices']}")
|
|||
|
|
print(f"p(MDD)_mean: {res['p_mdd_mean']:.4f}")
|
|||
|
|
print(f"阈值thr: {res['threshold']:.4f}")
|
|||
|
|
print(f"预测结果: {res['pred_label']}")
|
|||
|
|
print("==============================\n")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|