# -*- coding: utf-8 -*- """ infer_pth.py 用途: - 从一个文件夹中自动读取第一个 .mat EEG 文件(64通道或32通道) - 若为64通道,则按 idx64_to_32 映射选出32通道 - 提取切片特征(DE + PSD(var近似),不含Asym) - 加载你训练好的 .pth 模型(FusionNet结构) - 输出该受试者的 HC / MDD 判断结果 运行方式(命令行): python infer_pth_from64_to32.py --eeg_dir "D:\\xxx\\folder" --model_path "C:\\xxx\\model.pth" 也可在其他py里import: from infer_pth_from64_to32 import predict_hc_mdd res = predict_hc_mdd(eeg_dir, model_path) print(res) """ from __future__ import annotations import os import argparse import numpy as np import scipy.io import scipy.signal as signal import torch import torch.nn as nn import torch.nn.functional as F # ========================================================= # 0) 配置区(按需改这里) # ========================================================= # 采样率(必须与训练时一致) SAMPLING_RATE = 250 # 滑窗参数(必须与训练时一致) WINDOW_SIZE = 500 STRIDE = 250 # 频段(必须与训练时一致) BAND_NAMES = ["Delta", "Theta", "Alpha", "Beta", "Gamma"] BANDS = { "Delta": (1, 4), "Theta": (4, 8), "Alpha": (8, 13), "Beta": (13, 30), "Gamma": (30, 50), } # 是否使用扩展特征(DE+PSD) USE_EXTENDED_FEATURES = True # 数值稳定项 EPS = 1e-12 # 设备 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 通道映射 IDX64_TO_32 = [ 23, # C5 47, # O1 39, # TP7 6, # FPZ 2, # PO6 21, # P4 35, # AF7 57, # AF3 1, # FP2 37, # T7 63, # F1 36, # A1 18, # FC4 31, # FC5 14, # FC2 48, # T8 60, # P2 41, # AF8 11, # CP1 0, # FP1 55, # PO7 59, # C1 22, # F5 10, # CP2 16, # C3 61, # P1 27, # CP5 17, # C4 26, # CP6 62, # F2 3, # POZ 13, # PO5 ] # 推理阈值:如果模型checkpoint里有 subject_threshold,会优先用它;否则用这个 DEFAULT_SUBJECT_THRESHOLD = 0.5 # ========================================================= # 1) 模型结构 # ========================================================= class SEBlock(nn.Module): def __init__(self, channels: int, reduction: int = 4) -> None: super().__init__() hidden = max(1, channels // reduction) self.fc1 = nn.Linear(channels, hidden) self.fc2 = nn.Linear(hidden, channels) def forward(self, x: torch.Tensor) -> torch.Tensor: se = F.relu(self.fc1(x)) se = torch.sigmoid(self.fc2(se)) return x * se class ResidualBlock(nn.Module): def __init__(self, in_features: int, out_features: int, dropout: float = 0.3) -> None: super().__init__() self.fc1 = nn.Linear(in_features, out_features) self.bn1 = nn.BatchNorm1d(out_features) self.fc2 = nn.Linear(out_features, out_features) self.bn2 = nn.BatchNorm1d(out_features) self.dropout = nn.Dropout(dropout) self.shortcut = nn.Identity() if in_features != out_features: self.shortcut = nn.Sequential( nn.Linear(in_features, out_features), nn.BatchNorm1d(out_features), ) def forward(self, x: torch.Tensor) -> torch.Tensor: identity = self.shortcut(x) out = F.relu(self.bn1(self.fc1(x))) out = self.dropout(out) out = self.bn2(self.fc2(out)) out = F.relu(out + identity) return out class FusionNet(nn.Module): def __init__(self, num_classes: int = 2, num_eeg_features: int = 320, num_scales: int = 6) -> None: super().__init__() self.input_norm = nn.BatchNorm1d(num_eeg_features) self.block1 = ResidualBlock(num_eeg_features, 512, dropout=0.4) self.block2 = ResidualBlock(512, 256, dropout=0.3) self.block3 = ResidualBlock(256, 128, dropout=0.2) self.attention = SEBlock(128, reduction=4) self.final_fc = nn.Sequential( nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.2), ) self.cls_head = nn.Linear(64, num_classes) # 训练时有回归头也没关系(推理只用cls) self.reg_head = nn.Sequential( nn.Linear(64, 32), nn.ReLU(), nn.Dropout(0.1), nn.Linear(32, num_scales), ) self._init_weights() def _init_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor): x = self.input_norm(x) x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.attention(x) features = self.final_fc(x) cls_out = self.cls_head(features) reg_out = self.reg_head(features) return cls_out, reg_out # ========================================================= # 2) mat读取 + 通道裁剪 # ========================================================= def _find_first_mat_file(folder: str) -> str: if not os.path.isdir(folder): raise RuntimeError(f"eeg_dir 不是文件夹: {folder}") mats = sorted([f for f in os.listdir(folder) if f.lower().endswith(".mat")]) if not mats: raise RuntimeError(f"文件夹内没有 .mat 文件: {folder}") return os.path.join(folder, mats[0]) import numpy as np import scipy.io def _unwrap_singleton(x): """ 把 (1,1) / (1,) 这种包裹层一直剥掉,直到不是 singleton。 也处理 object array 的情况。 """ while True: if isinstance(x, np.ndarray): if x.dtype == object and x.size == 1: x = x.item() continue if x.size == 1 and x.ndim >= 1: # 例如 (1,1) 或 (1,) 的数值/对象数组 try: x = x.reshape(-1)[0] continue except Exception: pass break return x def _try_get_struct_field(v, field_name="data"): """ 尝试从以下几种结构中提取字段: 1) scipy 读出的 mat_struct(有 _fieldnames) 2) numpy structured/record array(dtype.names) """ # case 1: mat_struct(推荐 loadmat(..., struct_as_record=False, squeeze_me=True)) if hasattr(v, "_fieldnames") and (field_name in getattr(v, "_fieldnames", [])): return getattr(v, field_name) # case 2: structured array if isinstance(v, np.ndarray) and v.dtype.names and (field_name in v.dtype.names): # 常见是 v[field] 仍然是 ndarray / object,需要 unwrap try: return v[field_name] except Exception: return None return None def load_eeg_from_mat_any_channels(mat_path: str) -> np.ndarray: """ 读取 .mat 中 EEG 数据,支持: - 直接二维矩阵 (T,C) 或 (C,T) - struct 里有字段 data 返回统一为 float32 的 (T, C) """ # 用这两个参数会让 struct 更容易处理:字段变成属性,且自动 squeeze mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True) candidates = [] for k, v in mat.items(): if k.startswith("__"): continue # --- 1) 直接二维数值矩阵 --- if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number): candidates.append((k, v)) continue # --- 2) struct/record:优先提取 data 字段 --- data_field = _try_get_struct_field(v, "data") if data_field is not None: data_field = _unwrap_singleton(data_field) # data_field 可能仍然被 object 包一层 if isinstance(data_field, np.ndarray) and data_field.dtype == object: data_field = _unwrap_singleton(data_field) if isinstance(data_field, np.ndarray) and data_field.ndim == 2: # 只收数值矩阵 if np.issubdtype(data_field.dtype, np.number) or data_field.dtype == object: candidates.append((f"{k}.data", data_field)) continue # --- 3) object array:尝试 item() 解包后再看是不是二维数值矩阵/struct --- if isinstance(v, np.ndarray) and v.dtype == object: vv = _unwrap_singleton(v) # 解包后若是二维数值矩阵 if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number): candidates.append((k, vv)) continue # 解包后若是 struct,再取 data data2 = _try_get_struct_field(vv, "data") if data2 is not None: data2 = _unwrap_singleton(data2) if isinstance(data2, np.ndarray) and data2.ndim == 2: candidates.append((f"{k}.data", data2)) continue if not candidates: raise RuntimeError(f"mat里没找到可用EEG二维矩阵或struct.data:{mat_path}") # 选一个最像 EEG 的(优先含32/64通道维度的) def score(arr: np.ndarray) -> int: s = 0 if 64 in arr.shape: s += 10 if 32 in arr.shape: s += 9 if 128 in arr.shape: s += 8 if 129 in arr.shape: s += 7 s += int(np.prod(arr.shape) // 100000) # 大一些更像EEG return s candidates.sort(key=lambda kv: score(kv[1]), reverse=True) key, eeg = candidates[0] eeg = _unwrap_singleton(eeg) # 如果还是 object dtype,尽力转成 float if isinstance(eeg, np.ndarray) and eeg.dtype == object: # 有时 object 里其实是数值 eeg = np.array(eeg, dtype=np.float32) else: eeg = np.asarray(eeg, dtype=np.float32) if eeg.ndim != 2: raise RuntimeError(f"解析结果不是二维矩阵: key={key}, shape={eeg.shape}, file={mat_path}") # 统一为 (T, C) # 常见 (C,T) 或 (T,C),我们用“通道维通常较小”+ “32/64/128/129”判断 if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129): eeg = eeg.T elif eeg.shape[1] in (32, 64, 128, 129): # 如果第一维也是这些数且更小,可能是(C,T) if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]: eeg = eeg.T return eeg def ensure_32_channels(eeg: np.ndarray) -> np.ndarray: """ 输入 (T, C),输出 (T, 32) - 若C=64:按 IDX64_TO_32 选32通道 - 若C=32:直接返回 """ if eeg.ndim != 2: raise RuntimeError(f"EEG必须是二维(T,C),但得到: {eeg.shape}") C = eeg.shape[1] if C == 64: idx = np.asarray(IDX64_TO_32, dtype=np.int64) if idx.min() < 0 or idx.max() >= 64: raise RuntimeError(f"IDX64_TO_32 越界:min={idx.min()}, max={idx.max()} (要求0~63)") return eeg[:, idx] if C == 32: return eeg raise RuntimeError(f"不支持的通道数C={C},当前只支持 64->32 或 32 直推。") # ========================================================= # 3) 特征提取(DE + PSD(var近似)) # ========================================================= class FeatureExtractor32: """ 只针对32通道,输出维度: - USE_EXTENDED_FEATURES=True:DE(32*5) + PSD(32*5) = 320 - 否则:DE(32*5) = 160 """ def __init__( self, fs: int = SAMPLING_RATE, window_size: int = WINDOW_SIZE, stride: int = STRIDE, filter_order: int = 4, zero_phase: bool = False, ) -> None: self.fs = fs self.window_size = window_size self.stride = stride self.filter_order = filter_order self.zero_phase = zero_phase self._sos = {} for bn in BAND_NAMES: low, high = BANDS[bn] self._sos[bn] = signal.butter( self.filter_order, [low, high], btype="band", fs=self.fs, output="sos" ) def _filter_bands(self, eeg: np.ndarray) -> dict[str, np.ndarray]: out = {} for bn in BAND_NAMES: sos = self._sos[bn] if self.zero_phase: out[bn] = signal.sosfiltfilt(sos, eeg, axis=0).astype(np.float32) else: out[bn] = signal.sosfilt(sos, eeg, axis=0).astype(np.float32) return out def extract(self, eeg32: np.ndarray) -> np.ndarray: """ eeg32: (T, 32) return: feats (N_slices, feat_dim) """ if eeg32.ndim != 2 or eeg32.shape[1] != 32: raise RuntimeError(f"extract需要 (T,32),得到 {eeg32.shape}") bands_data = self._filter_bands(eeg32) T = eeg32.shape[0] feats = [] for start in range(0, T - self.window_size, self.stride): end = start + self.window_size de_list = [] psd_list = [] for bn in BAND_NAMES: seg = bands_data[bn][start:end, :] # (W, 32) var = np.var(seg, axis=0, ddof=1) # (32,) # DE de = 0.5 * np.log(2 * np.pi * np.e * (var + EPS)) de_list.append(de) if USE_EXTENDED_FEATURES: # PSD近似:log(var) psd_list.append(np.log(var + EPS)) de_feat = np.stack(de_list, axis=0).T.reshape(-1) # (32*5,) if USE_EXTENDED_FEATURES: psd_feat = np.stack(psd_list, axis=0).T.reshape(-1) # (32*5,) f = np.concatenate([de_feat, psd_feat], axis=0).astype(np.float32) else: f = de_feat.astype(np.float32) feats.append(f) if not feats: raise RuntimeError("EEG长度不足以切片(请检查T是否太短,或调整WINDOW_SIZE/STRIDE)") return np.stack(feats, axis=0).astype(np.float32) # ========================================================= # 4) 模型加载 + 推理接口 # ========================================================= def _safe_torch_load(path: str): try: return torch.load(path, map_location=DEVICE, weights_only=False) except TypeError: return torch.load(path, map_location=DEVICE) def load_model(model_path: str) -> tuple[FusionNet, dict]: """ 返回: (model, ckpt_dict) """ obj = _safe_torch_load(model_path) if isinstance(obj, dict) and "model_state" in obj: ckpt = obj state = obj["model_state"] feat_dim = int(obj.get("feat_dim", 320)) else: ckpt = {} state = obj feat_dim = 320 model = FusionNet(num_classes=2, num_eeg_features=feat_dim).to(DEVICE) model.load_state_dict(state, strict=True) model.eval() return model, ckpt def predict_hc_mdd(eeg_dir: str, model_path: str) -> dict: """ 接口:传入 EEG文件夹 和 模型路径,返回判断结果 dict 返回字段: - mat_file: 使用的mat文件 - pred_label: "HC" or "MDD" - p_mdd_mean: 切片p(MDD)均值 - threshold: subject判定阈值 - n_slices: 切片数 """ mat_file = _find_first_mat_file(eeg_dir) # 1) 读EEG (T,C),并保证变成32通道 eeg = load_eeg_from_mat_any_channels(mat_file) # (T,C) eeg32 = ensure_32_channels(eeg) # (T,32) # 2) 提特征 (N,feat_dim) extractor = FeatureExtractor32(fs=SAMPLING_RATE, window_size=WINDOW_SIZE, stride=STRIDE) feats = extractor.extract(eeg32) # (N, dim) # 3) 加载模型 model, ckpt = load_model(model_path) # 4) 可选:归一化(若ckpt里保存的mean/std维度刚好匹配) mean = ckpt.get("global_mean", None) if isinstance(ckpt, dict) else None std = ckpt.get("global_std", None) if isinstance(ckpt, dict) else None if mean is not None and std is not None: mean = np.asarray(mean, dtype=np.float32) std = np.asarray(std, dtype=np.float32) if mean.shape[0] == feats.shape[1] and std.shape[0] == feats.shape[1]: feats = (feats - mean) / (std + 1e-8) # 不匹配就跳过 # 你说“先不用其他步骤或信息”,所以这里按“尽量运行”处理 # 5) 推理:对所有切片算 p(MDD),取均值做subject-level x = torch.from_numpy(feats).to(DEVICE) with torch.no_grad(): cls_out, _ = model(x) prob_mdd = torch.softmax(cls_out, dim=1)[:, 1].detach().cpu().numpy() p_mdd_mean = float(np.mean(prob_mdd)) thr = float(ckpt.get("subject_threshold", DEFAULT_SUBJECT_THRESHOLD)) if isinstance(ckpt, dict) else float(DEFAULT_SUBJECT_THRESHOLD) pred_is_mdd = (p_mdd_mean >= thr) pred_label = "MDD" if pred_is_mdd else "HC" return { "mat_file": mat_file, "pred_label": pred_label, "p_mdd_mean": p_mdd_mean, "threshold": thr, "n_slices": int(feats.shape[0]), } # ========================================================= # 5) CLI:命令行运行入口 # ========================================================= def main(): parser = argparse.ArgumentParser(description="Infer HC/MDD from 64ch->32ch EEG mat using a .pth FusionNet model (no Asym).") parser.add_argument("--eeg_dir", type=str, required=True, help="包含.mat EEG文件的文件夹(自动读取第一个.mat)") parser.add_argument("--model_path", type=str, required=True, help="训练好的.pth模型路径") args = parser.parse_args() res = predict_hc_mdd(args.eeg_dir, args.model_path) print("\n========== 推理结果 ==========") print(f"MAT文件: {res['mat_file']}") print(f"切片数量: {res['n_slices']}") print(f"p(MDD)_mean: {res['p_mdd_mean']:.4f}") print(f"阈值thr: {res['threshold']:.4f}") print(f"预测结果: {res['pred_label']}") print("==============================\n") if __name__ == "__main__": main()