Files
Depression_TMS/algorithm_V1/infer_pth.py
2026-06-01 13:18:36 +08:00

562 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
infer_pth.py
用途:
- 从一个文件夹中自动读取第一个 .mat EEG 文件64通道或32通道
- 若为64通道则按 idx64_to_32 映射选出32通道
- 提取切片特征DE + PSD(var近似)不含Asym
- 加载你训练好的 .pth 模型FusionNet结构
- 输出该受试者的 HC / MDD 判断结果
运行方式(命令行):
python infer_pth_from64_to32.py --eeg_dir "D:\\xxx\\folder" --model_path "C:\\xxx\\model.pth"
也可在其他py里import
from infer_pth_from64_to32 import predict_hc_mdd
res = predict_hc_mdd(eeg_dir, model_path)
print(res)
"""
from __future__ import annotations
import os
import argparse
import numpy as np
import scipy.io
import scipy.signal as signal
import torch
import torch.nn as nn
import torch.nn.functional as F
# =========================================================
# 0) 配置区(按需改这里)
# =========================================================
# 采样率(必须与训练时一致)
SAMPLING_RATE = 250
# 滑窗参数(必须与训练时一致)
WINDOW_SIZE = 500
STRIDE = 250
# 频段(必须与训练时一致)
BAND_NAMES = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]
BANDS = {
"Delta": (1, 4),
"Theta": (4, 8),
"Alpha": (8, 13),
"Beta": (13, 30),
"Gamma": (30, 50),
}
# 是否使用扩展特征DE+PSD
USE_EXTENDED_FEATURES = True
# 数值稳定项
EPS = 1e-12
# 设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 通道映射
IDX64_TO_32 = [
23, # C5
47, # O1
39, # TP7
6, # FPZ
2, # PO6
21, # P4
35, # AF7
57, # AF3
1, # FP2
37, # T7
63, # F1
36, # A1
18, # FC4
31, # FC5
14, # FC2
48, # T8
60, # P2
41, # AF8
11, # CP1
0, # FP1
55, # PO7
59, # C1
22, # F5
10, # CP2
16, # C3
61, # P1
27, # CP5
17, # C4
26, # CP6
62, # F2
3, # POZ
13, # PO5
]
# 推理阈值如果模型checkpoint里有 subject_threshold会优先用它否则用这个
DEFAULT_SUBJECT_THRESHOLD = 0.5
# =========================================================
# 1) 模型结构
# =========================================================
class SEBlock(nn.Module):
def __init__(self, channels: int, reduction: int = 4) -> None:
super().__init__()
hidden = max(1, channels // reduction)
self.fc1 = nn.Linear(channels, hidden)
self.fc2 = nn.Linear(hidden, channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
se = F.relu(self.fc1(x))
se = torch.sigmoid(self.fc2(se))
return x * se
class ResidualBlock(nn.Module):
def __init__(self, in_features: int, out_features: int, dropout: float = 0.3) -> None:
super().__init__()
self.fc1 = nn.Linear(in_features, out_features)
self.bn1 = nn.BatchNorm1d(out_features)
self.fc2 = nn.Linear(out_features, out_features)
self.bn2 = nn.BatchNorm1d(out_features)
self.dropout = nn.Dropout(dropout)
self.shortcut = nn.Identity()
if in_features != out_features:
self.shortcut = nn.Sequential(
nn.Linear(in_features, out_features),
nn.BatchNorm1d(out_features),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.shortcut(x)
out = F.relu(self.bn1(self.fc1(x)))
out = self.dropout(out)
out = self.bn2(self.fc2(out))
out = F.relu(out + identity)
return out
class FusionNet(nn.Module):
def __init__(self, num_classes: int = 2, num_eeg_features: int = 320, num_scales: int = 6) -> None:
super().__init__()
self.input_norm = nn.BatchNorm1d(num_eeg_features)
self.block1 = ResidualBlock(num_eeg_features, 512, dropout=0.4)
self.block2 = ResidualBlock(512, 256, dropout=0.3)
self.block3 = ResidualBlock(256, 128, dropout=0.2)
self.attention = SEBlock(128, reduction=4)
self.final_fc = nn.Sequential(
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(0.2),
)
self.cls_head = nn.Linear(64, num_classes)
# 训练时有回归头也没关系推理只用cls
self.reg_head = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(32, num_scales),
)
self._init_weights()
def _init_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor):
x = self.input_norm(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.attention(x)
features = self.final_fc(x)
cls_out = self.cls_head(features)
reg_out = self.reg_head(features)
return cls_out, reg_out
# =========================================================
# 2) mat读取 + 通道裁剪
# =========================================================
def _find_first_mat_file(folder: str) -> str:
if not os.path.isdir(folder):
raise RuntimeError(f"eeg_dir 不是文件夹: {folder}")
mats = sorted([f for f in os.listdir(folder) if f.lower().endswith(".mat")])
if not mats:
raise RuntimeError(f"文件夹内没有 .mat 文件: {folder}")
return os.path.join(folder, mats[0])
import numpy as np
import scipy.io
def _unwrap_singleton(x):
"""
把 (1,1) / (1,) 这种包裹层一直剥掉,直到不是 singleton。
也处理 object array 的情况。
"""
while True:
if isinstance(x, np.ndarray):
if x.dtype == object and x.size == 1:
x = x.item()
continue
if x.size == 1 and x.ndim >= 1:
# 例如 (1,1) 或 (1,) 的数值/对象数组
try:
x = x.reshape(-1)[0]
continue
except Exception:
pass
break
return x
def _try_get_struct_field(v, field_name="data"):
"""
尝试从以下几种结构中提取字段:
1) scipy 读出的 mat_struct有 _fieldnames
2) numpy structured/record arraydtype.names
"""
# case 1: mat_struct推荐 loadmat(..., struct_as_record=False, squeeze_me=True)
if hasattr(v, "_fieldnames") and (field_name in getattr(v, "_fieldnames", [])):
return getattr(v, field_name)
# case 2: structured array
if isinstance(v, np.ndarray) and v.dtype.names and (field_name in v.dtype.names):
# 常见是 v[field] 仍然是 ndarray / object需要 unwrap
try:
return v[field_name]
except Exception:
return None
return None
def load_eeg_from_mat_any_channels(mat_path: str) -> np.ndarray:
"""
读取 .mat 中 EEG 数据,支持:
- 直接二维矩阵 (T,C) 或 (C,T)
- struct 里有字段 data
返回统一为 float32 的 (T, C)
"""
# 用这两个参数会让 struct 更容易处理:字段变成属性,且自动 squeeze
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
candidates = []
for k, v in mat.items():
if k.startswith("__"):
continue
# --- 1) 直接二维数值矩阵 ---
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
candidates.append((k, v))
continue
# --- 2) struct/record优先提取 data 字段 ---
data_field = _try_get_struct_field(v, "data")
if data_field is not None:
data_field = _unwrap_singleton(data_field)
# data_field 可能仍然被 object 包一层
if isinstance(data_field, np.ndarray) and data_field.dtype == object:
data_field = _unwrap_singleton(data_field)
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
# 只收数值矩阵
if np.issubdtype(data_field.dtype, np.number) or data_field.dtype == object:
candidates.append((f"{k}.data", data_field))
continue
# --- 3) object array尝试 item() 解包后再看是不是二维数值矩阵/struct ---
if isinstance(v, np.ndarray) and v.dtype == object:
vv = _unwrap_singleton(v)
# 解包后若是二维数值矩阵
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
candidates.append((k, vv))
continue
# 解包后若是 struct再取 data
data2 = _try_get_struct_field(vv, "data")
if data2 is not None:
data2 = _unwrap_singleton(data2)
if isinstance(data2, np.ndarray) and data2.ndim == 2:
candidates.append((f"{k}.data", data2))
continue
if not candidates:
raise RuntimeError(f"mat里没找到可用EEG二维矩阵或struct.data{mat_path}")
# 选一个最像 EEG 的优先含32/64通道维度的
def score(arr: np.ndarray) -> int:
s = 0
if 64 in arr.shape: s += 10
if 32 in arr.shape: s += 9
if 128 in arr.shape: s += 8
if 129 in arr.shape: s += 7
s += int(np.prod(arr.shape) // 100000) # 大一些更像EEG
return s
candidates.sort(key=lambda kv: score(kv[1]), reverse=True)
key, eeg = candidates[0]
eeg = _unwrap_singleton(eeg)
# 如果还是 object dtype尽力转成 float
if isinstance(eeg, np.ndarray) and eeg.dtype == object:
# 有时 object 里其实是数值
eeg = np.array(eeg, dtype=np.float32)
else:
eeg = np.asarray(eeg, dtype=np.float32)
if eeg.ndim != 2:
raise RuntimeError(f"解析结果不是二维矩阵: key={key}, shape={eeg.shape}, file={mat_path}")
# 统一为 (T, C)
# 常见 (C,T) 或 (T,C),我们用“通道维通常较小”+ “32/64/128/129”判断
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
eeg = eeg.T
elif eeg.shape[1] in (32, 64, 128, 129):
# 如果第一维也是这些数且更小,可能是(C,T)
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
eeg = eeg.T
return eeg
def ensure_32_channels(eeg: np.ndarray) -> np.ndarray:
"""
输入 (T, C),输出 (T, 32)
- 若C=64按 IDX64_TO_32 选32通道
- 若C=32直接返回
"""
if eeg.ndim != 2:
raise RuntimeError(f"EEG必须是二维(T,C),但得到: {eeg.shape}")
C = eeg.shape[1]
if C == 64:
idx = np.asarray(IDX64_TO_32, dtype=np.int64)
if idx.min() < 0 or idx.max() >= 64:
raise RuntimeError(f"IDX64_TO_32 越界min={idx.min()}, max={idx.max()} (要求0~63)")
return eeg[:, idx]
if C == 32:
return eeg
raise RuntimeError(f"不支持的通道数C={C},当前只支持 64->32 或 32 直推。")
# =========================================================
# 3) 特征提取DE + PSD(var近似)
# =========================================================
class FeatureExtractor32:
"""
只针对32通道输出维度
- USE_EXTENDED_FEATURES=TrueDE(32*5) + PSD(32*5) = 320
- 否则DE(32*5) = 160
"""
def __init__(
self,
fs: int = SAMPLING_RATE,
window_size: int = WINDOW_SIZE,
stride: int = STRIDE,
filter_order: int = 4,
zero_phase: bool = False,
) -> None:
self.fs = fs
self.window_size = window_size
self.stride = stride
self.filter_order = filter_order
self.zero_phase = zero_phase
self._sos = {}
for bn in BAND_NAMES:
low, high = BANDS[bn]
self._sos[bn] = signal.butter(
self.filter_order, [low, high],
btype="band", fs=self.fs, output="sos"
)
def _filter_bands(self, eeg: np.ndarray) -> dict[str, np.ndarray]:
out = {}
for bn in BAND_NAMES:
sos = self._sos[bn]
if self.zero_phase:
out[bn] = signal.sosfiltfilt(sos, eeg, axis=0).astype(np.float32)
else:
out[bn] = signal.sosfilt(sos, eeg, axis=0).astype(np.float32)
return out
def extract(self, eeg32: np.ndarray) -> np.ndarray:
"""
eeg32: (T, 32)
return: feats (N_slices, feat_dim)
"""
if eeg32.ndim != 2 or eeg32.shape[1] != 32:
raise RuntimeError(f"extract需要 (T,32),得到 {eeg32.shape}")
bands_data = self._filter_bands(eeg32)
T = eeg32.shape[0]
feats = []
for start in range(0, T - self.window_size, self.stride):
end = start + self.window_size
de_list = []
psd_list = []
for bn in BAND_NAMES:
seg = bands_data[bn][start:end, :] # (W, 32)
var = np.var(seg, axis=0, ddof=1) # (32,)
# DE
de = 0.5 * np.log(2 * np.pi * np.e * (var + EPS))
de_list.append(de)
if USE_EXTENDED_FEATURES:
# PSD近似log(var)
psd_list.append(np.log(var + EPS))
de_feat = np.stack(de_list, axis=0).T.reshape(-1) # (32*5,)
if USE_EXTENDED_FEATURES:
psd_feat = np.stack(psd_list, axis=0).T.reshape(-1) # (32*5,)
f = np.concatenate([de_feat, psd_feat], axis=0).astype(np.float32)
else:
f = de_feat.astype(np.float32)
feats.append(f)
if not feats:
raise RuntimeError("EEG长度不足以切片请检查T是否太短或调整WINDOW_SIZE/STRIDE")
return np.stack(feats, axis=0).astype(np.float32)
# =========================================================
# 4) 模型加载 + 推理接口
# =========================================================
def _safe_torch_load(path: str):
try:
return torch.load(path, map_location=DEVICE, weights_only=False)
except TypeError:
return torch.load(path, map_location=DEVICE)
def load_model(model_path: str) -> tuple[FusionNet, dict]:
"""
返回: (model, ckpt_dict)
"""
obj = _safe_torch_load(model_path)
if isinstance(obj, dict) and "model_state" in obj:
ckpt = obj
state = obj["model_state"]
feat_dim = int(obj.get("feat_dim", 320))
else:
ckpt = {}
state = obj
feat_dim = 320
model = FusionNet(num_classes=2, num_eeg_features=feat_dim).to(DEVICE)
model.load_state_dict(state, strict=True)
model.eval()
return model, ckpt
def predict_hc_mdd(eeg_dir: str, model_path: str) -> dict:
"""
接口:传入 EEG文件夹 和 模型路径,返回判断结果 dict
返回字段:
- mat_file: 使用的mat文件
- pred_label: "HC" or "MDD"
- p_mdd_mean: 切片p(MDD)均值
- threshold: subject判定阈值
- n_slices: 切片数
"""
mat_file = _find_first_mat_file(eeg_dir)
# 1) 读EEG (T,C)并保证变成32通道
eeg = load_eeg_from_mat_any_channels(mat_file) # (T,C)
eeg32 = ensure_32_channels(eeg) # (T,32)
# 2) 提特征 (N,feat_dim)
extractor = FeatureExtractor32(fs=SAMPLING_RATE, window_size=WINDOW_SIZE, stride=STRIDE)
feats = extractor.extract(eeg32) # (N, dim)
# 3) 加载模型
model, ckpt = load_model(model_path)
# 4) 可选归一化若ckpt里保存的mean/std维度刚好匹配
mean = ckpt.get("global_mean", None) if isinstance(ckpt, dict) else None
std = ckpt.get("global_std", None) if isinstance(ckpt, dict) else None
if mean is not None and std is not None:
mean = np.asarray(mean, dtype=np.float32)
std = np.asarray(std, dtype=np.float32)
if mean.shape[0] == feats.shape[1] and std.shape[0] == feats.shape[1]:
feats = (feats - mean) / (std + 1e-8)
# 不匹配就跳过
# 你说“先不用其他步骤或信息”,所以这里按“尽量运行”处理
# 5) 推理:对所有切片算 p(MDD)取均值做subject-level
x = torch.from_numpy(feats).to(DEVICE)
with torch.no_grad():
cls_out, _ = model(x)
prob_mdd = torch.softmax(cls_out, dim=1)[:, 1].detach().cpu().numpy()
p_mdd_mean = float(np.mean(prob_mdd))
thr = float(ckpt.get("subject_threshold", DEFAULT_SUBJECT_THRESHOLD)) if isinstance(ckpt, dict) else float(DEFAULT_SUBJECT_THRESHOLD)
pred_is_mdd = (p_mdd_mean >= thr)
pred_label = "MDD" if pred_is_mdd else "HC"
return {
"mat_file": mat_file,
"pred_label": pred_label,
"p_mdd_mean": p_mdd_mean,
"threshold": thr,
"n_slices": int(feats.shape[0]),
}
# =========================================================
# 5) CLI命令行运行入口
# =========================================================
def main():
parser = argparse.ArgumentParser(description="Infer HC/MDD from 64ch->32ch EEG mat using a .pth FusionNet model (no Asym).")
parser.add_argument("--eeg_dir", type=str, required=True, help="包含.mat EEG文件的文件夹自动读取第一个.mat")
parser.add_argument("--model_path", type=str, required=True, help="训练好的.pth模型路径")
args = parser.parse_args()
res = predict_hc_mdd(args.eeg_dir, args.model_path)
print("\n========== 推理结果 ==========")
print(f"MAT文件: {res['mat_file']}")
print(f"切片数量: {res['n_slices']}")
print(f"p(MDD)_mean: {res['p_mdd_mean']:.4f}")
print(f"阈值thr: {res['threshold']:.4f}")
print(f"预测结果: {res['pred_label']}")
print("==============================\n")
if __name__ == "__main__":
main()