init commit
This commit is contained in:
55
.gitignore
vendored
Normal file
55
.gitignore
vendored
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# data format
|
||||||
|
*.dat
|
||||||
|
*.csv
|
||||||
|
*.edf
|
||||||
|
*.event
|
||||||
|
*.edf.event
|
||||||
|
*.zip
|
||||||
|
*.xlsx
|
||||||
|
*.mat
|
||||||
|
*.json
|
||||||
|
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself
|
||||||
|
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Other common ignores
|
||||||
|
node_modules/
|
||||||
|
dist/
|
||||||
|
tmp/
|
||||||
|
temp/
|
||||||
|
|
||||||
|
# Project-specific ignores
|
||||||
|
# Ignore all directories in the root
|
||||||
|
# merge64ch_0127/
|
||||||
|
/P300_speller/braindecode/
|
||||||
|
/P300_speller/data/
|
||||||
|
/P300_speller/pyRiemann/
|
||||||
|
/P300_speller/README/
|
||||||
|
/merge64ch_new/
|
||||||
|
/merge64ch_tianjinZMQdebug/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
417
Debug_64ch_Decoder/Decoder.py
Normal file
417
Debug_64ch_Decoder/Decoder.py
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
import ast
|
||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
import multiprocessing as mp
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from queue import Empty
|
||||||
|
from scipy import signal
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from Device.SunnyLinker import SunnyLinker64
|
||||||
|
from MI.Algorithm.otherModels import weights_init
|
||||||
|
from SSMVEP.algorithm.tdca import TDCA
|
||||||
|
from SSMVEP.algorithm.base import generate_cca_references
|
||||||
|
from Zmq.zmqServer import zmqServer
|
||||||
|
from Zmq.zmqClient import zmqClient
|
||||||
|
from MI.Algorithm.conformer_2class import onlineTrain
|
||||||
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
from SSVEP.dwfbcca import FbccaDw
|
||||||
|
from Tools.plot_MI_EEG import plotMain
|
||||||
|
|
||||||
|
class Decoder_main(threading.Thread):
|
||||||
|
def __init__(self):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.Runing=True
|
||||||
|
self.decoder = None
|
||||||
|
|
||||||
|
self.fs = 250 # 采样率
|
||||||
|
self.energy = 0 # 电量
|
||||||
|
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||||
|
self.decoder_class = None #解码器类别
|
||||||
|
|
||||||
|
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
|
||||||
|
def connect(self):
|
||||||
|
self.thread_data_server = SunnyLinker64('127.0.0.1', 7878, 250, 64,
|
||||||
|
method='tcp')
|
||||||
|
self.thread_data_server.toUv = True
|
||||||
|
self.thread_data_server.start()
|
||||||
|
|
||||||
|
self.zmqServer = zmqServer()
|
||||||
|
self.zmqServer.start()
|
||||||
|
self.zmqClient = zmqClient('127.0.0.1', 8088)
|
||||||
|
self.zmqClient.connect()
|
||||||
|
|
||||||
|
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
||||||
|
# data: (chans, samples)
|
||||||
|
energy = np.mean(np.var(data, axis=1)) # 各通道方差均值
|
||||||
|
if energy > threshold:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
def init_Decoder(self,decoder_class):
|
||||||
|
'''
|
||||||
|
初始化解码器
|
||||||
|
:param decoder_class: 'ssvep' or 'ssmvep' or 'mi'
|
||||||
|
:return:
|
||||||
|
'''
|
||||||
|
self.decoder_class = decoder_class
|
||||||
|
if decoder_class == 'ssvep':
|
||||||
|
self.n_chan = 8
|
||||||
|
self.thread_data_server.interval_inited = False
|
||||||
|
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
|
||||||
|
self.ListFreq = self.zmqServer.targetFreqs
|
||||||
|
self.num_target = len(self.ListFreq)
|
||||||
|
if self.num_target == 0:
|
||||||
|
return
|
||||||
|
# 初始化对象 二代算法
|
||||||
|
self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5,
|
||||||
|
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
|
||||||
|
# frequence band
|
||||||
|
self.dw.filterFrequenceBank()
|
||||||
|
self.dw.setNotchFilterPara()
|
||||||
|
self.calculateCount = 0
|
||||||
|
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs),
|
||||||
|
5)
|
||||||
|
self.dw.filterInit()
|
||||||
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||||
|
|
||||||
|
elif decoder_class == 'ssmvep':
|
||||||
|
self.thread_data_server.interval_init(decoder_class)
|
||||||
|
self.n_chan = 8
|
||||||
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||||
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||||
|
self.single_train = 10 # 单类别数量
|
||||||
|
self.num_target = 2 # 分类目标数目
|
||||||
|
self.list_freqs = np.array([8, 9]) # 刺激频率
|
||||||
|
self.list_phase = np.array([0, 0]) # 相位
|
||||||
|
self.tdca = TDCA(padding_len=5, n_components=1)
|
||||||
|
self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length,
|
||||||
|
phases=self.list_phase, n_harmonics=5)
|
||||||
|
self.parameter_init(5,45)
|
||||||
|
elif decoder_class == 'mi':
|
||||||
|
self.thread_data_server.interval_init(decoder_class)
|
||||||
|
self.n_chan = 21
|
||||||
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||||
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||||
|
self.single_train = 40 # 单类别数量
|
||||||
|
self.num_target = 2 # 分类目标数目
|
||||||
|
|
||||||
|
self.parameter_init(8, 30)
|
||||||
|
|
||||||
|
|
||||||
|
def parameter_init(self,bandPass_low,bandPass_high):
|
||||||
|
self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息
|
||||||
|
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch
|
||||||
|
self.trainData = [] #训练数据
|
||||||
|
self.trainLabel = [] #训练标签
|
||||||
|
self.plotData = [] #报告分析数据
|
||||||
|
self.plotLabel = [] #报告分析标签
|
||||||
|
self.currentLabel = -1 #刺激界面当前显示的训练标签
|
||||||
|
self.train_started = False #是否开始训练模型
|
||||||
|
self.load_model = False # 调用模型是否完成的标志
|
||||||
|
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
||||||
|
self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
||||||
|
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||||
|
filePath = './online_Models/'
|
||||||
|
self.modelPath = ''.join([filePath, fileName, '.pth'])
|
||||||
|
self.mp_data_queue = mp.Queue() #多进程传参队列
|
||||||
|
self.mp_result_queue = mp.Queue() #多进程结果队列
|
||||||
|
|
||||||
|
def preprocess(self, signal_data):
|
||||||
|
# # 计算每行的平均值
|
||||||
|
row_means = np.mean(signal_data, axis=-1, keepdims=True)
|
||||||
|
# 对每一行去均值
|
||||||
|
signal_data = signal_data - row_means
|
||||||
|
|
||||||
|
signal_data = signal.lfilter(self.b_notch, self.a_notch, signal_data, axis=-1) # 工频陷波
|
||||||
|
signal_data = signal.lfilter(self.b_design, 1, signal_data, axis=-1) # 带通滤波
|
||||||
|
return signal_data
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
while self.Runing:
|
||||||
|
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
|
||||||
|
self.zmqServer.decoder_switch = False
|
||||||
|
self.zmqServer.changeTarget = False
|
||||||
|
self.init_Decoder(self.zmqServer.decoder_class)
|
||||||
|
|
||||||
|
# 同步信息
|
||||||
|
if self.zmqServer.state_mode == 'sync':
|
||||||
|
self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||||
|
self.zmqServer.state_mode = 'rest'
|
||||||
|
# 状态异常,报告上位机
|
||||||
|
if self.status_code != self.thread_data_server.status_code:
|
||||||
|
self.status_code = self.thread_data_server.status_code
|
||||||
|
self.zmqClient.send_to_all('status_code', int(self.status_code))
|
||||||
|
|
||||||
|
# 返回电量
|
||||||
|
if self.energy != self.thread_data_server.energy:
|
||||||
|
self.energy = self.thread_data_server.energy
|
||||||
|
self.zmqClient.send_to_all('energy', int(self.energy))
|
||||||
|
|
||||||
|
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
||||||
|
self.thread_data_server.Impedance(True)
|
||||||
|
self.zmqServer.open_Impedance = -1
|
||||||
|
elif self.zmqServer.open_Impedance == False:
|
||||||
|
self.thread_data_server.Impedance(False)
|
||||||
|
self.zmqServer.open_Impedance = -1
|
||||||
|
|
||||||
|
if self.zmqServer.get_Impedance: # 返回阻抗值
|
||||||
|
if self.thread_data_server.GetDataLenCount() > 250:
|
||||||
|
Impe_data = self.thread_data_server.getData(250)
|
||||||
|
# 计算阻抗
|
||||||
|
imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
|
||||||
|
self.zmqClient.send_to_all('impedance', imps.tolist())
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
if self.zmqServer.getReport: #返回训练报告内容
|
||||||
|
self.zmqServer.getReport = False
|
||||||
|
allData = np.array(self.plotData)
|
||||||
|
allLabel = np.array(self.plotLabel) + 1
|
||||||
|
nTrials = min(len(allLabel),len(allData))
|
||||||
|
if nTrials == 0:
|
||||||
|
self.zmqClient.send_to_all('miReport',0)
|
||||||
|
else:
|
||||||
|
allData = allData[:nTrials]
|
||||||
|
allLabel = allLabel[:nTrials]
|
||||||
|
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
|
||||||
|
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
|
||||||
|
compare_names = ['C3', 'CZ', 'C4']
|
||||||
|
miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
|
||||||
|
fs=self.fs)
|
||||||
|
self.zmqClient.send_to_all('miReport',miReport)
|
||||||
|
|
||||||
|
|
||||||
|
if self.decoder_class == 'ssvep':
|
||||||
|
self.decoder_SSVEP()
|
||||||
|
elif self.decoder_class == 'ssmvep':
|
||||||
|
self.decoder_SSMVEP()
|
||||||
|
elif self.decoder_class == 'mi':
|
||||||
|
self.decoder_MI()
|
||||||
|
else:
|
||||||
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
|
time.sleep(0.005)
|
||||||
|
continue;
|
||||||
|
self.thread_data_server.getData(25)
|
||||||
|
|
||||||
|
|
||||||
|
def decoder_SSVEP(self):
|
||||||
|
if self.zmqServer.StartDecode:
|
||||||
|
self.zmqServer.StartDecode = False
|
||||||
|
self.decodingSteps = 1
|
||||||
|
self.thread_data_server.ResetAll()
|
||||||
|
print('启动预测')
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 50:
|
||||||
|
time.sleep(0.005)
|
||||||
|
return
|
||||||
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
|
data = self.thread_data_server.getDataViaSSVEP(50)
|
||||||
|
data = data[:self.n_chan, :]
|
||||||
|
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
||||||
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||||
|
self.dw.warmFilter(data) # 预热
|
||||||
|
self.decodingSteps = 2
|
||||||
|
print('预热数据完成。开始预测')
|
||||||
|
return
|
||||||
|
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
|
||||||
|
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
|
||||||
|
self.calculateCount += 1
|
||||||
|
if choosenNum != -1 and self.is_valid_signal(data):
|
||||||
|
self.decodingSteps = 3
|
||||||
|
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
|
||||||
|
self.calculateCount = 0
|
||||||
|
if self.decodingSteps == 3: # 发送解码后的信息
|
||||||
|
self.zmqClient.send_to_all('result', int(choosenNum))
|
||||||
|
self.decodingSteps = 0
|
||||||
|
print('发送给界面完成。')
|
||||||
|
def decoder_SSMVEP(self):
|
||||||
|
'''模型训练'''
|
||||||
|
if self.load_model == False and all(
|
||||||
|
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成
|
||||||
|
self.trainData = np.array(self.trainData)
|
||||||
|
self.trainLabel = np.array(self.trainLabel)
|
||||||
|
print(np.shape(self.trainData), (self.trainLabel))
|
||||||
|
# 保存多个数组到文件
|
||||||
|
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel)
|
||||||
|
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
||||||
|
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
||||||
|
now = datetime.now()
|
||||||
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print('模型训练完成', formatted_time)
|
||||||
|
self.load_model = True
|
||||||
|
self.zmqClient.send_to_all('paradigm', 1)
|
||||||
|
|
||||||
|
'''训练阶段采集数据'''
|
||||||
|
if self.zmqServer.state_mode == 'train': # 训练状态
|
||||||
|
if self.zmqServer.StartTrain:
|
||||||
|
self.currentLabel = self.zmqServer.currentLabel
|
||||||
|
self.zmqServer.StartTrain = False
|
||||||
|
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||||
|
self.train_epoch[1] \
|
||||||
|
+ self.thread_data_server.event_inner_idx:
|
||||||
|
time.sleep(0.0001)
|
||||||
|
return
|
||||||
|
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
|
||||||
|
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
|
||||||
|
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
|
||||||
|
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||||
|
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.train_epoch[
|
||||||
|
0]:self.thread_data_server.event_inner_idx + self.train_epoch[1]]
|
||||||
|
print('trial: ', self.thread_data_server.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||||
|
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||||
|
self.trainLabel, list) \
|
||||||
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||||
|
self.trainData.append(trainTrial)
|
||||||
|
self.trainLabel.append(self.currentLabel)
|
||||||
|
|
||||||
|
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
||||||
|
if self.load_model == False: # 模型尚未训练完成
|
||||||
|
time.sleep(0.01)
|
||||||
|
return
|
||||||
|
else: # 已有模型
|
||||||
|
if self.zmqServer.StartDecode:
|
||||||
|
self.zmqServer.StartDecode = False
|
||||||
|
now = datetime.now()
|
||||||
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print('启动预测 ', formatted_time)
|
||||||
|
|
||||||
|
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||||
|
self.interval_epoch[1] \
|
||||||
|
+ self.thread_data_server.event_inner_idx:
|
||||||
|
time.sleep(0.0001)
|
||||||
|
return
|
||||||
|
data = self.thread_data_server.get_SSMVEPData() # 读取全部数据
|
||||||
|
print('取出的: ', data.shape, 'event: ', data[-2, self.thread_data_server.event_inner_idx])
|
||||||
|
data = self.preprocess(data[:self.n_chan, :]) # 预处理
|
||||||
|
data = data[:,
|
||||||
|
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||||
|
pad_eeg_test = np.zeros(
|
||||||
|
(data.shape[0], int((self.sample_length + 0.1) * self.fs)))
|
||||||
|
pad_eeg_test[:, :int(self.sample_length * self.fs)] = data
|
||||||
|
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
|
||||||
|
if isinstance(choosenNum, np.ndarray):
|
||||||
|
choosenNum = choosenNum[0]
|
||||||
|
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
|
||||||
|
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
|
||||||
|
self.zmqClient.send_to_all('result', int(choosenNum))
|
||||||
|
print('发送给界面完成。')
|
||||||
|
else: # 休息状态
|
||||||
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
|
time.sleep(0.005)
|
||||||
|
return
|
||||||
|
self.thread_data_server.getData(25)
|
||||||
|
def decoder_MI(self):
|
||||||
|
'''模型训练'''
|
||||||
|
if self.train_started == False and all(
|
||||||
|
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
|
||||||
|
self.zmqClient.send_to_all('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
||||||
|
self.train_started = True
|
||||||
|
self.trainData = np.array(self.trainData)
|
||||||
|
self.trainLabel = np.array(self.trainLabel) + 1
|
||||||
|
# print('训练集:',np.shape(self.trainData), (self.trainLabel))
|
||||||
|
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
|
||||||
|
p.start()
|
||||||
|
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
|
||||||
|
'n_chan': self.n_chan})
|
||||||
|
|
||||||
|
'''检查模型是否训练完成,调用'''
|
||||||
|
if self.load_model == False and self.train_started == True:
|
||||||
|
try:
|
||||||
|
result = self.mp_result_queue.get_nowait()
|
||||||
|
if result['status'] == 'success':
|
||||||
|
print("模型训练完成,加载新模型")
|
||||||
|
# 调用模型
|
||||||
|
self.model = torch.load(self.modelPath, weights_only=False)
|
||||||
|
self.model.eval()
|
||||||
|
# 模型预热
|
||||||
|
warmup_data = np.random.uniform(-1, 1, (1, 1, self.n_chan, 1000))
|
||||||
|
warmup_data = torch.from_numpy(warmup_data)
|
||||||
|
warmup_data = Variable(warmup_data.type(torch.cuda.FloatTensor))
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = self.model(warmup_data)
|
||||||
|
self.load_model = True
|
||||||
|
self.zmqClient.send_to_all('paradigm', 1) # 模型调用完毕,通知上位机
|
||||||
|
else:
|
||||||
|
print("训练失败:", result['msg'])
|
||||||
|
except Empty:
|
||||||
|
pass # 还没完成
|
||||||
|
except Exception as e:
|
||||||
|
print('模型调用失败: ', e)
|
||||||
|
|
||||||
|
'''训练阶段采集数据'''
|
||||||
|
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
|
||||||
|
if self.zmqServer.StartTrain:
|
||||||
|
self.currentLabel = self.zmqServer.currentLabel
|
||||||
|
self.zmqServer.StartTrain = False
|
||||||
|
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||||
|
self.interval_epoch[1] \
|
||||||
|
+ self.thread_data_server.event_inner_idx:
|
||||||
|
time.sleep(0.0001)
|
||||||
|
return
|
||||||
|
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
|
||||||
|
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
|
||||||
|
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
|
||||||
|
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||||
|
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||||
|
print('trial: ', self.thread_data_server.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
|
||||||
|
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
||||||
|
list) \
|
||||||
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||||
|
self.trainData.append(trainTrial)
|
||||||
|
self.trainLabel.append(self.currentLabel)
|
||||||
|
print('训练集:', np.shape(self.trainData))
|
||||||
|
self.plotData.append(originalTrial[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
||||||
|
self.plotLabel.append(self.currentLabel)
|
||||||
|
|
||||||
|
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
||||||
|
if self.zmqServer.StartDecode:
|
||||||
|
self.zmqServer.StartDecode = False
|
||||||
|
now = datetime.now()
|
||||||
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print('启动预测 ', formatted_time)
|
||||||
|
|
||||||
|
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||||
|
self.interval_epoch[1] \
|
||||||
|
+ self.thread_data_server.event_inner_idx:
|
||||||
|
time.sleep(0.0001)
|
||||||
|
return
|
||||||
|
originalData = self.thread_data_server.get_MIData() # 读取全部数据
|
||||||
|
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.thread_data_server.event_inner_idx])
|
||||||
|
start = time.time()
|
||||||
|
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
|
||||||
|
data = data[:,
|
||||||
|
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||||
|
self.plotData.append(
|
||||||
|
originalData[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
||||||
|
|
||||||
|
test_data = data[np.newaxis, np.newaxis, :, :]
|
||||||
|
test_data = torch.from_numpy(test_data)
|
||||||
|
test_data = Variable(test_data.type(torch.cuda.FloatTensor))
|
||||||
|
with torch.no_grad():
|
||||||
|
Cls = self.model(test_data)
|
||||||
|
y_pred = torch.max(Cls, 1)[1]
|
||||||
|
self.plotLabel.append(int(y_pred.item()))
|
||||||
|
print('运动意图识别: ', y_pred)
|
||||||
|
self.zmqClient.send_to_all('result', int(y_pred.item()))
|
||||||
|
end = time.time()
|
||||||
|
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
||||||
|
else: # 休息状态
|
||||||
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
|
time.sleep(0.005)
|
||||||
|
return
|
||||||
|
self.thread_data_server.getData(25)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
'''
|
||||||
|
停止运行
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
self.zmqServer.stop()
|
||||||
|
self.Runing=False
|
||||||
754
Debug_64ch_Decoder/Device/SunnyLinker.py
Normal file
754
Debug_64ch_Decoder/Device/SunnyLinker.py
Normal file
@@ -0,0 +1,754 @@
|
|||||||
|
# -*-coding:utf-8 -*-
|
||||||
|
'''
|
||||||
|
SunnyLinker的通讯驱动
|
||||||
|
'''
|
||||||
|
import ast
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from threading import Thread, Event
|
||||||
|
import serial
|
||||||
|
from scipy import signal
|
||||||
|
from serial.serialutil import SerialException
|
||||||
|
|
||||||
|
from Device.protocol import ProtocolFrame
|
||||||
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
|
||||||
|
class RingBuffer:
|
||||||
|
def __init__(self, n_chan, n_points):
|
||||||
|
self.n_chan = n_chan
|
||||||
|
self.n_points = n_points
|
||||||
|
self.buffer = np.zeros((n_chan, n_points))
|
||||||
|
self.currentPtr = 0
|
||||||
|
self.readPtr = 0
|
||||||
|
self.nUpdate = 0
|
||||||
|
self.rawData = np.zeros((n_chan, 1))
|
||||||
|
|
||||||
|
## append buffer and update current pointer
|
||||||
|
def appendBuffer(self, data):
|
||||||
|
if self.nUpdate == self.n_points:
|
||||||
|
raise Exception("Buffer is full")
|
||||||
|
|
||||||
|
n = data.shape[1]
|
||||||
|
|
||||||
|
# 计算可以写入的元素数量
|
||||||
|
write_count = min(self.n_points - self.nUpdate, n)
|
||||||
|
# 写入新数据
|
||||||
|
self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count]
|
||||||
|
# 更新结束指针
|
||||||
|
self.currentPtr = (self.currentPtr + write_count) % self.n_points
|
||||||
|
# 更新大小
|
||||||
|
self.nUpdate += write_count
|
||||||
|
|
||||||
|
## get data from buffer
|
||||||
|
def getData(self, count=50):
|
||||||
|
# 确保不会尝试读取超过缓冲区当前大小的数据
|
||||||
|
count = min(count, self.nUpdate)
|
||||||
|
|
||||||
|
# 计算读取结束后的下一个位置
|
||||||
|
next_read_ptr = (self.readPtr + count) % self.n_points
|
||||||
|
if self.readPtr + count <= self.n_points:
|
||||||
|
# 情况 1:不环绕,数据是连续的
|
||||||
|
end_index = next_read_ptr if next_read_ptr != 0 else self.n_points
|
||||||
|
data = self.buffer[:, self.readPtr:end_index]
|
||||||
|
else:
|
||||||
|
# 情况 2:发生环绕,数据被分成两部分
|
||||||
|
# 第一部分:从 readPtr 到缓冲区末尾
|
||||||
|
part1 = self.buffer[:, self.readPtr:]
|
||||||
|
# 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点
|
||||||
|
part2 = self.buffer[:, :next_read_ptr]
|
||||||
|
# 将两部分在列方向上拼接
|
||||||
|
data = np.concatenate((part1, part2), axis=1)
|
||||||
|
|
||||||
|
# 更新读指针
|
||||||
|
self.readPtr = next_read_ptr
|
||||||
|
# 更新大小
|
||||||
|
self.nUpdate -= count
|
||||||
|
return data
|
||||||
|
|
||||||
|
# reset buffer
|
||||||
|
def resetAllPara(self):
|
||||||
|
self.nUpdate = 0
|
||||||
|
self.currentPtr = 0
|
||||||
|
self.readPtr = 0 # add by lizhenhua 清空读指针
|
||||||
|
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
||||||
|
|
||||||
|
|
||||||
|
class SunnyLinker64(Thread, ):
|
||||||
|
serial_port = str(IniRead('system', 'Serial_port'))
|
||||||
|
t_buffer = 10
|
||||||
|
n_chan = 64
|
||||||
|
srate = 250
|
||||||
|
receiveData = b''
|
||||||
|
toUv=True#转为uV
|
||||||
|
RingBufferLock = threading.Lock()
|
||||||
|
|
||||||
|
# 单例模式
|
||||||
|
_instance = None
|
||||||
|
_initialized = False # 检查是否已经初始化
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(SunnyLinker64, cls).__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
def __init__(self, host='127.0.0.1', port=7878, srate=250, n_chan=64,method = 'tcp'):
|
||||||
|
if SunnyLinker64._initialized:
|
||||||
|
return
|
||||||
|
Thread.__init__(self)
|
||||||
|
self.daemon = True
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.srate = srate
|
||||||
|
self.n_chan = n_chan
|
||||||
|
self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输
|
||||||
|
self.__ringBuffer = RingBuffer(self.n_chan + 2,
|
||||||
|
int(np.round(self.t_buffer * self.srate)))
|
||||||
|
self.energy = 0 # 电量
|
||||||
|
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||||
|
self.gain_value = 6 # 增益倍数
|
||||||
|
self.interval_inited = False #ssmvep或mi时间窗是否初始化
|
||||||
|
|
||||||
|
# 设置初始化标志为True,防止重复初始化
|
||||||
|
SunnyLinker64._initialized = True
|
||||||
|
|
||||||
|
# --- 新增:用于心跳检测 ---
|
||||||
|
self.last_called = 0 # 初始化为0
|
||||||
|
self.last_called_lock = threading.Lock() # 保护 last_called 的访问
|
||||||
|
|
||||||
|
def interval_init(self,decoder_class):
|
||||||
|
if decoder_class == 'ssmvep':
|
||||||
|
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||||
|
self.interval_epoch = [int(i * self.srate) for i in interval_epoch] # epoch截取信息
|
||||||
|
self.train_epoch = [int(self.interval_epoch[0]),
|
||||||
|
int(self.interval_epoch[1] + 0.1 * self.srate)] # 训练样本epoch
|
||||||
|
self.latency = (self.interval_epoch[
|
||||||
|
1] + 0.1 * self.srate) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉
|
||||||
|
self.train_latency = (self.train_epoch[1] + 0.1 * self.srate) // 5
|
||||||
|
|
||||||
|
elif decoder_class == 'mi':
|
||||||
|
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||||
|
self.interval_epoch = [int(i * self.srate) for i in interval_epoch] # epoch截取信息
|
||||||
|
self.train_epoch = self.interval_epoch.copy()
|
||||||
|
self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;
|
||||||
|
self.train_latency = self.latency
|
||||||
|
|
||||||
|
print('时间窗:', (interval_epoch))
|
||||||
|
self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息
|
||||||
|
self.event_inner_idx = -1 # event在5位数据包内部的idx
|
||||||
|
self.epoch_finished = False # 接收epoch是否完整
|
||||||
|
self.pack_contain_event = False # 当前包是否含有event
|
||||||
|
self.predict_event = 99
|
||||||
|
self.events = [1, 2, self.predict_event]
|
||||||
|
if getattr(self, 'serial', None) and self.serial.is_open:
|
||||||
|
self.serial.close()
|
||||||
|
self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口
|
||||||
|
self.interval_inited = True
|
||||||
|
|
||||||
|
def set_sampleRate(self,sampleRate_Code=0x00):
|
||||||
|
'''
|
||||||
|
设置采样率
|
||||||
|
:param sampleRate_Code: 0x00:250Hz,0x01:500Hz,0x02:1000Hz,0x03:2000Hz
|
||||||
|
'''
|
||||||
|
function_code = 0x02
|
||||||
|
gain_code = 0x06
|
||||||
|
sampleRate_Code = [gain_code,sampleRate_Code]
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, sampleRate_Code)
|
||||||
|
if self.method == 'tcp':
|
||||||
|
self.sock.send(packed_data)
|
||||||
|
|
||||||
|
def push_trigger(self,label):
|
||||||
|
'''
|
||||||
|
数据打标
|
||||||
|
@param label:标签类别
|
||||||
|
'''
|
||||||
|
function_code = None
|
||||||
|
label = [label]
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, label)
|
||||||
|
if self.method == 'tcp' and hasattr(self,'serial'):
|
||||||
|
print('发送:', label, datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3])
|
||||||
|
self.serial.write(packed_data)
|
||||||
|
def Impedance(self, On):
|
||||||
|
'''
|
||||||
|
阻抗检测开关
|
||||||
|
:param On:True为开启,False为关闭
|
||||||
|
:return: 组好的协议帧
|
||||||
|
'''
|
||||||
|
function_code = 0x01
|
||||||
|
if On:
|
||||||
|
data = [0x1]
|
||||||
|
self.gain_value = 6
|
||||||
|
else:
|
||||||
|
data = [0x0]
|
||||||
|
self.gain_value = 6
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, data)
|
||||||
|
if self.method == 'tcp':
|
||||||
|
self.sock.send(packed_data)
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
try:
|
||||||
|
if self.method == 'serial':
|
||||||
|
# 开启com口,波特率115200,超时5
|
||||||
|
self.sock = serial.Serial(self.host, self.port, timeout=5)
|
||||||
|
self.sock.flushInput() # 清空缓冲区
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
while not count:
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
# # 接收和存储数据
|
||||||
|
data = (self.sock.read(count))
|
||||||
|
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||||
|
elif self.method == 'tcp':
|
||||||
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
self.sock.connect((self.host, int(self.port)))
|
||||||
|
self.set_sampleRate(0x00) #设置250Hz采样率
|
||||||
|
except Exception as e:
|
||||||
|
print("请打开头环")
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
print("connected")
|
||||||
|
|
||||||
|
def extract_packet(self, packet):
|
||||||
|
# 存储一个点的八通道数据
|
||||||
|
dataList = []
|
||||||
|
# 存储116个点的八通道数据
|
||||||
|
dataMatrix = []
|
||||||
|
|
||||||
|
for j in range(5):
|
||||||
|
for i in range(self.n_chan):
|
||||||
|
if not self.toUv:#原始数据直接输出
|
||||||
|
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||||
|
194 * j + 25 + 2 + i * 3]
|
||||||
|
|
||||||
|
else:#转为uV
|
||||||
|
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||||
|
194 * j + 25 + 2 + i * 3]
|
||||||
|
if val < 8388608:
|
||||||
|
val = val * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||||
|
else:
|
||||||
|
val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||||
|
dataList.append(val)
|
||||||
|
#同步触发源
|
||||||
|
val = packet[194 * j + 25 + (i+1) * 3]
|
||||||
|
dataList.append(val)
|
||||||
|
#同步触发序号
|
||||||
|
val = packet[194 * j + 25 + (i+1) * 3+1]
|
||||||
|
dataList.append(val)
|
||||||
|
|
||||||
|
|
||||||
|
# 将数据矩阵进行拼接
|
||||||
|
if len(dataMatrix) == 0:
|
||||||
|
dataMatrix = np.asmatrix(dataList)
|
||||||
|
else:
|
||||||
|
dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0)
|
||||||
|
dataList.clear()
|
||||||
|
return np.transpose(dataMatrix)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.connect()
|
||||||
|
self.running = True
|
||||||
|
self.PackageLength = 998
|
||||||
|
# 启动心跳检测线程
|
||||||
|
threading.Thread(target=self.heartbeat_checker, daemon=True).start()
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
if self.method == 'serial':
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
if count:
|
||||||
|
# 接收和存储数据
|
||||||
|
data = (self.sock.read(count))
|
||||||
|
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||||
|
elif self.method == 'tcp':
|
||||||
|
data = self.sock.recv(600)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
self.receiveData += data
|
||||||
|
with self.last_called_lock:
|
||||||
|
self.last_called = time.time()
|
||||||
|
self.status_code = 1 # 收到数据,标记为正常
|
||||||
|
if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind(
|
||||||
|
b'\x55\x55') >= self.PackageLength - 2:
|
||||||
|
|
||||||
|
index = self.receiveData.index(b'\xaa')
|
||||||
|
self.receiveData = self.receiveData[index:]
|
||||||
|
if len(self.receiveData) >= self.PackageLength:
|
||||||
|
onepackage = self.receiveData[:self.PackageLength]
|
||||||
|
if onepackage[7] != 0:
|
||||||
|
self.energy = onepackage[7] # 电量
|
||||||
|
self.receiveData = self.receiveData[self.PackageLength:]
|
||||||
|
dataMatrix = self.extract_packet(onepackage)
|
||||||
|
try:
|
||||||
|
with self.RingBufferLock:
|
||||||
|
if self.interval_inited:
|
||||||
|
self.epoch_finished = self.detect_event(dataMatrix)
|
||||||
|
if self.pack_contain_event:
|
||||||
|
self.__ringBuffer.resetAllPara() # 检测到当前pack含有event,清除ringbuffer中之前的数据
|
||||||
|
self.__ringBuffer.appendBuffer(dataMatrix)
|
||||||
|
# self.plotBuffer.appendBuffer(dataMatrix)
|
||||||
|
if self.epoch_finished:
|
||||||
|
time.sleep(0.005)
|
||||||
|
print('epoch_finished: ', datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3])
|
||||||
|
else:
|
||||||
|
self.__ringBuffer.appendBuffer(dataMatrix)
|
||||||
|
except Exception as e:
|
||||||
|
print("锁:写入异常",e)
|
||||||
|
# self.RingBufferLock.release()
|
||||||
|
except ConnectionResetError:
|
||||||
|
self.status_code = 0 # 状态异常
|
||||||
|
print("Connection was reset by the peer.")
|
||||||
|
break
|
||||||
|
self.sock.close()
|
||||||
|
|
||||||
|
# 检测是否含有标签
|
||||||
|
def detect_event(self, samples):
|
||||||
|
self.pack_contain_event = False
|
||||||
|
events = np.array(samples[-2])[0].tolist()
|
||||||
|
for idx, event in enumerate(events):
|
||||||
|
if int(event) in self.events:
|
||||||
|
new_key = "".join(
|
||||||
|
[
|
||||||
|
str(event),
|
||||||
|
datetime.datetime.now().strftime("%Y-%m-%d \
|
||||||
|
-%H-%M-%S"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if event == self.predict_event:
|
||||||
|
self.count_events[new_key] = self.latency + 1
|
||||||
|
else:
|
||||||
|
self.count_events[new_key] = self.train_latency + 1
|
||||||
|
self.event_inner_idx = idx
|
||||||
|
self.pack_contain_event = True
|
||||||
|
drop_items = []
|
||||||
|
for key, value in self.count_events.items():
|
||||||
|
value = value - 1
|
||||||
|
if value == 0:
|
||||||
|
drop_items.append(key)
|
||||||
|
self.count_events[key] = value
|
||||||
|
for key in drop_items:
|
||||||
|
del self.count_events[key]
|
||||||
|
if drop_items:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# --- 新增:心跳检测线程 ---
|
||||||
|
def heartbeat_checker(self):
|
||||||
|
"""
|
||||||
|
定期检查是否在最近2秒内收到 eegData
|
||||||
|
如果超过2秒未收到,则设置 status_code = 0
|
||||||
|
"""
|
||||||
|
while self.running:
|
||||||
|
time.sleep(0.5) # 每0.5秒检查一次
|
||||||
|
with self.last_called_lock:
|
||||||
|
now = time.time()
|
||||||
|
# 只有收到过一次数据后才开始判断超时
|
||||||
|
if self.last_called > 0 and (now - self.last_called) > 2:
|
||||||
|
if self.status_code != 0:
|
||||||
|
print("EEG data timeout: disconnected")
|
||||||
|
self.status_code = 0
|
||||||
|
|
||||||
|
def getDataViaSSVEP(self,count):
|
||||||
|
'''
|
||||||
|
ssvep的视觉通道,共8个通道
|
||||||
|
@param count: 每通道读取的数值数量
|
||||||
|
@return: 返回最新的数值
|
||||||
|
'''
|
||||||
|
data=self.getData(count)
|
||||||
|
# PO5,POZ,PO6,O2,PO8,OZ,O1,PO7 64是event导联
|
||||||
|
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55,64]
|
||||||
|
row_to_select=np.array(rows_to_extract)
|
||||||
|
data=data[row_to_select,:]
|
||||||
|
return data
|
||||||
|
def get_MIData(self):
|
||||||
|
'''
|
||||||
|
取出当前所有数值
|
||||||
|
:return:
|
||||||
|
'''
|
||||||
|
data = self.getData(self.__ringBuffer.nUpdate)
|
||||||
|
#MI选取导联:FC3,FC1,FCZ,FC2,FC4,C5,C3,C1,CZ,C2,C4,C6,CP3,CP1,CP2,CP4,P3,P1,PZ,P2,P4,event1,event2
|
||||||
|
rows_to_extract = [8, 15, 12, 14, 18, 23, 16,59,50,58,17,45,29,11,10,19,20,61,51,60,21,64,65]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
data = data[row_to_select,:]
|
||||||
|
return data
|
||||||
|
def get_SSMVEPData(self):
|
||||||
|
'''
|
||||||
|
取出当前所有数值
|
||||||
|
:return:
|
||||||
|
'''
|
||||||
|
data = self.getData(self.__ringBuffer.nUpdate)
|
||||||
|
# PO5,POZ,PO6,O2,PO8,OZ,O1,PO7 64是event导联
|
||||||
|
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64,65]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
data = data[row_to_select, :]
|
||||||
|
return data
|
||||||
|
def getImpedance(self, data,decoder_class):
|
||||||
|
'''
|
||||||
|
获取阻抗值,已经放大100倍,单位是kΩ
|
||||||
|
@param data: 准备计算的通道数据,每通道200个值,注意不要把信号打标的通道传进来
|
||||||
|
@return: 返回各个通道的阻抗值
|
||||||
|
'''
|
||||||
|
impedanceList = []
|
||||||
|
for channelindex in range(data.shape[0]):
|
||||||
|
if len(data[channelindex]) > 0:
|
||||||
|
data_list = []
|
||||||
|
# 设计陷波滤波器,去除50Hz成分
|
||||||
|
is50filter = True
|
||||||
|
if is50filter:
|
||||||
|
b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽,1000是采样频率
|
||||||
|
data_list = signal.lfilter(b, a, data[channelindex].tolist())
|
||||||
|
|
||||||
|
else:
|
||||||
|
data_list.extend(data[channelindex].tolist())
|
||||||
|
|
||||||
|
data_list = data_list[-1000:]
|
||||||
|
# 执行FFT
|
||||||
|
fft_result = np.fft.fft(data_list)
|
||||||
|
fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果
|
||||||
|
freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴
|
||||||
|
|
||||||
|
# y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())],
|
||||||
|
# fft_magnitude[1:-1] * 2 / len(t[0].tolist()),
|
||||||
|
# [fft_magnitude[-1] / len(t[0].tolist())]))
|
||||||
|
|
||||||
|
# 找到幅值最大的频率成分的索引(忽略直流分量,即索引0)
|
||||||
|
max_index = np.argmax(fft_magnitude[1:])
|
||||||
|
|
||||||
|
# 获取最大幅值的频率索引(加上1,因为索引0是直流分量)
|
||||||
|
freq_index = max_index + 1
|
||||||
|
|
||||||
|
# 获取最大幅值
|
||||||
|
max_magnitude = fft_magnitude[freq_index]
|
||||||
|
|
||||||
|
# 阻抗
|
||||||
|
import math
|
||||||
|
result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4
|
||||||
|
result *= 0.44 * 100 # 统一放大100倍
|
||||||
|
impedanceList.append(int(result))
|
||||||
|
# print(max_magnitude, result)
|
||||||
|
else:
|
||||||
|
impedanceList.append(0)
|
||||||
|
impedances = np.array(impedanceList)
|
||||||
|
if decoder_class == 'mi':
|
||||||
|
impedances = impedances[np.array([8, 15, 12, 14, 18, 23, 16,59,50,58,17,45,29,11,10,19,20,61,51,60,21])]
|
||||||
|
else:
|
||||||
|
impedances = impedances[np.array([13, 3, 2, 46, 9, 54, 47, 55])]
|
||||||
|
return impedances
|
||||||
|
def getData(self,count):
|
||||||
|
'''
|
||||||
|
获取最新的数据
|
||||||
|
@param count: 每通道返回的最数值数目
|
||||||
|
@return: 所有通道的最新count个数值
|
||||||
|
'''
|
||||||
|
data=None
|
||||||
|
try:
|
||||||
|
with self.RingBufferLock:
|
||||||
|
data = self.__ringBuffer.getData(count)
|
||||||
|
except:
|
||||||
|
print("锁:读取异常")
|
||||||
|
# self.RingBufferLock.release()
|
||||||
|
|
||||||
|
|
||||||
|
return data
|
||||||
|
def GetDataLenCount(self):
|
||||||
|
'''
|
||||||
|
获取最新缓存中每个通道的数量
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
return self.__ringBuffer.nUpdate
|
||||||
|
|
||||||
|
def ResetAll(self):
|
||||||
|
'''
|
||||||
|
清空缓存
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
with self.RingBufferLock:
|
||||||
|
self.__ringBuffer.resetAllPara()
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
class SunnyLinker8(Thread, ):
|
||||||
|
receiveData = ''
|
||||||
|
t_buffer = 10
|
||||||
|
n_chan = 9
|
||||||
|
srate = 1000
|
||||||
|
receiveData = b''
|
||||||
|
toUv=False#转为uV
|
||||||
|
RingBufferLock = threading.Lock()
|
||||||
|
def __init__(self, host, port, srate=1000, n_chan=9,method = 'tcp'):
|
||||||
|
Thread.__init__(self)
|
||||||
|
self.daemon = True
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.srate = srate
|
||||||
|
self.n_chan = n_chan
|
||||||
|
self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输
|
||||||
|
self.__ringBuffer = RingBuffer(self.n_chan + 2,
|
||||||
|
int(np.round(self.t_buffer * self.srate)))
|
||||||
|
self.energy = 0 #电量
|
||||||
|
self.status_code = 0 #与采集设备通信的状态码,0为异常,1为正常
|
||||||
|
self.gain_value = 6 # 增益倍数
|
||||||
|
|
||||||
|
def push_trigger(self,label):
|
||||||
|
'''
|
||||||
|
数据打标
|
||||||
|
@param label:标签类别
|
||||||
|
'''
|
||||||
|
function_code = None
|
||||||
|
label = [label]
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, label)
|
||||||
|
if self.method == 'tcp':
|
||||||
|
self.sock.send(packed_data)
|
||||||
|
elif self.method == 'serial':
|
||||||
|
self.sock.write(packed_data)
|
||||||
|
|
||||||
|
def Impedance(self, On):
|
||||||
|
'''
|
||||||
|
阻抗检测开关
|
||||||
|
:param On:True为开启,False为关闭
|
||||||
|
:return: 组好的协议帧
|
||||||
|
'''
|
||||||
|
function_code = None
|
||||||
|
if On:
|
||||||
|
data = [0xA1]
|
||||||
|
self.gain_value = 24
|
||||||
|
else:
|
||||||
|
data = [0xA0]
|
||||||
|
self.gain_value = 6
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, data)
|
||||||
|
if self.method == 'tcp':
|
||||||
|
self.sock.send(packed_data)
|
||||||
|
elif self.method == 'serial':
|
||||||
|
self.sock.write(packed_data)
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
try:
|
||||||
|
if self.method == 'serial':
|
||||||
|
# 开启com口,波特率115200,超时5
|
||||||
|
self.sock = serial.Serial(self.host, self.port, timeout=5)
|
||||||
|
self.sock.flushInput() # 清空缓冲区
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
while not count:
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
# # 接收和存储数据
|
||||||
|
data = (self.sock.read(count))
|
||||||
|
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||||
|
elif self.method == 'tcp':
|
||||||
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
self.sock.connect((self.host, int(self.port)))
|
||||||
|
except Exception as e:
|
||||||
|
print("请打开头环")
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
print("connected")
|
||||||
|
|
||||||
|
def extract_packet(self, packet):
|
||||||
|
# 存储一个点的八通道数据
|
||||||
|
dataList = []
|
||||||
|
# 存储116个点的八通道数据
|
||||||
|
dataMatrix = []
|
||||||
|
|
||||||
|
# index = (packet[1] << 24) | (packet[2] << 16) | (packet[3] << 8) | packet[4]
|
||||||
|
# print(index)
|
||||||
|
|
||||||
|
for j in range(5):
|
||||||
|
for i in range(self.n_chan):
|
||||||
|
if not self.toUv:#原始数据直接输出
|
||||||
|
val = (packet[26 * j + 25 + i * 3] << 16) | (packet[26 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||||
|
26 * j + 25 + 2 + i * 3]
|
||||||
|
|
||||||
|
else:#转为uV
|
||||||
|
val = (packet[26 * j + 25 + i * 3] << 16) | (packet[26 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||||
|
26 * j + 25 + 2 + i * 3]
|
||||||
|
if val < 8388608:
|
||||||
|
val = val * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||||
|
else:
|
||||||
|
val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||||
|
dataList.append(val)
|
||||||
|
#同步触发源
|
||||||
|
val = packet[26 * j + 25 + (i+1) * 3]
|
||||||
|
dataList.append(val)
|
||||||
|
#同步触发序号
|
||||||
|
val = packet[26 * j + 25 + (i+1) * 3+1]
|
||||||
|
dataList.append(val)
|
||||||
|
|
||||||
|
|
||||||
|
# 将数据矩阵进行拼接
|
||||||
|
if len(dataMatrix) == 0:
|
||||||
|
dataMatrix = np.asmatrix(dataList)
|
||||||
|
else:
|
||||||
|
dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0)
|
||||||
|
dataList.clear()
|
||||||
|
return np.transpose(dataMatrix)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.connect()
|
||||||
|
self.running = True
|
||||||
|
self.PackageLength = 158
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while self.running:
|
||||||
|
if self.method == 'serial':
|
||||||
|
end_time = time.time()
|
||||||
|
if end_time-start_time > 2: #超过2s未收到数据
|
||||||
|
self.status_code = 0 #状态异常
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
if count:
|
||||||
|
start_time = time.time()
|
||||||
|
self.status_code = 1 # 收到数据,状态正常
|
||||||
|
# 接收和存储数据
|
||||||
|
data = (self.sock.read(count))
|
||||||
|
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||||
|
elif self.method == 'tcp':
|
||||||
|
data = self.sock.recv(100)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
self.receiveData += data
|
||||||
|
if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind(
|
||||||
|
b'\x55\x55') >= self.PackageLength - 2:
|
||||||
|
|
||||||
|
index = self.receiveData.index(b'\xaa')
|
||||||
|
self.receiveData = self.receiveData[index:]
|
||||||
|
if len(self.receiveData) >= self.PackageLength:
|
||||||
|
onepackage = self.receiveData[:self.PackageLength]
|
||||||
|
if onepackage[7] != 0:
|
||||||
|
self.energy = onepackage[7] # 电量
|
||||||
|
self.receiveData = self.receiveData[self.PackageLength:]
|
||||||
|
dataMatrix = self.extract_packet(onepackage)
|
||||||
|
try:
|
||||||
|
with self.RingBufferLock:
|
||||||
|
self.__ringBuffer.appendBuffer(dataMatrix)
|
||||||
|
except:
|
||||||
|
print("锁:写入异常")
|
||||||
|
self.sock.close()
|
||||||
|
|
||||||
|
except ConnectionResetError:
|
||||||
|
self.status_code = 0 # 状态异常
|
||||||
|
print("Connection was reset by the peer.")
|
||||||
|
except SerialException as Se:
|
||||||
|
self.status_code = 0
|
||||||
|
print('串口通信异常!请检查适配器')
|
||||||
|
|
||||||
|
|
||||||
|
def process_packet(self):
|
||||||
|
if self.circular_buffer.buffer_length > 158:
|
||||||
|
packet = self.circular_buffer.extract_packet()
|
||||||
|
|
||||||
|
if packet:
|
||||||
|
# Here you would parse the packet according to the protocol
|
||||||
|
# print("Received packet:%s,index:%s", len(packet),str(integer_value))
|
||||||
|
return packet
|
||||||
|
else:
|
||||||
|
print("Received Nothing")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def getDataViaSSVEP(self,count):
|
||||||
|
'''
|
||||||
|
ssvep的视觉通道,共8个通道
|
||||||
|
@param count: 每通道读取的数值数量
|
||||||
|
@return: 返回最新的数值
|
||||||
|
'''
|
||||||
|
data=self.getData(count)
|
||||||
|
data=data[:8,:]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def getImpedance(self, data):
|
||||||
|
'''
|
||||||
|
获取阻抗值,已经放大100倍,单位是kΩ
|
||||||
|
@param data: 准备计算的通道数据,每通道200个值,注意不要把信号打标的通道传进来
|
||||||
|
@return: 返回各个通道的阻抗值
|
||||||
|
'''
|
||||||
|
impedanceList = []
|
||||||
|
for channelindex in range(data.shape[0]):
|
||||||
|
if len(data[channelindex]) > 0:
|
||||||
|
data_list = []
|
||||||
|
# 设计陷波滤波器,去除50Hz成分
|
||||||
|
is50filter = True
|
||||||
|
if is50filter:
|
||||||
|
b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽,1000是采样频率
|
||||||
|
data_list = signal.lfilter(b, a, data[channelindex].tolist())
|
||||||
|
|
||||||
|
else:
|
||||||
|
data_list.extend(data[channelindex].tolist())
|
||||||
|
|
||||||
|
data_list = data_list[-1000:]
|
||||||
|
# 执行FFT
|
||||||
|
fft_result = np.fft.fft(data_list)
|
||||||
|
fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果
|
||||||
|
freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴
|
||||||
|
|
||||||
|
# y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())],
|
||||||
|
# fft_magnitude[1:-1] * 2 / len(t[0].tolist()),
|
||||||
|
# [fft_magnitude[-1] / len(t[0].tolist())]))
|
||||||
|
|
||||||
|
# 找到幅值最大的频率成分的索引(忽略直流分量,即索引0)
|
||||||
|
max_index = np.argmax(fft_magnitude[1:])
|
||||||
|
|
||||||
|
# 获取最大幅值的频率索引(加上1,因为索引0是直流分量)
|
||||||
|
freq_index = max_index + 1
|
||||||
|
|
||||||
|
# 获取最大幅值
|
||||||
|
max_magnitude = fft_magnitude[freq_index]
|
||||||
|
|
||||||
|
# 阻抗
|
||||||
|
import math
|
||||||
|
result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4
|
||||||
|
result *= 0.44 * 100 # 统一放大100倍
|
||||||
|
impedanceList.append(int(result))
|
||||||
|
# print(max_magnitude, result)
|
||||||
|
else:
|
||||||
|
impedanceList.append(0)
|
||||||
|
# impedances = ":".join(map(str, impedanceList))
|
||||||
|
impedances = np.array(impedanceList)
|
||||||
|
impedances = impedances[:8]
|
||||||
|
return impedances
|
||||||
|
def getData(self,count):
|
||||||
|
'''
|
||||||
|
获取最新的数据
|
||||||
|
@param count: 每通道返回的最数值数目
|
||||||
|
@return: 所有通道的最新count个数值
|
||||||
|
'''
|
||||||
|
data=None
|
||||||
|
try:
|
||||||
|
with self.RingBufferLock:
|
||||||
|
data = self.__ringBuffer.getData(count)
|
||||||
|
except:
|
||||||
|
print("锁:读取异常")
|
||||||
|
# self.RingBufferLock.release()
|
||||||
|
|
||||||
|
|
||||||
|
return data
|
||||||
|
def GetDataLenCount(self):
|
||||||
|
'''
|
||||||
|
获取最新缓存中每个通道的数量
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
return self.__ringBuffer.nUpdate
|
||||||
|
|
||||||
|
def ResetAll(self):
|
||||||
|
'''
|
||||||
|
清空缓存
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
with self.RingBufferLock:
|
||||||
|
self.__ringBuffer.resetAllPara()
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Usage
|
||||||
|
Linker = SunnyLinker64('127.0.0.1', 5086, 1000, 65)
|
||||||
|
Linker.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(0.005)
|
||||||
|
if(Linker.count()>0):
|
||||||
|
# print(Linker.ringBuffer.nUpdate)
|
||||||
|
t = Linker.getData()
|
||||||
|
print(t.shape[1], Linker.count())
|
||||||
|
# Linker.ringBuffer.nUpdate=0
|
||||||
|
# time.sleep(0.2)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
Linker.stop()
|
||||||
193
Debug_64ch_Decoder/Device/protocol.py
Normal file
193
Debug_64ch_Decoder/Device/protocol.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
from typing import List, Tuple, Union, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolFrame:
|
||||||
|
# 协议常量
|
||||||
|
FRAME_HEADER = 0xAA
|
||||||
|
FRAME_TAIL1 = 0x55
|
||||||
|
FRAME_TAIL2 = 0x55
|
||||||
|
RESERVED_SIZE = 6
|
||||||
|
MIN_FRAME_SIZE = 13 # 帧头1 + 功能1 + 长度2 + 预留6 + CRC1 + 包尾2
|
||||||
|
MAX_DATA_LENGTH = 0xFFFF # 最大数据长度 (2字节能表示的最大值)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_crc8(data: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
计算CRC8校验值
|
||||||
|
Args:
|
||||||
|
data: 需要计算CRC的数据
|
||||||
|
Returns:
|
||||||
|
一个字节的CRC值(bytes类型)
|
||||||
|
"""
|
||||||
|
crc = 0
|
||||||
|
for byte in data:
|
||||||
|
crc ^= byte
|
||||||
|
for _ in range(8):
|
||||||
|
crc = ((crc << 1) ^ 0x07 if crc & 0x80 else crc << 1) & 0xFF
|
||||||
|
return bytes([crc])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pack(cls, function, data: Union[bytes, bytearray, List[int]],
|
||||||
|
reserved: Optional[Union[bytes, bytearray, List[int]]] = None) -> bytes:
|
||||||
|
"""
|
||||||
|
协议打包函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function: 功能码 (1字节)
|
||||||
|
data: 数据块
|
||||||
|
reserved: 预留字节(6字节,可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
打包后的字节数据
|
||||||
|
"""
|
||||||
|
# 检查功能码
|
||||||
|
if function != None:
|
||||||
|
if not 0 <= function <= 0xFF:
|
||||||
|
raise ValueError("功能码必须是1字节")
|
||||||
|
|
||||||
|
# 转换数据为bytearray
|
||||||
|
if isinstance(data, list):
|
||||||
|
data = bytearray(data)
|
||||||
|
elif isinstance(data, bytes):
|
||||||
|
data = bytearray(data)
|
||||||
|
|
||||||
|
# 检查数据长度
|
||||||
|
data_length = len(data)
|
||||||
|
if data_length > cls.MAX_DATA_LENGTH:
|
||||||
|
raise ValueError(f"数据长度超过最大值 {cls.MAX_DATA_LENGTH}")
|
||||||
|
|
||||||
|
# 处理预留字节
|
||||||
|
if reserved is None:
|
||||||
|
reserved = bytearray([0] * cls.RESERVED_SIZE)
|
||||||
|
else:
|
||||||
|
if isinstance(reserved, list):
|
||||||
|
reserved = bytearray(reserved)
|
||||||
|
elif isinstance(reserved, bytes):
|
||||||
|
reserved = bytearray(reserved)
|
||||||
|
if len(reserved) != cls.RESERVED_SIZE:
|
||||||
|
raise ValueError(f"预留字节必须是{cls.RESERVED_SIZE}字节")
|
||||||
|
|
||||||
|
# 构建帧
|
||||||
|
frame = bytearray([cls.FRAME_HEADER]) # 帧头 (1字节)
|
||||||
|
if function != None:
|
||||||
|
frame.append(function) # 功能码 (1字节)
|
||||||
|
data_length+=6
|
||||||
|
|
||||||
|
# 数据长度 (2字节,大端序)
|
||||||
|
frame.append((data_length >> 8) & 0xFF) # 高字节
|
||||||
|
frame.append(data_length & 0xFF) # 低字节
|
||||||
|
|
||||||
|
if function != None:
|
||||||
|
frame.extend(reserved) # 预留字节 (6字节)
|
||||||
|
frame.extend(data) # 数据块 (变长)
|
||||||
|
|
||||||
|
# 计算CRC (从功能码开始到数据块结束)
|
||||||
|
crc = cls.calculate_crc8(frame[1:]) # 不包含帧头
|
||||||
|
frame.extend(crc) # CRC校验 (1字节)
|
||||||
|
|
||||||
|
# 添加帧尾
|
||||||
|
frame.extend([cls.FRAME_TAIL1, cls.FRAME_TAIL2]) # 帧尾 (2字节)
|
||||||
|
|
||||||
|
return bytes(frame)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unpack(cls, data: Union[bytes, bytearray]) -> Tuple[int, bytearray, bytearray]:
|
||||||
|
"""
|
||||||
|
协议解包函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 待解析的字节数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(功能码, 数据块, 预留字节)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当数据格式不正确时
|
||||||
|
"""
|
||||||
|
# 检查数据长度
|
||||||
|
if len(data) < cls.MIN_FRAME_SIZE:
|
||||||
|
raise ValueError("数据长度不足")
|
||||||
|
|
||||||
|
# 检查帧头
|
||||||
|
if data[0] != cls.FRAME_HEADER:
|
||||||
|
raise ValueError("帧头错误")
|
||||||
|
|
||||||
|
# 检查帧尾
|
||||||
|
if data[-2:] != bytes([cls.FRAME_TAIL1, cls.FRAME_TAIL2]):
|
||||||
|
raise ValueError("帧尾错误")
|
||||||
|
|
||||||
|
# 解析基本信息
|
||||||
|
function = data[1] # 功能码 (1字节)
|
||||||
|
|
||||||
|
# 数据长度 (2字节,大端序)
|
||||||
|
data_length = (data[2] << 8) | data[3]
|
||||||
|
|
||||||
|
reserved = data[4:10] # 预留字节 (6字节)
|
||||||
|
|
||||||
|
# 检查数据长度
|
||||||
|
expected_length = cls.MIN_FRAME_SIZE + data_length
|
||||||
|
if len(data) != expected_length:
|
||||||
|
raise ValueError(f"数据长度不匹配: 期望{expected_length}字节,实际{len(data)}字节")
|
||||||
|
|
||||||
|
# 提取数据块
|
||||||
|
payload = data[10:10 + data_length]
|
||||||
|
|
||||||
|
# 验证CRC (从功能码开始到数据块结束)
|
||||||
|
received_crc = data[-3]
|
||||||
|
calculated_crc = cls.calculate_crc8(data[1:-3])[0] # 获取字节值
|
||||||
|
|
||||||
|
if received_crc != calculated_crc:
|
||||||
|
raise ValueError(f"CRC校验失败: 期望{calculated_crc:02X},实际{received_crc:02X}")
|
||||||
|
|
||||||
|
return function, bytearray(payload), bytearray(reserved)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def print_hex(data: bytes, label: str = ""):
|
||||||
|
"""打印十六进制数据,并按字节添加空格"""
|
||||||
|
hex_str = ' '.join([f"{b:02X}" for b in data])
|
||||||
|
if label:
|
||||||
|
print(f"{label}: {hex_str}")
|
||||||
|
else:
|
||||||
|
print(hex_str)
|
||||||
|
|
||||||
|
|
||||||
|
def print_frame_details(data: bytes):
|
||||||
|
"""打印帧的详细信息"""
|
||||||
|
print("帧详细信息:")
|
||||||
|
print(f"帧头: {data[0]:02X}")
|
||||||
|
print(f"功能码: {data[1]:02X}")
|
||||||
|
print(f"数据长度: {data[2]:02X} {data[3]:02X} ({(data[2] << 8) | data[3]}字节)")
|
||||||
|
print(f"预留字节: {' '.join([f'{b:02X}' for b in data[4:10]])}")
|
||||||
|
data_length = (data[2] << 8) | data[3]
|
||||||
|
print(f"数据块: {' '.join([f'{b:02X}' for b in data[10:10 + data_length]])}")
|
||||||
|
print(f"CRC校验: {data[-3]:02X}")
|
||||||
|
print(f"帧尾: {data[-2]:02X} {data[-1]:02X}")
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
def example_usage():
|
||||||
|
try:
|
||||||
|
|
||||||
|
|
||||||
|
# 示例1:简单数据打包
|
||||||
|
function_code = 0x01
|
||||||
|
data = [0x1]
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, data)
|
||||||
|
print_hex(packed_data, "示例1 - 完整帧")
|
||||||
|
print_frame_details(packed_data)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 示例3:解包验证
|
||||||
|
function, payload, reserved = ProtocolFrame.unpack(packed_data)
|
||||||
|
print("解包结果:")
|
||||||
|
print(f"功能码: 0x{function:02X}")
|
||||||
|
print_hex(payload, "数据块")
|
||||||
|
print_hex(reserved, "预留字节")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"错误: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
example_usage()
|
||||||
409
Debug_64ch_Decoder/MI/Algorithm/conformer_2class.py
Normal file
409
Debug_64ch_Decoder/MI/Algorithm/conformer_2class.py
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
"""
|
||||||
|
EEG Conformer
|
||||||
|
|
||||||
|
Convolutional Transformer for EEG decoding
|
||||||
|
|
||||||
|
Couple CNN and Transformer in a concise manner with amazing results
|
||||||
|
"""
|
||||||
|
# remember to change paths
|
||||||
|
|
||||||
|
import os
|
||||||
|
gpus = [0]
|
||||||
|
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
from einops import rearrange
|
||||||
|
from einops.layers.torch import Rearrange, Reduce
|
||||||
|
# from common_spatial_pattern import csp
|
||||||
|
|
||||||
|
# from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from torch.backends import cudnn
|
||||||
|
cudnn.benchmark = True
|
||||||
|
cudnn.deterministic = True
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
# writer = SummaryWriter('./TensorBoardX/')
|
||||||
|
|
||||||
|
|
||||||
|
# Convolution module
|
||||||
|
# use conv to capture local features, instead of postion embedding.
|
||||||
|
class PatchEmbedding(nn.Module):
|
||||||
|
def __init__(self, emb_size=40,n_chan=8):
|
||||||
|
# self.patch_size = patch_size
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.shallownet = nn.Sequential(
|
||||||
|
nn.Conv2d(1, 40, (1, 25), (1, 1)),
|
||||||
|
nn.Conv2d(40, 40, (n_chan, 1), (1, 1)),
|
||||||
|
nn.BatchNorm2d(40),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.projection = nn.Sequential(
|
||||||
|
nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly
|
||||||
|
Rearrange('b e (h) (w) -> b (h w) e'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
b, _, _, _ = x.shape
|
||||||
|
x = self.shallownet(x)
|
||||||
|
x = self.projection(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, emb_size, num_heads, dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.emb_size = emb_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.keys = nn.Linear(emb_size, emb_size)
|
||||||
|
self.queries = nn.Linear(emb_size, emb_size)
|
||||||
|
self.values = nn.Linear(emb_size, emb_size)
|
||||||
|
self.att_drop = nn.Dropout(dropout)
|
||||||
|
self.projection = nn.Linear(emb_size, emb_size)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
|
||||||
|
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
|
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
|
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
|
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||||
|
if mask is not None:
|
||||||
|
fill_value = torch.finfo(torch.float32).min
|
||||||
|
energy.mask_fill(~mask, fill_value)
|
||||||
|
|
||||||
|
scaling = self.emb_size ** (1 / 2)
|
||||||
|
att = F.softmax(energy / scaling, dim=-1)
|
||||||
|
att = self.att_drop(att)
|
||||||
|
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
out = self.projection(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAdd(nn.Module):
|
||||||
|
def __init__(self, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
res = x
|
||||||
|
x = self.fn(x, **kwargs)
|
||||||
|
x += res
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForwardBlock(nn.Sequential):
|
||||||
|
def __init__(self, emb_size, expansion, drop_p):
|
||||||
|
super().__init__(
|
||||||
|
nn.Linear(emb_size, expansion * emb_size),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(drop_p),
|
||||||
|
nn.Linear(expansion * emb_size, emb_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderBlock(nn.Sequential):
|
||||||
|
def __init__(self,
|
||||||
|
emb_size,
|
||||||
|
num_heads=10,
|
||||||
|
drop_p=0.5,
|
||||||
|
forward_expansion=4,
|
||||||
|
forward_drop_p=0.5):
|
||||||
|
super().__init__(
|
||||||
|
ResidualAdd(nn.Sequential(
|
||||||
|
nn.LayerNorm(emb_size),
|
||||||
|
MultiHeadAttention(emb_size, num_heads, drop_p),
|
||||||
|
nn.Dropout(drop_p)
|
||||||
|
)),
|
||||||
|
ResidualAdd(nn.Sequential(
|
||||||
|
nn.LayerNorm(emb_size),
|
||||||
|
FeedForwardBlock(
|
||||||
|
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
|
||||||
|
nn.Dropout(drop_p)
|
||||||
|
)
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Sequential):
|
||||||
|
def __init__(self, depth, emb_size):
|
||||||
|
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationHead(nn.Sequential):
|
||||||
|
def __init__(self, emb_size, n_classes):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# global average pooling
|
||||||
|
self.clshead = nn.Sequential(
|
||||||
|
Reduce('b n e -> b e', reduction='mean'),
|
||||||
|
nn.LayerNorm(emb_size),
|
||||||
|
nn.Linear(emb_size, n_classes)
|
||||||
|
)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(2440, 256),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.Linear(256, 32),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(32, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.contiguous().view(x.size(0), -1)
|
||||||
|
out = self.fc(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Conformer(nn.Sequential):
|
||||||
|
def __init__(self, emb_size=40, depth=6, n_classes=2,n_chan=8, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
|
||||||
|
PatchEmbedding(emb_size,n_chan),
|
||||||
|
TransformerEncoder(depth, emb_size),
|
||||||
|
ClassificationHead(emb_size, n_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExP():
|
||||||
|
def __init__(self,n_chan):
|
||||||
|
super(ExP, self).__init__()
|
||||||
|
self.n_chan = n_chan
|
||||||
|
self.batch_size = 24
|
||||||
|
self.n_epochs = 250
|
||||||
|
self.c_dim = 4
|
||||||
|
self.lr = 0.0002
|
||||||
|
self.b1 = 0.5
|
||||||
|
self.b2 = 0.999
|
||||||
|
|
||||||
|
self.start_epoch = 0
|
||||||
|
# 创建目录
|
||||||
|
os.makedirs("online_Models", exist_ok=True)
|
||||||
|
self.log_write = open("./online_Models/log_result.txt", "w")
|
||||||
|
|
||||||
|
|
||||||
|
self.Tensor = torch.cuda.FloatTensor
|
||||||
|
self.LongTensor = torch.cuda.LongTensor
|
||||||
|
|
||||||
|
self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
|
||||||
|
|
||||||
|
self.model = Conformer(n_chan=self.n_chan).cuda()
|
||||||
|
self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
|
||||||
|
self.model = self.model.cuda()
|
||||||
|
|
||||||
|
# self.model = EEGNet().cuda()
|
||||||
|
# self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))])
|
||||||
|
# self.model = self.model.cuda()
|
||||||
|
# summary(self.model, (1, 8, 1000))
|
||||||
|
|
||||||
|
|
||||||
|
# Segmentation and Reconstruction (S&R) data augmentation
|
||||||
|
def interaug(self, timg, label):
|
||||||
|
# 确保输入是 numpy 数组(CPU)
|
||||||
|
if isinstance(timg, torch.Tensor):
|
||||||
|
timg = timg.cpu().numpy()
|
||||||
|
if isinstance(label, torch.Tensor):
|
||||||
|
label = label.cpu().numpy()
|
||||||
|
|
||||||
|
aug_data = []
|
||||||
|
aug_label = []
|
||||||
|
for cls4aug in range(2):
|
||||||
|
cls_idx = np.where(label == cls4aug + 1)
|
||||||
|
tmp_data = timg[cls_idx]
|
||||||
|
tmp_label = label[cls_idx]
|
||||||
|
tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, self.n_chan, 1000))
|
||||||
|
for ri in range(int(self.batch_size / 2)):
|
||||||
|
for rj in range(8):
|
||||||
|
rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
|
||||||
|
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
|
||||||
|
rj * 125:(rj + 1) * 125]
|
||||||
|
|
||||||
|
aug_data.append(tmp_aug_data)
|
||||||
|
aug_label.append(tmp_label[:int(self.batch_size / 2)])
|
||||||
|
aug_data = np.concatenate(aug_data)
|
||||||
|
aug_label = np.concatenate(aug_label)
|
||||||
|
aug_shuffle = np.random.permutation(len(aug_data))
|
||||||
|
aug_data = aug_data[aug_shuffle, :, :]
|
||||||
|
aug_label = aug_label[aug_shuffle]
|
||||||
|
|
||||||
|
# 返回 numpy 数组,由调用方决定是否移到 GPU
|
||||||
|
return aug_data, aug_label
|
||||||
|
|
||||||
|
def train(self,all_data,all_label,model_path):
|
||||||
|
all_data = np.array(all_data);all_label = np.array(all_label)
|
||||||
|
all_data = np.expand_dims(all_data, axis=1)
|
||||||
|
train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2,
|
||||||
|
random_state=42, stratify=all_label,shuffle=True)
|
||||||
|
|
||||||
|
# === 优化:一次性预生成增强数据,避免每个 batch 都重复计算 ===
|
||||||
|
aug_data, aug_label = self.interaug(train_data, train_label)
|
||||||
|
# 将原始数据和增强数据合并,再一起打乱
|
||||||
|
train_data_full = np.concatenate([train_data, aug_data], axis=0)
|
||||||
|
train_label_full = np.concatenate([train_label, aug_label], axis=0)
|
||||||
|
shuffle_idx = np.random.permutation(len(train_data_full))
|
||||||
|
train_data_full = train_data_full[shuffle_idx]
|
||||||
|
train_label_full = train_label_full[shuffle_idx]
|
||||||
|
|
||||||
|
img = torch.from_numpy(train_data_full)
|
||||||
|
label = torch.from_numpy(train_label_full-1)
|
||||||
|
|
||||||
|
dataset = torch.utils.data.TensorDataset(img, label)
|
||||||
|
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
test_data = torch.from_numpy(test_data)
|
||||||
|
test_label = torch.from_numpy(test_label-1)
|
||||||
|
test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
|
||||||
|
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
# Optimizers
|
||||||
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
|
||||||
|
|
||||||
|
test_data = Variable(test_data.type(self.Tensor))
|
||||||
|
test_label = Variable(test_label.type(self.LongTensor))
|
||||||
|
|
||||||
|
bestAcc = 0
|
||||||
|
averAcc = 0
|
||||||
|
num = 0
|
||||||
|
Y_true = 0
|
||||||
|
Y_pred = 0
|
||||||
|
|
||||||
|
# Train the cnn model
|
||||||
|
for e in range(self.n_epochs):
|
||||||
|
# in_epoch = time.time()
|
||||||
|
self.model.train()
|
||||||
|
for i, (img, label) in enumerate(self.dataloader):
|
||||||
|
|
||||||
|
img = Variable(img.cuda().type(self.Tensor))
|
||||||
|
label = Variable(label.cuda().type(self.LongTensor))
|
||||||
|
|
||||||
|
outputs = self.model(img)
|
||||||
|
|
||||||
|
loss = self.criterion_cls(outputs, label)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
# out_epoch = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
# test process
|
||||||
|
if (e + 1) % 1 == 0:
|
||||||
|
self.model.eval()
|
||||||
|
Cls = self.model(test_data)
|
||||||
|
|
||||||
|
loss_test = self.criterion_cls(Cls, test_label)
|
||||||
|
y_pred = torch.max(Cls, 1)[1]
|
||||||
|
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
|
||||||
|
train_pred = torch.max(outputs, 1)[1]
|
||||||
|
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
||||||
|
|
||||||
|
print('Epoch:', e,
|
||||||
|
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
||||||
|
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
||||||
|
' Train accuracy %.6f' % train_acc,
|
||||||
|
' Test accuracy is %.6f' % acc)
|
||||||
|
|
||||||
|
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
||||||
|
num = num + 1
|
||||||
|
averAcc = averAcc + acc
|
||||||
|
if acc > bestAcc:
|
||||||
|
bestAcc = acc
|
||||||
|
Y_true = test_label
|
||||||
|
Y_pred = y_pred
|
||||||
|
|
||||||
|
|
||||||
|
torch.save(self.model, model_path)
|
||||||
|
averAcc = averAcc / num
|
||||||
|
print('The average accuracy is:', averAcc)
|
||||||
|
print('The best accuracy is:', bestAcc)
|
||||||
|
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
||||||
|
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
||||||
|
|
||||||
|
return bestAcc, averAcc, Y_true, Y_pred
|
||||||
|
# writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
def onlineTrain(data_queue,result_queue):
|
||||||
|
import torch
|
||||||
|
print(f"[DEBUG] torch.__version__ = {torch.__version__}")
|
||||||
|
print(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}")
|
||||||
|
try:
|
||||||
|
starttime = datetime.datetime.now()
|
||||||
|
|
||||||
|
# seed_n = np.random.randint(2025)
|
||||||
|
seed_n = 1877
|
||||||
|
random.seed(seed_n)
|
||||||
|
np.random.seed(seed_n)
|
||||||
|
torch.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed_all(seed_n)
|
||||||
|
|
||||||
|
|
||||||
|
# 从队列获取训练数据
|
||||||
|
data = data_queue.get(timeout=30)
|
||||||
|
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
|
||||||
|
exp = ExP(n_chan)
|
||||||
|
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
|
||||||
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
||||||
|
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
||||||
|
|
||||||
|
endtime = datetime.datetime.now()
|
||||||
|
print('train duration: ',str(endtime - starttime))
|
||||||
|
|
||||||
|
# 将模型或参数传回
|
||||||
|
result_queue.put({
|
||||||
|
'status': 'success',
|
||||||
|
'model_state': model_path, # 或保存路径
|
||||||
|
'timestamp': time.time()
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put({'status': 'error', 'msg': str(e)})
|
||||||
|
|
||||||
|
def offlineTrain(all_data,all_label,modelPath):
|
||||||
|
starttime = datetime.datetime.now()
|
||||||
|
|
||||||
|
# seed_n = np.random.randint(2025)
|
||||||
|
seed_n = 1877
|
||||||
|
print('seed is ' + str(seed_n))
|
||||||
|
random.seed(seed_n)
|
||||||
|
np.random.seed(seed_n)
|
||||||
|
torch.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed_all(seed_n)
|
||||||
|
|
||||||
|
exp = ExP()
|
||||||
|
|
||||||
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
||||||
|
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
||||||
|
|
||||||
|
endtime = datetime.datetime.now()
|
||||||
|
print('train duration: ',str(endtime - starttime))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(time.asctime(time.localtime(time.time())))
|
||||||
|
print(time.asctime(time.localtime(time.time())))
|
||||||
382
Debug_64ch_Decoder/MI/Algorithm/conformer_2class_cpu.py
Normal file
382
Debug_64ch_Decoder/MI/Algorithm/conformer_2class_cpu.py
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
"""
|
||||||
|
EEG Conformer
|
||||||
|
|
||||||
|
Convolutional Transformer for EEG decoding
|
||||||
|
|
||||||
|
Couple CNN and Transformer in a concise manner with amazing results
|
||||||
|
"""
|
||||||
|
# remember to change paths
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
from einops import rearrange
|
||||||
|
from einops.layers.torch import Rearrange, Reduce
|
||||||
|
from torch.backends import cudnn
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
# writer = SummaryWriter('./TensorBoardX/')
|
||||||
|
|
||||||
|
|
||||||
|
# Convolution module
|
||||||
|
# use conv to capture local features, instead of postion embedding.
|
||||||
|
class PatchEmbedding(nn.Module):
|
||||||
|
def __init__(self, emb_size=40):
|
||||||
|
# self.patch_size = patch_size
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.shallownet = nn.Sequential(
|
||||||
|
nn.Conv2d(1, 40, (1, 25), (1, 1)),
|
||||||
|
nn.Conv2d(40, 40, (8, 1), (1, 1)),
|
||||||
|
nn.BatchNorm2d(40),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.projection = nn.Sequential(
|
||||||
|
nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly
|
||||||
|
Rearrange('b e (h) (w) -> b (h w) e'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
b, _, _, _ = x.shape
|
||||||
|
x = self.shallownet(x)
|
||||||
|
x = self.projection(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, emb_size, num_heads, dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.emb_size = emb_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.keys = nn.Linear(emb_size, emb_size)
|
||||||
|
self.queries = nn.Linear(emb_size, emb_size)
|
||||||
|
self.values = nn.Linear(emb_size, emb_size)
|
||||||
|
self.att_drop = nn.Dropout(dropout)
|
||||||
|
self.projection = nn.Linear(emb_size, emb_size)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
|
||||||
|
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
|
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
|
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
|
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||||
|
if mask is not None:
|
||||||
|
fill_value = torch.finfo(torch.float32).min
|
||||||
|
energy.mask_fill(~mask, fill_value)
|
||||||
|
|
||||||
|
scaling = self.emb_size ** (1 / 2)
|
||||||
|
att = F.softmax(energy / scaling, dim=-1)
|
||||||
|
att = self.att_drop(att)
|
||||||
|
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
out = self.projection(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAdd(nn.Module):
|
||||||
|
def __init__(self, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
res = x
|
||||||
|
x = self.fn(x, **kwargs)
|
||||||
|
x += res
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForwardBlock(nn.Sequential):
|
||||||
|
def __init__(self, emb_size, expansion, drop_p):
|
||||||
|
super().__init__(
|
||||||
|
nn.Linear(emb_size, expansion * emb_size),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(drop_p),
|
||||||
|
nn.Linear(expansion * emb_size, emb_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderBlock(nn.Sequential):
|
||||||
|
def __init__(self,
|
||||||
|
emb_size,
|
||||||
|
num_heads=10,
|
||||||
|
drop_p=0.5,
|
||||||
|
forward_expansion=4,
|
||||||
|
forward_drop_p=0.5):
|
||||||
|
super().__init__(
|
||||||
|
ResidualAdd(nn.Sequential(
|
||||||
|
nn.LayerNorm(emb_size),
|
||||||
|
MultiHeadAttention(emb_size, num_heads, drop_p),
|
||||||
|
nn.Dropout(drop_p)
|
||||||
|
)),
|
||||||
|
ResidualAdd(nn.Sequential(
|
||||||
|
nn.LayerNorm(emb_size),
|
||||||
|
FeedForwardBlock(
|
||||||
|
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
|
||||||
|
nn.Dropout(drop_p)
|
||||||
|
)
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Sequential):
|
||||||
|
def __init__(self, depth, emb_size):
|
||||||
|
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationHead(nn.Sequential):
|
||||||
|
def __init__(self, emb_size, n_classes):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# global average pooling
|
||||||
|
self.clshead = nn.Sequential(
|
||||||
|
Reduce('b n e -> b e', reduction='mean'),
|
||||||
|
nn.LayerNorm(emb_size),
|
||||||
|
nn.Linear(emb_size, n_classes)
|
||||||
|
)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(2440, 256),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.Linear(256, 32),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(32, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.contiguous().view(x.size(0), -1)
|
||||||
|
out = self.fc(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Conformer(nn.Sequential):
|
||||||
|
def __init__(self, emb_size=40, depth=6, n_classes=2, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
|
||||||
|
PatchEmbedding(emb_size),
|
||||||
|
TransformerEncoder(depth, emb_size),
|
||||||
|
ClassificationHead(emb_size, n_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExP():
|
||||||
|
def __init__(self):
|
||||||
|
super(ExP, self).__init__()
|
||||||
|
self.batch_size = 24
|
||||||
|
self.n_epochs = 250
|
||||||
|
self.c_dim = 4
|
||||||
|
self.lr = 0.0002
|
||||||
|
self.b1 = 0.5
|
||||||
|
self.b2 = 0.999
|
||||||
|
|
||||||
|
self.start_epoch = 0
|
||||||
|
|
||||||
|
self.log_write = open("./online_Models/log_result.txt", "w")
|
||||||
|
|
||||||
|
# 自动选择设备:有 GPU 用 GPU,否则用 CPU
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
# self.device = torch.device("cpu")
|
||||||
|
print(f"Using device: {self.device}")
|
||||||
|
|
||||||
|
# 定义张量类型(不再强制使用 cuda)
|
||||||
|
self.Tensor = torch.FloatTensor
|
||||||
|
self.LongTensor = torch.LongTensor
|
||||||
|
|
||||||
|
# 将模型移到指定设备
|
||||||
|
self.model = Conformer().to(self.device)
|
||||||
|
|
||||||
|
# 损失函数也移到设备
|
||||||
|
self.criterion_cls = torch.nn.CrossEntropyLoss().to(self.device)
|
||||||
|
|
||||||
|
# self.model = EEGNet().cuda()
|
||||||
|
# self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))])
|
||||||
|
# self.model = self.model.cuda()
|
||||||
|
# summary(self.model, (1, 8, 1000))
|
||||||
|
|
||||||
|
|
||||||
|
# Segmentation and Reconstruction (S&R) data augmentation
|
||||||
|
def interaug(self, timg, label):
|
||||||
|
aug_data = []
|
||||||
|
aug_label = []
|
||||||
|
for cls4aug in range(2):
|
||||||
|
cls_idx = np.where(label == cls4aug + 1)
|
||||||
|
tmp_data = timg[cls_idx]
|
||||||
|
tmp_label = label[cls_idx]
|
||||||
|
tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, 8, 1000))
|
||||||
|
for ri in range(int(self.batch_size / 2)):
|
||||||
|
for rj in range(8):
|
||||||
|
rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
|
||||||
|
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
|
||||||
|
rj * 125:(rj + 1) * 125]
|
||||||
|
|
||||||
|
aug_data.append(tmp_aug_data)
|
||||||
|
aug_label.append(tmp_label[:int(self.batch_size / 2)])
|
||||||
|
aug_data = np.concatenate(aug_data)
|
||||||
|
aug_label = np.concatenate(aug_label)
|
||||||
|
aug_shuffle = np.random.permutation(len(aug_data))
|
||||||
|
aug_data = aug_data[aug_shuffle, :, :]
|
||||||
|
aug_label = aug_label[aug_shuffle]
|
||||||
|
|
||||||
|
aug_data = torch.from_numpy(aug_data).float().to(self.device)
|
||||||
|
aug_label = torch.from_numpy(aug_label - 1).long().to(self.device)
|
||||||
|
return aug_data, aug_label
|
||||||
|
|
||||||
|
def train(self,all_data,all_label,model_path):
|
||||||
|
all_data = np.array(all_data);all_label = np.array(all_label)
|
||||||
|
all_data = np.expand_dims(all_data, axis=1)
|
||||||
|
train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2,
|
||||||
|
random_state=42, stratify=all_label,shuffle=True)
|
||||||
|
# 转为 Tensor
|
||||||
|
img = torch.from_numpy(train_data).float().to(self.device)
|
||||||
|
label = torch.from_numpy(train_label - 1).long().to(self.device)
|
||||||
|
|
||||||
|
dataset = torch.utils.data.TensorDataset(img, label)
|
||||||
|
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
test_data = torch.from_numpy(test_data).float().to(self.device)
|
||||||
|
test_label = torch.from_numpy(test_label - 1).long().to(self.device)
|
||||||
|
test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
|
||||||
|
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
# Optimizers
|
||||||
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
|
||||||
|
|
||||||
|
bestAcc = 0
|
||||||
|
averAcc = 0
|
||||||
|
num = 0
|
||||||
|
Y_true = 0
|
||||||
|
Y_pred = 0
|
||||||
|
|
||||||
|
# Train the cnn model
|
||||||
|
for e in range(self.n_epochs):
|
||||||
|
# in_epoch = time.time()
|
||||||
|
self.model.train()
|
||||||
|
for i, (img, label) in enumerate(self.dataloader):
|
||||||
|
|
||||||
|
# data augmentation
|
||||||
|
aug_data, aug_label = self.interaug(train_data, train_label)
|
||||||
|
img = torch.cat((img, aug_data))
|
||||||
|
label = torch.cat((label, aug_label))
|
||||||
|
|
||||||
|
|
||||||
|
outputs = self.model(img)
|
||||||
|
|
||||||
|
loss = self.criterion_cls(outputs, label)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
# out_epoch = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
# test process
|
||||||
|
if (e + 1) % 1 == 0:
|
||||||
|
self.model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
Cls = self.model(test_data)
|
||||||
|
|
||||||
|
loss_test = self.criterion_cls(Cls, test_label)
|
||||||
|
y_pred = torch.max(Cls, 1)[1]
|
||||||
|
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
|
||||||
|
train_pred = torch.max(outputs, 1)[1]
|
||||||
|
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
||||||
|
|
||||||
|
print('Epoch:', e,
|
||||||
|
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
||||||
|
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
||||||
|
' Train accuracy %.6f' % train_acc,
|
||||||
|
' Test accuracy is %.6f' % acc)
|
||||||
|
|
||||||
|
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
||||||
|
num = num + 1
|
||||||
|
averAcc = averAcc + acc
|
||||||
|
if acc > bestAcc:
|
||||||
|
bestAcc = acc
|
||||||
|
Y_true = test_label
|
||||||
|
Y_pred = y_pred
|
||||||
|
|
||||||
|
|
||||||
|
torch.save(self.model, model_path)
|
||||||
|
averAcc = averAcc / num
|
||||||
|
print('The average accuracy is:', averAcc)
|
||||||
|
print('The best accuracy is:', bestAcc)
|
||||||
|
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
||||||
|
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
||||||
|
|
||||||
|
return bestAcc, averAcc, Y_true, Y_pred
|
||||||
|
# writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
def onlineTrain(data_queue,result_queue):
|
||||||
|
try:
|
||||||
|
starttime = datetime.datetime.now()
|
||||||
|
|
||||||
|
# seed_n = np.random.randint(2025)
|
||||||
|
seed_n = 1877
|
||||||
|
random.seed(seed_n)
|
||||||
|
np.random.seed(seed_n)
|
||||||
|
torch.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed_all(seed_n)
|
||||||
|
|
||||||
|
exp = ExP()
|
||||||
|
# 从队列获取训练数据
|
||||||
|
data = data_queue.get(timeout=30)
|
||||||
|
all_data, all_label,model_path = data['data'], data['label'],data['modelPath']
|
||||||
|
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
|
||||||
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
||||||
|
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
||||||
|
|
||||||
|
endtime = datetime.datetime.now()
|
||||||
|
print('train duration: ',str(endtime - starttime))
|
||||||
|
|
||||||
|
# 将模型或参数传回
|
||||||
|
result_queue.put({
|
||||||
|
'status': 'success',
|
||||||
|
'model_state': model_path, # 或保存路径
|
||||||
|
'timestamp': time.time()
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put({'status': 'error', 'msg': str(e)})
|
||||||
|
|
||||||
|
def offlineTrain(all_data,all_label,modelPath):
|
||||||
|
starttime = datetime.datetime.now()
|
||||||
|
|
||||||
|
# seed_n = np.random.randint(2025)
|
||||||
|
seed_n = 1877
|
||||||
|
print('seed is ' + str(seed_n))
|
||||||
|
random.seed(seed_n)
|
||||||
|
np.random.seed(seed_n)
|
||||||
|
torch.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed_all(seed_n)
|
||||||
|
|
||||||
|
exp = ExP()
|
||||||
|
|
||||||
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
||||||
|
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
||||||
|
|
||||||
|
endtime = datetime.datetime.now()
|
||||||
|
print('train duration: ',str(endtime - starttime))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(time.asctime(time.localtime(time.time())))
|
||||||
|
print(time.asctime(time.localtime(time.time())))
|
||||||
184
Debug_64ch_Decoder/MI/Algorithm/otherModels.py
Normal file
184
Debug_64ch_Decoder/MI/Algorithm/otherModels.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
from torchsummary import summary
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def weights_init(m):
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.xavier_uniform_(m.weight)
|
||||||
|
# nn.init.constant(m.bias, 0) # bias may be none
|
||||||
|
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(m.weight)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def square_activation(x):
|
||||||
|
return torch.square(x)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_log(x):
|
||||||
|
return torch.clip(torch.log(x), min=1e-7, max=1e7)
|
||||||
|
|
||||||
|
|
||||||
|
class ShallowConvNet(nn.Module):
|
||||||
|
def __init__(self, num_classes=3, chans=19, samples=768):
|
||||||
|
super(ShallowConvNet, self).__init__()
|
||||||
|
self.conv_nums = 40
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
nn.Conv2d(1, self.conv_nums, (1, 25)),
|
||||||
|
nn.Conv2d(self.conv_nums, self.conv_nums, (chans, 1), bias=False),
|
||||||
|
nn.BatchNorm2d(self.conv_nums)
|
||||||
|
)
|
||||||
|
self.avgpool = nn.AvgPool2d(kernel_size=(1, 75), stride=(1, 15))
|
||||||
|
self.dropout = nn.Dropout()
|
||||||
|
|
||||||
|
out = torch.ones((1, 1, chans, samples))
|
||||||
|
out = self.features(out)
|
||||||
|
out = self.avgpool(out)
|
||||||
|
n_out_time = out.cpu().data.numpy().shape
|
||||||
|
self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = square_activation(x)
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = safe_log(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
features = torch.flatten(x, 1)
|
||||||
|
cls = self.classifier(features)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class EEGNet(nn.Module):
|
||||||
|
def __init__(self, num_classes=2, chans=8, samples=1000, dropout_rate=0.5, kernel_length=64, F1=8,
|
||||||
|
F2=16,):
|
||||||
|
super(EEGNet, self).__init__()
|
||||||
|
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
nn.Conv2d(1, F1, kernel_size=(1, kernel_length), bias=False),
|
||||||
|
nn.BatchNorm2d(F1),
|
||||||
|
nn.Conv2d(F1, F1, kernel_size=(chans, 1), groups=F1, bias=False), # groups=F1 for depthWiseConv
|
||||||
|
nn.BatchNorm2d(F1),
|
||||||
|
nn.ELU(inplace=True),
|
||||||
|
# nn.ReLU(),
|
||||||
|
nn.AvgPool2d((1, 4)),
|
||||||
|
nn.Dropout(dropout_rate),
|
||||||
|
# for SeparableCon2D
|
||||||
|
# SeparableConv2D(F1, F2, kernel1_size=(1, 16), bias=False),
|
||||||
|
nn.Conv2d(F1, F1, kernel_size=(1, 16), groups=F1, bias=False), # groups=F1 for depthWiseConv
|
||||||
|
nn.BatchNorm2d(F1),
|
||||||
|
nn.ELU(inplace=True),
|
||||||
|
# nn.ReLU(),
|
||||||
|
nn.Conv2d(F1, F2, kernel_size=(1, 1), groups=1, bias=False), # point-wise cnn
|
||||||
|
nn.BatchNorm2d(F2),
|
||||||
|
# nn.ReLU(),
|
||||||
|
nn.ELU(inplace=True),
|
||||||
|
nn.AvgPool2d((1, 8)),
|
||||||
|
nn.Dropout(p=dropout_rate),
|
||||||
|
# nn.Dropout(p=0.5),
|
||||||
|
)
|
||||||
|
out = torch.ones((1, 1, chans, samples))
|
||||||
|
out = self.features(out)
|
||||||
|
n_out_time = out.cpu().data.numpy().shape
|
||||||
|
self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
conv_features = self.features(x)
|
||||||
|
features = torch.flatten(conv_features, 1)
|
||||||
|
cls = self.classifier(features)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class LMDA(nn.Module):
|
||||||
|
"""
|
||||||
|
LMDA-Net for the paper
|
||||||
|
"""
|
||||||
|
def __init__(self, chans=19, samples=768, num_classes=3, depth=9, kernel=75, channel_depth1=24, channel_depth2=9,
|
||||||
|
ave_depth=1, avepool=5):
|
||||||
|
super(LMDA, self).__init__()
|
||||||
|
self.ave_depth = ave_depth
|
||||||
|
self.channel_weight = nn.Parameter(torch.randn(depth, 1, chans), requires_grad=True)
|
||||||
|
nn.init.xavier_uniform_(self.channel_weight.data)
|
||||||
|
|
||||||
|
|
||||||
|
self.time_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(depth, channel_depth1, kernel_size=(1, 1), groups=1, bias=False),
|
||||||
|
nn.BatchNorm2d(channel_depth1),
|
||||||
|
nn.Conv2d(channel_depth1, channel_depth1, kernel_size=(1, kernel),
|
||||||
|
groups=channel_depth1, bias=False),
|
||||||
|
nn.BatchNorm2d(channel_depth1),
|
||||||
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
# self.avgPool1 = nn.AvgPool2d((1, 24))
|
||||||
|
self.chanel_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(channel_depth1, channel_depth2, kernel_size=(1, 1), groups=1, bias=False),
|
||||||
|
nn.BatchNorm2d(channel_depth2),
|
||||||
|
nn.Conv2d(channel_depth2, channel_depth2, kernel_size=(chans, 1), groups=channel_depth2, bias=False),
|
||||||
|
nn.BatchNorm2d(channel_depth2),
|
||||||
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = nn.Sequential(
|
||||||
|
nn.AvgPool3d(kernel_size=(1, 1, avepool)),
|
||||||
|
# nn.AdaptiveAvgPool3d((9, 1, 35)),
|
||||||
|
nn.Dropout(p=0.65),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 定义自动填充模块
|
||||||
|
out = torch.ones((1, 1, chans, samples))
|
||||||
|
out = torch.einsum('bdcw, hdc->bhcw', out, self.channel_weight)
|
||||||
|
out = self.time_conv(out)
|
||||||
|
out = self.chanel_conv(out)
|
||||||
|
out = self.norm(out)
|
||||||
|
n_out_time = out.cpu().data.numpy().shape
|
||||||
|
print('In ShallowNet, n_out_time shape: ', n_out_time)
|
||||||
|
self.classifier = nn.Linear(n_out_time[-1]*n_out_time[-2]*n_out_time[-3], num_classes)
|
||||||
|
|
||||||
|
def EEGDepthAttention(self, x):
|
||||||
|
# x: input features with shape [N, C, H, W]
|
||||||
|
|
||||||
|
N, C, H, W = x.size()
|
||||||
|
# K = W if W % 2 else W + 1
|
||||||
|
k = 7
|
||||||
|
adaptive_pool = nn.AdaptiveAvgPool2d((1, W))
|
||||||
|
conv = nn.Conv2d(1, 1, kernel_size=(k, 1), padding=(k//2, 0), bias=True).to(x.device) # original kernel k
|
||||||
|
nn.init.xavier_uniform_(conv.weight)
|
||||||
|
nn.init.constant_(conv.bias, 0)
|
||||||
|
softmax = nn.Softmax(dim=-2)
|
||||||
|
x_pool = adaptive_pool(x)
|
||||||
|
x_transpose = x_pool.transpose(-2, -3)
|
||||||
|
y = conv(x_transpose)
|
||||||
|
y = softmax(y)
|
||||||
|
y = y.transpose(-2, -3)
|
||||||
|
return y * C * x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.einsum('bdcw, hdc->bhcw', x, self.channel_weight)
|
||||||
|
|
||||||
|
x_time = self.time_conv(x) # batch, depth1, channel, samples_
|
||||||
|
x_time = self.EEGDepthAttention(x_time) # DA1
|
||||||
|
|
||||||
|
x = self.chanel_conv(x_time) # batch, depth2, 1, samples_
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
features = torch.flatten(x, 1)
|
||||||
|
cls = self.classifier(features)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = ShallowConvNet(num_classes=4, chans=22, samples=1125).cuda()
|
||||||
|
a = torch.randn(12, 1, 3, 875).cuda().float()
|
||||||
|
l2 = model(a)
|
||||||
|
model_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
|
||||||
|
summary(model, show_input=True)
|
||||||
|
|
||||||
|
print(l2.shape)
|
||||||
|
|
||||||
34
Debug_64ch_Decoder/PubLibrary/InifileHelper.py
Normal file
34
Debug_64ch_Decoder/PubLibrary/InifileHelper.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# -*-coding:utf-8 -*-
|
||||||
|
import configparser
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from audioop import error
|
||||||
|
|
||||||
|
# 打包后需要切换到 exe 所在目录来定位 config.ini
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
_BASE_DIR = os.path.dirname(sys.executable)
|
||||||
|
else:
|
||||||
|
_BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
IniFileName = os.path.join(_BASE_DIR, 'config.ini')
|
||||||
|
|
||||||
|
def IniWrite(section,keyname,value):
|
||||||
|
# 创建ConfigParser对象
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(IniFileName,encoding='utf-8')
|
||||||
|
with open(IniFileName, 'w') as configfile:
|
||||||
|
if not config.has_section(section):
|
||||||
|
config.add_section(section)
|
||||||
|
config[section][keyname]=str(value)
|
||||||
|
config.write(configfile)
|
||||||
|
|
||||||
|
def IniRead(section,key):
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(IniFileName,encoding='utf-8')
|
||||||
|
return config[section][key]
|
||||||
|
except error as e:
|
||||||
|
print(e)
|
||||||
|
# 读取特定section和键的值
|
||||||
|
return '5'
|
||||||
15
Debug_64ch_Decoder/PubLibrary/RunOnce.py
Normal file
15
Debug_64ch_Decoder/PubLibrary/RunOnce.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import ctypes
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def is_program_running(name='Global\\Decoder'):
|
||||||
|
# 创建互斥体
|
||||||
|
mutex_name =name
|
||||||
|
h_mutex = ctypes.windll.kernel32.CreateMutexW(None, False, mutex_name)
|
||||||
|
|
||||||
|
# 检查互斥体是否已经存在
|
||||||
|
if ctypes.windll.kernel32.GetLastError() == 183: # ERROR_ALREADY_EXISTS
|
||||||
|
print("程序已经在运行.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
418
Debug_64ch_Decoder/SSMVEP/algorithm/base.py
Normal file
418
Debug_64ch_Decoder/SSMVEP/algorithm/base.py
Normal file
@@ -0,0 +1,418 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Authors: Swolf <swolfforever@gmail.com>
|
||||||
|
# Date: 2021/1/07
|
||||||
|
# License: MIT License
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, List, Tuple, Union
|
||||||
|
import warnings
|
||||||
|
import numpy as np
|
||||||
|
from numpy import ndarray
|
||||||
|
from numpy.linalg import linalg
|
||||||
|
from scipy.linalg import solve, qr
|
||||||
|
from scipy.signal import sosfiltfilt, cheby1, cheb1ord
|
||||||
|
from sklearn.base import BaseEstimator, TransformerMixin, clone
|
||||||
|
|
||||||
|
|
||||||
|
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
|
||||||
|
"""Transform spatial filters to spatial patterns based on paper [1]_.
|
||||||
|
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
|
||||||
|
information from different channels to extract signals of interest from EEG signals, but if our goal is
|
||||||
|
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
|
||||||
|
from the obtained spatial filters.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
W : ndarray
|
||||||
|
Spatial filters, shape (n_channels, n_filters).
|
||||||
|
Cx : ndarray
|
||||||
|
Covariance matrix of eeg data, shape (n_channels, n_channels).
|
||||||
|
Cs : ndarray
|
||||||
|
Covariance matrix of source data, shape (n_channels, n_channels).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A : ndarray
|
||||||
|
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
|
||||||
|
Neuroimage 87 (2014): 96-110.
|
||||||
|
"""
|
||||||
|
# use linalg.solve instead of inv, makes it more stable
|
||||||
|
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
|
||||||
|
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
|
||||||
|
A = solve(Cs.T, np.dot(Cx, W).T).T
|
||||||
|
return A
|
||||||
|
|
||||||
|
|
||||||
|
class FilterBank(BaseEstimator, TransformerMixin):
|
||||||
|
"""
|
||||||
|
Filter bank decomposition is a bandpass filter array that divides the input signal into
|
||||||
|
multiple subband components and obtains the eigenvalues of each subband component.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
base_estimator : class
|
||||||
|
Estimator for model training and feature extraction.
|
||||||
|
filterbank : list[ndarray]
|
||||||
|
A bandpass filter bank used to divide the input signal into multiple subband components.
|
||||||
|
n_jobs : int
|
||||||
|
Sets the number of CPU working cores. The default is None.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] Chen X, Wang Y, Nakanishi M, et al. High-speed spelling with a noninvasive brain-computer interface[J].
|
||||||
|
Proceedings of the national academy of sciences, 2015, 112(44): E6058-E6067.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_estimator: BaseEstimator,
|
||||||
|
filterbank: List[ndarray],
|
||||||
|
n_jobs: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.base_estimator = base_estimator
|
||||||
|
self.filterbank = filterbank
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
|
||||||
|
def fit(self, X: ndarray, y: Optional[ndarray] = None, **kwargs):
|
||||||
|
"""
|
||||||
|
Training model
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : None
|
||||||
|
Training signal (parameters can be ignored, only used to maintain code structure).
|
||||||
|
y : None
|
||||||
|
Label data (ibid., ignorable).
|
||||||
|
Yf : None
|
||||||
|
Reference signal (ibid., ignorable).
|
||||||
|
"""
|
||||||
|
self.estimators_ = [
|
||||||
|
clone(self.base_estimator) for _ in range(len(self.filterbank))
|
||||||
|
]
|
||||||
|
X = self.transform_filterbank(X)
|
||||||
|
for i, est in enumerate(self.estimators_):
|
||||||
|
est.fit(X[i], y, **kwargs)
|
||||||
|
# def wrapper(est, X, y, kwargs):
|
||||||
|
# est.fit(X, y, **kwargs)
|
||||||
|
# return est
|
||||||
|
# self.estimators_ = Parallel(n_jobs=self.n_jobs)(
|
||||||
|
# delayed(wrapper)(est, X[i], y, kwargs) for i, est in enumerate(self.estimators_))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: ndarray, **kwargs):
|
||||||
|
"""
|
||||||
|
The parameters stored in self are used to convert X into features, and X is filtered through the filter bank to
|
||||||
|
obtain the eigenvalues of each subband component.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray, shape(n_trials, n_channels, n_samples)
|
||||||
|
Test the signal.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
feat : ndarray, shape(n_trials, n_fre)
|
||||||
|
Feature array.
|
||||||
|
"""
|
||||||
|
X = self.transform_filterbank(X)
|
||||||
|
feat = [est.transform(X[i], **kwargs) for i, est in enumerate(self.estimators_)]
|
||||||
|
# def wrapper(est, X, kwargs):
|
||||||
|
# retval = est.transform(X, **kwargs)
|
||||||
|
# return retval
|
||||||
|
# feat = Parallel(n_jobs=self.n_jobs)(
|
||||||
|
# delayed(wrapper)(est, X[i], kwargs) for i, est in enumerate(self.estimators_))
|
||||||
|
feat = np.concatenate(feat, axis=-1)
|
||||||
|
return feat
|
||||||
|
|
||||||
|
def transform_filterbank(self, X: ndarray):
|
||||||
|
"""
|
||||||
|
The input signal is filtered through a filter bank.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray, shape(n_trials, n_channels, n_samples)
|
||||||
|
Input signal.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Xs: ndarray, shape(Nfb, n_trials, n_channels, n_samples)
|
||||||
|
Individual subband components of the input signal.
|
||||||
|
"""
|
||||||
|
Xs = np.stack([sosfiltfilt(sos, X, axis=-1) for sos in self.filterbank])
|
||||||
|
return Xs
|
||||||
|
|
||||||
|
|
||||||
|
class FilterBankSSVEP(FilterBank):
|
||||||
|
"""
|
||||||
|
Filter bank analysis for SSVEP.
|
||||||
|
The SSVEP is analyzed using filter banks, that is, multiple filters are combined to decompose the SSVEP signal
|
||||||
|
into specific segments (subbands containing the original data) and obtain its characteristic data.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filterbank : list[ndarray]
|
||||||
|
The filter bank.
|
||||||
|
base_estimator : class
|
||||||
|
Estimator for model training and feature extraction.
|
||||||
|
filterweights : ndarray
|
||||||
|
Filter weight, default is None.
|
||||||
|
n_jobs : int
|
||||||
|
Sets the number of CPU working cores. The default is None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filterbank: List[ndarray],
|
||||||
|
base_estimator: BaseEstimator,
|
||||||
|
filterweights: Optional[ndarray] = None,
|
||||||
|
n_jobs: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.filterweights = filterweights
|
||||||
|
super().__init__(base_estimator, filterbank, n_jobs=n_jobs)
|
||||||
|
|
||||||
|
def transform(self, X: ndarray): # type: ignore[override]
|
||||||
|
"""
|
||||||
|
X is converted into features by using the parameters stored in self, and the eigenvalues of each subband
|
||||||
|
component are obtained after the input signal is filtered by the filter bank.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray, shape(n_trials, n_channels, n_samples)
|
||||||
|
Test the signal.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
features : ndarray, shape(n_trials, n_fre)
|
||||||
|
Feature array.
|
||||||
|
"""
|
||||||
|
features = super().transform(X)
|
||||||
|
if self.filterweights is None:
|
||||||
|
return features
|
||||||
|
else:
|
||||||
|
features = np.reshape(
|
||||||
|
features, (features.shape[0], len(self.filterbank), -1)
|
||||||
|
)
|
||||||
|
return np.sum(
|
||||||
|
features * self.filterweights[np.newaxis, :, np.newaxis], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def generate_filterbank(
|
||||||
|
passbands: List[Tuple[float, float]],
|
||||||
|
stopbands: List[Tuple[float, float]],
|
||||||
|
srate: int,
|
||||||
|
order: Optional[int] = None,
|
||||||
|
rp: float = 0.5,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a filter bank, that is, obtain a bandpass filter coefficient that can divide the input signal into multiple
|
||||||
|
subband components.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
passbands : list or tuple(float, float)
|
||||||
|
Passband parameters.
|
||||||
|
stopbands : list or tuple(float, float)
|
||||||
|
Stopband parameters.
|
||||||
|
srate : float
|
||||||
|
Sampling rate.
|
||||||
|
order : int
|
||||||
|
Filter order.
|
||||||
|
rp : float
|
||||||
|
The maximum ripple allowed in the passband below the unit gain is 0.5 by default.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Filterbank:ndarray, shape(len(passbands), N, 6)
|
||||||
|
Filter bank coefficient.
|
||||||
|
"""
|
||||||
|
filterbank = []
|
||||||
|
for wp, ws in zip(passbands, stopbands):
|
||||||
|
if order is None:
|
||||||
|
N, wn = cheb1ord(wp, ws, 3, 40, fs=srate)
|
||||||
|
sos = cheby1(N, rp, wn, btype="bandpass", output="sos", fs=srate)
|
||||||
|
else:
|
||||||
|
sos = cheby1(order, rp, wp, btype="bandpass", output="sos", fs=srate)
|
||||||
|
|
||||||
|
filterbank.append(sos)
|
||||||
|
return filterbank
|
||||||
|
|
||||||
|
def process(data):
|
||||||
|
# 白化操作
|
||||||
|
meanValue = np.mat(data.mean(axis=1))
|
||||||
|
meanData = np.repeat(meanValue, data.shape[1], axis=1)
|
||||||
|
whiteTemp = data - meanData
|
||||||
|
# QR 分解
|
||||||
|
rankWhiteTemp = whiteTemp.shape[0]
|
||||||
|
whiteTemp = np.transpose(whiteTemp)
|
||||||
|
Q, R = qr(whiteTemp.A, mode='economic')
|
||||||
|
# 计算矩阵的秩
|
||||||
|
rankQ = linalg.matrix_rank(R)
|
||||||
|
if rankQ == 0:
|
||||||
|
raise ValueError('stats:canoncorr:badData')
|
||||||
|
elif rankQ <= rankWhiteTemp:
|
||||||
|
# warnings.warn('stats:canoncorr:NotFullRank')
|
||||||
|
Q = Q[:, 0:rankQ]
|
||||||
|
return Q, rankQ
|
||||||
|
|
||||||
|
def reference(listFreqs,fs, numberSmples, num_harms):
|
||||||
|
numberFrequence = len(listFreqs)
|
||||||
|
timeIndex = np.arange(1, numberSmples + 1) / fs # time index
|
||||||
|
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
|
||||||
|
for frequenceIndex in range(numberFrequence):
|
||||||
|
temp = []
|
||||||
|
for harmIndex in range(1, num_harms + 1):
|
||||||
|
stimFrequence = listFreqs[frequenceIndex] # in HZ
|
||||||
|
# Sin and Cos
|
||||||
|
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
|
||||||
|
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
|
||||||
|
referenceTemp = np.mat(temp)
|
||||||
|
# 白化操作和QR分解
|
||||||
|
Q, rankQ = process(referenceTemp)
|
||||||
|
referenceData[frequenceIndex] = np.transpose(Q)
|
||||||
|
return referenceData
|
||||||
|
|
||||||
|
def generate_cca_references(
|
||||||
|
freqs: Union[ndarray, int, float],
|
||||||
|
srate,
|
||||||
|
T,
|
||||||
|
phases: Optional[Union[ndarray, int, float]] = None,
|
||||||
|
n_harmonics: int = 1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Construct a sine-cosine reference signal for canonical correlation analysis (CCA).
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
freqs : int or float
|
||||||
|
Frequency.
|
||||||
|
srate : int
|
||||||
|
Sampling rate.
|
||||||
|
T : int
|
||||||
|
Sampling time.
|
||||||
|
phases : int or float
|
||||||
|
Phase, default is None.
|
||||||
|
n_harmonics : int
|
||||||
|
The number of harmonics. The default value is 1.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Yf:ndarray, shape(srate*T, n_harmonics*2)
|
||||||
|
Sine and cosine reference signal.
|
||||||
|
"""
|
||||||
|
if isinstance(freqs, int) or isinstance(freqs, float):
|
||||||
|
freqs = np.array([freqs])
|
||||||
|
freqs = np.array(freqs)[:, np.newaxis]
|
||||||
|
if phases is None:
|
||||||
|
phases = 0
|
||||||
|
if isinstance(phases, int) or isinstance(phases, float):
|
||||||
|
phases = np.array([phases])
|
||||||
|
phases = np.array(phases)[:, np.newaxis]
|
||||||
|
t = np.linspace(0, T, int(T * srate))
|
||||||
|
|
||||||
|
Yf = []
|
||||||
|
for i in range(n_harmonics):
|
||||||
|
Yf.append(
|
||||||
|
np.stack(
|
||||||
|
[
|
||||||
|
np.sin(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
|
||||||
|
np.cos(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
Yf = np.concatenate(Yf, axis=1)
|
||||||
|
return Yf
|
||||||
|
|
||||||
|
|
||||||
|
def sign_flip(u, s, vh=None):
|
||||||
|
"""Flip signs of SVD or EIG using the method in paper [1]_.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
u: ndarray
|
||||||
|
left singular vectors, shape (M, K).
|
||||||
|
s: ndarray
|
||||||
|
singular values, shape (K,).
|
||||||
|
vh: ndarray or None
|
||||||
|
transpose of right singular vectors, shape (K, N).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
u: ndarray
|
||||||
|
corrected left singular vectors.
|
||||||
|
s: ndarray
|
||||||
|
singular values.
|
||||||
|
vh: ndarray
|
||||||
|
transpose of corrected right singular vectors.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6422.pdf
|
||||||
|
"""
|
||||||
|
if vh is None:
|
||||||
|
total_proj = np.sum(u * s, axis=0)
|
||||||
|
signs = np.sign(total_proj)
|
||||||
|
|
||||||
|
random_idx = signs == 0
|
||||||
|
if np.any(random_idx):
|
||||||
|
signs[random_idx] = 1
|
||||||
|
warnings.warn(
|
||||||
|
"The magnitude is close to zero, the sign will become arbitrary."
|
||||||
|
)
|
||||||
|
|
||||||
|
u = u * signs
|
||||||
|
|
||||||
|
return u, s
|
||||||
|
else:
|
||||||
|
left_proj = np.sum(s[:, np.newaxis] * vh, axis=-1)
|
||||||
|
right_proj = np.sum(u * s, axis=0)
|
||||||
|
total_proj = left_proj + right_proj
|
||||||
|
signs = np.sign(total_proj)
|
||||||
|
|
||||||
|
random_idx = signs == 0
|
||||||
|
if np.any(random_idx):
|
||||||
|
signs[random_idx] = 1
|
||||||
|
warnings.warn(
|
||||||
|
"The magnitude is close to zero, the sign will become arbitrary."
|
||||||
|
)
|
||||||
|
|
||||||
|
u = u * signs
|
||||||
|
vh = signs[:, np.newaxis] * vh
|
||||||
|
|
||||||
|
return u, s, vh
|
||||||
436
Debug_64ch_Decoder/SSMVEP/algorithm/dsp.py
Normal file
436
Debug_64ch_Decoder/SSMVEP/algorithm/dsp.py
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# DSP: Discriminal Spatial Patterns
|
||||||
|
# Authors: Swolf <swolfforever@gmail.com>
|
||||||
|
# Junyang Wang <2144755928@qq.com>
|
||||||
|
# Last update date: 2022-8-11
|
||||||
|
# License: MIT License
|
||||||
|
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
from itertools import combinations
|
||||||
|
import numpy as np
|
||||||
|
from scipy.linalg import eigh
|
||||||
|
from numpy import ndarray
|
||||||
|
from scipy.linalg import solve
|
||||||
|
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
|
||||||
|
"""Transform spatial filters to spatial patterns based on paper [1]_.
|
||||||
|
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
|
||||||
|
information from different channels to extract signals of interest from EEG signals, but if our goal is
|
||||||
|
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
|
||||||
|
from the obtained spatial filters.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
W : ndarray
|
||||||
|
Spatial filters, shape (n_channels, n_filters).
|
||||||
|
Cx : ndarray
|
||||||
|
Covariance matrix of eeg data, shape (n_channels, n_channels).
|
||||||
|
Cs : ndarray
|
||||||
|
Covariance matrix of source data, shape (n_channels, n_channels).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A : ndarray
|
||||||
|
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
|
||||||
|
Neuroimage 87 (2014): 96-110.
|
||||||
|
"""
|
||||||
|
# use linalg.solve instead of inv, makes it more stable
|
||||||
|
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
|
||||||
|
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
|
||||||
|
A = solve(Cs.T, np.dot(Cx, W).T).T
|
||||||
|
return A
|
||||||
|
|
||||||
|
def isPD(B: ndarray) -> bool:
|
||||||
|
"""Returns true when input matrix is positive-definite, via Cholesky decompositon method.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
B : ndarray
|
||||||
|
Any matrix, shape (N, N)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if B is positve-definite.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
Use numpy.linalg rather than scipy.linalg. In this case, scipy.linalg has unpredictable behaviors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = np.linalg.cholesky(B)
|
||||||
|
return True
|
||||||
|
except np.linalg.LinAlgError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def nearestPD(A: ndarray) -> ndarray:
|
||||||
|
"""Find the nearest positive-definite matrix to input.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
A : ndarray
|
||||||
|
Any square matrxi, shape (N, N)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A3 : ndarray
|
||||||
|
positive-definite matrix to A
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1]_, which
|
||||||
|
origins at [2]_.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd
|
||||||
|
.. [2] N.J. Higham, "Computing a nearest symmetric positive semidefinite matrix" (1988):
|
||||||
|
https://doi.org/10.1016/0024-3795(88)90223-6
|
||||||
|
"""
|
||||||
|
|
||||||
|
B = (A + A.T) / 2
|
||||||
|
_, s, V = np.linalg.svd(B)
|
||||||
|
|
||||||
|
H = np.dot(V.T, np.dot(np.diag(s), V))
|
||||||
|
|
||||||
|
A2 = (B + H) / 2
|
||||||
|
|
||||||
|
A3 = (A2 + A2.T) / 2
|
||||||
|
|
||||||
|
if isPD(A3):
|
||||||
|
return A3
|
||||||
|
|
||||||
|
print("Replace current matrix with the nearest positive-definite matrix.")
|
||||||
|
|
||||||
|
spacing = np.spacing(np.linalg.norm(A))
|
||||||
|
# The above is different from [1]. It appears that MATLAB's `chol` Cholesky
|
||||||
|
# decomposition will accept matrixes with exactly 0-eigenvalue, whereas
|
||||||
|
# Numpy's will not. So where [1] uses `eps(mineig)` (where `eps` is Matlab
|
||||||
|
# for `numpy.spacing`), we use the above definition. CAVEAT: our `spacing`
|
||||||
|
# will be much larger than [1]'s `eps(mineig)`, since `mineig` is usually on
|
||||||
|
# the order of 1e-16, and `eps(1e-16)` is on the order of 1e-34, whereas
|
||||||
|
# `spacing` will, for Gaussian random matrixes of small dimension, be on
|
||||||
|
# othe order of 1e-16. In practice, both ways converge, as the unit test
|
||||||
|
# below suggests.
|
||||||
|
eye = np.eye(A.shape[0])
|
||||||
|
k = 1
|
||||||
|
while not isPD(A3):
|
||||||
|
mineig = np.min(np.real(np.linalg.eigvals(A3)))
|
||||||
|
A3 += eye * (-mineig * k**2 + spacing)
|
||||||
|
k += 1
|
||||||
|
|
||||||
|
return A3
|
||||||
|
|
||||||
|
def xiang_dsp_kernel(
|
||||||
|
X: ndarray, y: ndarray
|
||||||
|
) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
|
||||||
|
"""
|
||||||
|
DSP: Discriminal Spatial Patterns, only for two classes[1]_.
|
||||||
|
Import train data to solve spatial filters with DSP,
|
||||||
|
finds a projection matrix that maximize the between-class scatter matrix and
|
||||||
|
minimize the within-class scatter matrix. Currently only support for two types of data.
|
||||||
|
|
||||||
|
Author: Swolf <swolfforever@gmail.com>
|
||||||
|
|
||||||
|
Created on: 2021-1-07
|
||||||
|
|
||||||
|
Update log:
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray
|
||||||
|
EEG train data assuming removing mean, shape (n_trials, n_channels, n_samples)
|
||||||
|
y : ndarray
|
||||||
|
labels of EEG data, shape (n_trials, )
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
W : ndarray
|
||||||
|
spatial filters, shape (n_channels, n_filters)
|
||||||
|
D : ndarray
|
||||||
|
eigenvalues in descending order
|
||||||
|
M : ndarray
|
||||||
|
mean value of all classes and trials, i.e. common mode signals, shape (n_channel, n_samples)
|
||||||
|
A : ndarray
|
||||||
|
spatial patterns, shape (n_channels, n_filters)
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
the implementation removes regularization on within-class scatter matrix Sw.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in
|
||||||
|
a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831.
|
||||||
|
"""
|
||||||
|
X, y = np.copy(X), np.copy(y)
|
||||||
|
labels = np.unique(y)
|
||||||
|
X = np.reshape(X, (-1, *X.shape[-2:]))
|
||||||
|
X = X - np.mean(X, axis=-1, keepdims=True)
|
||||||
|
# the number of each label
|
||||||
|
n_labels = np.array([np.sum(y == label) for label in labels])
|
||||||
|
# average template of all trials
|
||||||
|
M = np.mean(X, axis=0)
|
||||||
|
# class conditional template
|
||||||
|
Ms, Ss = zip(
|
||||||
|
*[
|
||||||
|
(
|
||||||
|
np.mean(X[y == label], axis=0),
|
||||||
|
np.sum(
|
||||||
|
np.matmul(X[y == label], np.swapaxes(X[y == label], -1, -2)), axis=0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for label in labels
|
||||||
|
]
|
||||||
|
)
|
||||||
|
Ms, Ss = np.stack(Ms), np.stack(Ss)
|
||||||
|
# within-class scatter matrix
|
||||||
|
Sw = np.sum(
|
||||||
|
Ss
|
||||||
|
- n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
Ms = Ms - M
|
||||||
|
# between-class scatter matrix
|
||||||
|
Sb = np.sum(
|
||||||
|
n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
D, W = eigh(nearestPD(Sb), nearestPD(Sw))
|
||||||
|
ix = np.argsort(D)[::-1] # in descending order
|
||||||
|
D, W = D[ix], W[:, ix]
|
||||||
|
A = robust_pattern(W, Sb, W.T @ Sb @ W)
|
||||||
|
|
||||||
|
return W, D, M, A
|
||||||
|
|
||||||
|
|
||||||
|
def xiang_dsp_feature(
|
||||||
|
W: ndarray, M: ndarray, X: ndarray, n_components: int = 1
|
||||||
|
) -> ndarray:
|
||||||
|
"""
|
||||||
|
Return DSP features in paper [1]_.
|
||||||
|
|
||||||
|
Author: Swolf <swolfforever@gmail.com>
|
||||||
|
|
||||||
|
Created on: 2021-1-07
|
||||||
|
|
||||||
|
Update log:
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
W : ndarray
|
||||||
|
spatial filters from csp_kernel, shape (n_channels, n_filters)
|
||||||
|
M : ndarray
|
||||||
|
common template for all classes, shape (n_channel, n_samples)
|
||||||
|
X : ndarray
|
||||||
|
eeg test data, shape (n_trials, n_channels, n_samples)
|
||||||
|
n_components : int, optional
|
||||||
|
length of the spatial filters, first k components to use, by default 1
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
features: ndarray
|
||||||
|
features, shape (n_trials, n_components, n_samples)
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
n_components should less than half of the number of channels
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
1. instead of meaning of filtered signals in paper [1]_., we directly return filtered signals.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in
|
||||||
|
a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831.
|
||||||
|
"""
|
||||||
|
W, M, X = np.copy(W), np.copy(M), np.copy(X)
|
||||||
|
max_components = W.shape[1]
|
||||||
|
if n_components > max_components:
|
||||||
|
raise ValueError("n_components should less than the number of channels")
|
||||||
|
X = np.reshape(X, (-1, *X.shape[-2:]))
|
||||||
|
X = X - np.mean(X, axis=-1, keepdims=True)
|
||||||
|
# print('************: ',np.shape(W),np.shape(X),np.shape(M))
|
||||||
|
features = np.matmul(W[:, :n_components].T, X - M)
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
class DSP(BaseEstimator, TransformerMixin, ClassifierMixin):
|
||||||
|
"""
|
||||||
|
DSP: Discriminal Spatial Patterns
|
||||||
|
|
||||||
|
Author: Swolf <swolfforever@gmail.com>
|
||||||
|
|
||||||
|
Created on: 2021-1-07
|
||||||
|
|
||||||
|
Update log:
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_components : int
|
||||||
|
length of the spatial filter, first k components to use, by default 1
|
||||||
|
transform_method : str
|
||||||
|
method of template matching, by default ’corr‘ (pearson correlation coefficient)
|
||||||
|
classes_ : int
|
||||||
|
number of the EEG classes
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
n_components : int
|
||||||
|
length of the spatial filter, first k components to use, by default 1
|
||||||
|
transform_method : str
|
||||||
|
method of template matching, by default ’corr‘ (pearson correlation coefficient)
|
||||||
|
classes_ : int
|
||||||
|
number of the EEG classes
|
||||||
|
W_ : ndarray, shape(n_channels, n_filters)
|
||||||
|
Spatial filters, shape(n_channels, n_filters), in which n_channels = n_filters
|
||||||
|
D_ : ndarray, shape(n_filters, )
|
||||||
|
eigenvalues in descending order, shape(n_filters, )
|
||||||
|
M_ : ndarray, shape(n_channels, n_samples)
|
||||||
|
mean value of all classes and trials, i.e. common mode signals, shape(n_channels, n_samples)
|
||||||
|
A_ : ndarray, shape(n_channels, n_filters)
|
||||||
|
spatial patterns, shape(n_channels, n_filters)
|
||||||
|
templates_: ndarray, shape(n_classes, n_filters, n_samples)
|
||||||
|
templates of train data, shape(n_classes, n_filters, n_samples)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_components: int = 1, transform_method: str = "corr"):
|
||||||
|
self.n_components = n_components
|
||||||
|
self.transform_method = transform_method
|
||||||
|
|
||||||
|
def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None):
|
||||||
|
"""
|
||||||
|
Import the train data to get a model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray
|
||||||
|
train data, shape(n_trials, n_channels, n_samples)
|
||||||
|
y : ndarray
|
||||||
|
labels of train data, shape (n_trials, )
|
||||||
|
Yf : ndarray
|
||||||
|
optional parameter
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
W_ : ndarray
|
||||||
|
spatial filters, shape (n_channels, n_filters), in which n_channels = n_filters
|
||||||
|
D_ : ndarray
|
||||||
|
eigenvalues in descending order, shape (n_filters, )
|
||||||
|
M_ : ndarray
|
||||||
|
template for all classes, shape (n_channel, n_samples)
|
||||||
|
A_ : ndarray
|
||||||
|
spatial patterns, shape (n_channels, n_filters)
|
||||||
|
templates_ : ndarray
|
||||||
|
templates of train data, shape (n_channels, n_filters, n_samples)
|
||||||
|
"""
|
||||||
|
X -= np.mean(X, axis=-1, keepdims=True)
|
||||||
|
self.classes_ = np.unique(y)
|
||||||
|
self.W_, self.D_, self.M_, self.A_ = xiang_dsp_kernel(X, y)
|
||||||
|
|
||||||
|
self.templates_ = np.stack(
|
||||||
|
[
|
||||||
|
np.mean(
|
||||||
|
xiang_dsp_feature(
|
||||||
|
self.W_, self.M_, X[y == label], n_components=self.W_.shape[1]
|
||||||
|
),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
for label in self.classes_
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: ndarray):
|
||||||
|
"""
|
||||||
|
Import the test data to get features.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray
|
||||||
|
test data, shape(n_trials, n_channels, n_samples)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
feature : ndarray, shape(n_trials,n_classes)
|
||||||
|
correlation coefficients of templates of train data and features of test data, shape(n_trials, n_classes)
|
||||||
|
"""
|
||||||
|
n_components = self.n_components
|
||||||
|
X -= np.mean(X, axis=-1, keepdims=True)
|
||||||
|
features = xiang_dsp_feature(self.W_, self.M_, X, n_components=n_components)
|
||||||
|
if self.transform_method is None:
|
||||||
|
return features.reshape((features.shape[0], -1))
|
||||||
|
elif self.transform_method == "mean":
|
||||||
|
return np.mean(features, axis=-1)
|
||||||
|
elif self.transform_method == "corr":
|
||||||
|
return self._pearson_features(
|
||||||
|
features, self.templates_[:, :n_components, :]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("non-supported transform method")
|
||||||
|
|
||||||
|
def _pearson_features(self, X: ndarray, templates: ndarray):
|
||||||
|
"""
|
||||||
|
Calculate pearson correlation coefficient.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray
|
||||||
|
features of test data after spatial filters, shape(n_trials, n_components, n_samples)
|
||||||
|
templates : ndarray
|
||||||
|
templates of train data, shape(n_classes, n_components, n_samples)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
corr : ndarray
|
||||||
|
pearson correlation coefficient, shape(n_trials, n_classes)
|
||||||
|
"""
|
||||||
|
X = np.reshape(X, (-1, *X.shape[-2:]))
|
||||||
|
templates = np.reshape(templates, (-1, *templates.shape[-2:]))
|
||||||
|
X = X - np.mean(X, axis=-1, keepdims=True)
|
||||||
|
templates = templates - np.mean(templates, axis=-1, keepdims=True)
|
||||||
|
X = np.reshape(X, (X.shape[0], -1))
|
||||||
|
templates = np.reshape(templates, (templates.shape[0], -1))
|
||||||
|
istd_X = 1 / np.std(X, axis=-1, keepdims=True)
|
||||||
|
istd_templates = 1 / np.std(templates, axis=-1, keepdims=True)
|
||||||
|
corr = (X @ templates.T) / (templates.shape[1] - 1)
|
||||||
|
corr = istd_X * corr * istd_templates.T
|
||||||
|
return corr
|
||||||
|
|
||||||
|
def predict(self, X: ndarray):
|
||||||
|
"""
|
||||||
|
Import the templates and the test data to get prediction labels.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray
|
||||||
|
test data, shape(n_trials, n_channels, n_samples)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
labels : ndarray
|
||||||
|
prediction labels of test data, shape(n_trials,)
|
||||||
|
"""
|
||||||
|
feat = self.transform(X)
|
||||||
|
if self.transform_method == "corr":
|
||||||
|
labels = self.classes_[np.argmax(feat, axis=-1)]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
175
Debug_64ch_Decoder/SSMVEP/algorithm/tdca.py
Normal file
175
Debug_64ch_Decoder/SSMVEP/algorithm/tdca.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Authors: Swolf <swolfforever@gmail.com>
|
||||||
|
# Date: 2021/10/10
|
||||||
|
# License: MIT License
|
||||||
|
"""
|
||||||
|
Task Decomposition Component Analysis.
|
||||||
|
"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.linalg import qr
|
||||||
|
from scipy.stats import pearsonr
|
||||||
|
from numpy import ndarray
|
||||||
|
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin
|
||||||
|
from typing import Optional, List
|
||||||
|
from SSMVEP.algorithm.base import FilterBankSSVEP
|
||||||
|
from SSMVEP.algorithm.dsp import xiang_dsp_kernel, xiang_dsp_feature
|
||||||
|
|
||||||
|
|
||||||
|
def proj_ref(Yf: ndarray):
|
||||||
|
Q, R = qr(Yf.T, mode="economic")
|
||||||
|
P = Q @ Q.T
|
||||||
|
return P
|
||||||
|
|
||||||
|
|
||||||
|
def aug_2(X: ndarray, n_samples: int, padding_len: int, P: ndarray, training: bool = True):
|
||||||
|
X = X.reshape((-1, *X.shape[-2:]))
|
||||||
|
n_trials, n_channels, n_points = X.shape
|
||||||
|
# if n_points < padding_len + n_samples:
|
||||||
|
# raise ValueError("the length of X should be larger than l+n_samples.")
|
||||||
|
aug_X = np.zeros((n_trials, (padding_len + 1) * n_channels, n_samples))
|
||||||
|
if training:
|
||||||
|
for i in range(padding_len + 1):
|
||||||
|
aug_X[:, i * n_channels : (i + 1) * n_channels, :] = X[
|
||||||
|
..., i : i + n_samples
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
for i in range(padding_len + 1):
|
||||||
|
aug_X[:, i * n_channels : (i + 1) * n_channels, : n_samples - i] = X[
|
||||||
|
..., i:n_samples
|
||||||
|
]
|
||||||
|
aug_Xp = aug_X @ P
|
||||||
|
aug_X = np.concatenate([aug_X, aug_Xp], axis=-1)
|
||||||
|
return aug_X
|
||||||
|
|
||||||
|
|
||||||
|
def tdca_feature(
|
||||||
|
X: ndarray,
|
||||||
|
templates: ndarray,
|
||||||
|
W: ndarray,
|
||||||
|
M: ndarray,
|
||||||
|
Ps: List[ndarray],
|
||||||
|
padding_len: int,
|
||||||
|
n_components: int = 1,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
rhos = []
|
||||||
|
for Xk, P in zip(templates, Ps):
|
||||||
|
a = xiang_dsp_feature(
|
||||||
|
W,
|
||||||
|
M,
|
||||||
|
aug_2(X, P.shape[0], padding_len, P, training=training),
|
||||||
|
n_components=n_components,
|
||||||
|
)
|
||||||
|
b = Xk[:n_components, :]
|
||||||
|
a = np.reshape(a, (-1))
|
||||||
|
b = np.reshape(b, (-1))
|
||||||
|
rhos.append(pearsonr(a, b)[0])
|
||||||
|
return rhos
|
||||||
|
|
||||||
|
|
||||||
|
class TDCA(BaseEstimator, TransformerMixin, ClassifierMixin):
|
||||||
|
def __init__(self, padding_len: int, n_components: int = 1):
|
||||||
|
self.padding_len = padding_len
|
||||||
|
self.n_components = n_components
|
||||||
|
|
||||||
|
def fit(self, X: ndarray, y: ndarray, Yf: ndarray):
|
||||||
|
X -= np.mean(X, axis=-1, keepdims=True)
|
||||||
|
self.classes_ = np.unique(y)
|
||||||
|
self.Ps_ = [proj_ref(Yf[i]) for i in range(len(self.classes_))]
|
||||||
|
# print(np.shape(self.Ps_))
|
||||||
|
|
||||||
|
aug_X_list, aug_Y_list = [], []
|
||||||
|
for i, label in enumerate(self.classes_):
|
||||||
|
aug_X_list.append(
|
||||||
|
aug_2(
|
||||||
|
X[y == label],
|
||||||
|
self.Ps_[i].shape[0],
|
||||||
|
self.padding_len,
|
||||||
|
self.Ps_[i],
|
||||||
|
training=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
aug_Y_list.append(y[y == label])
|
||||||
|
|
||||||
|
aug_X = np.concatenate(aug_X_list, axis=0)
|
||||||
|
aug_Y = np.concatenate(aug_Y_list, axis=0)
|
||||||
|
self.W_, _, self.M_, _ = xiang_dsp_kernel(aug_X, aug_Y)
|
||||||
|
|
||||||
|
self.templates_ = np.stack(
|
||||||
|
[
|
||||||
|
np.mean(
|
||||||
|
xiang_dsp_feature(
|
||||||
|
self.W_,
|
||||||
|
self.M_,
|
||||||
|
aug_X[aug_Y == label],
|
||||||
|
n_components=self.W_.shape[1],
|
||||||
|
),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
for label in self.classes_
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: ndarray):
|
||||||
|
n_components = self.n_components
|
||||||
|
X -= np.mean(X, axis=-1, keepdims=True)
|
||||||
|
X = X.reshape((-1, *X.shape[-2:]))
|
||||||
|
rhos = [
|
||||||
|
tdca_feature(
|
||||||
|
tmp,
|
||||||
|
self.templates_,
|
||||||
|
self.W_,
|
||||||
|
self.M_,
|
||||||
|
self.Ps_,
|
||||||
|
self.padding_len,
|
||||||
|
n_components=n_components,
|
||||||
|
)
|
||||||
|
for tmp in X
|
||||||
|
]
|
||||||
|
rhos = np.stack(rhos)
|
||||||
|
return rhos
|
||||||
|
|
||||||
|
def predict(self, X: ndarray):
|
||||||
|
feat = self.transform(X)
|
||||||
|
labels = self.classes_[np.argmax(feat, axis=-1)]
|
||||||
|
return labels,feat
|
||||||
|
|
||||||
|
|
||||||
|
class FBTDCA(FilterBankSSVEP, ClassifierMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filterbank: List[ndarray],
|
||||||
|
padding_len: int,
|
||||||
|
n_components: int = 1,
|
||||||
|
filterweights: Optional[ndarray] = None,
|
||||||
|
n_jobs: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.padding_len = padding_len
|
||||||
|
self.n_components = n_components
|
||||||
|
self.filterweights = filterweights
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
super().__init__(
|
||||||
|
filterbank,
|
||||||
|
TDCA(padding_len, n_components=n_components),
|
||||||
|
filterweights=filterweights,
|
||||||
|
n_jobs=n_jobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None): # type: ignore[override]
|
||||||
|
self.classes_ = np.unique(y)
|
||||||
|
super().fit(X, y, Yf=Yf)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, X: ndarray):
|
||||||
|
features = self.transform(X)
|
||||||
|
if self.filterweights is None:
|
||||||
|
features = np.reshape(
|
||||||
|
features, (features.shape[0], len(self.filterbank), -1)
|
||||||
|
)
|
||||||
|
features = np.mean(features, axis=1)
|
||||||
|
labels = self.classes_[np.argmax(features, axis=-1)]
|
||||||
|
return labels,features
|
||||||
527
Debug_64ch_Decoder/SSVEP/dwfbcca.py
Normal file
527
Debug_64ch_Decoder/SSVEP/dwfbcca.py
Normal file
@@ -0,0 +1,527 @@
|
|||||||
|
|
||||||
|
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from os import error
|
||||||
|
import numpy as np
|
||||||
|
import scipy
|
||||||
|
from numpy.linalg import linalg
|
||||||
|
from scipy.io import loadmat
|
||||||
|
from scipy.linalg import qr
|
||||||
|
from scipy.signal import filtfilt, lfilter
|
||||||
|
|
||||||
|
|
||||||
|
class FbccaDw:
|
||||||
|
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
||||||
|
print('******************************************')
|
||||||
|
print('parameter list')
|
||||||
|
print('target:', num_target)
|
||||||
|
print('number of filter bank:', num_filter)
|
||||||
|
print('parameter:', parameter)
|
||||||
|
print('width:', width)
|
||||||
|
self.phase = 0
|
||||||
|
self.bandWidth = width
|
||||||
|
self.winNum = winNum
|
||||||
|
self.num_harms = num_harms
|
||||||
|
self.num_target = num_target
|
||||||
|
self.num_chans = num_chans
|
||||||
|
self.winTimeDelay = stimTime
|
||||||
|
self.fs = fs
|
||||||
|
self.winMaxSampleNum = self.winTimeDelay * self.winNum * self.fs
|
||||||
|
self.winDelayNum = round(self.winTimeDelay * self.fs)
|
||||||
|
self.num_fbs = num_filter
|
||||||
|
parameterValue = np.power(np.arange(1, self.num_fbs + 1), -(parameter[0])) + parameter[1]
|
||||||
|
self.weightValue = parameterValue / (sum(parameterValue))
|
||||||
|
|
||||||
|
self.dataUseLen = [0] * self.winNum
|
||||||
|
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
|
||||||
|
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
|
||||||
|
self.rhoNum = 2
|
||||||
|
self.notchZh = [0]
|
||||||
|
self.filterZf = [0] * self.num_fbs
|
||||||
|
self.north_b = []
|
||||||
|
self.north_a = []
|
||||||
|
self.filterBank_A = []
|
||||||
|
self.filterBank_B = []
|
||||||
|
self.winStep = 1
|
||||||
|
self.DW_cost_method = 'DW11' if method==1 else 'DW1'
|
||||||
|
|
||||||
|
'''
|
||||||
|
filterFrequenceBank:根据刺激频率生成的通带和阻带,用于滤波器组频带分解
|
||||||
|
'''
|
||||||
|
|
||||||
|
def filterFrequenceBank(self):
|
||||||
|
# 阻带的最高频率
|
||||||
|
lastFrequence = 90
|
||||||
|
freqBandWidth = self.bandWidth[1]
|
||||||
|
fStep = self.bandWidth[0]
|
||||||
|
bandFrequence = np.zeros((5, 4))
|
||||||
|
# 第二列频率带
|
||||||
|
band = list(range(freqBandWidth, lastFrequence, fStep))
|
||||||
|
band[:] = [x - 2 for x in band]
|
||||||
|
colValue = np.maximum(np.asmatrix(band), 1)
|
||||||
|
bandFrequence[:, 1] = colValue[0, 0:5]
|
||||||
|
# 第一列频率带
|
||||||
|
bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
|
||||||
|
# 第三列频率带
|
||||||
|
bandFrequence[:, 2] = lastFrequence + 2
|
||||||
|
# 第四列频率带
|
||||||
|
bandFrequence[:, 3] = bandFrequence[:, 2] + 10
|
||||||
|
# bandFrequence = np.array([[30,33,77,82],
|
||||||
|
# [62,68,77,82]])
|
||||||
|
for idx_fb in range(self.num_fbs):
|
||||||
|
Nq = self.fs / 2
|
||||||
|
Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
|
||||||
|
Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
|
||||||
|
[N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
|
||||||
|
40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
|
||||||
|
[B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
|
||||||
|
self.filterBank_A.append(A)
|
||||||
|
self.filterBank_B.append(B)
|
||||||
|
# def filterFrequenceBank(self):
|
||||||
|
# # 阻带的最高频率
|
||||||
|
# lastFrequence = 90
|
||||||
|
# freqBandWidth = self.bandWidth[1]
|
||||||
|
# fStep = self.bandWidth[0]
|
||||||
|
# bandFrequence = np.zeros((5, 4))
|
||||||
|
# # 第二列频率带
|
||||||
|
# band = list(range(freqBandWidth, lastFrequence, fStep))
|
||||||
|
# band[:] = [x - 2 for x in band]
|
||||||
|
# colValue = np.maximum(np.asmatrix(band), 1)
|
||||||
|
# bandFrequence[:, 1] = colValue[0, 0:5]
|
||||||
|
# # 第一列频率带
|
||||||
|
# bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
|
||||||
|
# # 第三列频率带
|
||||||
|
# bandFrequence[:, 2] = lastFrequence + 2
|
||||||
|
# # 第四列频率带
|
||||||
|
# bandFrequence[:, 3] = bandFrequence[:, 2] + 10
|
||||||
|
# for idx_fb in range(self.num_fbs):
|
||||||
|
# Nq = self.fs / 2
|
||||||
|
# Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
|
||||||
|
# Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
|
||||||
|
# [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
|
||||||
|
# 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
|
||||||
|
# [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
|
||||||
|
# self.filterBank_A.append(A)
|
||||||
|
# self.filterBank_B.append(B)
|
||||||
|
|
||||||
|
'''
|
||||||
|
Filter bank analysis
|
||||||
|
Input:
|
||||||
|
eeg : Input eeg data (# of targets, # of channels, Data length [sample])
|
||||||
|
Output:
|
||||||
|
filterData : Generated filter Data
|
||||||
|
'''
|
||||||
|
|
||||||
|
def filterbank(self, eeg):
|
||||||
|
filterData = np.zeros((self.num_fbs, eeg.shape[1], eeg.shape[0]))
|
||||||
|
for filterIndex in range(self.num_fbs):
|
||||||
|
if np.all(self.filterZf[filterIndex] == 0):
|
||||||
|
zi = np.zeros(
|
||||||
|
[max(len(self.filterBank_A[filterIndex]), len(self.filterBank_B[filterIndex])) - 1, self.num_chans])
|
||||||
|
_, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex],
|
||||||
|
eeg, zi=zi.T)
|
||||||
|
Data = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex], eeg)
|
||||||
|
else:
|
||||||
|
Data, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex],
|
||||||
|
self.filterBank_A[filterIndex], eeg,
|
||||||
|
zi=self.filterZf[filterIndex])
|
||||||
|
filterData[filterIndex, :, :] = Data.T
|
||||||
|
return filterData
|
||||||
|
|
||||||
|
'''
|
||||||
|
process
|
||||||
|
矩阵的白化和QR正则化分解,降低矩阵的维度,加速计算时间
|
||||||
|
Input:
|
||||||
|
data : 输入的二维脑电信号
|
||||||
|
Output:
|
||||||
|
Q : 降维后的矩阵
|
||||||
|
rankQ :正则矩阵的秩
|
||||||
|
'''
|
||||||
|
|
||||||
|
def process(self, data):
|
||||||
|
# 白化操作
|
||||||
|
meanValue = np.asmatrix(data.mean(axis=1))
|
||||||
|
meanData = np.repeat(meanValue, data.shape[1], axis=1)
|
||||||
|
whiteTemp = data - meanData
|
||||||
|
# QR 分解
|
||||||
|
rankWhiteTemp = whiteTemp.shape[0]
|
||||||
|
whiteTemp = np.transpose(whiteTemp)
|
||||||
|
Q, R = qr(whiteTemp.A, mode='economic')
|
||||||
|
# 计算矩阵的秩
|
||||||
|
rankQ = linalg.matrix_rank(R)
|
||||||
|
if rankQ == 0:
|
||||||
|
raise ValueError('stats:canoncorr:badData')
|
||||||
|
elif rankQ <= rankWhiteTemp:
|
||||||
|
# warnings.warn('stats:canoncorr:NotFullRank')
|
||||||
|
Q = Q[:, 0:rankQ]
|
||||||
|
return Q, rankQ
|
||||||
|
|
||||||
|
'''
|
||||||
|
reference
|
||||||
|
Input:
|
||||||
|
listFreqs : 刺激频率列表
|
||||||
|
numberSmples : 用于分类的脑电信号采样点个数
|
||||||
|
num_harms : 谐波数
|
||||||
|
Output:
|
||||||
|
y_ref : 生成的参考信号 (刺激目标数, 2 * 谐波数, 数据长度/采样点数)
|
||||||
|
'''
|
||||||
|
|
||||||
|
def reference(self, listFreqs, numberSmples, num_harms):
|
||||||
|
numberFrequence = len(listFreqs)
|
||||||
|
timeIndex = np.arange(1, numberSmples + 1) / self.fs # time index
|
||||||
|
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
|
||||||
|
for frequenceIndex in range(numberFrequence):
|
||||||
|
temp = []
|
||||||
|
for harmIndex in range(1, num_harms + 1):
|
||||||
|
stimFrequence = listFreqs[frequenceIndex] # in HZ
|
||||||
|
# Sin and Cos
|
||||||
|
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
|
||||||
|
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
|
||||||
|
referenceTemp = np.asmatrix(temp)
|
||||||
|
# 白化操作和QR分解
|
||||||
|
Q, rankQ = self.process(referenceTemp)
|
||||||
|
referenceData[frequenceIndex] = np.transpose(Q)
|
||||||
|
return referenceData
|
||||||
|
|
||||||
|
'''
|
||||||
|
setNorthFilterPara
|
||||||
|
陷波器的参数初始化
|
||||||
|
self.north_b, self.north_a : 陷波器的参数设计
|
||||||
|
'''
|
||||||
|
|
||||||
|
def setNotchFilterPara(self):
|
||||||
|
# notchFilterNum = 3
|
||||||
|
# northFreq = 50
|
||||||
|
# bwDen = 35
|
||||||
|
# wo = northFreq / (self.fs / 2)
|
||||||
|
# bw = wo / bwDen
|
||||||
|
# self.north_b, self.north_a = iirnotch(wo, Q=35) # self.north_b, self.north_a = iircomb(northFreq, bwDen, 'notch')
|
||||||
|
# # n倍零极点,相当于重复滤波n次
|
||||||
|
# if notchFilterNum > 1:
|
||||||
|
# z, p, k = tf2zpk(self.north_b, self.north_a)
|
||||||
|
# zNew = np.repeat(z, notchFilterNum, axis=0)
|
||||||
|
# zNew[1], zNew[4] = zNew[4], zNew[1]
|
||||||
|
# pNew = np.repeat(p, notchFilterNum, axis=0)
|
||||||
|
# pNew[1], pNew[4] = pNew[4], pNew[1]
|
||||||
|
# kNew = np.power(k, notchFilterNum)
|
||||||
|
# self.north_b, self.north_a = zpk2tf(zNew, pNew, kNew)
|
||||||
|
self.north_b = [0.94801603944125245604368501517456, -1.7577184027642647201616910024313,
|
||||||
|
3.9303778338832491279219993884908, -3.7392330345967859095424046245171,
|
||||||
|
3.9303778338832482397435796883656, -1.7577184027642638319832713023061,
|
||||||
|
0.94801603944125156786526531504933]
|
||||||
|
|
||||||
|
self.north_a = [1, -1.8214007435820627200939725298667, 4.0000101767406484043476666556671,
|
||||||
|
-3.7380998614928691026193519064691, 3.8589119784285759173769747576443,
|
||||||
|
-1.6951692350503837491970671180752, 0.89786559147978006745205448169145]
|
||||||
|
|
||||||
|
'''
|
||||||
|
northFilter
|
||||||
|
进行信号的50hz陷波处理
|
||||||
|
Input:
|
||||||
|
data :输入脑电数据
|
||||||
|
Output:
|
||||||
|
dataFiltered : 陷波处理后的脑电数据
|
||||||
|
'''
|
||||||
|
|
||||||
|
def northFilter(self, data):
|
||||||
|
try:
|
||||||
|
if np.all(self.notchZh[0] == 0):
|
||||||
|
zi = np.zeros([max(len(self.north_a), len(self.north_b)) - 1, self.num_chans])
|
||||||
|
_, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=zi.T)
|
||||||
|
dataFiltered = lfilter(self.north_b, self.north_a, data)
|
||||||
|
else:
|
||||||
|
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
|
||||||
|
return np.asmatrix(dataFiltered)
|
||||||
|
except Exception:
|
||||||
|
print(Exception)
|
||||||
|
|
||||||
|
'''
|
||||||
|
getDataQ
|
||||||
|
Inputs:
|
||||||
|
data:脑电数据
|
||||||
|
Rbuffer:待更新的中间系数
|
||||||
|
Output:
|
||||||
|
Qs1 : 脑电特征1
|
||||||
|
Qs2 : 脑电特征2
|
||||||
|
Rbuffer : 单窗口更新后的系数
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
def getDataQ(self, data, Rbuffer):
|
||||||
|
Qs1 = [0] * self.num_fbs
|
||||||
|
Qs2 = [0] * self.num_fbs
|
||||||
|
nulldata = np.zeros([self.num_chans, self.num_chans])
|
||||||
|
Rnum = self.num_chans
|
||||||
|
for fb_num in range(self.num_fbs):
|
||||||
|
fb_data = np.squeeze(data[fb_num, :, :])
|
||||||
|
if np.all(Rbuffer[fb_num] == 0):
|
||||||
|
whiteTemp = fb_data
|
||||||
|
Q, R = qr(whiteTemp, mode='economic')
|
||||||
|
Qs1[fb_num] = nulldata
|
||||||
|
Qs2[fb_num] = Q
|
||||||
|
Rbuffer[fb_num] = R
|
||||||
|
else:
|
||||||
|
whiteTemp = np.concatenate((Rbuffer[fb_num], fb_data), axis=0)
|
||||||
|
Q, R = qr(whiteTemp, mode='economic')
|
||||||
|
Qs1[fb_num] = Q[0:Rnum, :]
|
||||||
|
Qs2[fb_num] = Q[Rnum:, :]
|
||||||
|
Rbuffer[fb_num] = R
|
||||||
|
return Qs1, Qs2, Rbuffer
|
||||||
|
|
||||||
|
'''
|
||||||
|
myCCA:根据脑电特征和参考信号计算相关系数
|
||||||
|
Inputs:
|
||||||
|
dataQ:脑电特征
|
||||||
|
Qc2y:参考信号
|
||||||
|
d : 相关系数取值数
|
||||||
|
Output:
|
||||||
|
rho : 相关系数
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
def myCCA(self, dataQ, Qc2y, d):
|
||||||
|
if len(Qc2y) == 0:
|
||||||
|
Cov = dataQ
|
||||||
|
else:
|
||||||
|
Cov = np.dot(Qc2y, dataQ)
|
||||||
|
# U, S, V = scipy.linalg.svd(Cov, 0)
|
||||||
|
# rho = np.minimum(np.maximum(np.diag(S[0: d]).T, 0), 1)
|
||||||
|
_, S, _ = np.linalg.svd(Cov, full_matrices=False)
|
||||||
|
rho = S
|
||||||
|
return rho[0:d]
|
||||||
|
|
||||||
|
'''
|
||||||
|
weightCCA:计算分类标签
|
||||||
|
Inputs:
|
||||||
|
Qs1:脑电特征1
|
||||||
|
Qs2:脑电特征2
|
||||||
|
ref : 正余弦参考信号
|
||||||
|
Cxy : 协方差中间参数
|
||||||
|
Output:
|
||||||
|
result : 分类标签
|
||||||
|
rho : 相关系数
|
||||||
|
Cxy : 更新后的协方差中间参数
|
||||||
|
'''
|
||||||
|
|
||||||
|
def weightCCA(self, Qs1, Qs2, ref, Cxy):
|
||||||
|
rMax = np.zeros([self.num_fbs, self.num_target])
|
||||||
|
for fi in range(self.num_fbs):
|
||||||
|
for si in range(self.num_target):
|
||||||
|
Qc2y = np.squeeze(ref[si, :, :])
|
||||||
|
# 更新协方差矩阵
|
||||||
|
if np.all(Cxy[fi][si] == 0):
|
||||||
|
Cxy[fi, si] = np.dot(Qc2y, Qs2[fi])
|
||||||
|
else:
|
||||||
|
Cxy[fi, si] = np.dot(Cxy[fi, si], Qs1[fi]) + np.dot(Qc2y, Qs2[fi])
|
||||||
|
r = self.myCCA(Cxy[fi, si], [], self.rhoNum)
|
||||||
|
rMax[fi, si] = r[0]
|
||||||
|
rho = np.dot(self.weightValue, np.power(rMax, 2)) # weighted sum of r from all different filter banks' result
|
||||||
|
result = np.argmax(rho)
|
||||||
|
return result, rho, Cxy
|
||||||
|
|
||||||
|
'''
|
||||||
|
costF:损失函数,根据计算的相关系数,生成决策值,用于和阈值进行比较
|
||||||
|
Inputs:
|
||||||
|
rho:相关系数
|
||||||
|
method:相关系数计算参数
|
||||||
|
C : 参数
|
||||||
|
Output:
|
||||||
|
decideValue : 决策阈值
|
||||||
|
'''
|
||||||
|
|
||||||
|
def costF(self, rho, method, C):
|
||||||
|
rho = rho.tolist()
|
||||||
|
rho.sort(reverse=True)
|
||||||
|
if method == 'DW1':
|
||||||
|
decideValue = (rho[0] - rho[1]) / (sum(rho) - self.num_target * np.log(sum(np.exp(rho))))
|
||||||
|
elif method == 'DW11':
|
||||||
|
decideValue = -(rho[0] - rho[1])
|
||||||
|
elif method == 'DW2':
|
||||||
|
decideValue = (rho[0] - C) / (rho[1] - rho[0])
|
||||||
|
return decideValue
|
||||||
|
|
||||||
|
'''
|
||||||
|
onlineInit:将窗口长度,相位值、中间参数初始化
|
||||||
|
'''
|
||||||
|
|
||||||
|
def onlineInit(self):
|
||||||
|
self.dataUseLen = [0] * self.winNum
|
||||||
|
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
|
||||||
|
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
|
||||||
|
self.phase = 0
|
||||||
|
|
||||||
|
'''
|
||||||
|
filterInit:重置陷波器和滤波器的滤波参数
|
||||||
|
'''
|
||||||
|
|
||||||
|
def filterInit(self):
|
||||||
|
self.notchZh = [0]
|
||||||
|
self.filterZf = [0] * self.num_fbs
|
||||||
|
|
||||||
|
'''
|
||||||
|
warmFilter:预热滤波器,去刺激前的4S数据对陷波器和滤波器参数进行初始化迭代,去除过渡带的效果
|
||||||
|
Inputs:
|
||||||
|
data:预处理脑电数据
|
||||||
|
'''
|
||||||
|
|
||||||
|
def warmFilter(self, data):
|
||||||
|
# 降采样在采集前完成
|
||||||
|
temp = self.preprocessFilter(data) #预热陷波滤波器
|
||||||
|
# 滤波器组频带分解
|
||||||
|
filterData = self.filterbank(temp) #预热滤波器组
|
||||||
|
|
||||||
|
'''
|
||||||
|
myDownSample:数据降采样
|
||||||
|
Inputs:
|
||||||
|
data:脑电数据
|
||||||
|
n:降采样的倍数
|
||||||
|
Output:
|
||||||
|
eegData2 : 降采样后的数据
|
||||||
|
'''
|
||||||
|
|
||||||
|
def myDownSample(self, data, n):
|
||||||
|
data = data[:8, self.phase:]
|
||||||
|
dataNum = data.shape[1]
|
||||||
|
remainNum = (dataNum - 1) % n
|
||||||
|
self.phase = n - 1 - remainNum
|
||||||
|
dataDowmSample = []
|
||||||
|
for value in data:
|
||||||
|
value = value[0:value.size:n]
|
||||||
|
dataDowmSample.append(value)
|
||||||
|
eegData2 = np.array(dataDowmSample).reshape([8, int(np.round(data.shape[1] / n))])
|
||||||
|
return eegData2
|
||||||
|
|
||||||
|
'''
|
||||||
|
preprocessFilter:预处理,调用函数降采样和陷波处理
|
||||||
|
Inputs:
|
||||||
|
data:脑电数据
|
||||||
|
Output:
|
||||||
|
filterData : 降采样和陷波后的数据
|
||||||
|
'''
|
||||||
|
|
||||||
|
def preprocessFilter(self, data):
|
||||||
|
# data = self.myDownSample(data, 4)
|
||||||
|
# filterData = self.northFilter(data[:8, :])
|
||||||
|
filterData = self.northFilter(data[:, :])
|
||||||
|
return filterData
|
||||||
|
|
||||||
|
'''
|
||||||
|
fbccaDWMW:分类函数,对输入的脑电信号进行识别,输出决策标签
|
||||||
|
Inputs:
|
||||||
|
testdata:脑电数据
|
||||||
|
referenceData:参考信号
|
||||||
|
tValue:出决策阈值
|
||||||
|
Output:
|
||||||
|
res : 决策标签
|
||||||
|
rho_new:相关系数
|
||||||
|
minEps:得到的决策阈值
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 动态窗算法主函数
|
||||||
|
def fbccaDWMW(self, testdata, referenceData, tValue,calculateCount):
|
||||||
|
t1 = time.time()
|
||||||
|
# try:
|
||||||
|
# 初始参数
|
||||||
|
res = -1
|
||||||
|
minEps = float("inf")
|
||||||
|
# 降采样和陷波器处理
|
||||||
|
northData = self.preprocessFilter(testdata)
|
||||||
|
newSampleNum = northData.shape[1]
|
||||||
|
# 数据大于延迟长度,则无法根据后面的规则更新窗口
|
||||||
|
if newSampleNum > self.winDelayNum:
|
||||||
|
error('need add window delay time')
|
||||||
|
|
||||||
|
# 防止秩小于导联数
|
||||||
|
if newSampleNum < self.num_chans:
|
||||||
|
warnings.warn('data shape is [%d %d] need more data' % (newSampleNum, northData.shape[0]))
|
||||||
|
# 滤波器组频带分解
|
||||||
|
filterData = self.filterbank(northData)
|
||||||
|
winMinTime = 0
|
||||||
|
# 计算每个窗口的结果
|
||||||
|
for wi in range(0, self.winNum, self.winStep):
|
||||||
|
# print('dataUseLen:',wi,calculateCount, self.dataUseLen)
|
||||||
|
if wi == 0:
|
||||||
|
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
|
||||||
|
else:
|
||||||
|
if self.dataUseLen[wi] == 0:
|
||||||
|
# 判断当前窗是否为新的窗口(因为每一次新的窗口进来时,都会使上一个窗口datauseLen>50)
|
||||||
|
if self.dataUseLen[wi - self.winStep] > self.winDelayNum*self.winStep:
|
||||||
|
self.dataUseLen[wi] = newSampleNum
|
||||||
|
else:
|
||||||
|
# print('中断: ',wi,calculateCount)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
|
||||||
|
|
||||||
|
if self.dataUseLen[wi] > self.winMaxSampleNum:
|
||||||
|
self.dataUseLen[wi] = newSampleNum
|
||||||
|
self.Rbuffer[wi, :, :, :] = 0
|
||||||
|
self.Cxy[wi, :, :, :, :] = 0
|
||||||
|
Qs1, Qs2, self.Rbuffer[wi, :, :, :] = self.getDataQ(filterData, self.Rbuffer[wi, :, :, :])
|
||||||
|
si = self.dataUseLen[wi] - newSampleNum
|
||||||
|
ei = self.dataUseLen[wi]
|
||||||
|
ref = referenceData[:, :, si:ei]
|
||||||
|
# 更新协方差
|
||||||
|
predLabel_new, rho_new, self.Cxy[wi, :, :, :] = self.weightCCA(Qs1, Qs2, ref, self.Cxy[wi, :, :, :])
|
||||||
|
# 增加限制,数据长度不能太短
|
||||||
|
if self.dataUseLen[wi] > winMinTime * self.fs:
|
||||||
|
epsilon = self.costF(rho_new, self.DW_cost_method, C=0)
|
||||||
|
if epsilon < minEps:
|
||||||
|
minEps = epsilon
|
||||||
|
predLabel = predLabel_new
|
||||||
|
xxx = rho_new
|
||||||
|
if minEps < tValue:
|
||||||
|
res = predLabel
|
||||||
|
|
||||||
|
if time.time() - t1 > 0.2 and self.winStep < 16:
|
||||||
|
self.winStep = self.winStep * 2
|
||||||
|
# print(self.winStep, " ", time.time() - t1)
|
||||||
|
# if res != -1:
|
||||||
|
# print('--------------------- ',res,xxx,' --------------------------')
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# The number of sub-bands in filter bank analysis
|
||||||
|
fs = 250
|
||||||
|
num_chans = 8
|
||||||
|
num_target = 40
|
||||||
|
num_filterBank = 3
|
||||||
|
num_harm = 5
|
||||||
|
stimTime = 0.2 # 多窗口窗长
|
||||||
|
winNum = 50 # 窗口的个数
|
||||||
|
trials = 1
|
||||||
|
step = 50
|
||||||
|
res = -1
|
||||||
|
list_freqs = [8.0, 8.2, 8.4, 8.6, 8.8, 9.0, 9.2, 9.4, 9.6, 9.8, 10.0, 10.2, 10.4, 10.6, 10.8, 11.0, 11.2, 11.4,
|
||||||
|
11.6, 11.8, 12.0, 12.2, 12.4, 12.6, 12.8, 13., 13.2, 13.4, 13.6, 13.8, 14., 14.2, 14.4, 14.6, 14.8,
|
||||||
|
15.0, 15.2, 15.4, 15.6, 15.8]
|
||||||
|
# 初始化对象
|
||||||
|
dw = FbccaDw(fs, num_target, num_chans, num_filterBank, num_harm, stimTime, [1.0, 0.3], [8, 8], winNum)
|
||||||
|
# frequenceband
|
||||||
|
dw.filterFrequenceBank()
|
||||||
|
referenceData = dw.reference(list_freqs, int(winNum * stimTime * fs), num_harm)
|
||||||
|
dw.setNotchFilterPara()
|
||||||
|
|
||||||
|
prelabels = np.zeros((1, 40))
|
||||||
|
coefficient = np.zeros([1, 1])
|
||||||
|
path = "D:\\工作相关\\项目代码\\SDK封装与测试\\八神BCI脑电信号SDK20230105\\双鹰SDK对比\\offline_data250\\"
|
||||||
|
for index in range(1, trials + 1):
|
||||||
|
D = loadmat(os.path.join(path + str(1) + '-warmData.mat'))
|
||||||
|
warmData = D['warmData']
|
||||||
|
dw.onlineInit()
|
||||||
|
dw.filterInit()
|
||||||
|
dw.warmFilter(warmData.T)
|
||||||
|
|
||||||
|
tagget_i = 0
|
||||||
|
for tagget_i in range(1, step + 1):
|
||||||
|
D = loadmat(os.path.join(path + str(1) + '-' + str(tagget_i) + '.mat'))
|
||||||
|
dataSlice = D['dataTemp']
|
||||||
|
res = dw.fbccaDWMW(dataSlice.T, referenceData, tValue=-0.2)
|
||||||
|
if res != -1:
|
||||||
|
break
|
||||||
|
prelabels[0, index - 1] = res + 1
|
||||||
|
print(index, '--', res + 1," 计算轮数", tagget_i)
|
||||||
851
Debug_64ch_Decoder/Tools/plot_MI_EEG.py
Normal file
851
Debug_64ch_Decoder/Tools/plot_MI_EEG.py
Normal file
@@ -0,0 +1,851 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.patches import Ellipse
|
||||||
|
import matplotlib.cm as cm
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
from scipy.spatial import Delaunay
|
||||||
|
from scipy.interpolate import Rbf
|
||||||
|
from scipy.signal import welch
|
||||||
|
from scipy.stats import sem
|
||||||
|
from scipy.signal import butter, filtfilt, hilbert
|
||||||
|
import base64
|
||||||
|
|
||||||
|
# 位置坐标
|
||||||
|
def read_ch_pos(file_path=r'xy_64.xlsx'):
|
||||||
|
"""
|
||||||
|
将电极位置信息转换为Dict
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path: 电极位置存储文件, 必须包含'channel', 'x', 'y', 'z'列
|
||||||
|
|
||||||
|
"""
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_path = os.path.join(script_dir,file_path )
|
||||||
|
df = pd.read_excel(file_path)
|
||||||
|
# 确保列名正确
|
||||||
|
if not all(col in df.columns for col in ['channel', 'x', 'y', 'z']):
|
||||||
|
raise ValueError("DataFrame必须包含'channel', 'x', 'y', 'z'列")
|
||||||
|
# 创建电极位置字典
|
||||||
|
ch_pos = {}
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
ch_pos[row['channel']] = [row['x'], row['y'], row['z']]
|
||||||
|
return ch_pos
|
||||||
|
# 头部轮廓
|
||||||
|
def draw_head(ax, center=(0, 0), radius=1.0, zorder=4):
|
||||||
|
"""
|
||||||
|
绘制头部轮廓、鼻子和耳朵。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- ax : matplotlib Axes 对象
|
||||||
|
- center : (x, y) 头中心坐标
|
||||||
|
- radius : float, 头半径
|
||||||
|
- zorder : 绘制层级
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 头圆
|
||||||
|
head = plt.Circle(center, radius, fill=False, color='k', linewidth=1, zorder=zorder)
|
||||||
|
ax.add_artist(head)
|
||||||
|
|
||||||
|
# 鼻子(参考 _make_head_outlines)
|
||||||
|
dx = np.exp(np.arccos(np.deg2rad(12)) * 1j)
|
||||||
|
dx_real, dx_imag = dx.real, dx.imag
|
||||||
|
nose_x = np.array([-dx_real, 0, dx_real]) * radius + center[0]
|
||||||
|
nose_y = np.array([dx_imag, 1.15, dx_imag]) * radius + center[1]
|
||||||
|
ax.plot(nose_x, nose_y, color='k', linewidth=1, zorder=zorder)
|
||||||
|
|
||||||
|
# 耳朵(参考 _make_head_outlines 手动标定)
|
||||||
|
ear_radius = radius * 0.12
|
||||||
|
ear_scale = radius * 2 # 根据半径缩放
|
||||||
|
theta = np.linspace(np.pi / 2, 3 * np.pi / 2, 30)
|
||||||
|
|
||||||
|
# 左耳
|
||||||
|
left_ear_x_array = np.array([0.497, 0.510, 0.518, 0.5299, 0.5419,
|
||||||
|
0.54, 0.547, 0.532, 0.510, 0.489]) * ear_scale
|
||||||
|
left_ear_y_array = np.array([0.0555, 0.0775, 0.0783, 0.0746, 0.0555,
|
||||||
|
-0.0055, -0.0932, -0.1313, -0.1384, -0.1199]) * ear_scale + center[1]
|
||||||
|
ax.plot(center[0] - left_ear_x_array, left_ear_y_array, color='k', linewidth=1, zorder=zorder)
|
||||||
|
|
||||||
|
# 右耳
|
||||||
|
right_ear_x_array = np.array([0.497, 0.510, 0.518, 0.5299, 0.5419,
|
||||||
|
0.54, 0.547, 0.532, 0.510, 0.489]) * ear_scale
|
||||||
|
right_ear_y_array = np.array([0.0555, 0.0775, 0.0783, 0.0746, 0.0555,
|
||||||
|
-0.0055, -0.0932, -0.1313, -0.1384, -0.1199]) * ear_scale + center[1]
|
||||||
|
ax.plot(center[0] + right_ear_x_array, right_ear_y_array, color='k', linewidth=1, zorder=zorder)
|
||||||
|
# 地形图 插值
|
||||||
|
def rbf_D_interpolate(xy, v, center=(0, 0), radius=1.1, grid_res=300,
|
||||||
|
n_extra=32, rbf_func='multiquadric', smooth=0,
|
||||||
|
border='mean', border_scale=1.0001, n_ngb=4):
|
||||||
|
"""
|
||||||
|
使用 RBF + Delaunay 邻域均值方式生成平滑的 EEG topomap 插值表面。
|
||||||
|
|
||||||
|
参数
|
||||||
|
----
|
||||||
|
xy : (N,2) array
|
||||||
|
电极二维坐标(与绘图坐标系一致)
|
||||||
|
v : (N,) array
|
||||||
|
每个电极对应的值(e.g. PSD)
|
||||||
|
center : tuple (x0, y0)
|
||||||
|
头部圆心(默认 (0,0))
|
||||||
|
radius : float
|
||||||
|
头部半径(用于生成边界点与网格范围)
|
||||||
|
grid_res : int
|
||||||
|
网格分辨率(每轴点数)
|
||||||
|
n_extra : int
|
||||||
|
边界虚拟点数量
|
||||||
|
rbf_func : str
|
||||||
|
RBF 内核名称('multiquadric','thin_plate','gaussian',...)
|
||||||
|
smooth : float
|
||||||
|
RBF 平滑参数
|
||||||
|
border : 'mean' or float
|
||||||
|
若 'mean':边界点用邻近真实通道均值赋值(推荐)
|
||||||
|
若 float:边界点赋相同常数值
|
||||||
|
border_scale : float
|
||||||
|
边界点半径相对 radius 的缩放(略微 >1 用以外推)
|
||||||
|
n_ngb : int
|
||||||
|
为每个边界点取值时使用的最近真实通道数
|
||||||
|
|
||||||
|
返回
|
||||||
|
----
|
||||||
|
zi : (grid_res, grid_res) ndarray
|
||||||
|
插值结果(与 grid_x, grid_y 对齐)
|
||||||
|
grid_x, grid_y : ndarrays
|
||||||
|
meshgrid(由 np.meshgrid 生成)
|
||||||
|
"""
|
||||||
|
xy = np.asarray(xy)
|
||||||
|
v = np.asarray(v)
|
||||||
|
if xy.ndim != 2 or xy.shape[1] != 2:
|
||||||
|
raise ValueError("xy must be shape (n_channels, 2)")
|
||||||
|
|
||||||
|
n_points = xy.shape[0]
|
||||||
|
|
||||||
|
# --- 1. 生成边界虚拟点(圆周) ---
|
||||||
|
theta = np.linspace(0.0, 2 * np.pi, n_extra, endpoint=False)
|
||||||
|
r_border = radius * border_scale
|
||||||
|
border_xy = np.column_stack([center[0] + r_border * np.cos(theta),
|
||||||
|
center[1] + r_border * np.sin(theta)])
|
||||||
|
|
||||||
|
# --- 2. 用 Delaunay 建图以便找到邻居(对边界点取邻居均值) ---
|
||||||
|
# 合并用于三角化的位置(真实点 + 边界点)
|
||||||
|
tri_xy = np.vstack([xy, border_xy])
|
||||||
|
tri = Delaunay(tri_xy)
|
||||||
|
|
||||||
|
# --- 3. 为边界点赋值 ---
|
||||||
|
if isinstance(border, str) and border == 'mean':
|
||||||
|
# 使用 Delaunay 的 vertex_neighbor_vertices 索引
|
||||||
|
# 注意:tri.vertex_neighbor_vertices 给出 vertices -> neighbor indptr
|
||||||
|
indices, indptr = tri.vertex_neighbor_vertices
|
||||||
|
v_extra = np.zeros(n_extra)
|
||||||
|
used = np.zeros(n_extra, dtype=bool)
|
||||||
|
# 边界点在 tri_xy 中的索引范围
|
||||||
|
rng = range(n_points, n_points + n_extra)
|
||||||
|
for idx, extra_idx in enumerate(rng):
|
||||||
|
neigh = indptr[indices[extra_idx]:indices[extra_idx + 1]]
|
||||||
|
# 仅保留原始点索引(小于 n_points)
|
||||||
|
neigh = neigh[neigh < n_points]
|
||||||
|
if neigh.size > 0:
|
||||||
|
used[idx] = True
|
||||||
|
# 使用最近 n_ngb 个邻居的均值(若邻居多则取最近的 n_ngb)
|
||||||
|
if neigh.size > n_ngb:
|
||||||
|
# 计算距离并选取最近 n_ngb
|
||||||
|
d = np.linalg.norm(xy[neigh] - tri_xy[extra_idx], axis=1)
|
||||||
|
order = np.argsort(d)[:n_ngb]
|
||||||
|
sel = neigh[order]
|
||||||
|
else:
|
||||||
|
sel = neigh
|
||||||
|
v_extra[idx] = v[sel].mean()
|
||||||
|
if not used.all() and used.any():
|
||||||
|
v_extra[~used] = np.mean(v_extra[used])
|
||||||
|
elif not used.any():
|
||||||
|
v_extra[:] = np.mean(v)
|
||||||
|
else:
|
||||||
|
# border 是数值
|
||||||
|
v_extra = np.full(n_extra, float(border))
|
||||||
|
|
||||||
|
# --- 4. 合并所有已知点并构建 RBF ---
|
||||||
|
all_xy = np.vstack([xy, border_xy])
|
||||||
|
all_v = np.concatenate([v, v_extra])
|
||||||
|
|
||||||
|
rbf = Rbf(all_xy[:, 0], all_xy[:, 1], all_v,
|
||||||
|
function=rbf_func, smooth=smooth)
|
||||||
|
|
||||||
|
# --- 5. 生成网格(使用 meshgrid,与主函数保持一致) ---
|
||||||
|
xmin, xmax = center[0] - radius, center[0] + radius
|
||||||
|
ymin, ymax = center[1] - radius, center[1] + radius
|
||||||
|
xi = np.linspace(xmin, xmax, grid_res)
|
||||||
|
yi = np.linspace(ymin, ymax, grid_res)
|
||||||
|
grid_x, grid_y = np.meshgrid(xi, yi) # meshgrid 与 imshow 对齐
|
||||||
|
|
||||||
|
# --- 6. 评估 RBF,返回与 grid 对齐的 zi ---
|
||||||
|
zi = rbf(grid_x, grid_y)
|
||||||
|
|
||||||
|
return zi, grid_x, grid_y
|
||||||
|
# plv矩阵计算
|
||||||
|
def calculate_plv(data):
|
||||||
|
"""
|
||||||
|
计算相位锁定值(PLV)矩阵。
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : ndarray, shape (num_channels, num_samples)
|
||||||
|
EEG 数据,通道数为 num_channels,样本数为 num_samples。
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
plv_matrix : ndarray, shape (num_channels, num_channels)
|
||||||
|
计算得到的 PLV 矩阵,表示各通道间的相位同步。
|
||||||
|
"""
|
||||||
|
num_channels, num_samples = data.shape
|
||||||
|
plv_matrix = np.zeros((num_channels, num_channels))
|
||||||
|
|
||||||
|
# 计算每个通道的解析信号
|
||||||
|
analytic_signals = np.apply_along_axis(hilbert, axis=1, arr=data)
|
||||||
|
|
||||||
|
for i in range(num_channels):
|
||||||
|
for j in range(i + 1, num_channels): # 只计算上三角矩阵,避免重复计算
|
||||||
|
# 计算 phase difference
|
||||||
|
phase_diff = np.angle(analytic_signals[i] * np.conj(analytic_signals[j]))
|
||||||
|
plv = np.abs(np.mean(np.exp(1j * phase_diff)))
|
||||||
|
plv_matrix[i, j] = plv
|
||||||
|
plv_matrix[j, i] = plv # 对称矩阵
|
||||||
|
|
||||||
|
return plv_matrix
|
||||||
|
# 矩阵阈值化
|
||||||
|
def threshold_proportional(adj, prop=0.2):
|
||||||
|
"""
|
||||||
|
Apply a proportional threshold to retain the top proportion of strongest edges.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
adj : ndarray, shape (n_channels, n_channels)
|
||||||
|
Adjacency matrix to threshold.
|
||||||
|
prop : float
|
||||||
|
Proportion of edges to retain (0 < prop <= 1).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bin_adj : ndarray, shape (n_channels, n_channels)
|
||||||
|
Binary adjacency matrix after thresholding.
|
||||||
|
"""
|
||||||
|
n = adj.shape[0]
|
||||||
|
triu_idx = np.triu_indices(n, k=1)
|
||||||
|
weights = adj[triu_idx]
|
||||||
|
k = int(np.floor(len(weights) * prop))
|
||||||
|
|
||||||
|
# Ensure that at least one edge is retained
|
||||||
|
k = max(k, 1)
|
||||||
|
|
||||||
|
# Get the threshold value
|
||||||
|
thr = np.sort(weights)[-k]
|
||||||
|
|
||||||
|
# Apply the threshold to create a binary adjacency matrix
|
||||||
|
bin_adj = np.where(adj >= thr, adj, 0.0)
|
||||||
|
|
||||||
|
return bin_adj
|
||||||
|
# 单个脑网络
|
||||||
|
def plot_single_network(ch_names,adj,ax=None,
|
||||||
|
node_size=20, node_color='orange',highlight_nodes=[], show_names=True,
|
||||||
|
edge_color='gray', weighted=True,
|
||||||
|
radius=1.1, figsize=(6, 6),cmap='RdYlBu_r'):
|
||||||
|
# 若 ax 未传入,则自己创建
|
||||||
|
own_fig = False
|
||||||
|
if ax is None:
|
||||||
|
fig, ax = plt.subplots(figsize=figsize)
|
||||||
|
own_fig = True
|
||||||
|
else:
|
||||||
|
fig = ax.figure
|
||||||
|
|
||||||
|
# 坐标归一化
|
||||||
|
pos3d = read_ch_pos()
|
||||||
|
all_chs_xy = np.array([pos3d[ch][:2] for ch in pos3d.keys()])
|
||||||
|
all_chs_xy -= all_chs_xy.mean(axis=0)
|
||||||
|
all_chs_xy /= np.sqrt((all_chs_xy ** 2).sum(axis=1)).max()
|
||||||
|
xy_dict = dict(zip(pos3d.keys(), all_chs_xy))
|
||||||
|
xy = np.array([xy_dict[ch] for ch in ch_names])
|
||||||
|
center = xy_dict.get('CZ', np.mean(list(xy_dict.values()), axis=0))
|
||||||
|
|
||||||
|
# ===== 初始化绘图窗口 =====
|
||||||
|
ax.set_aspect('equal')
|
||||||
|
ax.axis('off')
|
||||||
|
# 设置边界(与原类保持一致)
|
||||||
|
ear_radius = radius * 0.12
|
||||||
|
nose_height = radius * 0.15
|
||||||
|
margin_x = radius * 0.12 + 0.05
|
||||||
|
ax.set_xlim(center[0] - radius - margin_x, center[0] + radius + margin_x)
|
||||||
|
ax.set_ylim(center[1] - radius - ear_radius, center[1] + radius + nose_height + ear_radius)
|
||||||
|
|
||||||
|
# 绘制头部轮廓
|
||||||
|
draw_head(ax, center=center, radius=radius)
|
||||||
|
|
||||||
|
# 节点
|
||||||
|
for ch in ch_names:
|
||||||
|
color = 'red' if ch in highlight_nodes else node_color
|
||||||
|
ax.scatter(*xy_dict[ch], s=node_size, color=color, edgecolor='k', zorder=4)
|
||||||
|
if show_names:
|
||||||
|
ax.text(xy_dict[ch][0], xy_dict[ch][1] + 0.03, ch,
|
||||||
|
ha='center', va='bottom', fontsize=8, zorder=5)
|
||||||
|
|
||||||
|
# colorbar
|
||||||
|
norm = mcolors.Normalize(vmin=0, vmax=1)
|
||||||
|
color_map = matplotlib.colormaps.get_cmap(cmap)
|
||||||
|
# ========= 边 ==========
|
||||||
|
N = len(ch_names)
|
||||||
|
for i in range(N):
|
||||||
|
for j in range(i + 1, N):
|
||||||
|
w = adj[i, j]
|
||||||
|
if w > 0:
|
||||||
|
x = [xy[i, 0], xy[j, 0]]
|
||||||
|
y = [xy[i, 1], xy[j, 1]]
|
||||||
|
lw = 1.5
|
||||||
|
if weighted:
|
||||||
|
ax.plot(x, y,
|
||||||
|
color=color_map(norm(w)),
|
||||||
|
linewidth=lw,
|
||||||
|
alpha=0.7,
|
||||||
|
zorder=3)
|
||||||
|
else:
|
||||||
|
ax.plot(x, y,
|
||||||
|
color=edge_color,
|
||||||
|
linewidth=lw,
|
||||||
|
alpha=0.7,
|
||||||
|
zorder=3)
|
||||||
|
|
||||||
|
if own_fig:
|
||||||
|
# 不回传 添加颜色条
|
||||||
|
sm = cm.ScalarMappable(norm=norm, cmap=color_map)
|
||||||
|
cbar = plt.colorbar(sm, ax=ax, fraction=0.035)
|
||||||
|
cbar.set_label('Connection Strength', fontsize=10)
|
||||||
|
cbar.ax.tick_params(direction='in', labelsize=10)
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
else:
|
||||||
|
|
||||||
|
return ax
|
||||||
|
# 脑网络对比
|
||||||
|
def plot_multiband_network(ch_names, adj_MI, adj_Rest,cmap='RdYlBu_r'):
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
|
||||||
|
fontsize = 16
|
||||||
|
fig.text(0.285, 0.08, 'MI', fontsize=fontsize, ha='center', va='center', rotation=0)
|
||||||
|
fig.text(0.68, 0.08, 'Rest', fontsize=fontsize, ha='center', va='center', rotation=0)
|
||||||
|
|
||||||
|
im1 = plot_single_network(ch_names,adj_MI,ax=axes[0], show_names=True,cmap=cmap)
|
||||||
|
# Rest 行
|
||||||
|
im2 = plot_single_network(ch_names,adj_Rest,ax=axes[1],show_names=True,cmap=cmap)
|
||||||
|
|
||||||
|
# --- 合并 colorbar(右侧一个) ---
|
||||||
|
norm = mcolors.Normalize(vmin=0, vmax=1)
|
||||||
|
color_map = matplotlib.colormaps.get_cmap(cmap)
|
||||||
|
sm = cm.ScalarMappable(norm=norm, cmap=color_map)
|
||||||
|
cbar = plt.colorbar(sm, ax=axes.ravel().tolist(), fraction=0.02)
|
||||||
|
cbar.set_label('Connection Strength', fontsize=10)
|
||||||
|
cbar.ax.tick_params(direction='in', labelsize=10)
|
||||||
|
|
||||||
|
# 将图像保存到内存字节流(PNG 格式)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close(fig) # 释放内存
|
||||||
|
buf.seek(0)
|
||||||
|
image_bytes = buf.read()
|
||||||
|
buf.close()
|
||||||
|
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
# 多个频带psd
|
||||||
|
def compute_band_psd(eeg, fs, bands, labels, trial_idx=0,MI_label=1, Rest_label=2,avg = True):
|
||||||
|
"""
|
||||||
|
eeg: (n_trials, n_channels, n_samples)
|
||||||
|
"""
|
||||||
|
n_trials, n_channels, n_samples = eeg.shape
|
||||||
|
band_names = list(bands.keys())
|
||||||
|
n_bands = len(band_names)
|
||||||
|
|
||||||
|
psd_MI = np.zeros((n_bands, n_channels))
|
||||||
|
psd_Rest = np.zeros((n_bands, n_channels))
|
||||||
|
|
||||||
|
# 先计算所有 trial 的功率谱
|
||||||
|
f, Pxx = welch(eeg, fs=fs, axis=-1, nperseg=fs,noverlap = fs // 2)
|
||||||
|
|
||||||
|
|
||||||
|
for bi, (bname, (f1, f2)) in enumerate(bands.items()):
|
||||||
|
idx = np.logical_and(f >= f1, f <= f2)
|
||||||
|
band_power = Pxx[:, :, idx].mean(axis=-1)
|
||||||
|
|
||||||
|
band_power_flat = band_power.flatten()
|
||||||
|
power_min = band_power_flat.min()
|
||||||
|
power_max = band_power_flat.max()
|
||||||
|
if power_max - power_min > 1e-12:
|
||||||
|
band_power_norm = (band_power - power_min) / (power_max - power_min)
|
||||||
|
else:
|
||||||
|
band_power_norm = band_power
|
||||||
|
|
||||||
|
if avg:
|
||||||
|
psd_MI[bi] = band_power_norm[labels == MI_label].mean(axis=0)
|
||||||
|
psd_Rest[bi] = band_power_norm[labels == Rest_label].mean(axis=0)
|
||||||
|
else:
|
||||||
|
psd_MI[bi] = band_power_norm[labels == MI_label][trial_idx]
|
||||||
|
psd_Rest[bi] = band_power_norm[labels == Rest_label][trial_idx]
|
||||||
|
return band_names, psd_MI, psd_Rest
|
||||||
|
# 单个脑地形图
|
||||||
|
def plot_single_topomap(ch_names, psd_values, cmap='RdYlBu_r', vlim=(0, 1),
|
||||||
|
show_names=True, node_size=3, radius=1.1, grid_res=300,
|
||||||
|
n_contours=None, contour_color='k',
|
||||||
|
ax=None,figsize=(6,6)):
|
||||||
|
# 若 ax 未传入,则自己创建
|
||||||
|
own_fig = False
|
||||||
|
if ax is None:
|
||||||
|
fig, ax = plt.subplots(figsize=figsize)
|
||||||
|
own_fig = True
|
||||||
|
else:
|
||||||
|
fig = ax.figure
|
||||||
|
|
||||||
|
# ===== 初始化绘图窗口 =====
|
||||||
|
ax.set_aspect('equal')
|
||||||
|
ax.axis('off')
|
||||||
|
# ax.set_title("EEG topomap (MNE-like)")
|
||||||
|
|
||||||
|
# 坐标归一化
|
||||||
|
pos3d = read_ch_pos()
|
||||||
|
all_chs_xy = np.array([pos3d[ch][:2] for ch in pos3d.keys()])
|
||||||
|
all_chs_xy -= all_chs_xy.mean(axis=0)
|
||||||
|
all_chs_xy /= np.sqrt((all_chs_xy ** 2).sum(axis=1)).max()
|
||||||
|
pos2d_dict = dict(zip(pos3d.keys(), all_chs_xy))
|
||||||
|
xy = np.array([pos2d_dict[ch] for ch in ch_names])
|
||||||
|
center = pos2d_dict.get('CZ', np.mean(list(pos2d_dict.values()), axis=0))
|
||||||
|
|
||||||
|
# 绘制头部轮廓
|
||||||
|
draw_head(ax, center=center, radius=radius)
|
||||||
|
# 绘制电极
|
||||||
|
fontsize = 4
|
||||||
|
ax.scatter(xy[:, 0], xy[:, 1], c='k', s=node_size, zorder=5)
|
||||||
|
if show_names:
|
||||||
|
for i, ch in enumerate(ch_names):
|
||||||
|
ax.text(xy[i, 0], xy[i, 1] + 0.03, ch,
|
||||||
|
ha='center', va='bottom', fontsize=fontsize, zorder=6)
|
||||||
|
|
||||||
|
# 数据插值
|
||||||
|
zi, grid_x, grid_y = rbf_D_interpolate(
|
||||||
|
xy, psd_values, radius=radius,
|
||||||
|
grid_res=grid_res
|
||||||
|
)
|
||||||
|
xmin, xmax = center[0] - radius, center[0] + radius
|
||||||
|
ymin, ymax = center[1] - radius, center[1] + radius
|
||||||
|
extent = (xmin, xmax, ymin, ymax)
|
||||||
|
im = ax.imshow(zi, extent=extent, origin='lower',
|
||||||
|
cmap=cmap, vmin=vlim[0], vmax=vlim[1],
|
||||||
|
interpolation='bicubic', zorder=0)
|
||||||
|
# 裁剪路径
|
||||||
|
patch_ = Ellipse(center, 2 * radius, 2 * radius, clip_on=True, transform=ax.transData)
|
||||||
|
im.set_clip_path(patch_)
|
||||||
|
# 初始等高线
|
||||||
|
linewidths = 0.5
|
||||||
|
if n_contours is None:
|
||||||
|
cset = ax.contour(grid_x, grid_y, zi,
|
||||||
|
colors=contour_color, linewidths=linewidths, zorder=2)
|
||||||
|
else:
|
||||||
|
cset = ax.contour(grid_x, grid_y, zi, levels=n_contours,
|
||||||
|
colors=contour_color, linewidths=linewidths, zorder=2)
|
||||||
|
cset.set_clip_path(patch_)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if own_fig:
|
||||||
|
# 不回传 添加颜色条
|
||||||
|
plt.colorbar(im, ax=ax, fraction=0.035)
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
else:
|
||||||
|
# plt.colorbar(im, ax=ax, fraction=0.035)
|
||||||
|
return im
|
||||||
|
# 脑地形图对比
|
||||||
|
def plot_multiband_topomaps(ch_names, psd_MI, psd_Rest, bands):
|
||||||
|
band_names = list(bands.keys()) # 改动 1:新增这行
|
||||||
|
n_bands = len(band_names)
|
||||||
|
fig, axes = plt.subplots(2, n_bands, figsize=(3*n_bands, 6))
|
||||||
|
|
||||||
|
fontsize = 16
|
||||||
|
|
||||||
|
axes[0, 0].text(-0.1, 0.5, 'MI', transform=axes[0, 0].transAxes, rotation=0, va='center', ha='center', fontsize=fontsize-2)
|
||||||
|
axes[1, 0].text(-0.1, 0.5, 'Rest', transform=axes[1, 0].transAxes, rotation=0, va='center', ha='center', fontsize=fontsize-2)
|
||||||
|
|
||||||
|
imgs = []
|
||||||
|
for i, bname in enumerate(band_names):
|
||||||
|
axes[0, i].set_title(bname, fontsize=fontsize, pad=0)
|
||||||
|
# MI 行
|
||||||
|
im1 = plot_single_topomap(ch_names,psd_MI[i],ax=axes[0, i], show_names=True)
|
||||||
|
# Rest 行
|
||||||
|
im2 = plot_single_topomap(ch_names,psd_Rest[i],ax=axes[1, i],show_names=True)
|
||||||
|
imgs.append(im1)
|
||||||
|
|
||||||
|
# --- 单个右侧合并 colorbar ---
|
||||||
|
cbar = fig.colorbar(imgs[0], ax=axes,fraction=0.02)
|
||||||
|
# cbar.set_label("PSD Power",fontsize=fontsize-4)
|
||||||
|
cbar.ax.tick_params(direction='in', labelsize=10)
|
||||||
|
|
||||||
|
# 将图像保存到内存字节流(PNG 格式)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close(fig) # 释放内存
|
||||||
|
buf.seek(0)
|
||||||
|
image_bytes = buf.read()
|
||||||
|
buf.close()
|
||||||
|
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
# 小波
|
||||||
|
def morlet_wavelet(f, fs, n_cycles=7):
|
||||||
|
"""
|
||||||
|
创建 Morlet 小波
|
||||||
|
f: 频率
|
||||||
|
fs: 采样率
|
||||||
|
"""
|
||||||
|
sigma_t = n_cycles / (2 * np.pi * f)
|
||||||
|
t = np.arange(-3*sigma_t, 3*sigma_t, 1/fs)
|
||||||
|
wavelet = (np.pi**-0.25) * np.exp(2j*np.pi*f*t) * np.exp(-(t**2)/(2*sigma_t**2))
|
||||||
|
return wavelet
|
||||||
|
|
||||||
|
|
||||||
|
# 希尔伯特变换 计算ERDS 效果不佳
|
||||||
|
def bandpass_filter(data, fs, band, order=4):
|
||||||
|
nyq = fs / 2
|
||||||
|
b, a = butter(order, [band[0]/nyq, band[1]/nyq], btype='band')
|
||||||
|
return filtfilt(b, a, data, axis=-1)
|
||||||
|
def compute_power_hilbert(filtered_data,is_dB =True):
|
||||||
|
analytic = hilbert(filtered_data, axis=-1)
|
||||||
|
power = np.abs(analytic) ** 2
|
||||||
|
if is_dB:
|
||||||
|
power = 10 * np.log10(power)
|
||||||
|
return power
|
||||||
|
def compute_power(data, fs=250,
|
||||||
|
bands={"mu": (8,12), "beta": (13,30)}):
|
||||||
|
"""
|
||||||
|
返回:
|
||||||
|
power_dict[band] = (n_trials, n_ch, n_samples)
|
||||||
|
"""
|
||||||
|
power_dict = {}
|
||||||
|
for band_name, band_range in bands.items():
|
||||||
|
filt = bandpass_filter(data, fs, band_range)
|
||||||
|
power = compute_power_hilbert(filt)
|
||||||
|
power_dict[band_name] = power
|
||||||
|
|
||||||
|
return power_dict
|
||||||
|
|
||||||
|
def compute_erds(power_MI, power_Rest, baseline_period=None):
|
||||||
|
"""
|
||||||
|
计算事件相关去同步/同步 (ERDS)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
power_MI, power_Rest: (n_trials, n_ch, n_samples)
|
||||||
|
功率数据,单位为 µV² 或 dB(取决于 compute_power_hilbert 的 is_dB 参数)
|
||||||
|
baseline_period: tuple (start_idx, end_idx) or None
|
||||||
|
基线时间段索引。如果为None,使用 Rest 状态的平均值作为基线
|
||||||
|
|
||||||
|
返回:
|
||||||
|
MI_erds_mean, MI_erds_sem
|
||||||
|
Rest_erds_mean, Rest_erds_sem
|
||||||
|
所有返回值的形状为 (n_ch, n_samples)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if baseline_period is not None:
|
||||||
|
start_idx, end_idx = baseline_period
|
||||||
|
baseline = np.concatenate([power_MI[:, :, start_idx:end_idx],
|
||||||
|
power_Rest[:, :, start_idx:end_idx]], axis=0)
|
||||||
|
baseline = baseline.mean(axis=(0, 2), keepdims=True)
|
||||||
|
else:
|
||||||
|
baseline = power_Rest.mean(axis=(0,2), keepdims=True)
|
||||||
|
|
||||||
|
# === ERDS (%) ===
|
||||||
|
MI_erds = (power_MI - baseline) / baseline * 100
|
||||||
|
Rest_erds = (power_Rest - baseline) / baseline * 100
|
||||||
|
|
||||||
|
return (
|
||||||
|
MI_erds.mean(axis=0), sem(MI_erds, axis=0),
|
||||||
|
Rest_erds.mean(axis=0), sem(Rest_erds, axis=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_all_erds(MI_power_dict, Rest_power_dict):
|
||||||
|
"""
|
||||||
|
对多个频带同时计算 ERDS。
|
||||||
|
|
||||||
|
输入:
|
||||||
|
MI_power_dict[band] = (n_trials, n_ch, n_samples)
|
||||||
|
Rest_power_dict[band] = (n_trials, n_ch, n_samples)
|
||||||
|
|
||||||
|
输出:
|
||||||
|
erds_MI[band] = (mean, sem)
|
||||||
|
erds_Rest[band] = (mean, sem)
|
||||||
|
"""
|
||||||
|
|
||||||
|
erds_MI = {}
|
||||||
|
erds_Rest = {}
|
||||||
|
|
||||||
|
for band in MI_power_dict.keys():
|
||||||
|
MI_power = MI_power_dict[band]
|
||||||
|
Rest_power = Rest_power_dict[band]
|
||||||
|
|
||||||
|
MI_mean, MI_sem, Rest_mean, Rest_sem = compute_erds(MI_power, Rest_power)
|
||||||
|
|
||||||
|
erds_MI[band] = (MI_mean, MI_sem)
|
||||||
|
erds_Rest[band] = (Rest_mean, Rest_sem)
|
||||||
|
|
||||||
|
return erds_MI, erds_Rest
|
||||||
|
|
||||||
|
def plot_compare_erds(data_MI, data_Rest, mode="power",
|
||||||
|
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'],
|
||||||
|
compare_names=['C3', 'CZ', 'C4'], bands=['mu', 'beta'],
|
||||||
|
fs=250, t=None, figsize=(12,6)):
|
||||||
|
|
||||||
|
n_bands = len(bands)
|
||||||
|
n_chs = len(compare_names)
|
||||||
|
|
||||||
|
# 自动添加单位
|
||||||
|
if mode == "power":
|
||||||
|
# y_unit = "Power (µV²)"
|
||||||
|
y_unit = "Power (dB)"
|
||||||
|
elif mode == "erds":
|
||||||
|
y_unit = "ERDS (%)"
|
||||||
|
else:
|
||||||
|
y_unit = ""
|
||||||
|
|
||||||
|
if t is None:
|
||||||
|
n_samples = next(iter(data_MI.values())).shape[-1] \
|
||||||
|
if mode=="power" else next(iter(data_MI.values()))[0].shape[-1]
|
||||||
|
t = np.arange(n_samples) / fs
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(n_bands, n_chs, figsize=figsize, sharex=True, sharey=True)
|
||||||
|
|
||||||
|
for i, band in enumerate(bands):
|
||||||
|
|
||||||
|
# 选择数据结构
|
||||||
|
if mode == "power":
|
||||||
|
MI_band = data_MI[band] # (trials, ch, samples)
|
||||||
|
Rest_band = data_Rest[band]
|
||||||
|
|
||||||
|
avg_MI = MI_band.mean(axis=0)
|
||||||
|
sem_MI = MI_band.std(axis=0)/np.sqrt(MI_band.shape[0])
|
||||||
|
|
||||||
|
avg_Rest = Rest_band.mean(axis=0)
|
||||||
|
sem_Rest = Rest_band.std(axis=0)/np.sqrt(Rest_band.shape[0])
|
||||||
|
|
||||||
|
elif mode == "erds":
|
||||||
|
avg_MI, sem_MI = data_MI[band]
|
||||||
|
avg_Rest, sem_Rest = data_Rest[band]
|
||||||
|
|
||||||
|
for j, ch in enumerate(compare_names):
|
||||||
|
ax = axes[i, j] if n_bands > 1 else axes[j]
|
||||||
|
|
||||||
|
ch_idx = ch_names.index(ch)
|
||||||
|
|
||||||
|
# 绘制 MI
|
||||||
|
ax.plot(t, avg_MI[ch_idx], color="C0", label="MI")
|
||||||
|
ax.fill_between(t,
|
||||||
|
avg_MI[ch_idx]-sem_MI[ch_idx],
|
||||||
|
avg_MI[ch_idx]+sem_MI[ch_idx],
|
||||||
|
alpha=0.3, color="C0")
|
||||||
|
|
||||||
|
# 绘制 Rest
|
||||||
|
ax.plot(t, avg_Rest[ch_idx], color="C1", label="Rest")
|
||||||
|
ax.fill_between(t,
|
||||||
|
avg_Rest[ch_idx]-sem_Rest[ch_idx],
|
||||||
|
avg_Rest[ch_idx]+sem_Rest[ch_idx],
|
||||||
|
alpha=0.3, color="C1")
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
ax.set_title(ch)
|
||||||
|
|
||||||
|
# ← Y 轴加单位
|
||||||
|
if j == 0:
|
||||||
|
ax.set_ylabel(f"{band}\n{y_unit}")
|
||||||
|
|
||||||
|
if i == n_bands - 1:
|
||||||
|
ax.set_xlabel("Time (s)")
|
||||||
|
|
||||||
|
ax.grid(alpha=0.3)
|
||||||
|
|
||||||
|
if i == 0 and j == n_chs - 1:
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# 将图像保存到内存字节流(PNG 格式)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close(fig) # 释放内存
|
||||||
|
buf.seek(0)
|
||||||
|
image_bytes = buf.read()
|
||||||
|
buf.close()
|
||||||
|
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
# 对比 MI vs Rest 的功率谱密度 PSD
|
||||||
|
def plot_psd_compare(MI_data, Rest_data, ch_names, compare_names=['C3', 'CZ', 'C4'],
|
||||||
|
fs=250, nperseg=None, average=True, show_sem=True,
|
||||||
|
figsize=(12, 3), save_dir=None, filename="psd.png"):
|
||||||
|
"""
|
||||||
|
对比 MI vs Rest 的功率谱密度 PSD
|
||||||
|
|
||||||
|
MI_data, Rest_data: (n_trials, n_ch, n_samples)
|
||||||
|
channels: 需要绘制的通道
|
||||||
|
average: 是否对所有试次平均
|
||||||
|
show_sem: 是否绘制 SEM 阴影
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_trials, n_ch, n_samples = MI_data.shape
|
||||||
|
n_trials = min(len(MI_data), len(Rest_data))
|
||||||
|
# assert Rest_data.shape == MI_data.shape, "MI 和 Rest 数据维度必须一致"
|
||||||
|
|
||||||
|
if nperseg is None:
|
||||||
|
nperseg = fs # 每 1 秒窗长度
|
||||||
|
|
||||||
|
# 计算 MI PSD
|
||||||
|
psd_MI_all = []
|
||||||
|
for trial in range(n_trials):
|
||||||
|
psd_trial = []
|
||||||
|
for ch in range(n_ch):
|
||||||
|
f, Pxx = welch(MI_data[trial, ch], fs=fs, nperseg=nperseg)
|
||||||
|
psd_trial.append(Pxx)
|
||||||
|
psd_MI_all.append(psd_trial)
|
||||||
|
psd_MI_all = np.array(psd_MI_all)
|
||||||
|
|
||||||
|
# 计算 Rest PSD
|
||||||
|
psd_Rest_all = []
|
||||||
|
for trial in range(n_trials):
|
||||||
|
psd_trial = []
|
||||||
|
for ch in range(n_ch):
|
||||||
|
_, Pxx = welch(Rest_data[trial, ch], fs=fs, nperseg=nperseg)
|
||||||
|
psd_trial.append(Pxx)
|
||||||
|
psd_Rest_all.append(psd_trial)
|
||||||
|
psd_Rest_all = np.array(psd_Rest_all)
|
||||||
|
|
||||||
|
# ---- Plot ----
|
||||||
|
fig, ax = plt.subplots(1, len(compare_names), figsize=figsize)
|
||||||
|
if len(compare_names) == 1:
|
||||||
|
ax = [ax]
|
||||||
|
|
||||||
|
for i, ch in enumerate(compare_names):
|
||||||
|
ch_idx = ch_names.index(ch)
|
||||||
|
psd_MI_ch = psd_MI_all[:, ch_idx, :]
|
||||||
|
psd_Rest_ch = psd_Rest_all[:, ch_idx, :]
|
||||||
|
|
||||||
|
if average:
|
||||||
|
mean_MI = psd_MI_ch.mean(axis=0)
|
||||||
|
mean_Rest = psd_Rest_ch.mean(axis=0)
|
||||||
|
|
||||||
|
ax[i].plot(f, mean_MI, color='C0', label='MI')
|
||||||
|
ax[i].plot(f, mean_Rest, color='C1', label='Rest')
|
||||||
|
|
||||||
|
if show_sem:
|
||||||
|
ax[i].fill_between(f, mean_MI - sem(psd_MI_ch, axis=0),
|
||||||
|
mean_MI + sem(psd_MI_ch, axis=0), color='C0', alpha=0.3)
|
||||||
|
ax[i].fill_between(f, mean_Rest - sem(psd_Rest_ch, axis=0),
|
||||||
|
mean_Rest + sem(psd_Rest_ch, axis=0), color='C1', alpha=0.3)
|
||||||
|
else:
|
||||||
|
ax[i].plot(f, psd_MI_ch.T, color='C0', alpha=0.3)
|
||||||
|
ax[i].plot(f, psd_Rest_ch.T, color='C1', alpha=0.3)
|
||||||
|
|
||||||
|
ax[i].set_title(ch)
|
||||||
|
ax[i].set_xlabel("Frequency (Hz)")
|
||||||
|
ax[i].set_ylabel("PSD (μV²/Hz)")
|
||||||
|
ax[i].grid(alpha=0.3)
|
||||||
|
if i == 0:
|
||||||
|
ax[i].legend()
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# 将图像保存到内存字节流(PNG 格式)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close(fig) # 释放内存
|
||||||
|
buf.seek(0)
|
||||||
|
image_bytes = buf.read()
|
||||||
|
buf.close()
|
||||||
|
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def plotMain(
|
||||||
|
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'],
|
||||||
|
compare_names = [ 'C3','CZ','C4'],
|
||||||
|
Data = None,labels = None,MI_label = None,Rest_label = None,
|
||||||
|
fs = 250):
|
||||||
|
|
||||||
|
trial_idx = 0
|
||||||
|
|
||||||
|
# 数据划分
|
||||||
|
if not MI_label:
|
||||||
|
label_ = np.unique(labels)
|
||||||
|
else:
|
||||||
|
label_ = (MI_label,Rest_label)
|
||||||
|
MI_data = Data[labels == label_[0]]
|
||||||
|
Rest_data = Data[labels == label_[1]]
|
||||||
|
|
||||||
|
# 典型 EEG 频带
|
||||||
|
FREQ_BANDS = {
|
||||||
|
"Delta (0.8-4Hz)": (0.8, 4),
|
||||||
|
"Theta (4-8Hz)": (4, 8),
|
||||||
|
"Alpha (8-12Hz)": (8, 12),
|
||||||
|
"Beta (12-30Hz)": (12, 30),
|
||||||
|
"All (0.8-30Hz)": (0.8, 30)
|
||||||
|
}
|
||||||
|
# 利用welch估算PSD
|
||||||
|
band_names, psd_MI, psd_Rest= compute_band_psd(
|
||||||
|
eeg=Data,
|
||||||
|
fs=fs,
|
||||||
|
bands=FREQ_BANDS,
|
||||||
|
labels=labels,
|
||||||
|
trial_idx=trial_idx,
|
||||||
|
MI_label=MI_label,
|
||||||
|
Rest_label=Rest_label,
|
||||||
|
avg= True
|
||||||
|
)
|
||||||
|
# 绘制地形图
|
||||||
|
topomaps_imgBytes = plot_multiband_topomaps(
|
||||||
|
ch_names=ch_names,
|
||||||
|
psd_MI=psd_MI,
|
||||||
|
psd_Rest=psd_Rest,
|
||||||
|
bands=FREQ_BANDS
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制脑网络
|
||||||
|
mi_plv_matrix = calculate_plv(MI_data[trial_idx])
|
||||||
|
mi_BI_matrix = threshold_proportional(mi_plv_matrix, prop=0.3)
|
||||||
|
rest_plv_matrix = calculate_plv(Rest_data[trial_idx])
|
||||||
|
rest_BI_matrix = threshold_proportional(rest_plv_matrix, prop=0.3)
|
||||||
|
network_imgBytes = plot_multiband_network(ch_names, mi_BI_matrix, rest_BI_matrix)
|
||||||
|
|
||||||
|
# ERDS 先计算erds,后平均
|
||||||
|
MI_power = compute_power(MI_data)
|
||||||
|
Rest_power = compute_power(Rest_data)
|
||||||
|
erds_dict_MI, erds_dict_Rest = compute_all_erds(MI_power, Rest_power)
|
||||||
|
erds_imgBytes = plot_compare_erds(erds_dict_MI, erds_dict_Rest, ch_names=ch_names,
|
||||||
|
compare_names=compare_names, bands=['mu', 'beta'],
|
||||||
|
fs=fs, mode="erds")
|
||||||
|
|
||||||
|
# 绘制PSD
|
||||||
|
psd_imgBytes = plot_psd_compare(MI_data, Rest_data, ch_names = ch_names, compare_names=compare_names,
|
||||||
|
fs=fs, nperseg=None, average=True, show_sem=True,
|
||||||
|
figsize=(12, 3))
|
||||||
|
return {'topomaps_imgBytes':base64.b64encode(topomaps_imgBytes).decode(),'network_imgBytes':base64.b64encode(network_imgBytes).decode(),
|
||||||
|
'erds_imgBytes':base64.b64encode(erds_imgBytes).decode(),'psd_imgBytes':base64.b64encode(psd_imgBytes).decode()}
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
allData = np.random.uniform(-50,50,size=(80,21,1000))
|
||||||
|
allLabel = np.random.randint(1,3,size=(80,))
|
||||||
|
allData = allData[:len(allLabel)]
|
||||||
|
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
|
||||||
|
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
|
||||||
|
compare_names = ['C3', 'CZ', 'C4']
|
||||||
|
ret = plotMain(ch_names=ch_names, compare_names=compare_names, Data=allData, labels=allLabel, MI_label=1, Rest_label=2,
|
||||||
|
fs=250)
|
||||||
|
print('计算完成,开始发送')
|
||||||
|
from Tools.zmqClient import zmqClient
|
||||||
|
|
||||||
|
zmqClient = zmqClient('192.168.76.101', 8088)
|
||||||
|
zmqClient.connect()
|
||||||
|
zmqClient.send_to_all('miReport', ret)
|
||||||
57
Debug_64ch_Decoder/Zmq/zmqClient.py
Normal file
57
Debug_64ch_Decoder/Zmq/zmqClient.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
|
||||||
|
class zmqClient:
|
||||||
|
def __init__(self, host, port):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.client_socket = None
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# 记录客户端连接前的状态
|
||||||
|
self.state = {
|
||||||
|
'status_code': None,
|
||||||
|
'energy': None
|
||||||
|
}
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
# 创建 ZeroMQ 上下文
|
||||||
|
self.context = zmq.Context()
|
||||||
|
# 创建 REQ 套接字(请求端)
|
||||||
|
self.client_socket = self.context.socket(zmq.DEALER)
|
||||||
|
# client_id = b'client1'
|
||||||
|
# self.client_socket.setsockopt(zmq.IDENTITY,client_id)
|
||||||
|
self.client_socket.connect(f"tcp://{self.host}:{self.port}") # 连接到服务器
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
def send_to_all(self, method,params):
|
||||||
|
if method in self.state.keys():
|
||||||
|
self.state[method] = params
|
||||||
|
try:
|
||||||
|
if self.running and self.client_socket != None:
|
||||||
|
msg = {'method': method, 'params': params}
|
||||||
|
# 发送响应
|
||||||
|
print(msg)
|
||||||
|
self.client_socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
||||||
|
else:
|
||||||
|
if method in self.state.keys():
|
||||||
|
self.state[method] = params
|
||||||
|
except ConnectionResetError:
|
||||||
|
print("Connection lost.")
|
||||||
|
self.running = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
def close_connection(self):
|
||||||
|
self.running = False
|
||||||
|
self.client_socket.close()
|
||||||
|
self.context.term()
|
||||||
|
print("Client closed explicitly.")
|
||||||
|
# 使用TCP客户端
|
||||||
|
if __name__ == "__main__":
|
||||||
|
client = zmqClient('127.0.0.1', 8099)
|
||||||
|
client.connect()
|
||||||
|
# client.close_connection()
|
||||||
104
Debug_64ch_Decoder/Zmq/zmqServer.py
Normal file
104
Debug_64ch_Decoder/Zmq/zmqServer.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
import threading
|
||||||
|
import json
|
||||||
|
from Device.SunnyLinker import SunnyLinker64
|
||||||
|
|
||||||
|
class zmqServer(threading.Thread):
|
||||||
|
def __init__(self, host='0.0.0.0', port=8099):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.running = False
|
||||||
|
self.get_Impedance = False # 是否返回阻抗值
|
||||||
|
self.open_Impedance = None # 是否开启阻抗检测功能
|
||||||
|
self.StartDecode = False # false 停止解码,true=开始解码
|
||||||
|
self.StartTrain = False # False未进入训练状态,True处于训练状态
|
||||||
|
self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态
|
||||||
|
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
|
||||||
|
self.IsExitApp = False # 当socket收到2的时候,就置为True,代表遥退出系统了。
|
||||||
|
self.getReport = False # 获取训练报告内容
|
||||||
|
self.daemon = True
|
||||||
|
# 创建 ZeroMQ 上下文
|
||||||
|
self.context = zmq.Context()
|
||||||
|
# 创建 REP 套接字(响应端)
|
||||||
|
self.socket = self.context.socket(zmq.ROUTER)
|
||||||
|
self.socket.bind(f"tcp://{self.host}:{self.port}") # 绑定到端口 8099
|
||||||
|
self.targetFreqs = []
|
||||||
|
self.changeTarget = False # 更换目标频率
|
||||||
|
self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化
|
||||||
|
self.labels = [0x01, 0x02,0x03]
|
||||||
|
|
||||||
|
self.decoder_switch = False #更换解码器
|
||||||
|
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
|
||||||
|
def run(self):
|
||||||
|
self.running = True
|
||||||
|
print(f"Server is running on {self.host}:{self.port}")
|
||||||
|
try:
|
||||||
|
while self.running:
|
||||||
|
# 等待客户端请求
|
||||||
|
_,_,message = self.socket.recv_multipart()
|
||||||
|
message = json.loads(message.decode('utf-8'))
|
||||||
|
print(f"Received request: {message}")
|
||||||
|
# 处理请求
|
||||||
|
method = message.get("method")
|
||||||
|
params = message.get("params")
|
||||||
|
if method == "sync":
|
||||||
|
self.state_mode = 'sync'
|
||||||
|
if method == "targetFreqs":
|
||||||
|
if not isinstance(params,list):
|
||||||
|
print('targetFreqs must be a list')
|
||||||
|
continue
|
||||||
|
if params != self.targetFreqs:
|
||||||
|
self.targetFreqs = params
|
||||||
|
self.changeTarget = True
|
||||||
|
if method == "decoderClass":
|
||||||
|
if not isinstance(params,str):
|
||||||
|
print('decoderClass must be a str')
|
||||||
|
continue
|
||||||
|
if params != self.decoder_class:
|
||||||
|
self.decoder_class = params
|
||||||
|
self.decoder_switch = True
|
||||||
|
if method == "getReport":
|
||||||
|
self.getReport = True
|
||||||
|
if method == "train":#训练状态
|
||||||
|
self.state_mode = 'train'
|
||||||
|
self.StartTrain = True
|
||||||
|
self.currentLabel = params # 当前刺激端的训练标签
|
||||||
|
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
||||||
|
elif method == "predict":#预测状态
|
||||||
|
self.state_mode = 'predict'
|
||||||
|
if params == 1: #开始解码
|
||||||
|
self.StartDecode = True
|
||||||
|
self.sunnyLinker.push_trigger(0x63)
|
||||||
|
elif params == 2: #停止解码
|
||||||
|
self.IsExitApp = True
|
||||||
|
self.running = False
|
||||||
|
elif method == "rest": #休息状态
|
||||||
|
self.state_mode = 'rest'
|
||||||
|
elif method == "impedance":
|
||||||
|
if params == 1:
|
||||||
|
self.open_Impedance = True # 开启阻抗
|
||||||
|
self.get_Impedance = True # 返回阻抗
|
||||||
|
elif params == 2:
|
||||||
|
self.open_Impedance = False # 关闭阻抗
|
||||||
|
self.get_Impedance = False # 停止返回阻抗
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An socket error occurred: {e}")
|
||||||
|
finally:
|
||||||
|
self.running = False
|
||||||
|
# 关闭套接字和上下文
|
||||||
|
self.socket.close()
|
||||||
|
self.context.term()
|
||||||
|
print("Server socket and context closed.")
|
||||||
|
def stop(self):
|
||||||
|
"""显式关闭服务器"""
|
||||||
|
self.running = False
|
||||||
|
self.socket.close()
|
||||||
|
self.context.term()
|
||||||
|
print("Server closed explicitly.")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
server = zmqServer()
|
||||||
|
server.start()
|
||||||
132
Debug_64ch_Decoder/build_algorithm.spec
Normal file
132
Debug_64ch_Decoder/build_algorithm.spec
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
# -*- mode: python ; coding: utf-8 -*-
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from PyInstaller.utils.hooks import collect_submodules, collect_data_files
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 1. 工程配置区 (Project Config)
|
||||||
|
# ========================================================
|
||||||
|
block_cipher = None
|
||||||
|
ENTRY_POINT = 'runDecoder.py'
|
||||||
|
APP_NAME = 'runDecoder'
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 2. 依赖分析 (Dependency Analysis)
|
||||||
|
# ========================================================
|
||||||
|
# 显式声明的隐藏导入,确保 PyInstaller 能找到所有 C 扩展和动态模块
|
||||||
|
hidden_imports = [
|
||||||
|
# sklearn Cython 扩展(极易被遗漏)
|
||||||
|
'sklearn.utils._cython_blas',
|
||||||
|
'torchsummary',
|
||||||
|
'sklearn.neighbors._typedefs',
|
||||||
|
'sklearn.neighbors._quad_tree',
|
||||||
|
'sklearn.tree._utils',
|
||||||
|
'sklearn.tree._criterion',
|
||||||
|
'sklearn.tree._splitter',
|
||||||
|
'sklearn.tree._tree',
|
||||||
|
'sklearn.utils._weight_vector',
|
||||||
|
# torch 核心模块
|
||||||
|
'torch',
|
||||||
|
'torch.nn',
|
||||||
|
'torch.nn.modules',
|
||||||
|
'torch.nn.modules.activation',
|
||||||
|
'torch.nn.modules.batchnorm',
|
||||||
|
'torch.nn.modules.conv',
|
||||||
|
'torch.nn.modules.dropout',
|
||||||
|
'torch.nn.modules.linear',
|
||||||
|
'torch.nn.modules.normalization',
|
||||||
|
'torch.nn.modules.pooling',
|
||||||
|
'torch.nn.functional',
|
||||||
|
'torch.autograd',
|
||||||
|
'torch.optim',
|
||||||
|
'torch.utils.data',
|
||||||
|
'torch.cuda',
|
||||||
|
# einops(必须显式添加)
|
||||||
|
'einops',
|
||||||
|
'einops.layers',
|
||||||
|
'einops.layers.torch',
|
||||||
|
# 并行计算相关
|
||||||
|
'multiprocessing',
|
||||||
|
'multiprocessing.connection',
|
||||||
|
'multiprocessing.context',
|
||||||
|
'multiprocessing.managers',
|
||||||
|
'multiprocessing.pool',
|
||||||
|
'multiprocessing.process',
|
||||||
|
'multiprocessing.queues',
|
||||||
|
'multiprocessing.reduction',
|
||||||
|
'multiprocessing.sharedctypes',
|
||||||
|
'multiprocessing.synchronize',
|
||||||
|
'multiprocessing.util',
|
||||||
|
]
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 3. 资源锚定 (Data Anchoring)
|
||||||
|
# ========================================================
|
||||||
|
# 收集 torch 的数据文件(triton、算子权重等)
|
||||||
|
datas = collect_data_files('torch')
|
||||||
|
datas += collect_data_files('torchvision')
|
||||||
|
# 收集 einops 数据文件
|
||||||
|
datas += collect_data_files('einops')
|
||||||
|
# 收集 sklearn 数据文件
|
||||||
|
datas += collect_data_files('sklearn')
|
||||||
|
# 收集 scipy 数据文件
|
||||||
|
datas += collect_data_files('scipy')
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 4. 构建流程 (Build Process)
|
||||||
|
# ========================================================
|
||||||
|
a = Analysis(
|
||||||
|
[ENTRY_POINT],
|
||||||
|
pathex=[],
|
||||||
|
binaries=[],
|
||||||
|
datas=datas,
|
||||||
|
hiddenimports=hidden_imports,
|
||||||
|
hookspath=[],
|
||||||
|
hooksconfig={},
|
||||||
|
runtime_hooks=['rthook.py'],
|
||||||
|
excludes=['tkinter', 'PyQt5', 'PySide2', 'PySide6', 'IPython', 'notebook', 'jupyter'],
|
||||||
|
win_no_prefer_redirects=False,
|
||||||
|
win_private_assemblies=False,
|
||||||
|
cipher=block_cipher,
|
||||||
|
noarchive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||||
|
|
||||||
|
exe = EXE(
|
||||||
|
pyz,
|
||||||
|
a.scripts,
|
||||||
|
[],
|
||||||
|
exclude_binaries=True,
|
||||||
|
name=APP_NAME,
|
||||||
|
debug=False,
|
||||||
|
bootloader_ignore_signals=False,
|
||||||
|
strip=False,
|
||||||
|
upx=False,
|
||||||
|
console=True,
|
||||||
|
disable_windowed_traceback=False,
|
||||||
|
argv_emulation=False,
|
||||||
|
target_arch=None,
|
||||||
|
codesign_identity=None,
|
||||||
|
entitlements_file=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 5. 打包模式: OneDir (单文件夹) + 资源旁路
|
||||||
|
# ========================================================
|
||||||
|
coll = COLLECT(
|
||||||
|
exe,
|
||||||
|
a.binaries,
|
||||||
|
a.zipfiles,
|
||||||
|
a.datas,
|
||||||
|
# 显式复制资源文件夹到 exe 同级目录
|
||||||
|
Tree('online_Models', prefix='online_Models', excludes=['*.pyc', '__pycache__']),
|
||||||
|
Tree('Tools', prefix='Tools', excludes=['*.pyc', '__pycache__']),
|
||||||
|
# config.ini 作为单独文件
|
||||||
|
[('config.ini', 'config.ini', 'DATA')],
|
||||||
|
strip=False,
|
||||||
|
upx=False,
|
||||||
|
upx_exclude=[],
|
||||||
|
name=APP_NAME,
|
||||||
|
)
|
||||||
88
Debug_64ch_Decoder/build_with_copy.py
Normal file
88
Debug_64ch_Decoder/build_with_copy.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 1. 定义路径
|
||||||
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
DIST_DIR = os.path.join(BASE_DIR, 'dist')
|
||||||
|
APP_NAME = 'runDecoder'
|
||||||
|
TARGET_DIR = os.path.join(DIST_DIR, APP_NAME)
|
||||||
|
|
||||||
|
# 定义需要复制的资源 {源路径: 目标子路径}
|
||||||
|
# 目标子路径相对于 TARGET_DIR
|
||||||
|
RESOURCES = {
|
||||||
|
'config.ini': 'config.ini',
|
||||||
|
'online_Models': 'online_Models',
|
||||||
|
'Tools': 'Tools',
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. 清理旧构建
|
||||||
|
print("[1/3] Cleaning up old builds...")
|
||||||
|
if os.path.exists(DIST_DIR):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(DIST_DIR)
|
||||||
|
print(" Cleaned dist/")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clean dist/: {e}")
|
||||||
|
|
||||||
|
BUILD_DIR = os.path.join(BASE_DIR, 'build')
|
||||||
|
if os.path.exists(BUILD_DIR):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(BUILD_DIR)
|
||||||
|
print(" Cleaned build/")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clean build/: {e}")
|
||||||
|
|
||||||
|
# 3. 运行 PyInstaller
|
||||||
|
print("[2/3] Running PyInstaller...")
|
||||||
|
# 注意:我们这里不传 --noupx,因为已经在 spec 文件里把 upx=False 写死了
|
||||||
|
cmd = [
|
||||||
|
"pyinstaller",
|
||||||
|
"build_algorithm.spec",
|
||||||
|
"--clean"
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.check_call(cmd, shell=True)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
print("Error: PyInstaller failed.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 4. 复制外部资源文件夹
|
||||||
|
print("[3/3] Verifying and Copying external resources...")
|
||||||
|
|
||||||
|
for src_name, dst_name in RESOURCES.items():
|
||||||
|
src_path = os.path.join(BASE_DIR, src_name)
|
||||||
|
dst_path = os.path.join(TARGET_DIR, dst_name)
|
||||||
|
|
||||||
|
if os.path.exists(src_path):
|
||||||
|
if os.path.isfile(src_path):
|
||||||
|
# 如果是文件
|
||||||
|
try:
|
||||||
|
shutil.copy2(src_path, dst_path)
|
||||||
|
print(f" Copied file: {src_name} -> {dst_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error copying file {src_name}: {e}")
|
||||||
|
else:
|
||||||
|
# 如果是文件夹
|
||||||
|
if os.path.exists(dst_path):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(dst_path) # 先删除 spec 生成的旧文件夹 (如果有)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not remove existing dir {dst_path}: {e}")
|
||||||
|
try:
|
||||||
|
shutil.copytree(src_path, dst_path, ignore=shutil.ignore_patterns('*.pyc', '__pycache__'))
|
||||||
|
print(f" Copied dir: {src_name} -> {dst_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error copying dir {src_name}: {e}")
|
||||||
|
else:
|
||||||
|
print(f" Warning: Source resource not found at {src_path}")
|
||||||
|
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print(f"SUCCESS! Build artifacts are in: {TARGET_DIR}")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
153
Debug_64ch_Decoder/config.ini
Normal file
153
Debug_64ch_Decoder/config.ini
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
[system]
|
||||||
|
SSVEP_ThresholdValue = [1,-0.023]
|
||||||
|
;SSVEP_ThresholdValue = [2,-0.00200]
|
||||||
|
SSMVEP_IntervalEpoch = [0.2,2.2]
|
||||||
|
MI_IntervalEpoch = [0.5,4.5]
|
||||||
|
Device_type=2
|
||||||
|
Right_rehabilitation = 5
|
||||||
|
Fault_rehabilitation = 5
|
||||||
|
Num_blocks = 3
|
||||||
|
Num_trials = 20
|
||||||
|
Audio_device = -1
|
||||||
|
Rest_time = 2
|
||||||
|
Serial_port = COM44
|
||||||
|
|
||||||
|
|
||||||
|
[Layout]
|
||||||
|
main_splitter_left = 993
|
||||||
|
main_splitter_right = 922
|
||||||
|
right_splitter_left = 233
|
||||||
|
right_splitter_right = 771
|
||||||
|
left_splitter_left = 503
|
||||||
|
left_splitter_right = 501q
|
||||||
|
|
||||||
|
[channel]
|
||||||
|
channel_x_fp1 = 419
|
||||||
|
channel_y_fp1 = 124
|
||||||
|
channel_x_fc1 = 439
|
||||||
|
channel_y_fc1 = 296
|
||||||
|
channel_x_fp2 = 576
|
||||||
|
channel_y_fp2 = 124
|
||||||
|
channel_x_fc2 = 556
|
||||||
|
channel_y_fc2 = 299
|
||||||
|
channel_x_f3 = 397
|
||||||
|
channel_y_f3 = 231
|
||||||
|
channel_x_cp1 = 439
|
||||||
|
channel_y_cp1 = 426
|
||||||
|
channel_x_f4 = 601
|
||||||
|
channel_y_f4 = 232
|
||||||
|
channel_x_cp2 = 559
|
||||||
|
channel_y_cp2 = 425
|
||||||
|
channel_x_fc3 = 379
|
||||||
|
channel_y_fc3 = 295
|
||||||
|
channel_x_af4 = 571
|
||||||
|
channel_y_af4 = 171
|
||||||
|
channel_x_po8 = 645
|
||||||
|
channel_y_po8 = 564
|
||||||
|
channel_x_fpz = 499
|
||||||
|
channel_y_fpz = 112
|
||||||
|
channel_x_fcz = 499
|
||||||
|
channel_y_fcz = 300
|
||||||
|
channel_x_poz = 500
|
||||||
|
channel_y_poz = 554
|
||||||
|
channel_x_po5 = 387
|
||||||
|
channel_y_po5 = 551
|
||||||
|
channel_x_po6 = 611
|
||||||
|
channel_y_po6 = 551
|
||||||
|
channel_x_c3 = 373
|
||||||
|
channel_y_c3 = 363
|
||||||
|
channel_x_fc5 = 319
|
||||||
|
channel_y_fc5 = 292
|
||||||
|
channel_x_c4 = 620
|
||||||
|
channel_y_c4 = 363
|
||||||
|
channel_x_fc6 = 676
|
||||||
|
channel_y_fc6 = 288
|
||||||
|
channel_x_p3 = 398
|
||||||
|
channel_y_p3 = 491
|
||||||
|
channel_x_cp5 = 322
|
||||||
|
channel_y_cp5 = 430
|
||||||
|
channel_x_p4 = 600
|
||||||
|
channel_y_p4 = 489
|
||||||
|
channel_x_cp6 = 678
|
||||||
|
channel_y_cp6 = 430
|
||||||
|
channel_x_c5 = 313
|
||||||
|
channel_y_c5 = 361
|
||||||
|
channel_x_f6 = 650
|
||||||
|
channel_y_f6 = 223
|
||||||
|
channel_x_f5 = 349
|
||||||
|
channel_y_f5 = 224
|
||||||
|
channel_x_po4 = 573
|
||||||
|
channel_y_po4 = 551
|
||||||
|
channel_x_po3 = 429
|
||||||
|
channel_y_po3 = 550
|
||||||
|
channel_x_cp4 = 619
|
||||||
|
channel_y_cp4 = 424
|
||||||
|
channel_x_cp3 = 381
|
||||||
|
channel_y_cp3 = 426
|
||||||
|
channel_x_fc4 = 619
|
||||||
|
channel_y_fc4 = 295
|
||||||
|
channel_x_o1 = 423
|
||||||
|
channel_y_o1 = 598
|
||||||
|
channel_x_ft9 = 252
|
||||||
|
channel_y_ft9 = 168
|
||||||
|
channel_x_o2 = 576
|
||||||
|
channel_y_o2 = 597
|
||||||
|
channel_x_ft10 = 798
|
||||||
|
channel_y_ft10 = 277
|
||||||
|
channel_x_f7 = 295
|
||||||
|
channel_y_f7 = 214
|
||||||
|
channel_x_tp9 = 202
|
||||||
|
channel_y_tp9 = 445
|
||||||
|
channel_x_f8 = 701
|
||||||
|
channel_y_f8 = 215
|
||||||
|
channel_x_t7 = 252
|
||||||
|
channel_y_t7 = 362
|
||||||
|
channel_x_tp7 = 261
|
||||||
|
channel_y_tp7 = 436
|
||||||
|
channel_x_ft8 = 734
|
||||||
|
channel_y_ft8 = 283
|
||||||
|
channel_x_ft7 = 264
|
||||||
|
channel_y_ft7 = 286
|
||||||
|
channel_x_af8 = 645
|
||||||
|
channel_y_af8 = 159
|
||||||
|
channel_x_af7 = 351
|
||||||
|
channel_y_af7 = 160
|
||||||
|
channel_x_p6 = 652
|
||||||
|
channel_y_p6 = 499
|
||||||
|
channel_x_p5 = 348
|
||||||
|
channel_y_p5 = 499
|
||||||
|
channel_x_c6 = 683
|
||||||
|
channel_y_c6 = 362
|
||||||
|
channel_x_f1 = 447
|
||||||
|
channel_y_f1 = 236
|
||||||
|
channel_x_t8 = 745
|
||||||
|
channel_y_t8 = 361
|
||||||
|
channel_x_f2 = 549
|
||||||
|
channel_y_f2 = 235
|
||||||
|
channel_x_p7 = 300
|
||||||
|
channel_y_p7 = 505
|
||||||
|
channel_x_c1 = 435
|
||||||
|
channel_y_c1 = 363
|
||||||
|
channel_x_p8 = 698
|
||||||
|
channel_y_p8 = 508
|
||||||
|
channel_x_c2 = 559
|
||||||
|
channel_y_c2 = 359
|
||||||
|
channel_x_fz = 499
|
||||||
|
channel_y_fz = 238
|
||||||
|
channel_x_po7 = 354
|
||||||
|
channel_y_po7 = 562
|
||||||
|
channel_x_tp8 = 735
|
||||||
|
channel_y_tp8 = 438
|
||||||
|
channel_x_oz = 498
|
||||||
|
channel_y_oz = 609
|
||||||
|
channel_x_af3 = 428
|
||||||
|
channel_y_af3 = 170
|
||||||
|
channel_x_pz = 501
|
||||||
|
channel_y_pz = 486
|
||||||
|
channel_x_p2 = 551
|
||||||
|
channel_y_p2 = 483
|
||||||
|
channel_x_cz = 499
|
||||||
|
channel_y_cz = 361
|
||||||
|
channel_x_p1 = 448
|
||||||
|
channel_y_p1 = 488
|
||||||
|
|
||||||
BIN
Debug_64ch_Decoder/online_Models/Model_2025-11-15-11-11-50.pth
Normal file
BIN
Debug_64ch_Decoder/online_Models/Model_2025-11-15-11-11-50.pth
Normal file
Binary file not shown.
BIN
Debug_64ch_Decoder/online_Models/Model_2025-11-17-16-55-25.pth
Normal file
BIN
Debug_64ch_Decoder/online_Models/Model_2025-11-17-16-55-25.pth
Normal file
Binary file not shown.
BIN
Debug_64ch_Decoder/online_Models/Model_2025-11-18-10-15-35.pth
Normal file
BIN
Debug_64ch_Decoder/online_Models/Model_2025-11-18-10-15-35.pth
Normal file
Binary file not shown.
252
Debug_64ch_Decoder/online_Models/log_result.txt
Normal file
252
Debug_64ch_Decoder/online_Models/log_result.txt
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
0 0.5
|
||||||
|
1 0.5
|
||||||
|
2 0.5
|
||||||
|
3 0.5
|
||||||
|
4 0.5
|
||||||
|
5 0.5
|
||||||
|
6 0.4375
|
||||||
|
7 0.5
|
||||||
|
8 0.5
|
||||||
|
9 0.5
|
||||||
|
10 0.5
|
||||||
|
11 0.5
|
||||||
|
12 0.3125
|
||||||
|
13 0.5
|
||||||
|
14 0.375
|
||||||
|
15 0.5625
|
||||||
|
16 0.3125
|
||||||
|
17 0.4375
|
||||||
|
18 0.4375
|
||||||
|
19 0.4375
|
||||||
|
20 0.375
|
||||||
|
21 0.375
|
||||||
|
22 0.4375
|
||||||
|
23 0.4375
|
||||||
|
24 0.4375
|
||||||
|
25 0.4375
|
||||||
|
26 0.4375
|
||||||
|
27 0.4375
|
||||||
|
28 0.4375
|
||||||
|
29 0.5
|
||||||
|
30 0.5
|
||||||
|
31 0.5625
|
||||||
|
32 0.375
|
||||||
|
33 0.5625
|
||||||
|
34 0.5
|
||||||
|
35 0.4375
|
||||||
|
36 0.5
|
||||||
|
37 0.4375
|
||||||
|
38 0.4375
|
||||||
|
39 0.5625
|
||||||
|
40 0.5
|
||||||
|
41 0.5
|
||||||
|
42 0.5
|
||||||
|
43 0.5
|
||||||
|
44 0.5
|
||||||
|
45 0.5
|
||||||
|
46 0.5625
|
||||||
|
47 0.5625
|
||||||
|
48 0.4375
|
||||||
|
49 0.4375
|
||||||
|
50 0.5
|
||||||
|
51 0.5625
|
||||||
|
52 0.5
|
||||||
|
53 0.4375
|
||||||
|
54 0.5
|
||||||
|
55 0.625
|
||||||
|
56 0.4375
|
||||||
|
57 0.625
|
||||||
|
58 0.5
|
||||||
|
59 0.5
|
||||||
|
60 0.5
|
||||||
|
61 0.5625
|
||||||
|
62 0.625
|
||||||
|
63 0.625
|
||||||
|
64 0.5
|
||||||
|
65 0.5625
|
||||||
|
66 0.5
|
||||||
|
67 0.5
|
||||||
|
68 0.5
|
||||||
|
69 0.5
|
||||||
|
70 0.625
|
||||||
|
71 0.5
|
||||||
|
72 0.4375
|
||||||
|
73 0.5625
|
||||||
|
74 0.5625
|
||||||
|
75 0.625
|
||||||
|
76 0.4375
|
||||||
|
77 0.4375
|
||||||
|
78 0.4375
|
||||||
|
79 0.5625
|
||||||
|
80 0.5
|
||||||
|
81 0.5
|
||||||
|
82 0.4375
|
||||||
|
83 0.4375
|
||||||
|
84 0.4375
|
||||||
|
85 0.4375
|
||||||
|
86 0.625
|
||||||
|
87 0.5625
|
||||||
|
88 0.4375
|
||||||
|
89 0.4375
|
||||||
|
90 0.5625
|
||||||
|
91 0.4375
|
||||||
|
92 0.4375
|
||||||
|
93 0.5
|
||||||
|
94 0.4375
|
||||||
|
95 0.5625
|
||||||
|
96 0.5625
|
||||||
|
97 0.5
|
||||||
|
98 0.625
|
||||||
|
99 0.5625
|
||||||
|
100 0.5
|
||||||
|
101 0.5
|
||||||
|
102 0.5
|
||||||
|
103 0.5
|
||||||
|
104 0.5
|
||||||
|
105 0.625
|
||||||
|
106 0.625
|
||||||
|
107 0.625
|
||||||
|
108 0.4375
|
||||||
|
109 0.5625
|
||||||
|
110 0.5
|
||||||
|
111 0.625
|
||||||
|
112 0.5625
|
||||||
|
113 0.5
|
||||||
|
114 0.5
|
||||||
|
115 0.625
|
||||||
|
116 0.5
|
||||||
|
117 0.5625
|
||||||
|
118 0.625
|
||||||
|
119 0.625
|
||||||
|
120 0.4375
|
||||||
|
121 0.4375
|
||||||
|
122 0.4375
|
||||||
|
123 0.5
|
||||||
|
124 0.625
|
||||||
|
125 0.625
|
||||||
|
126 0.625
|
||||||
|
127 0.625
|
||||||
|
128 0.6875
|
||||||
|
129 0.5625
|
||||||
|
130 0.5625
|
||||||
|
131 0.4375
|
||||||
|
132 0.4375
|
||||||
|
133 0.4375
|
||||||
|
134 0.4375
|
||||||
|
135 0.5625
|
||||||
|
136 0.625
|
||||||
|
137 0.5625
|
||||||
|
138 0.5
|
||||||
|
139 0.4375
|
||||||
|
140 0.5
|
||||||
|
141 0.625
|
||||||
|
142 0.625
|
||||||
|
143 0.5625
|
||||||
|
144 0.625
|
||||||
|
145 0.5625
|
||||||
|
146 0.5625
|
||||||
|
147 0.5
|
||||||
|
148 0.5
|
||||||
|
149 0.5
|
||||||
|
150 0.4375
|
||||||
|
151 0.4375
|
||||||
|
152 0.5625
|
||||||
|
153 0.625
|
||||||
|
154 0.5
|
||||||
|
155 0.625
|
||||||
|
156 0.625
|
||||||
|
157 0.625
|
||||||
|
158 0.5625
|
||||||
|
159 0.5625
|
||||||
|
160 0.5625
|
||||||
|
161 0.625
|
||||||
|
162 0.5
|
||||||
|
163 0.5625
|
||||||
|
164 0.625
|
||||||
|
165 0.4375
|
||||||
|
166 0.5625
|
||||||
|
167 0.625
|
||||||
|
168 0.625
|
||||||
|
169 0.5625
|
||||||
|
170 0.5625
|
||||||
|
171 0.5
|
||||||
|
172 0.4375
|
||||||
|
173 0.5625
|
||||||
|
174 0.5
|
||||||
|
175 0.4375
|
||||||
|
176 0.5625
|
||||||
|
177 0.5
|
||||||
|
178 0.5625
|
||||||
|
179 0.5625
|
||||||
|
180 0.5625
|
||||||
|
181 0.5
|
||||||
|
182 0.5625
|
||||||
|
183 0.5
|
||||||
|
184 0.5625
|
||||||
|
185 0.5625
|
||||||
|
186 0.5625
|
||||||
|
187 0.5
|
||||||
|
188 0.4375
|
||||||
|
189 0.5
|
||||||
|
190 0.4375
|
||||||
|
191 0.4375
|
||||||
|
192 0.5
|
||||||
|
193 0.5625
|
||||||
|
194 0.5625
|
||||||
|
195 0.5625
|
||||||
|
196 0.625
|
||||||
|
197 0.5
|
||||||
|
198 0.5625
|
||||||
|
199 0.625
|
||||||
|
200 0.5
|
||||||
|
201 0.5
|
||||||
|
202 0.625
|
||||||
|
203 0.5625
|
||||||
|
204 0.625
|
||||||
|
205 0.5
|
||||||
|
206 0.5
|
||||||
|
207 0.625
|
||||||
|
208 0.625
|
||||||
|
209 0.5625
|
||||||
|
210 0.625
|
||||||
|
211 0.4375
|
||||||
|
212 0.5625
|
||||||
|
213 0.5
|
||||||
|
214 0.5
|
||||||
|
215 0.5625
|
||||||
|
216 0.4375
|
||||||
|
217 0.5
|
||||||
|
218 0.5625
|
||||||
|
219 0.5
|
||||||
|
220 0.625
|
||||||
|
221 0.5625
|
||||||
|
222 0.5625
|
||||||
|
223 0.625
|
||||||
|
224 0.5625
|
||||||
|
225 0.5625
|
||||||
|
226 0.625
|
||||||
|
227 0.5625
|
||||||
|
228 0.6875
|
||||||
|
229 0.5
|
||||||
|
230 0.5625
|
||||||
|
231 0.625
|
||||||
|
232 0.5
|
||||||
|
233 0.625
|
||||||
|
234 0.5
|
||||||
|
235 0.5
|
||||||
|
236 0.5
|
||||||
|
237 0.4375
|
||||||
|
238 0.625
|
||||||
|
239 0.5625
|
||||||
|
240 0.5625
|
||||||
|
241 0.5
|
||||||
|
242 0.5
|
||||||
|
243 0.5625
|
||||||
|
244 0.5625
|
||||||
|
245 0.5625
|
||||||
|
246 0.625
|
||||||
|
247 0.5
|
||||||
|
248 0.5
|
||||||
|
249 0.4375
|
||||||
|
The average accuracy is: 0.5235
|
||||||
|
The best accuracy is: 0.6875
|
||||||
22
Debug_64ch_Decoder/rthook.py
Normal file
22
Debug_64ch_Decoder/rthook.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 0. Matplotlib 非交互式后端(必须在导入 matplotlib.pyplot 之前设置)
|
||||||
|
# plot_MI_EEG.py 等模块会用到 pyplot,必须在打包后的无显示器环境下工作
|
||||||
|
# ============================================================
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
os.environ.setdefault('MPLBACKEND', 'Agg')
|
||||||
|
|
||||||
|
# 1. 路径自适应:在 Frozen 模式下,将当前工作目录切换到可执行文件所在目录
|
||||||
|
# 这样代码中使用的相对路径(如 './config.ini')就能正确指向 exe 旁边的文件
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
os.chdir(os.path.dirname(sys.executable))
|
||||||
|
|
||||||
|
# 2. 多进程保护:防止 Windows 下的无限递归炸弹
|
||||||
|
# Windows 下 multiprocessing 需要 freeze_support()
|
||||||
|
if sys.platform.startswith('win'):
|
||||||
|
multiprocessing.freeze_support()
|
||||||
17
Debug_64ch_Decoder/runDecoder.py
Normal file
17
Debug_64ch_Decoder/runDecoder.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
|
||||||
|
import time
|
||||||
|
from Decoder import Decoder_main
|
||||||
|
from PubLibrary.RunOnce import is_program_running
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
if not is_program_running():
|
||||||
|
decoder = Decoder_main()
|
||||||
|
decoder.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
decoder.start()
|
||||||
|
while not decoder.zmqServer.IsExitApp:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
decoder.stop()
|
||||||
632
Debug_64ch_Decoder_Optimize/Decoder.py
Normal file
632
Debug_64ch_Decoder_Optimize/Decoder.py
Normal file
@@ -0,0 +1,632 @@
|
|||||||
|
import ast
|
||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
import multiprocessing as mp
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from queue import Empty
|
||||||
|
from scipy import signal
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from Device.SunnyLinker import SunnyLinker64
|
||||||
|
from SSMVEP.algorithm.tdca import TDCA
|
||||||
|
from SSMVEP.algorithm.base import generate_cca_references
|
||||||
|
from concentration.algorithm.calculate_focus import Calculate
|
||||||
|
from blinkdetection.algorithm.eye_detection import blink_detection
|
||||||
|
from Zmq.zmqServer import zmqServer
|
||||||
|
from Zmq.zmqClient import zmqClient
|
||||||
|
from MI.Algorithm.conformer_2class import onlineTrain
|
||||||
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
from SSVEP.dwfbcca import FbccaDw
|
||||||
|
from Tools.plot_MI_EEG import plotMain
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
class Decoder_main(threading.Thread):
|
||||||
|
def __init__(self):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.Runing=True
|
||||||
|
self.decoder = None
|
||||||
|
|
||||||
|
self.fs = 250 # 采样率
|
||||||
|
self.energy = 0 # 电量
|
||||||
|
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||||
|
self.decoder_class = None #解码器类别
|
||||||
|
|
||||||
|
self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果
|
||||||
|
|
||||||
|
def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None):
|
||||||
|
self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type'))
|
||||||
|
_device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host'))
|
||||||
|
_device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port'))
|
||||||
|
_upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host'))
|
||||||
|
_upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port'))
|
||||||
|
|
||||||
|
if self.DeviceType == 1:
|
||||||
|
self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.fs, 64, method='tcp')
|
||||||
|
self.thread_data_server.host = _device_host
|
||||||
|
self.thread_data_server.port = _device_port
|
||||||
|
|
||||||
|
self.thread_data_server.toUv = True
|
||||||
|
self.thread_data_server.start()
|
||||||
|
|
||||||
|
self.zmqServer = zmqServer()
|
||||||
|
self.zmqServer.start()
|
||||||
|
|
||||||
|
self.zmqClient = zmqClient(_upper_host, _upper_port)
|
||||||
|
self.zmqClient.set_zmq_server(self.zmqServer)
|
||||||
|
self.zmqClient.connect()
|
||||||
|
|
||||||
|
def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号
|
||||||
|
# data: (chans, samples)
|
||||||
|
energy = np.mean(np.var(data, axis=1)) # 各通道方差均值
|
||||||
|
if energy > threshold:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
def init_Decoder(self,decoder_class):
|
||||||
|
'''
|
||||||
|
初始化解码器
|
||||||
|
:param decoder_class: 'ssvep' or 'ssmvep' or 'mi' or 'concentration' or ''
|
||||||
|
:return:
|
||||||
|
'''
|
||||||
|
self.decoder_class = decoder_class
|
||||||
|
if decoder_class == 'ssvep' or decoder_class == 'pvs':
|
||||||
|
self.n_chan = 8
|
||||||
|
self.thread_data_server.interval_inited = False
|
||||||
|
DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue'))
|
||||||
|
self.ListFreq = self.zmqServer.targetFreqs
|
||||||
|
self.num_target = len(self.ListFreq)
|
||||||
|
if self.num_target == 0:
|
||||||
|
return
|
||||||
|
# 初始化对象 二代算法
|
||||||
|
self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5,
|
||||||
|
0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method)
|
||||||
|
# frequence band
|
||||||
|
self.dw.filterFrequenceBank()
|
||||||
|
self.dw.setNotchFilterPara()
|
||||||
|
self.calculateCount = 0
|
||||||
|
self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs),
|
||||||
|
5)
|
||||||
|
self.dw.filterInit()
|
||||||
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||||
|
|
||||||
|
elif decoder_class == 'ssmvep':
|
||||||
|
self.thread_data_server.interval_init(decoder_class)
|
||||||
|
self.n_chan = 8
|
||||||
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||||
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||||
|
self.single_train = 10 # 单类别数量
|
||||||
|
self.num_target = 2 # 分类目标数目
|
||||||
|
self.list_freqs = np.array([8, 9]) # 刺激频率
|
||||||
|
self.list_phase = np.array([0, 0]) # 相位
|
||||||
|
self.tdca = TDCA(padding_len=5, n_components=1)
|
||||||
|
self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length,
|
||||||
|
phases=self.list_phase, n_harmonics=5)
|
||||||
|
self.parameter_init(5,45)
|
||||||
|
|
||||||
|
elif decoder_class == 'mi' or decoder_class == 'ma':
|
||||||
|
self.thread_data_server.interval_init(decoder_class)
|
||||||
|
self.n_chan = 21
|
||||||
|
self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||||
|
self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位
|
||||||
|
self.single_train = 40 # 单类别数量
|
||||||
|
self.num_target = 2 # 分类目标数目
|
||||||
|
|
||||||
|
self.parameter_init(8, 30)
|
||||||
|
|
||||||
|
elif decoder_class == 'concentration':
|
||||||
|
self.thread_data_server.interval_inited = False
|
||||||
|
self.n_chan = 6
|
||||||
|
self.win_len = 10
|
||||||
|
self.win_step = 1
|
||||||
|
self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue'))
|
||||||
|
self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len)
|
||||||
|
self.interval_epoch = [0, 1]
|
||||||
|
self.parameter_init(2, 40)
|
||||||
|
# self.eegQueue moved to Calculate class
|
||||||
|
|
||||||
|
elif decoder_class == 'blink':
|
||||||
|
self.n_chan = 2
|
||||||
|
self.l_freq = 0.1 # 带通滤波器低频截止
|
||||||
|
self.h_freq = 8.0 # 带通滤波器高频截止
|
||||||
|
self.total_samples = 0 # 总采样点数
|
||||||
|
self.window_ms = 600 # 检测窗口大小 (ms)
|
||||||
|
self.step_ms = 100 # 滑动步长 (ms)
|
||||||
|
self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点
|
||||||
|
self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点
|
||||||
|
self.buffer_size = self.window_samples + self.step_samples * 5
|
||||||
|
self.fp1_buffer = deque(maxlen=self.buffer_size)
|
||||||
|
self.fp2_buffer = deque(maxlen=self.buffer_size)
|
||||||
|
self.sample_counter = 0
|
||||||
|
# 预计算滤波器系数,避免在循环中重复设计
|
||||||
|
self.Dmin, self.Dmax, self.EMin, self.EMax, self.jitterwin,self.double_blink_interval,self.double_blink_jitter = ast.literal_eval(IniRead('system', 'blink'))
|
||||||
|
self.blink_count = 0 # 单次眨眼的次数
|
||||||
|
self.last_blink_time = 0 # 上次检测到单次眨眼的时间(样本索引)
|
||||||
|
self.blink_timestamps = deque(maxlen=10) # 记录最近10次 单次眨眼的时间戳
|
||||||
|
self.double_blink_count = 0 # 连续两次眨眼的次数
|
||||||
|
self.double_blink_events = [] # 连续眨眼事件记录
|
||||||
|
self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳
|
||||||
|
self.blink_events = []
|
||||||
|
self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band')
|
||||||
|
|
||||||
|
def parameter_init(self,bandPass_low,bandPass_high):
|
||||||
|
self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息
|
||||||
|
self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch
|
||||||
|
self.trainData = [] #训练数据
|
||||||
|
self.trainLabel = [] #训练标签
|
||||||
|
self.plotData = [] #报告分析数据
|
||||||
|
self.plotLabel = [] #报告分析标签
|
||||||
|
self.currentLabel = -1 #刺激界面当前显示的训练标签
|
||||||
|
self.train_started = False #是否开始训练模型
|
||||||
|
self.load_model = False # 调用模型是否完成的标志
|
||||||
|
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子
|
||||||
|
self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器
|
||||||
|
fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||||
|
filePath = './online_Models/'
|
||||||
|
self.modelPath = ''.join([filePath, fileName, '.pth'])
|
||||||
|
self.mp_data_queue = mp.Queue() #多进程传参队列
|
||||||
|
self.mp_result_queue = mp.Queue() #多进程结果队列
|
||||||
|
|
||||||
|
def preprocess(self, signal_data):
|
||||||
|
# # 计算每行的平均值
|
||||||
|
row_means = np.mean(signal_data, axis=-1, keepdims=True)
|
||||||
|
# 对每一行去均值
|
||||||
|
signal_data = signal_data - row_means
|
||||||
|
|
||||||
|
signal_data = signal.lfilter(self.b_notch, self.a_notch, signal_data, axis=-1) # 工频陷波
|
||||||
|
signal_data = signal.lfilter(self.b_design, 1, signal_data, axis=-1) # 带通滤波
|
||||||
|
return signal_data
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
while self.Runing:
|
||||||
|
if self.zmqServer.decoder_switch or self.zmqServer.changeTarget:
|
||||||
|
print(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}")
|
||||||
|
self.zmqServer.decoder_switch = False
|
||||||
|
self.zmqServer.changeTarget = False
|
||||||
|
self.reset_state() # 切换前先统一清理旧状态
|
||||||
|
self.init_Decoder(self.zmqServer.decoder_class)
|
||||||
|
|
||||||
|
# 同步信息
|
||||||
|
if self.zmqServer.state_mode == 'sync':
|
||||||
|
self.zmqClient.send_to_all('sync', self.zmqClient.state)
|
||||||
|
self.zmqServer.state_mode = 'rest'
|
||||||
|
# 状态异常,报告上位机
|
||||||
|
if self.status_code != self.thread_data_server.status_code:
|
||||||
|
self.status_code = self.thread_data_server.status_code
|
||||||
|
self.zmqClient.send_to_all('status_code', int(self.status_code))
|
||||||
|
print('status code')
|
||||||
|
|
||||||
|
# 返回电量
|
||||||
|
if self.energy != self.thread_data_server.energy:
|
||||||
|
self.energy = self.thread_data_server.energy
|
||||||
|
self.zmqClient.send_to_all('energy', int(self.energy))
|
||||||
|
print('energy')
|
||||||
|
|
||||||
|
if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次
|
||||||
|
self.thread_data_server.Impedance(True)
|
||||||
|
print('Impedance')
|
||||||
|
self.zmqServer.open_Impedance = -1
|
||||||
|
elif self.zmqServer.open_Impedance == False:
|
||||||
|
self.thread_data_server.Impedance(False)
|
||||||
|
self.zmqServer.open_Impedance = -1
|
||||||
|
|
||||||
|
if self.zmqServer.get_Impedance: # 返回阻抗值
|
||||||
|
# print(self.zmqServer.get_Impedance)
|
||||||
|
# print(self.thread_data_server.GetDataLenCount())
|
||||||
|
if self.thread_data_server.GetDataLenCount() > 250:
|
||||||
|
Impe_data = self.thread_data_server.getData(250)
|
||||||
|
# 计算阻抗
|
||||||
|
imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class)
|
||||||
|
self.zmqClient.send_to_all('impedance', imps.tolist())
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
if self.zmqServer.getReport: #返回训练报告内容
|
||||||
|
self.zmqServer.getReport = False
|
||||||
|
allData = np.array(self.plotData)
|
||||||
|
allLabel = np.array(self.plotLabel) + 1
|
||||||
|
nTrials = min(len(allLabel),len(allData))
|
||||||
|
if nTrials < 30:
|
||||||
|
self.zmqClient.send_to_all('miReport',0)
|
||||||
|
else:
|
||||||
|
allData = allData[:nTrials]
|
||||||
|
allLabel = allLabel[:nTrials]
|
||||||
|
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
|
||||||
|
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
|
||||||
|
compare_names = ['C3', 'CZ', 'C4']
|
||||||
|
miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2,
|
||||||
|
fs=self.fs)
|
||||||
|
self.zmqClient.send_to_all('miReport',miReport)
|
||||||
|
|
||||||
|
|
||||||
|
# --- 取数优先:先执行 decoder(消费环形缓冲),再处理 plot/report 等重负载 ---
|
||||||
|
try:
|
||||||
|
if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs':
|
||||||
|
self.decoder_SSVEP()
|
||||||
|
elif self.decoder_class == 'ssmvep':
|
||||||
|
self.decoder_SSMVEP()
|
||||||
|
elif self.decoder_class == 'mi':
|
||||||
|
self.decoder_MI()
|
||||||
|
elif self.decoder_class == 'concentration':
|
||||||
|
self.decoder_concentration()
|
||||||
|
elif self.decoder_class == 'blink':
|
||||||
|
self.decoder_blink()
|
||||||
|
else:
|
||||||
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
|
time.sleep(0.005)
|
||||||
|
continue;
|
||||||
|
self.thread_data_server.getData(25)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Decoder Loop Error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
time.sleep(0.1) # Prevent CPU spin if error is persistent
|
||||||
|
|
||||||
|
def decoder_SSVEP(self):
|
||||||
|
if self.zmqServer.StartDecode:
|
||||||
|
self.zmqServer.StartDecode = False
|
||||||
|
self.decodingSteps = 1
|
||||||
|
self.thread_data_server.ResetAll()
|
||||||
|
print('启动预测')
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 50:
|
||||||
|
time.sleep(0.005)
|
||||||
|
return
|
||||||
|
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
||||||
|
return
|
||||||
|
data = self.thread_data_server.getDataViaSSVEP(50)
|
||||||
|
data = data[:self.n_chan, :]
|
||||||
|
if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热
|
||||||
|
self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时
|
||||||
|
self.dw.warmFilter(data) # 预热
|
||||||
|
self.decodingSteps = 2
|
||||||
|
print('预热数据完成。开始预测')
|
||||||
|
return
|
||||||
|
if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中
|
||||||
|
choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount)
|
||||||
|
self.calculateCount += 1
|
||||||
|
if choosenNum != -1 and self.is_valid_signal(data):
|
||||||
|
self.decodingSteps = 3
|
||||||
|
print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount))
|
||||||
|
self.calculateCount = 0
|
||||||
|
if self.decodingSteps == 3: # 发送解码后的信息
|
||||||
|
self.zmqClient.send_to_all('result', int(choosenNum))
|
||||||
|
self.decodingSteps = 0
|
||||||
|
print('发送给界面完成。')
|
||||||
|
|
||||||
|
def decoder_SSMVEP(self):
|
||||||
|
'''模型训练'''
|
||||||
|
if self.load_model == False and all(
|
||||||
|
self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成
|
||||||
|
self.trainData = np.array(self.trainData)
|
||||||
|
self.trainLabel = np.array(self.trainLabel)
|
||||||
|
print(np.shape(self.trainData), (self.trainLabel))
|
||||||
|
# 保存多个数组到文件
|
||||||
|
# np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel)
|
||||||
|
# self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
||||||
|
self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf)
|
||||||
|
now = datetime.now()
|
||||||
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print('模型训练完成', formatted_time)
|
||||||
|
self.load_model = True
|
||||||
|
self.zmqClient.send_to_all('paradigm', 1)
|
||||||
|
|
||||||
|
'''训练阶段采集数据'''
|
||||||
|
if self.zmqServer.state_mode == 'train': # 训练状态
|
||||||
|
if self.zmqServer.StartTrain:
|
||||||
|
self.currentLabel = self.zmqServer.currentLabel
|
||||||
|
self.zmqServer.StartTrain = False
|
||||||
|
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||||
|
self.train_epoch[1] \
|
||||||
|
+ self.thread_data_server.event_inner_idx:
|
||||||
|
time.sleep(0.0001)
|
||||||
|
return
|
||||||
|
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
|
||||||
|
trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据
|
||||||
|
print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx])
|
||||||
|
trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理
|
||||||
|
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.train_epoch[
|
||||||
|
0]:self.thread_data_server.event_inner_idx + self.train_epoch[1]]
|
||||||
|
print('trial: ', self.thread_data_server.event_inner_idx, self.train_epoch[0], self.train_epoch[1])
|
||||||
|
if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance(
|
||||||
|
self.trainLabel, list) \
|
||||||
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||||
|
self.trainData.append(trainTrial)
|
||||||
|
self.trainLabel.append(self.currentLabel)
|
||||||
|
|
||||||
|
elif self.zmqServer.state_mode == 'predict': # 测试状态
|
||||||
|
if self.load_model == False: # 模型尚未训练完成
|
||||||
|
time.sleep(0.01)
|
||||||
|
return
|
||||||
|
else: # 已有模型
|
||||||
|
if self.zmqServer.StartDecode:
|
||||||
|
self.zmqServer.StartDecode = False
|
||||||
|
now = datetime.now()
|
||||||
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print('启动预测 ', formatted_time)
|
||||||
|
|
||||||
|
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||||
|
self.interval_epoch[1] \
|
||||||
|
+ self.thread_data_server.event_inner_idx:
|
||||||
|
time.sleep(0.0001)
|
||||||
|
return
|
||||||
|
data = self.thread_data_server.get_SSMVEPData() # 读取全部数据
|
||||||
|
print('取出的: ', data.shape, 'event: ', data[-2, self.thread_data_server.event_inner_idx])
|
||||||
|
data = self.preprocess(data[:self.n_chan, :]) # 预处理
|
||||||
|
data = data[:,
|
||||||
|
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||||
|
pad_eeg_test = np.zeros(
|
||||||
|
(data.shape[0], int((self.sample_length + 0.1) * self.fs)))
|
||||||
|
pad_eeg_test[:, :int(self.sample_length * self.fs)] = data
|
||||||
|
choosenNum, features_2 = self.decoder.predict(pad_eeg_test)
|
||||||
|
if isinstance(choosenNum, np.ndarray):
|
||||||
|
choosenNum = choosenNum[0]
|
||||||
|
print('结果:', choosenNum, 'rho: ', sorted(features_2[0]),
|
||||||
|
sorted(features_2[0])[-1] - sorted(features_2[0])[-2])
|
||||||
|
self.zmqClient.send_to_all('result', int(choosenNum))
|
||||||
|
print('发送给界面完成。')
|
||||||
|
else: # 休息状态
|
||||||
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
|
time.sleep(0.005)
|
||||||
|
return
|
||||||
|
self.thread_data_server.getData(25)
|
||||||
|
|
||||||
|
def decoder_MI(self):
|
||||||
|
'''模型训练'''
|
||||||
|
if self.train_started == False and all(
|
||||||
|
self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练
|
||||||
|
self.zmqClient.send_to_all('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机
|
||||||
|
self.train_started = True
|
||||||
|
self.trainData = np.array(self.trainData)
|
||||||
|
self.trainLabel = np.array(self.trainLabel) + 1
|
||||||
|
# print('训练集:',np.shape(self.trainData), (self.trainLabel))
|
||||||
|
p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型
|
||||||
|
p.start()
|
||||||
|
self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath,
|
||||||
|
'n_chan': self.n_chan})
|
||||||
|
|
||||||
|
'''检查模型是否训练完成,调用'''
|
||||||
|
if self.load_model == False and self.train_started == True:
|
||||||
|
try:
|
||||||
|
result = self.mp_result_queue.get_nowait()
|
||||||
|
if result['status'] == 'success':
|
||||||
|
print("模型训练完成,加载新模型")
|
||||||
|
# 调用模型
|
||||||
|
self.model = torch.load(self.modelPath, weights_only=False)
|
||||||
|
self.model.eval()
|
||||||
|
# 模型预热
|
||||||
|
warmup_data = np.random.uniform(-1, 1, (1, 1, self.n_chan, 1000))
|
||||||
|
warmup_data = torch.from_numpy(warmup_data)
|
||||||
|
warmup_data = Variable(warmup_data.type(torch.cuda.FloatTensor))
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = self.model(warmup_data)
|
||||||
|
self.load_model = True
|
||||||
|
self.zmqClient.send_to_all('paradigm', 1) # 模型调用完毕,通知上位机
|
||||||
|
else:
|
||||||
|
print("训练失败:", result['msg'])
|
||||||
|
except Empty:
|
||||||
|
pass # 还没完成
|
||||||
|
except Exception as e:
|
||||||
|
print('模型调用失败: ', e)
|
||||||
|
|
||||||
|
'''训练阶段采集数据'''
|
||||||
|
if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态
|
||||||
|
if self.zmqServer.StartTrain:
|
||||||
|
self.currentLabel = self.zmqServer.currentLabel
|
||||||
|
self.zmqServer.StartTrain = False
|
||||||
|
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||||
|
self.interval_epoch[1] \
|
||||||
|
+ self.thread_data_server.event_inner_idx:
|
||||||
|
time.sleep(0.0001)
|
||||||
|
return
|
||||||
|
print('训练队列数据:', self.thread_data_server.GetDataLenCount())
|
||||||
|
originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据
|
||||||
|
print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx])
|
||||||
|
trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理
|
||||||
|
trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||||
|
print('trial: ', self.thread_data_server.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1])
|
||||||
|
if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel,
|
||||||
|
list) \
|
||||||
|
and self.trainLabel.count(self.currentLabel) < self.single_train:
|
||||||
|
self.trainData.append(trainTrial)
|
||||||
|
self.trainLabel.append(self.currentLabel)
|
||||||
|
print('训练集:', np.shape(self.trainData))
|
||||||
|
self.plotData.append(originalTrial[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
||||||
|
self.plotLabel.append(self.currentLabel)
|
||||||
|
|
||||||
|
elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态
|
||||||
|
if self.zmqServer.StartDecode:
|
||||||
|
self.zmqServer.StartDecode = False
|
||||||
|
now = datetime.now()
|
||||||
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print('启动预测 ', formatted_time)
|
||||||
|
|
||||||
|
if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \
|
||||||
|
self.interval_epoch[1] \
|
||||||
|
+ self.thread_data_server.event_inner_idx:
|
||||||
|
time.sleep(0.0001)
|
||||||
|
return
|
||||||
|
originalData = self.thread_data_server.get_MIData() # 读取全部数据
|
||||||
|
print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.thread_data_server.event_inner_idx])
|
||||||
|
start = time.time()
|
||||||
|
data = self.preprocess(originalData[:self.n_chan, :]) # 预处理
|
||||||
|
data = data[:,
|
||||||
|
self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]
|
||||||
|
self.plotData.append(
|
||||||
|
originalData[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[
|
||||||
|
0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]])
|
||||||
|
|
||||||
|
test_data = data[np.newaxis, np.newaxis, :, :]
|
||||||
|
test_data = torch.from_numpy(test_data)
|
||||||
|
test_data = Variable(test_data.type(torch.cuda.FloatTensor))
|
||||||
|
with torch.no_grad():
|
||||||
|
Cls = self.model(test_data)
|
||||||
|
y_pred = torch.max(Cls, 1)[1]
|
||||||
|
self.plotLabel.append(int(y_pred.item()))
|
||||||
|
print('运动意图识别: ', y_pred)
|
||||||
|
self.zmqClient.send_to_all('result', int(y_pred.item()))
|
||||||
|
end = time.time()
|
||||||
|
print(f'发送给界面完成,耗时{end - start:.3f}s。')
|
||||||
|
else: # 休息状态
|
||||||
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
|
time.sleep(0.005)
|
||||||
|
return
|
||||||
|
self.thread_data_server.getData(25)
|
||||||
|
|
||||||
|
def decoder_concentration(self):
|
||||||
|
if self.zmqServer.state_mode == 'predict':
|
||||||
|
if self.zmqServer.StartDecode:
|
||||||
|
self.zmqServer.StartDecode = False
|
||||||
|
self.thread_data_server.ResetAll()
|
||||||
|
now = datetime.now()
|
||||||
|
formatted_time = now.strftime('%H:%M:%S.%f')[:-3]
|
||||||
|
print('启动专注力预测 ', formatted_time)
|
||||||
|
if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.fs): # 每win_step得出一次结果
|
||||||
|
time.sleep(0.005)
|
||||||
|
return
|
||||||
|
if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码
|
||||||
|
return
|
||||||
|
data = self.thread_data_server.get_concentrateData(int(self.win_step * self.fs)) # 修改每次读取的数据
|
||||||
|
result = self.calculate.queueOpt(data)
|
||||||
|
if result is not None:
|
||||||
|
self.zmqClient.send_to_all('result', int(result))
|
||||||
|
else: # 休息状态
|
||||||
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 25:
|
||||||
|
time.sleep(0.005)
|
||||||
|
return
|
||||||
|
self.thread_data_server.getData(25)
|
||||||
|
|
||||||
|
#### Blink detection #####
|
||||||
|
def check_double_blink(self, current_time):
|
||||||
|
"""
|
||||||
|
检查是否检测到连续两次眨眼
|
||||||
|
@param current_time: 当前眨眼时间戳
|
||||||
|
@return: True表示检测到连续两次眨眼
|
||||||
|
"""
|
||||||
|
if len(self.blink_timestamps) < 2:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查是否在去抖期内
|
||||||
|
if self.last_double_blink_time > 0:
|
||||||
|
time_since_last_double_blink = current_time - self.last_double_blink_time
|
||||||
|
if time_since_last_double_blink < self.double_blink_jitter:
|
||||||
|
return False # 在去抖期内,忽略连续眨眼检测
|
||||||
|
last_time = self.blink_timestamps[-1] # 当前眨眼
|
||||||
|
prev_time = self.blink_timestamps[-2] # 上次眨眼
|
||||||
|
|
||||||
|
interval = last_time - prev_time
|
||||||
|
if interval <= self.double_blink_interval:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def process_blink_detection(self):
|
||||||
|
"""
|
||||||
|
在缓冲区数据上执行,单次眨眼检测
|
||||||
|
"""
|
||||||
|
if len(self.fp1_buffer) < self.window_samples:
|
||||||
|
return
|
||||||
|
|
||||||
|
fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:])
|
||||||
|
fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:])
|
||||||
|
# 计算FP1和FP2的平均
|
||||||
|
fp12_mean = (fp1_data + fp2_data) / 2.0
|
||||||
|
# 带通滤波
|
||||||
|
try:
|
||||||
|
fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Filter error: {e}")
|
||||||
|
return
|
||||||
|
F = np.diff(fp12_filtered)
|
||||||
|
if len(F) < 3:
|
||||||
|
return
|
||||||
|
b, d, e = blink_detection(F, self.fs, self.Dmin, self.Dmax, self.EMin, self.EMax)
|
||||||
|
|
||||||
|
if b == 1:
|
||||||
|
samples_since_last = self.total_samples - self.last_blink_time
|
||||||
|
time_since_last_ms = (samples_since_last / self.fs) * 1000
|
||||||
|
if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms
|
||||||
|
self.blink_count += 1
|
||||||
|
self.last_blink_time = self.total_samples
|
||||||
|
current_time = time.time()
|
||||||
|
self.blink_timestamps.append(current_time)
|
||||||
|
blink_event = {
|
||||||
|
'count': self.blink_count,
|
||||||
|
'time': current_time,
|
||||||
|
'sample_index': self.total_samples,
|
||||||
|
'duration_ms': d,
|
||||||
|
'energy': e
|
||||||
|
}
|
||||||
|
self.blink_events.append(blink_event)
|
||||||
|
self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机
|
||||||
|
if self.check_double_blink(current_time):
|
||||||
|
self.double_blink_count += 1
|
||||||
|
interval = self.blink_timestamps[-1] - self.blink_timestamps[-2]
|
||||||
|
double_blink_event = {
|
||||||
|
'double_blink_count': self.double_blink_count,
|
||||||
|
'blink1_time': self.blink_timestamps[-2],
|
||||||
|
'blink2_time': self.blink_timestamps[-1],
|
||||||
|
'interval': interval
|
||||||
|
}
|
||||||
|
self.double_blink_events.append(double_blink_event)
|
||||||
|
self.last_double_blink_time = current_time
|
||||||
|
self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件
|
||||||
|
|
||||||
|
def decoder_blink(self):
|
||||||
|
if self.thread_data_server.GetDataLenCount() < 50:
|
||||||
|
time.sleep(0.005)
|
||||||
|
return
|
||||||
|
if self.zmqServer.get_Impedance == False: # 非阻抗检测状态
|
||||||
|
data = self.thread_data_server.get_blinkData(50)
|
||||||
|
fp1_data = data[0, :] # ch1 (相当于FP1)
|
||||||
|
fp2_data = data[1, :] # ch2 (相当于FP2)
|
||||||
|
for i in range(len(fp1_data)):
|
||||||
|
self.fp1_buffer.append(fp1_data[i])
|
||||||
|
self.fp2_buffer.append(fp2_data[i])
|
||||||
|
self.total_samples += 1
|
||||||
|
self.sample_counter += 1
|
||||||
|
|
||||||
|
if self.sample_counter >= self.step_samples:
|
||||||
|
self.process_blink_detection()
|
||||||
|
self.sample_counter = 0
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
'''
|
||||||
|
停止运行
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
self.zmqServer.stop()
|
||||||
|
self.Runing=False
|
||||||
|
|
||||||
|
def reset_state(self):
|
||||||
|
"""清空解码器状态和缓存数据"""
|
||||||
|
# 重置设备层缓存
|
||||||
|
self.thread_data_server.reset_state()
|
||||||
|
|
||||||
|
# 重置解码状态
|
||||||
|
self.decodingSteps = 0
|
||||||
|
self.calculateCount = 0
|
||||||
|
|
||||||
|
# 重置训练数据
|
||||||
|
self.plotData = []
|
||||||
|
self.plotLabel = []
|
||||||
|
self.trainData = []
|
||||||
|
self.trainLabel = []
|
||||||
|
self.currentLabel = -1
|
||||||
|
self.train_started = False
|
||||||
|
self.load_model = False
|
||||||
|
|
||||||
|
# 重置多进程队列,确保切换 decoder 时旧数据不会泄漏到新队列
|
||||||
|
if hasattr(self, 'mp_data_queue'):
|
||||||
|
while not self.mp_data_queue.empty():
|
||||||
|
try: self.mp_data_queue.get_nowait()
|
||||||
|
except Empty: pass
|
||||||
|
if hasattr(self, 'mp_result_queue'):
|
||||||
|
while not self.mp_result_queue.empty():
|
||||||
|
try: self.mp_result_queue.get_nowait()
|
||||||
|
except Empty: pass
|
||||||
814
Debug_64ch_Decoder_Optimize/Device/SunnyLinker.py
Normal file
814
Debug_64ch_Decoder_Optimize/Device/SunnyLinker.py
Normal file
@@ -0,0 +1,814 @@
|
|||||||
|
# -*-coding:utf-8 -*-
|
||||||
|
'''
|
||||||
|
SunnyLinker的通讯驱动
|
||||||
|
'''
|
||||||
|
import ast
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
from typing import Dict
|
||||||
|
from collections import deque
|
||||||
|
import numpy as np
|
||||||
|
from threading import Thread, Event
|
||||||
|
import serial
|
||||||
|
from scipy import signal
|
||||||
|
from serial.serialutil import SerialException
|
||||||
|
|
||||||
|
from Device.protocol import ProtocolFrame
|
||||||
|
from PubLibrary.InifileHelper import IniRead
|
||||||
|
|
||||||
|
class RingBuffer:
|
||||||
|
def __init__(self, n_chan, n_points):
|
||||||
|
self.n_chan = n_chan
|
||||||
|
self.n_points = n_points
|
||||||
|
self.buffer = np.zeros((n_chan, n_points))
|
||||||
|
self.currentPtr = 0
|
||||||
|
self.readPtr = 0
|
||||||
|
self.nUpdate = 0
|
||||||
|
self.rawData = np.zeros((n_chan, 1))
|
||||||
|
|
||||||
|
## append buffer and update current pointer
|
||||||
|
def appendBuffer(self, data):
|
||||||
|
if self.nUpdate == self.n_points:
|
||||||
|
raise Exception("Buffer is full")
|
||||||
|
|
||||||
|
n = data.shape[1]
|
||||||
|
|
||||||
|
# 计算可以写入的元素数量
|
||||||
|
write_count = min(self.n_points - self.nUpdate, n)
|
||||||
|
# 写入新数据
|
||||||
|
self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count]
|
||||||
|
# 更新结束指针
|
||||||
|
self.currentPtr = (self.currentPtr + write_count) % self.n_points
|
||||||
|
# 更新大小
|
||||||
|
self.nUpdate += write_count
|
||||||
|
|
||||||
|
## get data from buffer
|
||||||
|
def getData(self, count=50):
|
||||||
|
# 确保不会尝试读取超过缓冲区当前大小的数据
|
||||||
|
count = min(count, self.nUpdate)
|
||||||
|
|
||||||
|
# 计算读取结束后的下一个位置
|
||||||
|
next_read_ptr = (self.readPtr + count) % self.n_points
|
||||||
|
if self.readPtr + count <= self.n_points:
|
||||||
|
# 情况 1:不环绕,数据是连续的
|
||||||
|
end_index = next_read_ptr if next_read_ptr != 0 else self.n_points
|
||||||
|
data = self.buffer[:, self.readPtr:end_index]
|
||||||
|
else:
|
||||||
|
# 情况 2:发生环绕,数据被分成两部分
|
||||||
|
# 第一部分:从 readPtr 到缓冲区末尾
|
||||||
|
part1 = self.buffer[:, self.readPtr:]
|
||||||
|
# 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点
|
||||||
|
part2 = self.buffer[:, :next_read_ptr]
|
||||||
|
# 将两部分在列方向上拼接
|
||||||
|
data = np.concatenate((part1, part2), axis=1)
|
||||||
|
|
||||||
|
# 更新读指针
|
||||||
|
self.readPtr = next_read_ptr
|
||||||
|
# 更新大小
|
||||||
|
self.nUpdate -= count
|
||||||
|
return data
|
||||||
|
|
||||||
|
# reset buffer
|
||||||
|
def resetAllPara(self):
|
||||||
|
self.nUpdate = 0
|
||||||
|
self.currentPtr = 0
|
||||||
|
self.readPtr = 0 # add by lizhenhua 清空读指针
|
||||||
|
self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区
|
||||||
|
|
||||||
|
|
||||||
|
class SunnyLinker64(Thread, ):
|
||||||
|
serial_port = str(IniRead('system', 'Serial_port'))
|
||||||
|
t_buffer = 10
|
||||||
|
n_chan = 64
|
||||||
|
srate = 250
|
||||||
|
win_len = 10
|
||||||
|
win_step = 1
|
||||||
|
ring_buffer = 5
|
||||||
|
receiveData = b''
|
||||||
|
toUv=True#转为uV
|
||||||
|
RingBufferLock = threading.Lock()
|
||||||
|
|
||||||
|
# 单例模式
|
||||||
|
_instance = None
|
||||||
|
_initialized = False # 检查是否已经初始化
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(SunnyLinker64, cls).__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
def __init__(self, host='127.0.0.1', port=7878, srate=250, n_chan=64,method = 'tcp'):
|
||||||
|
if SunnyLinker64._initialized:
|
||||||
|
return
|
||||||
|
Thread.__init__(self)
|
||||||
|
self.daemon = True
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.srate = srate
|
||||||
|
self.n_chan = n_chan
|
||||||
|
self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输
|
||||||
|
self.__ringBuffer = RingBuffer(self.n_chan + 2,
|
||||||
|
int(np.round(self.t_buffer * self.srate)))
|
||||||
|
self.energy = 0 # 电量
|
||||||
|
self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常
|
||||||
|
self.gain_value = 6 # 增益倍数
|
||||||
|
self.interval_inited = False #ssmvep或mi时间窗是否初始化
|
||||||
|
|
||||||
|
# 设置初始化标志为True,防止重复初始化
|
||||||
|
SunnyLinker64._initialized = True
|
||||||
|
|
||||||
|
# --- 新增:用于心跳检测 ---
|
||||||
|
self.last_called = 0 # 初始化为0
|
||||||
|
self.last_called_lock = threading.Lock() # 保护 last_called 的访问
|
||||||
|
|
||||||
|
def reset_state(self):
|
||||||
|
"""清空采集器状态和缓存数据"""
|
||||||
|
with self.RingBufferLock:
|
||||||
|
self.__ringBuffer.resetAllPara()
|
||||||
|
self.count_events = {}
|
||||||
|
self.epoch_finished = False
|
||||||
|
self.pack_contain_event = False
|
||||||
|
self.event_inner_idx = -1
|
||||||
|
self.interval_inited = False
|
||||||
|
|
||||||
|
def interval_init(self,decoder_class):
|
||||||
|
if decoder_class == 'ssmvep':
|
||||||
|
interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch'))
|
||||||
|
self.interval_epoch = [int(i * self.srate) for i in interval_epoch] # epoch截取信息
|
||||||
|
self.train_epoch = [int(self.interval_epoch[0]),
|
||||||
|
int(self.interval_epoch[1] + 0.1 * self.srate)] # 训练样本epoch
|
||||||
|
self.latency = (self.interval_epoch[
|
||||||
|
1] + 0.1 * self.srate) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉
|
||||||
|
self.train_latency = (self.train_epoch[1] + 0.1 * self.srate) // 5
|
||||||
|
|
||||||
|
elif decoder_class == 'mi':
|
||||||
|
interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch'))
|
||||||
|
self.interval_epoch = [int(i * self.srate) for i in interval_epoch] # epoch截取信息
|
||||||
|
self.train_epoch = self.interval_epoch.copy()
|
||||||
|
self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;
|
||||||
|
self.train_latency = self.latency
|
||||||
|
|
||||||
|
print('时间窗:', (interval_epoch))
|
||||||
|
self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息
|
||||||
|
self.event_inner_idx = -1 # event在5位数据包内部的idx
|
||||||
|
self.epoch_finished = False # 接收epoch是否完整
|
||||||
|
self.pack_contain_event = False # 当前包是否含有event
|
||||||
|
self.predict_event = 99
|
||||||
|
self.events = [1, 2, self.predict_event]
|
||||||
|
if getattr(self, 'serial', None) and self.serial.is_open:
|
||||||
|
self.serial.close()
|
||||||
|
self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口
|
||||||
|
self.interval_inited = True
|
||||||
|
|
||||||
|
def set_sampleRate(self,sampleRate_Code=0x00):
|
||||||
|
'''
|
||||||
|
设置采样率
|
||||||
|
:param sampleRate_Code: 0x00:250Hz,0x01:500Hz,0x02:1000Hz,0x03:2000Hz
|
||||||
|
'''
|
||||||
|
function_code = 0x02
|
||||||
|
gain_code = 0x06
|
||||||
|
sampleRate_Code = [gain_code,sampleRate_Code]
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, sampleRate_Code)
|
||||||
|
if self.method == 'tcp':
|
||||||
|
self.sock.send(packed_data)
|
||||||
|
|
||||||
|
def push_trigger(self,label):
|
||||||
|
'''
|
||||||
|
数据打标
|
||||||
|
@param label:标签类别
|
||||||
|
'''
|
||||||
|
function_code = None
|
||||||
|
label = [label]
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, label)
|
||||||
|
if self.method == 'tcp' and hasattr(self,'serial'):
|
||||||
|
print('发送:', label, datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3])
|
||||||
|
self.serial.write(packed_data)
|
||||||
|
def Impedance(self, On):
|
||||||
|
'''
|
||||||
|
阻抗检测开关
|
||||||
|
:param On:True为开启,False为关闭
|
||||||
|
:return: 组好的协议帧
|
||||||
|
'''
|
||||||
|
function_code = 0x01
|
||||||
|
if On:
|
||||||
|
data = [0x1]
|
||||||
|
self.gain_value = 6
|
||||||
|
else:
|
||||||
|
data = [0x0]
|
||||||
|
self.gain_value = 6
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, data)
|
||||||
|
if self.method == 'tcp':
|
||||||
|
self.sock.send(packed_data)
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
try:
|
||||||
|
if self.method == 'serial':
|
||||||
|
# 开启com口,波特率115200,超时5
|
||||||
|
self.sock = serial.Serial(self.host, self.port, timeout=5)
|
||||||
|
self.sock.flushInput() # 清空缓冲区
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
while not count:
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
# # 接收和存储数据
|
||||||
|
data = (self.sock.read(count))
|
||||||
|
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||||
|
elif self.method == 'tcp':
|
||||||
|
# 重连前关闭旧 socket,避免资源泄漏
|
||||||
|
if hasattr(self, 'sock') and self.sock:
|
||||||
|
try:
|
||||||
|
self.sock.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
self.sock.connect((self.host, int(self.port)))
|
||||||
|
self.set_sampleRate(0x00) #设置250Hz采样率
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print("请打开头环")
|
||||||
|
print(e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("connected")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def extract_packet(self, packet):
|
||||||
|
# 存储一个点的八通道数据
|
||||||
|
dataList = []
|
||||||
|
# 存储116个点的八通道数据
|
||||||
|
dataMatrix = []
|
||||||
|
|
||||||
|
for j in range(5):
|
||||||
|
for i in range(self.n_chan):
|
||||||
|
if not self.toUv:#原始数据直接输出
|
||||||
|
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||||
|
194 * j + 25 + 2 + i * 3]
|
||||||
|
|
||||||
|
else:#转为uV
|
||||||
|
val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||||
|
194 * j + 25 + 2 + i * 3]
|
||||||
|
if val < 8388608:
|
||||||
|
val = val * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||||
|
else:
|
||||||
|
val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||||
|
dataList.append(val)
|
||||||
|
#同步触发源
|
||||||
|
val = packet[194 * j + 25 + (i+1) * 3]
|
||||||
|
dataList.append(val)
|
||||||
|
#同步触发序号
|
||||||
|
val = packet[194 * j + 25 + (i+1) * 3+1]
|
||||||
|
dataList.append(val)
|
||||||
|
|
||||||
|
|
||||||
|
# 将数据矩阵进行拼接
|
||||||
|
if len(dataMatrix) == 0:
|
||||||
|
dataMatrix = np.asmatrix(dataList)
|
||||||
|
else:
|
||||||
|
dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0)
|
||||||
|
dataList.clear()
|
||||||
|
return np.transpose(dataMatrix)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.running = True
|
||||||
|
self.PackageLength = 998
|
||||||
|
|
||||||
|
# 尝试连接循环,断开后自动重连
|
||||||
|
while self.running:
|
||||||
|
if self.connect():
|
||||||
|
break
|
||||||
|
print(f"无法连接到 {self.host}:{self.port},15秒后重试...")
|
||||||
|
time.sleep(15)
|
||||||
|
|
||||||
|
# 启动心跳检测线程
|
||||||
|
threading.Thread(target=self.heartbeat_checker, daemon=True).start()
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
if self.method == 'serial':
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
if count:
|
||||||
|
# 接收和存储数据
|
||||||
|
data = (self.sock.read(count))
|
||||||
|
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||||
|
elif self.method == 'tcp':
|
||||||
|
data = self.sock.recv(600)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
self.receiveData += data
|
||||||
|
with self.last_called_lock:
|
||||||
|
self.last_called = time.time()
|
||||||
|
self.status_code = 1 # 收到数据,标记为正常
|
||||||
|
if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind(
|
||||||
|
b'\x55\x55') >= self.PackageLength - 2:
|
||||||
|
|
||||||
|
index = self.receiveData.index(b'\xaa')
|
||||||
|
self.receiveData = self.receiveData[index:]
|
||||||
|
if len(self.receiveData) >= self.PackageLength:
|
||||||
|
onepackage = self.receiveData[:self.PackageLength]
|
||||||
|
if onepackage[7] != 0:
|
||||||
|
self.energy = onepackage[7] # 电量
|
||||||
|
self.receiveData = self.receiveData[self.PackageLength:]
|
||||||
|
dataMatrix = self.extract_packet(onepackage)
|
||||||
|
try:
|
||||||
|
with self.RingBufferLock:
|
||||||
|
if self.interval_inited:
|
||||||
|
self.epoch_finished = self.detect_event(dataMatrix)
|
||||||
|
if self.pack_contain_event:
|
||||||
|
self.__ringBuffer.resetAllPara() # 检测到当前pack含有event,清除ringbuffer中之前的数据
|
||||||
|
self.__ringBuffer.appendBuffer(dataMatrix)
|
||||||
|
# self.plotBuffer.appendBuffer(dataMatrix)
|
||||||
|
if self.epoch_finished:
|
||||||
|
time.sleep(0.005)
|
||||||
|
print('epoch_finished: ', datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3])
|
||||||
|
else:
|
||||||
|
self.__ringBuffer.appendBuffer(dataMatrix)
|
||||||
|
except Exception as e:
|
||||||
|
print("锁:写入异常",e)
|
||||||
|
# self.RingBufferLock.release()
|
||||||
|
except ConnectionResetError:
|
||||||
|
self.status_code = 0 # 状态异常
|
||||||
|
print("Connection was reset by the peer. 正在重新连接...")
|
||||||
|
self.sock.close()
|
||||||
|
# 退出循环后,run() 开头的重连循环会自动接管
|
||||||
|
break
|
||||||
|
# 如果 running=True,重连循环会接管,不会执行到这里
|
||||||
|
|
||||||
|
# 检测是否含有标签
|
||||||
|
def detect_event(self, samples):
|
||||||
|
self.pack_contain_event = False
|
||||||
|
events = np.array(samples[-2])[0].tolist()
|
||||||
|
for idx, event in enumerate(events):
|
||||||
|
if int(event) in self.events:
|
||||||
|
new_key = "".join(
|
||||||
|
[
|
||||||
|
str(event),
|
||||||
|
datetime.datetime.now().strftime("%Y-%m-%d \
|
||||||
|
-%H-%M-%S"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if event == self.predict_event:
|
||||||
|
self.count_events[new_key] = self.latency + 1
|
||||||
|
else:
|
||||||
|
self.count_events[new_key] = self.train_latency + 1
|
||||||
|
self.event_inner_idx = idx
|
||||||
|
self.pack_contain_event = True
|
||||||
|
drop_items = []
|
||||||
|
for key, value in self.count_events.items():
|
||||||
|
value = value - 1
|
||||||
|
if value == 0:
|
||||||
|
drop_items.append(key)
|
||||||
|
self.count_events[key] = value
|
||||||
|
for key in drop_items:
|
||||||
|
del self.count_events[key]
|
||||||
|
if drop_items:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# --- 新增:心跳检测线程 ---
|
||||||
|
def heartbeat_checker(self):
|
||||||
|
"""
|
||||||
|
定期检查是否在最近2秒内收到 eegData
|
||||||
|
如果超过2秒未收到,则设置 status_code = 0
|
||||||
|
"""
|
||||||
|
while self.running:
|
||||||
|
time.sleep(0.5) # 每0.5秒检查一次
|
||||||
|
with self.last_called_lock:
|
||||||
|
now = time.time()
|
||||||
|
# 只有收到过一次数据后才开始判断超时
|
||||||
|
if self.last_called > 0 and (now - self.last_called) > 30:
|
||||||
|
if self.status_code != 0:
|
||||||
|
print("EEG data timeout: disconnected")
|
||||||
|
self.status_code = 0
|
||||||
|
|
||||||
|
def getDataViaSSVEP(self,count):
|
||||||
|
'''
|
||||||
|
ssvep的视觉通道,共8个通道
|
||||||
|
@param count: 每通道读取的数值数量
|
||||||
|
@return: 返回最新的数值
|
||||||
|
'''
|
||||||
|
data=self.getData(count)
|
||||||
|
# PO5,POZ,PO6,O2,PO8,OZ,O1,PO7 64是event导联
|
||||||
|
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55,64]
|
||||||
|
row_to_select=np.array(rows_to_extract)
|
||||||
|
data=data[row_to_select,:]
|
||||||
|
return data
|
||||||
|
def get_MIData(self):
|
||||||
|
'''
|
||||||
|
取出当前所有数值
|
||||||
|
:return:
|
||||||
|
'''
|
||||||
|
data = self.getData(self.__ringBuffer.nUpdate)
|
||||||
|
#MI选取导联:FC3,FC1,FCZ,FC2,FC4,C5,C3,C1,CZ,C2,C4,C6,CP3,CP1,CP2,CP4,P3,P1,PZ,P2,P4,event1,event2
|
||||||
|
rows_to_extract = [8, 15, 12, 14, 18, 23, 16,59,50,58,17,45,29,11,10,19,20,61,51,60,21,64,65]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
data = data[row_to_select,:]
|
||||||
|
return data
|
||||||
|
def get_SSMVEPData(self):
|
||||||
|
'''
|
||||||
|
取出当前所有数值
|
||||||
|
:return:
|
||||||
|
'''
|
||||||
|
data = self.getData(self.__ringBuffer.nUpdate)
|
||||||
|
# PO5,POZ,PO6,O2,PO8,OZ,O1,PO7 64是event导联
|
||||||
|
rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64,65]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
data = data[row_to_select, :]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_concentrateData(self,count):
|
||||||
|
'''
|
||||||
|
@param count: 每通道读取的数值数量
|
||||||
|
@return: 返回最新的数值
|
||||||
|
'''
|
||||||
|
data=self.getData(count)
|
||||||
|
rows_to_extract = [0, 1]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
data = data[row_to_select, :]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_blinkData(self,count):
|
||||||
|
'''
|
||||||
|
@param count: 每通道读取的数值数量
|
||||||
|
@return: 返回最新的数值
|
||||||
|
'''
|
||||||
|
data=self.getData(count)
|
||||||
|
rows_to_extract = [0,1]
|
||||||
|
row_to_select = np.array(rows_to_extract)
|
||||||
|
data = data[row_to_select, :]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def getImpedance(self, data,decoder_class):
|
||||||
|
'''
|
||||||
|
获取阻抗值,已经放大100倍,单位是kΩ
|
||||||
|
@param data: 准备计算的通道数据,每通道200个值,注意不要把信号打标的通道传进来
|
||||||
|
@return: 返回各个通道的阻抗值
|
||||||
|
'''
|
||||||
|
impedanceList = []
|
||||||
|
for channelindex in range(data.shape[0]):
|
||||||
|
if len(data[channelindex]) > 0:
|
||||||
|
data_list = []
|
||||||
|
# 设计陷波滤波器,去除50Hz成分
|
||||||
|
is50filter = True
|
||||||
|
if is50filter:
|
||||||
|
b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽,1000是采样频率
|
||||||
|
data_list = signal.lfilter(b, a, data[channelindex].tolist())
|
||||||
|
|
||||||
|
else:
|
||||||
|
data_list.extend(data[channelindex].tolist())
|
||||||
|
|
||||||
|
data_list = data_list[-1000:]
|
||||||
|
# 执行FFT
|
||||||
|
fft_result = np.fft.fft(data_list)
|
||||||
|
fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果
|
||||||
|
freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴
|
||||||
|
|
||||||
|
# y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())],
|
||||||
|
# fft_magnitude[1:-1] * 2 / len(t[0].tolist()),
|
||||||
|
# [fft_magnitude[-1] / len(t[0].tolist())]))
|
||||||
|
|
||||||
|
# 找到幅值最大的频率成分的索引(忽略直流分量,即索引0)
|
||||||
|
max_index = np.argmax(fft_magnitude[1:])
|
||||||
|
|
||||||
|
# 获取最大幅值的频率索引(加上1,因为索引0是直流分量)
|
||||||
|
freq_index = max_index + 1
|
||||||
|
|
||||||
|
# 获取最大幅值
|
||||||
|
max_magnitude = fft_magnitude[freq_index]
|
||||||
|
|
||||||
|
# 阻抗
|
||||||
|
import math
|
||||||
|
result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4
|
||||||
|
result *= 0.44 * 100 # 统一放大100倍
|
||||||
|
impedanceList.append(int(result))
|
||||||
|
# print(max_magnitude, result)
|
||||||
|
else:
|
||||||
|
impedanceList.append(0)
|
||||||
|
impedances = np.array(impedanceList)
|
||||||
|
if decoder_class in ('mi', 'ma'):
|
||||||
|
impedances = impedances[np.array([8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21])]
|
||||||
|
elif decoder_class == 'blink':
|
||||||
|
impedances = impedances[np.array([0, 1])]
|
||||||
|
elif decoder_class == 'concentration':
|
||||||
|
impedances = impedances[np.array([0, 1])]
|
||||||
|
else:
|
||||||
|
impedances = impedances[np.array([13, 3, 2, 46, 9, 54, 47, 55])]
|
||||||
|
return impedances
|
||||||
|
def getData(self,count):
|
||||||
|
'''
|
||||||
|
获取最新的数据
|
||||||
|
@param count: 每通道返回的最数值数目
|
||||||
|
@return: 所有通道的最新count个数值
|
||||||
|
'''
|
||||||
|
data=None
|
||||||
|
try:
|
||||||
|
with self.RingBufferLock:
|
||||||
|
data = self.__ringBuffer.getData(count)
|
||||||
|
except:
|
||||||
|
print("锁:读取异常")
|
||||||
|
# self.RingBufferLock.release()
|
||||||
|
|
||||||
|
|
||||||
|
return data
|
||||||
|
def GetDataLenCount(self):
|
||||||
|
'''
|
||||||
|
获取最新缓存中每个通道的数量
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
return self.__ringBuffer.nUpdate
|
||||||
|
|
||||||
|
def ResetAll(self):
|
||||||
|
'''
|
||||||
|
清空缓存
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
with self.RingBufferLock:
|
||||||
|
self.__ringBuffer.resetAllPara()
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
class SunnyLinker8(Thread, ):
|
||||||
|
receiveData = ''
|
||||||
|
t_buffer = 10
|
||||||
|
n_chan = 9
|
||||||
|
srate = 1000
|
||||||
|
receiveData = b''
|
||||||
|
toUv=False#转为uV
|
||||||
|
RingBufferLock = threading.Lock()
|
||||||
|
def __init__(self, host, port, srate=1000, n_chan=9,method = 'tcp'):
|
||||||
|
Thread.__init__(self)
|
||||||
|
self.daemon = True
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.srate = srate
|
||||||
|
self.n_chan = n_chan
|
||||||
|
self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输
|
||||||
|
self.__ringBuffer = RingBuffer(self.n_chan + 2,
|
||||||
|
int(np.round(self.t_buffer * self.srate)))
|
||||||
|
self.energy = 0 #电量
|
||||||
|
self.status_code = 0 #与采集设备通信的状态码,0为异常,1为正常
|
||||||
|
self.gain_value = 6 # 增益倍数
|
||||||
|
|
||||||
|
def push_trigger(self,label):
|
||||||
|
'''
|
||||||
|
数据打标
|
||||||
|
@param label:标签类别
|
||||||
|
'''
|
||||||
|
function_code = None
|
||||||
|
label = [label]
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, label)
|
||||||
|
if self.method == 'tcp':
|
||||||
|
self.sock.send(packed_data)
|
||||||
|
elif self.method == 'serial':
|
||||||
|
self.sock.write(packed_data)
|
||||||
|
|
||||||
|
def Impedance(self, On):
|
||||||
|
'''
|
||||||
|
阻抗检测开关
|
||||||
|
:param On:True为开启,False为关闭
|
||||||
|
:return: 组好的协议帧
|
||||||
|
'''
|
||||||
|
function_code = None
|
||||||
|
if On:
|
||||||
|
data = [0xA1]
|
||||||
|
self.gain_value = 24
|
||||||
|
else:
|
||||||
|
data = [0xA0]
|
||||||
|
self.gain_value = 6
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, data)
|
||||||
|
if self.method == 'tcp':
|
||||||
|
self.sock.send(packed_data)
|
||||||
|
elif self.method == 'serial':
|
||||||
|
self.sock.write(packed_data)
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
try:
|
||||||
|
if self.method == 'serial':
|
||||||
|
# 开启com口,波特率115200,超时5
|
||||||
|
self.sock = serial.Serial(self.host, self.port, timeout=5)
|
||||||
|
self.sock.flushInput() # 清空缓冲区
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
while not count:
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
# # 接收和存储数据
|
||||||
|
data = (self.sock.read(count))
|
||||||
|
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||||
|
print("connected")
|
||||||
|
elif self.method == 'tcp':
|
||||||
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
self.sock.connect((self.host, int(self.port)))
|
||||||
|
print("connected")
|
||||||
|
except Exception as e:
|
||||||
|
print("请打开头环")
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
print("connected")
|
||||||
|
|
||||||
|
def extract_packet(self, packet):
|
||||||
|
# 存储一个点的八通道数据
|
||||||
|
dataList = []
|
||||||
|
# 存储116个点的八通道数据
|
||||||
|
dataMatrix = []
|
||||||
|
|
||||||
|
# index = (packet[1] << 24) | (packet[2] << 16) | (packet[3] << 8) | packet[4]
|
||||||
|
# print(index)
|
||||||
|
|
||||||
|
for j in range(5):
|
||||||
|
for i in range(self.n_chan):
|
||||||
|
if not self.toUv:#原始数据直接输出
|
||||||
|
val = (packet[26 * j + 25 + i * 3] << 16) | (packet[26 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||||
|
26 * j + 25 + 2 + i * 3]
|
||||||
|
|
||||||
|
else:#转为uV
|
||||||
|
val = (packet[26 * j + 25 + i * 3] << 16) | (packet[26 * j + 25 + 1 + i * 3] << 8) | packet[
|
||||||
|
26 * j + 25 + 2 + i * 3]
|
||||||
|
if val < 8388608:
|
||||||
|
val = val * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||||
|
else:
|
||||||
|
val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000;
|
||||||
|
dataList.append(val)
|
||||||
|
#同步触发源
|
||||||
|
val = packet[26 * j + 25 + (i+1) * 3]
|
||||||
|
dataList.append(val)
|
||||||
|
#同步触发序号
|
||||||
|
val = packet[26 * j + 25 + (i+1) * 3+1]
|
||||||
|
dataList.append(val)
|
||||||
|
|
||||||
|
|
||||||
|
# 将数据矩阵进行拼接
|
||||||
|
if len(dataMatrix) == 0:
|
||||||
|
dataMatrix = np.asmatrix(dataList)
|
||||||
|
else:
|
||||||
|
dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0)
|
||||||
|
dataList.clear()
|
||||||
|
return np.transpose(dataMatrix)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.connect()
|
||||||
|
self.running = True
|
||||||
|
self.PackageLength = 158
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while self.running:
|
||||||
|
if self.method == 'serial':
|
||||||
|
end_time = time.time()
|
||||||
|
if end_time-start_time > 2: #超过2s未收到数据
|
||||||
|
self.status_code = 0 #状态异常
|
||||||
|
count = self.sock.inWaiting() # 获取串口缓冲区数据
|
||||||
|
if count:
|
||||||
|
start_time = time.time()
|
||||||
|
self.status_code = 1 # 收到数据,状态正常
|
||||||
|
# 接收和存储数据
|
||||||
|
data = (self.sock.read(count))
|
||||||
|
self.receiveData = self.receiveData + data # 将接收数据存储在字符串中
|
||||||
|
elif self.method == 'tcp':
|
||||||
|
data = self.sock.recv(100)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
self.receiveData += data
|
||||||
|
if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind(
|
||||||
|
b'\x55\x55') >= self.PackageLength - 2:
|
||||||
|
|
||||||
|
index = self.receiveData.index(b'\xaa')
|
||||||
|
self.receiveData = self.receiveData[index:]
|
||||||
|
if len(self.receiveData) >= self.PackageLength:
|
||||||
|
onepackage = self.receiveData[:self.PackageLength]
|
||||||
|
if onepackage[7] != 0:
|
||||||
|
self.energy = onepackage[7] # 电量
|
||||||
|
self.receiveData = self.receiveData[self.PackageLength:]
|
||||||
|
dataMatrix = self.extract_packet(onepackage)
|
||||||
|
try:
|
||||||
|
with self.RingBufferLock:
|
||||||
|
self.__ringBuffer.appendBuffer(dataMatrix)
|
||||||
|
except:
|
||||||
|
print("锁:写入异常")
|
||||||
|
self.sock.close()
|
||||||
|
|
||||||
|
except ConnectionResetError:
|
||||||
|
self.status_code = 0 # 状态异常
|
||||||
|
print("Connection was reset by the peer.")
|
||||||
|
except SerialException as Se:
|
||||||
|
self.status_code = 0
|
||||||
|
print('串口通信异常!请检查适配器')
|
||||||
|
|
||||||
|
|
||||||
|
def process_packet(self):
|
||||||
|
if self.circular_buffer.buffer_length > 158:
|
||||||
|
packet = self.circular_buffer.extract_packet()
|
||||||
|
|
||||||
|
if packet:
|
||||||
|
# Here you would parse the packet according to the protocol
|
||||||
|
# print("Received packet:%s,index:%s", len(packet),str(integer_value))
|
||||||
|
return packet
|
||||||
|
else:
|
||||||
|
print("Received Nothing")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def getDataViaSSVEP(self,count):
|
||||||
|
'''
|
||||||
|
ssvep的视觉通道,共8个通道
|
||||||
|
@param count: 每通道读取的数值数量
|
||||||
|
@return: 返回最新的数值
|
||||||
|
'''
|
||||||
|
data=self.getData(count)
|
||||||
|
data=data[:8,:]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def getImpedance(self, data):
|
||||||
|
'''
|
||||||
|
获取阻抗值,已经放大100倍,单位是kΩ
|
||||||
|
@param data: 准备计算的通道数据,每通道200个值,注意不要把信号打标的通道传进来
|
||||||
|
@return: 返回各个通道的阻抗值
|
||||||
|
'''
|
||||||
|
impedanceList = []
|
||||||
|
for channelindex in range(data.shape[0]):
|
||||||
|
if len(data[channelindex]) > 0:
|
||||||
|
data_list = []
|
||||||
|
# 设计陷波滤波器,去除50Hz成分
|
||||||
|
is50filter = True
|
||||||
|
if is50filter:
|
||||||
|
b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽,1000是采样频率
|
||||||
|
data_list = signal.lfilter(b, a, data[channelindex].tolist())
|
||||||
|
|
||||||
|
else:
|
||||||
|
data_list.extend(data[channelindex].tolist())
|
||||||
|
|
||||||
|
data_list = data_list[-1000:]
|
||||||
|
# 执行FFT
|
||||||
|
fft_result = np.fft.fft(data_list)
|
||||||
|
fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果
|
||||||
|
freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴
|
||||||
|
|
||||||
|
# y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())],
|
||||||
|
# fft_magnitude[1:-1] * 2 / len(t[0].tolist()),
|
||||||
|
# [fft_magnitude[-1] / len(t[0].tolist())]))
|
||||||
|
|
||||||
|
# 找到幅值最大的频率成分的索引(忽略直流分量,即索引0)
|
||||||
|
max_index = np.argmax(fft_magnitude[1:])
|
||||||
|
|
||||||
|
# 获取最大幅值的频率索引(加上1,因为索引0是直流分量)
|
||||||
|
freq_index = max_index + 1
|
||||||
|
|
||||||
|
# 获取最大幅值
|
||||||
|
max_magnitude = fft_magnitude[freq_index]
|
||||||
|
|
||||||
|
# 阻抗
|
||||||
|
import math
|
||||||
|
result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4
|
||||||
|
result *= 0.44 * 100 # 统一放大100倍
|
||||||
|
impedanceList.append(int(result))
|
||||||
|
# print(max_magnitude, result)
|
||||||
|
else:
|
||||||
|
impedanceList.append(0)
|
||||||
|
# impedances = ":".join(map(str, impedanceList))
|
||||||
|
impedances = np.array(impedanceList)
|
||||||
|
impedances = impedances[:8]
|
||||||
|
return impedances
|
||||||
|
def getData(self,count):
|
||||||
|
'''
|
||||||
|
获取最新的数据
|
||||||
|
@param count: 每通道返回的最数值数目
|
||||||
|
@return: 所有通道的最新count个数值
|
||||||
|
'''
|
||||||
|
data=None
|
||||||
|
try:
|
||||||
|
with self.RingBufferLock:
|
||||||
|
data = self.__ringBuffer.getData(count)
|
||||||
|
except:
|
||||||
|
print("锁:读取异常")
|
||||||
|
# self.RingBufferLock.release()
|
||||||
|
|
||||||
|
|
||||||
|
return data
|
||||||
|
def GetDataLenCount(self):
|
||||||
|
'''
|
||||||
|
获取最新缓存中每个通道的数量
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
return self.__ringBuffer.nUpdate
|
||||||
|
|
||||||
|
def ResetAll(self):
|
||||||
|
'''
|
||||||
|
清空缓存
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
with self.RingBufferLock:
|
||||||
|
self.__ringBuffer.resetAllPara()
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Usage
|
||||||
|
Linker = SunnyLinker64('127.0.0.1', 5086, 1000, 65)
|
||||||
|
Linker.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(0.005)
|
||||||
|
if(Linker.count()>0):
|
||||||
|
# print(Linker.ringBuffer.nUpdate)
|
||||||
|
t = Linker.getData()
|
||||||
|
print(t.shape[1], Linker.count())
|
||||||
|
# Linker.ringBuffer.nUpdate=0
|
||||||
|
# time.sleep(0.2)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
Linker.stop()
|
||||||
193
Debug_64ch_Decoder_Optimize/Device/protocol.py
Normal file
193
Debug_64ch_Decoder_Optimize/Device/protocol.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
from typing import List, Tuple, Union, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolFrame:
|
||||||
|
# 协议常量
|
||||||
|
FRAME_HEADER = 0xAA
|
||||||
|
FRAME_TAIL1 = 0x55
|
||||||
|
FRAME_TAIL2 = 0x55
|
||||||
|
RESERVED_SIZE = 6
|
||||||
|
MIN_FRAME_SIZE = 13 # 帧头1 + 功能1 + 长度2 + 预留6 + CRC1 + 包尾2
|
||||||
|
MAX_DATA_LENGTH = 0xFFFF # 最大数据长度 (2字节能表示的最大值)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_crc8(data: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
计算CRC8校验值
|
||||||
|
Args:
|
||||||
|
data: 需要计算CRC的数据
|
||||||
|
Returns:
|
||||||
|
一个字节的CRC值(bytes类型)
|
||||||
|
"""
|
||||||
|
crc = 0
|
||||||
|
for byte in data:
|
||||||
|
crc ^= byte
|
||||||
|
for _ in range(8):
|
||||||
|
crc = ((crc << 1) ^ 0x07 if crc & 0x80 else crc << 1) & 0xFF
|
||||||
|
return bytes([crc])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pack(cls, function, data: Union[bytes, bytearray, List[int]],
|
||||||
|
reserved: Optional[Union[bytes, bytearray, List[int]]] = None) -> bytes:
|
||||||
|
"""
|
||||||
|
协议打包函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function: 功能码 (1字节)
|
||||||
|
data: 数据块
|
||||||
|
reserved: 预留字节(6字节,可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
打包后的字节数据
|
||||||
|
"""
|
||||||
|
# 检查功能码
|
||||||
|
if function != None:
|
||||||
|
if not 0 <= function <= 0xFF:
|
||||||
|
raise ValueError("功能码必须是1字节")
|
||||||
|
|
||||||
|
# 转换数据为bytearray
|
||||||
|
if isinstance(data, list):
|
||||||
|
data = bytearray(data)
|
||||||
|
elif isinstance(data, bytes):
|
||||||
|
data = bytearray(data)
|
||||||
|
|
||||||
|
# 检查数据长度
|
||||||
|
data_length = len(data)
|
||||||
|
if data_length > cls.MAX_DATA_LENGTH:
|
||||||
|
raise ValueError(f"数据长度超过最大值 {cls.MAX_DATA_LENGTH}")
|
||||||
|
|
||||||
|
# 处理预留字节
|
||||||
|
if reserved is None:
|
||||||
|
reserved = bytearray([0] * cls.RESERVED_SIZE)
|
||||||
|
else:
|
||||||
|
if isinstance(reserved, list):
|
||||||
|
reserved = bytearray(reserved)
|
||||||
|
elif isinstance(reserved, bytes):
|
||||||
|
reserved = bytearray(reserved)
|
||||||
|
if len(reserved) != cls.RESERVED_SIZE:
|
||||||
|
raise ValueError(f"预留字节必须是{cls.RESERVED_SIZE}字节")
|
||||||
|
|
||||||
|
# 构建帧
|
||||||
|
frame = bytearray([cls.FRAME_HEADER]) # 帧头 (1字节)
|
||||||
|
if function != None:
|
||||||
|
frame.append(function) # 功能码 (1字节)
|
||||||
|
data_length+=6
|
||||||
|
|
||||||
|
# 数据长度 (2字节,大端序)
|
||||||
|
frame.append((data_length >> 8) & 0xFF) # 高字节
|
||||||
|
frame.append(data_length & 0xFF) # 低字节
|
||||||
|
|
||||||
|
if function != None:
|
||||||
|
frame.extend(reserved) # 预留字节 (6字节)
|
||||||
|
frame.extend(data) # 数据块 (变长)
|
||||||
|
|
||||||
|
# 计算CRC (从功能码开始到数据块结束)
|
||||||
|
crc = cls.calculate_crc8(frame[1:]) # 不包含帧头
|
||||||
|
frame.extend(crc) # CRC校验 (1字节)
|
||||||
|
|
||||||
|
# 添加帧尾
|
||||||
|
frame.extend([cls.FRAME_TAIL1, cls.FRAME_TAIL2]) # 帧尾 (2字节)
|
||||||
|
|
||||||
|
return bytes(frame)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unpack(cls, data: Union[bytes, bytearray]) -> Tuple[int, bytearray, bytearray]:
|
||||||
|
"""
|
||||||
|
协议解包函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 待解析的字节数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(功能码, 数据块, 预留字节)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当数据格式不正确时
|
||||||
|
"""
|
||||||
|
# 检查数据长度
|
||||||
|
if len(data) < cls.MIN_FRAME_SIZE:
|
||||||
|
raise ValueError("数据长度不足")
|
||||||
|
|
||||||
|
# 检查帧头
|
||||||
|
if data[0] != cls.FRAME_HEADER:
|
||||||
|
raise ValueError("帧头错误")
|
||||||
|
|
||||||
|
# 检查帧尾
|
||||||
|
if data[-2:] != bytes([cls.FRAME_TAIL1, cls.FRAME_TAIL2]):
|
||||||
|
raise ValueError("帧尾错误")
|
||||||
|
|
||||||
|
# 解析基本信息
|
||||||
|
function = data[1] # 功能码 (1字节)
|
||||||
|
|
||||||
|
# 数据长度 (2字节,大端序)
|
||||||
|
data_length = (data[2] << 8) | data[3]
|
||||||
|
|
||||||
|
reserved = data[4:10] # 预留字节 (6字节)
|
||||||
|
|
||||||
|
# 检查数据长度
|
||||||
|
expected_length = cls.MIN_FRAME_SIZE + data_length
|
||||||
|
if len(data) != expected_length:
|
||||||
|
raise ValueError(f"数据长度不匹配: 期望{expected_length}字节,实际{len(data)}字节")
|
||||||
|
|
||||||
|
# 提取数据块
|
||||||
|
payload = data[10:10 + data_length]
|
||||||
|
|
||||||
|
# 验证CRC (从功能码开始到数据块结束)
|
||||||
|
received_crc = data[-3]
|
||||||
|
calculated_crc = cls.calculate_crc8(data[1:-3])[0] # 获取字节值
|
||||||
|
|
||||||
|
if received_crc != calculated_crc:
|
||||||
|
raise ValueError(f"CRC校验失败: 期望{calculated_crc:02X},实际{received_crc:02X}")
|
||||||
|
|
||||||
|
return function, bytearray(payload), bytearray(reserved)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def print_hex(data: bytes, label: str = ""):
|
||||||
|
"""打印十六进制数据,并按字节添加空格"""
|
||||||
|
hex_str = ' '.join([f"{b:02X}" for b in data])
|
||||||
|
if label:
|
||||||
|
print(f"{label}: {hex_str}")
|
||||||
|
else:
|
||||||
|
print(hex_str)
|
||||||
|
|
||||||
|
|
||||||
|
def print_frame_details(data: bytes):
|
||||||
|
"""打印帧的详细信息"""
|
||||||
|
print("帧详细信息:")
|
||||||
|
print(f"帧头: {data[0]:02X}")
|
||||||
|
print(f"功能码: {data[1]:02X}")
|
||||||
|
print(f"数据长度: {data[2]:02X} {data[3]:02X} ({(data[2] << 8) | data[3]}字节)")
|
||||||
|
print(f"预留字节: {' '.join([f'{b:02X}' for b in data[4:10]])}")
|
||||||
|
data_length = (data[2] << 8) | data[3]
|
||||||
|
print(f"数据块: {' '.join([f'{b:02X}' for b in data[10:10 + data_length]])}")
|
||||||
|
print(f"CRC校验: {data[-3]:02X}")
|
||||||
|
print(f"帧尾: {data[-2]:02X} {data[-1]:02X}")
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
def example_usage():
|
||||||
|
try:
|
||||||
|
|
||||||
|
|
||||||
|
# 示例1:简单数据打包
|
||||||
|
function_code = 0x01
|
||||||
|
data = [0x1]
|
||||||
|
packed_data = ProtocolFrame.pack(function_code, data)
|
||||||
|
print_hex(packed_data, "示例1 - 完整帧")
|
||||||
|
print_frame_details(packed_data)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 示例3:解包验证
|
||||||
|
function, payload, reserved = ProtocolFrame.unpack(packed_data)
|
||||||
|
print("解包结果:")
|
||||||
|
print(f"功能码: 0x{function:02X}")
|
||||||
|
print_hex(payload, "数据块")
|
||||||
|
print_hex(reserved, "预留字节")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"错误: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
example_usage()
|
||||||
409
Debug_64ch_Decoder_Optimize/MI/Algorithm/conformer_2class.py
Normal file
409
Debug_64ch_Decoder_Optimize/MI/Algorithm/conformer_2class.py
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
"""
|
||||||
|
EEG Conformer
|
||||||
|
|
||||||
|
Convolutional Transformer for EEG decoding
|
||||||
|
|
||||||
|
Couple CNN and Transformer in a concise manner with amazing results
|
||||||
|
"""
|
||||||
|
# remember to change paths
|
||||||
|
|
||||||
|
import os
|
||||||
|
gpus = [0]
|
||||||
|
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
from einops import rearrange
|
||||||
|
from einops.layers.torch import Rearrange, Reduce
|
||||||
|
# from common_spatial_pattern import csp
|
||||||
|
|
||||||
|
# from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from torch.backends import cudnn
|
||||||
|
cudnn.benchmark = True
|
||||||
|
cudnn.deterministic = True
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
# writer = SummaryWriter('./TensorBoardX/')
|
||||||
|
|
||||||
|
|
||||||
|
# Convolution module
|
||||||
|
# use conv to capture local features, instead of postion embedding.
|
||||||
|
class PatchEmbedding(nn.Module):
|
||||||
|
def __init__(self, emb_size=40,n_chan=8):
|
||||||
|
# self.patch_size = patch_size
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.shallownet = nn.Sequential(
|
||||||
|
nn.Conv2d(1, 40, (1, 25), (1, 1)),
|
||||||
|
nn.Conv2d(40, 40, (n_chan, 1), (1, 1)),
|
||||||
|
nn.BatchNorm2d(40),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.projection = nn.Sequential(
|
||||||
|
nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly
|
||||||
|
Rearrange('b e (h) (w) -> b (h w) e'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
b, _, _, _ = x.shape
|
||||||
|
x = self.shallownet(x)
|
||||||
|
x = self.projection(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, emb_size, num_heads, dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.emb_size = emb_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.keys = nn.Linear(emb_size, emb_size)
|
||||||
|
self.queries = nn.Linear(emb_size, emb_size)
|
||||||
|
self.values = nn.Linear(emb_size, emb_size)
|
||||||
|
self.att_drop = nn.Dropout(dropout)
|
||||||
|
self.projection = nn.Linear(emb_size, emb_size)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
|
||||||
|
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
|
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
|
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
|
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||||
|
if mask is not None:
|
||||||
|
fill_value = torch.finfo(torch.float32).min
|
||||||
|
energy.mask_fill(~mask, fill_value)
|
||||||
|
|
||||||
|
scaling = self.emb_size ** (1 / 2)
|
||||||
|
att = F.softmax(energy / scaling, dim=-1)
|
||||||
|
att = self.att_drop(att)
|
||||||
|
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
out = self.projection(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAdd(nn.Module):
|
||||||
|
def __init__(self, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
res = x
|
||||||
|
x = self.fn(x, **kwargs)
|
||||||
|
x += res
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForwardBlock(nn.Sequential):
|
||||||
|
def __init__(self, emb_size, expansion, drop_p):
|
||||||
|
super().__init__(
|
||||||
|
nn.Linear(emb_size, expansion * emb_size),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(drop_p),
|
||||||
|
nn.Linear(expansion * emb_size, emb_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderBlock(nn.Sequential):
|
||||||
|
def __init__(self,
|
||||||
|
emb_size,
|
||||||
|
num_heads=10,
|
||||||
|
drop_p=0.5,
|
||||||
|
forward_expansion=4,
|
||||||
|
forward_drop_p=0.5):
|
||||||
|
super().__init__(
|
||||||
|
ResidualAdd(nn.Sequential(
|
||||||
|
nn.LayerNorm(emb_size),
|
||||||
|
MultiHeadAttention(emb_size, num_heads, drop_p),
|
||||||
|
nn.Dropout(drop_p)
|
||||||
|
)),
|
||||||
|
ResidualAdd(nn.Sequential(
|
||||||
|
nn.LayerNorm(emb_size),
|
||||||
|
FeedForwardBlock(
|
||||||
|
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
|
||||||
|
nn.Dropout(drop_p)
|
||||||
|
)
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Sequential):
|
||||||
|
def __init__(self, depth, emb_size):
|
||||||
|
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationHead(nn.Sequential):
|
||||||
|
def __init__(self, emb_size, n_classes):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# global average pooling
|
||||||
|
self.clshead = nn.Sequential(
|
||||||
|
Reduce('b n e -> b e', reduction='mean'),
|
||||||
|
nn.LayerNorm(emb_size),
|
||||||
|
nn.Linear(emb_size, n_classes)
|
||||||
|
)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(2440, 256),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.Linear(256, 32),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(32, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.contiguous().view(x.size(0), -1)
|
||||||
|
out = self.fc(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Conformer(nn.Sequential):
|
||||||
|
def __init__(self, emb_size=40, depth=6, n_classes=2,n_chan=8, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
|
||||||
|
PatchEmbedding(emb_size,n_chan),
|
||||||
|
TransformerEncoder(depth, emb_size),
|
||||||
|
ClassificationHead(emb_size, n_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExP():
|
||||||
|
def __init__(self,n_chan):
|
||||||
|
super(ExP, self).__init__()
|
||||||
|
self.n_chan = n_chan
|
||||||
|
self.batch_size = 24
|
||||||
|
self.n_epochs = 250
|
||||||
|
self.c_dim = 4
|
||||||
|
self.lr = 0.0002
|
||||||
|
self.b1 = 0.5
|
||||||
|
self.b2 = 0.999
|
||||||
|
|
||||||
|
self.start_epoch = 0
|
||||||
|
# 创建目录
|
||||||
|
os.makedirs("online_Models", exist_ok=True)
|
||||||
|
self.log_write = open("./online_Models/log_result.txt", "w")
|
||||||
|
|
||||||
|
|
||||||
|
self.Tensor = torch.cuda.FloatTensor
|
||||||
|
self.LongTensor = torch.cuda.LongTensor
|
||||||
|
|
||||||
|
self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
|
||||||
|
|
||||||
|
self.model = Conformer(n_chan=self.n_chan).cuda()
|
||||||
|
self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
|
||||||
|
self.model = self.model.cuda()
|
||||||
|
|
||||||
|
# self.model = EEGNet().cuda()
|
||||||
|
# self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))])
|
||||||
|
# self.model = self.model.cuda()
|
||||||
|
# summary(self.model, (1, 8, 1000))
|
||||||
|
|
||||||
|
|
||||||
|
# Segmentation and Reconstruction (S&R) data augmentation
|
||||||
|
def interaug(self, timg, label):
|
||||||
|
# 确保输入是 numpy 数组(CPU)
|
||||||
|
if isinstance(timg, torch.Tensor):
|
||||||
|
timg = timg.cpu().numpy()
|
||||||
|
if isinstance(label, torch.Tensor):
|
||||||
|
label = label.cpu().numpy()
|
||||||
|
|
||||||
|
aug_data = []
|
||||||
|
aug_label = []
|
||||||
|
for cls4aug in range(2):
|
||||||
|
cls_idx = np.where(label == cls4aug + 1)
|
||||||
|
tmp_data = timg[cls_idx]
|
||||||
|
tmp_label = label[cls_idx]
|
||||||
|
tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, self.n_chan, 1000))
|
||||||
|
for ri in range(int(self.batch_size / 2)):
|
||||||
|
for rj in range(8):
|
||||||
|
rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
|
||||||
|
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
|
||||||
|
rj * 125:(rj + 1) * 125]
|
||||||
|
|
||||||
|
aug_data.append(tmp_aug_data)
|
||||||
|
aug_label.append(tmp_label[:int(self.batch_size / 2)])
|
||||||
|
aug_data = np.concatenate(aug_data)
|
||||||
|
aug_label = np.concatenate(aug_label)
|
||||||
|
aug_shuffle = np.random.permutation(len(aug_data))
|
||||||
|
aug_data = aug_data[aug_shuffle, :, :]
|
||||||
|
aug_label = aug_label[aug_shuffle]
|
||||||
|
|
||||||
|
# 返回 numpy 数组,由调用方决定是否移到 GPU
|
||||||
|
return aug_data, aug_label
|
||||||
|
|
||||||
|
def train(self,all_data,all_label,model_path):
|
||||||
|
all_data = np.array(all_data);all_label = np.array(all_label)
|
||||||
|
all_data = np.expand_dims(all_data, axis=1)
|
||||||
|
train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2,
|
||||||
|
random_state=42, stratify=all_label,shuffle=True)
|
||||||
|
|
||||||
|
# === 优化:一次性预生成增强数据,避免每个 batch 都重复计算 ===
|
||||||
|
aug_data, aug_label = self.interaug(train_data, train_label)
|
||||||
|
# 将原始数据和增强数据合并,再一起打乱
|
||||||
|
train_data_full = np.concatenate([train_data, aug_data], axis=0)
|
||||||
|
train_label_full = np.concatenate([train_label, aug_label], axis=0)
|
||||||
|
shuffle_idx = np.random.permutation(len(train_data_full))
|
||||||
|
train_data_full = train_data_full[shuffle_idx]
|
||||||
|
train_label_full = train_label_full[shuffle_idx]
|
||||||
|
|
||||||
|
img = torch.from_numpy(train_data_full)
|
||||||
|
label = torch.from_numpy(train_label_full-1)
|
||||||
|
|
||||||
|
dataset = torch.utils.data.TensorDataset(img, label)
|
||||||
|
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
test_data = torch.from_numpy(test_data)
|
||||||
|
test_label = torch.from_numpy(test_label-1)
|
||||||
|
test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
|
||||||
|
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
# Optimizers
|
||||||
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
|
||||||
|
|
||||||
|
test_data = Variable(test_data.type(self.Tensor))
|
||||||
|
test_label = Variable(test_label.type(self.LongTensor))
|
||||||
|
|
||||||
|
bestAcc = 0
|
||||||
|
averAcc = 0
|
||||||
|
num = 0
|
||||||
|
Y_true = 0
|
||||||
|
Y_pred = 0
|
||||||
|
|
||||||
|
# Train the cnn model
|
||||||
|
for e in range(self.n_epochs):
|
||||||
|
# in_epoch = time.time()
|
||||||
|
self.model.train()
|
||||||
|
for i, (img, label) in enumerate(self.dataloader):
|
||||||
|
|
||||||
|
img = Variable(img.cuda().type(self.Tensor))
|
||||||
|
label = Variable(label.cuda().type(self.LongTensor))
|
||||||
|
|
||||||
|
outputs = self.model(img)
|
||||||
|
|
||||||
|
loss = self.criterion_cls(outputs, label)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
# out_epoch = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
# test process
|
||||||
|
if (e + 1) % 1 == 0:
|
||||||
|
self.model.eval()
|
||||||
|
Cls = self.model(test_data)
|
||||||
|
|
||||||
|
loss_test = self.criterion_cls(Cls, test_label)
|
||||||
|
y_pred = torch.max(Cls, 1)[1]
|
||||||
|
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
|
||||||
|
train_pred = torch.max(outputs, 1)[1]
|
||||||
|
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
||||||
|
|
||||||
|
print('Epoch:', e,
|
||||||
|
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
||||||
|
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
||||||
|
' Train accuracy %.6f' % train_acc,
|
||||||
|
' Test accuracy is %.6f' % acc)
|
||||||
|
|
||||||
|
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
||||||
|
num = num + 1
|
||||||
|
averAcc = averAcc + acc
|
||||||
|
if acc > bestAcc:
|
||||||
|
bestAcc = acc
|
||||||
|
Y_true = test_label
|
||||||
|
Y_pred = y_pred
|
||||||
|
|
||||||
|
|
||||||
|
torch.save(self.model, model_path)
|
||||||
|
averAcc = averAcc / num
|
||||||
|
print('The average accuracy is:', averAcc)
|
||||||
|
print('The best accuracy is:', bestAcc)
|
||||||
|
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
||||||
|
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
||||||
|
|
||||||
|
return bestAcc, averAcc, Y_true, Y_pred
|
||||||
|
# writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
def onlineTrain(data_queue,result_queue):
|
||||||
|
import torch
|
||||||
|
print(f"[DEBUG] torch.__version__ = {torch.__version__}")
|
||||||
|
print(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}")
|
||||||
|
try:
|
||||||
|
starttime = datetime.datetime.now()
|
||||||
|
|
||||||
|
# seed_n = np.random.randint(2025)
|
||||||
|
seed_n = 1877
|
||||||
|
random.seed(seed_n)
|
||||||
|
np.random.seed(seed_n)
|
||||||
|
torch.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed_all(seed_n)
|
||||||
|
|
||||||
|
|
||||||
|
# 从队列获取训练数据
|
||||||
|
data = data_queue.get(timeout=30)
|
||||||
|
all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan']
|
||||||
|
exp = ExP(n_chan)
|
||||||
|
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
|
||||||
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
||||||
|
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
||||||
|
|
||||||
|
endtime = datetime.datetime.now()
|
||||||
|
print('train duration: ',str(endtime - starttime))
|
||||||
|
|
||||||
|
# 将模型或参数传回
|
||||||
|
result_queue.put({
|
||||||
|
'status': 'success',
|
||||||
|
'model_state': model_path, # 或保存路径
|
||||||
|
'timestamp': time.time()
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put({'status': 'error', 'msg': str(e)})
|
||||||
|
|
||||||
|
def offlineTrain(all_data,all_label,modelPath):
|
||||||
|
starttime = datetime.datetime.now()
|
||||||
|
|
||||||
|
# seed_n = np.random.randint(2025)
|
||||||
|
seed_n = 1877
|
||||||
|
print('seed is ' + str(seed_n))
|
||||||
|
random.seed(seed_n)
|
||||||
|
np.random.seed(seed_n)
|
||||||
|
torch.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed_all(seed_n)
|
||||||
|
|
||||||
|
exp = ExP()
|
||||||
|
|
||||||
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
||||||
|
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
||||||
|
|
||||||
|
endtime = datetime.datetime.now()
|
||||||
|
print('train duration: ',str(endtime - starttime))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(time.asctime(time.localtime(time.time())))
|
||||||
|
print(time.asctime(time.localtime(time.time())))
|
||||||
382
Debug_64ch_Decoder_Optimize/MI/Algorithm/conformer_2class_cpu.py
Normal file
382
Debug_64ch_Decoder_Optimize/MI/Algorithm/conformer_2class_cpu.py
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
"""
|
||||||
|
EEG Conformer
|
||||||
|
|
||||||
|
Convolutional Transformer for EEG decoding
|
||||||
|
|
||||||
|
Couple CNN and Transformer in a concise manner with amazing results
|
||||||
|
"""
|
||||||
|
# remember to change paths
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
from einops import rearrange
|
||||||
|
from einops.layers.torch import Rearrange, Reduce
|
||||||
|
from torch.backends import cudnn
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
# writer = SummaryWriter('./TensorBoardX/')
|
||||||
|
|
||||||
|
|
||||||
|
# Convolution module
|
||||||
|
# use conv to capture local features, instead of postion embedding.
|
||||||
|
class PatchEmbedding(nn.Module):
|
||||||
|
def __init__(self, emb_size=40):
|
||||||
|
# self.patch_size = patch_size
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.shallownet = nn.Sequential(
|
||||||
|
nn.Conv2d(1, 40, (1, 25), (1, 1)),
|
||||||
|
nn.Conv2d(40, 40, (8, 1), (1, 1)),
|
||||||
|
nn.BatchNorm2d(40),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.projection = nn.Sequential(
|
||||||
|
nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly
|
||||||
|
Rearrange('b e (h) (w) -> b (h w) e'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
b, _, _, _ = x.shape
|
||||||
|
x = self.shallownet(x)
|
||||||
|
x = self.projection(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, emb_size, num_heads, dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.emb_size = emb_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.keys = nn.Linear(emb_size, emb_size)
|
||||||
|
self.queries = nn.Linear(emb_size, emb_size)
|
||||||
|
self.values = nn.Linear(emb_size, emb_size)
|
||||||
|
self.att_drop = nn.Dropout(dropout)
|
||||||
|
self.projection = nn.Linear(emb_size, emb_size)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
|
||||||
|
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
|
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
|
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||||
|
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||||
|
if mask is not None:
|
||||||
|
fill_value = torch.finfo(torch.float32).min
|
||||||
|
energy.mask_fill(~mask, fill_value)
|
||||||
|
|
||||||
|
scaling = self.emb_size ** (1 / 2)
|
||||||
|
att = F.softmax(energy / scaling, dim=-1)
|
||||||
|
att = self.att_drop(att)
|
||||||
|
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
out = self.projection(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAdd(nn.Module):
|
||||||
|
def __init__(self, fn):
|
||||||
|
super().__init__()
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
res = x
|
||||||
|
x = self.fn(x, **kwargs)
|
||||||
|
x += res
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForwardBlock(nn.Sequential):
|
||||||
|
def __init__(self, emb_size, expansion, drop_p):
|
||||||
|
super().__init__(
|
||||||
|
nn.Linear(emb_size, expansion * emb_size),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(drop_p),
|
||||||
|
nn.Linear(expansion * emb_size, emb_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderBlock(nn.Sequential):
|
||||||
|
def __init__(self,
|
||||||
|
emb_size,
|
||||||
|
num_heads=10,
|
||||||
|
drop_p=0.5,
|
||||||
|
forward_expansion=4,
|
||||||
|
forward_drop_p=0.5):
|
||||||
|
super().__init__(
|
||||||
|
ResidualAdd(nn.Sequential(
|
||||||
|
nn.LayerNorm(emb_size),
|
||||||
|
MultiHeadAttention(emb_size, num_heads, drop_p),
|
||||||
|
nn.Dropout(drop_p)
|
||||||
|
)),
|
||||||
|
ResidualAdd(nn.Sequential(
|
||||||
|
nn.LayerNorm(emb_size),
|
||||||
|
FeedForwardBlock(
|
||||||
|
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
|
||||||
|
nn.Dropout(drop_p)
|
||||||
|
)
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Sequential):
|
||||||
|
def __init__(self, depth, emb_size):
|
||||||
|
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationHead(nn.Sequential):
|
||||||
|
def __init__(self, emb_size, n_classes):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# global average pooling
|
||||||
|
self.clshead = nn.Sequential(
|
||||||
|
Reduce('b n e -> b e', reduction='mean'),
|
||||||
|
nn.LayerNorm(emb_size),
|
||||||
|
nn.Linear(emb_size, n_classes)
|
||||||
|
)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(2440, 256),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.Linear(256, 32),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(32, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.contiguous().view(x.size(0), -1)
|
||||||
|
out = self.fc(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Conformer(nn.Sequential):
|
||||||
|
def __init__(self, emb_size=40, depth=6, n_classes=2, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
|
||||||
|
PatchEmbedding(emb_size),
|
||||||
|
TransformerEncoder(depth, emb_size),
|
||||||
|
ClassificationHead(emb_size, n_classes)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExP():
|
||||||
|
def __init__(self):
|
||||||
|
super(ExP, self).__init__()
|
||||||
|
self.batch_size = 24
|
||||||
|
self.n_epochs = 250
|
||||||
|
self.c_dim = 4
|
||||||
|
self.lr = 0.0002
|
||||||
|
self.b1 = 0.5
|
||||||
|
self.b2 = 0.999
|
||||||
|
|
||||||
|
self.start_epoch = 0
|
||||||
|
|
||||||
|
self.log_write = open("./online_Models/log_result.txt", "w")
|
||||||
|
|
||||||
|
# 自动选择设备:有 GPU 用 GPU,否则用 CPU
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
# self.device = torch.device("cpu")
|
||||||
|
print(f"Using device: {self.device}")
|
||||||
|
|
||||||
|
# 定义张量类型(不再强制使用 cuda)
|
||||||
|
self.Tensor = torch.FloatTensor
|
||||||
|
self.LongTensor = torch.LongTensor
|
||||||
|
|
||||||
|
# 将模型移到指定设备
|
||||||
|
self.model = Conformer().to(self.device)
|
||||||
|
|
||||||
|
# 损失函数也移到设备
|
||||||
|
self.criterion_cls = torch.nn.CrossEntropyLoss().to(self.device)
|
||||||
|
|
||||||
|
# self.model = EEGNet().cuda()
|
||||||
|
# self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))])
|
||||||
|
# self.model = self.model.cuda()
|
||||||
|
# summary(self.model, (1, 8, 1000))
|
||||||
|
|
||||||
|
|
||||||
|
# Segmentation and Reconstruction (S&R) data augmentation
|
||||||
|
def interaug(self, timg, label):
|
||||||
|
aug_data = []
|
||||||
|
aug_label = []
|
||||||
|
for cls4aug in range(2):
|
||||||
|
cls_idx = np.where(label == cls4aug + 1)
|
||||||
|
tmp_data = timg[cls_idx]
|
||||||
|
tmp_label = label[cls_idx]
|
||||||
|
tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, 8, 1000))
|
||||||
|
for ri in range(int(self.batch_size / 2)):
|
||||||
|
for rj in range(8):
|
||||||
|
rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
|
||||||
|
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
|
||||||
|
rj * 125:(rj + 1) * 125]
|
||||||
|
|
||||||
|
aug_data.append(tmp_aug_data)
|
||||||
|
aug_label.append(tmp_label[:int(self.batch_size / 2)])
|
||||||
|
aug_data = np.concatenate(aug_data)
|
||||||
|
aug_label = np.concatenate(aug_label)
|
||||||
|
aug_shuffle = np.random.permutation(len(aug_data))
|
||||||
|
aug_data = aug_data[aug_shuffle, :, :]
|
||||||
|
aug_label = aug_label[aug_shuffle]
|
||||||
|
|
||||||
|
aug_data = torch.from_numpy(aug_data).float().to(self.device)
|
||||||
|
aug_label = torch.from_numpy(aug_label - 1).long().to(self.device)
|
||||||
|
return aug_data, aug_label
|
||||||
|
|
||||||
|
def train(self,all_data,all_label,model_path):
|
||||||
|
all_data = np.array(all_data);all_label = np.array(all_label)
|
||||||
|
all_data = np.expand_dims(all_data, axis=1)
|
||||||
|
train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2,
|
||||||
|
random_state=42, stratify=all_label,shuffle=True)
|
||||||
|
# 转为 Tensor
|
||||||
|
img = torch.from_numpy(train_data).float().to(self.device)
|
||||||
|
label = torch.from_numpy(train_label - 1).long().to(self.device)
|
||||||
|
|
||||||
|
dataset = torch.utils.data.TensorDataset(img, label)
|
||||||
|
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
test_data = torch.from_numpy(test_data).float().to(self.device)
|
||||||
|
test_label = torch.from_numpy(test_label - 1).long().to(self.device)
|
||||||
|
test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
|
||||||
|
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
# Optimizers
|
||||||
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
|
||||||
|
|
||||||
|
bestAcc = 0
|
||||||
|
averAcc = 0
|
||||||
|
num = 0
|
||||||
|
Y_true = 0
|
||||||
|
Y_pred = 0
|
||||||
|
|
||||||
|
# Train the cnn model
|
||||||
|
for e in range(self.n_epochs):
|
||||||
|
# in_epoch = time.time()
|
||||||
|
self.model.train()
|
||||||
|
for i, (img, label) in enumerate(self.dataloader):
|
||||||
|
|
||||||
|
# data augmentation
|
||||||
|
aug_data, aug_label = self.interaug(train_data, train_label)
|
||||||
|
img = torch.cat((img, aug_data))
|
||||||
|
label = torch.cat((label, aug_label))
|
||||||
|
|
||||||
|
|
||||||
|
outputs = self.model(img)
|
||||||
|
|
||||||
|
loss = self.criterion_cls(outputs, label)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
# out_epoch = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
# test process
|
||||||
|
if (e + 1) % 1 == 0:
|
||||||
|
self.model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
Cls = self.model(test_data)
|
||||||
|
|
||||||
|
loss_test = self.criterion_cls(Cls, test_label)
|
||||||
|
y_pred = torch.max(Cls, 1)[1]
|
||||||
|
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
|
||||||
|
train_pred = torch.max(outputs, 1)[1]
|
||||||
|
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
|
||||||
|
|
||||||
|
print('Epoch:', e,
|
||||||
|
' Train loss: %.6f' % loss.detach().cpu().numpy(),
|
||||||
|
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
|
||||||
|
' Train accuracy %.6f' % train_acc,
|
||||||
|
' Test accuracy is %.6f' % acc)
|
||||||
|
|
||||||
|
self.log_write.write(str(e) + " " + str(acc) + "\n")
|
||||||
|
num = num + 1
|
||||||
|
averAcc = averAcc + acc
|
||||||
|
if acc > bestAcc:
|
||||||
|
bestAcc = acc
|
||||||
|
Y_true = test_label
|
||||||
|
Y_pred = y_pred
|
||||||
|
|
||||||
|
|
||||||
|
torch.save(self.model, model_path)
|
||||||
|
averAcc = averAcc / num
|
||||||
|
print('The average accuracy is:', averAcc)
|
||||||
|
print('The best accuracy is:', bestAcc)
|
||||||
|
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
|
||||||
|
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
|
||||||
|
|
||||||
|
return bestAcc, averAcc, Y_true, Y_pred
|
||||||
|
# writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
def onlineTrain(data_queue,result_queue):
|
||||||
|
try:
|
||||||
|
starttime = datetime.datetime.now()
|
||||||
|
|
||||||
|
# seed_n = np.random.randint(2025)
|
||||||
|
seed_n = 1877
|
||||||
|
random.seed(seed_n)
|
||||||
|
np.random.seed(seed_n)
|
||||||
|
torch.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed_all(seed_n)
|
||||||
|
|
||||||
|
exp = ExP()
|
||||||
|
# 从队列获取训练数据
|
||||||
|
data = data_queue.get(timeout=30)
|
||||||
|
all_data, all_label,model_path = data['data'], data['label'],data['modelPath']
|
||||||
|
print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path)
|
||||||
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path)
|
||||||
|
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
||||||
|
|
||||||
|
endtime = datetime.datetime.now()
|
||||||
|
print('train duration: ',str(endtime - starttime))
|
||||||
|
|
||||||
|
# 将模型或参数传回
|
||||||
|
result_queue.put({
|
||||||
|
'status': 'success',
|
||||||
|
'model_state': model_path, # 或保存路径
|
||||||
|
'timestamp': time.time()
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put({'status': 'error', 'msg': str(e)})
|
||||||
|
|
||||||
|
def offlineTrain(all_data,all_label,modelPath):
|
||||||
|
starttime = datetime.datetime.now()
|
||||||
|
|
||||||
|
# seed_n = np.random.randint(2025)
|
||||||
|
seed_n = 1877
|
||||||
|
print('seed is ' + str(seed_n))
|
||||||
|
random.seed(seed_n)
|
||||||
|
np.random.seed(seed_n)
|
||||||
|
torch.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed(seed_n)
|
||||||
|
torch.cuda.manual_seed_all(seed_n)
|
||||||
|
|
||||||
|
exp = ExP()
|
||||||
|
|
||||||
|
bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath)
|
||||||
|
print('THE BEST ACCURACY IS ' + str(bestAcc))
|
||||||
|
|
||||||
|
endtime = datetime.datetime.now()
|
||||||
|
print('train duration: ',str(endtime - starttime))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(time.asctime(time.localtime(time.time())))
|
||||||
|
print(time.asctime(time.localtime(time.time())))
|
||||||
184
Debug_64ch_Decoder_Optimize/MI/Algorithm/otherModels.py
Normal file
184
Debug_64ch_Decoder_Optimize/MI/Algorithm/otherModels.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
from torchsummary import summary
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def weights_init(m):
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.xavier_uniform_(m.weight)
|
||||||
|
# nn.init.constant(m.bias, 0) # bias may be none
|
||||||
|
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(m.weight)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def square_activation(x):
|
||||||
|
return torch.square(x)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_log(x):
|
||||||
|
return torch.clip(torch.log(x), min=1e-7, max=1e7)
|
||||||
|
|
||||||
|
|
||||||
|
class ShallowConvNet(nn.Module):
|
||||||
|
def __init__(self, num_classes=3, chans=19, samples=768):
|
||||||
|
super(ShallowConvNet, self).__init__()
|
||||||
|
self.conv_nums = 40
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
nn.Conv2d(1, self.conv_nums, (1, 25)),
|
||||||
|
nn.Conv2d(self.conv_nums, self.conv_nums, (chans, 1), bias=False),
|
||||||
|
nn.BatchNorm2d(self.conv_nums)
|
||||||
|
)
|
||||||
|
self.avgpool = nn.AvgPool2d(kernel_size=(1, 75), stride=(1, 15))
|
||||||
|
self.dropout = nn.Dropout()
|
||||||
|
|
||||||
|
out = torch.ones((1, 1, chans, samples))
|
||||||
|
out = self.features(out)
|
||||||
|
out = self.avgpool(out)
|
||||||
|
n_out_time = out.cpu().data.numpy().shape
|
||||||
|
self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.features(x)
|
||||||
|
x = square_activation(x)
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = safe_log(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
features = torch.flatten(x, 1)
|
||||||
|
cls = self.classifier(features)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class EEGNet(nn.Module):
|
||||||
|
def __init__(self, num_classes=2, chans=8, samples=1000, dropout_rate=0.5, kernel_length=64, F1=8,
|
||||||
|
F2=16,):
|
||||||
|
super(EEGNet, self).__init__()
|
||||||
|
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
nn.Conv2d(1, F1, kernel_size=(1, kernel_length), bias=False),
|
||||||
|
nn.BatchNorm2d(F1),
|
||||||
|
nn.Conv2d(F1, F1, kernel_size=(chans, 1), groups=F1, bias=False), # groups=F1 for depthWiseConv
|
||||||
|
nn.BatchNorm2d(F1),
|
||||||
|
nn.ELU(inplace=True),
|
||||||
|
# nn.ReLU(),
|
||||||
|
nn.AvgPool2d((1, 4)),
|
||||||
|
nn.Dropout(dropout_rate),
|
||||||
|
# for SeparableCon2D
|
||||||
|
# SeparableConv2D(F1, F2, kernel1_size=(1, 16), bias=False),
|
||||||
|
nn.Conv2d(F1, F1, kernel_size=(1, 16), groups=F1, bias=False), # groups=F1 for depthWiseConv
|
||||||
|
nn.BatchNorm2d(F1),
|
||||||
|
nn.ELU(inplace=True),
|
||||||
|
# nn.ReLU(),
|
||||||
|
nn.Conv2d(F1, F2, kernel_size=(1, 1), groups=1, bias=False), # point-wise cnn
|
||||||
|
nn.BatchNorm2d(F2),
|
||||||
|
# nn.ReLU(),
|
||||||
|
nn.ELU(inplace=True),
|
||||||
|
nn.AvgPool2d((1, 8)),
|
||||||
|
nn.Dropout(p=dropout_rate),
|
||||||
|
# nn.Dropout(p=0.5),
|
||||||
|
)
|
||||||
|
out = torch.ones((1, 1, chans, samples))
|
||||||
|
out = self.features(out)
|
||||||
|
n_out_time = out.cpu().data.numpy().shape
|
||||||
|
self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
conv_features = self.features(x)
|
||||||
|
features = torch.flatten(conv_features, 1)
|
||||||
|
cls = self.classifier(features)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class LMDA(nn.Module):
|
||||||
|
"""
|
||||||
|
LMDA-Net for the paper
|
||||||
|
"""
|
||||||
|
def __init__(self, chans=19, samples=768, num_classes=3, depth=9, kernel=75, channel_depth1=24, channel_depth2=9,
|
||||||
|
ave_depth=1, avepool=5):
|
||||||
|
super(LMDA, self).__init__()
|
||||||
|
self.ave_depth = ave_depth
|
||||||
|
self.channel_weight = nn.Parameter(torch.randn(depth, 1, chans), requires_grad=True)
|
||||||
|
nn.init.xavier_uniform_(self.channel_weight.data)
|
||||||
|
|
||||||
|
|
||||||
|
self.time_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(depth, channel_depth1, kernel_size=(1, 1), groups=1, bias=False),
|
||||||
|
nn.BatchNorm2d(channel_depth1),
|
||||||
|
nn.Conv2d(channel_depth1, channel_depth1, kernel_size=(1, kernel),
|
||||||
|
groups=channel_depth1, bias=False),
|
||||||
|
nn.BatchNorm2d(channel_depth1),
|
||||||
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
# self.avgPool1 = nn.AvgPool2d((1, 24))
|
||||||
|
self.chanel_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(channel_depth1, channel_depth2, kernel_size=(1, 1), groups=1, bias=False),
|
||||||
|
nn.BatchNorm2d(channel_depth2),
|
||||||
|
nn.Conv2d(channel_depth2, channel_depth2, kernel_size=(chans, 1), groups=channel_depth2, bias=False),
|
||||||
|
nn.BatchNorm2d(channel_depth2),
|
||||||
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = nn.Sequential(
|
||||||
|
nn.AvgPool3d(kernel_size=(1, 1, avepool)),
|
||||||
|
# nn.AdaptiveAvgPool3d((9, 1, 35)),
|
||||||
|
nn.Dropout(p=0.65),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 定义自动填充模块
|
||||||
|
out = torch.ones((1, 1, chans, samples))
|
||||||
|
out = torch.einsum('bdcw, hdc->bhcw', out, self.channel_weight)
|
||||||
|
out = self.time_conv(out)
|
||||||
|
out = self.chanel_conv(out)
|
||||||
|
out = self.norm(out)
|
||||||
|
n_out_time = out.cpu().data.numpy().shape
|
||||||
|
print('In ShallowNet, n_out_time shape: ', n_out_time)
|
||||||
|
self.classifier = nn.Linear(n_out_time[-1]*n_out_time[-2]*n_out_time[-3], num_classes)
|
||||||
|
|
||||||
|
def EEGDepthAttention(self, x):
|
||||||
|
# x: input features with shape [N, C, H, W]
|
||||||
|
|
||||||
|
N, C, H, W = x.size()
|
||||||
|
# K = W if W % 2 else W + 1
|
||||||
|
k = 7
|
||||||
|
adaptive_pool = nn.AdaptiveAvgPool2d((1, W))
|
||||||
|
conv = nn.Conv2d(1, 1, kernel_size=(k, 1), padding=(k//2, 0), bias=True).to(x.device) # original kernel k
|
||||||
|
nn.init.xavier_uniform_(conv.weight)
|
||||||
|
nn.init.constant_(conv.bias, 0)
|
||||||
|
softmax = nn.Softmax(dim=-2)
|
||||||
|
x_pool = adaptive_pool(x)
|
||||||
|
x_transpose = x_pool.transpose(-2, -3)
|
||||||
|
y = conv(x_transpose)
|
||||||
|
y = softmax(y)
|
||||||
|
y = y.transpose(-2, -3)
|
||||||
|
return y * C * x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.einsum('bdcw, hdc->bhcw', x, self.channel_weight)
|
||||||
|
|
||||||
|
x_time = self.time_conv(x) # batch, depth1, channel, samples_
|
||||||
|
x_time = self.EEGDepthAttention(x_time) # DA1
|
||||||
|
|
||||||
|
x = self.chanel_conv(x_time) # batch, depth2, 1, samples_
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
features = torch.flatten(x, 1)
|
||||||
|
cls = self.classifier(features)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = ShallowConvNet(num_classes=4, chans=22, samples=1125).cuda()
|
||||||
|
a = torch.randn(12, 1, 3, 875).cuda().float()
|
||||||
|
l2 = model(a)
|
||||||
|
model_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
|
||||||
|
summary(model, show_input=True)
|
||||||
|
|
||||||
|
print(l2.shape)
|
||||||
|
|
||||||
30
Debug_64ch_Decoder_Optimize/PubLibrary/InifileHelper.py
Normal file
30
Debug_64ch_Decoder_Optimize/PubLibrary/InifileHelper.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# -*-coding:utf-8 -*-
|
||||||
|
import configparser
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from audioop import error
|
||||||
|
|
||||||
|
BASE_DIR = os.getcwd()
|
||||||
|
IniFileName = os.path.join(BASE_DIR, 'config.ini')
|
||||||
|
# IniFileName=os.path.join( 'config.ini')
|
||||||
|
|
||||||
|
def IniWrite(section,keyname,value):
|
||||||
|
# 创建ConfigParser对象
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(IniFileName,encoding='utf-8')
|
||||||
|
with open(IniFileName, 'w') as configfile:
|
||||||
|
if not config.has_section(section):
|
||||||
|
config.add_section(section)
|
||||||
|
config[section][keyname]=str(value)
|
||||||
|
config.write(configfile)
|
||||||
|
|
||||||
|
def IniRead(section,key):
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(IniFileName,encoding='utf-8')
|
||||||
|
return config[section][key]
|
||||||
|
except error as e:
|
||||||
|
print(e)
|
||||||
|
# 读取特定section和键的值
|
||||||
|
return '5'
|
||||||
15
Debug_64ch_Decoder_Optimize/PubLibrary/RunOnce.py
Normal file
15
Debug_64ch_Decoder_Optimize/PubLibrary/RunOnce.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import ctypes
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def is_program_running(name='Global\\Decoder'):
|
||||||
|
# 创建互斥体
|
||||||
|
mutex_name =name
|
||||||
|
h_mutex = ctypes.windll.kernel32.CreateMutexW(None, False, mutex_name)
|
||||||
|
|
||||||
|
# 检查互斥体是否已经存在
|
||||||
|
if ctypes.windll.kernel32.GetLastError() == 183: # ERROR_ALREADY_EXISTS
|
||||||
|
print("程序已经在运行.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
418
Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/base.py
Normal file
418
Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/base.py
Normal file
@@ -0,0 +1,418 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Authors: Swolf <swolfforever@gmail.com>
|
||||||
|
# Date: 2021/1/07
|
||||||
|
# License: MIT License
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, List, Tuple, Union
|
||||||
|
import warnings
|
||||||
|
import numpy as np
|
||||||
|
from numpy import ndarray
|
||||||
|
from numpy.linalg import linalg
|
||||||
|
from scipy.linalg import solve, qr
|
||||||
|
from scipy.signal import sosfiltfilt, cheby1, cheb1ord
|
||||||
|
from sklearn.base import BaseEstimator, TransformerMixin, clone
|
||||||
|
|
||||||
|
|
||||||
|
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
|
||||||
|
"""Transform spatial filters to spatial patterns based on paper [1]_.
|
||||||
|
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
|
||||||
|
information from different channels to extract signals of interest from EEG signals, but if our goal is
|
||||||
|
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
|
||||||
|
from the obtained spatial filters.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
W : ndarray
|
||||||
|
Spatial filters, shape (n_channels, n_filters).
|
||||||
|
Cx : ndarray
|
||||||
|
Covariance matrix of eeg data, shape (n_channels, n_channels).
|
||||||
|
Cs : ndarray
|
||||||
|
Covariance matrix of source data, shape (n_channels, n_channels).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A : ndarray
|
||||||
|
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
|
||||||
|
Neuroimage 87 (2014): 96-110.
|
||||||
|
"""
|
||||||
|
# use linalg.solve instead of inv, makes it more stable
|
||||||
|
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
|
||||||
|
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
|
||||||
|
A = solve(Cs.T, np.dot(Cx, W).T).T
|
||||||
|
return A
|
||||||
|
|
||||||
|
|
||||||
|
class FilterBank(BaseEstimator, TransformerMixin):
|
||||||
|
"""
|
||||||
|
Filter bank decomposition is a bandpass filter array that divides the input signal into
|
||||||
|
multiple subband components and obtains the eigenvalues of each subband component.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
base_estimator : class
|
||||||
|
Estimator for model training and feature extraction.
|
||||||
|
filterbank : list[ndarray]
|
||||||
|
A bandpass filter bank used to divide the input signal into multiple subband components.
|
||||||
|
n_jobs : int
|
||||||
|
Sets the number of CPU working cores. The default is None.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] Chen X, Wang Y, Nakanishi M, et al. High-speed spelling with a noninvasive brain-computer interface[J].
|
||||||
|
Proceedings of the national academy of sciences, 2015, 112(44): E6058-E6067.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_estimator: BaseEstimator,
|
||||||
|
filterbank: List[ndarray],
|
||||||
|
n_jobs: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.base_estimator = base_estimator
|
||||||
|
self.filterbank = filterbank
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
|
||||||
|
def fit(self, X: ndarray, y: Optional[ndarray] = None, **kwargs):
|
||||||
|
"""
|
||||||
|
Training model
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : None
|
||||||
|
Training signal (parameters can be ignored, only used to maintain code structure).
|
||||||
|
y : None
|
||||||
|
Label data (ibid., ignorable).
|
||||||
|
Yf : None
|
||||||
|
Reference signal (ibid., ignorable).
|
||||||
|
"""
|
||||||
|
self.estimators_ = [
|
||||||
|
clone(self.base_estimator) for _ in range(len(self.filterbank))
|
||||||
|
]
|
||||||
|
X = self.transform_filterbank(X)
|
||||||
|
for i, est in enumerate(self.estimators_):
|
||||||
|
est.fit(X[i], y, **kwargs)
|
||||||
|
# def wrapper(est, X, y, kwargs):
|
||||||
|
# est.fit(X, y, **kwargs)
|
||||||
|
# return est
|
||||||
|
# self.estimators_ = Parallel(n_jobs=self.n_jobs)(
|
||||||
|
# delayed(wrapper)(est, X[i], y, kwargs) for i, est in enumerate(self.estimators_))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: ndarray, **kwargs):
|
||||||
|
"""
|
||||||
|
The parameters stored in self are used to convert X into features, and X is filtered through the filter bank to
|
||||||
|
obtain the eigenvalues of each subband component.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray, shape(n_trials, n_channels, n_samples)
|
||||||
|
Test the signal.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
feat : ndarray, shape(n_trials, n_fre)
|
||||||
|
Feature array.
|
||||||
|
"""
|
||||||
|
X = self.transform_filterbank(X)
|
||||||
|
feat = [est.transform(X[i], **kwargs) for i, est in enumerate(self.estimators_)]
|
||||||
|
# def wrapper(est, X, kwargs):
|
||||||
|
# retval = est.transform(X, **kwargs)
|
||||||
|
# return retval
|
||||||
|
# feat = Parallel(n_jobs=self.n_jobs)(
|
||||||
|
# delayed(wrapper)(est, X[i], kwargs) for i, est in enumerate(self.estimators_))
|
||||||
|
feat = np.concatenate(feat, axis=-1)
|
||||||
|
return feat
|
||||||
|
|
||||||
|
def transform_filterbank(self, X: ndarray):
|
||||||
|
"""
|
||||||
|
The input signal is filtered through a filter bank.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray, shape(n_trials, n_channels, n_samples)
|
||||||
|
Input signal.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Xs: ndarray, shape(Nfb, n_trials, n_channels, n_samples)
|
||||||
|
Individual subband components of the input signal.
|
||||||
|
"""
|
||||||
|
Xs = np.stack([sosfiltfilt(sos, X, axis=-1) for sos in self.filterbank])
|
||||||
|
return Xs
|
||||||
|
|
||||||
|
|
||||||
|
class FilterBankSSVEP(FilterBank):
|
||||||
|
"""
|
||||||
|
Filter bank analysis for SSVEP.
|
||||||
|
The SSVEP is analyzed using filter banks, that is, multiple filters are combined to decompose the SSVEP signal
|
||||||
|
into specific segments (subbands containing the original data) and obtain its characteristic data.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filterbank : list[ndarray]
|
||||||
|
The filter bank.
|
||||||
|
base_estimator : class
|
||||||
|
Estimator for model training and feature extraction.
|
||||||
|
filterweights : ndarray
|
||||||
|
Filter weight, default is None.
|
||||||
|
n_jobs : int
|
||||||
|
Sets the number of CPU working cores. The default is None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filterbank: List[ndarray],
|
||||||
|
base_estimator: BaseEstimator,
|
||||||
|
filterweights: Optional[ndarray] = None,
|
||||||
|
n_jobs: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.filterweights = filterweights
|
||||||
|
super().__init__(base_estimator, filterbank, n_jobs=n_jobs)
|
||||||
|
|
||||||
|
def transform(self, X: ndarray): # type: ignore[override]
|
||||||
|
"""
|
||||||
|
X is converted into features by using the parameters stored in self, and the eigenvalues of each subband
|
||||||
|
component are obtained after the input signal is filtered by the filter bank.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray, shape(n_trials, n_channels, n_samples)
|
||||||
|
Test the signal.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
features : ndarray, shape(n_trials, n_fre)
|
||||||
|
Feature array.
|
||||||
|
"""
|
||||||
|
features = super().transform(X)
|
||||||
|
if self.filterweights is None:
|
||||||
|
return features
|
||||||
|
else:
|
||||||
|
features = np.reshape(
|
||||||
|
features, (features.shape[0], len(self.filterbank), -1)
|
||||||
|
)
|
||||||
|
return np.sum(
|
||||||
|
features * self.filterweights[np.newaxis, :, np.newaxis], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def generate_filterbank(
|
||||||
|
passbands: List[Tuple[float, float]],
|
||||||
|
stopbands: List[Tuple[float, float]],
|
||||||
|
srate: int,
|
||||||
|
order: Optional[int] = None,
|
||||||
|
rp: float = 0.5,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a filter bank, that is, obtain a bandpass filter coefficient that can divide the input signal into multiple
|
||||||
|
subband components.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
passbands : list or tuple(float, float)
|
||||||
|
Passband parameters.
|
||||||
|
stopbands : list or tuple(float, float)
|
||||||
|
Stopband parameters.
|
||||||
|
srate : float
|
||||||
|
Sampling rate.
|
||||||
|
order : int
|
||||||
|
Filter order.
|
||||||
|
rp : float
|
||||||
|
The maximum ripple allowed in the passband below the unit gain is 0.5 by default.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Filterbank:ndarray, shape(len(passbands), N, 6)
|
||||||
|
Filter bank coefficient.
|
||||||
|
"""
|
||||||
|
filterbank = []
|
||||||
|
for wp, ws in zip(passbands, stopbands):
|
||||||
|
if order is None:
|
||||||
|
N, wn = cheb1ord(wp, ws, 3, 40, fs=srate)
|
||||||
|
sos = cheby1(N, rp, wn, btype="bandpass", output="sos", fs=srate)
|
||||||
|
else:
|
||||||
|
sos = cheby1(order, rp, wp, btype="bandpass", output="sos", fs=srate)
|
||||||
|
|
||||||
|
filterbank.append(sos)
|
||||||
|
return filterbank
|
||||||
|
|
||||||
|
def process(data):
|
||||||
|
# 白化操作
|
||||||
|
meanValue = np.mat(data.mean(axis=1))
|
||||||
|
meanData = np.repeat(meanValue, data.shape[1], axis=1)
|
||||||
|
whiteTemp = data - meanData
|
||||||
|
# QR 分解
|
||||||
|
rankWhiteTemp = whiteTemp.shape[0]
|
||||||
|
whiteTemp = np.transpose(whiteTemp)
|
||||||
|
Q, R = qr(whiteTemp.A, mode='economic')
|
||||||
|
# 计算矩阵的秩
|
||||||
|
rankQ = linalg.matrix_rank(R)
|
||||||
|
if rankQ == 0:
|
||||||
|
raise ValueError('stats:canoncorr:badData')
|
||||||
|
elif rankQ <= rankWhiteTemp:
|
||||||
|
# warnings.warn('stats:canoncorr:NotFullRank')
|
||||||
|
Q = Q[:, 0:rankQ]
|
||||||
|
return Q, rankQ
|
||||||
|
|
||||||
|
def reference(listFreqs,fs, numberSmples, num_harms):
|
||||||
|
numberFrequence = len(listFreqs)
|
||||||
|
timeIndex = np.arange(1, numberSmples + 1) / fs # time index
|
||||||
|
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
|
||||||
|
for frequenceIndex in range(numberFrequence):
|
||||||
|
temp = []
|
||||||
|
for harmIndex in range(1, num_harms + 1):
|
||||||
|
stimFrequence = listFreqs[frequenceIndex] # in HZ
|
||||||
|
# Sin and Cos
|
||||||
|
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
|
||||||
|
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
|
||||||
|
referenceTemp = np.mat(temp)
|
||||||
|
# 白化操作和QR分解
|
||||||
|
Q, rankQ = process(referenceTemp)
|
||||||
|
referenceData[frequenceIndex] = np.transpose(Q)
|
||||||
|
return referenceData
|
||||||
|
|
||||||
|
def generate_cca_references(
|
||||||
|
freqs: Union[ndarray, int, float],
|
||||||
|
srate,
|
||||||
|
T,
|
||||||
|
phases: Optional[Union[ndarray, int, float]] = None,
|
||||||
|
n_harmonics: int = 1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Construct a sine-cosine reference signal for canonical correlation analysis (CCA).
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
freqs : int or float
|
||||||
|
Frequency.
|
||||||
|
srate : int
|
||||||
|
Sampling rate.
|
||||||
|
T : int
|
||||||
|
Sampling time.
|
||||||
|
phases : int or float
|
||||||
|
Phase, default is None.
|
||||||
|
n_harmonics : int
|
||||||
|
The number of harmonics. The default value is 1.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Yf:ndarray, shape(srate*T, n_harmonics*2)
|
||||||
|
Sine and cosine reference signal.
|
||||||
|
"""
|
||||||
|
if isinstance(freqs, int) or isinstance(freqs, float):
|
||||||
|
freqs = np.array([freqs])
|
||||||
|
freqs = np.array(freqs)[:, np.newaxis]
|
||||||
|
if phases is None:
|
||||||
|
phases = 0
|
||||||
|
if isinstance(phases, int) or isinstance(phases, float):
|
||||||
|
phases = np.array([phases])
|
||||||
|
phases = np.array(phases)[:, np.newaxis]
|
||||||
|
t = np.linspace(0, T, int(T * srate))
|
||||||
|
|
||||||
|
Yf = []
|
||||||
|
for i in range(n_harmonics):
|
||||||
|
Yf.append(
|
||||||
|
np.stack(
|
||||||
|
[
|
||||||
|
np.sin(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
|
||||||
|
np.cos(2 * np.pi * (i + 1) * freqs * t + np.pi * phases),
|
||||||
|
],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
Yf = np.concatenate(Yf, axis=1)
|
||||||
|
return Yf
|
||||||
|
|
||||||
|
|
||||||
|
def sign_flip(u, s, vh=None):
|
||||||
|
"""Flip signs of SVD or EIG using the method in paper [1]_.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
u: ndarray
|
||||||
|
left singular vectors, shape (M, K).
|
||||||
|
s: ndarray
|
||||||
|
singular values, shape (K,).
|
||||||
|
vh: ndarray or None
|
||||||
|
transpose of right singular vectors, shape (K, N).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
u: ndarray
|
||||||
|
corrected left singular vectors.
|
||||||
|
s: ndarray
|
||||||
|
singular values.
|
||||||
|
vh: ndarray
|
||||||
|
transpose of corrected right singular vectors.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6422.pdf
|
||||||
|
"""
|
||||||
|
if vh is None:
|
||||||
|
total_proj = np.sum(u * s, axis=0)
|
||||||
|
signs = np.sign(total_proj)
|
||||||
|
|
||||||
|
random_idx = signs == 0
|
||||||
|
if np.any(random_idx):
|
||||||
|
signs[random_idx] = 1
|
||||||
|
warnings.warn(
|
||||||
|
"The magnitude is close to zero, the sign will become arbitrary."
|
||||||
|
)
|
||||||
|
|
||||||
|
u = u * signs
|
||||||
|
|
||||||
|
return u, s
|
||||||
|
else:
|
||||||
|
left_proj = np.sum(s[:, np.newaxis] * vh, axis=-1)
|
||||||
|
right_proj = np.sum(u * s, axis=0)
|
||||||
|
total_proj = left_proj + right_proj
|
||||||
|
signs = np.sign(total_proj)
|
||||||
|
|
||||||
|
random_idx = signs == 0
|
||||||
|
if np.any(random_idx):
|
||||||
|
signs[random_idx] = 1
|
||||||
|
warnings.warn(
|
||||||
|
"The magnitude is close to zero, the sign will become arbitrary."
|
||||||
|
)
|
||||||
|
|
||||||
|
u = u * signs
|
||||||
|
vh = signs[:, np.newaxis] * vh
|
||||||
|
|
||||||
|
return u, s, vh
|
||||||
436
Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/dsp.py
Normal file
436
Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/dsp.py
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# DSP: Discriminal Spatial Patterns
|
||||||
|
# Authors: Swolf <swolfforever@gmail.com>
|
||||||
|
# Junyang Wang <2144755928@qq.com>
|
||||||
|
# Last update date: 2022-8-11
|
||||||
|
# License: MIT License
|
||||||
|
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
from itertools import combinations
|
||||||
|
import numpy as np
|
||||||
|
from scipy.linalg import eigh
|
||||||
|
from numpy import ndarray
|
||||||
|
from scipy.linalg import solve
|
||||||
|
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray:
|
||||||
|
"""Transform spatial filters to spatial patterns based on paper [1]_.
|
||||||
|
Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine
|
||||||
|
information from different channels to extract signals of interest from EEG signals, but if our goal is
|
||||||
|
neurophysiological interpretation or visualization of weights, activation patterns need to be constructed
|
||||||
|
from the obtained spatial filters.
|
||||||
|
|
||||||
|
update log:
|
||||||
|
2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
W : ndarray
|
||||||
|
Spatial filters, shape (n_channels, n_filters).
|
||||||
|
Cx : ndarray
|
||||||
|
Covariance matrix of eeg data, shape (n_channels, n_channels).
|
||||||
|
Cs : ndarray
|
||||||
|
Covariance matrix of source data, shape (n_channels, n_channels).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A : ndarray
|
||||||
|
Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging.
|
||||||
|
Neuroimage 87 (2014): 96-110.
|
||||||
|
"""
|
||||||
|
# use linalg.solve instead of inv, makes it more stable
|
||||||
|
# see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py
|
||||||
|
# and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html
|
||||||
|
A = solve(Cs.T, np.dot(Cx, W).T).T
|
||||||
|
return A
|
||||||
|
|
||||||
|
def isPD(B: ndarray) -> bool:
|
||||||
|
"""Returns true when input matrix is positive-definite, via Cholesky decompositon method.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
B : ndarray
|
||||||
|
Any matrix, shape (N, N)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if B is positve-definite.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
Use numpy.linalg rather than scipy.linalg. In this case, scipy.linalg has unpredictable behaviors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = np.linalg.cholesky(B)
|
||||||
|
return True
|
||||||
|
except np.linalg.LinAlgError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def nearestPD(A: ndarray) -> ndarray:
|
||||||
|
"""Find the nearest positive-definite matrix to input.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
A : ndarray
|
||||||
|
Any square matrxi, shape (N, N)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A3 : ndarray
|
||||||
|
positive-definite matrix to A
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1]_, which
|
||||||
|
origins at [2]_.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd
|
||||||
|
.. [2] N.J. Higham, "Computing a nearest symmetric positive semidefinite matrix" (1988):
|
||||||
|
https://doi.org/10.1016/0024-3795(88)90223-6
|
||||||
|
"""
|
||||||
|
|
||||||
|
B = (A + A.T) / 2
|
||||||
|
_, s, V = np.linalg.svd(B)
|
||||||
|
|
||||||
|
H = np.dot(V.T, np.dot(np.diag(s), V))
|
||||||
|
|
||||||
|
A2 = (B + H) / 2
|
||||||
|
|
||||||
|
A3 = (A2 + A2.T) / 2
|
||||||
|
|
||||||
|
if isPD(A3):
|
||||||
|
return A3
|
||||||
|
|
||||||
|
print("Replace current matrix with the nearest positive-definite matrix.")
|
||||||
|
|
||||||
|
spacing = np.spacing(np.linalg.norm(A))
|
||||||
|
# The above is different from [1]. It appears that MATLAB's `chol` Cholesky
|
||||||
|
# decomposition will accept matrixes with exactly 0-eigenvalue, whereas
|
||||||
|
# Numpy's will not. So where [1] uses `eps(mineig)` (where `eps` is Matlab
|
||||||
|
# for `numpy.spacing`), we use the above definition. CAVEAT: our `spacing`
|
||||||
|
# will be much larger than [1]'s `eps(mineig)`, since `mineig` is usually on
|
||||||
|
# the order of 1e-16, and `eps(1e-16)` is on the order of 1e-34, whereas
|
||||||
|
# `spacing` will, for Gaussian random matrixes of small dimension, be on
|
||||||
|
# othe order of 1e-16. In practice, both ways converge, as the unit test
|
||||||
|
# below suggests.
|
||||||
|
eye = np.eye(A.shape[0])
|
||||||
|
k = 1
|
||||||
|
while not isPD(A3):
|
||||||
|
mineig = np.min(np.real(np.linalg.eigvals(A3)))
|
||||||
|
A3 += eye * (-mineig * k**2 + spacing)
|
||||||
|
k += 1
|
||||||
|
|
||||||
|
return A3
|
||||||
|
|
||||||
|
def xiang_dsp_kernel(
|
||||||
|
X: ndarray, y: ndarray
|
||||||
|
) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
|
||||||
|
"""
|
||||||
|
DSP: Discriminal Spatial Patterns, only for two classes[1]_.
|
||||||
|
Import train data to solve spatial filters with DSP,
|
||||||
|
finds a projection matrix that maximize the between-class scatter matrix and
|
||||||
|
minimize the within-class scatter matrix. Currently only support for two types of data.
|
||||||
|
|
||||||
|
Author: Swolf <swolfforever@gmail.com>
|
||||||
|
|
||||||
|
Created on: 2021-1-07
|
||||||
|
|
||||||
|
Update log:
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray
|
||||||
|
EEG train data assuming removing mean, shape (n_trials, n_channels, n_samples)
|
||||||
|
y : ndarray
|
||||||
|
labels of EEG data, shape (n_trials, )
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
W : ndarray
|
||||||
|
spatial filters, shape (n_channels, n_filters)
|
||||||
|
D : ndarray
|
||||||
|
eigenvalues in descending order
|
||||||
|
M : ndarray
|
||||||
|
mean value of all classes and trials, i.e. common mode signals, shape (n_channel, n_samples)
|
||||||
|
A : ndarray
|
||||||
|
spatial patterns, shape (n_channels, n_filters)
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
the implementation removes regularization on within-class scatter matrix Sw.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in
|
||||||
|
a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831.
|
||||||
|
"""
|
||||||
|
X, y = np.copy(X), np.copy(y)
|
||||||
|
labels = np.unique(y)
|
||||||
|
X = np.reshape(X, (-1, *X.shape[-2:]))
|
||||||
|
X = X - np.mean(X, axis=-1, keepdims=True)
|
||||||
|
# the number of each label
|
||||||
|
n_labels = np.array([np.sum(y == label) for label in labels])
|
||||||
|
# average template of all trials
|
||||||
|
M = np.mean(X, axis=0)
|
||||||
|
# class conditional template
|
||||||
|
Ms, Ss = zip(
|
||||||
|
*[
|
||||||
|
(
|
||||||
|
np.mean(X[y == label], axis=0),
|
||||||
|
np.sum(
|
||||||
|
np.matmul(X[y == label], np.swapaxes(X[y == label], -1, -2)), axis=0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for label in labels
|
||||||
|
]
|
||||||
|
)
|
||||||
|
Ms, Ss = np.stack(Ms), np.stack(Ss)
|
||||||
|
# within-class scatter matrix
|
||||||
|
Sw = np.sum(
|
||||||
|
Ss
|
||||||
|
- n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
Ms = Ms - M
|
||||||
|
# between-class scatter matrix
|
||||||
|
Sb = np.sum(
|
||||||
|
n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
D, W = eigh(nearestPD(Sb), nearestPD(Sw))
|
||||||
|
ix = np.argsort(D)[::-1] # in descending order
|
||||||
|
D, W = D[ix], W[:, ix]
|
||||||
|
A = robust_pattern(W, Sb, W.T @ Sb @ W)
|
||||||
|
|
||||||
|
return W, D, M, A
|
||||||
|
|
||||||
|
|
||||||
|
def xiang_dsp_feature(
|
||||||
|
W: ndarray, M: ndarray, X: ndarray, n_components: int = 1
|
||||||
|
) -> ndarray:
|
||||||
|
"""
|
||||||
|
Return DSP features in paper [1]_.
|
||||||
|
|
||||||
|
Author: Swolf <swolfforever@gmail.com>
|
||||||
|
|
||||||
|
Created on: 2021-1-07
|
||||||
|
|
||||||
|
Update log:
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
W : ndarray
|
||||||
|
spatial filters from csp_kernel, shape (n_channels, n_filters)
|
||||||
|
M : ndarray
|
||||||
|
common template for all classes, shape (n_channel, n_samples)
|
||||||
|
X : ndarray
|
||||||
|
eeg test data, shape (n_trials, n_channels, n_samples)
|
||||||
|
n_components : int, optional
|
||||||
|
length of the spatial filters, first k components to use, by default 1
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
features: ndarray
|
||||||
|
features, shape (n_trials, n_components, n_samples)
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
n_components should less than half of the number of channels
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
1. instead of meaning of filtered signals in paper [1]_., we directly return filtered signals.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in
|
||||||
|
a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831.
|
||||||
|
"""
|
||||||
|
W, M, X = np.copy(W), np.copy(M), np.copy(X)
|
||||||
|
max_components = W.shape[1]
|
||||||
|
if n_components > max_components:
|
||||||
|
raise ValueError("n_components should less than the number of channels")
|
||||||
|
X = np.reshape(X, (-1, *X.shape[-2:]))
|
||||||
|
X = X - np.mean(X, axis=-1, keepdims=True)
|
||||||
|
# print('************: ',np.shape(W),np.shape(X),np.shape(M))
|
||||||
|
features = np.matmul(W[:, :n_components].T, X - M)
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
class DSP(BaseEstimator, TransformerMixin, ClassifierMixin):
|
||||||
|
"""
|
||||||
|
DSP: Discriminal Spatial Patterns
|
||||||
|
|
||||||
|
Author: Swolf <swolfforever@gmail.com>
|
||||||
|
|
||||||
|
Created on: 2021-1-07
|
||||||
|
|
||||||
|
Update log:
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_components : int
|
||||||
|
length of the spatial filter, first k components to use, by default 1
|
||||||
|
transform_method : str
|
||||||
|
method of template matching, by default ’corr‘ (pearson correlation coefficient)
|
||||||
|
classes_ : int
|
||||||
|
number of the EEG classes
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
n_components : int
|
||||||
|
length of the spatial filter, first k components to use, by default 1
|
||||||
|
transform_method : str
|
||||||
|
method of template matching, by default ’corr‘ (pearson correlation coefficient)
|
||||||
|
classes_ : int
|
||||||
|
number of the EEG classes
|
||||||
|
W_ : ndarray, shape(n_channels, n_filters)
|
||||||
|
Spatial filters, shape(n_channels, n_filters), in which n_channels = n_filters
|
||||||
|
D_ : ndarray, shape(n_filters, )
|
||||||
|
eigenvalues in descending order, shape(n_filters, )
|
||||||
|
M_ : ndarray, shape(n_channels, n_samples)
|
||||||
|
mean value of all classes and trials, i.e. common mode signals, shape(n_channels, n_samples)
|
||||||
|
A_ : ndarray, shape(n_channels, n_filters)
|
||||||
|
spatial patterns, shape(n_channels, n_filters)
|
||||||
|
templates_: ndarray, shape(n_classes, n_filters, n_samples)
|
||||||
|
templates of train data, shape(n_classes, n_filters, n_samples)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_components: int = 1, transform_method: str = "corr"):
|
||||||
|
self.n_components = n_components
|
||||||
|
self.transform_method = transform_method
|
||||||
|
|
||||||
|
def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None):
|
||||||
|
"""
|
||||||
|
Import the train data to get a model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray
|
||||||
|
train data, shape(n_trials, n_channels, n_samples)
|
||||||
|
y : ndarray
|
||||||
|
labels of train data, shape (n_trials, )
|
||||||
|
Yf : ndarray
|
||||||
|
optional parameter
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
W_ : ndarray
|
||||||
|
spatial filters, shape (n_channels, n_filters), in which n_channels = n_filters
|
||||||
|
D_ : ndarray
|
||||||
|
eigenvalues in descending order, shape (n_filters, )
|
||||||
|
M_ : ndarray
|
||||||
|
template for all classes, shape (n_channel, n_samples)
|
||||||
|
A_ : ndarray
|
||||||
|
spatial patterns, shape (n_channels, n_filters)
|
||||||
|
templates_ : ndarray
|
||||||
|
templates of train data, shape (n_channels, n_filters, n_samples)
|
||||||
|
"""
|
||||||
|
X -= np.mean(X, axis=-1, keepdims=True)
|
||||||
|
self.classes_ = np.unique(y)
|
||||||
|
self.W_, self.D_, self.M_, self.A_ = xiang_dsp_kernel(X, y)
|
||||||
|
|
||||||
|
self.templates_ = np.stack(
|
||||||
|
[
|
||||||
|
np.mean(
|
||||||
|
xiang_dsp_feature(
|
||||||
|
self.W_, self.M_, X[y == label], n_components=self.W_.shape[1]
|
||||||
|
),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
for label in self.classes_
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: ndarray):
|
||||||
|
"""
|
||||||
|
Import the test data to get features.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray
|
||||||
|
test data, shape(n_trials, n_channels, n_samples)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
feature : ndarray, shape(n_trials,n_classes)
|
||||||
|
correlation coefficients of templates of train data and features of test data, shape(n_trials, n_classes)
|
||||||
|
"""
|
||||||
|
n_components = self.n_components
|
||||||
|
X -= np.mean(X, axis=-1, keepdims=True)
|
||||||
|
features = xiang_dsp_feature(self.W_, self.M_, X, n_components=n_components)
|
||||||
|
if self.transform_method is None:
|
||||||
|
return features.reshape((features.shape[0], -1))
|
||||||
|
elif self.transform_method == "mean":
|
||||||
|
return np.mean(features, axis=-1)
|
||||||
|
elif self.transform_method == "corr":
|
||||||
|
return self._pearson_features(
|
||||||
|
features, self.templates_[:, :n_components, :]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("non-supported transform method")
|
||||||
|
|
||||||
|
def _pearson_features(self, X: ndarray, templates: ndarray):
|
||||||
|
"""
|
||||||
|
Calculate pearson correlation coefficient.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray
|
||||||
|
features of test data after spatial filters, shape(n_trials, n_components, n_samples)
|
||||||
|
templates : ndarray
|
||||||
|
templates of train data, shape(n_classes, n_components, n_samples)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
corr : ndarray
|
||||||
|
pearson correlation coefficient, shape(n_trials, n_classes)
|
||||||
|
"""
|
||||||
|
X = np.reshape(X, (-1, *X.shape[-2:]))
|
||||||
|
templates = np.reshape(templates, (-1, *templates.shape[-2:]))
|
||||||
|
X = X - np.mean(X, axis=-1, keepdims=True)
|
||||||
|
templates = templates - np.mean(templates, axis=-1, keepdims=True)
|
||||||
|
X = np.reshape(X, (X.shape[0], -1))
|
||||||
|
templates = np.reshape(templates, (templates.shape[0], -1))
|
||||||
|
istd_X = 1 / np.std(X, axis=-1, keepdims=True)
|
||||||
|
istd_templates = 1 / np.std(templates, axis=-1, keepdims=True)
|
||||||
|
corr = (X @ templates.T) / (templates.shape[1] - 1)
|
||||||
|
corr = istd_X * corr * istd_templates.T
|
||||||
|
return corr
|
||||||
|
|
||||||
|
def predict(self, X: ndarray):
|
||||||
|
"""
|
||||||
|
Import the templates and the test data to get prediction labels.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : ndarray
|
||||||
|
test data, shape(n_trials, n_channels, n_samples)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
labels : ndarray
|
||||||
|
prediction labels of test data, shape(n_trials,)
|
||||||
|
"""
|
||||||
|
feat = self.transform(X)
|
||||||
|
if self.transform_method == "corr":
|
||||||
|
labels = self.classes_[np.argmax(feat, axis=-1)]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
175
Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/tdca.py
Normal file
175
Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/tdca.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Authors: Swolf <swolfforever@gmail.com>
|
||||||
|
# Date: 2021/10/10
|
||||||
|
# License: MIT License
|
||||||
|
"""
|
||||||
|
Task Decomposition Component Analysis.
|
||||||
|
"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.linalg import qr
|
||||||
|
from scipy.stats import pearsonr
|
||||||
|
from numpy import ndarray
|
||||||
|
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin
|
||||||
|
from typing import Optional, List
|
||||||
|
from SSMVEP.algorithm.base import FilterBankSSVEP
|
||||||
|
from SSMVEP.algorithm.dsp import xiang_dsp_kernel, xiang_dsp_feature
|
||||||
|
|
||||||
|
|
||||||
|
def proj_ref(Yf: ndarray):
|
||||||
|
Q, R = qr(Yf.T, mode="economic")
|
||||||
|
P = Q @ Q.T
|
||||||
|
return P
|
||||||
|
|
||||||
|
|
||||||
|
def aug_2(X: ndarray, n_samples: int, padding_len: int, P: ndarray, training: bool = True):
|
||||||
|
X = X.reshape((-1, *X.shape[-2:]))
|
||||||
|
n_trials, n_channels, n_points = X.shape
|
||||||
|
# if n_points < padding_len + n_samples:
|
||||||
|
# raise ValueError("the length of X should be larger than l+n_samples.")
|
||||||
|
aug_X = np.zeros((n_trials, (padding_len + 1) * n_channels, n_samples))
|
||||||
|
if training:
|
||||||
|
for i in range(padding_len + 1):
|
||||||
|
aug_X[:, i * n_channels : (i + 1) * n_channels, :] = X[
|
||||||
|
..., i : i + n_samples
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
for i in range(padding_len + 1):
|
||||||
|
aug_X[:, i * n_channels : (i + 1) * n_channels, : n_samples - i] = X[
|
||||||
|
..., i:n_samples
|
||||||
|
]
|
||||||
|
aug_Xp = aug_X @ P
|
||||||
|
aug_X = np.concatenate([aug_X, aug_Xp], axis=-1)
|
||||||
|
return aug_X
|
||||||
|
|
||||||
|
|
||||||
|
def tdca_feature(
|
||||||
|
X: ndarray,
|
||||||
|
templates: ndarray,
|
||||||
|
W: ndarray,
|
||||||
|
M: ndarray,
|
||||||
|
Ps: List[ndarray],
|
||||||
|
padding_len: int,
|
||||||
|
n_components: int = 1,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
rhos = []
|
||||||
|
for Xk, P in zip(templates, Ps):
|
||||||
|
a = xiang_dsp_feature(
|
||||||
|
W,
|
||||||
|
M,
|
||||||
|
aug_2(X, P.shape[0], padding_len, P, training=training),
|
||||||
|
n_components=n_components,
|
||||||
|
)
|
||||||
|
b = Xk[:n_components, :]
|
||||||
|
a = np.reshape(a, (-1))
|
||||||
|
b = np.reshape(b, (-1))
|
||||||
|
rhos.append(pearsonr(a, b)[0])
|
||||||
|
return rhos
|
||||||
|
|
||||||
|
|
||||||
|
class TDCA(BaseEstimator, TransformerMixin, ClassifierMixin):
|
||||||
|
def __init__(self, padding_len: int, n_components: int = 1):
|
||||||
|
self.padding_len = padding_len
|
||||||
|
self.n_components = n_components
|
||||||
|
|
||||||
|
def fit(self, X: ndarray, y: ndarray, Yf: ndarray):
|
||||||
|
X -= np.mean(X, axis=-1, keepdims=True)
|
||||||
|
self.classes_ = np.unique(y)
|
||||||
|
self.Ps_ = [proj_ref(Yf[i]) for i in range(len(self.classes_))]
|
||||||
|
# print(np.shape(self.Ps_))
|
||||||
|
|
||||||
|
aug_X_list, aug_Y_list = [], []
|
||||||
|
for i, label in enumerate(self.classes_):
|
||||||
|
aug_X_list.append(
|
||||||
|
aug_2(
|
||||||
|
X[y == label],
|
||||||
|
self.Ps_[i].shape[0],
|
||||||
|
self.padding_len,
|
||||||
|
self.Ps_[i],
|
||||||
|
training=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
aug_Y_list.append(y[y == label])
|
||||||
|
|
||||||
|
aug_X = np.concatenate(aug_X_list, axis=0)
|
||||||
|
aug_Y = np.concatenate(aug_Y_list, axis=0)
|
||||||
|
self.W_, _, self.M_, _ = xiang_dsp_kernel(aug_X, aug_Y)
|
||||||
|
|
||||||
|
self.templates_ = np.stack(
|
||||||
|
[
|
||||||
|
np.mean(
|
||||||
|
xiang_dsp_feature(
|
||||||
|
self.W_,
|
||||||
|
self.M_,
|
||||||
|
aug_X[aug_Y == label],
|
||||||
|
n_components=self.W_.shape[1],
|
||||||
|
),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
for label in self.classes_
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: ndarray):
|
||||||
|
n_components = self.n_components
|
||||||
|
X -= np.mean(X, axis=-1, keepdims=True)
|
||||||
|
X = X.reshape((-1, *X.shape[-2:]))
|
||||||
|
rhos = [
|
||||||
|
tdca_feature(
|
||||||
|
tmp,
|
||||||
|
self.templates_,
|
||||||
|
self.W_,
|
||||||
|
self.M_,
|
||||||
|
self.Ps_,
|
||||||
|
self.padding_len,
|
||||||
|
n_components=n_components,
|
||||||
|
)
|
||||||
|
for tmp in X
|
||||||
|
]
|
||||||
|
rhos = np.stack(rhos)
|
||||||
|
return rhos
|
||||||
|
|
||||||
|
def predict(self, X: ndarray):
|
||||||
|
feat = self.transform(X)
|
||||||
|
labels = self.classes_[np.argmax(feat, axis=-1)]
|
||||||
|
return labels,feat
|
||||||
|
|
||||||
|
|
||||||
|
class FBTDCA(FilterBankSSVEP, ClassifierMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filterbank: List[ndarray],
|
||||||
|
padding_len: int,
|
||||||
|
n_components: int = 1,
|
||||||
|
filterweights: Optional[ndarray] = None,
|
||||||
|
n_jobs: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.padding_len = padding_len
|
||||||
|
self.n_components = n_components
|
||||||
|
self.filterweights = filterweights
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
super().__init__(
|
||||||
|
filterbank,
|
||||||
|
TDCA(padding_len, n_components=n_components),
|
||||||
|
filterweights=filterweights,
|
||||||
|
n_jobs=n_jobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None): # type: ignore[override]
|
||||||
|
self.classes_ = np.unique(y)
|
||||||
|
super().fit(X, y, Yf=Yf)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, X: ndarray):
|
||||||
|
features = self.transform(X)
|
||||||
|
if self.filterweights is None:
|
||||||
|
features = np.reshape(
|
||||||
|
features, (features.shape[0], len(self.filterbank), -1)
|
||||||
|
)
|
||||||
|
features = np.mean(features, axis=1)
|
||||||
|
labels = self.classes_[np.argmax(features, axis=-1)]
|
||||||
|
return labels,features
|
||||||
529
Debug_64ch_Decoder_Optimize/SSVEP/dwfbcca.py
Normal file
529
Debug_64ch_Decoder_Optimize/SSVEP/dwfbcca.py
Normal file
@@ -0,0 +1,529 @@
|
|||||||
|
|
||||||
|
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from os import error
|
||||||
|
import numpy as np
|
||||||
|
import scipy
|
||||||
|
from numpy.linalg import linalg
|
||||||
|
from scipy.io import loadmat
|
||||||
|
from scipy.linalg import qr
|
||||||
|
from scipy.signal import filtfilt, lfilter
|
||||||
|
# from numpy.linalg import _umath_linalg
|
||||||
|
|
||||||
|
|
||||||
|
class FbccaDw:
|
||||||
|
def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method):
|
||||||
|
print('******************************************')
|
||||||
|
print('parameter list')
|
||||||
|
print('target:', num_target)
|
||||||
|
print('number of filter bank:', num_filter)
|
||||||
|
print('parameter:', parameter)
|
||||||
|
print('width:', width)
|
||||||
|
self.phase = 0
|
||||||
|
self.bandWidth = width
|
||||||
|
self.winNum = winNum
|
||||||
|
self.num_harms = num_harms
|
||||||
|
self.num_target = num_target
|
||||||
|
self.num_chans = num_chans
|
||||||
|
self.winTimeDelay = stimTime
|
||||||
|
self.fs = fs
|
||||||
|
self.winMaxSampleNum = self.winTimeDelay * self.winNum * self.fs
|
||||||
|
self.winDelayNum = round(self.winTimeDelay * self.fs)
|
||||||
|
self.num_fbs = num_filter
|
||||||
|
parameterValue = np.power(np.arange(1, self.num_fbs + 1), -(parameter[0])) + parameter[1]
|
||||||
|
self.weightValue = parameterValue / (sum(parameterValue))
|
||||||
|
|
||||||
|
self.dataUseLen = [0] * self.winNum
|
||||||
|
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
|
||||||
|
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
|
||||||
|
self.rhoNum = 2
|
||||||
|
self.notchZh = [0]
|
||||||
|
self.filterZf = [0] * self.num_fbs
|
||||||
|
self.north_b = []
|
||||||
|
self.north_a = []
|
||||||
|
self.filterBank_A = []
|
||||||
|
self.filterBank_B = []
|
||||||
|
self.winStep = 1
|
||||||
|
self.DW_cost_method = 'DW11' if method==1 else 'DW1'
|
||||||
|
|
||||||
|
'''
|
||||||
|
filterFrequenceBank:根据刺激频率生成的通带和阻带,用于滤波器组频带分解
|
||||||
|
'''
|
||||||
|
|
||||||
|
def filterFrequenceBank(self):
|
||||||
|
# 阻带的最高频率
|
||||||
|
lastFrequence = 90
|
||||||
|
freqBandWidth = self.bandWidth[1]
|
||||||
|
fStep = self.bandWidth[0]
|
||||||
|
bandFrequence = np.zeros((5, 4))
|
||||||
|
# 第二列频率带
|
||||||
|
band = list(range(freqBandWidth, lastFrequence, fStep))
|
||||||
|
band[:] = [x - 2 for x in band]
|
||||||
|
colValue = np.maximum(np.asmatrix(band), 1)
|
||||||
|
bandFrequence[:, 1] = colValue[0, 0:5]
|
||||||
|
# 第一列频率带
|
||||||
|
bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
|
||||||
|
# 第三列频率带
|
||||||
|
bandFrequence[:, 2] = lastFrequence + 2
|
||||||
|
# 第四列频率带
|
||||||
|
bandFrequence[:, 3] = bandFrequence[:, 2] + 10
|
||||||
|
# bandFrequence = np.array([[30,33,77,82],
|
||||||
|
# [62,68,77,82]])
|
||||||
|
for idx_fb in range(self.num_fbs):
|
||||||
|
Nq = self.fs / 2
|
||||||
|
Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
|
||||||
|
Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
|
||||||
|
[N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
|
||||||
|
40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
|
||||||
|
[B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
|
||||||
|
self.filterBank_A.append(A)
|
||||||
|
self.filterBank_B.append(B)
|
||||||
|
# def filterFrequenceBank(self):
|
||||||
|
# # 阻带的最高频率
|
||||||
|
# lastFrequence = 90
|
||||||
|
# freqBandWidth = self.bandWidth[1]
|
||||||
|
# fStep = self.bandWidth[0]
|
||||||
|
# bandFrequence = np.zeros((5, 4))
|
||||||
|
# # 第二列频率带
|
||||||
|
# band = list(range(freqBandWidth, lastFrequence, fStep))
|
||||||
|
# band[:] = [x - 2 for x in band]
|
||||||
|
# colValue = np.maximum(np.asmatrix(band), 1)
|
||||||
|
# bandFrequence[:, 1] = colValue[0, 0:5]
|
||||||
|
# # 第一列频率带
|
||||||
|
# bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1)
|
||||||
|
# # 第三列频率带
|
||||||
|
# bandFrequence[:, 2] = lastFrequence + 2
|
||||||
|
# # 第四列频率带
|
||||||
|
# bandFrequence[:, 3] = bandFrequence[:, 2] + 10
|
||||||
|
# for idx_fb in range(self.num_fbs):
|
||||||
|
# Nq = self.fs / 2
|
||||||
|
# Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq]
|
||||||
|
# Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq]
|
||||||
|
# [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3,
|
||||||
|
# 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)]
|
||||||
|
# [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency
|
||||||
|
# self.filterBank_A.append(A)
|
||||||
|
# self.filterBank_B.append(B)
|
||||||
|
|
||||||
|
'''
|
||||||
|
Filter bank analysis
|
||||||
|
Input:
|
||||||
|
eeg : Input eeg data (# of targets, # of channels, Data length [sample])
|
||||||
|
Output:
|
||||||
|
filterData : Generated filter Data
|
||||||
|
'''
|
||||||
|
|
||||||
|
def filterbank(self, eeg):
|
||||||
|
filterData = np.zeros((self.num_fbs, eeg.shape[1], eeg.shape[0]))
|
||||||
|
for filterIndex in range(self.num_fbs):
|
||||||
|
if np.all(self.filterZf[filterIndex] == 0):
|
||||||
|
zi = np.zeros(
|
||||||
|
[max(len(self.filterBank_A[filterIndex]), len(self.filterBank_B[filterIndex])) - 1, self.num_chans])
|
||||||
|
_, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex],
|
||||||
|
eeg, zi=zi.T)
|
||||||
|
Data = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex], eeg)
|
||||||
|
else:
|
||||||
|
Data, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex],
|
||||||
|
self.filterBank_A[filterIndex], eeg,
|
||||||
|
zi=self.filterZf[filterIndex])
|
||||||
|
filterData[filterIndex, :, :] = Data.T
|
||||||
|
return filterData
|
||||||
|
|
||||||
|
'''
|
||||||
|
process
|
||||||
|
矩阵的白化和QR正则化分解,降低矩阵的维度,加速计算时间
|
||||||
|
Input:
|
||||||
|
data : 输入的二维脑电信号
|
||||||
|
Output:
|
||||||
|
Q : 降维后的矩阵
|
||||||
|
rankQ :正则矩阵的秩
|
||||||
|
'''
|
||||||
|
|
||||||
|
def process(self, data):
|
||||||
|
# 白化操作
|
||||||
|
meanValue = np.asmatrix(data.mean(axis=1))
|
||||||
|
meanData = np.repeat(meanValue, data.shape[1], axis=1)
|
||||||
|
whiteTemp = data - meanData
|
||||||
|
# QR 分解
|
||||||
|
rankWhiteTemp = whiteTemp.shape[0]
|
||||||
|
whiteTemp = np.transpose(whiteTemp)
|
||||||
|
Q, R = qr(whiteTemp.A, mode='economic')
|
||||||
|
# 计算矩阵的秩
|
||||||
|
rankQ = linalg.matrix_rank(R)
|
||||||
|
if rankQ == 0:
|
||||||
|
raise ValueError('stats:canoncorr:badData')
|
||||||
|
elif rankQ <= rankWhiteTemp:
|
||||||
|
# warnings.warn('stats:canoncorr:NotFullRank')
|
||||||
|
Q = Q[:, 0:rankQ]
|
||||||
|
return Q, rankQ
|
||||||
|
|
||||||
|
'''
|
||||||
|
reference
|
||||||
|
Input:
|
||||||
|
listFreqs : 刺激频率列表
|
||||||
|
numberSmples : 用于分类的脑电信号采样点个数
|
||||||
|
num_harms : 谐波数
|
||||||
|
Output:
|
||||||
|
y_ref : 生成的参考信号 (刺激目标数, 2 * 谐波数, 数据长度/采样点数)
|
||||||
|
'''
|
||||||
|
|
||||||
|
def reference(self, listFreqs, numberSmples, num_harms):
|
||||||
|
numberFrequence = len(listFreqs)
|
||||||
|
timeIndex = np.arange(1, numberSmples + 1) / self.fs # time index
|
||||||
|
referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples))
|
||||||
|
for frequenceIndex in range(numberFrequence):
|
||||||
|
temp = []
|
||||||
|
for harmIndex in range(1, num_harms + 1):
|
||||||
|
stimFrequence = listFreqs[frequenceIndex] # in HZ
|
||||||
|
# Sin and Cos
|
||||||
|
temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence),
|
||||||
|
np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)])
|
||||||
|
referenceTemp = np.asmatrix(temp)
|
||||||
|
# 白化操作和QR分解
|
||||||
|
Q, rankQ = self.process(referenceTemp)
|
||||||
|
referenceData[frequenceIndex] = np.transpose(Q)
|
||||||
|
return referenceData
|
||||||
|
|
||||||
|
'''
|
||||||
|
setNorthFilterPara
|
||||||
|
陷波器的参数初始化
|
||||||
|
self.north_b, self.north_a : 陷波器的参数设计
|
||||||
|
'''
|
||||||
|
|
||||||
|
def setNotchFilterPara(self):
|
||||||
|
# notchFilterNum = 3
|
||||||
|
# northFreq = 50
|
||||||
|
# bwDen = 35
|
||||||
|
# wo = northFreq / (self.fs / 2)
|
||||||
|
# bw = wo / bwDen
|
||||||
|
# self.north_b, self.north_a = iirnotch(wo, Q=35) # self.north_b, self.north_a = iircomb(northFreq, bwDen, 'notch')
|
||||||
|
# # n倍零极点,相当于重复滤波n次
|
||||||
|
# if notchFilterNum > 1:
|
||||||
|
# z, p, k = tf2zpk(self.north_b, self.north_a)
|
||||||
|
# zNew = np.repeat(z, notchFilterNum, axis=0)
|
||||||
|
# zNew[1], zNew[4] = zNew[4], zNew[1]
|
||||||
|
# pNew = np.repeat(p, notchFilterNum, axis=0)
|
||||||
|
# pNew[1], pNew[4] = pNew[4], pNew[1]
|
||||||
|
# kNew = np.power(k, notchFilterNum)
|
||||||
|
# self.north_b, self.north_a = zpk2tf(zNew, pNew, kNew)
|
||||||
|
self.north_b = [0.94801603944125245604368501517456, -1.7577184027642647201616910024313,
|
||||||
|
3.9303778338832491279219993884908, -3.7392330345967859095424046245171,
|
||||||
|
3.9303778338832482397435796883656, -1.7577184027642638319832713023061,
|
||||||
|
0.94801603944125156786526531504933]
|
||||||
|
|
||||||
|
self.north_a = [1, -1.8214007435820627200939725298667, 4.0000101767406484043476666556671,
|
||||||
|
-3.7380998614928691026193519064691, 3.8589119784285759173769747576443,
|
||||||
|
-1.6951692350503837491970671180752, 0.89786559147978006745205448169145]
|
||||||
|
|
||||||
|
'''
|
||||||
|
northFilter
|
||||||
|
进行信号的50hz陷波处理
|
||||||
|
Input:
|
||||||
|
data :输入脑电数据
|
||||||
|
Output:
|
||||||
|
dataFiltered : 陷波处理后的脑电数据
|
||||||
|
'''
|
||||||
|
|
||||||
|
def northFilter(self, data):
|
||||||
|
try:
|
||||||
|
if np.all(self.notchZh[0] == 0):
|
||||||
|
zi = np.zeros([max(len(self.north_a), len(self.north_b)) - 1, self.num_chans])
|
||||||
|
_, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=zi.T)
|
||||||
|
dataFiltered = lfilter(self.north_b, self.north_a, data)
|
||||||
|
else:
|
||||||
|
dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0])
|
||||||
|
return np.asmatrix(dataFiltered)
|
||||||
|
except Exception:
|
||||||
|
print(Exception)
|
||||||
|
|
||||||
|
'''
|
||||||
|
getDataQ
|
||||||
|
Inputs:
|
||||||
|
data:脑电数据
|
||||||
|
Rbuffer:待更新的中间系数
|
||||||
|
Output:
|
||||||
|
Qs1 : 脑电特征1
|
||||||
|
Qs2 : 脑电特征2
|
||||||
|
Rbuffer : 单窗口更新后的系数
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
def getDataQ(self, data, Rbuffer):
|
||||||
|
Qs1 = [0] * self.num_fbs
|
||||||
|
Qs2 = [0] * self.num_fbs
|
||||||
|
nulldata = np.zeros([self.num_chans, self.num_chans])
|
||||||
|
Rnum = self.num_chans
|
||||||
|
for fb_num in range(self.num_fbs):
|
||||||
|
fb_data = np.squeeze(data[fb_num, :, :])
|
||||||
|
if np.all(Rbuffer[fb_num] == 0):
|
||||||
|
whiteTemp = fb_data
|
||||||
|
Q, R = qr(whiteTemp, mode='economic')
|
||||||
|
Qs1[fb_num] = nulldata
|
||||||
|
Qs2[fb_num] = Q
|
||||||
|
Rbuffer[fb_num] = R
|
||||||
|
else:
|
||||||
|
whiteTemp = np.concatenate((Rbuffer[fb_num], fb_data), axis=0)
|
||||||
|
Q, R = qr(whiteTemp, mode='economic')
|
||||||
|
Qs1[fb_num] = Q[0:Rnum, :]
|
||||||
|
Qs2[fb_num] = Q[Rnum:, :]
|
||||||
|
Rbuffer[fb_num] = R
|
||||||
|
return Qs1, Qs2, Rbuffer
|
||||||
|
|
||||||
|
'''
|
||||||
|
myCCA:根据脑电特征和参考信号计算相关系数
|
||||||
|
Inputs:
|
||||||
|
dataQ:脑电特征
|
||||||
|
Qc2y:参考信号
|
||||||
|
d : 相关系数取值数
|
||||||
|
Output:
|
||||||
|
rho : 相关系数
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
def myCCA(self, dataQ, Qc2y, d):
|
||||||
|
if len(Qc2y) == 0:
|
||||||
|
Cov = dataQ
|
||||||
|
else:
|
||||||
|
Cov = np.dot(Qc2y, dataQ)
|
||||||
|
# U, S, V = scipy.linalg.svd(Cov, 0)
|
||||||
|
# rho = np.minimum(np.maximum(np.diag(S[0: d]).T, 0), 1)
|
||||||
|
# gufunc = _umath_linalg.svd_n
|
||||||
|
# rho = gufunc(Cov)
|
||||||
|
rho = np.linalg.svd(Cov, compute_uv=False)
|
||||||
|
return rho[0:d]
|
||||||
|
|
||||||
|
'''
|
||||||
|
weightCCA:计算分类标签
|
||||||
|
Inputs:
|
||||||
|
Qs1:脑电特征1
|
||||||
|
Qs2:脑电特征2
|
||||||
|
ref : 正余弦参考信号
|
||||||
|
Cxy : 协方差中间参数
|
||||||
|
Output:
|
||||||
|
result : 分类标签
|
||||||
|
rho : 相关系数
|
||||||
|
Cxy : 更新后的协方差中间参数
|
||||||
|
'''
|
||||||
|
|
||||||
|
def weightCCA(self, Qs1, Qs2, ref, Cxy):
|
||||||
|
rMax = np.zeros([self.num_fbs, self.num_target])
|
||||||
|
for fi in range(self.num_fbs):
|
||||||
|
for si in range(self.num_target):
|
||||||
|
Qc2y = np.squeeze(ref[si, :, :])
|
||||||
|
# 更新协方差矩阵
|
||||||
|
if np.all(Cxy[fi][si] == 0):
|
||||||
|
Cxy[fi, si] = np.dot(Qc2y, Qs2[fi])
|
||||||
|
else:
|
||||||
|
Cxy[fi, si] = np.dot(Cxy[fi, si], Qs1[fi]) + np.dot(Qc2y, Qs2[fi])
|
||||||
|
r = self.myCCA(Cxy[fi, si], [], self.rhoNum)
|
||||||
|
rMax[fi, si] = r[0]
|
||||||
|
rho = np.dot(self.weightValue, np.power(rMax, 2)) # weighted sum of r from all different filter banks' result
|
||||||
|
result = np.argmax(rho)
|
||||||
|
return result, rho, Cxy
|
||||||
|
|
||||||
|
'''
|
||||||
|
costF:损失函数,根据计算的相关系数,生成决策值,用于和阈值进行比较
|
||||||
|
Inputs:
|
||||||
|
rho:相关系数
|
||||||
|
method:相关系数计算参数
|
||||||
|
C : 参数
|
||||||
|
Output:
|
||||||
|
decideValue : 决策阈值
|
||||||
|
'''
|
||||||
|
|
||||||
|
def costF(self, rho, method, C):
|
||||||
|
rho = rho.tolist()
|
||||||
|
rho.sort(reverse=True)
|
||||||
|
if method == 'DW1':
|
||||||
|
decideValue = (rho[0] - rho[1]) / (sum(rho) - self.num_target * np.log(sum(np.exp(rho))))
|
||||||
|
elif method == 'DW11':
|
||||||
|
decideValue = -(rho[0] - rho[1])
|
||||||
|
elif method == 'DW2':
|
||||||
|
decideValue = (rho[0] - C) / (rho[1] - rho[0])
|
||||||
|
return decideValue
|
||||||
|
|
||||||
|
'''
|
||||||
|
onlineInit:将窗口长度,相位值、中间参数初始化
|
||||||
|
'''
|
||||||
|
|
||||||
|
def onlineInit(self):
|
||||||
|
self.dataUseLen = [0] * self.winNum
|
||||||
|
self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans])
|
||||||
|
self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans])
|
||||||
|
self.phase = 0
|
||||||
|
|
||||||
|
'''
|
||||||
|
filterInit:重置陷波器和滤波器的滤波参数
|
||||||
|
'''
|
||||||
|
|
||||||
|
def filterInit(self):
|
||||||
|
self.notchZh = [0]
|
||||||
|
self.filterZf = [0] * self.num_fbs
|
||||||
|
|
||||||
|
'''
|
||||||
|
warmFilter:预热滤波器,去刺激前的4S数据对陷波器和滤波器参数进行初始化迭代,去除过渡带的效果
|
||||||
|
Inputs:
|
||||||
|
data:预处理脑电数据
|
||||||
|
'''
|
||||||
|
|
||||||
|
def warmFilter(self, data):
|
||||||
|
# 降采样在采集前完成
|
||||||
|
temp = self.preprocessFilter(data) #预热陷波滤波器
|
||||||
|
# 滤波器组频带分解
|
||||||
|
filterData = self.filterbank(temp) #预热滤波器组
|
||||||
|
|
||||||
|
'''
|
||||||
|
myDownSample:数据降采样
|
||||||
|
Inputs:
|
||||||
|
data:脑电数据
|
||||||
|
n:降采样的倍数
|
||||||
|
Output:
|
||||||
|
eegData2 : 降采样后的数据
|
||||||
|
'''
|
||||||
|
|
||||||
|
def myDownSample(self, data, n):
|
||||||
|
data = data[:8, self.phase:]
|
||||||
|
dataNum = data.shape[1]
|
||||||
|
remainNum = (dataNum - 1) % n
|
||||||
|
self.phase = n - 1 - remainNum
|
||||||
|
dataDowmSample = []
|
||||||
|
for value in data:
|
||||||
|
value = value[0:value.size:n]
|
||||||
|
dataDowmSample.append(value)
|
||||||
|
eegData2 = np.array(dataDowmSample).reshape([8, int(np.round(data.shape[1] / n))])
|
||||||
|
return eegData2
|
||||||
|
|
||||||
|
'''
|
||||||
|
preprocessFilter:预处理,调用函数降采样和陷波处理
|
||||||
|
Inputs:
|
||||||
|
data:脑电数据
|
||||||
|
Output:
|
||||||
|
filterData : 降采样和陷波后的数据
|
||||||
|
'''
|
||||||
|
|
||||||
|
def preprocessFilter(self, data):
|
||||||
|
# data = self.myDownSample(data, 4)
|
||||||
|
# filterData = self.northFilter(data[:8, :])
|
||||||
|
filterData = self.northFilter(data[:, :])
|
||||||
|
return filterData
|
||||||
|
|
||||||
|
'''
|
||||||
|
fbccaDWMW:分类函数,对输入的脑电信号进行识别,输出决策标签
|
||||||
|
Inputs:
|
||||||
|
testdata:脑电数据
|
||||||
|
referenceData:参考信号
|
||||||
|
tValue:出决策阈值
|
||||||
|
Output:
|
||||||
|
res : 决策标签
|
||||||
|
rho_new:相关系数
|
||||||
|
minEps:得到的决策阈值
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 动态窗算法主函数
|
||||||
|
def fbccaDWMW(self, testdata, referenceData, tValue,calculateCount):
|
||||||
|
t1 = time.time()
|
||||||
|
# try:
|
||||||
|
# 初始参数
|
||||||
|
res = -1
|
||||||
|
minEps = float("inf")
|
||||||
|
# 降采样和陷波器处理
|
||||||
|
northData = self.preprocessFilter(testdata)
|
||||||
|
newSampleNum = northData.shape[1]
|
||||||
|
# 数据大于延迟长度,则无法根据后面的规则更新窗口
|
||||||
|
if newSampleNum > self.winDelayNum:
|
||||||
|
error('need add window delay time')
|
||||||
|
|
||||||
|
# 防止秩小于导联数
|
||||||
|
if newSampleNum < self.num_chans:
|
||||||
|
warnings.warn('data shape is [%d %d] need more data' % (newSampleNum, northData.shape[0]))
|
||||||
|
# 滤波器组频带分解
|
||||||
|
filterData = self.filterbank(northData)
|
||||||
|
winMinTime = 0
|
||||||
|
# 计算每个窗口的结果
|
||||||
|
for wi in range(0, self.winNum, self.winStep):
|
||||||
|
# print('dataUseLen:',wi,calculateCount, self.dataUseLen)
|
||||||
|
if wi == 0:
|
||||||
|
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
|
||||||
|
else:
|
||||||
|
if self.dataUseLen[wi] == 0:
|
||||||
|
# 判断当前窗是否为新的窗口(因为每一次新的窗口进来时,都会使上一个窗口datauseLen>50)
|
||||||
|
if self.dataUseLen[wi - self.winStep] > self.winDelayNum*self.winStep:
|
||||||
|
self.dataUseLen[wi] = newSampleNum
|
||||||
|
else:
|
||||||
|
# print('中断: ',wi,calculateCount)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum
|
||||||
|
|
||||||
|
if self.dataUseLen[wi] > self.winMaxSampleNum:
|
||||||
|
self.dataUseLen[wi] = newSampleNum
|
||||||
|
self.Rbuffer[wi, :, :, :] = 0
|
||||||
|
self.Cxy[wi, :, :, :, :] = 0
|
||||||
|
Qs1, Qs2, self.Rbuffer[wi, :, :, :] = self.getDataQ(filterData, self.Rbuffer[wi, :, :, :])
|
||||||
|
si = self.dataUseLen[wi] - newSampleNum
|
||||||
|
ei = self.dataUseLen[wi]
|
||||||
|
ref = referenceData[:, :, si:ei]
|
||||||
|
# 更新协方差
|
||||||
|
predLabel_new, rho_new, self.Cxy[wi, :, :, :] = self.weightCCA(Qs1, Qs2, ref, self.Cxy[wi, :, :, :])
|
||||||
|
# 增加限制,数据长度不能太短
|
||||||
|
if self.dataUseLen[wi] > winMinTime * self.fs:
|
||||||
|
epsilon = self.costF(rho_new, self.DW_cost_method, C=0)
|
||||||
|
if epsilon < minEps:
|
||||||
|
minEps = epsilon
|
||||||
|
predLabel = predLabel_new
|
||||||
|
xxx = rho_new
|
||||||
|
if minEps < tValue:
|
||||||
|
res = predLabel
|
||||||
|
|
||||||
|
if time.time() - t1 > 0.2 and self.winStep < 16:
|
||||||
|
self.winStep = self.winStep * 2
|
||||||
|
# print(self.winStep, " ", time.time() - t1)
|
||||||
|
# if res != -1:
|
||||||
|
# print('--------------------- ',res,xxx,' --------------------------')
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# The number of sub-bands in filter bank analysis
|
||||||
|
fs = 250
|
||||||
|
num_chans = 8
|
||||||
|
num_target = 40
|
||||||
|
num_filterBank = 3
|
||||||
|
num_harm = 5
|
||||||
|
stimTime = 0.2 # 多窗口窗长
|
||||||
|
winNum = 50 # 窗口的个数
|
||||||
|
trials = 1
|
||||||
|
step = 50
|
||||||
|
res = -1
|
||||||
|
list_freqs = [8.0, 8.2, 8.4, 8.6, 8.8, 9.0, 9.2, 9.4, 9.6, 9.8, 10.0, 10.2, 10.4, 10.6, 10.8, 11.0, 11.2, 11.4,
|
||||||
|
11.6, 11.8, 12.0, 12.2, 12.4, 12.6, 12.8, 13., 13.2, 13.4, 13.6, 13.8, 14., 14.2, 14.4, 14.6, 14.8,
|
||||||
|
15.0, 15.2, 15.4, 15.6, 15.8]
|
||||||
|
# 初始化对象
|
||||||
|
dw = FbccaDw(fs, num_target, num_chans, num_filterBank, num_harm, stimTime, [1.0, 0.3], [8, 8], winNum)
|
||||||
|
# frequenceband
|
||||||
|
dw.filterFrequenceBank()
|
||||||
|
referenceData = dw.reference(list_freqs, int(winNum * stimTime * fs), num_harm)
|
||||||
|
dw.setNotchFilterPara()
|
||||||
|
|
||||||
|
prelabels = np.zeros((1, 40))
|
||||||
|
coefficient = np.zeros([1, 1])
|
||||||
|
path = "D:\\工作相关\\项目代码\\SDK封装与测试\\八神BCI脑电信号SDK20230105\\双鹰SDK对比\\offline_data250\\"
|
||||||
|
for index in range(1, trials + 1):
|
||||||
|
D = loadmat(os.path.join(path + str(1) + '-warmData.mat'))
|
||||||
|
warmData = D['warmData']
|
||||||
|
dw.onlineInit()
|
||||||
|
dw.filterInit()
|
||||||
|
dw.warmFilter(warmData.T)
|
||||||
|
|
||||||
|
tagget_i = 0
|
||||||
|
for tagget_i in range(1, step + 1):
|
||||||
|
D = loadmat(os.path.join(path + str(1) + '-' + str(tagget_i) + '.mat'))
|
||||||
|
dataSlice = D['dataTemp']
|
||||||
|
res = dw.fbccaDWMW(dataSlice.T, referenceData, tValue=-0.2)
|
||||||
|
if res != -1:
|
||||||
|
break
|
||||||
|
prelabels[0, index - 1] = res + 1
|
||||||
|
print(index, '--', res + 1," 计算轮数", tagget_i)
|
||||||
851
Debug_64ch_Decoder_Optimize/Tools/plot_MI_EEG.py
Normal file
851
Debug_64ch_Decoder_Optimize/Tools/plot_MI_EEG.py
Normal file
@@ -0,0 +1,851 @@
|
|||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.patches import Ellipse
|
||||||
|
import matplotlib.cm as cm
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
from scipy.spatial import Delaunay
|
||||||
|
from scipy.interpolate import Rbf
|
||||||
|
from scipy.signal import welch
|
||||||
|
from scipy.stats import sem
|
||||||
|
from scipy.signal import butter, filtfilt, hilbert
|
||||||
|
import base64
|
||||||
|
|
||||||
|
# 位置坐标
|
||||||
|
def read_ch_pos(file_path=r'xy_64.xlsx'):
|
||||||
|
"""
|
||||||
|
将电极位置信息转换为Dict
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path: 电极位置存储文件, 必须包含'channel', 'x', 'y', 'z'列
|
||||||
|
|
||||||
|
"""
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
file_path = os.path.join(script_dir,file_path )
|
||||||
|
df = pd.read_excel(file_path)
|
||||||
|
# 确保列名正确
|
||||||
|
if not all(col in df.columns for col in ['channel', 'x', 'y', 'z']):
|
||||||
|
raise ValueError("DataFrame必须包含'channel', 'x', 'y', 'z'列")
|
||||||
|
# 创建电极位置字典
|
||||||
|
ch_pos = {}
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
ch_pos[row['channel']] = [row['x'], row['y'], row['z']]
|
||||||
|
return ch_pos
|
||||||
|
# 头部轮廓
|
||||||
|
def draw_head(ax, center=(0, 0), radius=1.0, zorder=4):
|
||||||
|
"""
|
||||||
|
绘制头部轮廓、鼻子和耳朵。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- ax : matplotlib Axes 对象
|
||||||
|
- center : (x, y) 头中心坐标
|
||||||
|
- radius : float, 头半径
|
||||||
|
- zorder : 绘制层级
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 头圆
|
||||||
|
head = plt.Circle(center, radius, fill=False, color='k', linewidth=1, zorder=zorder)
|
||||||
|
ax.add_artist(head)
|
||||||
|
|
||||||
|
# 鼻子(参考 _make_head_outlines)
|
||||||
|
dx = np.exp(np.arccos(np.deg2rad(12)) * 1j)
|
||||||
|
dx_real, dx_imag = dx.real, dx.imag
|
||||||
|
nose_x = np.array([-dx_real, 0, dx_real]) * radius + center[0]
|
||||||
|
nose_y = np.array([dx_imag, 1.15, dx_imag]) * radius + center[1]
|
||||||
|
ax.plot(nose_x, nose_y, color='k', linewidth=1, zorder=zorder)
|
||||||
|
|
||||||
|
# 耳朵(参考 _make_head_outlines 手动标定)
|
||||||
|
ear_radius = radius * 0.12
|
||||||
|
ear_scale = radius * 2 # 根据半径缩放
|
||||||
|
theta = np.linspace(np.pi / 2, 3 * np.pi / 2, 30)
|
||||||
|
|
||||||
|
# 左耳
|
||||||
|
left_ear_x_array = np.array([0.497, 0.510, 0.518, 0.5299, 0.5419,
|
||||||
|
0.54, 0.547, 0.532, 0.510, 0.489]) * ear_scale
|
||||||
|
left_ear_y_array = np.array([0.0555, 0.0775, 0.0783, 0.0746, 0.0555,
|
||||||
|
-0.0055, -0.0932, -0.1313, -0.1384, -0.1199]) * ear_scale + center[1]
|
||||||
|
ax.plot(center[0] - left_ear_x_array, left_ear_y_array, color='k', linewidth=1, zorder=zorder)
|
||||||
|
|
||||||
|
# 右耳
|
||||||
|
right_ear_x_array = np.array([0.497, 0.510, 0.518, 0.5299, 0.5419,
|
||||||
|
0.54, 0.547, 0.532, 0.510, 0.489]) * ear_scale
|
||||||
|
right_ear_y_array = np.array([0.0555, 0.0775, 0.0783, 0.0746, 0.0555,
|
||||||
|
-0.0055, -0.0932, -0.1313, -0.1384, -0.1199]) * ear_scale + center[1]
|
||||||
|
ax.plot(center[0] + right_ear_x_array, right_ear_y_array, color='k', linewidth=1, zorder=zorder)
|
||||||
|
# 地形图 插值
|
||||||
|
def rbf_D_interpolate(xy, v, center=(0, 0), radius=1.1, grid_res=300,
|
||||||
|
n_extra=32, rbf_func='multiquadric', smooth=0,
|
||||||
|
border='mean', border_scale=1.0001, n_ngb=4):
|
||||||
|
"""
|
||||||
|
使用 RBF + Delaunay 邻域均值方式生成平滑的 EEG topomap 插值表面。
|
||||||
|
|
||||||
|
参数
|
||||||
|
----
|
||||||
|
xy : (N,2) array
|
||||||
|
电极二维坐标(与绘图坐标系一致)
|
||||||
|
v : (N,) array
|
||||||
|
每个电极对应的值(e.g. PSD)
|
||||||
|
center : tuple (x0, y0)
|
||||||
|
头部圆心(默认 (0,0))
|
||||||
|
radius : float
|
||||||
|
头部半径(用于生成边界点与网格范围)
|
||||||
|
grid_res : int
|
||||||
|
网格分辨率(每轴点数)
|
||||||
|
n_extra : int
|
||||||
|
边界虚拟点数量
|
||||||
|
rbf_func : str
|
||||||
|
RBF 内核名称('multiquadric','thin_plate','gaussian',...)
|
||||||
|
smooth : float
|
||||||
|
RBF 平滑参数
|
||||||
|
border : 'mean' or float
|
||||||
|
若 'mean':边界点用邻近真实通道均值赋值(推荐)
|
||||||
|
若 float:边界点赋相同常数值
|
||||||
|
border_scale : float
|
||||||
|
边界点半径相对 radius 的缩放(略微 >1 用以外推)
|
||||||
|
n_ngb : int
|
||||||
|
为每个边界点取值时使用的最近真实通道数
|
||||||
|
|
||||||
|
返回
|
||||||
|
----
|
||||||
|
zi : (grid_res, grid_res) ndarray
|
||||||
|
插值结果(与 grid_x, grid_y 对齐)
|
||||||
|
grid_x, grid_y : ndarrays
|
||||||
|
meshgrid(由 np.meshgrid 生成)
|
||||||
|
"""
|
||||||
|
xy = np.asarray(xy)
|
||||||
|
v = np.asarray(v)
|
||||||
|
if xy.ndim != 2 or xy.shape[1] != 2:
|
||||||
|
raise ValueError("xy must be shape (n_channels, 2)")
|
||||||
|
|
||||||
|
n_points = xy.shape[0]
|
||||||
|
|
||||||
|
# --- 1. 生成边界虚拟点(圆周) ---
|
||||||
|
theta = np.linspace(0.0, 2 * np.pi, n_extra, endpoint=False)
|
||||||
|
r_border = radius * border_scale
|
||||||
|
border_xy = np.column_stack([center[0] + r_border * np.cos(theta),
|
||||||
|
center[1] + r_border * np.sin(theta)])
|
||||||
|
|
||||||
|
# --- 2. 用 Delaunay 建图以便找到邻居(对边界点取邻居均值) ---
|
||||||
|
# 合并用于三角化的位置(真实点 + 边界点)
|
||||||
|
tri_xy = np.vstack([xy, border_xy])
|
||||||
|
tri = Delaunay(tri_xy)
|
||||||
|
|
||||||
|
# --- 3. 为边界点赋值 ---
|
||||||
|
if isinstance(border, str) and border == 'mean':
|
||||||
|
# 使用 Delaunay 的 vertex_neighbor_vertices 索引
|
||||||
|
# 注意:tri.vertex_neighbor_vertices 给出 vertices -> neighbor indptr
|
||||||
|
indices, indptr = tri.vertex_neighbor_vertices
|
||||||
|
v_extra = np.zeros(n_extra)
|
||||||
|
used = np.zeros(n_extra, dtype=bool)
|
||||||
|
# 边界点在 tri_xy 中的索引范围
|
||||||
|
rng = range(n_points, n_points + n_extra)
|
||||||
|
for idx, extra_idx in enumerate(rng):
|
||||||
|
neigh = indptr[indices[extra_idx]:indices[extra_idx + 1]]
|
||||||
|
# 仅保留原始点索引(小于 n_points)
|
||||||
|
neigh = neigh[neigh < n_points]
|
||||||
|
if neigh.size > 0:
|
||||||
|
used[idx] = True
|
||||||
|
# 使用最近 n_ngb 个邻居的均值(若邻居多则取最近的 n_ngb)
|
||||||
|
if neigh.size > n_ngb:
|
||||||
|
# 计算距离并选取最近 n_ngb
|
||||||
|
d = np.linalg.norm(xy[neigh] - tri_xy[extra_idx], axis=1)
|
||||||
|
order = np.argsort(d)[:n_ngb]
|
||||||
|
sel = neigh[order]
|
||||||
|
else:
|
||||||
|
sel = neigh
|
||||||
|
v_extra[idx] = v[sel].mean()
|
||||||
|
if not used.all() and used.any():
|
||||||
|
v_extra[~used] = np.mean(v_extra[used])
|
||||||
|
elif not used.any():
|
||||||
|
v_extra[:] = np.mean(v)
|
||||||
|
else:
|
||||||
|
# border 是数值
|
||||||
|
v_extra = np.full(n_extra, float(border))
|
||||||
|
|
||||||
|
# --- 4. 合并所有已知点并构建 RBF ---
|
||||||
|
all_xy = np.vstack([xy, border_xy])
|
||||||
|
all_v = np.concatenate([v, v_extra])
|
||||||
|
|
||||||
|
rbf = Rbf(all_xy[:, 0], all_xy[:, 1], all_v,
|
||||||
|
function=rbf_func, smooth=smooth)
|
||||||
|
|
||||||
|
# --- 5. 生成网格(使用 meshgrid,与主函数保持一致) ---
|
||||||
|
xmin, xmax = center[0] - radius, center[0] + radius
|
||||||
|
ymin, ymax = center[1] - radius, center[1] + radius
|
||||||
|
xi = np.linspace(xmin, xmax, grid_res)
|
||||||
|
yi = np.linspace(ymin, ymax, grid_res)
|
||||||
|
grid_x, grid_y = np.meshgrid(xi, yi) # meshgrid 与 imshow 对齐
|
||||||
|
|
||||||
|
# --- 6. 评估 RBF,返回与 grid 对齐的 zi ---
|
||||||
|
zi = rbf(grid_x, grid_y)
|
||||||
|
|
||||||
|
return zi, grid_x, grid_y
|
||||||
|
# plv矩阵计算
|
||||||
|
def calculate_plv(data):
|
||||||
|
"""
|
||||||
|
计算相位锁定值(PLV)矩阵。
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : ndarray, shape (num_channels, num_samples)
|
||||||
|
EEG 数据,通道数为 num_channels,样本数为 num_samples。
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
plv_matrix : ndarray, shape (num_channels, num_channels)
|
||||||
|
计算得到的 PLV 矩阵,表示各通道间的相位同步。
|
||||||
|
"""
|
||||||
|
num_channels, num_samples = data.shape
|
||||||
|
plv_matrix = np.zeros((num_channels, num_channels))
|
||||||
|
|
||||||
|
# 计算每个通道的解析信号
|
||||||
|
analytic_signals = np.apply_along_axis(hilbert, axis=1, arr=data)
|
||||||
|
|
||||||
|
for i in range(num_channels):
|
||||||
|
for j in range(i + 1, num_channels): # 只计算上三角矩阵,避免重复计算
|
||||||
|
# 计算 phase difference
|
||||||
|
phase_diff = np.angle(analytic_signals[i] * np.conj(analytic_signals[j]))
|
||||||
|
plv = np.abs(np.mean(np.exp(1j * phase_diff)))
|
||||||
|
plv_matrix[i, j] = plv
|
||||||
|
plv_matrix[j, i] = plv # 对称矩阵
|
||||||
|
|
||||||
|
return plv_matrix
|
||||||
|
# 矩阵阈值化
|
||||||
|
def threshold_proportional(adj, prop=0.2):
|
||||||
|
"""
|
||||||
|
Apply a proportional threshold to retain the top proportion of strongest edges.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
adj : ndarray, shape (n_channels, n_channels)
|
||||||
|
Adjacency matrix to threshold.
|
||||||
|
prop : float
|
||||||
|
Proportion of edges to retain (0 < prop <= 1).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bin_adj : ndarray, shape (n_channels, n_channels)
|
||||||
|
Binary adjacency matrix after thresholding.
|
||||||
|
"""
|
||||||
|
n = adj.shape[0]
|
||||||
|
triu_idx = np.triu_indices(n, k=1)
|
||||||
|
weights = adj[triu_idx]
|
||||||
|
k = int(np.floor(len(weights) * prop))
|
||||||
|
|
||||||
|
# Ensure that at least one edge is retained
|
||||||
|
k = max(k, 1)
|
||||||
|
|
||||||
|
# Get the threshold value
|
||||||
|
thr = np.sort(weights)[-k]
|
||||||
|
|
||||||
|
# Apply the threshold to create a binary adjacency matrix
|
||||||
|
bin_adj = np.where(adj >= thr, adj, 0.0)
|
||||||
|
|
||||||
|
return bin_adj
|
||||||
|
# 单个脑网络
|
||||||
|
def plot_single_network(ch_names,adj,ax=None,
|
||||||
|
node_size=20, node_color='orange',highlight_nodes=[], show_names=True,
|
||||||
|
edge_color='gray', weighted=True,
|
||||||
|
radius=1.1, figsize=(6, 6),cmap='RdYlBu_r'):
|
||||||
|
# 若 ax 未传入,则自己创建
|
||||||
|
own_fig = False
|
||||||
|
if ax is None:
|
||||||
|
fig, ax = plt.subplots(figsize=figsize)
|
||||||
|
own_fig = True
|
||||||
|
else:
|
||||||
|
fig = ax.figure
|
||||||
|
|
||||||
|
# 坐标归一化
|
||||||
|
pos3d = read_ch_pos()
|
||||||
|
all_chs_xy = np.array([pos3d[ch][:2] for ch in pos3d.keys()])
|
||||||
|
all_chs_xy -= all_chs_xy.mean(axis=0)
|
||||||
|
all_chs_xy /= np.sqrt((all_chs_xy ** 2).sum(axis=1)).max()
|
||||||
|
xy_dict = dict(zip(pos3d.keys(), all_chs_xy))
|
||||||
|
xy = np.array([xy_dict[ch] for ch in ch_names])
|
||||||
|
center = xy_dict.get('CZ', np.mean(list(xy_dict.values()), axis=0))
|
||||||
|
|
||||||
|
# ===== 初始化绘图窗口 =====
|
||||||
|
ax.set_aspect('equal')
|
||||||
|
ax.axis('off')
|
||||||
|
# 设置边界(与原类保持一致)
|
||||||
|
ear_radius = radius * 0.12
|
||||||
|
nose_height = radius * 0.15
|
||||||
|
margin_x = radius * 0.12 + 0.05
|
||||||
|
ax.set_xlim(center[0] - radius - margin_x, center[0] + radius + margin_x)
|
||||||
|
ax.set_ylim(center[1] - radius - ear_radius, center[1] + radius + nose_height + ear_radius)
|
||||||
|
|
||||||
|
# 绘制头部轮廓
|
||||||
|
draw_head(ax, center=center, radius=radius)
|
||||||
|
|
||||||
|
# 节点
|
||||||
|
for ch in ch_names:
|
||||||
|
color = 'red' if ch in highlight_nodes else node_color
|
||||||
|
ax.scatter(*xy_dict[ch], s=node_size, color=color, edgecolor='k', zorder=4)
|
||||||
|
if show_names:
|
||||||
|
ax.text(xy_dict[ch][0], xy_dict[ch][1] + 0.03, ch,
|
||||||
|
ha='center', va='bottom', fontsize=8, zorder=5)
|
||||||
|
|
||||||
|
# colorbar
|
||||||
|
norm = mcolors.Normalize(vmin=0, vmax=1)
|
||||||
|
color_map = matplotlib.colormaps.get_cmap(cmap)
|
||||||
|
# ========= 边 ==========
|
||||||
|
N = len(ch_names)
|
||||||
|
for i in range(N):
|
||||||
|
for j in range(i + 1, N):
|
||||||
|
w = adj[i, j]
|
||||||
|
if w > 0:
|
||||||
|
x = [xy[i, 0], xy[j, 0]]
|
||||||
|
y = [xy[i, 1], xy[j, 1]]
|
||||||
|
lw = 1.5
|
||||||
|
if weighted:
|
||||||
|
ax.plot(x, y,
|
||||||
|
color=color_map(norm(w)),
|
||||||
|
linewidth=lw,
|
||||||
|
alpha=0.7,
|
||||||
|
zorder=3)
|
||||||
|
else:
|
||||||
|
ax.plot(x, y,
|
||||||
|
color=edge_color,
|
||||||
|
linewidth=lw,
|
||||||
|
alpha=0.7,
|
||||||
|
zorder=3)
|
||||||
|
|
||||||
|
if own_fig:
|
||||||
|
# 不回传 添加颜色条
|
||||||
|
sm = cm.ScalarMappable(norm=norm, cmap=color_map)
|
||||||
|
cbar = plt.colorbar(sm, ax=ax, fraction=0.035)
|
||||||
|
cbar.set_label('Connection Strength', fontsize=10)
|
||||||
|
cbar.ax.tick_params(direction='in', labelsize=10)
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
else:
|
||||||
|
|
||||||
|
return ax
|
||||||
|
# 脑网络对比
|
||||||
|
def plot_multiband_network(ch_names, adj_MI, adj_Rest,cmap='RdYlBu_r'):
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
|
||||||
|
fontsize = 16
|
||||||
|
fig.text(0.285, 0.08, 'MI', fontsize=fontsize, ha='center', va='center', rotation=0)
|
||||||
|
fig.text(0.68, 0.08, 'Rest', fontsize=fontsize, ha='center', va='center', rotation=0)
|
||||||
|
|
||||||
|
im1 = plot_single_network(ch_names,adj_MI,ax=axes[0], show_names=True,cmap=cmap)
|
||||||
|
# Rest 行
|
||||||
|
im2 = plot_single_network(ch_names,adj_Rest,ax=axes[1],show_names=True,cmap=cmap)
|
||||||
|
|
||||||
|
# --- 合并 colorbar(右侧一个) ---
|
||||||
|
norm = mcolors.Normalize(vmin=0, vmax=1)
|
||||||
|
color_map = matplotlib.colormaps.get_cmap(cmap)
|
||||||
|
sm = cm.ScalarMappable(norm=norm, cmap=color_map)
|
||||||
|
cbar = plt.colorbar(sm, ax=axes.ravel().tolist(), fraction=0.02)
|
||||||
|
cbar.set_label('Connection Strength', fontsize=10)
|
||||||
|
cbar.ax.tick_params(direction='in', labelsize=10)
|
||||||
|
|
||||||
|
# 将图像保存到内存字节流(PNG 格式)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close(fig) # 释放内存
|
||||||
|
buf.seek(0)
|
||||||
|
image_bytes = buf.read()
|
||||||
|
buf.close()
|
||||||
|
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
# 多个频带psd
|
||||||
|
def compute_band_psd(eeg, fs, bands, labels, trial_idx=0,MI_label=1, Rest_label=2,avg = True):
|
||||||
|
"""
|
||||||
|
eeg: (n_trials, n_channels, n_samples)
|
||||||
|
"""
|
||||||
|
n_trials, n_channels, n_samples = eeg.shape
|
||||||
|
band_names = list(bands.keys())
|
||||||
|
n_bands = len(band_names)
|
||||||
|
|
||||||
|
psd_MI = np.zeros((n_bands, n_channels))
|
||||||
|
psd_Rest = np.zeros((n_bands, n_channels))
|
||||||
|
|
||||||
|
# 先计算所有 trial 的功率谱
|
||||||
|
f, Pxx = welch(eeg, fs=fs, axis=-1, nperseg=fs,noverlap = fs // 2)
|
||||||
|
|
||||||
|
|
||||||
|
for bi, (bname, (f1, f2)) in enumerate(bands.items()):
|
||||||
|
idx = np.logical_and(f >= f1, f <= f2)
|
||||||
|
band_power = Pxx[:, :, idx].mean(axis=-1)
|
||||||
|
|
||||||
|
band_power_flat = band_power.flatten()
|
||||||
|
power_min = band_power_flat.min()
|
||||||
|
power_max = band_power_flat.max()
|
||||||
|
if power_max - power_min > 1e-12:
|
||||||
|
band_power_norm = (band_power - power_min) / (power_max - power_min)
|
||||||
|
else:
|
||||||
|
band_power_norm = band_power
|
||||||
|
|
||||||
|
if avg:
|
||||||
|
psd_MI[bi] = band_power_norm[labels == MI_label].mean(axis=0)
|
||||||
|
psd_Rest[bi] = band_power_norm[labels == Rest_label].mean(axis=0)
|
||||||
|
else:
|
||||||
|
psd_MI[bi] = band_power_norm[labels == MI_label][trial_idx]
|
||||||
|
psd_Rest[bi] = band_power_norm[labels == Rest_label][trial_idx]
|
||||||
|
return band_names, psd_MI, psd_Rest
|
||||||
|
# 单个脑地形图
|
||||||
|
def plot_single_topomap(ch_names, psd_values, cmap='RdYlBu_r', vlim=(0, 1),
|
||||||
|
show_names=True, node_size=3, radius=1.1, grid_res=300,
|
||||||
|
n_contours=None, contour_color='k',
|
||||||
|
ax=None,figsize=(6,6)):
|
||||||
|
# 若 ax 未传入,则自己创建
|
||||||
|
own_fig = False
|
||||||
|
if ax is None:
|
||||||
|
fig, ax = plt.subplots(figsize=figsize)
|
||||||
|
own_fig = True
|
||||||
|
else:
|
||||||
|
fig = ax.figure
|
||||||
|
|
||||||
|
# ===== 初始化绘图窗口 =====
|
||||||
|
ax.set_aspect('equal')
|
||||||
|
ax.axis('off')
|
||||||
|
# ax.set_title("EEG topomap (MNE-like)")
|
||||||
|
|
||||||
|
# 坐标归一化
|
||||||
|
pos3d = read_ch_pos()
|
||||||
|
all_chs_xy = np.array([pos3d[ch][:2] for ch in pos3d.keys()])
|
||||||
|
all_chs_xy -= all_chs_xy.mean(axis=0)
|
||||||
|
all_chs_xy /= np.sqrt((all_chs_xy ** 2).sum(axis=1)).max()
|
||||||
|
pos2d_dict = dict(zip(pos3d.keys(), all_chs_xy))
|
||||||
|
xy = np.array([pos2d_dict[ch] for ch in ch_names])
|
||||||
|
center = pos2d_dict.get('CZ', np.mean(list(pos2d_dict.values()), axis=0))
|
||||||
|
|
||||||
|
# 绘制头部轮廓
|
||||||
|
draw_head(ax, center=center, radius=radius)
|
||||||
|
# 绘制电极
|
||||||
|
fontsize = 4
|
||||||
|
ax.scatter(xy[:, 0], xy[:, 1], c='k', s=node_size, zorder=5)
|
||||||
|
if show_names:
|
||||||
|
for i, ch in enumerate(ch_names):
|
||||||
|
ax.text(xy[i, 0], xy[i, 1] + 0.03, ch,
|
||||||
|
ha='center', va='bottom', fontsize=fontsize, zorder=6)
|
||||||
|
|
||||||
|
# 数据插值
|
||||||
|
zi, grid_x, grid_y = rbf_D_interpolate(
|
||||||
|
xy, psd_values, radius=radius,
|
||||||
|
grid_res=grid_res
|
||||||
|
)
|
||||||
|
xmin, xmax = center[0] - radius, center[0] + radius
|
||||||
|
ymin, ymax = center[1] - radius, center[1] + radius
|
||||||
|
extent = (xmin, xmax, ymin, ymax)
|
||||||
|
im = ax.imshow(zi, extent=extent, origin='lower',
|
||||||
|
cmap=cmap, vmin=vlim[0], vmax=vlim[1],
|
||||||
|
interpolation='bicubic', zorder=0)
|
||||||
|
# 裁剪路径
|
||||||
|
patch_ = Ellipse(center, 2 * radius, 2 * radius, clip_on=True, transform=ax.transData)
|
||||||
|
im.set_clip_path(patch_)
|
||||||
|
# 初始等高线
|
||||||
|
linewidths = 0.5
|
||||||
|
if n_contours is None:
|
||||||
|
cset = ax.contour(grid_x, grid_y, zi,
|
||||||
|
colors=contour_color, linewidths=linewidths, zorder=2)
|
||||||
|
else:
|
||||||
|
cset = ax.contour(grid_x, grid_y, zi, levels=n_contours,
|
||||||
|
colors=contour_color, linewidths=linewidths, zorder=2)
|
||||||
|
cset.set_clip_path(patch_)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if own_fig:
|
||||||
|
# 不回传 添加颜色条
|
||||||
|
plt.colorbar(im, ax=ax, fraction=0.035)
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
else:
|
||||||
|
# plt.colorbar(im, ax=ax, fraction=0.035)
|
||||||
|
return im
|
||||||
|
# 脑地形图对比
|
||||||
|
def plot_multiband_topomaps(ch_names, psd_MI, psd_Rest, bands):
|
||||||
|
band_names = list(bands.keys()) # 改动 1:新增这行
|
||||||
|
n_bands = len(band_names)
|
||||||
|
fig, axes = plt.subplots(2, n_bands, figsize=(3*n_bands, 6))
|
||||||
|
|
||||||
|
fontsize = 16
|
||||||
|
|
||||||
|
axes[0, 0].text(-0.1, 0.5, 'MI', transform=axes[0, 0].transAxes, rotation=0, va='center', ha='center', fontsize=fontsize-2)
|
||||||
|
axes[1, 0].text(-0.1, 0.5, 'Rest', transform=axes[1, 0].transAxes, rotation=0, va='center', ha='center', fontsize=fontsize-2)
|
||||||
|
|
||||||
|
imgs = []
|
||||||
|
for i, bname in enumerate(band_names):
|
||||||
|
axes[0, i].set_title(bname, fontsize=fontsize, pad=0)
|
||||||
|
# MI 行
|
||||||
|
im1 = plot_single_topomap(ch_names,psd_MI[i],ax=axes[0, i], show_names=True)
|
||||||
|
# Rest 行
|
||||||
|
im2 = plot_single_topomap(ch_names,psd_Rest[i],ax=axes[1, i],show_names=True)
|
||||||
|
imgs.append(im1)
|
||||||
|
|
||||||
|
# --- 单个右侧合并 colorbar ---
|
||||||
|
cbar = fig.colorbar(imgs[0], ax=axes,fraction=0.02)
|
||||||
|
# cbar.set_label("PSD Power",fontsize=fontsize-4)
|
||||||
|
cbar.ax.tick_params(direction='in', labelsize=10)
|
||||||
|
|
||||||
|
# 将图像保存到内存字节流(PNG 格式)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close(fig) # 释放内存
|
||||||
|
buf.seek(0)
|
||||||
|
image_bytes = buf.read()
|
||||||
|
buf.close()
|
||||||
|
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
# 小波
|
||||||
|
def morlet_wavelet(f, fs, n_cycles=7):
|
||||||
|
"""
|
||||||
|
创建 Morlet 小波
|
||||||
|
f: 频率
|
||||||
|
fs: 采样率
|
||||||
|
"""
|
||||||
|
sigma_t = n_cycles / (2 * np.pi * f)
|
||||||
|
t = np.arange(-3*sigma_t, 3*sigma_t, 1/fs)
|
||||||
|
wavelet = (np.pi**-0.25) * np.exp(2j*np.pi*f*t) * np.exp(-(t**2)/(2*sigma_t**2))
|
||||||
|
return wavelet
|
||||||
|
|
||||||
|
|
||||||
|
# 希尔伯特变换 计算ERDS 效果不佳
|
||||||
|
def bandpass_filter(data, fs, band, order=4):
|
||||||
|
nyq = fs / 2
|
||||||
|
b, a = butter(order, [band[0]/nyq, band[1]/nyq], btype='band')
|
||||||
|
return filtfilt(b, a, data, axis=-1)
|
||||||
|
def compute_power_hilbert(filtered_data,is_dB =True):
|
||||||
|
analytic = hilbert(filtered_data, axis=-1)
|
||||||
|
power = np.abs(analytic) ** 2
|
||||||
|
if is_dB:
|
||||||
|
power = 10 * np.log10(power)
|
||||||
|
return power
|
||||||
|
def compute_power(data, fs=250,
|
||||||
|
bands={"mu": (8,12), "beta": (13,30)}):
|
||||||
|
"""
|
||||||
|
返回:
|
||||||
|
power_dict[band] = (n_trials, n_ch, n_samples)
|
||||||
|
"""
|
||||||
|
power_dict = {}
|
||||||
|
for band_name, band_range in bands.items():
|
||||||
|
filt = bandpass_filter(data, fs, band_range)
|
||||||
|
power = compute_power_hilbert(filt)
|
||||||
|
power_dict[band_name] = power
|
||||||
|
|
||||||
|
return power_dict
|
||||||
|
|
||||||
|
def compute_erds(power_MI, power_Rest, baseline_period=None):
|
||||||
|
"""
|
||||||
|
计算事件相关去同步/同步 (ERDS)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
power_MI, power_Rest: (n_trials, n_ch, n_samples)
|
||||||
|
功率数据,单位为 µV² 或 dB(取决于 compute_power_hilbert 的 is_dB 参数)
|
||||||
|
baseline_period: tuple (start_idx, end_idx) or None
|
||||||
|
基线时间段索引。如果为None,使用 Rest 状态的平均值作为基线
|
||||||
|
|
||||||
|
返回:
|
||||||
|
MI_erds_mean, MI_erds_sem
|
||||||
|
Rest_erds_mean, Rest_erds_sem
|
||||||
|
所有返回值的形状为 (n_ch, n_samples)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if baseline_period is not None:
|
||||||
|
start_idx, end_idx = baseline_period
|
||||||
|
baseline = np.concatenate([power_MI[:, :, start_idx:end_idx],
|
||||||
|
power_Rest[:, :, start_idx:end_idx]], axis=0)
|
||||||
|
baseline = baseline.mean(axis=(0, 2), keepdims=True)
|
||||||
|
else:
|
||||||
|
baseline = power_Rest.mean(axis=(0,2), keepdims=True)
|
||||||
|
|
||||||
|
# === ERDS (%) ===
|
||||||
|
MI_erds = (power_MI - baseline) / baseline * 100
|
||||||
|
Rest_erds = (power_Rest - baseline) / baseline * 100
|
||||||
|
|
||||||
|
return (
|
||||||
|
MI_erds.mean(axis=0), sem(MI_erds, axis=0),
|
||||||
|
Rest_erds.mean(axis=0), sem(Rest_erds, axis=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_all_erds(MI_power_dict, Rest_power_dict):
|
||||||
|
"""
|
||||||
|
对多个频带同时计算 ERDS。
|
||||||
|
|
||||||
|
输入:
|
||||||
|
MI_power_dict[band] = (n_trials, n_ch, n_samples)
|
||||||
|
Rest_power_dict[band] = (n_trials, n_ch, n_samples)
|
||||||
|
|
||||||
|
输出:
|
||||||
|
erds_MI[band] = (mean, sem)
|
||||||
|
erds_Rest[band] = (mean, sem)
|
||||||
|
"""
|
||||||
|
|
||||||
|
erds_MI = {}
|
||||||
|
erds_Rest = {}
|
||||||
|
|
||||||
|
for band in MI_power_dict.keys():
|
||||||
|
MI_power = MI_power_dict[band]
|
||||||
|
Rest_power = Rest_power_dict[band]
|
||||||
|
|
||||||
|
MI_mean, MI_sem, Rest_mean, Rest_sem = compute_erds(MI_power, Rest_power)
|
||||||
|
|
||||||
|
erds_MI[band] = (MI_mean, MI_sem)
|
||||||
|
erds_Rest[band] = (Rest_mean, Rest_sem)
|
||||||
|
|
||||||
|
return erds_MI, erds_Rest
|
||||||
|
|
||||||
|
def plot_compare_erds(data_MI, data_Rest, mode="power",
|
||||||
|
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'],
|
||||||
|
compare_names=['C3', 'CZ', 'C4'], bands=['mu', 'beta'],
|
||||||
|
fs=250, t=None, figsize=(12,6)):
|
||||||
|
|
||||||
|
n_bands = len(bands)
|
||||||
|
n_chs = len(compare_names)
|
||||||
|
|
||||||
|
# 自动添加单位
|
||||||
|
if mode == "power":
|
||||||
|
# y_unit = "Power (µV²)"
|
||||||
|
y_unit = "Power (dB)"
|
||||||
|
elif mode == "erds":
|
||||||
|
y_unit = "ERDS (%)"
|
||||||
|
else:
|
||||||
|
y_unit = ""
|
||||||
|
|
||||||
|
if t is None:
|
||||||
|
n_samples = next(iter(data_MI.values())).shape[-1] \
|
||||||
|
if mode=="power" else next(iter(data_MI.values()))[0].shape[-1]
|
||||||
|
t = np.arange(n_samples) / fs
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(n_bands, n_chs, figsize=figsize, sharex=True, sharey=True)
|
||||||
|
|
||||||
|
for i, band in enumerate(bands):
|
||||||
|
|
||||||
|
# 选择数据结构
|
||||||
|
if mode == "power":
|
||||||
|
MI_band = data_MI[band] # (trials, ch, samples)
|
||||||
|
Rest_band = data_Rest[band]
|
||||||
|
|
||||||
|
avg_MI = MI_band.mean(axis=0)
|
||||||
|
sem_MI = MI_band.std(axis=0)/np.sqrt(MI_band.shape[0])
|
||||||
|
|
||||||
|
avg_Rest = Rest_band.mean(axis=0)
|
||||||
|
sem_Rest = Rest_band.std(axis=0)/np.sqrt(Rest_band.shape[0])
|
||||||
|
|
||||||
|
elif mode == "erds":
|
||||||
|
avg_MI, sem_MI = data_MI[band]
|
||||||
|
avg_Rest, sem_Rest = data_Rest[band]
|
||||||
|
|
||||||
|
for j, ch in enumerate(compare_names):
|
||||||
|
ax = axes[i, j] if n_bands > 1 else axes[j]
|
||||||
|
|
||||||
|
ch_idx = ch_names.index(ch)
|
||||||
|
|
||||||
|
# 绘制 MI
|
||||||
|
ax.plot(t, avg_MI[ch_idx], color="C0", label="MI")
|
||||||
|
ax.fill_between(t,
|
||||||
|
avg_MI[ch_idx]-sem_MI[ch_idx],
|
||||||
|
avg_MI[ch_idx]+sem_MI[ch_idx],
|
||||||
|
alpha=0.3, color="C0")
|
||||||
|
|
||||||
|
# 绘制 Rest
|
||||||
|
ax.plot(t, avg_Rest[ch_idx], color="C1", label="Rest")
|
||||||
|
ax.fill_between(t,
|
||||||
|
avg_Rest[ch_idx]-sem_Rest[ch_idx],
|
||||||
|
avg_Rest[ch_idx]+sem_Rest[ch_idx],
|
||||||
|
alpha=0.3, color="C1")
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
ax.set_title(ch)
|
||||||
|
|
||||||
|
# ← Y 轴加单位
|
||||||
|
if j == 0:
|
||||||
|
ax.set_ylabel(f"{band}\n{y_unit}")
|
||||||
|
|
||||||
|
if i == n_bands - 1:
|
||||||
|
ax.set_xlabel("Time (s)")
|
||||||
|
|
||||||
|
ax.grid(alpha=0.3)
|
||||||
|
|
||||||
|
if i == 0 and j == n_chs - 1:
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# 将图像保存到内存字节流(PNG 格式)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close(fig) # 释放内存
|
||||||
|
buf.seek(0)
|
||||||
|
image_bytes = buf.read()
|
||||||
|
buf.close()
|
||||||
|
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
# 对比 MI vs Rest 的功率谱密度 PSD
|
||||||
|
def plot_psd_compare(MI_data, Rest_data, ch_names, compare_names=['C3', 'CZ', 'C4'],
|
||||||
|
fs=250, nperseg=None, average=True, show_sem=True,
|
||||||
|
figsize=(12, 3), save_dir=None, filename="psd.png"):
|
||||||
|
"""
|
||||||
|
对比 MI vs Rest 的功率谱密度 PSD
|
||||||
|
|
||||||
|
MI_data, Rest_data: (n_trials, n_ch, n_samples)
|
||||||
|
channels: 需要绘制的通道
|
||||||
|
average: 是否对所有试次平均
|
||||||
|
show_sem: 是否绘制 SEM 阴影
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_trials, n_ch, n_samples = MI_data.shape
|
||||||
|
n_trials = min(len(MI_data), len(Rest_data))
|
||||||
|
# assert Rest_data.shape == MI_data.shape, "MI 和 Rest 数据维度必须一致"
|
||||||
|
|
||||||
|
if nperseg is None:
|
||||||
|
nperseg = fs # 每 1 秒窗长度
|
||||||
|
|
||||||
|
# 计算 MI PSD
|
||||||
|
psd_MI_all = []
|
||||||
|
for trial in range(n_trials):
|
||||||
|
psd_trial = []
|
||||||
|
for ch in range(n_ch):
|
||||||
|
f, Pxx = welch(MI_data[trial, ch], fs=fs, nperseg=nperseg)
|
||||||
|
psd_trial.append(Pxx)
|
||||||
|
psd_MI_all.append(psd_trial)
|
||||||
|
psd_MI_all = np.array(psd_MI_all)
|
||||||
|
|
||||||
|
# 计算 Rest PSD
|
||||||
|
psd_Rest_all = []
|
||||||
|
for trial in range(n_trials):
|
||||||
|
psd_trial = []
|
||||||
|
for ch in range(n_ch):
|
||||||
|
_, Pxx = welch(Rest_data[trial, ch], fs=fs, nperseg=nperseg)
|
||||||
|
psd_trial.append(Pxx)
|
||||||
|
psd_Rest_all.append(psd_trial)
|
||||||
|
psd_Rest_all = np.array(psd_Rest_all)
|
||||||
|
|
||||||
|
# ---- Plot ----
|
||||||
|
fig, ax = plt.subplots(1, len(compare_names), figsize=figsize)
|
||||||
|
if len(compare_names) == 1:
|
||||||
|
ax = [ax]
|
||||||
|
|
||||||
|
for i, ch in enumerate(compare_names):
|
||||||
|
ch_idx = ch_names.index(ch)
|
||||||
|
psd_MI_ch = psd_MI_all[:, ch_idx, :]
|
||||||
|
psd_Rest_ch = psd_Rest_all[:, ch_idx, :]
|
||||||
|
|
||||||
|
if average:
|
||||||
|
mean_MI = psd_MI_ch.mean(axis=0)
|
||||||
|
mean_Rest = psd_Rest_ch.mean(axis=0)
|
||||||
|
|
||||||
|
ax[i].plot(f, mean_MI, color='C0', label='MI')
|
||||||
|
ax[i].plot(f, mean_Rest, color='C1', label='Rest')
|
||||||
|
|
||||||
|
if show_sem:
|
||||||
|
ax[i].fill_between(f, mean_MI - sem(psd_MI_ch, axis=0),
|
||||||
|
mean_MI + sem(psd_MI_ch, axis=0), color='C0', alpha=0.3)
|
||||||
|
ax[i].fill_between(f, mean_Rest - sem(psd_Rest_ch, axis=0),
|
||||||
|
mean_Rest + sem(psd_Rest_ch, axis=0), color='C1', alpha=0.3)
|
||||||
|
else:
|
||||||
|
ax[i].plot(f, psd_MI_ch.T, color='C0', alpha=0.3)
|
||||||
|
ax[i].plot(f, psd_Rest_ch.T, color='C1', alpha=0.3)
|
||||||
|
|
||||||
|
ax[i].set_title(ch)
|
||||||
|
ax[i].set_xlabel("Frequency (Hz)")
|
||||||
|
ax[i].set_ylabel("PSD (μV²/Hz)")
|
||||||
|
ax[i].grid(alpha=0.3)
|
||||||
|
if i == 0:
|
||||||
|
ax[i].legend()
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# 将图像保存到内存字节流(PNG 格式)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close(fig) # 释放内存
|
||||||
|
buf.seek(0)
|
||||||
|
image_bytes = buf.read()
|
||||||
|
buf.close()
|
||||||
|
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def plotMain(
|
||||||
|
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'],
|
||||||
|
compare_names = [ 'C3','CZ','C4'],
|
||||||
|
Data = None,labels = None,MI_label = None,Rest_label = None,
|
||||||
|
fs = 250):
|
||||||
|
|
||||||
|
trial_idx = 0
|
||||||
|
|
||||||
|
# 数据划分
|
||||||
|
if not MI_label:
|
||||||
|
label_ = np.unique(labels)
|
||||||
|
else:
|
||||||
|
label_ = (MI_label,Rest_label)
|
||||||
|
MI_data = Data[labels == label_[0]]
|
||||||
|
Rest_data = Data[labels == label_[1]]
|
||||||
|
|
||||||
|
# 典型 EEG 频带
|
||||||
|
FREQ_BANDS = {
|
||||||
|
"Delta (0.8-4Hz)": (0.8, 4),
|
||||||
|
"Theta (4-8Hz)": (4, 8),
|
||||||
|
"Alpha (8-12Hz)": (8, 12),
|
||||||
|
"Beta (12-30Hz)": (12, 30),
|
||||||
|
"All (0.8-30Hz)": (0.8, 30)
|
||||||
|
}
|
||||||
|
# 利用welch估算PSD
|
||||||
|
band_names, psd_MI, psd_Rest= compute_band_psd(
|
||||||
|
eeg=Data,
|
||||||
|
fs=fs,
|
||||||
|
bands=FREQ_BANDS,
|
||||||
|
labels=labels,
|
||||||
|
trial_idx=trial_idx,
|
||||||
|
MI_label=MI_label,
|
||||||
|
Rest_label=Rest_label,
|
||||||
|
avg= True
|
||||||
|
)
|
||||||
|
# 绘制地形图
|
||||||
|
topomaps_imgBytes = plot_multiband_topomaps(
|
||||||
|
ch_names=ch_names,
|
||||||
|
psd_MI=psd_MI,
|
||||||
|
psd_Rest=psd_Rest,
|
||||||
|
bands=FREQ_BANDS
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制脑网络
|
||||||
|
mi_plv_matrix = calculate_plv(MI_data[trial_idx])
|
||||||
|
mi_BI_matrix = threshold_proportional(mi_plv_matrix, prop=0.3)
|
||||||
|
rest_plv_matrix = calculate_plv(Rest_data[trial_idx])
|
||||||
|
rest_BI_matrix = threshold_proportional(rest_plv_matrix, prop=0.3)
|
||||||
|
network_imgBytes = plot_multiband_network(ch_names, mi_BI_matrix, rest_BI_matrix)
|
||||||
|
|
||||||
|
# ERDS 先计算erds,后平均
|
||||||
|
MI_power = compute_power(MI_data)
|
||||||
|
Rest_power = compute_power(Rest_data)
|
||||||
|
erds_dict_MI, erds_dict_Rest = compute_all_erds(MI_power, Rest_power)
|
||||||
|
erds_imgBytes = plot_compare_erds(erds_dict_MI, erds_dict_Rest, ch_names=ch_names,
|
||||||
|
compare_names=compare_names, bands=['mu', 'beta'],
|
||||||
|
fs=fs, mode="erds")
|
||||||
|
|
||||||
|
# 绘制PSD
|
||||||
|
psd_imgBytes = plot_psd_compare(MI_data, Rest_data, ch_names = ch_names, compare_names=compare_names,
|
||||||
|
fs=fs, nperseg=None, average=True, show_sem=True,
|
||||||
|
figsize=(12, 3))
|
||||||
|
return {'topomaps_imgBytes':base64.b64encode(topomaps_imgBytes).decode(),'network_imgBytes':base64.b64encode(network_imgBytes).decode(),
|
||||||
|
'erds_imgBytes':base64.b64encode(erds_imgBytes).decode(),'psd_imgBytes':base64.b64encode(psd_imgBytes).decode()}
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
allData = np.random.uniform(-50,50,size=(80,21,1000))
|
||||||
|
allLabel = np.random.randint(1,3,size=(80,))
|
||||||
|
allData = allData[:len(allLabel)]
|
||||||
|
ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1',
|
||||||
|
'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4']
|
||||||
|
compare_names = ['C3', 'CZ', 'C4']
|
||||||
|
ret = plotMain(ch_names=ch_names, compare_names=compare_names, Data=allData, labels=allLabel, MI_label=1, Rest_label=2,
|
||||||
|
fs=250)
|
||||||
|
print('计算完成,开始发送')
|
||||||
|
from Zmq.zmqClient import zmqClient
|
||||||
|
|
||||||
|
zmqClient = zmqClient('192.168.76.101', 8088)
|
||||||
|
zmqClient.connect()
|
||||||
|
zmqClient.send_to_all('miReport', ret)
|
||||||
68
Debug_64ch_Decoder_Optimize/Zmq/zmqClient.py
Normal file
68
Debug_64ch_Decoder_Optimize/Zmq/zmqClient.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
|
||||||
|
class zmqClient:
|
||||||
|
def __init__(self, host, port):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.client_socket = None
|
||||||
|
self.running = False
|
||||||
|
self.zmq_server = None # Reference to zmqServer for Unity communication
|
||||||
|
# 记录客户端连接前的状态
|
||||||
|
self.state = {
|
||||||
|
'status_code': None,
|
||||||
|
'energy': None
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_zmq_server(self, server):
|
||||||
|
"""Set the zmqServer instance to forward messages to Unity"""
|
||||||
|
self.zmq_server = server
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
# 创建 ZeroMQ 上下文
|
||||||
|
self.context = zmq.Context()
|
||||||
|
# 创建 REQ 套接字(请求端)
|
||||||
|
self.client_socket = self.context.socket(zmq.DEALER)
|
||||||
|
# client_id = b'client1'
|
||||||
|
# self.client_socket.setsockopt(zmq.IDENTITY,client_id)
|
||||||
|
self.client_socket.connect(f"tcp://{self.host}:{self.port}") # 连接到服务器
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
def send_to_all(self, method,params):
|
||||||
|
if method in self.state.keys():
|
||||||
|
self.state[method] = params
|
||||||
|
|
||||||
|
# Also send to Unity via zmqServer if connected
|
||||||
|
if self.zmq_server:
|
||||||
|
self.zmq_server.broadcast_message(method, params)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.running and self.client_socket != None:
|
||||||
|
msg = {'method': method, 'params': params}
|
||||||
|
if method in ['single_trial_plot', 'miReport']:
|
||||||
|
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
|
||||||
|
else:
|
||||||
|
print(msg)
|
||||||
|
self.client_socket.send_multipart([b'', json.dumps(msg).encode('utf-8')])
|
||||||
|
else:
|
||||||
|
if method in self.state.keys():
|
||||||
|
self.state[method] = params
|
||||||
|
except ConnectionResetError:
|
||||||
|
print("Connection lost.")
|
||||||
|
self.running = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
def close_connection(self):
|
||||||
|
self.running = False
|
||||||
|
self.client_socket.close()
|
||||||
|
self.context.term()
|
||||||
|
print("Client closed explicitly.")
|
||||||
|
# 使用TCP客户端
|
||||||
|
if __name__ == "__main__":
|
||||||
|
client = zmqClient('127.0.0.1', 8099)
|
||||||
|
client.connect()
|
||||||
|
# client.close_connection()
|
||||||
149
Debug_64ch_Decoder_Optimize/Zmq/zmqServer.py
Normal file
149
Debug_64ch_Decoder_Optimize/Zmq/zmqServer.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
import threading
|
||||||
|
import json
|
||||||
|
import queue
|
||||||
|
from Device.SunnyLinker import SunnyLinker64
|
||||||
|
|
||||||
|
class zmqServer(threading.Thread):
|
||||||
|
def __init__(self, host='0.0.0.0', port=8099):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.running = False
|
||||||
|
self.get_Impedance = False # 是否返回阻抗值
|
||||||
|
self.open_Impedance = None # 是否开启阻抗检测功能
|
||||||
|
self.StartDecode = False # false 停止解码,true=开始解码
|
||||||
|
self.StartTrain = False # False未进入训练状态,True处于训练状态
|
||||||
|
self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态
|
||||||
|
self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签
|
||||||
|
self.IsExitApp = False # 当socket收到2的时候,就置为True,代表要退出系统了。
|
||||||
|
self.getReport = False # 获取训练报告内容
|
||||||
|
self.daemon = True
|
||||||
|
# 创建 ZeroMQ 上下文
|
||||||
|
self.context = zmq.Context()
|
||||||
|
# 创建 REP 套接字(响应端)
|
||||||
|
self.socket = self.context.socket(zmq.ROUTER)
|
||||||
|
self.socket.bind(f"tcp://{self.host}:{self.port}") # 绑定到端口 8099
|
||||||
|
self.targetFreqs = []
|
||||||
|
self.changeTarget = False # 更换目标频率
|
||||||
|
self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化
|
||||||
|
self.labels = [0x01, 0x02,0x03]
|
||||||
|
|
||||||
|
self.decoder_switch = False #更换解码器
|
||||||
|
self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi'
|
||||||
|
# Client Management (e.g. Unity, Other listeners)
|
||||||
|
self.clients = set() # 维护客户端ID
|
||||||
|
self.send_queue = queue.Queue() # 发送队列,安全信箱,维护socket线程
|
||||||
|
|
||||||
|
def broadcast_message(self, method, params):
|
||||||
|
"""Put message into queue to be sent to all connected clients"""
|
||||||
|
self.send_queue.put((method, params))
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.running = True
|
||||||
|
print(f"Server is running on {self.host}:{self.port}")
|
||||||
|
# Use Poller for non-blocking receive
|
||||||
|
poller = zmq.Poller()
|
||||||
|
poller.register(self.socket, zmq.POLLIN)
|
||||||
|
try:
|
||||||
|
while self.running:
|
||||||
|
# 1. Process Send Queue (Send to all clients)
|
||||||
|
while not self.send_queue.empty():
|
||||||
|
method, params = self.send_queue.get()
|
||||||
|
if self.clients:
|
||||||
|
try:
|
||||||
|
msg = {'method': method, 'params': params}
|
||||||
|
msg_bytes = json.dumps(msg).encode('utf-8')
|
||||||
|
if method in ['single_trial_plot', 'single_trial_plot', 'miReport']:
|
||||||
|
print(f"{{'method': '{method}', 'params': <Base64 Image Data>}}")
|
||||||
|
else:
|
||||||
|
print(f"Sending message: {msg}")
|
||||||
|
# Broadcast to all maintained clients
|
||||||
|
for client_id in list(self.clients):
|
||||||
|
try:
|
||||||
|
# Send: [ID, Empty, JSON]
|
||||||
|
self.socket.send_multipart([client_id, b'', msg_bytes])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error sending to client {client_id}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error preparing broadcast: {e}")
|
||||||
|
|
||||||
|
# 2. Process Receive (Commands)
|
||||||
|
socks = dict(poller.poll(10)) # 100ms timeout
|
||||||
|
if self.socket in socks and socks[self.socket] == zmq.POLLIN:
|
||||||
|
frames = self.socket.recv_multipart()
|
||||||
|
if len(frames) < 3:
|
||||||
|
continue
|
||||||
|
ident, _, message_bytes = frames[:3]
|
||||||
|
if ident not in self.clients: # register client ID
|
||||||
|
self.clients.add(ident)
|
||||||
|
print(f"New Client Detected: {ident}")
|
||||||
|
try:
|
||||||
|
message = json.loads(message_bytes.decode('utf-8'))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
print(f"Received request: {message}")
|
||||||
|
|
||||||
|
method = message.get("method") # process request
|
||||||
|
params = message.get("params")
|
||||||
|
|
||||||
|
if method == "sync":
|
||||||
|
self.state_mode = 'sync'
|
||||||
|
if method == "targetFreqs":
|
||||||
|
if not isinstance(params,list):
|
||||||
|
print('targetFreqs must be a list')
|
||||||
|
continue
|
||||||
|
if params != self.targetFreqs:
|
||||||
|
self.targetFreqs = params
|
||||||
|
self.changeTarget = True
|
||||||
|
if method == "decoderClass":
|
||||||
|
if not isinstance(params,str):
|
||||||
|
print('decoderClass must be a str')
|
||||||
|
continue
|
||||||
|
if params != self.decoder_class:
|
||||||
|
self.decoder_class = params
|
||||||
|
self.decoder_switch = True
|
||||||
|
if method == "getReport":
|
||||||
|
self.getReport = True
|
||||||
|
if method == "train":#训练状态
|
||||||
|
self.state_mode = 'train'
|
||||||
|
self.StartTrain = True
|
||||||
|
self.currentLabel = params # 当前刺激端的训练标签
|
||||||
|
self.sunnyLinker.push_trigger(self.labels[self.currentLabel])
|
||||||
|
elif method == "predict":#预测状态
|
||||||
|
self.state_mode = 'predict'
|
||||||
|
if params == 1: #开始解码
|
||||||
|
self.StartDecode = True
|
||||||
|
self.sunnyLinker.push_trigger(0x63)
|
||||||
|
elif params == 2: #停止解码
|
||||||
|
self.IsExitApp = True
|
||||||
|
self.running = False
|
||||||
|
elif method == "rest": #休息状态
|
||||||
|
self.state_mode = 'rest'
|
||||||
|
elif method == "impedance":
|
||||||
|
if params == 1:
|
||||||
|
self.open_Impedance = True # 开启阻抗
|
||||||
|
self.get_Impedance = True # 返回阻抗
|
||||||
|
elif params == 2:
|
||||||
|
self.open_Impedance = False # 关闭阻抗
|
||||||
|
self.get_Impedance = False # 停止返回阻抗
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An socket error occurred: {e}")
|
||||||
|
finally:
|
||||||
|
self.running = False
|
||||||
|
# 关闭套接字和上下文
|
||||||
|
self.socket.close()
|
||||||
|
self.context.term()
|
||||||
|
print("Server socket and context closed.")
|
||||||
|
def stop(self):
|
||||||
|
"""显式关闭服务器"""
|
||||||
|
self.running = False
|
||||||
|
self.socket.close()
|
||||||
|
self.context.term()
|
||||||
|
print("Server closed explicitly.")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
server = zmqServer()
|
||||||
|
server.start()
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Created on Mon Sep 29 16:14:17 2025
|
||||||
|
|
||||||
|
@author: 23749
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.signal import butter, filtfilt
|
||||||
|
|
||||||
|
## 1.Bandpass Filter
|
||||||
|
def butter_bandpass(lowcut, highcut, fs, order=4):
|
||||||
|
# 滤波器
|
||||||
|
nyq = 0.5 * fs #ny:Nyquist频率,即能表示的最大有效频率
|
||||||
|
low = lowcut / nyq
|
||||||
|
high = highcut / nyq
|
||||||
|
b, a = butter(order, [low, high], btype='band') #巴特沃斯滤波器,order=4阶
|
||||||
|
return b, a
|
||||||
|
|
||||||
|
def bandpass_filter(data, lowcut, highcut, fs, order=4):
|
||||||
|
b, a = butter_bandpass(lowcut, highcut, fs, order)
|
||||||
|
return filtfilt(b, a, data)
|
||||||
|
|
||||||
|
## 2.Eye Blink Dectection
|
||||||
|
def blink_detection(F, fs, Dmin, Dmax, Emin, Emax):
|
||||||
|
"""
|
||||||
|
波形检测
|
||||||
|
输入: 差分特征向量 F, 采样率 fs
|
||||||
|
输出: b (0/1), 以及计算出的 d, e
|
||||||
|
"""
|
||||||
|
|
||||||
|
if F is None or len(F) < 3:
|
||||||
|
return 0, None, None
|
||||||
|
|
||||||
|
# 找最大时间(peak) & 最小时间(valley)
|
||||||
|
t_peak = np.argmax(F)
|
||||||
|
t_valley = np.argmin(F)
|
||||||
|
|
||||||
|
# 要求 peak 在 valley 之前(符合 blink 形态),否则交换
|
||||||
|
if t_valley < t_peak:
|
||||||
|
t_peak, t_valley = t_valley, t_peak
|
||||||
|
|
||||||
|
# 计算持续时间 d (ms)
|
||||||
|
d = (t_valley - t_peak) * 1000.0 / fs
|
||||||
|
|
||||||
|
# 计算能量 e (差分平方和)
|
||||||
|
e = np.sum(F[t_peak:t_valley + 1] ** 2)
|
||||||
|
|
||||||
|
# 阈值判定
|
||||||
|
if Dmin <= d <= Dmax and Emin <= e <= Emax:
|
||||||
|
b = 1 # 检测到眨眼
|
||||||
|
else:
|
||||||
|
b = 0 # 否则 no blink
|
||||||
|
return b, d, e
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fs = 250 # 采样率
|
||||||
|
t = np.arange(0, 5, 1/fs)
|
||||||
|
eog = 0.01 * np.random.randn(len(t)) # 基线+噪声
|
||||||
|
|
||||||
|
# 模拟眨眼(在 2.0s 注入脉冲)
|
||||||
|
center = int(2.0 * fs)
|
||||||
|
eog[center:center+5] += 0.5
|
||||||
|
eog[center+5:center+15] -= 0.4
|
||||||
|
|
||||||
|
# 测试 blink_detection
|
||||||
|
F = np.diff(eog)
|
||||||
|
b, d, e = blink_detection(F, fs, 70, 500, 0.1, 10)
|
||||||
|
print(f"Detected: {b}, Duration: {d}ms, Energy: {e}")
|
||||||
98
Debug_64ch_Decoder_Optimize/build_algorithm.spec
Normal file
98
Debug_64ch_Decoder_Optimize/build_algorithm.spec
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
# -*- mode: python ; coding: utf-8 -*-
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from PyInstaller.utils.hooks import collect_submodules, collect_data_files
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 1. 工程配置区 (Project Config)
|
||||||
|
# ========================================================
|
||||||
|
block_cipher = None
|
||||||
|
ENTRY_POINT = 'runDecoder.py'
|
||||||
|
APP_NAME = 'runDecoder'
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 2. 依赖分析 (Dependency Analysis)
|
||||||
|
# ========================================================
|
||||||
|
# 收集 sklearn, scipy 可能遗漏的隐藏导入
|
||||||
|
hidden_imports = [
|
||||||
|
'sklearn.utils._cython_blas',
|
||||||
|
'sklearn.neighbors.typedefs',
|
||||||
|
'sklearn.neighbors.quad_tree',
|
||||||
|
'sklearn.tree',
|
||||||
|
'sklearn.tree._utils',
|
||||||
|
'einops', # 必须显式添加
|
||||||
|
]
|
||||||
|
|
||||||
|
# 收集 torch 相关的隐式导入
|
||||||
|
hidden_imports += ['torch', 'torchvision']
|
||||||
|
|
||||||
|
# 收集 pandas 相关的隐式导入
|
||||||
|
hidden_imports += ['pandas']
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 3. 资源锚定 (Data Anchoring)
|
||||||
|
# ========================================================
|
||||||
|
# Analysis 中的 datas 用于将文件嵌入到内部
|
||||||
|
datas = []
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 4. 构建流程 (Build Process)
|
||||||
|
# ========================================================
|
||||||
|
a = Analysis(
|
||||||
|
[ENTRY_POINT],
|
||||||
|
pathex=[],
|
||||||
|
binaries=[],
|
||||||
|
datas=datas,
|
||||||
|
hiddenimports=hidden_imports,
|
||||||
|
hookspath=[],
|
||||||
|
hooksconfig={},
|
||||||
|
runtime_hooks=['rthook.py'], # 添加运行时钩子,处理路径和多进程
|
||||||
|
excludes=['tkinter', 'PyQt5', 'PySide2', 'IPython', 'notebook'], # 排除 GUI 和交互式库减小体积
|
||||||
|
win_no_prefer_redirects=False,
|
||||||
|
win_private_assemblies=False,
|
||||||
|
cipher=block_cipher,
|
||||||
|
noarchive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||||
|
|
||||||
|
exe = EXE(
|
||||||
|
pyz,
|
||||||
|
a.scripts,
|
||||||
|
[],
|
||||||
|
exclude_binaries=True,
|
||||||
|
name=APP_NAME,
|
||||||
|
debug=False,
|
||||||
|
bootloader_ignore_signals=False,
|
||||||
|
strip=False,
|
||||||
|
upx=False,
|
||||||
|
console=True, # 保持 True 以便查看日志,部署时可改为 False
|
||||||
|
disable_windowed_traceback=False,
|
||||||
|
argv_emulation=False,
|
||||||
|
target_arch=None,
|
||||||
|
codesign_identity=None,
|
||||||
|
entitlements_file=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========================================================
|
||||||
|
# 5. 打包模式: OneDir (单文件夹) + 资源旁路
|
||||||
|
# ========================================================
|
||||||
|
# 使用 Tree 将文件夹原样复制到 dist/APP_NAME/ 下
|
||||||
|
# 格式: Tree('源路径', prefix='目标子目录')
|
||||||
|
|
||||||
|
coll = COLLECT(
|
||||||
|
exe,
|
||||||
|
a.binaries,
|
||||||
|
a.zipfiles,
|
||||||
|
a.datas,
|
||||||
|
# 显式复制资源文件夹到 exe 同级目录
|
||||||
|
Tree('online_Models', prefix='online_Models', excludes=['*.pyc']),
|
||||||
|
Tree('Tools', prefix='Tools', excludes=['*.pyc']),
|
||||||
|
# config.ini 作为单独文件
|
||||||
|
[('config.ini', 'config.ini', 'DATA')],
|
||||||
|
strip=False,
|
||||||
|
upx=False,
|
||||||
|
upx_exclude=[],
|
||||||
|
name=APP_NAME,
|
||||||
|
)
|
||||||
88
Debug_64ch_Decoder_Optimize/build_with_copy.py
Normal file
88
Debug_64ch_Decoder_Optimize/build_with_copy.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 1. 定义路径
|
||||||
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
DIST_DIR = os.path.join(BASE_DIR, 'dist')
|
||||||
|
APP_NAME = 'runDecoder'
|
||||||
|
TARGET_DIR = os.path.join(DIST_DIR, APP_NAME)
|
||||||
|
|
||||||
|
# 定义需要复制的资源 {源路径: 目标子路径}
|
||||||
|
# 目标子路径相对于 TARGET_DIR
|
||||||
|
RESOURCES = {
|
||||||
|
'config.ini': 'config.ini',
|
||||||
|
'online_Models': 'online_Models',
|
||||||
|
'Tools': 'Tools',
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. 清理旧构建
|
||||||
|
print("[1/3] Cleaning up old builds...")
|
||||||
|
if os.path.exists(DIST_DIR):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(DIST_DIR)
|
||||||
|
print(" Cleaned dist/")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clean dist/: {e}")
|
||||||
|
|
||||||
|
BUILD_DIR = os.path.join(BASE_DIR, 'build')
|
||||||
|
if os.path.exists(BUILD_DIR):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(BUILD_DIR)
|
||||||
|
print(" Cleaned build/")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not clean build/: {e}")
|
||||||
|
|
||||||
|
# 3. 运行 PyInstaller
|
||||||
|
print("[2/3] Running PyInstaller...")
|
||||||
|
# 注意:我们这里不传 --noupx,因为已经在 spec 文件里把 upx=False 写死了
|
||||||
|
cmd = [
|
||||||
|
"pyinstaller",
|
||||||
|
"build_algorithm.spec",
|
||||||
|
"--clean"
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.check_call(cmd, shell=True)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
print("Error: PyInstaller failed.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 4. 复制外部资源文件夹
|
||||||
|
print("[3/3] Verifying and Copying external resources...")
|
||||||
|
|
||||||
|
for src_name, dst_name in RESOURCES.items():
|
||||||
|
src_path = os.path.join(BASE_DIR, src_name)
|
||||||
|
dst_path = os.path.join(TARGET_DIR, dst_name)
|
||||||
|
|
||||||
|
if os.path.exists(src_path):
|
||||||
|
if os.path.isfile(src_path):
|
||||||
|
# 如果是文件
|
||||||
|
try:
|
||||||
|
shutil.copy2(src_path, dst_path)
|
||||||
|
print(f" Copied file: {src_name} -> {dst_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error copying file {src_name}: {e}")
|
||||||
|
else:
|
||||||
|
# 如果是文件夹
|
||||||
|
if os.path.exists(dst_path):
|
||||||
|
try:
|
||||||
|
shutil.rmtree(dst_path) # 先删除 spec 生成的旧文件夹 (如果有)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not remove existing dir {dst_path}: {e}")
|
||||||
|
try:
|
||||||
|
shutil.copytree(src_path, dst_path, ignore=shutil.ignore_patterns('*.pyc', '__pycache__'))
|
||||||
|
print(f" Copied dir: {src_name} -> {dst_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error copying dir {src_name}: {e}")
|
||||||
|
else:
|
||||||
|
print(f" Warning: Source resource not found at {src_path}")
|
||||||
|
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print(f"SUCCESS! Build artifacts are in: {TARGET_DIR}")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,396 @@
|
|||||||
|
import numpy as np
|
||||||
|
from scipy.signal import welch
|
||||||
|
from scipy.fft import fft
|
||||||
|
from scipy import signal
|
||||||
|
from collections import deque
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
# import logging
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
# logger = logging.getLogger(__name__)
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# import matplotlib
|
||||||
|
# matplotlib.use('Agg')
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
# MATPLOTLIB_AVAILABLE = True
|
||||||
|
# except ImportError:
|
||||||
|
# MATPLOTLIB_AVAILABLE = False
|
||||||
|
# logger.warning("matplotlib未安装,报告图表功能不可用")
|
||||||
|
|
||||||
|
|
||||||
|
class Calculate():
|
||||||
|
def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10):
|
||||||
|
self.Threshold_value_low = Threshold_value_low
|
||||||
|
self.Threshold_value_high = Threshold_value_high
|
||||||
|
self.fs = fs
|
||||||
|
self.focus_result = []
|
||||||
|
self.CLI_result = []
|
||||||
|
self.EVI_result = []
|
||||||
|
self.eegQueue = deque(maxlen=win_len)
|
||||||
|
|
||||||
|
# # 存储历史数据用于绘图
|
||||||
|
# self.beta_history = []
|
||||||
|
# self.alpha_history = []
|
||||||
|
# self.theta_history = []
|
||||||
|
# self.focus_history = []
|
||||||
|
# self.timestamp_history = []
|
||||||
|
#
|
||||||
|
# # 记录开始时间
|
||||||
|
# self.start_time = None
|
||||||
|
# self.recording = False
|
||||||
|
#
|
||||||
|
# # 图表保存路径
|
||||||
|
# self.chart_dir = "reports"
|
||||||
|
# if not os.path.exists(self.chart_dir):
|
||||||
|
# os.makedirs(self.chart_dir)
|
||||||
|
# print(f"[调试] 创建目录: {self.chart_dir}")
|
||||||
|
|
||||||
|
# 初始化滤波器
|
||||||
|
self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30)
|
||||||
|
self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False)
|
||||||
|
|
||||||
|
print("[调试] Calculate 类初始化完成")
|
||||||
|
|
||||||
|
def calculate_focus(self, beta, alpha, theta):
|
||||||
|
"""
|
||||||
|
专注度计算 - 固定映射版本
|
||||||
|
"""
|
||||||
|
# 原始比值
|
||||||
|
raw = beta / (alpha + theta + 1e-10)
|
||||||
|
|
||||||
|
# Sigmoid 映射:让 raw 在 0.3-1.5 区间敏感
|
||||||
|
# 参数可调:
|
||||||
|
# k = 12 (斜率,越大越陡)
|
||||||
|
# x0 = 0.6 (中心点,raw=0.6时focus≈50)
|
||||||
|
k = 12.0
|
||||||
|
x0 = 0.6
|
||||||
|
focus = 100.0 / (1.0 + np.exp(-k * (raw - x0)))
|
||||||
|
|
||||||
|
# 可选:添加滑动平均平滑
|
||||||
|
return int(focus)
|
||||||
|
|
||||||
|
def calculate_all(self, data, fs, nperseg=1000):
|
||||||
|
mean_x = np.mean(data, axis=-1, keepdims=True)
|
||||||
|
data = data - mean_x
|
||||||
|
freqs, psd = self.compute_psd_multichannel(data, fs, nperseg)
|
||||||
|
beta_psd = np.sum(self.band_psd(freqs, psd, (13, 30)))
|
||||||
|
alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13)))
|
||||||
|
theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8)))
|
||||||
|
|
||||||
|
print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
focus_score = self.calculate_focus(beta_psd, alpha_psd, theta_psd)
|
||||||
|
focus_score = max(0, min(100, focus_score))
|
||||||
|
|
||||||
|
self.focus_result.append(focus_score)
|
||||||
|
if len(self.focus_result) > 3:
|
||||||
|
self.focus_result.pop(0)
|
||||||
|
final_focus = int(self.simple_moving_average(self.focus_result, window_size=5))
|
||||||
|
|
||||||
|
cli_denom = alpha_psd + beta_psd
|
||||||
|
CLI_score = np.log(theta_psd / (cli_denom + 1e-10)) if cli_denom > 0 else 0
|
||||||
|
self.CLI_result.append(CLI_score)
|
||||||
|
if len(self.CLI_result) > 5:
|
||||||
|
self.CLI_result.pop(0)
|
||||||
|
final_CLI = round(self.simple_moving_average(self.CLI_result, window_size=5), 2)
|
||||||
|
|
||||||
|
return final_focus, final_CLI, beta_psd, alpha_psd, theta_psd
|
||||||
|
|
||||||
|
def compute_psd_multichannel(self, data, fs=250, nperseg=1000):
|
||||||
|
n_samples = data.shape[-1]
|
||||||
|
if n_samples < nperseg:
|
||||||
|
nperseg = n_samples
|
||||||
|
|
||||||
|
noverlap = 500
|
||||||
|
if noverlap >= nperseg:
|
||||||
|
noverlap = int(nperseg / 2)
|
||||||
|
|
||||||
|
if nperseg == 0:
|
||||||
|
return np.array([]), np.zeros((data.shape[0], 0))
|
||||||
|
|
||||||
|
freqs, psd = welch(data, fs=fs, nperseg=nperseg, noverlap=noverlap, axis=-1)
|
||||||
|
return freqs, psd
|
||||||
|
|
||||||
|
def band_psd(self, freqs, psd, band):
|
||||||
|
idx = np.logical_and(freqs >= band[0], freqs <= band[1])
|
||||||
|
return np.sum(psd[:, idx], axis=-1)
|
||||||
|
|
||||||
|
def simple_moving_average(self, data, window_size=5):
|
||||||
|
if len(data) == 0:
|
||||||
|
return 30
|
||||||
|
window = data[-window_size:]
|
||||||
|
return sum(window) / len(window)
|
||||||
|
|
||||||
|
def reset_queue(self):
|
||||||
|
self.eegQueue.clear()
|
||||||
|
|
||||||
|
# def start_recording(self):
|
||||||
|
# """开始记录数据"""
|
||||||
|
# self.recording = True
|
||||||
|
# self.start_time = time.time()
|
||||||
|
# self.beta_history = []
|
||||||
|
# self.alpha_history = []
|
||||||
|
# self.theta_history = []
|
||||||
|
# self.focus_history = []
|
||||||
|
# self.timestamp_history = []
|
||||||
|
# print("[调试] ========== 开始记录专注度数据 ==========")
|
||||||
|
|
||||||
|
# def stop_recording(self):
|
||||||
|
# """停止记录并生成图表"""
|
||||||
|
# print(f"[调试] stop_recording被调用, recording={self.recording}, focus_history长度={len(self.focus_history)}")
|
||||||
|
# self.recording = False
|
||||||
|
# if len(self.focus_history) > 0:
|
||||||
|
# print("[调试] 数据非空,开始生成图表...")
|
||||||
|
# # 保存到本地文件
|
||||||
|
# chart_path = self.save_chart_to_file()
|
||||||
|
# if chart_path:
|
||||||
|
# print(f"[调试] 本地文件保存成功: {chart_path}")
|
||||||
|
# else:
|
||||||
|
# print("[调试] 本地文件保存失败")
|
||||||
|
# # 生成base64编码
|
||||||
|
# base64_data = self.generate_chart_base64()
|
||||||
|
# return base64_data
|
||||||
|
# else:
|
||||||
|
# print("[调试] 没有数据可保存,focus_history为空")
|
||||||
|
# return None
|
||||||
|
|
||||||
|
# def add_data_point(self, focus, beta, alpha, theta):
|
||||||
|
# if not self.recording:
|
||||||
|
# return
|
||||||
|
# current_time = time.time()
|
||||||
|
# elapsed = current_time - self.start_time
|
||||||
|
#
|
||||||
|
# self.beta_history.append(beta)
|
||||||
|
# self.alpha_history.append(alpha)
|
||||||
|
# self.theta_history.append(theta)
|
||||||
|
# self.focus_history.append(focus)
|
||||||
|
# self.timestamp_history.append(elapsed)
|
||||||
|
# print(f"[调试] 记录数据点: time={elapsed:.1f}s, focus={focus}, beta={beta:.2f}")
|
||||||
|
|
||||||
|
# def save_chart_to_file(self):
|
||||||
|
# """
|
||||||
|
# 保存图表到本地文件(唯一实现)
|
||||||
|
# """
|
||||||
|
# print(f"[调试] save_chart_to_file被调用, MATPLOTLIB_AVAILABLE={MATPLOTLIB_AVAILABLE}")
|
||||||
|
#
|
||||||
|
# if not MATPLOTLIB_AVAILABLE:
|
||||||
|
# print("[调试] matplotlib不可用,无法保存")
|
||||||
|
# return None
|
||||||
|
#
|
||||||
|
# if len(self.focus_history) < 2:
|
||||||
|
# print(f"[调试] 数据点不足,需要至少2个点,当前{len(self.focus_history)}个点")
|
||||||
|
# return None
|
||||||
|
#
|
||||||
|
# print(f"[调试] 开始保存图表到本地文件...")
|
||||||
|
#
|
||||||
|
# # 确保所有列表长度一致
|
||||||
|
# min_len = min(len(self.beta_history), len(self.alpha_history),
|
||||||
|
# len(self.theta_history), len(self.focus_history),
|
||||||
|
# len(self.timestamp_history))
|
||||||
|
#
|
||||||
|
# print(f"[调试] 数据长度: min_len={min_len}")
|
||||||
|
#
|
||||||
|
# beta_list = self.beta_history[:min_len]
|
||||||
|
# alpha_list = self.alpha_history[:min_len]
|
||||||
|
# theta_list = self.theta_history[:min_len]
|
||||||
|
# focus_list = self.focus_history[:min_len]
|
||||||
|
# times = self.timestamp_history[:min_len]
|
||||||
|
#
|
||||||
|
# # 生成文件名
|
||||||
|
# timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
# chart_path = os.path.join(self.chart_dir, f"concentration_report_{timestamp}.png")
|
||||||
|
# print(f"[调试] 保存路径: {chart_path}")
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# # 创建图表
|
||||||
|
# fig, ax1 = plt.subplots(figsize=(14, 8))
|
||||||
|
#
|
||||||
|
# # 左Y轴:功率数据
|
||||||
|
# ax1.plot(times, beta_list, 'b-', linewidth=1.5, alpha=0.8, label='Beta Power')
|
||||||
|
# ax1.plot(times, alpha_list, 'g-', linewidth=1.5, alpha=0.8, label='Alpha Power')
|
||||||
|
# ax1.plot(times, theta_list, 'orange', linewidth=1.5, alpha=0.8, label='Theta Power')
|
||||||
|
# ax1.set_xlabel('Time (s)', fontsize=12)
|
||||||
|
# ax1.set_ylabel('Band Power', fontsize=12, color='black')
|
||||||
|
# ax1.tick_params(axis='y', labelcolor='black')
|
||||||
|
# ax1.legend(loc='upper left')
|
||||||
|
# ax1.grid(True, alpha=0.3)
|
||||||
|
#
|
||||||
|
# # 右Y轴:专注度
|
||||||
|
# ax2 = ax1.twinx()
|
||||||
|
# ax2.plot(times, focus_list, 'r-', linewidth=2, alpha=0.9, label='Focus (%)')
|
||||||
|
# ax2.set_ylabel('Focus (%)', fontsize=12, color='red')
|
||||||
|
# ax2.tick_params(axis='y', labelcolor='red')
|
||||||
|
# ax2.set_ylim(0, 105)
|
||||||
|
# ax2.legend(loc='upper right')
|
||||||
|
#
|
||||||
|
# # 标题
|
||||||
|
# duration = times[-1] if times else 0
|
||||||
|
# avg_focus = np.mean(focus_list) if focus_list else 0
|
||||||
|
# plt.title(f'Concentration and EEG Band Power Trend\nDuration: {duration:.1f}s, Avg Focus: {avg_focus:.1f}%',
|
||||||
|
# fontsize=14)
|
||||||
|
#
|
||||||
|
# plt.tight_layout()
|
||||||
|
# plt.savefig(chart_path, dpi=150, bbox_inches='tight')
|
||||||
|
# plt.close()
|
||||||
|
#
|
||||||
|
# print(f"\n{'='*60}")
|
||||||
|
# print(f"专注度报告图片已保存到本地:")
|
||||||
|
# print(f" 文件路径: {chart_path}")
|
||||||
|
# print(f" 数据点数: {min_len}")
|
||||||
|
# print(f" 时长: {duration:.1f}秒")
|
||||||
|
# print(f" 平均专注度: {avg_focus:.1f}%")
|
||||||
|
# print(f"{'='*60}\n")
|
||||||
|
#
|
||||||
|
# return chart_path
|
||||||
|
#
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"[调试] 保存文件时出错: {e}")
|
||||||
|
# import traceback
|
||||||
|
# traceback.print_exc()
|
||||||
|
# return None
|
||||||
|
#
|
||||||
|
# def generate_chart_base64(self):
|
||||||
|
# """
|
||||||
|
# 生成图表的base64编码(用于网络传输)
|
||||||
|
# """
|
||||||
|
# if not MATPLOTLIB_AVAILABLE:
|
||||||
|
# return None
|
||||||
|
#
|
||||||
|
# if len(self.focus_history) < 2:
|
||||||
|
# return None
|
||||||
|
#
|
||||||
|
# min_len = min(len(self.beta_history), len(self.alpha_history),
|
||||||
|
# len(self.theta_history), len(self.focus_history),
|
||||||
|
# len(self.timestamp_history))
|
||||||
|
#
|
||||||
|
# beta_list = self.beta_history[:min_len]
|
||||||
|
# alpha_list = self.alpha_history[:min_len]
|
||||||
|
# theta_list = self.theta_history[:min_len]
|
||||||
|
# focus_list = self.focus_history[:min_len]
|
||||||
|
# times = self.timestamp_history[:min_len]
|
||||||
|
#
|
||||||
|
# fig, ax1 = plt.subplots(figsize=(14, 8))
|
||||||
|
#
|
||||||
|
# ax1.plot(times, beta_list, 'b-', linewidth=1.5, alpha=0.8, label='Beta Power')
|
||||||
|
# ax1.plot(times, alpha_list, 'g-', linewidth=1.5, alpha=0.8, label='Alpha Power')
|
||||||
|
# ax1.plot(times, theta_list, 'orange', linewidth=1.5, alpha=0.8, label='Theta Power')
|
||||||
|
# ax1.set_xlabel('Time (s)', fontsize=12)
|
||||||
|
# ax1.set_ylabel('Band Power', fontsize=12, color='black')
|
||||||
|
# ax1.tick_params(axis='y', labelcolor='black')
|
||||||
|
# ax1.legend(loc='upper left')
|
||||||
|
# ax1.grid(True, alpha=0.3)
|
||||||
|
#
|
||||||
|
# ax2 = ax1.twinx()
|
||||||
|
# ax2.plot(times, focus_list, 'r-', linewidth=2, alpha=0.9, label='Focus (%)')
|
||||||
|
# ax2.set_ylabel('Focus (%)', fontsize=12, color='red')
|
||||||
|
# ax2.tick_params(axis='y', labelcolor='red')
|
||||||
|
# ax2.set_ylim(0, 105)
|
||||||
|
# ax2.legend(loc='upper right')
|
||||||
|
#
|
||||||
|
# duration = times[-1] if times else 0
|
||||||
|
# avg_focus = np.mean(focus_list) if focus_list else 0
|
||||||
|
# plt.title(f'Concentration and EEG Band Power Trend\nDuration: {duration:.1f}s, Avg Focus: {avg_focus:.1f}%',
|
||||||
|
# fontsize=14)
|
||||||
|
#
|
||||||
|
# plt.tight_layout()
|
||||||
|
#
|
||||||
|
# buffer = io.BytesIO()
|
||||||
|
# plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
|
||||||
|
# buffer.seek(0)
|
||||||
|
# image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
|
||||||
|
# plt.close()
|
||||||
|
#
|
||||||
|
# return image_base64
|
||||||
|
|
||||||
|
def queueOpt(self, data):
|
||||||
|
if data is None or data.size == 0:
|
||||||
|
return None
|
||||||
|
if len(self.eegQueue) < self.eegQueue.maxlen:
|
||||||
|
self.eegQueue.append(data)
|
||||||
|
else:
|
||||||
|
self.eegQueue.append(data)
|
||||||
|
|
||||||
|
if len(self.eegQueue) == self.eegQueue.maxlen:
|
||||||
|
eegData = np.hstack([self.eegQueue[i] for i in range(len(self.eegQueue))])
|
||||||
|
if eegData.size == 0:
|
||||||
|
return None
|
||||||
|
eegData -= np.mean(eegData, axis=-1, keepdims=True)
|
||||||
|
eegData = signal.lfilter(self.b_notch, self.a_notch, eegData)
|
||||||
|
eegData = signal.lfilter(self.b_design, 1, eegData)
|
||||||
|
focus_score, CLI_score, beta, alpha, theta = self.calculate_all(eegData, fs=self.fs, nperseg=1000)
|
||||||
|
|
||||||
|
# self.add_data_point(focus_score, beta, alpha, theta)
|
||||||
|
|
||||||
|
return focus_score
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class Calculate2():
|
||||||
|
def __init__(self, Threshold_value_low, Threshold_value_high):
|
||||||
|
self.Threshold_value_low = Threshold_value_low
|
||||||
|
self.Threshold_value_high = Threshold_value_high
|
||||||
|
self.focus_result = []
|
||||||
|
self.theta_result = []
|
||||||
|
self.alpha_result = []
|
||||||
|
self.flow_result = []
|
||||||
|
|
||||||
|
def calculate_all(self, data, fs, L=2500):
|
||||||
|
mean_x = np.mean(data, axis=-1, keepdims=True)
|
||||||
|
data = data - mean_x
|
||||||
|
|
||||||
|
Y = fft(data, axis=-1)
|
||||||
|
P2 = np.abs(Y / L)
|
||||||
|
P1 = P2[:, :L // 2 + 1]
|
||||||
|
P1[:, 1:-1] = 2 * P1[:, 1:-1]
|
||||||
|
|
||||||
|
beta_power = self.PSD(P1, L, fs, 13, 30)
|
||||||
|
alpha_power = self.PSD(P1, L, fs, 8, 13)
|
||||||
|
theta_power = self.PSD(P1, L, fs, 4, 8)
|
||||||
|
gamma_power = self.PSD(P1, L, fs, 30, 100)
|
||||||
|
|
||||||
|
focus_score = beta_power / (alpha_power + theta_power)
|
||||||
|
print('focus score:', focus_score)
|
||||||
|
focus_score = ((focus_score - self.Threshold_value_low) * 100) / (self.Threshold_value_high - self.Threshold_value_low)
|
||||||
|
self.focus_result.append(focus_score)
|
||||||
|
if len(self.focus_result) > 3:
|
||||||
|
self.focus_result.pop(0)
|
||||||
|
final_focus = int(self.simple_moving_average(self.focus_result, window_size=3))
|
||||||
|
|
||||||
|
self.theta_result.append(theta_power)
|
||||||
|
if len(self.theta_result) > 30:
|
||||||
|
self.theta_result.pop(0)
|
||||||
|
self.alpha_result.append(alpha_power)
|
||||||
|
if len(self.alpha_result) > 30:
|
||||||
|
self.alpha_result.pop(0)
|
||||||
|
rest_theta = self.simple_moving_average(self.theta_result, window_size=30)
|
||||||
|
rest_alpha = self.simple_moving_average(self.alpha_result, window_size=30)
|
||||||
|
distraction_score = (theta_power / rest_theta) * (1 - (alpha_power / rest_alpha))
|
||||||
|
|
||||||
|
flow_score = gamma_power / beta_power
|
||||||
|
flow_score = (flow_score / self.Threshold_value_high) * 100
|
||||||
|
self.flow_result.append(flow_score)
|
||||||
|
if len(self.flow_result) > 3:
|
||||||
|
self.flow_result.pop(0)
|
||||||
|
final_flow = int(self.simple_moving_average(self.flow_result, window_size=3))
|
||||||
|
|
||||||
|
return final_focus, distraction_score, final_flow
|
||||||
|
|
||||||
|
def PSD(self, P1, L, Fs, s_freq, e_freq):
|
||||||
|
s_point = round(s_freq * L / Fs)
|
||||||
|
e_point = round(e_freq * L / Fs)
|
||||||
|
x, y = P1.shape
|
||||||
|
band_PSD = 0
|
||||||
|
for i in range(x):
|
||||||
|
for j in range(s_point, e_point):
|
||||||
|
band_PSD += P1[i, j] ** 2
|
||||||
|
return band_PSD
|
||||||
|
|
||||||
|
def simple_moving_average(self, data, window_size=3):
|
||||||
|
if len(data) == 0:
|
||||||
|
return []
|
||||||
|
window = data[-window_size:]
|
||||||
|
return sum(window) / len(window)
|
||||||
161
Debug_64ch_Decoder_Optimize/config.ini
Normal file
161
Debug_64ch_Decoder_Optimize/config.ini
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
[system]
|
||||||
|
SSVEP_ThresholdValue = [1,-0.023]
|
||||||
|
;SSVEP_ThresholdValue = [2,-0.00200]
|
||||||
|
SSMVEP_IntervalEpoch = [0.2,2.2]
|
||||||
|
list_freqs = [8, 9]
|
||||||
|
phase = [0, 0]
|
||||||
|
concentration_ThresholdValue = [0.1, 0.8]
|
||||||
|
MI_IntervalEpoch = [0.5,4.5]
|
||||||
|
blink = [70,500,100,500,800,3,2]
|
||||||
|
Right_rehabilitation = 5
|
||||||
|
Fault_rehabilitation = 5
|
||||||
|
Num_blocks = 1
|
||||||
|
Num_trials = 10
|
||||||
|
Audio_device = 0
|
||||||
|
Rest_time = 2
|
||||||
|
Device_type = 1
|
||||||
|
Device_Host = 127.0.0.1
|
||||||
|
Device_Port = 5086
|
||||||
|
Upper_Host = 127.0.0.1
|
||||||
|
Upper_Port = 8088
|
||||||
|
Serial_port = COM44
|
||||||
|
|
||||||
|
|
||||||
|
[Layout]
|
||||||
|
main_splitter_left = 993
|
||||||
|
main_splitter_right = 922
|
||||||
|
right_splitter_left = 233
|
||||||
|
right_splitter_right = 771
|
||||||
|
left_splitter_left = 503
|
||||||
|
left_splitter_right = 501q
|
||||||
|
|
||||||
|
[channel]
|
||||||
|
channel_x_fp1 = 419
|
||||||
|
channel_y_fp1 = 124
|
||||||
|
channel_x_fc1 = 439
|
||||||
|
channel_y_fc1 = 296
|
||||||
|
channel_x_fp2 = 576
|
||||||
|
channel_y_fp2 = 124
|
||||||
|
channel_x_fc2 = 556
|
||||||
|
channel_y_fc2 = 299
|
||||||
|
channel_x_f3 = 397
|
||||||
|
channel_y_f3 = 231
|
||||||
|
channel_x_cp1 = 439
|
||||||
|
channel_y_cp1 = 426
|
||||||
|
channel_x_f4 = 601
|
||||||
|
channel_y_f4 = 232
|
||||||
|
channel_x_cp2 = 559
|
||||||
|
channel_y_cp2 = 425
|
||||||
|
channel_x_fc3 = 379
|
||||||
|
channel_y_fc3 = 295
|
||||||
|
channel_x_af4 = 571
|
||||||
|
channel_y_af4 = 171
|
||||||
|
channel_x_po8 = 645
|
||||||
|
channel_y_po8 = 564
|
||||||
|
channel_x_fpz = 499
|
||||||
|
channel_y_fpz = 112
|
||||||
|
channel_x_fcz = 499
|
||||||
|
channel_y_fcz = 300
|
||||||
|
channel_x_poz = 500
|
||||||
|
channel_y_poz = 554
|
||||||
|
channel_x_po5 = 387
|
||||||
|
channel_y_po5 = 551
|
||||||
|
channel_x_po6 = 611
|
||||||
|
channel_y_po6 = 551
|
||||||
|
channel_x_c3 = 373
|
||||||
|
channel_y_c3 = 363
|
||||||
|
channel_x_fc5 = 319
|
||||||
|
channel_y_fc5 = 292
|
||||||
|
channel_x_c4 = 620
|
||||||
|
channel_y_c4 = 363
|
||||||
|
channel_x_fc6 = 676
|
||||||
|
channel_y_fc6 = 288
|
||||||
|
channel_x_p3 = 398
|
||||||
|
channel_y_p3 = 491
|
||||||
|
channel_x_cp5 = 322
|
||||||
|
channel_y_cp5 = 430
|
||||||
|
channel_x_p4 = 600
|
||||||
|
channel_y_p4 = 489
|
||||||
|
channel_x_cp6 = 678
|
||||||
|
channel_y_cp6 = 430
|
||||||
|
channel_x_c5 = 313
|
||||||
|
channel_y_c5 = 361
|
||||||
|
channel_x_f6 = 650
|
||||||
|
channel_y_f6 = 223
|
||||||
|
channel_x_f5 = 349
|
||||||
|
channel_y_f5 = 224
|
||||||
|
channel_x_po4 = 573
|
||||||
|
channel_y_po4 = 551
|
||||||
|
channel_x_po3 = 429
|
||||||
|
channel_y_po3 = 550
|
||||||
|
channel_x_cp4 = 619
|
||||||
|
channel_y_cp4 = 424
|
||||||
|
channel_x_cp3 = 381
|
||||||
|
channel_y_cp3 = 426
|
||||||
|
channel_x_fc4 = 619
|
||||||
|
channel_y_fc4 = 295
|
||||||
|
channel_x_o1 = 423
|
||||||
|
channel_y_o1 = 598
|
||||||
|
channel_x_ft9 = 252
|
||||||
|
channel_y_ft9 = 168
|
||||||
|
channel_x_o2 = 576
|
||||||
|
channel_y_o2 = 597
|
||||||
|
channel_x_ft10 = 798
|
||||||
|
channel_y_ft10 = 277
|
||||||
|
channel_x_f7 = 295
|
||||||
|
channel_y_f7 = 214
|
||||||
|
channel_x_tp9 = 202
|
||||||
|
channel_y_tp9 = 445
|
||||||
|
channel_x_f8 = 701
|
||||||
|
channel_y_f8 = 215
|
||||||
|
channel_x_t7 = 252
|
||||||
|
channel_y_t7 = 362
|
||||||
|
channel_x_tp7 = 261
|
||||||
|
channel_y_tp7 = 436
|
||||||
|
channel_x_ft8 = 734
|
||||||
|
channel_y_ft8 = 283
|
||||||
|
channel_x_ft7 = 264
|
||||||
|
channel_y_ft7 = 286
|
||||||
|
channel_x_af8 = 645
|
||||||
|
channel_y_af8 = 159
|
||||||
|
channel_x_af7 = 351
|
||||||
|
channel_y_af7 = 160
|
||||||
|
channel_x_p6 = 652
|
||||||
|
channel_y_p6 = 499
|
||||||
|
channel_x_p5 = 348
|
||||||
|
channel_y_p5 = 499
|
||||||
|
channel_x_c6 = 683
|
||||||
|
channel_y_c6 = 362
|
||||||
|
channel_x_f1 = 447
|
||||||
|
channel_y_f1 = 236
|
||||||
|
channel_x_t8 = 745
|
||||||
|
channel_y_t8 = 361
|
||||||
|
channel_x_f2 = 549
|
||||||
|
channel_y_f2 = 235
|
||||||
|
channel_x_p7 = 300
|
||||||
|
channel_y_p7 = 505
|
||||||
|
channel_x_c1 = 435
|
||||||
|
channel_y_c1 = 363
|
||||||
|
channel_x_p8 = 698
|
||||||
|
channel_y_p8 = 508
|
||||||
|
channel_x_c2 = 559
|
||||||
|
channel_y_c2 = 359
|
||||||
|
channel_x_fz = 499
|
||||||
|
channel_y_fz = 238
|
||||||
|
channel_x_po7 = 354
|
||||||
|
channel_y_po7 = 562
|
||||||
|
channel_x_tp8 = 735
|
||||||
|
channel_y_tp8 = 438
|
||||||
|
channel_x_oz = 498
|
||||||
|
channel_y_oz = 609
|
||||||
|
channel_x_af3 = 428
|
||||||
|
channel_y_af3 = 170
|
||||||
|
channel_x_pz = 501
|
||||||
|
channel_y_pz = 486
|
||||||
|
channel_x_p2 = 551
|
||||||
|
channel_y_p2 = 483
|
||||||
|
channel_x_cz = 499
|
||||||
|
channel_y_cz = 361
|
||||||
|
channel_x_p1 = 448
|
||||||
|
channel_y_p1 = 488
|
||||||
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
252
Debug_64ch_Decoder_Optimize/online_Models/log_result.txt
Normal file
252
Debug_64ch_Decoder_Optimize/online_Models/log_result.txt
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
0 0.5
|
||||||
|
1 0.5
|
||||||
|
2 0.375
|
||||||
|
3 0.5
|
||||||
|
4 0.4375
|
||||||
|
5 0.375
|
||||||
|
6 0.5
|
||||||
|
7 0.5
|
||||||
|
8 0.375
|
||||||
|
9 0.375
|
||||||
|
10 0.375
|
||||||
|
11 0.375
|
||||||
|
12 0.5
|
||||||
|
13 0.5625
|
||||||
|
14 0.5625
|
||||||
|
15 0.5
|
||||||
|
16 0.5
|
||||||
|
17 0.5
|
||||||
|
18 0.5
|
||||||
|
19 0.5625
|
||||||
|
20 0.4375
|
||||||
|
21 0.5
|
||||||
|
22 0.5
|
||||||
|
23 0.375
|
||||||
|
24 0.375
|
||||||
|
25 0.375
|
||||||
|
26 0.375
|
||||||
|
27 0.375
|
||||||
|
28 0.3125
|
||||||
|
29 0.375
|
||||||
|
30 0.5625
|
||||||
|
31 0.5
|
||||||
|
32 0.5
|
||||||
|
33 0.5625
|
||||||
|
34 0.5625
|
||||||
|
35 0.3125
|
||||||
|
36 0.3125
|
||||||
|
37 0.3125
|
||||||
|
38 0.375
|
||||||
|
39 0.5625
|
||||||
|
40 0.3125
|
||||||
|
41 0.5625
|
||||||
|
42 0.3125
|
||||||
|
43 0.375
|
||||||
|
44 0.5625
|
||||||
|
45 0.5
|
||||||
|
46 0.375
|
||||||
|
47 0.375
|
||||||
|
48 0.3125
|
||||||
|
49 0.375
|
||||||
|
50 0.375
|
||||||
|
51 0.5
|
||||||
|
52 0.5625
|
||||||
|
53 0.375
|
||||||
|
54 0.5625
|
||||||
|
55 0.5625
|
||||||
|
56 0.375
|
||||||
|
57 0.375
|
||||||
|
58 0.375
|
||||||
|
59 0.5
|
||||||
|
60 0.3125
|
||||||
|
61 0.375
|
||||||
|
62 0.375
|
||||||
|
63 0.375
|
||||||
|
64 0.375
|
||||||
|
65 0.375
|
||||||
|
66 0.3125
|
||||||
|
67 0.375
|
||||||
|
68 0.5625
|
||||||
|
69 0.5625
|
||||||
|
70 0.5625
|
||||||
|
71 0.5
|
||||||
|
72 0.5625
|
||||||
|
73 0.375
|
||||||
|
74 0.375
|
||||||
|
75 0.375
|
||||||
|
76 0.375
|
||||||
|
77 0.375
|
||||||
|
78 0.5
|
||||||
|
79 0.375
|
||||||
|
80 0.375
|
||||||
|
81 0.5
|
||||||
|
82 0.375
|
||||||
|
83 0.375
|
||||||
|
84 0.375
|
||||||
|
85 0.375
|
||||||
|
86 0.3125
|
||||||
|
87 0.375
|
||||||
|
88 0.375
|
||||||
|
89 0.5
|
||||||
|
90 0.375
|
||||||
|
91 0.4375
|
||||||
|
92 0.3125
|
||||||
|
93 0.3125
|
||||||
|
94 0.375
|
||||||
|
95 0.375
|
||||||
|
96 0.375
|
||||||
|
97 0.375
|
||||||
|
98 0.3125
|
||||||
|
99 0.4375
|
||||||
|
100 0.375
|
||||||
|
101 0.375
|
||||||
|
102 0.375
|
||||||
|
103 0.3125
|
||||||
|
104 0.5625
|
||||||
|
105 0.5
|
||||||
|
106 0.5625
|
||||||
|
107 0.5625
|
||||||
|
108 0.5
|
||||||
|
109 0.3125
|
||||||
|
110 0.5625
|
||||||
|
111 0.5625
|
||||||
|
112 0.5
|
||||||
|
113 0.3125
|
||||||
|
114 0.5
|
||||||
|
115 0.3125
|
||||||
|
116 0.375
|
||||||
|
117 0.3125
|
||||||
|
118 0.3125
|
||||||
|
119 0.3125
|
||||||
|
120 0.3125
|
||||||
|
121 0.375
|
||||||
|
122 0.375
|
||||||
|
123 0.375
|
||||||
|
124 0.375
|
||||||
|
125 0.3125
|
||||||
|
126 0.375
|
||||||
|
127 0.375
|
||||||
|
128 0.375
|
||||||
|
129 0.375
|
||||||
|
130 0.5625
|
||||||
|
131 0.375
|
||||||
|
132 0.5
|
||||||
|
133 0.3125
|
||||||
|
134 0.3125
|
||||||
|
135 0.3125
|
||||||
|
136 0.375
|
||||||
|
137 0.5
|
||||||
|
138 0.3125
|
||||||
|
139 0.375
|
||||||
|
140 0.3125
|
||||||
|
141 0.3125
|
||||||
|
142 0.3125
|
||||||
|
143 0.5625
|
||||||
|
144 0.3125
|
||||||
|
145 0.375
|
||||||
|
146 0.5
|
||||||
|
147 0.5
|
||||||
|
148 0.375
|
||||||
|
149 0.4375
|
||||||
|
150 0.5
|
||||||
|
151 0.3125
|
||||||
|
152 0.375
|
||||||
|
153 0.375
|
||||||
|
154 0.375
|
||||||
|
155 0.3125
|
||||||
|
156 0.375
|
||||||
|
157 0.4375
|
||||||
|
158 0.4375
|
||||||
|
159 0.375
|
||||||
|
160 0.375
|
||||||
|
161 0.3125
|
||||||
|
162 0.375
|
||||||
|
163 0.375
|
||||||
|
164 0.375
|
||||||
|
165 0.3125
|
||||||
|
166 0.3125
|
||||||
|
167 0.3125
|
||||||
|
168 0.375
|
||||||
|
169 0.3125
|
||||||
|
170 0.3125
|
||||||
|
171 0.3125
|
||||||
|
172 0.375
|
||||||
|
173 0.3125
|
||||||
|
174 0.3125
|
||||||
|
175 0.5
|
||||||
|
176 0.3125
|
||||||
|
177 0.375
|
||||||
|
178 0.375
|
||||||
|
179 0.3125
|
||||||
|
180 0.3125
|
||||||
|
181 0.3125
|
||||||
|
182 0.3125
|
||||||
|
183 0.5625
|
||||||
|
184 0.5625
|
||||||
|
185 0.3125
|
||||||
|
186 0.5
|
||||||
|
187 0.5
|
||||||
|
188 0.5625
|
||||||
|
189 0.5
|
||||||
|
190 0.5625
|
||||||
|
191 0.5625
|
||||||
|
192 0.5625
|
||||||
|
193 0.5
|
||||||
|
194 0.5
|
||||||
|
195 0.5625
|
||||||
|
196 0.5625
|
||||||
|
197 0.5625
|
||||||
|
198 0.5625
|
||||||
|
199 0.5
|
||||||
|
200 0.5625
|
||||||
|
201 0.5625
|
||||||
|
202 0.375
|
||||||
|
203 0.375
|
||||||
|
204 0.375
|
||||||
|
205 0.375
|
||||||
|
206 0.375
|
||||||
|
207 0.5
|
||||||
|
208 0.5
|
||||||
|
209 0.5625
|
||||||
|
210 0.5625
|
||||||
|
211 0.5625
|
||||||
|
212 0.3125
|
||||||
|
213 0.5
|
||||||
|
214 0.5
|
||||||
|
215 0.5625
|
||||||
|
216 0.5
|
||||||
|
217 0.5
|
||||||
|
218 0.5
|
||||||
|
219 0.5625
|
||||||
|
220 0.5
|
||||||
|
221 0.4375
|
||||||
|
222 0.5
|
||||||
|
223 0.5
|
||||||
|
224 0.4375
|
||||||
|
225 0.5
|
||||||
|
226 0.4375
|
||||||
|
227 0.5
|
||||||
|
228 0.5
|
||||||
|
229 0.375
|
||||||
|
230 0.375
|
||||||
|
231 0.3125
|
||||||
|
232 0.375
|
||||||
|
233 0.375
|
||||||
|
234 0.375
|
||||||
|
235 0.5625
|
||||||
|
236 0.5625
|
||||||
|
237 0.5625
|
||||||
|
238 0.5625
|
||||||
|
239 0.5625
|
||||||
|
240 0.5
|
||||||
|
241 0.5
|
||||||
|
242 0.5
|
||||||
|
243 0.5625
|
||||||
|
244 0.5625
|
||||||
|
245 0.375
|
||||||
|
246 0.375
|
||||||
|
247 0.375
|
||||||
|
248 0.3125
|
||||||
|
249 0.375
|
||||||
|
The average accuracy is: 0.42675
|
||||||
|
The best accuracy is: 0.5625
|
||||||
13
Debug_64ch_Decoder_Optimize/rthook.py
Normal file
13
Debug_64ch_Decoder_Optimize/rthook.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
# 1. 路径自适应:在 Frozen 模式下,将当前工作目录切换到可执行文件所在目录
|
||||||
|
# 这样代码中使用的相对路径(如 './config.ini')就能正确指向 exe 旁边的文件
|
||||||
|
if getattr(sys, 'frozen', False):
|
||||||
|
os.chdir(os.path.dirname(sys.executable))
|
||||||
|
|
||||||
|
# 2. 多进程保护:防止 Windows 下的无限递归炸弹
|
||||||
|
# Windows 下 multiprocessing 需要 freeze_support()
|
||||||
|
if sys.platform.startswith('win'):
|
||||||
|
multiprocessing.freeze_support()
|
||||||
35
Debug_64ch_Decoder_Optimize/runDecoder.py
Normal file
35
Debug_64ch_Decoder_Optimize/runDecoder.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from Decoder import Decoder_main
|
||||||
|
from PubLibrary.RunOnce import is_program_running
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if not is_program_running():
|
||||||
|
# 解析命令行参数
|
||||||
|
parser = argparse.ArgumentParser(description="EEG Decoder Application")
|
||||||
|
parser.add_argument('-dt', '--device-type', type=int, default=None, help="Device Type")
|
||||||
|
parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP")
|
||||||
|
parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port")
|
||||||
|
parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP")
|
||||||
|
parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
decoder = Decoder_main()
|
||||||
|
decoder.connect(
|
||||||
|
device_type=args.device_type,
|
||||||
|
device_host=args.device_host,
|
||||||
|
device_port=args.device_port,
|
||||||
|
upper_host=args.upper_host,
|
||||||
|
upper_port=args.upper_port
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
decoder.start()
|
||||||
|
while not decoder.zmqServer.IsExitApp:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
decoder.stop()
|
||||||
Reference in New Issue
Block a user