original push
This commit is contained in:
56
algorithm_V1/.gitignore
vendored
Normal file
56
algorithm_V1/.gitignore
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
# Distribution / packaging
|
||||
build/
|
||||
dist/
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# data format
|
||||
*.dat
|
||||
*.csv
|
||||
*.edf
|
||||
*.event
|
||||
*.edf.event
|
||||
*.zip
|
||||
*.xlsx
|
||||
*.mat
|
||||
*.json
|
||||
*.7z
|
||||
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself
|
||||
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||
.idea/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Other common ignores
|
||||
node_modules/
|
||||
dist/
|
||||
tmp/
|
||||
temp/
|
||||
|
||||
# Project-specific ignores
|
||||
# Ignore all directories in the root
|
||||
# merge64ch_0127/
|
||||
/P300_speller/braindecode/
|
||||
/P300_speller/data/
|
||||
/P300_speller/pyRiemann/
|
||||
/P300_speller/README/
|
||||
/merge64ch_new/
|
||||
/merge64ch_tianjinZMQdebug/
|
||||
|
||||
|
||||
|
||||
|
||||
250
algorithm_V1/bdf_analyzer.py
Normal file
250
algorithm_V1/bdf_analyzer.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
bdf_analyzer.py
|
||||
|
||||
Analyze .bdf files - print data amplitude range and mean values.
|
||||
Supports single file or batch processing of all .bdf files in a directory.
|
||||
"""
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import mne
|
||||
import scipy.signal as signal
|
||||
|
||||
|
||||
def analyze_bdf(filepath: str, unit: str = "uV") -> dict:
|
||||
"""
|
||||
Analyze a single .bdf file and compute statistics.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filepath : str
|
||||
Path to .bdf file
|
||||
unit : str, optional
|
||||
Display unit (default: uV for microvolts)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Dictionary containing statistics
|
||||
"""
|
||||
print("=" * 60)
|
||||
print(f"File: {os.path.basename(filepath)}")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Read BDF file
|
||||
raw = mne.io.read_raw_bdf(filepath, preload=True, verbose=False)
|
||||
|
||||
# Get data (n_channels, n_times) in V
|
||||
data = raw.get_data()
|
||||
n_channels, n_times = data.shape
|
||||
sfreq = raw.info["sfreq"]
|
||||
|
||||
# Convert to microvolts (uV)
|
||||
data_uv = data * 1e6
|
||||
|
||||
# Raw data statistics (V)
|
||||
raw_all = data.flatten()
|
||||
raw_min = float(np.min(raw_all))
|
||||
raw_max = float(np.max(raw_all))
|
||||
raw_mean = float(np.mean(raw_all))
|
||||
raw_std = float(np.std(raw_all))
|
||||
|
||||
# Overall statistics
|
||||
all_values = data_uv.flatten()
|
||||
min_val = np.min(all_values)
|
||||
max_val = np.max(all_values)
|
||||
mean_val = np.mean(all_values)
|
||||
std_val = np.std(all_values)
|
||||
|
||||
print(f"Sampling rate: {sfreq:.2f} Hz")
|
||||
print(f"Channels: {n_channels}")
|
||||
print(f"Samples: {n_times:,}")
|
||||
print(f"Duration: {n_times / sfreq:.2f} sec")
|
||||
print("-" * 40)
|
||||
print(f"[RAW - V]")
|
||||
print(f"Amplitude range: [{raw_min:.6f}, {raw_max:.6f}] V")
|
||||
print(f"Mean value: {raw_mean:.6f} V")
|
||||
print(f"Std deviation: {raw_std:.6f} V")
|
||||
print(f"[RAW - uV]")
|
||||
print(f"Amplitude range: [{min_val:.4f}, {max_val:.4f}] uV")
|
||||
print(f"Mean value: {mean_val:.4f} uV")
|
||||
print(f"Std deviation: {std_val:.4f} uV")
|
||||
print("-" * 40)
|
||||
|
||||
# Per-channel statistics
|
||||
print("\nPer-channel statistics:")
|
||||
print(f"{'Channel':<15} {'Min (uV)':<15} {'Max (uV)':<15} {'Mean (uV)':<15} {'PSD Peak (Hz)':<15}")
|
||||
print("-" * 75)
|
||||
|
||||
channel_stats = []
|
||||
for i, ch_name in enumerate(raw.ch_names):
|
||||
ch_data = data_uv[i, :]
|
||||
ch_min = np.min(ch_data)
|
||||
ch_max = np.max(ch_data)
|
||||
ch_mean = np.mean(ch_data)
|
||||
|
||||
# PSD peak frequency
|
||||
nperseg = min(1024, n_times)
|
||||
freqs, pxx = signal.welch(ch_data, fs=sfreq, nperseg=nperseg)
|
||||
peak_idx = np.argmax(pxx)
|
||||
peak_freq = freqs[peak_idx]
|
||||
|
||||
print(f"{ch_name:<15} {ch_min:<15.4f} {ch_max:<15.4f} {ch_mean:<15.4f} {peak_freq:<15.2f}")
|
||||
channel_stats.append({
|
||||
"name": ch_name,
|
||||
"min": ch_min,
|
||||
"max": ch_max,
|
||||
"mean": ch_mean,
|
||||
"psd_peak_hz": peak_freq
|
||||
})
|
||||
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
return {
|
||||
"filepath": filepath,
|
||||
"sfreq": sfreq,
|
||||
"n_channels": n_channels,
|
||||
"n_times": n_times,
|
||||
"duration": n_times / sfreq,
|
||||
"raw_min": raw_min,
|
||||
"raw_max": raw_max,
|
||||
"raw_mean": raw_mean,
|
||||
"raw_std": raw_std,
|
||||
"min": min_val,
|
||||
"max": max_val,
|
||||
"mean": mean_val,
|
||||
"std": std_val,
|
||||
"channels": channel_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to read file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def analyze_directory(dir_path: str) -> list:
|
||||
"""
|
||||
Analyze all .bdf files in a directory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dir_path : str
|
||||
Directory path
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
List of analysis results for all files
|
||||
"""
|
||||
# Find all .bdf files
|
||||
bdf_files = sorted(glob.glob(os.path.join(dir_path, "*.bdf")))
|
||||
|
||||
if not bdf_files:
|
||||
print(f"[WARNING] No .bdf files found in: {dir_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(bdf_files)} .bdf file(s)\n")
|
||||
|
||||
results = []
|
||||
for filepath in bdf_files:
|
||||
result = analyze_bdf(filepath)
|
||||
if result:
|
||||
results.append(result)
|
||||
|
||||
# Summary statistics
|
||||
if results:
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary")
|
||||
print("=" * 60)
|
||||
|
||||
all_means = [r["mean"] for r in results]
|
||||
all_mins = [r["min"] for r in results]
|
||||
all_maxs = [r["max"] for r in results]
|
||||
|
||||
print(f"File count: {len(results)}")
|
||||
print(f"[RAW - V] Overall range: [{min(r['raw_min'] for r in results):.6f}, {max(r['raw_max'] for r in results):.6f}] V")
|
||||
print(f"[RAW - V] Avg mean: {np.mean([r['raw_mean'] for r in results]):.6f} V")
|
||||
print(f"[RAW - uV] Overall range: [{min(all_mins):.4f}, {max(all_maxs):.4f}] uV")
|
||||
print(f"[RAW - uV] Avg mean: {np.mean(all_means):.4f} uV")
|
||||
print(f"Max value file: {results[np.argmax(all_maxs)]['filepath']}")
|
||||
print(f"Min value file: {results[np.argmin(all_mins)]['filepath']}")
|
||||
|
||||
# Per-channel mean summary across all files
|
||||
n_channels = len(results[0]["channels"])
|
||||
ch_names = [results[0]["channels"][i]["name"] for i in range(n_channels)]
|
||||
ch_mean_over_files = []
|
||||
for ch_idx in range(n_channels):
|
||||
ch_means = [results[f_idx]["channels"][ch_idx]["mean"] for f_idx in range(len(results))]
|
||||
ch_mean_over_files.append(np.mean(ch_means))
|
||||
|
||||
ch_peak_over_files = []
|
||||
for ch_idx in range(n_channels):
|
||||
ch_peaks = [results[f_idx]["channels"][ch_idx]["psd_peak_hz"] for f_idx in range(len(results))]
|
||||
ch_peak_over_files.append(np.mean(ch_peaks))
|
||||
|
||||
print("\nPer-channel mean across all files:")
|
||||
print(f"{'Channel':<15} {'Mean (uV)':<15} {'PSD Peak (Hz)':<15}")
|
||||
print("-" * 45)
|
||||
for ch_name, ch_mean, ch_peak in zip(ch_names, ch_mean_over_files, ch_peak_over_files):
|
||||
print(f"{ch_name:<15} {ch_mean:<15.4f} {ch_peak:<15.2f}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function with CLI support."""
|
||||
import argparse
|
||||
|
||||
# Default analysis directory
|
||||
default_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "raw_data")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Analyze .bdf files - print amplitude range and mean values",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=f"""
|
||||
Examples:
|
||||
python bdf_analyzer.py # Analyze all .bdf in raw_data/
|
||||
python bdf_analyzer.py data/test.bdf # Analyze single file
|
||||
python bdf_analyzer.py data/ # Analyze all .bdf in directory
|
||||
python bdf_analyzer.py . -u mV # Current dir, unit mV
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"path",
|
||||
nargs="?",
|
||||
default=default_dir,
|
||||
help="Path to BDF file or directory containing BDF files (default: raw_data/)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-u", "--unit",
|
||||
choices=["uV", "mV", "V"],
|
||||
default="uV",
|
||||
help="Display unit (default: uV)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
filepath = args.path
|
||||
|
||||
# Determine file or directory mode
|
||||
if os.path.isfile(filepath):
|
||||
# Single file mode
|
||||
result = analyze_bdf(filepath, unit=args.unit)
|
||||
if result:
|
||||
print("Analysis complete!")
|
||||
elif os.path.isdir(filepath):
|
||||
# Directory mode
|
||||
results = analyze_directory(filepath)
|
||||
if results:
|
||||
print("\nBatch analysis complete!")
|
||||
else:
|
||||
print("No analyzable files found")
|
||||
else:
|
||||
print(f"[ERROR] File does not exist: {filepath}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
203
algorithm_V1/bdf_to_mat.py
Normal file
203
algorithm_V1/bdf_to_mat.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Convert BDF file to MAT format.
|
||||
|
||||
This script converts a BDF (Biosemi Data Format) EEG file to .mat format,
|
||||
matching the structure of eeg_data.mat.
|
||||
|
||||
Structure of eeg_data.mat:
|
||||
- data: (n_samples, n_channels) float64
|
||||
- chn: (1, n_channels) object - channel names
|
||||
- sample_rate: (1, 1) int64
|
||||
- node_number: (1, 1) int64
|
||||
- t: (n_samples, 1) float64 - time vector in seconds
|
||||
- electrode_name: (1, n_channels) object - electrode names (10-20 system)
|
||||
- electrode_xyz: (n_channels, 3) float64 - electrode 3D coordinates
|
||||
- electrode_coord_system: (1,) <U21
|
||||
- meta: (1, 1) structured - metadata
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
import mne
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def get_standard_electrode_coords():
|
||||
"""
|
||||
Standard 10-20 system electrode coordinates.
|
||||
Returns a dictionary mapping electrode names to x, y, z coordinates.
|
||||
Coordinates are approximate spherical projections.
|
||||
"""
|
||||
coords = {
|
||||
'FP1': (-0.0293, 0.0903, -0.0033),
|
||||
'FP2': (0.0293, 0.0903, -0.0033),
|
||||
'FPZ': (0.0, 0.0903, -0.0033),
|
||||
'AF7': (-0.0658, 0.0734, -0.0224),
|
||||
'AF3': (-0.0350, 0.0812, -0.0183),
|
||||
'AF4': (0.0350, 0.0812, -0.0183),
|
||||
'AF8': (0.0658, 0.0734, -0.0224),
|
||||
'F7': (-0.0815, 0.0467, -0.0336),
|
||||
'F5': (-0.0667, 0.0503, -0.0351),
|
||||
'F3': (-0.0489, 0.0560, -0.0370),
|
||||
'F1': (-0.0254, 0.0584, -0.0384),
|
||||
'FZ': (0.0, 0.0584, -0.0384),
|
||||
'F2': (0.0254, 0.0584, -0.0384),
|
||||
'F4': (0.0489, 0.0560, -0.0370),
|
||||
'F6': (0.0667, 0.0503, -0.0351),
|
||||
'F8': (0.0815, 0.0467, -0.0336),
|
||||
'FT7': (-0.0880, 0.0229, -0.0397),
|
||||
'FC5': (-0.0699, 0.0317, -0.0402),
|
||||
'FC3': (-0.0514, 0.0362, -0.0411),
|
||||
'FC1': (-0.0268, 0.0383, -0.0419),
|
||||
'FCZ': (0.0, 0.0383, -0.0419),
|
||||
'FC2': (0.0268, 0.0383, -0.0419),
|
||||
'FC4': (0.0514, 0.0362, -0.0411),
|
||||
'FC6': (0.0699, 0.0317, -0.0402),
|
||||
'FT8': (0.0880, 0.0229, -0.0397),
|
||||
'T7': (-0.0958, 0.0, -0.0411),
|
||||
'T8': (0.0958, 0.0, -0.0411),
|
||||
'C5': (-0.0739, 0.0, -0.0425),
|
||||
'C3': (-0.0544, 0.0, -0.0436),
|
||||
'C1': (-0.0283, 0.0, -0.0444),
|
||||
'CZ': (0.0, 0.0, -0.0444),
|
||||
'C2': (0.0283, 0.0, -0.0444),
|
||||
'C4': (0.0544, 0.0, -0.0436),
|
||||
'C6': (0.0739, 0.0, -0.0425),
|
||||
'TP7': (-0.0880, -0.0229, -0.0397),
|
||||
'CP5': (-0.0699, -0.0317, -0.0402),
|
||||
'CP3': (-0.0514, -0.0362, -0.0411),
|
||||
'CP1': (-0.0268, -0.0383, -0.0419),
|
||||
'CPZ': (0.0, -0.0383, -0.0419),
|
||||
'CP2': (0.0268, -0.0383, -0.0419),
|
||||
'CP4': (0.0514, -0.0362, -0.0411),
|
||||
'CP6': (0.0699, -0.0317, -0.0402),
|
||||
'TP8': (0.0880, -0.0229, -0.0397),
|
||||
'P7': (-0.0815, -0.0467, -0.0336),
|
||||
'P5': (-0.0667, -0.0503, -0.0351),
|
||||
'P3': (-0.0489, -0.0560, -0.0370),
|
||||
'P1': (-0.0254, -0.0584, -0.0384),
|
||||
'PZ': (0.0, -0.0584, -0.0384),
|
||||
'P2': (0.0254, -0.0584, -0.0384),
|
||||
'P4': (0.0489, -0.0560, -0.0370),
|
||||
'P6': (0.0667, -0.0503, -0.0351),
|
||||
'P8': (0.0815, -0.0467, -0.0336),
|
||||
'PO7': (-0.0658, -0.0734, -0.0224),
|
||||
'PO5': (-0.0503, -0.0744, -0.0258),
|
||||
'PO3': (-0.0350, -0.0812, -0.0183),
|
||||
'POZ': (0.0, -0.0829, -0.0172),
|
||||
'PO4': (0.0350, -0.0812, -0.0183),
|
||||
'PO6': (0.0503, -0.0744, -0.0258),
|
||||
'PO8': (0.0658, -0.0734, -0.0224),
|
||||
'O1': (-0.0293, -0.0903, -0.0033),
|
||||
'OZ': (0.0, -0.0903, -0.0033),
|
||||
'O2': (0.0293, -0.0903, -0.0033),
|
||||
'CB1': (-0.0618, -0.0380, -0.0387),
|
||||
'CB2': (0.0618, -0.0380, -0.0387),
|
||||
'A1': (-0.0958, 0.0, 0.0),
|
||||
'A2': (0.0958, 0.0, 0.0),
|
||||
}
|
||||
return coords
|
||||
|
||||
|
||||
def bdf_to_mat(bdf_path, output_path, subject_id='unknown', session_id='unknown'):
|
||||
"""
|
||||
Convert BDF file to MAT format matching eeg_data.mat structure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bdf_path : str
|
||||
Path to the input BDF file.
|
||||
output_path : str
|
||||
Path to the output MAT file.
|
||||
subject_id : str, optional
|
||||
Subject identifier. Default is 'unknown'.
|
||||
session_id : str, optional
|
||||
Session identifier. Default is 'unknown'.
|
||||
"""
|
||||
print(f'Loading BDF file: {bdf_path}')
|
||||
raw = mne.io.read_raw_bdf(bdf_path, preload=True, verbose=False)
|
||||
|
||||
# Get basic info
|
||||
ch_names = raw.ch_names
|
||||
n_channels = len(ch_names)
|
||||
sfreq = int(raw.info['sfreq'])
|
||||
data = raw.get_data()
|
||||
|
||||
# BDF data shape: (n_channels, n_samples)
|
||||
# Convert to eeg_data.mat format: (n_samples, n_channels)
|
||||
data = data.T
|
||||
n_samples = data.shape[0]
|
||||
|
||||
# Create time vector (in seconds)
|
||||
t = np.arange(n_samples) / sfreq
|
||||
t = t.reshape(-1, 1)
|
||||
|
||||
# Create channel names array (matching eeg_data.mat structure)
|
||||
chn = np.array([[name] for name in ch_names], dtype=object)
|
||||
|
||||
# Create electrode names (same as channel names for BDF)
|
||||
electrode_name = np.array([[name] for name in ch_names], dtype=object)
|
||||
|
||||
# Get electrode coordinates
|
||||
standard_coords = get_standard_electrode_coords()
|
||||
electrode_xyz = np.zeros((n_channels, 3))
|
||||
for i, name in enumerate(ch_names):
|
||||
if name in standard_coords:
|
||||
electrode_xyz[i] = standard_coords[name]
|
||||
else:
|
||||
print(f'Warning: No standard coordinate for electrode {name}')
|
||||
|
||||
# Create metadata structure
|
||||
start_time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
meta = np.array([(subject_id, session_id, 'CMS/DRL', start_time_str)],
|
||||
dtype=[('subject_id', 'O'), ('session_id', 'O'),
|
||||
('ref', 'O'), ('start_time', 'O')])
|
||||
|
||||
# Create the EEG structure (matching eeg_data.mat format)
|
||||
eeg_struct = np.array([(data, chn, [[sfreq]], [[n_channels]],
|
||||
t, electrode_name, electrode_xyz,
|
||||
'buzsaki', meta)],
|
||||
dtype=[('data', 'O'), ('chn', 'O'),
|
||||
('sample_rate', 'O'), ('node_number', 'O'),
|
||||
('t', 'O'), ('electrode_name', 'O'),
|
||||
('electrode_xyz', 'O'), ('electrode_coord_system', 'O'),
|
||||
('meta', 'O')])
|
||||
|
||||
# Save to MAT file
|
||||
print(f'Saving to: {output_path}')
|
||||
scipy.io.savemat(output_path, {'eeg': eeg_struct}, do_compression=True)
|
||||
|
||||
print(f'\nConversion complete!')
|
||||
print(f' Channels: {n_channels}')
|
||||
print(f' Samples: {n_samples}')
|
||||
print(f' Duration: {n_samples / sfreq:.2f} seconds')
|
||||
print(f' Sample rate: {sfreq} Hz')
|
||||
print(f' Data shape: {data.shape}')
|
||||
|
||||
|
||||
def main():
|
||||
# File paths
|
||||
bdf_path = r'D:\Ivey\Code_New_Proj\Debug_Depression\algorithm_version_0521_v0\0515-18.bdf'
|
||||
output_path = r'D:\Ivey\Code_New_Proj\Debug_Depression\algorithm_version_0521_v0\0515-18.mat'
|
||||
|
||||
# Convert
|
||||
bdf_to_mat(bdf_path, output_path, subject_id='lvpeng', session_id='01')
|
||||
|
||||
# Verify the output
|
||||
print('\n=== Verification ===')
|
||||
mat_data = scipy.io.loadmat(output_path)
|
||||
eeg = mat_data['eeg'][0, 0]
|
||||
|
||||
print(f'Output file keys: {list(mat_data.keys())}')
|
||||
print(f'eeg.data shape: {eeg["data"].shape}')
|
||||
print(f'eeg.chn shape: {eeg["chn"].shape}')
|
||||
print(f'eeg.sample_rate: {eeg["sample_rate"][0, 0]}')
|
||||
print(f'eeg.t shape: {eeg["t"].shape}')
|
||||
print(f'eeg.electrode_name: {eeg["electrode_name"].shape}')
|
||||
print(f'eeg.electrode_xyz shape: {eeg["electrode_xyz"].shape}')
|
||||
print(f'eeg.electrode_coord_system: {eeg["electrode_coord_system"][0]}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
96
algorithm_V1/build_algorithm.spec
Normal file
96
algorithm_V1/build_algorithm.spec
Normal file
@@ -0,0 +1,96 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
|
||||
import sys
|
||||
import os
|
||||
from PyInstaller.utils.hooks import collect_submodules, collect_data_files
|
||||
|
||||
# ========================================================
|
||||
# 1. 工程配置区 (Project Config)
|
||||
# ========================================================
|
||||
block_cipher = None
|
||||
ENTRY_POINT = 'runDecoder.py'
|
||||
APP_NAME = 'Depression_Decoder'
|
||||
|
||||
# ========================================================
|
||||
# 2. 依赖分析 (Dependency Analysis)
|
||||
# ========================================================
|
||||
# 收集 mne, sklearn, scipy 可能遗漏的隐藏导入
|
||||
hidden_imports = [
|
||||
'infer_pth', # 你的动态导入模块
|
||||
'sklearn.utils._cython_blas',
|
||||
'sklearn.neighbors.typedefs',
|
||||
'sklearn.neighbors.quad_tree',
|
||||
'sklearn.tree',
|
||||
'sklearn.tree._utils',
|
||||
]
|
||||
# 自动收集 mne 的子模块
|
||||
hidden_imports += collect_submodules('mne')
|
||||
# 收集 torch 相关的隐式导入(虽然 PyInstaller 通常能处理,但显式更安全)
|
||||
hidden_imports += ['torch', 'torchvision']
|
||||
|
||||
# ========================================================
|
||||
# 3. 资源锚定 (Data Anchoring)
|
||||
# ========================================================
|
||||
# Analysis 中的 datas 用于将文件嵌入到内部(运行时在临时目录或 _internal)
|
||||
# 这里我们留空,改为在 COLLECT 阶段通过 Tree 显式复制到 EXE 旁,
|
||||
# 这样生成的文件夹里能直接看到 model 和 raw_data
|
||||
datas = []
|
||||
|
||||
# 收集 mne 的数据文件(如果需要默认配置)
|
||||
datas += collect_data_files('mne')
|
||||
|
||||
# ========================================================
|
||||
# 4. 构建流程 (Build Process)
|
||||
# ========================================================
|
||||
a = Analysis(
|
||||
[ENTRY_POINT],
|
||||
pathex=[],
|
||||
binaries=[],
|
||||
datas=datas,
|
||||
hiddenimports=hidden_imports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=['tkinter', 'PyQt5', 'PySide2', 'IPython'], # 排除 GUI 和交互式库减小体积
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
[],
|
||||
exclude_binaries=True,
|
||||
name=APP_NAME,
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=False,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
|
||||
# ========================================================
|
||||
# 5. 打包模式: OneDir (单文件夹) + 资源旁路
|
||||
# ========================================================
|
||||
# 使用 Tree 将文件夹原样复制到 dist/APP_NAME/ 下
|
||||
# 格式: Tree('源路径', prefix='目标子目录')
|
||||
|
||||
coll = COLLECT(
|
||||
exe,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
strip=False,
|
||||
upx=False,
|
||||
upx_exclude=[],
|
||||
name=APP_NAME,
|
||||
)
|
||||
77
algorithm_V1/build_with_copy.py
Normal file
77
algorithm_V1/build_with_copy.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
def main():
|
||||
# 1. 定义路径
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
DIST_DIR = os.path.join(BASE_DIR, 'dist')
|
||||
APP_NAME = 'Depression_Decoder'
|
||||
TARGET_DIR = os.path.join(DIST_DIR, APP_NAME)
|
||||
|
||||
MODEL_SRC = os.path.join(BASE_DIR, 'model')
|
||||
RAW_DATA_SRC = os.path.join(BASE_DIR, 'raw_data')
|
||||
|
||||
MODEL_DST = os.path.join(TARGET_DIR, 'model')
|
||||
RAW_DATA_DST = os.path.join(TARGET_DIR, 'raw_data')
|
||||
|
||||
# 2. 清理旧构建
|
||||
print("[1/3] Cleaning up old builds...")
|
||||
if os.path.exists(DIST_DIR):
|
||||
try:
|
||||
shutil.rmtree(DIST_DIR)
|
||||
print(" Cleaned dist/")
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not clean dist/: {e}")
|
||||
|
||||
BUILD_DIR = os.path.join(BASE_DIR, 'build')
|
||||
if os.path.exists(BUILD_DIR):
|
||||
try:
|
||||
shutil.rmtree(BUILD_DIR)
|
||||
print(" Cleaned build/")
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not clean build/: {e}")
|
||||
|
||||
# 3. 运行 PyInstaller
|
||||
print("[2/3] Running PyInstaller...")
|
||||
# 注意:我们这里不传 --noupx,因为已经在 spec 文件里把 upx=False 写死了
|
||||
cmd = [
|
||||
"pyinstaller",
|
||||
"build_algorithm.spec",
|
||||
"--clean"
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.check_call(cmd, shell=True)
|
||||
except subprocess.CalledProcessError:
|
||||
print("Error: PyInstaller failed.")
|
||||
sys.exit(1)
|
||||
|
||||
# 4. 复制外部资源文件夹
|
||||
print("[3/3] Copying external resources...")
|
||||
|
||||
# 复制 model 文件夹
|
||||
if os.path.exists(MODEL_SRC):
|
||||
if os.path.exists(MODEL_DST):
|
||||
shutil.rmtree(MODEL_DST)
|
||||
shutil.copytree(MODEL_SRC, MODEL_DST)
|
||||
print(f" Copied: model -> {MODEL_DST}")
|
||||
else:
|
||||
print(f" Warning: Source model dir not found at {MODEL_SRC}")
|
||||
|
||||
# 复制 raw_data 文件夹
|
||||
if os.path.exists(RAW_DATA_SRC):
|
||||
if os.path.exists(RAW_DATA_DST):
|
||||
shutil.rmtree(RAW_DATA_DST)
|
||||
shutil.copytree(RAW_DATA_SRC, RAW_DATA_DST)
|
||||
print(f" Copied: raw_data -> {RAW_DATA_DST}")
|
||||
else:
|
||||
print(f" Warning: Source raw_data dir not found at {RAW_DATA_SRC}")
|
||||
|
||||
print("\n" + "="*50)
|
||||
print(f"SUCCESS! Build artifacts are in: {TARGET_DIR}")
|
||||
print("="*50)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
561
algorithm_V1/infer_pth.py
Normal file
561
algorithm_V1/infer_pth.py
Normal file
@@ -0,0 +1,561 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
infer_pth.py
|
||||
|
||||
用途:
|
||||
- 从一个文件夹中自动读取第一个 .mat EEG 文件(64通道或32通道)
|
||||
- 若为64通道,则按 idx64_to_32 映射选出32通道
|
||||
- 提取切片特征(DE + PSD(var近似),不含Asym)
|
||||
- 加载你训练好的 .pth 模型(FusionNet结构)
|
||||
- 输出该受试者的 HC / MDD 判断结果
|
||||
|
||||
运行方式(命令行):
|
||||
python infer_pth_from64_to32.py --eeg_dir "D:\\xxx\\folder" --model_path "C:\\xxx\\model.pth"
|
||||
|
||||
也可在其他py里import:
|
||||
from infer_pth_from64_to32 import predict_hc_mdd
|
||||
res = predict_hc_mdd(eeg_dir, model_path)
|
||||
print(res)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
import scipy.signal as signal
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 0) 配置区(按需改这里)
|
||||
# =========================================================
|
||||
|
||||
# 采样率(必须与训练时一致)
|
||||
SAMPLING_RATE = 250
|
||||
|
||||
# 滑窗参数(必须与训练时一致)
|
||||
WINDOW_SIZE = 500
|
||||
STRIDE = 250
|
||||
|
||||
# 频段(必须与训练时一致)
|
||||
BAND_NAMES = ["Delta", "Theta", "Alpha", "Beta", "Gamma"]
|
||||
BANDS = {
|
||||
"Delta": (1, 4),
|
||||
"Theta": (4, 8),
|
||||
"Alpha": (8, 13),
|
||||
"Beta": (13, 30),
|
||||
"Gamma": (30, 50),
|
||||
}
|
||||
|
||||
# 是否使用扩展特征(DE+PSD)
|
||||
USE_EXTENDED_FEATURES = True
|
||||
|
||||
# 数值稳定项
|
||||
EPS = 1e-12
|
||||
|
||||
# 设备
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 通道映射
|
||||
IDX64_TO_32 = [
|
||||
23, # C5
|
||||
47, # O1
|
||||
39, # TP7
|
||||
6, # FPZ
|
||||
2, # PO6
|
||||
21, # P4
|
||||
35, # AF7
|
||||
57, # AF3
|
||||
1, # FP2
|
||||
37, # T7
|
||||
63, # F1
|
||||
36, # A1
|
||||
18, # FC4
|
||||
31, # FC5
|
||||
14, # FC2
|
||||
48, # T8
|
||||
60, # P2
|
||||
41, # AF8
|
||||
11, # CP1
|
||||
0, # FP1
|
||||
55, # PO7
|
||||
59, # C1
|
||||
22, # F5
|
||||
10, # CP2
|
||||
16, # C3
|
||||
61, # P1
|
||||
27, # CP5
|
||||
17, # C4
|
||||
26, # CP6
|
||||
62, # F2
|
||||
3, # POZ
|
||||
13, # PO5
|
||||
]
|
||||
|
||||
# 推理阈值:如果模型checkpoint里有 subject_threshold,会优先用它;否则用这个
|
||||
DEFAULT_SUBJECT_THRESHOLD = 0.5
|
||||
|
||||
# =========================================================
|
||||
# 1) 模型结构
|
||||
# =========================================================
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
def __init__(self, channels: int, reduction: int = 4) -> None:
|
||||
super().__init__()
|
||||
hidden = max(1, channels // reduction)
|
||||
self.fc1 = nn.Linear(channels, hidden)
|
||||
self.fc2 = nn.Linear(hidden, channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
se = F.relu(self.fc1(x))
|
||||
se = torch.sigmoid(self.fc2(se))
|
||||
return x * se
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, dropout: float = 0.3) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_features, out_features)
|
||||
self.bn1 = nn.BatchNorm1d(out_features)
|
||||
self.fc2 = nn.Linear(out_features, out_features)
|
||||
self.bn2 = nn.BatchNorm1d(out_features)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.shortcut = nn.Identity()
|
||||
if in_features != out_features:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Linear(in_features, out_features),
|
||||
nn.BatchNorm1d(out_features),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
identity = self.shortcut(x)
|
||||
out = F.relu(self.bn1(self.fc1(x)))
|
||||
out = self.dropout(out)
|
||||
out = self.bn2(self.fc2(out))
|
||||
out = F.relu(out + identity)
|
||||
return out
|
||||
|
||||
|
||||
class FusionNet(nn.Module):
|
||||
def __init__(self, num_classes: int = 2, num_eeg_features: int = 320, num_scales: int = 6) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.input_norm = nn.BatchNorm1d(num_eeg_features)
|
||||
|
||||
self.block1 = ResidualBlock(num_eeg_features, 512, dropout=0.4)
|
||||
self.block2 = ResidualBlock(512, 256, dropout=0.3)
|
||||
self.block3 = ResidualBlock(256, 128, dropout=0.2)
|
||||
|
||||
self.attention = SEBlock(128, reduction=4)
|
||||
|
||||
self.final_fc = nn.Sequential(
|
||||
nn.Linear(128, 64),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
)
|
||||
|
||||
self.cls_head = nn.Linear(64, num_classes)
|
||||
|
||||
# 训练时有回归头也没关系(推理只用cls)
|
||||
self.reg_head = nn.Sequential(
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(32, num_scales),
|
||||
)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self) -> None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.input_norm(x)
|
||||
x = self.block1(x)
|
||||
x = self.block2(x)
|
||||
x = self.block3(x)
|
||||
x = self.attention(x)
|
||||
features = self.final_fc(x)
|
||||
cls_out = self.cls_head(features)
|
||||
reg_out = self.reg_head(features)
|
||||
return cls_out, reg_out
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 2) mat读取 + 通道裁剪
|
||||
# =========================================================
|
||||
|
||||
def _find_first_mat_file(folder: str) -> str:
|
||||
if not os.path.isdir(folder):
|
||||
raise RuntimeError(f"eeg_dir 不是文件夹: {folder}")
|
||||
mats = sorted([f for f in os.listdir(folder) if f.lower().endswith(".mat")])
|
||||
if not mats:
|
||||
raise RuntimeError(f"文件夹内没有 .mat 文件: {folder}")
|
||||
return os.path.join(folder, mats[0])
|
||||
|
||||
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
|
||||
def _unwrap_singleton(x):
|
||||
"""
|
||||
把 (1,1) / (1,) 这种包裹层一直剥掉,直到不是 singleton。
|
||||
也处理 object array 的情况。
|
||||
"""
|
||||
while True:
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype == object and x.size == 1:
|
||||
x = x.item()
|
||||
continue
|
||||
if x.size == 1 and x.ndim >= 1:
|
||||
# 例如 (1,1) 或 (1,) 的数值/对象数组
|
||||
try:
|
||||
x = x.reshape(-1)[0]
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
return x
|
||||
|
||||
def _try_get_struct_field(v, field_name="data"):
|
||||
"""
|
||||
尝试从以下几种结构中提取字段:
|
||||
1) scipy 读出的 mat_struct(有 _fieldnames)
|
||||
2) numpy structured/record array(dtype.names)
|
||||
"""
|
||||
# case 1: mat_struct(推荐 loadmat(..., struct_as_record=False, squeeze_me=True))
|
||||
if hasattr(v, "_fieldnames") and (field_name in getattr(v, "_fieldnames", [])):
|
||||
return getattr(v, field_name)
|
||||
|
||||
# case 2: structured array
|
||||
if isinstance(v, np.ndarray) and v.dtype.names and (field_name in v.dtype.names):
|
||||
# 常见是 v[field] 仍然是 ndarray / object,需要 unwrap
|
||||
try:
|
||||
return v[field_name]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def load_eeg_from_mat_any_channels(mat_path: str) -> np.ndarray:
|
||||
"""
|
||||
读取 .mat 中 EEG 数据,支持:
|
||||
- 直接二维矩阵 (T,C) 或 (C,T)
|
||||
- struct 里有字段 data
|
||||
返回统一为 float32 的 (T, C)
|
||||
"""
|
||||
# 用这两个参数会让 struct 更容易处理:字段变成属性,且自动 squeeze
|
||||
mat = scipy.io.loadmat(mat_path, struct_as_record=False, squeeze_me=True)
|
||||
|
||||
candidates = []
|
||||
|
||||
for k, v in mat.items():
|
||||
if k.startswith("__"):
|
||||
continue
|
||||
|
||||
# --- 1) 直接二维数值矩阵 ---
|
||||
if isinstance(v, np.ndarray) and v.ndim == 2 and np.issubdtype(v.dtype, np.number):
|
||||
candidates.append((k, v))
|
||||
continue
|
||||
|
||||
# --- 2) struct/record:优先提取 data 字段 ---
|
||||
data_field = _try_get_struct_field(v, "data")
|
||||
if data_field is not None:
|
||||
data_field = _unwrap_singleton(data_field)
|
||||
|
||||
# data_field 可能仍然被 object 包一层
|
||||
if isinstance(data_field, np.ndarray) and data_field.dtype == object:
|
||||
data_field = _unwrap_singleton(data_field)
|
||||
|
||||
if isinstance(data_field, np.ndarray) and data_field.ndim == 2:
|
||||
# 只收数值矩阵
|
||||
if np.issubdtype(data_field.dtype, np.number) or data_field.dtype == object:
|
||||
candidates.append((f"{k}.data", data_field))
|
||||
continue
|
||||
|
||||
# --- 3) object array:尝试 item() 解包后再看是不是二维数值矩阵/struct ---
|
||||
if isinstance(v, np.ndarray) and v.dtype == object:
|
||||
vv = _unwrap_singleton(v)
|
||||
|
||||
# 解包后若是二维数值矩阵
|
||||
if isinstance(vv, np.ndarray) and vv.ndim == 2 and np.issubdtype(vv.dtype, np.number):
|
||||
candidates.append((k, vv))
|
||||
continue
|
||||
|
||||
# 解包后若是 struct,再取 data
|
||||
data2 = _try_get_struct_field(vv, "data")
|
||||
if data2 is not None:
|
||||
data2 = _unwrap_singleton(data2)
|
||||
if isinstance(data2, np.ndarray) and data2.ndim == 2:
|
||||
candidates.append((f"{k}.data", data2))
|
||||
continue
|
||||
|
||||
if not candidates:
|
||||
raise RuntimeError(f"mat里没找到可用EEG二维矩阵或struct.data:{mat_path}")
|
||||
|
||||
# 选一个最像 EEG 的(优先含32/64通道维度的)
|
||||
def score(arr: np.ndarray) -> int:
|
||||
s = 0
|
||||
if 64 in arr.shape: s += 10
|
||||
if 32 in arr.shape: s += 9
|
||||
if 128 in arr.shape: s += 8
|
||||
if 129 in arr.shape: s += 7
|
||||
s += int(np.prod(arr.shape) // 100000) # 大一些更像EEG
|
||||
return s
|
||||
|
||||
candidates.sort(key=lambda kv: score(kv[1]), reverse=True)
|
||||
key, eeg = candidates[0]
|
||||
|
||||
eeg = _unwrap_singleton(eeg)
|
||||
|
||||
# 如果还是 object dtype,尽力转成 float
|
||||
if isinstance(eeg, np.ndarray) and eeg.dtype == object:
|
||||
# 有时 object 里其实是数值
|
||||
eeg = np.array(eeg, dtype=np.float32)
|
||||
else:
|
||||
eeg = np.asarray(eeg, dtype=np.float32)
|
||||
|
||||
if eeg.ndim != 2:
|
||||
raise RuntimeError(f"解析结果不是二维矩阵: key={key}, shape={eeg.shape}, file={mat_path}")
|
||||
|
||||
# 统一为 (T, C)
|
||||
# 常见 (C,T) 或 (T,C),我们用“通道维通常较小”+ “32/64/128/129”判断
|
||||
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[1] not in (32, 64, 128, 129):
|
||||
eeg = eeg.T
|
||||
elif eeg.shape[1] in (32, 64, 128, 129):
|
||||
# 如果第一维也是这些数且更小,可能是(C,T)
|
||||
if eeg.shape[0] in (32, 64, 128, 129) and eeg.shape[0] < eeg.shape[1]:
|
||||
eeg = eeg.T
|
||||
|
||||
return eeg
|
||||
|
||||
|
||||
def ensure_32_channels(eeg: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
输入 (T, C),输出 (T, 32)
|
||||
- 若C=64:按 IDX64_TO_32 选32通道
|
||||
- 若C=32:直接返回
|
||||
"""
|
||||
if eeg.ndim != 2:
|
||||
raise RuntimeError(f"EEG必须是二维(T,C),但得到: {eeg.shape}")
|
||||
|
||||
C = eeg.shape[1]
|
||||
if C == 64:
|
||||
idx = np.asarray(IDX64_TO_32, dtype=np.int64)
|
||||
if idx.min() < 0 or idx.max() >= 64:
|
||||
raise RuntimeError(f"IDX64_TO_32 越界:min={idx.min()}, max={idx.max()} (要求0~63)")
|
||||
return eeg[:, idx]
|
||||
if C == 32:
|
||||
return eeg
|
||||
raise RuntimeError(f"不支持的通道数C={C},当前只支持 64->32 或 32 直推。")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 3) 特征提取(DE + PSD(var近似))
|
||||
# =========================================================
|
||||
|
||||
class FeatureExtractor32:
|
||||
"""
|
||||
只针对32通道,输出维度:
|
||||
- USE_EXTENDED_FEATURES=True:DE(32*5) + PSD(32*5) = 320
|
||||
- 否则:DE(32*5) = 160
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = SAMPLING_RATE,
|
||||
window_size: int = WINDOW_SIZE,
|
||||
stride: int = STRIDE,
|
||||
filter_order: int = 4,
|
||||
zero_phase: bool = False,
|
||||
) -> None:
|
||||
self.fs = fs
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
self.filter_order = filter_order
|
||||
self.zero_phase = zero_phase
|
||||
|
||||
self._sos = {}
|
||||
for bn in BAND_NAMES:
|
||||
low, high = BANDS[bn]
|
||||
self._sos[bn] = signal.butter(
|
||||
self.filter_order, [low, high],
|
||||
btype="band", fs=self.fs, output="sos"
|
||||
)
|
||||
|
||||
def _filter_bands(self, eeg: np.ndarray) -> dict[str, np.ndarray]:
|
||||
out = {}
|
||||
for bn in BAND_NAMES:
|
||||
sos = self._sos[bn]
|
||||
if self.zero_phase:
|
||||
out[bn] = signal.sosfiltfilt(sos, eeg, axis=0).astype(np.float32)
|
||||
else:
|
||||
out[bn] = signal.sosfilt(sos, eeg, axis=0).astype(np.float32)
|
||||
return out
|
||||
|
||||
def extract(self, eeg32: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
eeg32: (T, 32)
|
||||
return: feats (N_slices, feat_dim)
|
||||
"""
|
||||
if eeg32.ndim != 2 or eeg32.shape[1] != 32:
|
||||
raise RuntimeError(f"extract需要 (T,32),得到 {eeg32.shape}")
|
||||
|
||||
bands_data = self._filter_bands(eeg32)
|
||||
T = eeg32.shape[0]
|
||||
|
||||
feats = []
|
||||
for start in range(0, T - self.window_size, self.stride):
|
||||
end = start + self.window_size
|
||||
|
||||
de_list = []
|
||||
psd_list = []
|
||||
|
||||
for bn in BAND_NAMES:
|
||||
seg = bands_data[bn][start:end, :] # (W, 32)
|
||||
var = np.var(seg, axis=0, ddof=1) # (32,)
|
||||
|
||||
# DE
|
||||
de = 0.5 * np.log(2 * np.pi * np.e * (var + EPS))
|
||||
de_list.append(de)
|
||||
|
||||
if USE_EXTENDED_FEATURES:
|
||||
# PSD近似:log(var)
|
||||
psd_list.append(np.log(var + EPS))
|
||||
|
||||
de_feat = np.stack(de_list, axis=0).T.reshape(-1) # (32*5,)
|
||||
if USE_EXTENDED_FEATURES:
|
||||
psd_feat = np.stack(psd_list, axis=0).T.reshape(-1) # (32*5,)
|
||||
f = np.concatenate([de_feat, psd_feat], axis=0).astype(np.float32)
|
||||
else:
|
||||
f = de_feat.astype(np.float32)
|
||||
|
||||
feats.append(f)
|
||||
|
||||
if not feats:
|
||||
raise RuntimeError("EEG长度不足以切片(请检查T是否太短,或调整WINDOW_SIZE/STRIDE)")
|
||||
|
||||
return np.stack(feats, axis=0).astype(np.float32)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 4) 模型加载 + 推理接口
|
||||
# =========================================================
|
||||
|
||||
def _safe_torch_load(path: str):
|
||||
try:
|
||||
return torch.load(path, map_location=DEVICE, weights_only=False)
|
||||
except TypeError:
|
||||
return torch.load(path, map_location=DEVICE)
|
||||
|
||||
|
||||
def load_model(model_path: str) -> tuple[FusionNet, dict]:
|
||||
"""
|
||||
返回: (model, ckpt_dict)
|
||||
"""
|
||||
obj = _safe_torch_load(model_path)
|
||||
if isinstance(obj, dict) and "model_state" in obj:
|
||||
ckpt = obj
|
||||
state = obj["model_state"]
|
||||
feat_dim = int(obj.get("feat_dim", 320))
|
||||
else:
|
||||
ckpt = {}
|
||||
state = obj
|
||||
feat_dim = 320
|
||||
|
||||
model = FusionNet(num_classes=2, num_eeg_features=feat_dim).to(DEVICE)
|
||||
model.load_state_dict(state, strict=True)
|
||||
model.eval()
|
||||
return model, ckpt
|
||||
|
||||
|
||||
def predict_hc_mdd(eeg_dir: str, model_path: str) -> dict:
|
||||
"""
|
||||
接口:传入 EEG文件夹 和 模型路径,返回判断结果 dict
|
||||
|
||||
返回字段:
|
||||
- mat_file: 使用的mat文件
|
||||
- pred_label: "HC" or "MDD"
|
||||
- p_mdd_mean: 切片p(MDD)均值
|
||||
- threshold: subject判定阈值
|
||||
- n_slices: 切片数
|
||||
"""
|
||||
mat_file = _find_first_mat_file(eeg_dir)
|
||||
|
||||
# 1) 读EEG (T,C),并保证变成32通道
|
||||
eeg = load_eeg_from_mat_any_channels(mat_file) # (T,C)
|
||||
eeg32 = ensure_32_channels(eeg) # (T,32)
|
||||
|
||||
# 2) 提特征 (N,feat_dim)
|
||||
extractor = FeatureExtractor32(fs=SAMPLING_RATE, window_size=WINDOW_SIZE, stride=STRIDE)
|
||||
feats = extractor.extract(eeg32) # (N, dim)
|
||||
|
||||
# 3) 加载模型
|
||||
model, ckpt = load_model(model_path)
|
||||
|
||||
# 4) 可选:归一化(若ckpt里保存的mean/std维度刚好匹配)
|
||||
mean = ckpt.get("global_mean", None) if isinstance(ckpt, dict) else None
|
||||
std = ckpt.get("global_std", None) if isinstance(ckpt, dict) else None
|
||||
|
||||
if mean is not None and std is not None:
|
||||
mean = np.asarray(mean, dtype=np.float32)
|
||||
std = np.asarray(std, dtype=np.float32)
|
||||
if mean.shape[0] == feats.shape[1] and std.shape[0] == feats.shape[1]:
|
||||
feats = (feats - mean) / (std + 1e-8)
|
||||
# 不匹配就跳过
|
||||
# 你说“先不用其他步骤或信息”,所以这里按“尽量运行”处理
|
||||
|
||||
# 5) 推理:对所有切片算 p(MDD),取均值做subject-level
|
||||
x = torch.from_numpy(feats).to(DEVICE)
|
||||
with torch.no_grad():
|
||||
cls_out, _ = model(x)
|
||||
prob_mdd = torch.softmax(cls_out, dim=1)[:, 1].detach().cpu().numpy()
|
||||
|
||||
p_mdd_mean = float(np.mean(prob_mdd))
|
||||
thr = float(ckpt.get("subject_threshold", DEFAULT_SUBJECT_THRESHOLD)) if isinstance(ckpt, dict) else float(DEFAULT_SUBJECT_THRESHOLD)
|
||||
pred_is_mdd = (p_mdd_mean >= thr)
|
||||
pred_label = "MDD" if pred_is_mdd else "HC"
|
||||
|
||||
return {
|
||||
"mat_file": mat_file,
|
||||
"pred_label": pred_label,
|
||||
"p_mdd_mean": p_mdd_mean,
|
||||
"threshold": thr,
|
||||
"n_slices": int(feats.shape[0]),
|
||||
}
|
||||
|
||||
|
||||
# =========================================================
|
||||
# 5) CLI:命令行运行入口
|
||||
# =========================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Infer HC/MDD from 64ch->32ch EEG mat using a .pth FusionNet model (no Asym).")
|
||||
parser.add_argument("--eeg_dir", type=str, required=True, help="包含.mat EEG文件的文件夹(自动读取第一个.mat)")
|
||||
parser.add_argument("--model_path", type=str, required=True, help="训练好的.pth模型路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
res = predict_hc_mdd(args.eeg_dir, args.model_path)
|
||||
print("\n========== 推理结果 ==========")
|
||||
print(f"MAT文件: {res['mat_file']}")
|
||||
print(f"切片数量: {res['n_slices']}")
|
||||
print(f"p(MDD)_mean: {res['p_mdd_mean']:.4f}")
|
||||
print(f"阈值thr: {res['threshold']:.4f}")
|
||||
print(f"预测结果: {res['pred_label']}")
|
||||
print("==============================\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
algorithm_V1/model/Model_0.pth
Normal file
BIN
algorithm_V1/model/Model_0.pth
Normal file
Binary file not shown.
BIN
algorithm_V1/model/Model_1.pth
Normal file
BIN
algorithm_V1/model/Model_1.pth
Normal file
Binary file not shown.
BIN
algorithm_V1/out/EEG.png
Normal file
BIN
algorithm_V1/out/EEG.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 683 KiB |
9
algorithm_V1/out/ResultData.txt
Normal file
9
algorithm_V1/out/ResultData.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
中央区α/β波比值:0.5
|
||||
额区α/β波比值:0.6
|
||||
顶区α/β波比值:0.6
|
||||
中央区θ/β波比值:0.9
|
||||
顶区θ/β波比值:0.9
|
||||
前额叶α波不对称性:-0.0
|
||||
个体化α峰值频率:8.5
|
||||
前额叶θ+δ波功率:74.1
|
||||
是否推荐治疗:是
|
||||
BIN
algorithm_V1/out/average_topomap.png
Normal file
BIN
algorithm_V1/out/average_topomap.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 247 KiB |
BIN
algorithm_V1/out/psd.png
Normal file
BIN
algorithm_V1/out/psd.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 107 KiB |
BIN
algorithm_V1/out/topomaps.png
Normal file
BIN
algorithm_V1/out/topomaps.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 298 KiB |
335615
algorithm_V1/raw_data/eeg_data0(6).bdf
Normal file
335615
algorithm_V1/raw_data/eeg_data0(6).bdf
Normal file
File diff suppressed because one or more lines are too long
6
algorithm_V1/requirements.txt
Normal file
6
algorithm_V1/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
numpy
|
||||
scipy
|
||||
matplotlib
|
||||
mne
|
||||
torch
|
||||
scikit-learn
|
||||
986
algorithm_V1/runDecoder.py
Normal file
986
algorithm_V1/runDecoder.py
Normal file
@@ -0,0 +1,986 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
runDecoder.py - BDF EEG Depression Assessment
|
||||
|
||||
功能:
|
||||
1. 读取 raw_data 文件夹的第一个 .bdf 格式文件
|
||||
2. 预处理:坏通道剔除、50Hz陷波、0.8-40Hz带通、幅值过滤、ICA去伪迹
|
||||
3. 调用 infer_pth.py 中的 predict_hc_mdd 进行 HC/MDD 分类预测
|
||||
4. 保存图表(EEG、PSD、Topomap)
|
||||
5. 生成 ResultData.txt
|
||||
|
||||
"""
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import scipy.signal as signal
|
||||
import matplotlib.pyplot as plt
|
||||
import mne
|
||||
from mne.preprocessing import ICA
|
||||
|
||||
# ==========================
|
||||
# Config - 预处理参数
|
||||
# ==========================
|
||||
# 滤波参数
|
||||
BANDPASS_LOW = 0.8
|
||||
BANDPASS_HIGH = 40.0
|
||||
NOTCH_FREQS = [50, 100] # 工频陷波
|
||||
|
||||
# 幅值过滤阈值 (μV)
|
||||
AMPLITUDE_MIN_UV = -200.0
|
||||
AMPLITUDE_MAX_UV = 200.0
|
||||
|
||||
# ICA 参数
|
||||
ICA_N_COMPONENTS = 20 # 使用绝对数量而非比例
|
||||
ICA_RANDOM_STATE = 97
|
||||
ICA_MAX_ITER = 800
|
||||
|
||||
# 坏段检测阈值 (μV)
|
||||
BAD_SEGMENT_THRESHOLD_UV = 350.0
|
||||
|
||||
# 默认采样率
|
||||
DEFAULT_FS = 250.0
|
||||
|
||||
# 画图参数
|
||||
EEG_PLOT_SECONDS = 10
|
||||
PSD_FMIN, PSD_FMAX = 0.8, 45.0
|
||||
FIXED_EEG_IDXS = [23, 47, 39, 6, 2, 21, 35, 57] # 0-based index
|
||||
FIXED_EEG_LABELS = ["C5", "O1", "TP7", "FPZ", "PO6", "P4", "AF7", "AF3"]
|
||||
|
||||
# 频段定义
|
||||
BANDS_METRICS = {
|
||||
"Delta": (1.0, 4.0),
|
||||
"Theta": (4.0, 8.0),
|
||||
"Alpha": (8.0, 13.0),
|
||||
"Beta": (13.0, 30.0),
|
||||
}
|
||||
TOTAL_POWER_BAND = (1.0, 50.0)
|
||||
|
||||
BANDS_TOPOMAP = {
|
||||
"delta": (1.0, 4.0),
|
||||
"theta": (4.0, 8.0),
|
||||
"alpha": (8.0, 13.0),
|
||||
"beta": (13.0, 30.0),
|
||||
"broad": (1.0, 30.0),
|
||||
}
|
||||
|
||||
# PSD 参数
|
||||
PSD_NPERSEG = 1024 # FFT 窗口大小,越大频率分辨率越高
|
||||
|
||||
EPS = 1e-12
|
||||
|
||||
# 脑地形图颜色范围参数
|
||||
# 设置为 None 表示自动范围,设置为 (min, max) 固定范围
|
||||
TOPOMAP_VMIN = None # 例如: -1.0
|
||||
TOPOMAP_VMAX = None # 例如: 1.0
|
||||
# 或者使用对称范围(相对于均值的倍数)
|
||||
TOPOMAP_SYM_SCALE = 1.5 # 颜色范围 = 均值 ± std * SYM_SCALE
|
||||
|
||||
# 脑地形图圆形大小参数 (0.08 - 0.15 范围)
|
||||
# 数值越小圆形越小,越大圆形越大
|
||||
TOPOMAP_SPHERE_RADIUS = 0.12
|
||||
|
||||
# 边界处理参数
|
||||
# 滤波前 padding 秒数,用于消除边界振铃效应
|
||||
FILTER_PAD_SEC = 1.0
|
||||
|
||||
|
||||
# ==========================
|
||||
# 数据文件读取
|
||||
# ==========================
|
||||
def load_data_file(file_path: str) -> tuple:
|
||||
"""根据文件扩展名读取数据,返回 MNE Raw 对象"""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
if ext == ".bdf":
|
||||
return load_bdf_file(file_path)
|
||||
elif ext == ".mat":
|
||||
return load_mat_file(file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {ext}")
|
||||
|
||||
|
||||
def load_bdf_file(bdf_path: str) -> tuple:
|
||||
"""读取 .bdf 格式文件,返回 MNE Raw 对象"""
|
||||
print(f"[INFO] Reading BDF file: {bdf_path}")
|
||||
raw = mne.io.read_raw_bdf(bdf_path, preload=True, verbose=False)
|
||||
|
||||
try:
|
||||
raw.set_montage("standard_1020", on_missing="ignore")
|
||||
except Exception as e:
|
||||
print(f"[WARN] Failed to set standard_1020 montage: {e}")
|
||||
|
||||
sfreq = raw.info['sfreq']
|
||||
ch_names = raw.ch_names
|
||||
n_channels = len(ch_names)
|
||||
duration = raw.times[-1] - raw.times[0]
|
||||
|
||||
print(f"[INFO] Channels: {n_channels}, Duration: {duration:.2f}s, SFreq: {sfreq:.2f}Hz")
|
||||
|
||||
return raw, sfreq, ch_names
|
||||
|
||||
|
||||
def load_mat_file(mat_path: str) -> tuple:
|
||||
"""读取 .mat 格式文件,返回 MNE Raw 对象"""
|
||||
print(f"[INFO] Reading MAT file: {mat_path}")
|
||||
import scipy.io
|
||||
|
||||
mat = scipy.io.loadmat(mat_path)
|
||||
eeg = mat['eeg'][0, 0]
|
||||
|
||||
# 提取数据
|
||||
data = eeg['data'] # (T, C)
|
||||
if data.shape[0] < data.shape[1]:
|
||||
data = data.T # 确保是 (T, C)
|
||||
data = data.astype(np.float64) # 确保是 float
|
||||
|
||||
# 提取采样率
|
||||
sfreq = float(eeg['sample_rate'][0, 0])
|
||||
|
||||
# 提取通道名称
|
||||
ch_names_raw = eeg['electrode_name']
|
||||
if ch_names_raw.ndim == 2:
|
||||
ch_names = [str(ch[0]) if isinstance(ch[0], np.bytes_) else str(ch[0]) for ch in ch_names_raw[0]]
|
||||
else:
|
||||
ch_names = [f"EEG{i+1}" for i in range(data.shape[1])]
|
||||
|
||||
n_channels = data.shape[1]
|
||||
n_samples = data.shape[0]
|
||||
duration = n_samples / sfreq
|
||||
|
||||
print(f"[INFO] Channels: {n_channels}, Duration: {duration:.2f}s, SFreq: {sfreq:.2f}Hz")
|
||||
|
||||
# 创建 MNE Raw 对象
|
||||
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=["eeg"] * n_channels)
|
||||
raw = mne.io.RawArray(data.T, info, verbose=False) # (T, C) -> (C, T)
|
||||
|
||||
# 尝试设置通道位置
|
||||
try:
|
||||
electrode_xyz = eeg['electrode_xyz'] # (64, 3)
|
||||
if electrode_xyz.shape[0] == n_channels:
|
||||
ch_pos = {}
|
||||
for i, name in enumerate(ch_names):
|
||||
ch_pos[name] = electrode_xyz[i] / 1000.0 # 转换为米
|
||||
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')
|
||||
info.set_montage(montage)
|
||||
print("[INFO] Applied electrode positions from mat file")
|
||||
else:
|
||||
raw.set_montage("standard_1020", on_missing="ignore")
|
||||
except Exception as e:
|
||||
print(f"[WARN] Failed to set montage from mat file: {e}")
|
||||
try:
|
||||
raw.set_montage("standard_1020", on_missing="ignore")
|
||||
except:
|
||||
pass
|
||||
|
||||
return raw, sfreq, ch_names
|
||||
|
||||
|
||||
# ==========================
|
||||
# 坏通道检测
|
||||
# ==========================
|
||||
def detect_bad_channels(raw: mne.io.RawArray, z_thresh: float = 3.0) -> list:
|
||||
"""检测坏通道:全零/常数通道 + MAD z-score 离群通道"""
|
||||
data = raw.get_data()
|
||||
ch_names = raw.ch_names
|
||||
bad_chs = []
|
||||
|
||||
ptp = np.ptp(data, axis=1)
|
||||
std = np.std(data, axis=1)
|
||||
|
||||
for i, (p, s) in enumerate(zip(ptp, std)):
|
||||
if p < 1e-12 or s < 1e-12:
|
||||
bad_chs.append(ch_names[i])
|
||||
|
||||
valid_mask = np.array([ch not in bad_chs for ch in ch_names])
|
||||
if valid_mask.sum() > 2:
|
||||
valid_ptp = ptp[valid_mask]
|
||||
med = np.median(valid_ptp)
|
||||
mad = np.median(np.abs(valid_ptp - med)) + 1e-30
|
||||
z = np.abs(ptp - med) / (mad * 1.4826)
|
||||
|
||||
for i, zv in enumerate(z):
|
||||
if zv > z_thresh and ch_names[i] not in bad_chs:
|
||||
bad_chs.append(ch_names[i])
|
||||
|
||||
if bad_chs:
|
||||
print(f"[INFO] Bad channels detected: {bad_chs}")
|
||||
else:
|
||||
print("[INFO] No bad channels detected")
|
||||
|
||||
return bad_chs
|
||||
|
||||
|
||||
# ==========================
|
||||
# 坏段标注
|
||||
# ==========================
|
||||
def annotate_bad_segments(raw: mne.io.RawArray, peak_to_peak_uv: float = 350.0):
|
||||
"""简单坏段检测:按1秒窗口计算峰峰值,超过阈值标为 bad"""
|
||||
peak_to_peak_v = peak_to_peak_uv * 1e-6
|
||||
win = int(raw.info["sfreq"] * 1.0)
|
||||
step = int(raw.info["sfreq"] * 0.5)
|
||||
data = raw.get_data()
|
||||
n_times = data.shape[1]
|
||||
|
||||
onsets = []
|
||||
durations = []
|
||||
|
||||
for start in range(0, n_times - win, step):
|
||||
seg = data[:, start:start + win]
|
||||
ptp = np.ptp(seg, axis=1)
|
||||
if np.any(ptp > peak_to_peak_v):
|
||||
onsets.append(start / raw.info["sfreq"])
|
||||
durations.append(win / raw.info["sfreq"])
|
||||
|
||||
if len(onsets) > 0:
|
||||
ann = mne.Annotations(onset=onsets, duration=durations, description=["BAD_SEG"] * len(onsets))
|
||||
raw.set_annotations(ann)
|
||||
print(f"[INFO] Annotated {len(onsets)} bad segments")
|
||||
|
||||
|
||||
# ==========================
|
||||
# 核心预处理函数
|
||||
# ==========================
|
||||
def preprocess_bdf(raw: mne.io.RawArray) -> mne.io.RawArray:
|
||||
"""BDF 数据预处理流程"""
|
||||
print("[INFO] Starting preprocessing pipeline...")
|
||||
|
||||
# 1) 裁剪首尾 2s
|
||||
crop_sec = 2.0
|
||||
t_start = crop_sec
|
||||
t_end = raw.times[-1] - crop_sec
|
||||
if t_end > t_start:
|
||||
raw = raw.crop(tmin=t_start, tmax=t_end)
|
||||
print(f"[INFO] Cropped: removed first/last {crop_sec}s")
|
||||
|
||||
# 2) 去直流偏置
|
||||
data = raw.get_data()
|
||||
data -= data.mean(axis=1, keepdims=True)
|
||||
raw._data = data
|
||||
print("[INFO] Removed DC offset")
|
||||
|
||||
# 3) 坏通道检测与插值
|
||||
bad_chs = detect_bad_channels(raw)
|
||||
if bad_chs:
|
||||
raw.info["bads"] = bad_chs
|
||||
try:
|
||||
raw_tmp = raw.copy()
|
||||
raw_tmp.set_montage(raw.get_montage(), on_missing="ignore")
|
||||
raw_tmp.interpolate_bads(reset_bads=True, verbose=False)
|
||||
raw = raw_tmp
|
||||
print(f"[INFO] Bad channels interpolated: {bad_chs}")
|
||||
except Exception as e:
|
||||
print(f"[WARN] Bad channel interpolation failed: {e}")
|
||||
raw.info["bads"] = []
|
||||
|
||||
# 4) 50Hz 陷波滤波
|
||||
print(f"[INFO] Applying notch filter: {NOTCH_FREQS}Hz")
|
||||
raw.notch_filter(NOTCH_FREQS, fir_design="firwin", verbose=False)
|
||||
|
||||
# 5) 0.8-40Hz 带通滤波 (使用 padding 消除边界振铃)
|
||||
print(f"[INFO] Applying bandpass filter: {BANDPASS_LOW}-{BANDPASS_HIGH}Hz (with {FILTER_PAD_SEC}s padding)")
|
||||
pad_sec = FILTER_PAD_SEC
|
||||
raw_length = raw.times[-1]
|
||||
pad_start = max(0, pad_sec)
|
||||
pad_end = max(0, pad_sec)
|
||||
|
||||
if raw_length > pad_start + pad_end + 1.0:
|
||||
raw.filter(BANDPASS_LOW, BANDPASS_HIGH, fir_design="firwin",
|
||||
pad='reflect', verbose=False)
|
||||
raw = raw.crop(tmin=pad_sec, tmax=raw_length - pad_sec)
|
||||
print(f"[INFO] Removed {pad_sec}s padding from each side after filtering")
|
||||
else:
|
||||
raw.filter(BANDPASS_LOW, BANDPASS_HIGH, fir_design="firwin", verbose=False)
|
||||
print(f"[WARN] Data too short ({raw_length:.1f}s), skipping padding")
|
||||
|
||||
# 6) 幅值过滤
|
||||
print(f"[INFO] Applying amplitude filter: [{AMPLITUDE_MIN_UV}, {AMPLITUDE_MAX_UV}] μV")
|
||||
amplitude_thresh_v = AMPLITUDE_MAX_UV * 1e-6
|
||||
d = raw.get_data()
|
||||
mask = np.abs(d) > amplitude_thresh_v
|
||||
n_clipped = int(mask.sum())
|
||||
if n_clipped > 0:
|
||||
d[mask] = 0.0
|
||||
raw._data = d
|
||||
print(f"[INFO] Amplitude clipping: {n_clipped} samples exceeded ±200μV, set to 0")
|
||||
|
||||
# 7) 坏段标注
|
||||
annotate_bad_segments(raw, peak_to_peak_uv=BAD_SEGMENT_THRESHOLD_UV)
|
||||
|
||||
# 8) ICA 去伪迹
|
||||
print("[INFO] Running ICA for artifact removal...")
|
||||
ica = ICA(n_components=ICA_N_COMPONENTS, random_state=ICA_RANDOM_STATE,
|
||||
max_iter=ICA_MAX_ITER, method="fastica")
|
||||
ica.fit(raw, reject_by_annotation=True, verbose=False)
|
||||
|
||||
try:
|
||||
eog_inds, _ = ica.find_bads_eog(raw, verbose=False)
|
||||
if eog_inds:
|
||||
ica.exclude.extend(eog_inds)
|
||||
print(f"[INFO] ICA exclude EOG components: {eog_inds}")
|
||||
except Exception as e:
|
||||
print(f"[WARN] ICA EOG detection skipped: {e}")
|
||||
|
||||
raw_clean = ica.apply(raw.copy(), verbose=False)
|
||||
|
||||
# 9) ICA 后再次去直流
|
||||
d = raw_clean.get_data()
|
||||
d -= d.mean(axis=1, keepdims=True)
|
||||
raw_clean._data = d
|
||||
|
||||
print("[INFO] Preprocessing completed")
|
||||
return raw_clean
|
||||
|
||||
|
||||
# ==========================
|
||||
# 输出目录管理
|
||||
# ==========================
|
||||
def ensure_outdir(out_root: str) -> str:
|
||||
"""确保输出目录存在,并清空旧文件(保留 ResultData.txt)"""
|
||||
if os.path.exists(out_root):
|
||||
for filename in os.listdir(out_root):
|
||||
if filename == "ResultData.txt":
|
||||
continue
|
||||
file_path = os.path.join(out_root, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
except Exception as e:
|
||||
print(f"[WARN] Failed to delete {file_path}: {e}")
|
||||
else:
|
||||
os.makedirs(out_root, exist_ok=True)
|
||||
return out_root
|
||||
|
||||
|
||||
# ==========================
|
||||
# 通道分区
|
||||
# ==========================
|
||||
def _norm_name(s: str) -> str:
|
||||
return str(s).strip().upper().replace(" ", "")
|
||||
|
||||
|
||||
def build_channel_index_map(ch_names, n_channels: int):
|
||||
if not ch_names or len(ch_names) != n_channels:
|
||||
return {}
|
||||
return {_norm_name(nm): i for i, nm in enumerate(ch_names)}
|
||||
|
||||
|
||||
def pick_indices_by_names(name_to_idx, names):
|
||||
idx = []
|
||||
for n in names:
|
||||
nn = _norm_name(n)
|
||||
if nn in name_to_idx:
|
||||
idx.append(name_to_idx[nn])
|
||||
return sorted(list(set(idx)))
|
||||
|
||||
|
||||
def _fallback_region_indices(n_channels: int):
|
||||
a = int(n_channels * 0.33)
|
||||
b = int(n_channels * 0.66)
|
||||
return (
|
||||
list(range(0, a)), # frontal
|
||||
list(range(a, b)), # central
|
||||
list(range(b, n_channels)), # parietal
|
||||
list(range(0, max(2, a // 2))), # prefrontal
|
||||
list(range(b, n_channels)), # posterior
|
||||
[i for i in range(n_channels) if i % 2 == 0], # left
|
||||
[i for i in range(n_channels) if i % 2 == 1], # right
|
||||
)
|
||||
|
||||
|
||||
def get_region_indices(name_to_idx, n_channels: int):
|
||||
if not name_to_idx:
|
||||
return _fallback_region_indices(n_channels)
|
||||
|
||||
central_names = ["CZ","C1","C2","C3","C4","C5","C6","CP1","CP2","CP3","CP4","CP5","CP6","FC1","FC2","FC3","FC4","FC5","FC6"]
|
||||
frontal_names = ["FZ","F1","F2","F3","F4","F5","F6","F7","F8","AF3","AF4","AF7","AF8","FPZ","FP1","FP2","FCZ"]
|
||||
parietal_names = ["PZ","P1","P2","P3","P4","P5","P6","POZ","PO3","PO4","PO5","PO6","PO7","PO8","CPZ"]
|
||||
prefrontal_names = ["FP1","FP2","FPZ","AF3","AF4","AF7","AF8"]
|
||||
posterior_names = ["O1","O2","OZ","PO7","PO8","PO3","PO4","PZ","P3","P4","P1","P2"]
|
||||
left_names = ["FP1","AF3","AF7","F3","F5","F7"]
|
||||
right_names = ["FP2","AF4","AF8","F4","F6","F8"]
|
||||
|
||||
central = pick_indices_by_names(name_to_idx, central_names)
|
||||
frontal = pick_indices_by_names(name_to_idx, frontal_names)
|
||||
parietal = pick_indices_by_names(name_to_idx, parietal_names)
|
||||
prefrontal = pick_indices_by_names(name_to_idx, prefrontal_names)
|
||||
posterior = pick_indices_by_names(name_to_idx, posterior_names)
|
||||
left = pick_indices_by_names(name_to_idx, left_names)
|
||||
right = pick_indices_by_names(name_to_idx, right_names)
|
||||
|
||||
if not (central and frontal and parietal and prefrontal and posterior):
|
||||
fb = _fallback_region_indices(n_channels)
|
||||
frontal2, central2, parietal2, prefrontal2, posterior2, left2, right2 = fb
|
||||
frontal = frontal if frontal else frontal2
|
||||
central = central if central else central2
|
||||
parietal = parietal if parietal else parietal2
|
||||
prefrontal = prefrontal if prefrontal else prefrontal2
|
||||
posterior = posterior if posterior else posterior2
|
||||
left = left if left else left2
|
||||
right = right if right else right2
|
||||
|
||||
return frontal, central, parietal, prefrontal, posterior, left, right
|
||||
|
||||
|
||||
# ==========================
|
||||
# PSD 和频段功率计算
|
||||
# ==========================
|
||||
def welch_psd(eeg_tc, fs):
|
||||
"""计算 PSD"""
|
||||
nperseg = min(PSD_NPERSEG, eeg_tc.shape[0])
|
||||
noverlap = int(nperseg * 0.75)
|
||||
freqs, pxx = signal.welch(
|
||||
eeg_tc, fs=fs, nperseg=nperseg, noverlap=noverlap,
|
||||
axis=0, scaling="density"
|
||||
)
|
||||
return freqs, pxx
|
||||
|
||||
|
||||
def band_power_from_psd(freqs, pxx_fc, band):
|
||||
lo, hi = band
|
||||
m = (freqs >= lo) & (freqs < hi)
|
||||
if not np.any(m):
|
||||
return np.zeros((pxx_fc.shape[1],), dtype=np.float32)
|
||||
|
||||
return np.trapz(pxx_fc[m, :], freqs[m], axis=0).astype(np.float32)
|
||||
|
||||
|
||||
def region_mean_power(freqs, pxx_fc, idx, band) -> float:
|
||||
if not idx:
|
||||
return 0.0
|
||||
pw = band_power_from_psd(freqs, pxx_fc, band)
|
||||
return float(np.mean(pw[idx]))
|
||||
|
||||
|
||||
def compute_iaf(freqs, pxx_fc, posterior_idx):
|
||||
lo, hi = BANDS_METRICS["Alpha"]
|
||||
m = (freqs >= lo) & (freqs <= hi)
|
||||
if not np.any(m) or not posterior_idx:
|
||||
return 0.0
|
||||
spec = np.mean(pxx_fc[:, posterior_idx], axis=1)
|
||||
sub = spec[m]
|
||||
fsub = freqs[m]
|
||||
return float(fsub[int(np.argmax(sub))])
|
||||
|
||||
|
||||
# ==========================
|
||||
# 画图函数
|
||||
# ==========================
|
||||
def plot_eeg_waveforms(data_uv_tc, fs, ch_names, out_dir, seconds=10, t_start_sec=30.0):
|
||||
"""画 EEG 波形图(固定通道)"""
|
||||
T, C = data_uv_tc.shape
|
||||
|
||||
start_sample = int(t_start_sec * fs)
|
||||
end_sample = int(min(T, start_sample + seconds * fs))
|
||||
|
||||
if start_sample >= T:
|
||||
start_sample = max(0, T - int(seconds * fs))
|
||||
end_sample = T
|
||||
print(f"[WARN] t_start_sec={t_start_sec}s exceeds data, using last {seconds}s")
|
||||
|
||||
seg_samples = end_sample - start_sample
|
||||
x = np.arange(seg_samples) / fs + t_start_sec
|
||||
|
||||
# 过滤有效索引
|
||||
idxs = [i for i in FIXED_EEG_IDXS if 0 <= i < C]
|
||||
if len(idxs) < len(FIXED_EEG_IDXS):
|
||||
missing = [i for i in FIXED_EEG_IDXS if not (0 <= i < C)]
|
||||
print(f"[WARN] Some indices out of range (C={C}): {missing}")
|
||||
|
||||
if len(idxs) == 0:
|
||||
raise RuntimeError(f"No valid indices for data (C={C})")
|
||||
|
||||
picked_names = []
|
||||
for idx in idxs:
|
||||
pos = FIXED_EEG_IDXS.index(idx)
|
||||
std_label = FIXED_EEG_LABELS[pos] if pos < len(FIXED_EEG_LABELS) else f"CH{idx}"
|
||||
if ch_names and idx < len(ch_names):
|
||||
picked_names.append(std_label)
|
||||
else:
|
||||
picked_names.append(std_label)
|
||||
|
||||
fig_h = 1.4 * len(idxs) + 1
|
||||
fig, axes = plt.subplots(len(idxs), 1, figsize=(10, fig_h), sharex=True)
|
||||
if len(idxs) == 1:
|
||||
axes = [axes]
|
||||
|
||||
seg = data_uv_tc[start_sample:end_sample, idxs].T
|
||||
lo = float(np.percentile(seg, 1))
|
||||
hi = float(np.percentile(seg, 99))
|
||||
m = max(abs(lo), abs(hi), 50.0)
|
||||
|
||||
for ax, ch_idx, nm in zip(axes, idxs, picked_names):
|
||||
y = data_uv_tc[start_sample:end_sample, ch_idx]
|
||||
ax.plot(x, y, linewidth=1.2)
|
||||
ax.set_ylabel("uV")
|
||||
ax.set_title(nm, loc="left", fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_ylim(-m, m)
|
||||
|
||||
axes[-1].set_xlabel("Time (s)")
|
||||
plt.tight_layout()
|
||||
|
||||
out_path = os.path.join(out_dir, "EEG.png")
|
||||
plt.savefig(out_path, dpi=200)
|
||||
plt.close(fig)
|
||||
print(f"[OK] EEG waveform saved: {out_path}")
|
||||
|
||||
|
||||
def plot_psd(eeg_uV_tc, fs, ch_names, out_dir):
|
||||
"""画 PSD 图"""
|
||||
C = eeg_uV_tc.shape[1]
|
||||
chosen_idx = []
|
||||
|
||||
if ch_names:
|
||||
mp = {n.upper(): i for i, n in enumerate(ch_names)}
|
||||
for p in ["C3", "C4", "CZ"]:
|
||||
if p in mp:
|
||||
chosen_idx.append(mp[p])
|
||||
if len(chosen_idx) < 3:
|
||||
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
|
||||
stds.sort(key=lambda x: x[1], reverse=True)
|
||||
for i, _ in stds:
|
||||
if i not in chosen_idx:
|
||||
chosen_idx.append(i)
|
||||
if len(chosen_idx) == 3:
|
||||
break
|
||||
chosen_name = [ch_names[i] for i in chosen_idx]
|
||||
else:
|
||||
stds = [(i, float(np.std(eeg_uV_tc[:, i]))) for i in range(C)]
|
||||
stds.sort(key=lambda x: x[1], reverse=True)
|
||||
chosen_idx = [i for i, _ in stds[:3]]
|
||||
chosen_name = [f"CH{i}" for i in chosen_idx]
|
||||
|
||||
# 增大 nperseg 提高频率分辨率
|
||||
nperseg = min(PSD_NPERSEG, eeg_uV_tc.shape[0])
|
||||
noverlap = int(nperseg * 0.75)
|
||||
|
||||
fig = plt.figure(figsize=(7.5, 4.8))
|
||||
for idx, nm in zip(chosen_idx, chosen_name):
|
||||
f, pxx = signal.welch(eeg_uV_tc[:, idx], fs=fs, nperseg=nperseg, noverlap=noverlap)
|
||||
mask = (f >= PSD_FMIN) & (f <= PSD_FMAX)
|
||||
p_db = 10 * np.log10(pxx[mask] + 1e-20)
|
||||
plt.plot(f[mask], p_db, linewidth=1.8, label=nm)
|
||||
|
||||
plt.xlabel("Hz")
|
||||
plt.ylabel("Power (dB)")
|
||||
plt.title("PSD")
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
|
||||
out_path = os.path.join(out_dir, "psd.png")
|
||||
plt.savefig(out_path, dpi=200)
|
||||
plt.close(fig)
|
||||
print(f"[OK] PSD saved: {out_path}")
|
||||
|
||||
|
||||
def _get_standard_1020_channel_indices(raw):
|
||||
"""获取符合 standard_1020 montage 的通道索引和名称"""
|
||||
try:
|
||||
standard_montage = mne.channels.make_standard_montage("standard_1020")
|
||||
standard_names_upper = {ch.upper() for ch in standard_montage.ch_names}
|
||||
standard_name_map = {ch.upper(): ch for ch in standard_montage.ch_names}
|
||||
data_ch_names = raw.ch_names
|
||||
exclude_names = {"A1", "A2", "M1", "M2", "LE", "RE", "LM", "RM"}
|
||||
|
||||
valid_indices = []
|
||||
valid_names = []
|
||||
for i, name in enumerate(data_ch_names):
|
||||
name_upper = name.upper()
|
||||
if name_upper in standard_names_upper and name_upper not in exclude_names:
|
||||
valid_indices.append(i)
|
||||
valid_names.append(standard_name_map[name_upper])
|
||||
|
||||
print(f"[INFO] Found {len(valid_indices)}/{len(data_ch_names)} channels matching standard_1020")
|
||||
return valid_indices, valid_names
|
||||
except Exception as e:
|
||||
print(f"[WARN] Failed to get standard_1020 channels: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def compute_band_powers_for_topomap(raw, bands):
|
||||
"""计算各频段功率,只使用 standard_1020 中有位置的通道"""
|
||||
# 获取 standard_1020 montage 和位置信息
|
||||
standard_montage = mne.channels.make_standard_montage("standard_1020")
|
||||
std_names_upper = {ch.upper() for ch in standard_montage.ch_names}
|
||||
ch_pos_map = standard_montage.get_positions()['ch_pos']
|
||||
|
||||
data_ch_names = raw.ch_names
|
||||
exclude = {"A1", "A2", "M1", "M2", "LE", "RE", "LM", "RM"}
|
||||
|
||||
# 只保留有位置信息的通道
|
||||
valid_indices = []
|
||||
valid_names = []
|
||||
for i, name in enumerate(data_ch_names):
|
||||
name_upper = name.upper()
|
||||
if name_upper in std_names_upper and name_upper not in exclude:
|
||||
if name_upper in ch_pos_map: # 必须有位置
|
||||
valid_indices.append(i)
|
||||
valid_names.append(name)
|
||||
|
||||
if len(valid_indices) < 8:
|
||||
return None
|
||||
|
||||
data = raw.get_data()
|
||||
data_standard = data[valid_indices, :]
|
||||
|
||||
fs = raw.info["sfreq"]
|
||||
n_fft = min(PSD_NPERSEG, data_standard.shape[1])
|
||||
n_overlap = int(n_fft * 0.75)
|
||||
|
||||
psds, freqs = mne.time_frequency.psd_array_welch(
|
||||
data_standard, sfreq=fs,
|
||||
fmin=min(v[0] for v in bands.values()),
|
||||
fmax=max(v[1] for v in bands.values()),
|
||||
n_fft=n_fft, n_overlap=n_overlap,
|
||||
average="mean", verbose=False
|
||||
)
|
||||
|
||||
out = {"_valid_names": valid_names}
|
||||
print(f"[DEBUG] PSD: fs={fs}Hz, n_fft={n_fft}, freq_res={fs/n_fft:.3f}Hz/bin")
|
||||
for k, (fmin, fmax) in bands.items():
|
||||
idx = np.where((freqs >= fmin) & (freqs < fmax))[0]
|
||||
if len(idx) == 0:
|
||||
out[k] = np.zeros(len(valid_indices), dtype=np.float32)
|
||||
print(f"[DEBUG] {k.upper()}: NO freq bins in [{fmin}-{fmax}]Hz")
|
||||
continue
|
||||
print(f"[DEBUG] {k.upper()}: freq bins {freqs[idx[0]]:.2f}-{freqs[idx[-1]]:.2f}Hz (bins {idx[0]}-{idx[-1]}, count={len(idx)})")
|
||||
# 使用线性功率值 (V^2 -> uV^2: * 1e12)
|
||||
bp = np.trapz(psds[:, idx], freqs[idx], axis=1) * 1e12
|
||||
out[k] = bp
|
||||
print(f"[DEBUG] {k.upper()}: power range [{bp.min():.4f}, {bp.max():.4f}] uV^2, mean={bp.mean():.4f}")
|
||||
|
||||
print(f"[INFO] Band powers computed for {len(valid_names)} channels with positions")
|
||||
return out
|
||||
|
||||
|
||||
def _create_topomap_raw(ch_names):
|
||||
"""创建只有 standard_1020 通道位置信息的临时 Raw 对象"""
|
||||
standard_montage = mne.channels.make_standard_montage("standard_1020")
|
||||
ch_pos_map = standard_montage.get_positions()['ch_pos']
|
||||
|
||||
valid_ch_names = []
|
||||
valid_positions = []
|
||||
for name in ch_names:
|
||||
name_upper = name.upper()
|
||||
if name_upper in ch_pos_map:
|
||||
valid_ch_names.append(name)
|
||||
valid_positions.append(ch_pos_map[name_upper])
|
||||
|
||||
if len(valid_ch_names) < 8:
|
||||
return None
|
||||
|
||||
ch_pos = {name: pos for name, pos in zip(valid_ch_names, valid_positions)}
|
||||
montage = mne.channels.make_dig_montage(ch_pos=ch_pos, coord_frame='head')
|
||||
|
||||
info = mne.create_info(ch_names=valid_ch_names, sfreq=250.0, ch_types=["eeg"] * len(valid_ch_names))
|
||||
info.set_montage(montage)
|
||||
|
||||
dummy_data = np.zeros((len(valid_ch_names), 1))
|
||||
return mne.io.RawArray(dummy_data, info, verbose=False)
|
||||
|
||||
|
||||
def plot_average_topomap(band_values, out_dir):
|
||||
"""绘制平均拓扑图"""
|
||||
valid_names = band_values.get("_valid_names", [])
|
||||
if not valid_names:
|
||||
return
|
||||
|
||||
values = band_values["broad"]
|
||||
temp_raw = _create_topomap_raw(valid_names)
|
||||
if temp_raw is None:
|
||||
return
|
||||
|
||||
vmin, vmax = _compute_topomap_vlim([values])
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(6.5, 4.6))
|
||||
im, _ = mne.viz.plot_topomap(
|
||||
values, temp_raw.info, axes=ax, show=False, contours=0,
|
||||
sphere=(0, 0, 0, TOPOMAP_SPHERE_RADIUS), extrapolate='head', border='mean',
|
||||
cmap='turbo'
|
||||
)
|
||||
im.set_clim(vmin=vmin, vmax=vmax)
|
||||
ax.set_title("0.8-30 Hz", fontsize=12)
|
||||
plt.colorbar(im, ax=ax, shrink=0.85)
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(out_dir, "average_topomap.png"), dpi=200)
|
||||
plt.close(fig)
|
||||
print(f"[OK] average_topomap saved")
|
||||
|
||||
|
||||
def plot_band_topomaps(band_values, out_dir):
|
||||
"""绘制分频段拓扑图"""
|
||||
valid_names = band_values.get("_valid_names", [])
|
||||
if not valid_names:
|
||||
return
|
||||
|
||||
order = [
|
||||
("delta", "δ (1-4Hz)"),
|
||||
("theta", "θ (4-8Hz)"),
|
||||
("alpha", "α (8-13Hz)"),
|
||||
("beta", "β (13-30Hz)"),
|
||||
("broad", "1-30 Hz"),
|
||||
]
|
||||
|
||||
temp_raw = _create_topomap_raw(valid_names)
|
||||
if temp_raw is None:
|
||||
return
|
||||
|
||||
all_values = [band_values[k] for k, _ in order]
|
||||
vmin, vmax = _compute_topomap_vlim(all_values)
|
||||
|
||||
fig, axes = plt.subplots(1, 5, figsize=(16, 4.2))
|
||||
ims = []
|
||||
for ax, (k, title) in zip(axes, order):
|
||||
im, _ = mne.viz.plot_topomap(
|
||||
band_values[k], temp_raw.info, axes=ax, show=False, contours=0,
|
||||
sphere=(0, 0, 0, TOPOMAP_SPHERE_RADIUS), extrapolate='head', border='mean',
|
||||
cmap='turbo'
|
||||
)
|
||||
im.set_clim(vmin=vmin, vmax=vmax)
|
||||
ax.set_title(title, fontsize=11)
|
||||
ims.append(im)
|
||||
fig.subplots_adjust(left=0.02, right=0.85, top=0.88, bottom=0.05, wspace=0.35)
|
||||
cax = fig.add_axes([0.87, 0.15, 0.015, 0.7])
|
||||
fig.colorbar(ims[-1], cax=cax)
|
||||
plt.savefig(os.path.join(out_dir, "topomaps.png"), dpi=200)
|
||||
plt.close(fig)
|
||||
print(f"[OK] topomaps saved")
|
||||
|
||||
|
||||
def _compute_topomap_vlim(values):
|
||||
"""计算脑地形图颜色范围"""
|
||||
v_all = np.concatenate(values) if isinstance(values, list) else np.array(values)
|
||||
if TOPOMAP_VMIN is not None and TOPOMAP_VMAX is not None:
|
||||
return TOPOMAP_VMAX - 60, TOPOMAP_VMAX # 保持 50 的范围
|
||||
if TOPOMAP_SYM_SCALE is not None and TOPOMAP_SYM_SCALE > 0:
|
||||
mean_val = np.mean(v_all)
|
||||
std_val = np.std(v_all)
|
||||
return mean_val - std_val * TOPOMAP_SYM_SCALE, mean_val + std_val * TOPOMAP_SYM_SCALE
|
||||
# 统一 vmax:使用所有频段中的最大值
|
||||
# vmin = 0:这样低功率频段会接近 0(白色/冷色),高功率频段突出
|
||||
vmax = np.max(v_all)
|
||||
vmin = 0
|
||||
return vmin, vmax
|
||||
|
||||
|
||||
# ==========================
|
||||
# 预测接口
|
||||
# ==========================
|
||||
def _predict_label_by_model(model_path: str, data_path: str) -> dict:
|
||||
"""调用 infer_pth.py 进行预测"""
|
||||
try:
|
||||
from infer_pth import predict_hc_mdd
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法导入 predict_hc_mdd: {e}")
|
||||
|
||||
import tempfile
|
||||
import scipy.io
|
||||
|
||||
ext = os.path.splitext(data_path)[1].lower()
|
||||
|
||||
if ext == ".mat":
|
||||
# 直接使用 mat 文件
|
||||
result = predict_hc_mdd(os.path.dirname(data_path), model_path)
|
||||
elif ext == ".bdf":
|
||||
# 转换为 mat 格式
|
||||
raw = mne.io.read_raw_bdf(data_path, preload=True, verbose=False)
|
||||
data, times = raw[:]
|
||||
sfreq = raw.info['sfreq']
|
||||
ch_names = raw.ch_names
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
mat_path = os.path.join(temp_dir, "preprocessed_eeg.mat")
|
||||
scipy.io.savemat(mat_path, {
|
||||
'eeg': {
|
||||
'data': (data * 1e6).T,
|
||||
'sample_rate': sfreq,
|
||||
'electrode_name': ch_names
|
||||
}
|
||||
})
|
||||
result = predict_hc_mdd(temp_dir, model_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {ext}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ==========================
|
||||
# 生成 ResultData.txt
|
||||
# ==========================
|
||||
def compute_and_save_txt(model_path, bdf_path, out_dir, eeg_uV_tc, fs, ch_names):
|
||||
"""计算特征指标并保存 ResultData.txt"""
|
||||
# 获取预测结果
|
||||
pred_result = _predict_label_by_model(model_path, bdf_path)
|
||||
pred_label = pred_result.get("pred_label", "UNKNOWN")
|
||||
recommend = "是" if pred_label == "MDD" else "否"
|
||||
|
||||
T, C = eeg_uV_tc.shape
|
||||
mp = build_channel_index_map(ch_names, C)
|
||||
frontal_idx, central_idx, parietal_idx, prefrontal_idx, posterior_idx, left_idx, right_idx = \
|
||||
get_region_indices(mp, C)
|
||||
|
||||
freqs, pxx = welch_psd(eeg_uV_tc, fs)
|
||||
|
||||
# 计算各频段功率比
|
||||
central_alpha = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Alpha"])
|
||||
central_beta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Beta"])
|
||||
frontal_alpha = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Alpha"])
|
||||
frontal_beta = region_mean_power(freqs, pxx, frontal_idx, BANDS_METRICS["Beta"])
|
||||
par_alpha = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Alpha"])
|
||||
par_beta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Beta"])
|
||||
|
||||
central_ab = (central_alpha / (central_beta + EPS)) if central_beta > 0 else 0.0
|
||||
frontal_ab = (frontal_alpha / (frontal_beta + EPS)) if frontal_beta > 0 else 0.0
|
||||
par_ab = (par_alpha / (par_beta + EPS)) if par_beta > 0 else 0.0
|
||||
|
||||
central_theta = region_mean_power(freqs, pxx, central_idx, BANDS_METRICS["Theta"])
|
||||
par_theta = region_mean_power(freqs, pxx, parietal_idx, BANDS_METRICS["Theta"])
|
||||
central_tb = (central_theta / (central_beta + EPS)) if central_beta > 0 else 0.0
|
||||
par_tb = (par_theta / (par_beta + EPS)) if par_beta > 0 else 0.0
|
||||
|
||||
if not left_idx or not right_idx:
|
||||
left_idx = [i for i in prefrontal_idx if (i % 2 == 0)]
|
||||
right_idx = [i for i in prefrontal_idx if (i % 2 == 1)]
|
||||
left_alpha = region_mean_power(freqs, pxx, left_idx, BANDS_METRICS["Alpha"])
|
||||
right_alpha = region_mean_power(freqs, pxx, right_idx, BANDS_METRICS["Alpha"])
|
||||
prefrontal_alpha_asym = float(np.log(right_alpha + EPS) - np.log(left_alpha + EPS))
|
||||
|
||||
iaf = compute_iaf(freqs, pxx, posterior_idx)
|
||||
|
||||
pre_td = region_mean_power(freqs, pxx, prefrontal_idx,
|
||||
(BANDS_METRICS["Delta"][0], BANDS_METRICS["Theta"][1]))
|
||||
pre_total = region_mean_power(freqs, pxx, prefrontal_idx, TOTAL_POWER_BAND)
|
||||
pre_td_rel = (pre_td / (pre_total + EPS)) * 100.0 if pre_total > 0 else 0.0
|
||||
|
||||
def f1(x): return f"{x:.1f}"
|
||||
|
||||
txt = (
|
||||
f"中央区α/β波比值:{f1(central_ab)}\n"
|
||||
f"额区α/β波比值:{f1(frontal_ab)}\n"
|
||||
f"顶区α/β波比值:{f1(par_ab)}\n"
|
||||
f"中央区θ/β波比值:{f1(central_tb)}\n"
|
||||
f"顶区θ/β波比值:{f1(par_tb)}\n"
|
||||
f"前额叶α波不对称性:{f1(prefrontal_alpha_asym)}\n"
|
||||
f"个体化α峰值频率:{f1(iaf)}\n"
|
||||
f"前额叶θ+δ波功率:{f1(pre_td_rel)}\n"
|
||||
f"是否推荐治疗:{recommend}\n"
|
||||
)
|
||||
|
||||
out_path = os.path.join(out_dir, "ResultData.txt")
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
f.write(txt)
|
||||
print(f"[OK] ResultData.txt saved: {out_path}")
|
||||
|
||||
# 打印预测结果
|
||||
print(f"\n========== 预测结果 ==========")
|
||||
print(f"预测标签: {pred_label}")
|
||||
print(f"p(MDD)均值: {pred_result.get('p_mdd_mean', 'N/A'):.4f}")
|
||||
print(f"切片数量: {pred_result.get('n_slices', 'N/A')}")
|
||||
print(f"==============================\n")
|
||||
|
||||
|
||||
# ==========================
|
||||
# 主函数
|
||||
# ==========================
|
||||
def run_all(model_path: str, bdf_dir: str, out_root: str, seconds: int = EEG_PLOT_SECONDS):
|
||||
"""主流程"""
|
||||
if not os.path.exists(bdf_dir):
|
||||
raise RuntimeError(f"输入目录不存在: {bdf_dir}")
|
||||
|
||||
# 支持 .bdf 和 .mat 文件
|
||||
data_files = [f for f in os.listdir(bdf_dir) if f.lower().endswith((".bdf", ".mat"))]
|
||||
if not data_files:
|
||||
raise RuntimeError(f"目录中找不到 .bdf 或 .mat 文件: {bdf_dir}")
|
||||
|
||||
data_files.sort()
|
||||
data_path = os.path.join(bdf_dir, data_files[0])
|
||||
print(f"[INFO] Processing file: {data_path}")
|
||||
|
||||
out_dir = ensure_outdir(out_root)
|
||||
print(f"[INFO] Output directory: {out_dir}")
|
||||
|
||||
raw, sfreq, ch_names = load_data_file(data_path)
|
||||
|
||||
raw_clean = preprocess_bdf(raw)
|
||||
|
||||
try:
|
||||
raw_clean.set_montage("standard_1020", on_missing="ignore")
|
||||
except Exception as e:
|
||||
print(f"[WARN] Failed to re-apply montage: {e}")
|
||||
|
||||
raw_data = raw_clean.get_data()
|
||||
eeg_uV_tc = (raw_data * 1e6).T.astype(np.float32)
|
||||
print(f"[INFO] Preprocessed EEG shape: {eeg_uV_tc.shape}")
|
||||
|
||||
print("[INFO] Generating figures...")
|
||||
plot_psd(eeg_uV_tc, sfreq, ch_names, out_dir)
|
||||
plot_eeg_waveforms(eeg_uV_tc, sfreq, ch_names, out_dir, seconds=seconds)
|
||||
|
||||
print("[INFO] Generating topomaps...")
|
||||
try:
|
||||
band_vals = compute_band_powers_for_topomap(raw_clean, BANDS_TOPOMAP)
|
||||
if band_vals is not None:
|
||||
plot_average_topomap(band_vals, out_dir)
|
||||
plot_band_topomaps(band_vals, out_dir)
|
||||
except Exception as e:
|
||||
print(f"[WARN] Topomap generation failed: {e}")
|
||||
|
||||
print("[INFO] Running prediction...")
|
||||
compute_and_save_txt(model_path, data_path, out_dir, eeg_uV_tc, sfreq, ch_names)
|
||||
|
||||
print("[DONE] All tasks completed.")
|
||||
return out_dir
|
||||
|
||||
|
||||
# ==========================
|
||||
# 命令行入口
|
||||
# ==========================
|
||||
if __name__ == "__main__":
|
||||
import multiprocessing
|
||||
multiprocessing.freeze_support()
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
def get_resource_path(relative_path):
|
||||
"""获取资源绝对路径"""
|
||||
if getattr(sys, 'frozen', False):
|
||||
base_path = os.path.dirname(sys.executable)
|
||||
else:
|
||||
base_path = os.path.dirname(os.path.abspath(__file__))
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
# 默认路径
|
||||
DEFAULT_MODEL = get_resource_path(os.path.join("model", "Model_1.pth"))
|
||||
if getattr(sys, 'frozen', False):
|
||||
EXE_DIR = os.path.dirname(sys.executable)
|
||||
else:
|
||||
EXE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
DEFAULT_BDF_DIR = os.path.join(EXE_DIR, "raw_data")
|
||||
DEFAULT_OUT = os.path.join(EXE_DIR, "out")
|
||||
|
||||
parser = argparse.ArgumentParser(description="EEG Depression Assessment")
|
||||
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL, help="模型文件路径 (.pth)")
|
||||
parser.add_argument("--bdf_dir", type=str, default=DEFAULT_BDF_DIR, help="输入文件夹路径 (包含 .bdf 或 .mat 文件)")
|
||||
parser.add_argument("--out_root", type=str, default=DEFAULT_OUT, help="结果输出目录")
|
||||
parser.add_argument("--seconds", type=int, default=10, help="画波形图的秒数")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"[*] 运行配置:")
|
||||
print(f" - Model : {args.model_path}")
|
||||
print(f" - Input : {args.bdf_dir}")
|
||||
print(f" - Output: {args.out_root}")
|
||||
|
||||
if not os.path.exists(args.bdf_dir):
|
||||
print(f"[ERROR] 输入目录不存在: {args.bdf_dir}")
|
||||
if not os.path.exists(args.model_path):
|
||||
print(f"[ERROR] 模型文件不存在: {args.model_path}")
|
||||
|
||||
run_all(args.model_path, args.bdf_dir, args.out_root, seconds=args.seconds)
|
||||
Reference in New Issue
Block a user