original push
This commit is contained in:
561
algorithm_V0/algorithm_fromXjtu/infer_pth.py
Normal file
561
algorithm_V0/algorithm_fromXjtu/infer_pth.py
Normal file
@@ -0,0 +1,561 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
infer_pth.py
|
||||
|
||||
用途:
|
||||
- 从一个文件夹中自动读取第一个 .mat EEG 文件(64通道或32通道)
|
||||
- 若为64通道,则按 idx64_to_32 映射选出32通道
|
||||
- 提取切片特征(DE + PSD(var近似),不含Asym)
|
||||
- 加载你训练好的 .pth 模型(FusionNet结构)
|
||||
- 输出该受试者的 HC / MDD 判断结果
|
||||
|
||||
运行方式(命令行):
|
||||
python infer_pth_from64_to32.py --eeg_dir "D:\\xxx\\folder" --model_path "C:\\xxx\\model.pth"
|
||||
|
||||
也可在其他py里import:
|
||||
from infer_pth_from64_to32 import predict_hc_mdd
|
||||
res = predict_hc_mdd(eeg_dir, model_path)
|
||||
print(res)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
import scipy.signal as signal
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 0) 配置区(按需改这里)
|
||||
# =========================================================
|
||||
|
||||
# 采样率(必须与训练时一致)
|
||||
SAMPLING_RATE = 250
|
||||
|
||||
# 滑窗参数(必须与训练时一致)
|
||||
WINDOW_SIZE = 500
|
||||
STRIDE = 250
|
||||
|
||||
# 频段(必须与训练时一致)
|
||||
BAND_NAMES = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]
|
||||
BANDS = {
|
||||
"Delta": (1, 4),
|
||||
"Theta": (4, 8),
|
||||
"Alpha": (8, 13),
|
||||
"Beta": (13, 30),
|
||||
"Gamma": (30, 50),
|
||||
}
|
||||
|
||||
# 是否使用扩展特征(DE+PSD)
|
||||
USE_EXTENDED_FEATURES = True
|
||||
|
||||
# 数值稳定项
|
||||
EPS = 1e-12
|
||||
|
||||
# 设备
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 通道映射
|
||||
IDX64_TO_32 = [
|
||||
23, # C5
|
||||
47, # O1
|
||||
39, # TP7
|
||||
6, # FPZ
|
||||
2, # PO6
|
||||
21, # P4
|
||||
35, # AF7
|
||||
57, # AF3
|
||||
1, # FP2
|
||||
37, # T7
|
||||
63, # F1
|
||||
36, # A1
|
||||
18, # FC4
|
||||
31, # FC5
|
||||
14, # FC2
|
||||
48, # T8
|
||||
60, # P2
|
||||
41, # AF8
|
||||
11, # CP1
|
||||
0, # FP1
|
||||
55, # PO7
|
||||
59, # C1
|
||||
22, # F5
|
||||
10, # CP2
|
||||
16, # C3
|
||||
61, # P1
|
||||
27, # CP5
|
||||
17, # C4
|
||||
26, # CP6
|
||||
62, # F2
|
||||
3, # POZ
|
||||
13, # PO5
|
||||
]
|
||||
|
||||
# 推理阈值:如果模型checkpoint里有 subject_threshold,会优先用它;否则用这个
|
||||
DEFAULT_SUBJECT_THRESHOLD = 0.5
|
||||
|
||||
# =========================================================
|
||||
# 1) 模型结构
|
||||
# =========================================================
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
def __init__(self, channels: int, reduction: int = 4) -> None:
|
||||
super().__init__()
|
||||
hidden = max(1, channels // reduction)
|
||||
self.fc1 = nn.Linear(channels, hidden)
|
||||
self.fc2 = nn.Linear(hidden, channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
se = F.relu(self.fc1(x))
|
||||
se = torch.sigmoid(self.fc2(se))
|
||||
return x * se
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, dropout: float = 0.3) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_features, out_features)
|
||||
self.bn1 = nn.BatchNorm1d(out_features)
|
||||
self.fc2 = nn.Linear(out_features, out_features)
|
||||
self.bn2 = nn.BatchNorm1d(out_features)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.shortcut = nn.Identity()
|
||||
if in_features != out_features:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Linear(in_features, out_features),
|
||||
nn.BatchNorm1d(out_features),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
identity = self.shortcut(x)
|
||||
out = F.relu(self.bn1(self.fc1(x)))
|
||||
out = self.dropout(out)
|
||||
out = self.bn2(self.fc2(out))
|
||||
out = F.relu(out + identity)
|
||||
return out
|
||||
|
||||
|
||||
class FusionNet(nn.Module):
|
||||
def __init__(self, num_classes: int = 2, num_eeg_features: int = 320, num_scales: int = 6) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.input_norm = nn.BatchNorm1d(num_eeg_features)
|
||||
|
||||
self.block1 = ResidualBlock(num_eeg_features, 512, dropout=0.4)
|
||||
self.block2 = ResidualBlock(512, 256, dropout=0.3)
|
||||
self.block3 = ResidualBlock(256, 128, dropout=0.2)
|
||||
|
||||
self.attention = SEBlock(128, reduction=4)
|
||||
|
||||
self.final_fc = nn.Sequential(
|
||||
nn.Linear(128, 64),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
)
|
||||
|
||||
self.cls_head = nn.Linear(64, num_classes)
|
||||
|
||||
# 训练时有回归头也没关系(推理只用cls)
|
||||
self.reg_head = nn.Sequential(
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(32, num_scales),
|
||||
)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self) -> None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.input_norm(x)
|
||||
x = self.block1(x)
|
||||
x = self.block2(x)
|
||||
x = self.block3(x)
|
||||
x = self.attention(x)
|
||||
features = self.final_fc(x)
|
||||
cls_out = self.cls_head(features)
|
||||
reg_out = self.reg_head(features)
|
||||
return cls_out, reg_out
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 2) mat读取 + 通道裁剪
|
||||
# =========================================================
|
||||
|
||||
def _find_first_mat_file(folder: str) -> str:
|
||||
if not os.path.isdir(folder):
|
||||
raise RuntimeError(f"eeg_dir 不是文件夹: {folder}")
|
||||
mats = sorted([f for f in os.listdir(folder) if f.lower().endswith(".mat")])
|
||||
if not mats:
|
||||
raise RuntimeError(f"文件夹内没有 .mat 文件: {folder}")
|
||||
return os.path.join(folder, mats[0])
|
||||
|
||||
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
|
||||
def _unwrap_singleton(x):
|
||||
"""
|
||||
把 (1,1) / (1,) 这种包裹层一直剥掉,直到不是 singleton。
|
||||
也处理 object array 的情况。
|
||||
"""
|
||||
while True:
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype == object and x.size == 1:
|
||||
x = x.item()
|
||||
continue
|
||||
if x.size == 1 and x.ndim >= 1:
|
||||
# 例如 (1,1) 或 (1,) 的数值/对象数组
|
||||
try:
|
||||
x = x.reshape(-1)[0]
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
return x
|
||||
|
||||
def _try_get_struct_field(v, field_name="data"):
|
||||
"""
|
||||
尝试从以下几种结构中提取字段:
|
||||
1) scipy 读出的 mat_struct(有 _fieldnames)
|
||||
2) numpy structured/record array(dtype.names)
|
||||
"""
|
||||
# case 1: mat_struct(推荐 loadmat(..., struct_as_record=False, squeeze_me=True))
|
||||
if hasattr(v, "_fieldnames") and (field_name in getattr(v, "_fieldnames", [])):
|
||||
return getattr(v, field_name)
|
||||
|
||||
# case 2: structured array
|
||||
if isinstance(v, np.ndarray) and v.dtype.names and (field_name in v.dtype.names):
|
||||
# 常见是 v[field] 仍然是 ndarray / object,需要 unwrap
|
||||
try:
|
||||
return v[field_name]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def load_eeg_from_mat_any_channels(mat_path: str) -> np.ndarray:
|
||||
"""
|
||||
读取 .mat 中 EEG 数据,支持:
|
||||
- 直接二维矩阵 (T,C) 或 (C,T)
|
||||
- struct 里有字段 data
|
||||
返回统一为 float32 的 (T, C)
|
||||
"""
|
||||
# 用这两个参数会让 struct 更容易处理:字段变成属性,且自动 squeeze
|
||||
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
|
||||
|
||||
candidates = []
|
||||
|
||||
for k, v in mat.items():
|
||||
if k.startswith("__"):
|
||||
continue
|
||||
|
||||
# --- 1) 直接二维数值矩阵 ---
|
||||
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
|
||||
candidates.append((k, v))
|
||||
continue
|
||||
|
||||
# --- 2) struct/record:优先提取 data 字段 ---
|
||||
data_field = _try_get_struct_field(v, "data")
|
||||
if data_field is not None:
|
||||
data_field = _unwrap_singleton(data_field)
|
||||
|
||||
# data_field 可能仍然被 object 包一层
|
||||
if isinstance(data_field, np.ndarray) and data_field.dtype == object:
|
||||
data_field = _unwrap_singleton(data_field)
|
||||
|
||||
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
|
||||
# 只收数值矩阵
|
||||
if np.issubdtype(data_field.dtype, np.number) or data_field.dtype == object:
|
||||
candidates.append((f"{k}.data", data_field))
|
||||
continue
|
||||
|
||||
# --- 3) object array:尝试 item() 解包后再看是不是二维数值矩阵/struct ---
|
||||
if isinstance(v, np.ndarray) and v.dtype == object:
|
||||
vv = _unwrap_singleton(v)
|
||||
|
||||
# 解包后若是二维数值矩阵
|
||||
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
|
||||
candidates.append((k, vv))
|
||||
continue
|
||||
|
||||
# 解包后若是 struct,再取 data
|
||||
data2 = _try_get_struct_field(vv, "data")
|
||||
if data2 is not None:
|
||||
data2 = _unwrap_singleton(data2)
|
||||
if isinstance(data2, np.ndarray) and data2.ndim == 2:
|
||||
candidates.append((f"{k}.data", data2))
|
||||
continue
|
||||
|
||||
if not candidates:
|
||||
raise RuntimeError(f"mat里没找到可用EEG二维矩阵或struct.data:{mat_path}")
|
||||
|
||||
# 选一个最像 EEG 的(优先含32/64通道维度的)
|
||||
def score(arr: np.ndarray) -> int:
|
||||
s = 0
|
||||
if 64 in arr.shape: s += 10
|
||||
if 32 in arr.shape: s += 9
|
||||
if 128 in arr.shape: s += 8
|
||||
if 129 in arr.shape: s += 7
|
||||
s += int(np.prod(arr.shape) // 100000) # 大一些更像EEG
|
||||
return s
|
||||
|
||||
candidates.sort(key=lambda kv: score(kv[1]), reverse=True)
|
||||
key, eeg = candidates[0]
|
||||
|
||||
eeg = _unwrap_singleton(eeg)
|
||||
|
||||
# 如果还是 object dtype,尽力转成 float
|
||||
if isinstance(eeg, np.ndarray) and eeg.dtype == object:
|
||||
# 有时 object 里其实是数值
|
||||
eeg = np.array(eeg, dtype=np.float32)
|
||||
else:
|
||||
eeg = np.asarray(eeg, dtype=np.float32)
|
||||
|
||||
if eeg.ndim != 2:
|
||||
raise RuntimeError(f"解析结果不是二维矩阵: key={key}, shape={eeg.shape}, file={mat_path}")
|
||||
|
||||
# 统一为 (T, C)
|
||||
# 常见 (C,T) 或 (T,C),我们用“通道维通常较小”+ “32/64/128/129”判断
|
||||
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
|
||||
eeg = eeg.T
|
||||
elif eeg.shape[1] in (32, 64, 128, 129):
|
||||
# 如果第一维也是这些数且更小,可能是(C,T)
|
||||
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
|
||||
eeg = eeg.T
|
||||
|
||||
return eeg
|
||||
|
||||
|
||||
def ensure_32_channels(eeg: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
输入 (T, C),输出 (T, 32)
|
||||
- 若C=64:按 IDX64_TO_32 选32通道
|
||||
- 若C=32:直接返回
|
||||
"""
|
||||
if eeg.ndim != 2:
|
||||
raise RuntimeError(f"EEG必须是二维(T,C),但得到: {eeg.shape}")
|
||||
|
||||
C = eeg.shape[1]
|
||||
if C == 64:
|
||||
idx = np.asarray(IDX64_TO_32, dtype=np.int64)
|
||||
if idx.min() < 0 or idx.max() >= 64:
|
||||
raise RuntimeError(f"IDX64_TO_32 越界:min={idx.min()}, max={idx.max()} (要求0~63)")
|
||||
return eeg[:, idx]
|
||||
if C == 32:
|
||||
return eeg
|
||||
raise RuntimeError(f"不支持的通道数C={C},当前只支持 64->32 或 32 直推。")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 3) 特征提取(DE + PSD(var近似))
|
||||
# =========================================================
|
||||
|
||||
class FeatureExtractor32:
|
||||
"""
|
||||
只针对32通道,输出维度:
|
||||
- USE_EXTENDED_FEATURES=True:DE(32*5) + PSD(32*5) = 320
|
||||
- 否则:DE(32*5) = 160
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = SAMPLING_RATE,
|
||||
window_size: int = WINDOW_SIZE,
|
||||
stride: int = STRIDE,
|
||||
filter_order: int = 4,
|
||||
zero_phase: bool = False,
|
||||
) -> None:
|
||||
self.fs = fs
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
self.filter_order = filter_order
|
||||
self.zero_phase = zero_phase
|
||||
|
||||
self._sos = {}
|
||||
for bn in BAND_NAMES:
|
||||
low, high = BANDS[bn]
|
||||
self._sos[bn] = signal.butter(
|
||||
self.filter_order, [low, high],
|
||||
btype="band", fs=self.fs, output="sos"
|
||||
)
|
||||
|
||||
def _filter_bands(self, eeg: np.ndarray) -> dict[str, np.ndarray]:
|
||||
out = {}
|
||||
for bn in BAND_NAMES:
|
||||
sos = self._sos[bn]
|
||||
if self.zero_phase:
|
||||
out[bn] = signal.sosfiltfilt(sos, eeg, axis=0).astype(np.float32)
|
||||
else:
|
||||
out[bn] = signal.sosfilt(sos, eeg, axis=0).astype(np.float32)
|
||||
return out
|
||||
|
||||
def extract(self, eeg32: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
eeg32: (T, 32)
|
||||
return: feats (N_slices, feat_dim)
|
||||
"""
|
||||
if eeg32.ndim != 2 or eeg32.shape[1] != 32:
|
||||
raise RuntimeError(f"extract需要 (T,32),得到 {eeg32.shape}")
|
||||
|
||||
bands_data = self._filter_bands(eeg32)
|
||||
T = eeg32.shape[0]
|
||||
|
||||
feats = []
|
||||
for start in range(0, T - self.window_size, self.stride):
|
||||
end = start + self.window_size
|
||||
|
||||
de_list = []
|
||||
psd_list = []
|
||||
|
||||
for bn in BAND_NAMES:
|
||||
seg = bands_data[bn][start:end, :] # (W, 32)
|
||||
var = np.var(seg, axis=0, ddof=1) # (32,)
|
||||
|
||||
# DE
|
||||
de = 0.5 * np.log(2 * np.pi * np.e * (var + EPS))
|
||||
de_list.append(de)
|
||||
|
||||
if USE_EXTENDED_FEATURES:
|
||||
# PSD近似:log(var)
|
||||
psd_list.append(np.log(var + EPS))
|
||||
|
||||
de_feat = np.stack(de_list, axis=0).T.reshape(-1) # (32*5,)
|
||||
if USE_EXTENDED_FEATURES:
|
||||
psd_feat = np.stack(psd_list, axis=0).T.reshape(-1) # (32*5,)
|
||||
f = np.concatenate([de_feat, psd_feat], axis=0).astype(np.float32)
|
||||
else:
|
||||
f = de_feat.astype(np.float32)
|
||||
|
||||
feats.append(f)
|
||||
|
||||
if not feats:
|
||||
raise RuntimeError("EEG长度不足以切片(请检查T是否太短,或调整WINDOW_SIZE/STRIDE)")
|
||||
|
||||
return np.stack(feats, axis=0).astype(np.float32)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 4) 模型加载 + 推理接口
|
||||
# =========================================================
|
||||
|
||||
def _safe_torch_load(path: str):
|
||||
try:
|
||||
return torch.load(path, map_location=DEVICE, weights_only=False)
|
||||
except TypeError:
|
||||
return torch.load(path, map_location=DEVICE)
|
||||
|
||||
|
||||
def load_model(model_path: str) -> tuple[FusionNet, dict]:
|
||||
"""
|
||||
返回: (model, ckpt_dict)
|
||||
"""
|
||||
obj = _safe_torch_load(model_path)
|
||||
if isinstance(obj, dict) and "model_state" in obj:
|
||||
ckpt = obj
|
||||
state = obj["model_state"]
|
||||
feat_dim = int(obj.get("feat_dim", 320))
|
||||
else:
|
||||
ckpt = {}
|
||||
state = obj
|
||||
feat_dim = 320
|
||||
|
||||
model = FusionNet(num_classes=2, num_eeg_features=feat_dim).to(DEVICE)
|
||||
model.load_state_dict(state, strict=True)
|
||||
model.eval()
|
||||
return model, ckpt
|
||||
|
||||
|
||||
def predict_hc_mdd(eeg_dir: str, model_path: str) -> dict:
|
||||
"""
|
||||
接口:传入 EEG文件夹 和 模型路径,返回判断结果 dict
|
||||
|
||||
返回字段:
|
||||
- mat_file: 使用的mat文件
|
||||
- pred_label: "HC" or "MDD"
|
||||
- p_mdd_mean: 切片p(MDD)均值
|
||||
- threshold: subject判定阈值
|
||||
- n_slices: 切片数
|
||||
"""
|
||||
mat_file = _find_first_mat_file(eeg_dir)
|
||||
|
||||
# 1) 读EEG (T,C),并保证变成32通道
|
||||
eeg = load_eeg_from_mat_any_channels(mat_file) # (T,C)
|
||||
eeg32 = ensure_32_channels(eeg) # (T,32)
|
||||
|
||||
# 2) 提特征 (N,feat_dim)
|
||||
extractor = FeatureExtractor32(fs=SAMPLING_RATE, window_size=WINDOW_SIZE, stride=STRIDE)
|
||||
feats = extractor.extract(eeg32) # (N, dim)
|
||||
|
||||
# 3) 加载模型
|
||||
model, ckpt = load_model(model_path)
|
||||
|
||||
# 4) 可选:归一化(若ckpt里保存的mean/std维度刚好匹配)
|
||||
mean = ckpt.get("global_mean", None) if isinstance(ckpt, dict) else None
|
||||
std = ckpt.get("global_std", None) if isinstance(ckpt, dict) else None
|
||||
|
||||
if mean is not None and std is not None:
|
||||
mean = np.asarray(mean, dtype=np.float32)
|
||||
std = np.asarray(std, dtype=np.float32)
|
||||
if mean.shape[0] == feats.shape[1] and std.shape[0] == feats.shape[1]:
|
||||
feats = (feats - mean) / (std + 1e-8)
|
||||
# 不匹配就跳过
|
||||
# 你说“先不用其他步骤或信息”,所以这里按“尽量运行”处理
|
||||
|
||||
# 5) 推理:对所有切片算 p(MDD),取均值做subject-level
|
||||
x = torch.from_numpy(feats).to(DEVICE)
|
||||
with torch.no_grad():
|
||||
cls_out, _ = model(x)
|
||||
prob_mdd = torch.softmax(cls_out, dim=1)[:, 1].detach().cpu().numpy()
|
||||
|
||||
p_mdd_mean = float(np.mean(prob_mdd))
|
||||
thr = float(ckpt.get("subject_threshold", DEFAULT_SUBJECT_THRESHOLD)) if isinstance(ckpt, dict) else float(DEFAULT_SUBJECT_THRESHOLD)
|
||||
pred_is_mdd = (p_mdd_mean >= thr)
|
||||
pred_label = "MDD" if pred_is_mdd else "HC"
|
||||
|
||||
return {
|
||||
"mat_file": mat_file,
|
||||
"pred_label": pred_label,
|
||||
"p_mdd_mean": p_mdd_mean,
|
||||
"threshold": thr,
|
||||
"n_slices": int(feats.shape[0]),
|
||||
}
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 5) CLI:命令行运行入口
|
||||
# =========================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Infer HC/MDD from 64ch->32ch EEG mat using a .pth FusionNet model (no Asym).")
|
||||
parser.add_argument("--eeg_dir", type=str, required=True, help="包含.mat EEG文件的文件夹(自动读取第一个.mat)")
|
||||
parser.add_argument("--model_path", type=str, required=True, help="训练好的.pth模型路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
res = predict_hc_mdd(args.eeg_dir, args.model_path)
|
||||
print("\n========== 推理结果 ==========")
|
||||
print(f"MAT文件: {res['mat_file']}")
|
||||
print(f"切片数量: {res['n_slices']}")
|
||||
print(f"p(MDD)_mean: {res['p_mdd_mean']:.4f}")
|
||||
print(f"阈值thr: {res['threshold']:.4f}")
|
||||
print(f"预测结果: {res['pred_label']}")
|
||||
print("==============================\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user