diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0365811 --- /dev/null +++ b/.gitignore @@ -0,0 +1,55 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ + +# Distribution / packaging +build/ +dist/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# data format +*.dat +*.csv +*.edf +*.event +*.edf.event +*.zip +*.xlsx +*.mat +*.json + + +# PyCharm +# JetBrains specific template is maintained in a separate repository that is not distributed with PyCharm itself +# https://github.com/github/gitignore/blob/main/Python.gitignore +.idea/ + +# Logs +*.log + +# Other common ignores +node_modules/ +dist/ +tmp/ +temp/ + +# Project-specific ignores +# Ignore all directories in the root +# merge64ch_0127/ +/P300_speller/braindecode/ +/P300_speller/data/ +/P300_speller/pyRiemann/ +/P300_speller/README/ +/merge64ch_new/ +/merge64ch_tianjinZMQdebug/ + + + + diff --git a/Debug_64ch_Decoder/Decoder.py b/Debug_64ch_Decoder/Decoder.py new file mode 100644 index 0000000..21c121c --- /dev/null +++ b/Debug_64ch_Decoder/Decoder.py @@ -0,0 +1,417 @@ +import ast +import threading +from datetime import datetime +import multiprocessing as mp +import numpy as np +import time +import torch +from queue import Empty +from scipy import signal +from torch.autograd import Variable +from Device.SunnyLinker import SunnyLinker64 +from MI.Algorithm.otherModels import weights_init +from SSMVEP.algorithm.tdca import TDCA +from SSMVEP.algorithm.base import generate_cca_references +from Zmq.zmqServer import zmqServer +from Zmq.zmqClient import zmqClient +from MI.Algorithm.conformer_2class import onlineTrain +from PubLibrary.InifileHelper import IniRead +from SSVEP.dwfbcca import FbccaDw +from Tools.plot_MI_EEG import plotMain + +class Decoder_main(threading.Thread): + def __init__(self): + threading.Thread.__init__(self) + self.Runing=True + self.decoder = None + + self.fs = 250 # 采样率 + self.energy = 0 # 电量 + self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常 + self.decoder_class = None #解码器类别 + + self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果 + def connect(self): + self.thread_data_server = SunnyLinker64('127.0.0.1', 7878, 250, 64, + method='tcp') + self.thread_data_server.toUv = True + self.thread_data_server.start() + + self.zmqServer = zmqServer() + self.zmqServer.start() + self.zmqClient = zmqClient('127.0.0.1', 8088) + self.zmqClient.connect() + + def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号 + # data: (chans, samples) + energy = np.mean(np.var(data, axis=1)) # 各通道方差均值 + if energy > threshold: + return False + return True + def init_Decoder(self,decoder_class): + ''' + 初始化解码器 + :param decoder_class: 'ssvep' or 'ssmvep' or 'mi' + :return: + ''' + self.decoder_class = decoder_class + if decoder_class == 'ssvep': + self.n_chan = 8 + self.thread_data_server.interval_inited = False + DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue')) + self.ListFreq = self.zmqServer.targetFreqs + self.num_target = len(self.ListFreq) + if self.num_target == 0: + return + # 初始化对象 二代算法 + self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5, + 0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method) + # frequence band + self.dw.filterFrequenceBank() + self.dw.setNotchFilterPara() + self.calculateCount = 0 + self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs), + 5) + self.dw.filterInit() + self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 + + elif decoder_class == 'ssmvep': + self.thread_data_server.interval_init(decoder_class) + self.n_chan = 8 + self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) + self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 + self.single_train = 10 # 单类别数量 + self.num_target = 2 # 分类目标数目 + self.list_freqs = np.array([8, 9]) # 刺激频率 + self.list_phase = np.array([0, 0]) # 相位 + self.tdca = TDCA(padding_len=5, n_components=1) + self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length, + phases=self.list_phase, n_harmonics=5) + self.parameter_init(5,45) + elif decoder_class == 'mi': + self.thread_data_server.interval_init(decoder_class) + self.n_chan = 21 + self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) + self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 + self.single_train = 40 # 单类别数量 + self.num_target = 2 # 分类目标数目 + + self.parameter_init(8, 30) + + + def parameter_init(self,bandPass_low,bandPass_high): + self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息 + self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch + self.trainData = [] #训练数据 + self.trainLabel = [] #训练标签 + self.plotData = [] #报告分析数据 + self.plotLabel = [] #报告分析标签 + self.currentLabel = -1 #刺激界面当前显示的训练标签 + self.train_started = False #是否开始训练模型 + self.load_model = False # 调用模型是否完成的标志 + self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子 + self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器 + fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + filePath = './online_Models/' + self.modelPath = ''.join([filePath, fileName, '.pth']) + self.mp_data_queue = mp.Queue() #多进程传参队列 + self.mp_result_queue = mp.Queue() #多进程结果队列 + + def preprocess(self, signal_data): + # # 计算每行的平均值 + row_means = np.mean(signal_data, axis=-1, keepdims=True) + # 对每一行去均值 + signal_data = signal_data - row_means + + signal_data = signal.lfilter(self.b_notch, self.a_notch, signal_data, axis=-1) # 工频陷波 + signal_data = signal.lfilter(self.b_design, 1, signal_data, axis=-1) # 带通滤波 + return signal_data + + def run(self): + while self.Runing: + if self.zmqServer.decoder_switch or self.zmqServer.changeTarget: + self.zmqServer.decoder_switch = False + self.zmqServer.changeTarget = False + self.init_Decoder(self.zmqServer.decoder_class) + + # 同步信息 + if self.zmqServer.state_mode == 'sync': + self.zmqClient.send_to_all('sync', self.zmqClient.state) + self.zmqServer.state_mode = 'rest' + # 状态异常,报告上位机 + if self.status_code != self.thread_data_server.status_code: + self.status_code = self.thread_data_server.status_code + self.zmqClient.send_to_all('status_code', int(self.status_code)) + + # 返回电量 + if self.energy != self.thread_data_server.energy: + self.energy = self.thread_data_server.energy + self.zmqClient.send_to_all('energy', int(self.energy)) + + if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次 + self.thread_data_server.Impedance(True) + self.zmqServer.open_Impedance = -1 + elif self.zmqServer.open_Impedance == False: + self.thread_data_server.Impedance(False) + self.zmqServer.open_Impedance = -1 + + if self.zmqServer.get_Impedance: # 返回阻抗值 + if self.thread_data_server.GetDataLenCount() > 250: + Impe_data = self.thread_data_server.getData(250) + # 计算阻抗 + imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class) + self.zmqClient.send_to_all('impedance', imps.tolist()) + else: + pass + if self.zmqServer.getReport: #返回训练报告内容 + self.zmqServer.getReport = False + allData = np.array(self.plotData) + allLabel = np.array(self.plotLabel) + 1 + nTrials = min(len(allLabel),len(allData)) + if nTrials == 0: + self.zmqClient.send_to_all('miReport',0) + else: + allData = allData[:nTrials] + allLabel = allLabel[:nTrials] + ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', + 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'] + compare_names = ['C3', 'CZ', 'C4'] + miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2, + fs=self.fs) + self.zmqClient.send_to_all('miReport',miReport) + + + if self.decoder_class == 'ssvep': + self.decoder_SSVEP() + elif self.decoder_class == 'ssmvep': + self.decoder_SSMVEP() + elif self.decoder_class == 'mi': + self.decoder_MI() + else: + if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 + if self.thread_data_server.GetDataLenCount() < 25: + time.sleep(0.005) + continue; + self.thread_data_server.getData(25) + + + def decoder_SSVEP(self): + if self.zmqServer.StartDecode: + self.zmqServer.StartDecode = False + self.decodingSteps = 1 + self.thread_data_server.ResetAll() + print('启动预测') + if self.thread_data_server.GetDataLenCount() < 50: + time.sleep(0.005) + return + if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 + data = self.thread_data_server.getDataViaSSVEP(50) + data = data[:self.n_chan, :] + if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热 + self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 + self.dw.warmFilter(data) # 预热 + self.decodingSteps = 2 + print('预热数据完成。开始预测') + return + if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中 + choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount) + self.calculateCount += 1 + if choosenNum != -1 and self.is_valid_signal(data): + self.decodingSteps = 3 + print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount)) + self.calculateCount = 0 + if self.decodingSteps == 3: # 发送解码后的信息 + self.zmqClient.send_to_all('result', int(choosenNum)) + self.decodingSteps = 0 + print('发送给界面完成。') + def decoder_SSMVEP(self): + '''模型训练''' + if self.load_model == False and all( + self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成 + self.trainData = np.array(self.trainData) + self.trainLabel = np.array(self.trainLabel) + print(np.shape(self.trainData), (self.trainLabel)) + # 保存多个数组到文件 + # np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel) + # self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf) + self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf) + now = datetime.now() + formatted_time = now.strftime('%H:%M:%S.%f')[:-3] + print('模型训练完成', formatted_time) + self.load_model = True + self.zmqClient.send_to_all('paradigm', 1) + + '''训练阶段采集数据''' + if self.zmqServer.state_mode == 'train': # 训练状态 + if self.zmqServer.StartTrain: + self.currentLabel = self.zmqServer.currentLabel + self.zmqServer.StartTrain = False + if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ + self.train_epoch[1] \ + + self.thread_data_server.event_inner_idx: + time.sleep(0.0001) + return + print('训练队列数据:', self.thread_data_server.GetDataLenCount()) + trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据 + print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx]) + trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 + trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.train_epoch[ + 0]:self.thread_data_server.event_inner_idx + self.train_epoch[1]] + print('trial: ', self.thread_data_server.event_inner_idx, self.train_epoch[0], self.train_epoch[1]) + if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance( + self.trainLabel, list) \ + and self.trainLabel.count(self.currentLabel) < self.single_train: + self.trainData.append(trainTrial) + self.trainLabel.append(self.currentLabel) + + elif self.zmqServer.state_mode == 'predict': # 测试状态 + if self.load_model == False: # 模型尚未训练完成 + time.sleep(0.01) + return + else: # 已有模型 + if self.zmqServer.StartDecode: + self.zmqServer.StartDecode = False + now = datetime.now() + formatted_time = now.strftime('%H:%M:%S.%f')[:-3] + print('启动预测 ', formatted_time) + + if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ + self.interval_epoch[1] \ + + self.thread_data_server.event_inner_idx: + time.sleep(0.0001) + return + data = self.thread_data_server.get_SSMVEPData() # 读取全部数据 + print('取出的: ', data.shape, 'event: ', data[-2, self.thread_data_server.event_inner_idx]) + data = self.preprocess(data[:self.n_chan, :]) # 预处理 + data = data[:, + self.thread_data_server.event_inner_idx + self.interval_epoch[ + 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]] + pad_eeg_test = np.zeros( + (data.shape[0], int((self.sample_length + 0.1) * self.fs))) + pad_eeg_test[:, :int(self.sample_length * self.fs)] = data + choosenNum, features_2 = self.decoder.predict(pad_eeg_test) + if isinstance(choosenNum, np.ndarray): + choosenNum = choosenNum[0] + print('结果:', choosenNum, 'rho: ', sorted(features_2[0]), + sorted(features_2[0])[-1] - sorted(features_2[0])[-2]) + self.zmqClient.send_to_all('result', int(choosenNum)) + print('发送给界面完成。') + else: # 休息状态 + if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 + if self.thread_data_server.GetDataLenCount() < 25: + time.sleep(0.005) + return + self.thread_data_server.getData(25) + def decoder_MI(self): + '''模型训练''' + if self.train_started == False and all( + self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练 + self.zmqClient.send_to_all('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机 + self.train_started = True + self.trainData = np.array(self.trainData) + self.trainLabel = np.array(self.trainLabel) + 1 + # print('训练集:',np.shape(self.trainData), (self.trainLabel)) + p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型 + p.start() + self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath, + 'n_chan': self.n_chan}) + + '''检查模型是否训练完成,调用''' + if self.load_model == False and self.train_started == True: + try: + result = self.mp_result_queue.get_nowait() + if result['status'] == 'success': + print("模型训练完成,加载新模型") + # 调用模型 + self.model = torch.load(self.modelPath, weights_only=False) + self.model.eval() + # 模型预热 + warmup_data = np.random.uniform(-1, 1, (1, 1, self.n_chan, 1000)) + warmup_data = torch.from_numpy(warmup_data) + warmup_data = Variable(warmup_data.type(torch.cuda.FloatTensor)) + with torch.no_grad(): + _ = self.model(warmup_data) + self.load_model = True + self.zmqClient.send_to_all('paradigm', 1) # 模型调用完毕,通知上位机 + else: + print("训练失败:", result['msg']) + except Empty: + pass # 还没完成 + except Exception as e: + print('模型调用失败: ', e) + + '''训练阶段采集数据''' + if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态 + if self.zmqServer.StartTrain: + self.currentLabel = self.zmqServer.currentLabel + self.zmqServer.StartTrain = False + if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ + self.interval_epoch[1] \ + + self.thread_data_server.event_inner_idx: + time.sleep(0.0001) + return + print('训练队列数据:', self.thread_data_server.GetDataLenCount()) + originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据 + print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx]) + trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理 + trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.interval_epoch[ + 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]] + print('trial: ', self.thread_data_server.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1]) + if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel, + list) \ + and self.trainLabel.count(self.currentLabel) < self.single_train: + self.trainData.append(trainTrial) + self.trainLabel.append(self.currentLabel) + print('训练集:', np.shape(self.trainData)) + self.plotData.append(originalTrial[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[ + 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]) + self.plotLabel.append(self.currentLabel) + + elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态 + if self.zmqServer.StartDecode: + self.zmqServer.StartDecode = False + now = datetime.now() + formatted_time = now.strftime('%H:%M:%S.%f')[:-3] + print('启动预测 ', formatted_time) + + if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ + self.interval_epoch[1] \ + + self.thread_data_server.event_inner_idx: + time.sleep(0.0001) + return + originalData = self.thread_data_server.get_MIData() # 读取全部数据 + print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.thread_data_server.event_inner_idx]) + start = time.time() + data = self.preprocess(originalData[:self.n_chan, :]) # 预处理 + data = data[:, + self.thread_data_server.event_inner_idx + self.interval_epoch[ + 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]] + self.plotData.append( + originalData[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[ + 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]) + + test_data = data[np.newaxis, np.newaxis, :, :] + test_data = torch.from_numpy(test_data) + test_data = Variable(test_data.type(torch.cuda.FloatTensor)) + with torch.no_grad(): + Cls = self.model(test_data) + y_pred = torch.max(Cls, 1)[1] + self.plotLabel.append(int(y_pred.item())) + print('运动意图识别: ', y_pred) + self.zmqClient.send_to_all('result', int(y_pred.item())) + end = time.time() + print(f'发送给界面完成,耗时{end - start:.3f}s。') + else: # 休息状态 + if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 + if self.thread_data_server.GetDataLenCount() < 25: + time.sleep(0.005) + return + self.thread_data_server.getData(25) + + def stop(self): + ''' + 停止运行 + @return: + ''' + self.zmqServer.stop() + self.Runing=False \ No newline at end of file diff --git a/Debug_64ch_Decoder/Device/SunnyLinker.py b/Debug_64ch_Decoder/Device/SunnyLinker.py new file mode 100644 index 0000000..3a7ac59 --- /dev/null +++ b/Debug_64ch_Decoder/Device/SunnyLinker.py @@ -0,0 +1,754 @@ +# -*-coding:utf-8 -*- +''' +SunnyLinker的通讯驱动 +''' +import ast +import socket +import threading +import time +import datetime +from typing import Dict + +import numpy as np +from threading import Thread, Event +import serial +from scipy import signal +from serial.serialutil import SerialException + +from Device.protocol import ProtocolFrame +from PubLibrary.InifileHelper import IniRead + +class RingBuffer: + def __init__(self, n_chan, n_points): + self.n_chan = n_chan + self.n_points = n_points + self.buffer = np.zeros((n_chan, n_points)) + self.currentPtr = 0 + self.readPtr = 0 + self.nUpdate = 0 + self.rawData = np.zeros((n_chan, 1)) + + ## append buffer and update current pointer + def appendBuffer(self, data): + if self.nUpdate == self.n_points: + raise Exception("Buffer is full") + + n = data.shape[1] + + # 计算可以写入的元素数量 + write_count = min(self.n_points - self.nUpdate, n) + # 写入新数据 + self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count] + # 更新结束指针 + self.currentPtr = (self.currentPtr + write_count) % self.n_points + # 更新大小 + self.nUpdate += write_count + + ## get data from buffer + def getData(self, count=50): + # 确保不会尝试读取超过缓冲区当前大小的数据 + count = min(count, self.nUpdate) + + # 计算读取结束后的下一个位置 + next_read_ptr = (self.readPtr + count) % self.n_points + if self.readPtr + count <= self.n_points: + # 情况 1:不环绕,数据是连续的 + end_index = next_read_ptr if next_read_ptr != 0 else self.n_points + data = self.buffer[:, self.readPtr:end_index] + else: + # 情况 2:发生环绕,数据被分成两部分 + # 第一部分:从 readPtr 到缓冲区末尾 + part1 = self.buffer[:, self.readPtr:] + # 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点 + part2 = self.buffer[:, :next_read_ptr] + # 将两部分在列方向上拼接 + data = np.concatenate((part1, part2), axis=1) + + # 更新读指针 + self.readPtr = next_read_ptr + # 更新大小 + self.nUpdate -= count + return data + + # reset buffer + def resetAllPara(self): + self.nUpdate = 0 + self.currentPtr = 0 + self.readPtr = 0 # add by lizhenhua 清空读指针 + self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区 + + +class SunnyLinker64(Thread, ): + serial_port = str(IniRead('system', 'Serial_port')) + t_buffer = 10 + n_chan = 64 + srate = 250 + receiveData = b'' + toUv=True#转为uV + RingBufferLock = threading.Lock() + + # 单例模式 + _instance = None + _initialized = False # 检查是否已经初始化 + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(SunnyLinker64, cls).__new__(cls) + return cls._instance + def __init__(self, host='127.0.0.1', port=7878, srate=250, n_chan=64,method = 'tcp'): + if SunnyLinker64._initialized: + return + Thread.__init__(self) + self.daemon = True + self.host = host + self.port = port + self.srate = srate + self.n_chan = n_chan + self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输 + self.__ringBuffer = RingBuffer(self.n_chan + 2, + int(np.round(self.t_buffer * self.srate))) + self.energy = 0 # 电量 + self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常 + self.gain_value = 6 # 增益倍数 + self.interval_inited = False #ssmvep或mi时间窗是否初始化 + + # 设置初始化标志为True,防止重复初始化 + SunnyLinker64._initialized = True + + # --- 新增:用于心跳检测 --- + self.last_called = 0 # 初始化为0 + self.last_called_lock = threading.Lock() # 保护 last_called 的访问 + + def interval_init(self,decoder_class): + if decoder_class == 'ssmvep': + interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) + self.interval_epoch = [int(i * self.srate) for i in interval_epoch] # epoch截取信息 + self.train_epoch = [int(self.interval_epoch[0]), + int(self.interval_epoch[1] + 0.1 * self.srate)] # 训练样本epoch + self.latency = (self.interval_epoch[ + 1] + 0.1 * self.srate) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉 + self.train_latency = (self.train_epoch[1] + 0.1 * self.srate) // 5 + + elif decoder_class == 'mi': + interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) + self.interval_epoch = [int(i * self.srate) for i in interval_epoch] # epoch截取信息 + self.train_epoch = self.interval_epoch.copy() + self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点; + self.train_latency = self.latency + + print('时间窗:', (interval_epoch)) + self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息 + self.event_inner_idx = -1 # event在5位数据包内部的idx + self.epoch_finished = False # 接收epoch是否完整 + self.pack_contain_event = False # 当前包是否含有event + self.predict_event = 99 + self.events = [1, 2, self.predict_event] + if getattr(self, 'serial', None) and self.serial.is_open: + self.serial.close() + self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口 + self.interval_inited = True + + def set_sampleRate(self,sampleRate_Code=0x00): + ''' + 设置采样率 + :param sampleRate_Code: 0x00:250Hz,0x01:500Hz,0x02:1000Hz,0x03:2000Hz + ''' + function_code = 0x02 + gain_code = 0x06 + sampleRate_Code = [gain_code,sampleRate_Code] + packed_data = ProtocolFrame.pack(function_code, sampleRate_Code) + if self.method == 'tcp': + self.sock.send(packed_data) + + def push_trigger(self,label): + ''' + 数据打标 + @param label:标签类别 + ''' + function_code = None + label = [label] + packed_data = ProtocolFrame.pack(function_code, label) + if self.method == 'tcp' and hasattr(self,'serial'): + print('发送:', label, datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3]) + self.serial.write(packed_data) + def Impedance(self, On): + ''' + 阻抗检测开关 + :param On:True为开启,False为关闭 + :return: 组好的协议帧 + ''' + function_code = 0x01 + if On: + data = [0x1] + self.gain_value = 6 + else: + data = [0x0] + self.gain_value = 6 + packed_data = ProtocolFrame.pack(function_code, data) + if self.method == 'tcp': + self.sock.send(packed_data) + + def connect(self): + try: + if self.method == 'serial': + # 开启com口,波特率115200,超时5 + self.sock = serial.Serial(self.host, self.port, timeout=5) + self.sock.flushInput() # 清空缓冲区 + count = self.sock.inWaiting() # 获取串口缓冲区数据 + while not count: + count = self.sock.inWaiting() # 获取串口缓冲区数据 + # # 接收和存储数据 + data = (self.sock.read(count)) + self.receiveData = self.receiveData + data # 将接收数据存储在字符串中 + elif self.method == 'tcp': + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.connect((self.host, int(self.port))) + self.set_sampleRate(0x00) #设置250Hz采样率 + except Exception as e: + print("请打开头环") + print(e) + + print("connected") + + def extract_packet(self, packet): + # 存储一个点的八通道数据 + dataList = [] + # 存储116个点的八通道数据 + dataMatrix = [] + + for j in range(5): + for i in range(self.n_chan): + if not self.toUv:#原始数据直接输出 + val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[ + 194 * j + 25 + 2 + i * 3] + + else:#转为uV + val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[ + 194 * j + 25 + 2 + i * 3] + if val < 8388608: + val = val * 4.5 / self.gain_value / 8388608 * 1000000; + else: + val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000; + dataList.append(val) + #同步触发源 + val = packet[194 * j + 25 + (i+1) * 3] + dataList.append(val) + #同步触发序号 + val = packet[194 * j + 25 + (i+1) * 3+1] + dataList.append(val) + + + # 将数据矩阵进行拼接 + if len(dataMatrix) == 0: + dataMatrix = np.asmatrix(dataList) + else: + dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0) + dataList.clear() + return np.transpose(dataMatrix) + + def run(self): + self.connect() + self.running = True + self.PackageLength = 998 + # 启动心跳检测线程 + threading.Thread(target=self.heartbeat_checker, daemon=True).start() + while self.running: + try: + if self.method == 'serial': + count = self.sock.inWaiting() # 获取串口缓冲区数据 + if count: + # 接收和存储数据 + data = (self.sock.read(count)) + self.receiveData = self.receiveData + data # 将接收数据存储在字符串中 + elif self.method == 'tcp': + data = self.sock.recv(600) + if not data: + break + self.receiveData += data + with self.last_called_lock: + self.last_called = time.time() + self.status_code = 1 # 收到数据,标记为正常 + if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind( + b'\x55\x55') >= self.PackageLength - 2: + + index = self.receiveData.index(b'\xaa') + self.receiveData = self.receiveData[index:] + if len(self.receiveData) >= self.PackageLength: + onepackage = self.receiveData[:self.PackageLength] + if onepackage[7] != 0: + self.energy = onepackage[7] # 电量 + self.receiveData = self.receiveData[self.PackageLength:] + dataMatrix = self.extract_packet(onepackage) + try: + with self.RingBufferLock: + if self.interval_inited: + self.epoch_finished = self.detect_event(dataMatrix) + if self.pack_contain_event: + self.__ringBuffer.resetAllPara() # 检测到当前pack含有event,清除ringbuffer中之前的数据 + self.__ringBuffer.appendBuffer(dataMatrix) + # self.plotBuffer.appendBuffer(dataMatrix) + if self.epoch_finished: + time.sleep(0.005) + print('epoch_finished: ', datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3]) + else: + self.__ringBuffer.appendBuffer(dataMatrix) + except Exception as e: + print("锁:写入异常",e) + # self.RingBufferLock.release() + except ConnectionResetError: + self.status_code = 0 # 状态异常 + print("Connection was reset by the peer.") + break + self.sock.close() + + # 检测是否含有标签 + def detect_event(self, samples): + self.pack_contain_event = False + events = np.array(samples[-2])[0].tolist() + for idx, event in enumerate(events): + if int(event) in self.events: + new_key = "".join( + [ + str(event), + datetime.datetime.now().strftime("%Y-%m-%d \ + -%H-%M-%S"), + ] + ) + if event == self.predict_event: + self.count_events[new_key] = self.latency + 1 + else: + self.count_events[new_key] = self.train_latency + 1 + self.event_inner_idx = idx + self.pack_contain_event = True + drop_items = [] + for key, value in self.count_events.items(): + value = value - 1 + if value == 0: + drop_items.append(key) + self.count_events[key] = value + for key in drop_items: + del self.count_events[key] + if drop_items: + return True + return False + + # --- 新增:心跳检测线程 --- + def heartbeat_checker(self): + """ + 定期检查是否在最近2秒内收到 eegData + 如果超过2秒未收到,则设置 status_code = 0 + """ + while self.running: + time.sleep(0.5) # 每0.5秒检查一次 + with self.last_called_lock: + now = time.time() + # 只有收到过一次数据后才开始判断超时 + if self.last_called > 0 and (now - self.last_called) > 2: + if self.status_code != 0: + print("EEG data timeout: disconnected") + self.status_code = 0 + + def getDataViaSSVEP(self,count): + ''' + ssvep的视觉通道,共8个通道 + @param count: 每通道读取的数值数量 + @return: 返回最新的数值 + ''' + data=self.getData(count) + # PO5,POZ,PO6,O2,PO8,OZ,O1,PO7 64是event导联 + rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55,64] + row_to_select=np.array(rows_to_extract) + data=data[row_to_select,:] + return data + def get_MIData(self): + ''' + 取出当前所有数值 + :return: + ''' + data = self.getData(self.__ringBuffer.nUpdate) + #MI选取导联:FC3,FC1,FCZ,FC2,FC4,C5,C3,C1,CZ,C2,C4,C6,CP3,CP1,CP2,CP4,P3,P1,PZ,P2,P4,event1,event2 + rows_to_extract = [8, 15, 12, 14, 18, 23, 16,59,50,58,17,45,29,11,10,19,20,61,51,60,21,64,65] + row_to_select = np.array(rows_to_extract) + data = data[row_to_select,:] + return data + def get_SSMVEPData(self): + ''' + 取出当前所有数值 + :return: + ''' + data = self.getData(self.__ringBuffer.nUpdate) + # PO5,POZ,PO6,O2,PO8,OZ,O1,PO7 64是event导联 + rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64,65] + row_to_select = np.array(rows_to_extract) + data = data[row_to_select, :] + return data + def getImpedance(self, data,decoder_class): + ''' + 获取阻抗值,已经放大100倍,单位是kΩ + @param data: 准备计算的通道数据,每通道200个值,注意不要把信号打标的通道传进来 + @return: 返回各个通道的阻抗值 + ''' + impedanceList = [] + for channelindex in range(data.shape[0]): + if len(data[channelindex]) > 0: + data_list = [] + # 设计陷波滤波器,去除50Hz成分 + is50filter = True + if is50filter: + b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽,1000是采样频率 + data_list = signal.lfilter(b, a, data[channelindex].tolist()) + + else: + data_list.extend(data[channelindex].tolist()) + + data_list = data_list[-1000:] + # 执行FFT + fft_result = np.fft.fft(data_list) + fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果 + freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴 + + # y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())], + # fft_magnitude[1:-1] * 2 / len(t[0].tolist()), + # [fft_magnitude[-1] / len(t[0].tolist())])) + + # 找到幅值最大的频率成分的索引(忽略直流分量,即索引0) + max_index = np.argmax(fft_magnitude[1:]) + + # 获取最大幅值的频率索引(加上1,因为索引0是直流分量) + freq_index = max_index + 1 + + # 获取最大幅值 + max_magnitude = fft_magnitude[freq_index] + + # 阻抗 + import math + result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4 + result *= 0.44 * 100 # 统一放大100倍 + impedanceList.append(int(result)) + # print(max_magnitude, result) + else: + impedanceList.append(0) + impedances = np.array(impedanceList) + if decoder_class == 'mi': + impedances = impedances[np.array([8, 15, 12, 14, 18, 23, 16,59,50,58,17,45,29,11,10,19,20,61,51,60,21])] + else: + impedances = impedances[np.array([13, 3, 2, 46, 9, 54, 47, 55])] + return impedances + def getData(self,count): + ''' + 获取最新的数据 + @param count: 每通道返回的最数值数目 + @return: 所有通道的最新count个数值 + ''' + data=None + try: + with self.RingBufferLock: + data = self.__ringBuffer.getData(count) + except: + print("锁:读取异常") + # self.RingBufferLock.release() + + + return data + def GetDataLenCount(self): + ''' + 获取最新缓存中每个通道的数量 + @return: + ''' + return self.__ringBuffer.nUpdate + + def ResetAll(self): + ''' + 清空缓存 + @return: + ''' + with self.RingBufferLock: + self.__ringBuffer.resetAllPara() + def stop(self): + self.running = False + +class SunnyLinker8(Thread, ): + receiveData = '' + t_buffer = 10 + n_chan = 9 + srate = 1000 + receiveData = b'' + toUv=False#转为uV + RingBufferLock = threading.Lock() + def __init__(self, host, port, srate=1000, n_chan=9,method = 'tcp'): + Thread.__init__(self) + self.daemon = True + self.host = host + self.port = port + self.srate = srate + self.n_chan = n_chan + self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输 + self.__ringBuffer = RingBuffer(self.n_chan + 2, + int(np.round(self.t_buffer * self.srate))) + self.energy = 0 #电量 + self.status_code = 0 #与采集设备通信的状态码,0为异常,1为正常 + self.gain_value = 6 # 增益倍数 + + def push_trigger(self,label): + ''' + 数据打标 + @param label:标签类别 + ''' + function_code = None + label = [label] + packed_data = ProtocolFrame.pack(function_code, label) + if self.method == 'tcp': + self.sock.send(packed_data) + elif self.method == 'serial': + self.sock.write(packed_data) + + def Impedance(self, On): + ''' + 阻抗检测开关 + :param On:True为开启,False为关闭 + :return: 组好的协议帧 + ''' + function_code = None + if On: + data = [0xA1] + self.gain_value = 24 + else: + data = [0xA0] + self.gain_value = 6 + packed_data = ProtocolFrame.pack(function_code, data) + if self.method == 'tcp': + self.sock.send(packed_data) + elif self.method == 'serial': + self.sock.write(packed_data) + + def connect(self): + try: + if self.method == 'serial': + # 开启com口,波特率115200,超时5 + self.sock = serial.Serial(self.host, self.port, timeout=5) + self.sock.flushInput() # 清空缓冲区 + count = self.sock.inWaiting() # 获取串口缓冲区数据 + while not count: + count = self.sock.inWaiting() # 获取串口缓冲区数据 + # # 接收和存储数据 + data = (self.sock.read(count)) + self.receiveData = self.receiveData + data # 将接收数据存储在字符串中 + elif self.method == 'tcp': + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.connect((self.host, int(self.port))) + except Exception as e: + print("请打开头环") + print(e) + + print("connected") + + def extract_packet(self, packet): + # 存储一个点的八通道数据 + dataList = [] + # 存储116个点的八通道数据 + dataMatrix = [] + + # index = (packet[1] << 24) | (packet[2] << 16) | (packet[3] << 8) | packet[4] + # print(index) + + for j in range(5): + for i in range(self.n_chan): + if not self.toUv:#原始数据直接输出 + val = (packet[26 * j + 25 + i * 3] << 16) | (packet[26 * j + 25 + 1 + i * 3] << 8) | packet[ + 26 * j + 25 + 2 + i * 3] + + else:#转为uV + val = (packet[26 * j + 25 + i * 3] << 16) | (packet[26 * j + 25 + 1 + i * 3] << 8) | packet[ + 26 * j + 25 + 2 + i * 3] + if val < 8388608: + val = val * 4.5 / self.gain_value / 8388608 * 1000000; + else: + val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000; + dataList.append(val) + #同步触发源 + val = packet[26 * j + 25 + (i+1) * 3] + dataList.append(val) + #同步触发序号 + val = packet[26 * j + 25 + (i+1) * 3+1] + dataList.append(val) + + + # 将数据矩阵进行拼接 + if len(dataMatrix) == 0: + dataMatrix = np.asmatrix(dataList) + else: + dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0) + dataList.clear() + return np.transpose(dataMatrix) + + def run(self): + self.connect() + self.running = True + self.PackageLength = 158 + start_time = time.time() + + try: + while self.running: + if self.method == 'serial': + end_time = time.time() + if end_time-start_time > 2: #超过2s未收到数据 + self.status_code = 0 #状态异常 + count = self.sock.inWaiting() # 获取串口缓冲区数据 + if count: + start_time = time.time() + self.status_code = 1 # 收到数据,状态正常 + # 接收和存储数据 + data = (self.sock.read(count)) + self.receiveData = self.receiveData + data # 将接收数据存储在字符串中 + elif self.method == 'tcp': + data = self.sock.recv(100) + if not data: + break + self.receiveData += data + if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind( + b'\x55\x55') >= self.PackageLength - 2: + + index = self.receiveData.index(b'\xaa') + self.receiveData = self.receiveData[index:] + if len(self.receiveData) >= self.PackageLength: + onepackage = self.receiveData[:self.PackageLength] + if onepackage[7] != 0: + self.energy = onepackage[7] # 电量 + self.receiveData = self.receiveData[self.PackageLength:] + dataMatrix = self.extract_packet(onepackage) + try: + with self.RingBufferLock: + self.__ringBuffer.appendBuffer(dataMatrix) + except: + print("锁:写入异常") + self.sock.close() + + except ConnectionResetError: + self.status_code = 0 # 状态异常 + print("Connection was reset by the peer.") + except SerialException as Se: + self.status_code = 0 + print('串口通信异常!请检查适配器') + + + def process_packet(self): + if self.circular_buffer.buffer_length > 158: + packet = self.circular_buffer.extract_packet() + + if packet: + # Here you would parse the packet according to the protocol + # print("Received packet:%s,index:%s", len(packet),str(integer_value)) + return packet + else: + print("Received Nothing") + return None + + def getDataViaSSVEP(self,count): + ''' + ssvep的视觉通道,共8个通道 + @param count: 每通道读取的数值数量 + @return: 返回最新的数值 + ''' + data=self.getData(count) + data=data[:8,:] + return data + + def getImpedance(self, data): + ''' + 获取阻抗值,已经放大100倍,单位是kΩ + @param data: 准备计算的通道数据,每通道200个值,注意不要把信号打标的通道传进来 + @return: 返回各个通道的阻抗值 + ''' + impedanceList = [] + for channelindex in range(data.shape[0]): + if len(data[channelindex]) > 0: + data_list = [] + # 设计陷波滤波器,去除50Hz成分 + is50filter = True + if is50filter: + b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽,1000是采样频率 + data_list = signal.lfilter(b, a, data[channelindex].tolist()) + + else: + data_list.extend(data[channelindex].tolist()) + + data_list = data_list[-1000:] + # 执行FFT + fft_result = np.fft.fft(data_list) + fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果 + freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴 + + # y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())], + # fft_magnitude[1:-1] * 2 / len(t[0].tolist()), + # [fft_magnitude[-1] / len(t[0].tolist())])) + + # 找到幅值最大的频率成分的索引(忽略直流分量,即索引0) + max_index = np.argmax(fft_magnitude[1:]) + + # 获取最大幅值的频率索引(加上1,因为索引0是直流分量) + freq_index = max_index + 1 + + # 获取最大幅值 + max_magnitude = fft_magnitude[freq_index] + + # 阻抗 + import math + result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4 + result *= 0.44 * 100 # 统一放大100倍 + impedanceList.append(int(result)) + # print(max_magnitude, result) + else: + impedanceList.append(0) + # impedances = ":".join(map(str, impedanceList)) + impedances = np.array(impedanceList) + impedances = impedances[:8] + return impedances + def getData(self,count): + ''' + 获取最新的数据 + @param count: 每通道返回的最数值数目 + @return: 所有通道的最新count个数值 + ''' + data=None + try: + with self.RingBufferLock: + data = self.__ringBuffer.getData(count) + except: + print("锁:读取异常") + # self.RingBufferLock.release() + + + return data + def GetDataLenCount(self): + ''' + 获取最新缓存中每个通道的数量 + @return: + ''' + return self.__ringBuffer.nUpdate + + def ResetAll(self): + ''' + 清空缓存 + @return: + ''' + with self.RingBufferLock: + self.__ringBuffer.resetAllPara() + def stop(self): + self.running = False + + +if __name__ == "__main__": + # Usage + Linker = SunnyLinker64('127.0.0.1', 5086, 1000, 65) + Linker.start() + + try: + while True: + time.sleep(0.005) + if(Linker.count()>0): + # print(Linker.ringBuffer.nUpdate) + t = Linker.getData() + print(t.shape[1], Linker.count()) + # Linker.ringBuffer.nUpdate=0 + # time.sleep(0.2) + except KeyboardInterrupt: + Linker.stop() diff --git a/Debug_64ch_Decoder/Device/protocol.py b/Debug_64ch_Decoder/Device/protocol.py new file mode 100644 index 0000000..62b274b --- /dev/null +++ b/Debug_64ch_Decoder/Device/protocol.py @@ -0,0 +1,193 @@ +from typing import List, Tuple, Union, Optional + + +class ProtocolFrame: + # 协议常量 + FRAME_HEADER = 0xAA + FRAME_TAIL1 = 0x55 + FRAME_TAIL2 = 0x55 + RESERVED_SIZE = 6 + MIN_FRAME_SIZE = 13 # 帧头1 + 功能1 + 长度2 + 预留6 + CRC1 + 包尾2 + MAX_DATA_LENGTH = 0xFFFF # 最大数据长度 (2字节能表示的最大值) + + @staticmethod + def calculate_crc8(data: bytes) -> bytes: + """ + 计算CRC8校验值 + Args: + data: 需要计算CRC的数据 + Returns: + 一个字节的CRC值(bytes类型) + """ + crc = 0 + for byte in data: + crc ^= byte + for _ in range(8): + crc = ((crc << 1) ^ 0x07 if crc & 0x80 else crc << 1) & 0xFF + return bytes([crc]) + + @classmethod + def pack(cls, function, data: Union[bytes, bytearray, List[int]], + reserved: Optional[Union[bytes, bytearray, List[int]]] = None) -> bytes: + """ + 协议打包函数 + + Args: + function: 功能码 (1字节) + data: 数据块 + reserved: 预留字节(6字节,可选) + + Returns: + 打包后的字节数据 + """ + # 检查功能码 + if function != None: + if not 0 <= function <= 0xFF: + raise ValueError("功能码必须是1字节") + + # 转换数据为bytearray + if isinstance(data, list): + data = bytearray(data) + elif isinstance(data, bytes): + data = bytearray(data) + + # 检查数据长度 + data_length = len(data) + if data_length > cls.MAX_DATA_LENGTH: + raise ValueError(f"数据长度超过最大值 {cls.MAX_DATA_LENGTH}") + + # 处理预留字节 + if reserved is None: + reserved = bytearray([0] * cls.RESERVED_SIZE) + else: + if isinstance(reserved, list): + reserved = bytearray(reserved) + elif isinstance(reserved, bytes): + reserved = bytearray(reserved) + if len(reserved) != cls.RESERVED_SIZE: + raise ValueError(f"预留字节必须是{cls.RESERVED_SIZE}字节") + + # 构建帧 + frame = bytearray([cls.FRAME_HEADER]) # 帧头 (1字节) + if function != None: + frame.append(function) # 功能码 (1字节) + data_length+=6 + + # 数据长度 (2字节,大端序) + frame.append((data_length >> 8) & 0xFF) # 高字节 + frame.append(data_length & 0xFF) # 低字节 + + if function != None: + frame.extend(reserved) # 预留字节 (6字节) + frame.extend(data) # 数据块 (变长) + + # 计算CRC (从功能码开始到数据块结束) + crc = cls.calculate_crc8(frame[1:]) # 不包含帧头 + frame.extend(crc) # CRC校验 (1字节) + + # 添加帧尾 + frame.extend([cls.FRAME_TAIL1, cls.FRAME_TAIL2]) # 帧尾 (2字节) + + return bytes(frame) + + @classmethod + def unpack(cls, data: Union[bytes, bytearray]) -> Tuple[int, bytearray, bytearray]: + """ + 协议解包函数 + + Args: + data: 待解析的字节数据 + + Returns: + (功能码, 数据块, 预留字节) + + Raises: + ValueError: 当数据格式不正确时 + """ + # 检查数据长度 + if len(data) < cls.MIN_FRAME_SIZE: + raise ValueError("数据长度不足") + + # 检查帧头 + if data[0] != cls.FRAME_HEADER: + raise ValueError("帧头错误") + + # 检查帧尾 + if data[-2:] != bytes([cls.FRAME_TAIL1, cls.FRAME_TAIL2]): + raise ValueError("帧尾错误") + + # 解析基本信息 + function = data[1] # 功能码 (1字节) + + # 数据长度 (2字节,大端序) + data_length = (data[2] << 8) | data[3] + + reserved = data[4:10] # 预留字节 (6字节) + + # 检查数据长度 + expected_length = cls.MIN_FRAME_SIZE + data_length + if len(data) != expected_length: + raise ValueError(f"数据长度不匹配: 期望{expected_length}字节,实际{len(data)}字节") + + # 提取数据块 + payload = data[10:10 + data_length] + + # 验证CRC (从功能码开始到数据块结束) + received_crc = data[-3] + calculated_crc = cls.calculate_crc8(data[1:-3])[0] # 获取字节值 + + if received_crc != calculated_crc: + raise ValueError(f"CRC校验失败: 期望{calculated_crc:02X},实际{received_crc:02X}") + + return function, bytearray(payload), bytearray(reserved) + + + +def print_hex(data: bytes, label: str = ""): + """打印十六进制数据,并按字节添加空格""" + hex_str = ' '.join([f"{b:02X}" for b in data]) + if label: + print(f"{label}: {hex_str}") + else: + print(hex_str) + + +def print_frame_details(data: bytes): + """打印帧的详细信息""" + print("帧详细信息:") + print(f"帧头: {data[0]:02X}") + print(f"功能码: {data[1]:02X}") + print(f"数据长度: {data[2]:02X} {data[3]:02X} ({(data[2] << 8) | data[3]}字节)") + print(f"预留字节: {' '.join([f'{b:02X}' for b in data[4:10]])}") + data_length = (data[2] << 8) | data[3] + print(f"数据块: {' '.join([f'{b:02X}' for b in data[10:10 + data_length]])}") + print(f"CRC校验: {data[-3]:02X}") + print(f"帧尾: {data[-2]:02X} {data[-1]:02X}") + + +# 使用示例 +def example_usage(): + try: + + + # 示例1:简单数据打包 + function_code = 0x01 + data = [0x1] + packed_data = ProtocolFrame.pack(function_code, data) + print_hex(packed_data, "示例1 - 完整帧") + print_frame_details(packed_data) + print() + + # 示例3:解包验证 + function, payload, reserved = ProtocolFrame.unpack(packed_data) + print("解包结果:") + print(f"功能码: 0x{function:02X}") + print_hex(payload, "数据块") + print_hex(reserved, "预留字节") + + except ValueError as e: + print(f"错误: {e}") + + +if __name__ == "__main__": + example_usage() \ No newline at end of file diff --git a/Debug_64ch_Decoder/MI/Algorithm/conformer_2class.py b/Debug_64ch_Decoder/MI/Algorithm/conformer_2class.py new file mode 100644 index 0000000..8148b68 --- /dev/null +++ b/Debug_64ch_Decoder/MI/Algorithm/conformer_2class.py @@ -0,0 +1,409 @@ +""" +EEG Conformer + +Convolutional Transformer for EEG decoding + +Couple CNN and Transformer in a concise manner with amazing results +""" +# remember to change paths + +import os +gpus = [0] +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' +os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus)) +import numpy as np +import math +import random +import time +import datetime + +from torch.utils.data import DataLoader +from torch.autograd import Variable + +import torch +import torch.nn.functional as F +from torch import nn +from torch import Tensor +from einops import rearrange +from einops.layers.torch import Rearrange, Reduce +# from common_spatial_pattern import csp + +# from torch.utils.tensorboard import SummaryWriter +from torch.backends import cudnn +cudnn.benchmark = True +cudnn.deterministic = True +from sklearn.model_selection import train_test_split +# writer = SummaryWriter('./TensorBoardX/') + + +# Convolution module +# use conv to capture local features, instead of postion embedding. +class PatchEmbedding(nn.Module): + def __init__(self, emb_size=40,n_chan=8): + # self.patch_size = patch_size + super().__init__() + + self.shallownet = nn.Sequential( + nn.Conv2d(1, 40, (1, 25), (1, 1)), + nn.Conv2d(40, 40, (n_chan, 1), (1, 1)), + nn.BatchNorm2d(40), + nn.ELU(), + nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT + nn.Dropout(0.5), + ) + + self.projection = nn.Sequential( + nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + + def forward(self, x: Tensor) -> Tensor: + b, _, _, _ = x.shape + x = self.shallownet(x) + x = self.projection(x) + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, emb_size, num_heads, dropout): + super().__init__() + self.emb_size = emb_size + self.num_heads = num_heads + self.keys = nn.Linear(emb_size, emb_size) + self.queries = nn.Linear(emb_size, emb_size) + self.values = nn.Linear(emb_size, emb_size) + self.att_drop = nn.Dropout(dropout) + self.projection = nn.Linear(emb_size, emb_size) + + def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: + queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) + keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) + values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) + energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) + if mask is not None: + fill_value = torch.finfo(torch.float32).min + energy.mask_fill(~mask, fill_value) + + scaling = self.emb_size ** (1 / 2) + att = F.softmax(energy / scaling, dim=-1) + att = self.att_drop(att) + out = torch.einsum('bhal, bhlv -> bhav ', att, values) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.projection(out) + return out + + +class ResidualAdd(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + res = x + x = self.fn(x, **kwargs) + x += res + return x + + +class FeedForwardBlock(nn.Sequential): + def __init__(self, emb_size, expansion, drop_p): + super().__init__( + nn.Linear(emb_size, expansion * emb_size), + nn.GELU(), + nn.Dropout(drop_p), + nn.Linear(expansion * emb_size, emb_size), + ) + + +class GELU(nn.Module): + def forward(self, input: Tensor) -> Tensor: + return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0))) + + +class TransformerEncoderBlock(nn.Sequential): + def __init__(self, + emb_size, + num_heads=10, + drop_p=0.5, + forward_expansion=4, + forward_drop_p=0.5): + super().__init__( + ResidualAdd(nn.Sequential( + nn.LayerNorm(emb_size), + MultiHeadAttention(emb_size, num_heads, drop_p), + nn.Dropout(drop_p) + )), + ResidualAdd(nn.Sequential( + nn.LayerNorm(emb_size), + FeedForwardBlock( + emb_size, expansion=forward_expansion, drop_p=forward_drop_p), + nn.Dropout(drop_p) + ) + )) + + +class TransformerEncoder(nn.Sequential): + def __init__(self, depth, emb_size): + super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)]) + + +class ClassificationHead(nn.Sequential): + def __init__(self, emb_size, n_classes): + super().__init__() + + # global average pooling + self.clshead = nn.Sequential( + Reduce('b n e -> b e', reduction='mean'), + nn.LayerNorm(emb_size), + nn.Linear(emb_size, n_classes) + ) + self.fc = nn.Sequential( + nn.Linear(2440, 256), + nn.ELU(), + nn.Dropout(0.5), + nn.Linear(256, 32), + nn.ELU(), + nn.Dropout(0.3), + nn.Linear(32, 2) + ) + + def forward(self, x): + x = x.contiguous().view(x.size(0), -1) + out = self.fc(x) + return out + + +class Conformer(nn.Sequential): + def __init__(self, emb_size=40, depth=6, n_classes=2,n_chan=8, **kwargs): + super().__init__( + + PatchEmbedding(emb_size,n_chan), + TransformerEncoder(depth, emb_size), + ClassificationHead(emb_size, n_classes) + ) + + +class ExP(): + def __init__(self,n_chan): + super(ExP, self).__init__() + self.n_chan = n_chan + self.batch_size = 24 + self.n_epochs = 250 + self.c_dim = 4 + self.lr = 0.0002 + self.b1 = 0.5 + self.b2 = 0.999 + + self.start_epoch = 0 + # 创建目录 + os.makedirs("online_Models", exist_ok=True) + self.log_write = open("./online_Models/log_result.txt", "w") + + + self.Tensor = torch.cuda.FloatTensor + self.LongTensor = torch.cuda.LongTensor + + self.criterion_cls = torch.nn.CrossEntropyLoss().cuda() + + self.model = Conformer(n_chan=self.n_chan).cuda() + self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))]) + self.model = self.model.cuda() + + # self.model = EEGNet().cuda() + # self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))]) + # self.model = self.model.cuda() + # summary(self.model, (1, 8, 1000)) + + + # Segmentation and Reconstruction (S&R) data augmentation + def interaug(self, timg, label): + # 确保输入是 numpy 数组(CPU) + if isinstance(timg, torch.Tensor): + timg = timg.cpu().numpy() + if isinstance(label, torch.Tensor): + label = label.cpu().numpy() + + aug_data = [] + aug_label = [] + for cls4aug in range(2): + cls_idx = np.where(label == cls4aug + 1) + tmp_data = timg[cls_idx] + tmp_label = label[cls_idx] + tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, self.n_chan, 1000)) + for ri in range(int(self.batch_size / 2)): + for rj in range(8): + rand_idx = np.random.randint(0, tmp_data.shape[0], 8) + tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, + rj * 125:(rj + 1) * 125] + + aug_data.append(tmp_aug_data) + aug_label.append(tmp_label[:int(self.batch_size / 2)]) + aug_data = np.concatenate(aug_data) + aug_label = np.concatenate(aug_label) + aug_shuffle = np.random.permutation(len(aug_data)) + aug_data = aug_data[aug_shuffle, :, :] + aug_label = aug_label[aug_shuffle] + + # 返回 numpy 数组,由调用方决定是否移到 GPU + return aug_data, aug_label + + def train(self,all_data,all_label,model_path): + all_data = np.array(all_data);all_label = np.array(all_label) + all_data = np.expand_dims(all_data, axis=1) + train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2, + random_state=42, stratify=all_label,shuffle=True) + + # === 优化:一次性预生成增强数据,避免每个 batch 都重复计算 === + aug_data, aug_label = self.interaug(train_data, train_label) + # 将原始数据和增强数据合并,再一起打乱 + train_data_full = np.concatenate([train_data, aug_data], axis=0) + train_label_full = np.concatenate([train_label, aug_label], axis=0) + shuffle_idx = np.random.permutation(len(train_data_full)) + train_data_full = train_data_full[shuffle_idx] + train_label_full = train_label_full[shuffle_idx] + + img = torch.from_numpy(train_data_full) + label = torch.from_numpy(train_label_full-1) + + dataset = torch.utils.data.TensorDataset(img, label) + self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True) + + test_data = torch.from_numpy(test_data) + test_label = torch.from_numpy(test_label-1) + test_dataset = torch.utils.data.TensorDataset(test_data, test_label) + self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True) + + # Optimizers + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2)) + + test_data = Variable(test_data.type(self.Tensor)) + test_label = Variable(test_label.type(self.LongTensor)) + + bestAcc = 0 + averAcc = 0 + num = 0 + Y_true = 0 + Y_pred = 0 + + # Train the cnn model + for e in range(self.n_epochs): + # in_epoch = time.time() + self.model.train() + for i, (img, label) in enumerate(self.dataloader): + + img = Variable(img.cuda().type(self.Tensor)) + label = Variable(label.cuda().type(self.LongTensor)) + + outputs = self.model(img) + + loss = self.criterion_cls(outputs, label) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + + # out_epoch = time.time() + + + # test process + if (e + 1) % 1 == 0: + self.model.eval() + Cls = self.model(test_data) + + loss_test = self.criterion_cls(Cls, test_label) + y_pred = torch.max(Cls, 1)[1] + acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0)) + train_pred = torch.max(outputs, 1)[1] + train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0)) + + print('Epoch:', e, + ' Train loss: %.6f' % loss.detach().cpu().numpy(), + ' Test loss: %.6f' % loss_test.detach().cpu().numpy(), + ' Train accuracy %.6f' % train_acc, + ' Test accuracy is %.6f' % acc) + + self.log_write.write(str(e) + " " + str(acc) + "\n") + num = num + 1 + averAcc = averAcc + acc + if acc > bestAcc: + bestAcc = acc + Y_true = test_label + Y_pred = y_pred + + + torch.save(self.model, model_path) + averAcc = averAcc / num + print('The average accuracy is:', averAcc) + print('The best accuracy is:', bestAcc) + self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n") + self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n") + + return bestAcc, averAcc, Y_true, Y_pred + # writer.close() + + +def onlineTrain(data_queue,result_queue): + import torch + print(f"[DEBUG] torch.__version__ = {torch.__version__}") + print(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}") + try: + starttime = datetime.datetime.now() + + # seed_n = np.random.randint(2025) + seed_n = 1877 + random.seed(seed_n) + np.random.seed(seed_n) + torch.manual_seed(seed_n) + torch.cuda.manual_seed(seed_n) + torch.cuda.manual_seed_all(seed_n) + + + # 从队列获取训练数据 + data = data_queue.get(timeout=30) + all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan'] + exp = ExP(n_chan) + print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path) + bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path) + print('THE BEST ACCURACY IS ' + str(bestAcc)) + + endtime = datetime.datetime.now() + print('train duration: ',str(endtime - starttime)) + + # 将模型或参数传回 + result_queue.put({ + 'status': 'success', + 'model_state': model_path, # 或保存路径 + 'timestamp': time.time() + }) + except Exception as e: + result_queue.put({'status': 'error', 'msg': str(e)}) + +def offlineTrain(all_data,all_label,modelPath): + starttime = datetime.datetime.now() + + # seed_n = np.random.randint(2025) + seed_n = 1877 + print('seed is ' + str(seed_n)) + random.seed(seed_n) + np.random.seed(seed_n) + torch.manual_seed(seed_n) + torch.cuda.manual_seed(seed_n) + torch.cuda.manual_seed_all(seed_n) + + exp = ExP() + + bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath) + print('THE BEST ACCURACY IS ' + str(bestAcc)) + + endtime = datetime.datetime.now() + print('train duration: ',str(endtime - starttime)) + + + +if __name__ == "__main__": + print(time.asctime(time.localtime(time.time()))) + print(time.asctime(time.localtime(time.time()))) diff --git a/Debug_64ch_Decoder/MI/Algorithm/conformer_2class_cpu.py b/Debug_64ch_Decoder/MI/Algorithm/conformer_2class_cpu.py new file mode 100644 index 0000000..6e29bc3 --- /dev/null +++ b/Debug_64ch_Decoder/MI/Algorithm/conformer_2class_cpu.py @@ -0,0 +1,382 @@ +""" +EEG Conformer + +Convolutional Transformer for EEG decoding + +Couple CNN and Transformer in a concise manner with amazing results +""" +# remember to change paths +import os +import numpy as np +import math +import random +import time +import datetime + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch import nn +from torch import Tensor +from einops import rearrange +from einops.layers.torch import Rearrange, Reduce +from torch.backends import cudnn +from sklearn.model_selection import train_test_split +# writer = SummaryWriter('./TensorBoardX/') + + +# Convolution module +# use conv to capture local features, instead of postion embedding. +class PatchEmbedding(nn.Module): + def __init__(self, emb_size=40): + # self.patch_size = patch_size + super().__init__() + + self.shallownet = nn.Sequential( + nn.Conv2d(1, 40, (1, 25), (1, 1)), + nn.Conv2d(40, 40, (8, 1), (1, 1)), + nn.BatchNorm2d(40), + nn.ELU(), + nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT + nn.Dropout(0.5), + ) + + self.projection = nn.Sequential( + nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + + def forward(self, x: Tensor) -> Tensor: + b, _, _, _ = x.shape + x = self.shallownet(x) + x = self.projection(x) + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, emb_size, num_heads, dropout): + super().__init__() + self.emb_size = emb_size + self.num_heads = num_heads + self.keys = nn.Linear(emb_size, emb_size) + self.queries = nn.Linear(emb_size, emb_size) + self.values = nn.Linear(emb_size, emb_size) + self.att_drop = nn.Dropout(dropout) + self.projection = nn.Linear(emb_size, emb_size) + + def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: + queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) + keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) + values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) + energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) + if mask is not None: + fill_value = torch.finfo(torch.float32).min + energy.mask_fill(~mask, fill_value) + + scaling = self.emb_size ** (1 / 2) + att = F.softmax(energy / scaling, dim=-1) + att = self.att_drop(att) + out = torch.einsum('bhal, bhlv -> bhav ', att, values) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.projection(out) + return out + + +class ResidualAdd(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + res = x + x = self.fn(x, **kwargs) + x += res + return x + + +class FeedForwardBlock(nn.Sequential): + def __init__(self, emb_size, expansion, drop_p): + super().__init__( + nn.Linear(emb_size, expansion * emb_size), + nn.GELU(), + nn.Dropout(drop_p), + nn.Linear(expansion * emb_size, emb_size), + ) + + +class GELU(nn.Module): + def forward(self, input: Tensor) -> Tensor: + return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0))) + + +class TransformerEncoderBlock(nn.Sequential): + def __init__(self, + emb_size, + num_heads=10, + drop_p=0.5, + forward_expansion=4, + forward_drop_p=0.5): + super().__init__( + ResidualAdd(nn.Sequential( + nn.LayerNorm(emb_size), + MultiHeadAttention(emb_size, num_heads, drop_p), + nn.Dropout(drop_p) + )), + ResidualAdd(nn.Sequential( + nn.LayerNorm(emb_size), + FeedForwardBlock( + emb_size, expansion=forward_expansion, drop_p=forward_drop_p), + nn.Dropout(drop_p) + ) + )) + + +class TransformerEncoder(nn.Sequential): + def __init__(self, depth, emb_size): + super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)]) + + +class ClassificationHead(nn.Sequential): + def __init__(self, emb_size, n_classes): + super().__init__() + + # global average pooling + self.clshead = nn.Sequential( + Reduce('b n e -> b e', reduction='mean'), + nn.LayerNorm(emb_size), + nn.Linear(emb_size, n_classes) + ) + self.fc = nn.Sequential( + nn.Linear(2440, 256), + nn.ELU(), + nn.Dropout(0.5), + nn.Linear(256, 32), + nn.ELU(), + nn.Dropout(0.3), + nn.Linear(32, 2) + ) + + def forward(self, x): + x = x.contiguous().view(x.size(0), -1) + out = self.fc(x) + return out + + +class Conformer(nn.Sequential): + def __init__(self, emb_size=40, depth=6, n_classes=2, **kwargs): + super().__init__( + + PatchEmbedding(emb_size), + TransformerEncoder(depth, emb_size), + ClassificationHead(emb_size, n_classes) + ) + + +class ExP(): + def __init__(self): + super(ExP, self).__init__() + self.batch_size = 24 + self.n_epochs = 250 + self.c_dim = 4 + self.lr = 0.0002 + self.b1 = 0.5 + self.b2 = 0.999 + + self.start_epoch = 0 + + self.log_write = open("./online_Models/log_result.txt", "w") + + # 自动选择设备:有 GPU 用 GPU,否则用 CPU + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # self.device = torch.device("cpu") + print(f"Using device: {self.device}") + + # 定义张量类型(不再强制使用 cuda) + self.Tensor = torch.FloatTensor + self.LongTensor = torch.LongTensor + + # 将模型移到指定设备 + self.model = Conformer().to(self.device) + + # 损失函数也移到设备 + self.criterion_cls = torch.nn.CrossEntropyLoss().to(self.device) + + # self.model = EEGNet().cuda() + # self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))]) + # self.model = self.model.cuda() + # summary(self.model, (1, 8, 1000)) + + + # Segmentation and Reconstruction (S&R) data augmentation + def interaug(self, timg, label): + aug_data = [] + aug_label = [] + for cls4aug in range(2): + cls_idx = np.where(label == cls4aug + 1) + tmp_data = timg[cls_idx] + tmp_label = label[cls_idx] + tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, 8, 1000)) + for ri in range(int(self.batch_size / 2)): + for rj in range(8): + rand_idx = np.random.randint(0, tmp_data.shape[0], 8) + tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, + rj * 125:(rj + 1) * 125] + + aug_data.append(tmp_aug_data) + aug_label.append(tmp_label[:int(self.batch_size / 2)]) + aug_data = np.concatenate(aug_data) + aug_label = np.concatenate(aug_label) + aug_shuffle = np.random.permutation(len(aug_data)) + aug_data = aug_data[aug_shuffle, :, :] + aug_label = aug_label[aug_shuffle] + + aug_data = torch.from_numpy(aug_data).float().to(self.device) + aug_label = torch.from_numpy(aug_label - 1).long().to(self.device) + return aug_data, aug_label + + def train(self,all_data,all_label,model_path): + all_data = np.array(all_data);all_label = np.array(all_label) + all_data = np.expand_dims(all_data, axis=1) + train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2, + random_state=42, stratify=all_label,shuffle=True) + # 转为 Tensor + img = torch.from_numpy(train_data).float().to(self.device) + label = torch.from_numpy(train_label - 1).long().to(self.device) + + dataset = torch.utils.data.TensorDataset(img, label) + self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True) + + test_data = torch.from_numpy(test_data).float().to(self.device) + test_label = torch.from_numpy(test_label - 1).long().to(self.device) + test_dataset = torch.utils.data.TensorDataset(test_data, test_label) + self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True) + + # Optimizers + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2)) + + bestAcc = 0 + averAcc = 0 + num = 0 + Y_true = 0 + Y_pred = 0 + + # Train the cnn model + for e in range(self.n_epochs): + # in_epoch = time.time() + self.model.train() + for i, (img, label) in enumerate(self.dataloader): + + # data augmentation + aug_data, aug_label = self.interaug(train_data, train_label) + img = torch.cat((img, aug_data)) + label = torch.cat((label, aug_label)) + + + outputs = self.model(img) + + loss = self.criterion_cls(outputs, label) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + + # out_epoch = time.time() + + + # test process + if (e + 1) % 1 == 0: + self.model.eval() + with torch.no_grad(): + Cls = self.model(test_data) + + loss_test = self.criterion_cls(Cls, test_label) + y_pred = torch.max(Cls, 1)[1] + acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0)) + train_pred = torch.max(outputs, 1)[1] + train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0)) + + print('Epoch:', e, + ' Train loss: %.6f' % loss.detach().cpu().numpy(), + ' Test loss: %.6f' % loss_test.detach().cpu().numpy(), + ' Train accuracy %.6f' % train_acc, + ' Test accuracy is %.6f' % acc) + + self.log_write.write(str(e) + " " + str(acc) + "\n") + num = num + 1 + averAcc = averAcc + acc + if acc > bestAcc: + bestAcc = acc + Y_true = test_label + Y_pred = y_pred + + + torch.save(self.model, model_path) + averAcc = averAcc / num + print('The average accuracy is:', averAcc) + print('The best accuracy is:', bestAcc) + self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n") + self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n") + + return bestAcc, averAcc, Y_true, Y_pred + # writer.close() + + +def onlineTrain(data_queue,result_queue): + try: + starttime = datetime.datetime.now() + + # seed_n = np.random.randint(2025) + seed_n = 1877 + random.seed(seed_n) + np.random.seed(seed_n) + torch.manual_seed(seed_n) + torch.cuda.manual_seed(seed_n) + torch.cuda.manual_seed_all(seed_n) + + exp = ExP() + # 从队列获取训练数据 + data = data_queue.get(timeout=30) + all_data, all_label,model_path = data['data'], data['label'],data['modelPath'] + print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path) + bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path) + print('THE BEST ACCURACY IS ' + str(bestAcc)) + + endtime = datetime.datetime.now() + print('train duration: ',str(endtime - starttime)) + + # 将模型或参数传回 + result_queue.put({ + 'status': 'success', + 'model_state': model_path, # 或保存路径 + 'timestamp': time.time() + }) + except Exception as e: + result_queue.put({'status': 'error', 'msg': str(e)}) + +def offlineTrain(all_data,all_label,modelPath): + starttime = datetime.datetime.now() + + # seed_n = np.random.randint(2025) + seed_n = 1877 + print('seed is ' + str(seed_n)) + random.seed(seed_n) + np.random.seed(seed_n) + torch.manual_seed(seed_n) + torch.cuda.manual_seed(seed_n) + torch.cuda.manual_seed_all(seed_n) + + exp = ExP() + + bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath) + print('THE BEST ACCURACY IS ' + str(bestAcc)) + + endtime = datetime.datetime.now() + print('train duration: ',str(endtime - starttime)) + + + +if __name__ == "__main__": + print(time.asctime(time.localtime(time.time()))) + print(time.asctime(time.localtime(time.time()))) diff --git a/Debug_64ch_Decoder/MI/Algorithm/otherModels.py b/Debug_64ch_Decoder/MI/Algorithm/otherModels.py new file mode 100644 index 0000000..be03ac2 --- /dev/null +++ b/Debug_64ch_Decoder/MI/Algorithm/otherModels.py @@ -0,0 +1,184 @@ +from torchsummary import summary +import torch +import torch.nn as nn + + +def weights_init(m): + if isinstance(m, nn.Conv2d): + nn.init.xavier_uniform_(m.weight) + # nn.init.constant(m.bias, 0) # bias may be none + + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.constant_(m.bias, 0) + + + +def square_activation(x): + return torch.square(x) + + +def safe_log(x): + return torch.clip(torch.log(x), min=1e-7, max=1e7) + + +class ShallowConvNet(nn.Module): + def __init__(self, num_classes=3, chans=19, samples=768): + super(ShallowConvNet, self).__init__() + self.conv_nums = 40 + self.features = nn.Sequential( + nn.Conv2d(1, self.conv_nums, (1, 25)), + nn.Conv2d(self.conv_nums, self.conv_nums, (chans, 1), bias=False), + nn.BatchNorm2d(self.conv_nums) + ) + self.avgpool = nn.AvgPool2d(kernel_size=(1, 75), stride=(1, 15)) + self.dropout = nn.Dropout() + + out = torch.ones((1, 1, chans, samples)) + out = self.features(out) + out = self.avgpool(out) + n_out_time = out.cpu().data.numpy().shape + self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes) + + def forward(self, x): + x = self.features(x) + x = square_activation(x) + x = self.avgpool(x) + x = safe_log(x) + x = self.dropout(x) + + features = torch.flatten(x, 1) + cls = self.classifier(features) + return cls + + +class EEGNet(nn.Module): + def __init__(self, num_classes=2, chans=8, samples=1000, dropout_rate=0.5, kernel_length=64, F1=8, + F2=16,): + super(EEGNet, self).__init__() + + self.features = nn.Sequential( + nn.Conv2d(1, F1, kernel_size=(1, kernel_length), bias=False), + nn.BatchNorm2d(F1), + nn.Conv2d(F1, F1, kernel_size=(chans, 1), groups=F1, bias=False), # groups=F1 for depthWiseConv + nn.BatchNorm2d(F1), + nn.ELU(inplace=True), + # nn.ReLU(), + nn.AvgPool2d((1, 4)), + nn.Dropout(dropout_rate), + # for SeparableCon2D + # SeparableConv2D(F1, F2, kernel1_size=(1, 16), bias=False), + nn.Conv2d(F1, F1, kernel_size=(1, 16), groups=F1, bias=False), # groups=F1 for depthWiseConv + nn.BatchNorm2d(F1), + nn.ELU(inplace=True), + # nn.ReLU(), + nn.Conv2d(F1, F2, kernel_size=(1, 1), groups=1, bias=False), # point-wise cnn + nn.BatchNorm2d(F2), + # nn.ReLU(), + nn.ELU(inplace=True), + nn.AvgPool2d((1, 8)), + nn.Dropout(p=dropout_rate), + # nn.Dropout(p=0.5), + ) + out = torch.ones((1, 1, chans, samples)) + out = self.features(out) + n_out_time = out.cpu().data.numpy().shape + self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes) + + def forward(self, x): + conv_features = self.features(x) + features = torch.flatten(conv_features, 1) + cls = self.classifier(features) + return cls + + +class LMDA(nn.Module): + """ + LMDA-Net for the paper + """ + def __init__(self, chans=19, samples=768, num_classes=3, depth=9, kernel=75, channel_depth1=24, channel_depth2=9, + ave_depth=1, avepool=5): + super(LMDA, self).__init__() + self.ave_depth = ave_depth + self.channel_weight = nn.Parameter(torch.randn(depth, 1, chans), requires_grad=True) + nn.init.xavier_uniform_(self.channel_weight.data) + + + self.time_conv = nn.Sequential( + nn.Conv2d(depth, channel_depth1, kernel_size=(1, 1), groups=1, bias=False), + nn.BatchNorm2d(channel_depth1), + nn.Conv2d(channel_depth1, channel_depth1, kernel_size=(1, kernel), + groups=channel_depth1, bias=False), + nn.BatchNorm2d(channel_depth1), + nn.GELU(), + ) + # self.avgPool1 = nn.AvgPool2d((1, 24)) + self.chanel_conv = nn.Sequential( + nn.Conv2d(channel_depth1, channel_depth2, kernel_size=(1, 1), groups=1, bias=False), + nn.BatchNorm2d(channel_depth2), + nn.Conv2d(channel_depth2, channel_depth2, kernel_size=(chans, 1), groups=channel_depth2, bias=False), + nn.BatchNorm2d(channel_depth2), + nn.GELU(), + ) + + self.norm = nn.Sequential( + nn.AvgPool3d(kernel_size=(1, 1, avepool)), + # nn.AdaptiveAvgPool3d((9, 1, 35)), + nn.Dropout(p=0.65), + ) + + # 定义自动填充模块 + out = torch.ones((1, 1, chans, samples)) + out = torch.einsum('bdcw, hdc->bhcw', out, self.channel_weight) + out = self.time_conv(out) + out = self.chanel_conv(out) + out = self.norm(out) + n_out_time = out.cpu().data.numpy().shape + print('In ShallowNet, n_out_time shape: ', n_out_time) + self.classifier = nn.Linear(n_out_time[-1]*n_out_time[-2]*n_out_time[-3], num_classes) + + def EEGDepthAttention(self, x): + # x: input features with shape [N, C, H, W] + + N, C, H, W = x.size() + # K = W if W % 2 else W + 1 + k = 7 + adaptive_pool = nn.AdaptiveAvgPool2d((1, W)) + conv = nn.Conv2d(1, 1, kernel_size=(k, 1), padding=(k//2, 0), bias=True).to(x.device) # original kernel k + nn.init.xavier_uniform_(conv.weight) + nn.init.constant_(conv.bias, 0) + softmax = nn.Softmax(dim=-2) + x_pool = adaptive_pool(x) + x_transpose = x_pool.transpose(-2, -3) + y = conv(x_transpose) + y = softmax(y) + y = y.transpose(-2, -3) + return y * C * x + + def forward(self, x): + x = torch.einsum('bdcw, hdc->bhcw', x, self.channel_weight) + + x_time = self.time_conv(x) # batch, depth1, channel, samples_ + x_time = self.EEGDepthAttention(x_time) # DA1 + + x = self.chanel_conv(x_time) # batch, depth2, 1, samples_ + x = self.norm(x) + + features = torch.flatten(x, 1) + cls = self.classifier(features) + return cls + + +if __name__ == '__main__': + model = ShallowConvNet(num_classes=4, chans=22, samples=1125).cuda() + a = torch.randn(12, 1, 3, 875).cuda().float() + l2 = model(a) + model_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) + summary(model, show_input=True) + + print(l2.shape) + diff --git a/Debug_64ch_Decoder/PubLibrary/InifileHelper.py b/Debug_64ch_Decoder/PubLibrary/InifileHelper.py new file mode 100644 index 0000000..3a771dc --- /dev/null +++ b/Debug_64ch_Decoder/PubLibrary/InifileHelper.py @@ -0,0 +1,34 @@ +# -*-coding:utf-8 -*- +import configparser +import os +import sys +from audioop import error + +# 打包后需要切换到 exe 所在目录来定位 config.ini +if getattr(sys, 'frozen', False): + _BASE_DIR = os.path.dirname(sys.executable) +else: + _BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +IniFileName = os.path.join(_BASE_DIR, 'config.ini') + +def IniWrite(section,keyname,value): + # 创建ConfigParser对象 + config = configparser.ConfigParser() + config.read(IniFileName,encoding='utf-8') + with open(IniFileName, 'w') as configfile: + if not config.has_section(section): + config.add_section(section) + config[section][keyname]=str(value) + config.write(configfile) + +def IniRead(section,key): + + try: + config = configparser.ConfigParser() + config.read(IniFileName,encoding='utf-8') + return config[section][key] + except error as e: + print(e) + # 读取特定section和键的值 + return '5' \ No newline at end of file diff --git a/Debug_64ch_Decoder/PubLibrary/RunOnce.py b/Debug_64ch_Decoder/PubLibrary/RunOnce.py new file mode 100644 index 0000000..4201773 --- /dev/null +++ b/Debug_64ch_Decoder/PubLibrary/RunOnce.py @@ -0,0 +1,15 @@ +import ctypes +import sys + + +def is_program_running(name='Global\\Decoder'): + # 创建互斥体 + mutex_name =name + h_mutex = ctypes.windll.kernel32.CreateMutexW(None, False, mutex_name) + + # 检查互斥体是否已经存在 + if ctypes.windll.kernel32.GetLastError() == 183: # ERROR_ALREADY_EXISTS + print("程序已经在运行.") + return True + + return False \ No newline at end of file diff --git a/Debug_64ch_Decoder/SSMVEP/algorithm/base.py b/Debug_64ch_Decoder/SSMVEP/algorithm/base.py new file mode 100644 index 0000000..dc00835 --- /dev/null +++ b/Debug_64ch_Decoder/SSMVEP/algorithm/base.py @@ -0,0 +1,418 @@ +# -*- coding: utf-8 -*- +# +# Authors: Swolf +# Date: 2021/1/07 +# License: MIT License + + +from typing import Optional, List, Tuple, Union +import warnings +import numpy as np +from numpy import ndarray +from numpy.linalg import linalg +from scipy.linalg import solve, qr +from scipy.signal import sosfiltfilt, cheby1, cheb1ord +from sklearn.base import BaseEstimator, TransformerMixin, clone + + +def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray: + """Transform spatial filters to spatial patterns based on paper [1]_. + Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine + information from different channels to extract signals of interest from EEG signals, but if our goal is + neurophysiological interpretation or visualization of weights, activation patterns need to be constructed + from the obtained spatial filters. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + W : ndarray + Spatial filters, shape (n_channels, n_filters). + Cx : ndarray + Covariance matrix of eeg data, shape (n_channels, n_channels). + Cs : ndarray + Covariance matrix of source data, shape (n_channels, n_channels). + + Returns + ------- + A : ndarray + Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern. + + References + ---------- + .. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging. + Neuroimage 87 (2014): 96-110. + """ + # use linalg.solve instead of inv, makes it more stable + # see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py + # and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html + A = solve(Cs.T, np.dot(Cx, W).T).T + return A + + +class FilterBank(BaseEstimator, TransformerMixin): + """ + Filter bank decomposition is a bandpass filter array that divides the input signal into + multiple subband components and obtains the eigenvalues of each subband component. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + base_estimator : class + Estimator for model training and feature extraction. + filterbank : list[ndarray] + A bandpass filter bank used to divide the input signal into multiple subband components. + n_jobs : int + Sets the number of CPU working cores. The default is None. + + References + ---------- + .. [1] Chen X, Wang Y, Nakanishi M, et al. High-speed spelling with a noninvasive brain-computer interface[J]. + Proceedings of the national academy of sciences, 2015, 112(44): E6058-E6067. + """ + def __init__( + self, + base_estimator: BaseEstimator, + filterbank: List[ndarray], + n_jobs: Optional[int] = None, + ): + self.base_estimator = base_estimator + self.filterbank = filterbank + self.n_jobs = n_jobs + + def fit(self, X: ndarray, y: Optional[ndarray] = None, **kwargs): + """ + Training model + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + X : None + Training signal (parameters can be ignored, only used to maintain code structure). + y : None + Label data (ibid., ignorable). + Yf : None + Reference signal (ibid., ignorable). + """ + self.estimators_ = [ + clone(self.base_estimator) for _ in range(len(self.filterbank)) + ] + X = self.transform_filterbank(X) + for i, est in enumerate(self.estimators_): + est.fit(X[i], y, **kwargs) + # def wrapper(est, X, y, kwargs): + # est.fit(X, y, **kwargs) + # return est + # self.estimators_ = Parallel(n_jobs=self.n_jobs)( + # delayed(wrapper)(est, X[i], y, kwargs) for i, est in enumerate(self.estimators_)) + return self + + def transform(self, X: ndarray, **kwargs): + """ + The parameters stored in self are used to convert X into features, and X is filtered through the filter bank to + obtain the eigenvalues of each subband component. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + X : ndarray, shape(n_trials, n_channels, n_samples) + Test the signal. + + Returns + ------- + feat : ndarray, shape(n_trials, n_fre) + Feature array. + """ + X = self.transform_filterbank(X) + feat = [est.transform(X[i], **kwargs) for i, est in enumerate(self.estimators_)] + # def wrapper(est, X, kwargs): + # retval = est.transform(X, **kwargs) + # return retval + # feat = Parallel(n_jobs=self.n_jobs)( + # delayed(wrapper)(est, X[i], kwargs) for i, est in enumerate(self.estimators_)) + feat = np.concatenate(feat, axis=-1) + return feat + + def transform_filterbank(self, X: ndarray): + """ + The input signal is filtered through a filter bank. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + X : ndarray, shape(n_trials, n_channels, n_samples) + Input signal. + + Returns + ------- + Xs: ndarray, shape(Nfb, n_trials, n_channels, n_samples) + Individual subband components of the input signal. + """ + Xs = np.stack([sosfiltfilt(sos, X, axis=-1) for sos in self.filterbank]) + return Xs + + +class FilterBankSSVEP(FilterBank): + """ + Filter bank analysis for SSVEP. + The SSVEP is analyzed using filter banks, that is, multiple filters are combined to decompose the SSVEP signal + into specific segments (subbands containing the original data) and obtain its characteristic data. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + filterbank : list[ndarray] + The filter bank. + base_estimator : class + Estimator for model training and feature extraction. + filterweights : ndarray + Filter weight, default is None. + n_jobs : int + Sets the number of CPU working cores. The default is None. + """ + + def __init__( + self, + filterbank: List[ndarray], + base_estimator: BaseEstimator, + filterweights: Optional[ndarray] = None, + n_jobs: Optional[int] = None, + ): + self.filterweights = filterweights + super().__init__(base_estimator, filterbank, n_jobs=n_jobs) + + def transform(self, X: ndarray): # type: ignore[override] + """ + X is converted into features by using the parameters stored in self, and the eigenvalues of each subband + component are obtained after the input signal is filtered by the filter bank. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + X : ndarray, shape(n_trials, n_channels, n_samples) + Test the signal. + + Returns + ------- + features : ndarray, shape(n_trials, n_fre) + Feature array. + """ + features = super().transform(X) + if self.filterweights is None: + return features + else: + features = np.reshape( + features, (features.shape[0], len(self.filterbank), -1) + ) + return np.sum( + features * self.filterweights[np.newaxis, :, np.newaxis], axis=1 + ) + + + +def generate_filterbank( + passbands: List[Tuple[float, float]], + stopbands: List[Tuple[float, float]], + srate: int, + order: Optional[int] = None, + rp: float = 0.5, +): + """ + Create a filter bank, that is, obtain a bandpass filter coefficient that can divide the input signal into multiple + subband components. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + passbands : list or tuple(float, float) + Passband parameters. + stopbands : list or tuple(float, float) + Stopband parameters. + srate : float + Sampling rate. + order : int + Filter order. + rp : float + The maximum ripple allowed in the passband below the unit gain is 0.5 by default. + + Returns + ------- + Filterbank:ndarray, shape(len(passbands), N, 6) + Filter bank coefficient. + """ + filterbank = [] + for wp, ws in zip(passbands, stopbands): + if order is None: + N, wn = cheb1ord(wp, ws, 3, 40, fs=srate) + sos = cheby1(N, rp, wn, btype="bandpass", output="sos", fs=srate) + else: + sos = cheby1(order, rp, wp, btype="bandpass", output="sos", fs=srate) + + filterbank.append(sos) + return filterbank + +def process(data): + # 白化操作 + meanValue = np.mat(data.mean(axis=1)) + meanData = np.repeat(meanValue, data.shape[1], axis=1) + whiteTemp = data - meanData + # QR 分解 + rankWhiteTemp = whiteTemp.shape[0] + whiteTemp = np.transpose(whiteTemp) + Q, R = qr(whiteTemp.A, mode='economic') + # 计算矩阵的秩 + rankQ = linalg.matrix_rank(R) + if rankQ == 0: + raise ValueError('stats:canoncorr:badData') + elif rankQ <= rankWhiteTemp: + # warnings.warn('stats:canoncorr:NotFullRank') + Q = Q[:, 0:rankQ] + return Q, rankQ + +def reference(listFreqs,fs, numberSmples, num_harms): + numberFrequence = len(listFreqs) + timeIndex = np.arange(1, numberSmples + 1) / fs # time index + referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples)) + for frequenceIndex in range(numberFrequence): + temp = [] + for harmIndex in range(1, num_harms + 1): + stimFrequence = listFreqs[frequenceIndex] # in HZ + # Sin and Cos + temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence), + np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)]) + referenceTemp = np.mat(temp) + # 白化操作和QR分解 + Q, rankQ = process(referenceTemp) + referenceData[frequenceIndex] = np.transpose(Q) + return referenceData + +def generate_cca_references( + freqs: Union[ndarray, int, float], + srate, + T, + phases: Optional[Union[ndarray, int, float]] = None, + n_harmonics: int = 1, +): + """ + Construct a sine-cosine reference signal for canonical correlation analysis (CCA). + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + freqs : int or float + Frequency. + srate : int + Sampling rate. + T : int + Sampling time. + phases : int or float + Phase, default is None. + n_harmonics : int + The number of harmonics. The default value is 1. + + Returns + ------- + Yf:ndarray, shape(srate*T, n_harmonics*2) + Sine and cosine reference signal. + """ + if isinstance(freqs, int) or isinstance(freqs, float): + freqs = np.array([freqs]) + freqs = np.array(freqs)[:, np.newaxis] + if phases is None: + phases = 0 + if isinstance(phases, int) or isinstance(phases, float): + phases = np.array([phases]) + phases = np.array(phases)[:, np.newaxis] + t = np.linspace(0, T, int(T * srate)) + + Yf = [] + for i in range(n_harmonics): + Yf.append( + np.stack( + [ + np.sin(2 * np.pi * (i + 1) * freqs * t + np.pi * phases), + np.cos(2 * np.pi * (i + 1) * freqs * t + np.pi * phases), + ], + axis=1, + ) + ) + Yf = np.concatenate(Yf, axis=1) + return Yf + + +def sign_flip(u, s, vh=None): + """Flip signs of SVD or EIG using the method in paper [1]_. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + u: ndarray + left singular vectors, shape (M, K). + s: ndarray + singular values, shape (K,). + vh: ndarray or None + transpose of right singular vectors, shape (K, N). + + Returns + ------- + u: ndarray + corrected left singular vectors. + s: ndarray + singular values. + vh: ndarray + transpose of corrected right singular vectors. + + References + ---------- + .. [1] https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6422.pdf + """ + if vh is None: + total_proj = np.sum(u * s, axis=0) + signs = np.sign(total_proj) + + random_idx = signs == 0 + if np.any(random_idx): + signs[random_idx] = 1 + warnings.warn( + "The magnitude is close to zero, the sign will become arbitrary." + ) + + u = u * signs + + return u, s + else: + left_proj = np.sum(s[:, np.newaxis] * vh, axis=-1) + right_proj = np.sum(u * s, axis=0) + total_proj = left_proj + right_proj + signs = np.sign(total_proj) + + random_idx = signs == 0 + if np.any(random_idx): + signs[random_idx] = 1 + warnings.warn( + "The magnitude is close to zero, the sign will become arbitrary." + ) + + u = u * signs + vh = signs[:, np.newaxis] * vh + + return u, s, vh diff --git a/Debug_64ch_Decoder/SSMVEP/algorithm/dsp.py b/Debug_64ch_Decoder/SSMVEP/algorithm/dsp.py new file mode 100644 index 0000000..a2ae853 --- /dev/null +++ b/Debug_64ch_Decoder/SSMVEP/algorithm/dsp.py @@ -0,0 +1,436 @@ +# -*- coding: utf-8 -*- +# DSP: Discriminal Spatial Patterns +# Authors: Swolf +# Junyang Wang <2144755928@qq.com> +# Last update date: 2022-8-11 +# License: MIT License + +from typing import Optional, List, Tuple +from itertools import combinations +import numpy as np +from scipy.linalg import eigh +from numpy import ndarray +from scipy.linalg import solve +from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin + + + +def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray: + """Transform spatial filters to spatial patterns based on paper [1]_. + Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine + information from different channels to extract signals of interest from EEG signals, but if our goal is + neurophysiological interpretation or visualization of weights, activation patterns need to be constructed + from the obtained spatial filters. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + W : ndarray + Spatial filters, shape (n_channels, n_filters). + Cx : ndarray + Covariance matrix of eeg data, shape (n_channels, n_channels). + Cs : ndarray + Covariance matrix of source data, shape (n_channels, n_channels). + + Returns + ------- + A : ndarray + Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern. + + References + ---------- + .. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging. + Neuroimage 87 (2014): 96-110. + """ + # use linalg.solve instead of inv, makes it more stable + # see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py + # and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html + A = solve(Cs.T, np.dot(Cx, W).T).T + return A + +def isPD(B: ndarray) -> bool: + """Returns true when input matrix is positive-definite, via Cholesky decompositon method. + + Parameters + ---------- + B : ndarray + Any matrix, shape (N, N) + + Returns + ------- + bool + True if B is positve-definite. + + Notes + ----- + Use numpy.linalg rather than scipy.linalg. In this case, scipy.linalg has unpredictable behaviors. + """ + + try: + _ = np.linalg.cholesky(B) + return True + except np.linalg.LinAlgError: + return False + +def nearestPD(A: ndarray) -> ndarray: + """Find the nearest positive-definite matrix to input. + + Parameters + ---------- + A : ndarray + Any square matrxi, shape (N, N) + + Returns + ------- + A3 : ndarray + positive-definite matrix to A + + Notes + ----- + A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1]_, which + origins at [2]_. + + References + ---------- + .. [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd + .. [2] N.J. Higham, "Computing a nearest symmetric positive semidefinite matrix" (1988): + https://doi.org/10.1016/0024-3795(88)90223-6 + """ + + B = (A + A.T) / 2 + _, s, V = np.linalg.svd(B) + + H = np.dot(V.T, np.dot(np.diag(s), V)) + + A2 = (B + H) / 2 + + A3 = (A2 + A2.T) / 2 + + if isPD(A3): + return A3 + + print("Replace current matrix with the nearest positive-definite matrix.") + + spacing = np.spacing(np.linalg.norm(A)) + # The above is different from [1]. It appears that MATLAB's `chol` Cholesky + # decomposition will accept matrixes with exactly 0-eigenvalue, whereas + # Numpy's will not. So where [1] uses `eps(mineig)` (where `eps` is Matlab + # for `numpy.spacing`), we use the above definition. CAVEAT: our `spacing` + # will be much larger than [1]'s `eps(mineig)`, since `mineig` is usually on + # the order of 1e-16, and `eps(1e-16)` is on the order of 1e-34, whereas + # `spacing` will, for Gaussian random matrixes of small dimension, be on + # othe order of 1e-16. In practice, both ways converge, as the unit test + # below suggests. + eye = np.eye(A.shape[0]) + k = 1 + while not isPD(A3): + mineig = np.min(np.real(np.linalg.eigvals(A3))) + A3 += eye * (-mineig * k**2 + spacing) + k += 1 + + return A3 + +def xiang_dsp_kernel( + X: ndarray, y: ndarray +) -> Tuple[ndarray, ndarray, ndarray, ndarray]: + """ + DSP: Discriminal Spatial Patterns, only for two classes[1]_. + Import train data to solve spatial filters with DSP, + finds a projection matrix that maximize the between-class scatter matrix and + minimize the within-class scatter matrix. Currently only support for two types of data. + + Author: Swolf + + Created on: 2021-1-07 + + Update log: + + Parameters + ---------- + X : ndarray + EEG train data assuming removing mean, shape (n_trials, n_channels, n_samples) + y : ndarray + labels of EEG data, shape (n_trials, ) + + Returns + ------- + W : ndarray + spatial filters, shape (n_channels, n_filters) + D : ndarray + eigenvalues in descending order + M : ndarray + mean value of all classes and trials, i.e. common mode signals, shape (n_channel, n_samples) + A : ndarray + spatial patterns, shape (n_channels, n_filters) + + Notes + ----- + the implementation removes regularization on within-class scatter matrix Sw. + + References + ---------- + .. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in + a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831. + """ + X, y = np.copy(X), np.copy(y) + labels = np.unique(y) + X = np.reshape(X, (-1, *X.shape[-2:])) + X = X - np.mean(X, axis=-1, keepdims=True) + # the number of each label + n_labels = np.array([np.sum(y == label) for label in labels]) + # average template of all trials + M = np.mean(X, axis=0) + # class conditional template + Ms, Ss = zip( + *[ + ( + np.mean(X[y == label], axis=0), + np.sum( + np.matmul(X[y == label], np.swapaxes(X[y == label], -1, -2)), axis=0 + ), + ) + for label in labels + ] + ) + Ms, Ss = np.stack(Ms), np.stack(Ss) + # within-class scatter matrix + Sw = np.sum( + Ss + - n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)), + axis=0, + ) + Ms = Ms - M + # between-class scatter matrix + Sb = np.sum( + n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)), + axis=0, + ) + + D, W = eigh(nearestPD(Sb), nearestPD(Sw)) + ix = np.argsort(D)[::-1] # in descending order + D, W = D[ix], W[:, ix] + A = robust_pattern(W, Sb, W.T @ Sb @ W) + + return W, D, M, A + + +def xiang_dsp_feature( + W: ndarray, M: ndarray, X: ndarray, n_components: int = 1 +) -> ndarray: + """ + Return DSP features in paper [1]_. + + Author: Swolf + + Created on: 2021-1-07 + + Update log: + + Parameters + ---------- + W : ndarray + spatial filters from csp_kernel, shape (n_channels, n_filters) + M : ndarray + common template for all classes, shape (n_channel, n_samples) + X : ndarray + eeg test data, shape (n_trials, n_channels, n_samples) + n_components : int, optional + length of the spatial filters, first k components to use, by default 1 + + Returns + ------- + features: ndarray + features, shape (n_trials, n_components, n_samples) + + Raises + ------ + ValueError + n_components should less than half of the number of channels + + Notes + ----- + 1. instead of meaning of filtered signals in paper [1]_., we directly return filtered signals. + + References + ---------- + .. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in + a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831. + """ + W, M, X = np.copy(W), np.copy(M), np.copy(X) + max_components = W.shape[1] + if n_components > max_components: + raise ValueError("n_components should less than the number of channels") + X = np.reshape(X, (-1, *X.shape[-2:])) + X = X - np.mean(X, axis=-1, keepdims=True) + # print('************: ',np.shape(W),np.shape(X),np.shape(M)) + features = np.matmul(W[:, :n_components].T, X - M) + return features + + +class DSP(BaseEstimator, TransformerMixin, ClassifierMixin): + """ + DSP: Discriminal Spatial Patterns + + Author: Swolf + + Created on: 2021-1-07 + + Update log: + + Parameters + ---------- + n_components : int + length of the spatial filter, first k components to use, by default 1 + transform_method : str + method of template matching, by default ’corr‘ (pearson correlation coefficient) + classes_ : int + number of the EEG classes + + Attributes + ---------- + n_components : int + length of the spatial filter, first k components to use, by default 1 + transform_method : str + method of template matching, by default ’corr‘ (pearson correlation coefficient) + classes_ : int + number of the EEG classes + W_ : ndarray, shape(n_channels, n_filters) + Spatial filters, shape(n_channels, n_filters), in which n_channels = n_filters + D_ : ndarray, shape(n_filters, ) + eigenvalues in descending order, shape(n_filters, ) + M_ : ndarray, shape(n_channels, n_samples) + mean value of all classes and trials, i.e. common mode signals, shape(n_channels, n_samples) + A_ : ndarray, shape(n_channels, n_filters) + spatial patterns, shape(n_channels, n_filters) + templates_: ndarray, shape(n_classes, n_filters, n_samples) + templates of train data, shape(n_classes, n_filters, n_samples) + + """ + + def __init__(self, n_components: int = 1, transform_method: str = "corr"): + self.n_components = n_components + self.transform_method = transform_method + + def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None): + """ + Import the train data to get a model. + + Parameters + ---------- + X : ndarray + train data, shape(n_trials, n_channels, n_samples) + y : ndarray + labels of train data, shape (n_trials, ) + Yf : ndarray + optional parameter + + Returns + ------- + W_ : ndarray + spatial filters, shape (n_channels, n_filters), in which n_channels = n_filters + D_ : ndarray + eigenvalues in descending order, shape (n_filters, ) + M_ : ndarray + template for all classes, shape (n_channel, n_samples) + A_ : ndarray + spatial patterns, shape (n_channels, n_filters) + templates_ : ndarray + templates of train data, shape (n_channels, n_filters, n_samples) + """ + X -= np.mean(X, axis=-1, keepdims=True) + self.classes_ = np.unique(y) + self.W_, self.D_, self.M_, self.A_ = xiang_dsp_kernel(X, y) + + self.templates_ = np.stack( + [ + np.mean( + xiang_dsp_feature( + self.W_, self.M_, X[y == label], n_components=self.W_.shape[1] + ), + axis=0, + ) + for label in self.classes_ + ] + ) + return self + + def transform(self, X: ndarray): + """ + Import the test data to get features. + + Parameters + ---------- + X : ndarray + test data, shape(n_trials, n_channels, n_samples) + + Returns + ------- + feature : ndarray, shape(n_trials,n_classes) + correlation coefficients of templates of train data and features of test data, shape(n_trials, n_classes) + """ + n_components = self.n_components + X -= np.mean(X, axis=-1, keepdims=True) + features = xiang_dsp_feature(self.W_, self.M_, X, n_components=n_components) + if self.transform_method is None: + return features.reshape((features.shape[0], -1)) + elif self.transform_method == "mean": + return np.mean(features, axis=-1) + elif self.transform_method == "corr": + return self._pearson_features( + features, self.templates_[:, :n_components, :] + ) + else: + raise ValueError("non-supported transform method") + + def _pearson_features(self, X: ndarray, templates: ndarray): + """ + Calculate pearson correlation coefficient. + + Parameters + ---------- + X : ndarray + features of test data after spatial filters, shape(n_trials, n_components, n_samples) + templates : ndarray + templates of train data, shape(n_classes, n_components, n_samples) + + Returns + ------- + corr : ndarray + pearson correlation coefficient, shape(n_trials, n_classes) + """ + X = np.reshape(X, (-1, *X.shape[-2:])) + templates = np.reshape(templates, (-1, *templates.shape[-2:])) + X = X - np.mean(X, axis=-1, keepdims=True) + templates = templates - np.mean(templates, axis=-1, keepdims=True) + X = np.reshape(X, (X.shape[0], -1)) + templates = np.reshape(templates, (templates.shape[0], -1)) + istd_X = 1 / np.std(X, axis=-1, keepdims=True) + istd_templates = 1 / np.std(templates, axis=-1, keepdims=True) + corr = (X @ templates.T) / (templates.shape[1] - 1) + corr = istd_X * corr * istd_templates.T + return corr + + def predict(self, X: ndarray): + """ + Import the templates and the test data to get prediction labels. + + Parameters + ---------- + X : ndarray + test data, shape(n_trials, n_channels, n_samples) + + Returns + ------- + labels : ndarray + prediction labels of test data, shape(n_trials,) + """ + feat = self.transform(X) + if self.transform_method == "corr": + labels = self.classes_[np.argmax(feat, axis=-1)] + else: + raise NotImplementedError() + return labels + + diff --git a/Debug_64ch_Decoder/SSMVEP/algorithm/tdca.py b/Debug_64ch_Decoder/SSMVEP/algorithm/tdca.py new file mode 100644 index 0000000..7b7247b --- /dev/null +++ b/Debug_64ch_Decoder/SSMVEP/algorithm/tdca.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +# +# Authors: Swolf +# Date: 2021/10/10 +# License: MIT License +""" +Task Decomposition Component Analysis. +""" +from typing import List + +import numpy as np +from scipy.linalg import qr +from scipy.stats import pearsonr +from numpy import ndarray +from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin +from typing import Optional, List +from SSMVEP.algorithm.base import FilterBankSSVEP +from SSMVEP.algorithm.dsp import xiang_dsp_kernel, xiang_dsp_feature + + +def proj_ref(Yf: ndarray): + Q, R = qr(Yf.T, mode="economic") + P = Q @ Q.T + return P + + +def aug_2(X: ndarray, n_samples: int, padding_len: int, P: ndarray, training: bool = True): + X = X.reshape((-1, *X.shape[-2:])) + n_trials, n_channels, n_points = X.shape + # if n_points < padding_len + n_samples: + # raise ValueError("the length of X should be larger than l+n_samples.") + aug_X = np.zeros((n_trials, (padding_len + 1) * n_channels, n_samples)) + if training: + for i in range(padding_len + 1): + aug_X[:, i * n_channels : (i + 1) * n_channels, :] = X[ + ..., i : i + n_samples + ] + else: + for i in range(padding_len + 1): + aug_X[:, i * n_channels : (i + 1) * n_channels, : n_samples - i] = X[ + ..., i:n_samples + ] + aug_Xp = aug_X @ P + aug_X = np.concatenate([aug_X, aug_Xp], axis=-1) + return aug_X + + +def tdca_feature( + X: ndarray, + templates: ndarray, + W: ndarray, + M: ndarray, + Ps: List[ndarray], + padding_len: int, + n_components: int = 1, + training=False, +): + rhos = [] + for Xk, P in zip(templates, Ps): + a = xiang_dsp_feature( + W, + M, + aug_2(X, P.shape[0], padding_len, P, training=training), + n_components=n_components, + ) + b = Xk[:n_components, :] + a = np.reshape(a, (-1)) + b = np.reshape(b, (-1)) + rhos.append(pearsonr(a, b)[0]) + return rhos + + +class TDCA(BaseEstimator, TransformerMixin, ClassifierMixin): + def __init__(self, padding_len: int, n_components: int = 1): + self.padding_len = padding_len + self.n_components = n_components + + def fit(self, X: ndarray, y: ndarray, Yf: ndarray): + X -= np.mean(X, axis=-1, keepdims=True) + self.classes_ = np.unique(y) + self.Ps_ = [proj_ref(Yf[i]) for i in range(len(self.classes_))] + # print(np.shape(self.Ps_)) + + aug_X_list, aug_Y_list = [], [] + for i, label in enumerate(self.classes_): + aug_X_list.append( + aug_2( + X[y == label], + self.Ps_[i].shape[0], + self.padding_len, + self.Ps_[i], + training=True, + ) + ) + aug_Y_list.append(y[y == label]) + + aug_X = np.concatenate(aug_X_list, axis=0) + aug_Y = np.concatenate(aug_Y_list, axis=0) + self.W_, _, self.M_, _ = xiang_dsp_kernel(aug_X, aug_Y) + + self.templates_ = np.stack( + [ + np.mean( + xiang_dsp_feature( + self.W_, + self.M_, + aug_X[aug_Y == label], + n_components=self.W_.shape[1], + ), + axis=0, + ) + for label in self.classes_ + ] + ) + return self + + def transform(self, X: ndarray): + n_components = self.n_components + X -= np.mean(X, axis=-1, keepdims=True) + X = X.reshape((-1, *X.shape[-2:])) + rhos = [ + tdca_feature( + tmp, + self.templates_, + self.W_, + self.M_, + self.Ps_, + self.padding_len, + n_components=n_components, + ) + for tmp in X + ] + rhos = np.stack(rhos) + return rhos + + def predict(self, X: ndarray): + feat = self.transform(X) + labels = self.classes_[np.argmax(feat, axis=-1)] + return labels,feat + + +class FBTDCA(FilterBankSSVEP, ClassifierMixin): + def __init__( + self, + filterbank: List[ndarray], + padding_len: int, + n_components: int = 1, + filterweights: Optional[ndarray] = None, + n_jobs: Optional[int] = None, + ): + self.padding_len = padding_len + self.n_components = n_components + self.filterweights = filterweights + self.n_jobs = n_jobs + super().__init__( + filterbank, + TDCA(padding_len, n_components=n_components), + filterweights=filterweights, + n_jobs=n_jobs, + ) + + def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None): # type: ignore[override] + self.classes_ = np.unique(y) + super().fit(X, y, Yf=Yf) + return self + + def predict(self, X: ndarray): + features = self.transform(X) + if self.filterweights is None: + features = np.reshape( + features, (features.shape[0], len(self.filterbank), -1) + ) + features = np.mean(features, axis=1) + labels = self.classes_[np.argmax(features, axis=-1)] + return labels,features diff --git a/Debug_64ch_Decoder/SSVEP/dwfbcca.py b/Debug_64ch_Decoder/SSVEP/dwfbcca.py new file mode 100644 index 0000000..773f325 --- /dev/null +++ b/Debug_64ch_Decoder/SSVEP/dwfbcca.py @@ -0,0 +1,527 @@ + + +# -*- coding: utf-8 -*- +import os +import time +import warnings +from os import error +import numpy as np +import scipy +from numpy.linalg import linalg +from scipy.io import loadmat +from scipy.linalg import qr +from scipy.signal import filtfilt, lfilter + + +class FbccaDw: + def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method): + print('******************************************') + print('parameter list') + print('target:', num_target) + print('number of filter bank:', num_filter) + print('parameter:', parameter) + print('width:', width) + self.phase = 0 + self.bandWidth = width + self.winNum = winNum + self.num_harms = num_harms + self.num_target = num_target + self.num_chans = num_chans + self.winTimeDelay = stimTime + self.fs = fs + self.winMaxSampleNum = self.winTimeDelay * self.winNum * self.fs + self.winDelayNum = round(self.winTimeDelay * self.fs) + self.num_fbs = num_filter + parameterValue = np.power(np.arange(1, self.num_fbs + 1), -(parameter[0])) + parameter[1] + self.weightValue = parameterValue / (sum(parameterValue)) + + self.dataUseLen = [0] * self.winNum + self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans]) + self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans]) + self.rhoNum = 2 + self.notchZh = [0] + self.filterZf = [0] * self.num_fbs + self.north_b = [] + self.north_a = [] + self.filterBank_A = [] + self.filterBank_B = [] + self.winStep = 1 + self.DW_cost_method = 'DW11' if method==1 else 'DW1' + + ''' + filterFrequenceBank:根据刺激频率生成的通带和阻带,用于滤波器组频带分解 + ''' + + def filterFrequenceBank(self): + # 阻带的最高频率 + lastFrequence = 90 + freqBandWidth = self.bandWidth[1] + fStep = self.bandWidth[0] + bandFrequence = np.zeros((5, 4)) + # 第二列频率带 + band = list(range(freqBandWidth, lastFrequence, fStep)) + band[:] = [x - 2 for x in band] + colValue = np.maximum(np.asmatrix(band), 1) + bandFrequence[:, 1] = colValue[0, 0:5] + # 第一列频率带 + bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1) + # 第三列频率带 + bandFrequence[:, 2] = lastFrequence + 2 + # 第四列频率带 + bandFrequence[:, 3] = bandFrequence[:, 2] + 10 + # bandFrequence = np.array([[30,33,77,82], + # [62,68,77,82]]) + for idx_fb in range(self.num_fbs): + Nq = self.fs / 2 + Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq] + Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq] + [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3, + 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)] + [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency + self.filterBank_A.append(A) + self.filterBank_B.append(B) + # def filterFrequenceBank(self): + # # 阻带的最高频率 + # lastFrequence = 90 + # freqBandWidth = self.bandWidth[1] + # fStep = self.bandWidth[0] + # bandFrequence = np.zeros((5, 4)) + # # 第二列频率带 + # band = list(range(freqBandWidth, lastFrequence, fStep)) + # band[:] = [x - 2 for x in band] + # colValue = np.maximum(np.asmatrix(band), 1) + # bandFrequence[:, 1] = colValue[0, 0:5] + # # 第一列频率带 + # bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1) + # # 第三列频率带 + # bandFrequence[:, 2] = lastFrequence + 2 + # # 第四列频率带 + # bandFrequence[:, 3] = bandFrequence[:, 2] + 10 + # for idx_fb in range(self.num_fbs): + # Nq = self.fs / 2 + # Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq] + # Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq] + # [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3, + # 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)] + # [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency + # self.filterBank_A.append(A) + # self.filterBank_B.append(B) + + ''' + Filter bank analysis + Input: + eeg : Input eeg data (# of targets, # of channels, Data length [sample]) + Output: + filterData : Generated filter Data + ''' + + def filterbank(self, eeg): + filterData = np.zeros((self.num_fbs, eeg.shape[1], eeg.shape[0])) + for filterIndex in range(self.num_fbs): + if np.all(self.filterZf[filterIndex] == 0): + zi = np.zeros( + [max(len(self.filterBank_A[filterIndex]), len(self.filterBank_B[filterIndex])) - 1, self.num_chans]) + _, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex], + eeg, zi=zi.T) + Data = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex], eeg) + else: + Data, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex], + self.filterBank_A[filterIndex], eeg, + zi=self.filterZf[filterIndex]) + filterData[filterIndex, :, :] = Data.T + return filterData + + ''' + process + 矩阵的白化和QR正则化分解,降低矩阵的维度,加速计算时间 + Input: + data : 输入的二维脑电信号 + Output: + Q : 降维后的矩阵 + rankQ :正则矩阵的秩 + ''' + + def process(self, data): + # 白化操作 + meanValue = np.asmatrix(data.mean(axis=1)) + meanData = np.repeat(meanValue, data.shape[1], axis=1) + whiteTemp = data - meanData + # QR 分解 + rankWhiteTemp = whiteTemp.shape[0] + whiteTemp = np.transpose(whiteTemp) + Q, R = qr(whiteTemp.A, mode='economic') + # 计算矩阵的秩 + rankQ = linalg.matrix_rank(R) + if rankQ == 0: + raise ValueError('stats:canoncorr:badData') + elif rankQ <= rankWhiteTemp: + # warnings.warn('stats:canoncorr:NotFullRank') + Q = Q[:, 0:rankQ] + return Q, rankQ + + ''' + reference + Input: + listFreqs : 刺激频率列表 + numberSmples : 用于分类的脑电信号采样点个数 + num_harms : 谐波数 + Output: + y_ref : 生成的参考信号 (刺激目标数, 2 * 谐波数, 数据长度/采样点数) + ''' + + def reference(self, listFreqs, numberSmples, num_harms): + numberFrequence = len(listFreqs) + timeIndex = np.arange(1, numberSmples + 1) / self.fs # time index + referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples)) + for frequenceIndex in range(numberFrequence): + temp = [] + for harmIndex in range(1, num_harms + 1): + stimFrequence = listFreqs[frequenceIndex] # in HZ + # Sin and Cos + temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence), + np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)]) + referenceTemp = np.asmatrix(temp) + # 白化操作和QR分解 + Q, rankQ = self.process(referenceTemp) + referenceData[frequenceIndex] = np.transpose(Q) + return referenceData + + ''' + setNorthFilterPara + 陷波器的参数初始化 + self.north_b, self.north_a : 陷波器的参数设计 + ''' + + def setNotchFilterPara(self): + # notchFilterNum = 3 + # northFreq = 50 + # bwDen = 35 + # wo = northFreq / (self.fs / 2) + # bw = wo / bwDen + # self.north_b, self.north_a = iirnotch(wo, Q=35) # self.north_b, self.north_a = iircomb(northFreq, bwDen, 'notch') + # # n倍零极点,相当于重复滤波n次 + # if notchFilterNum > 1: + # z, p, k = tf2zpk(self.north_b, self.north_a) + # zNew = np.repeat(z, notchFilterNum, axis=0) + # zNew[1], zNew[4] = zNew[4], zNew[1] + # pNew = np.repeat(p, notchFilterNum, axis=0) + # pNew[1], pNew[4] = pNew[4], pNew[1] + # kNew = np.power(k, notchFilterNum) + # self.north_b, self.north_a = zpk2tf(zNew, pNew, kNew) + self.north_b = [0.94801603944125245604368501517456, -1.7577184027642647201616910024313, + 3.9303778338832491279219993884908, -3.7392330345967859095424046245171, + 3.9303778338832482397435796883656, -1.7577184027642638319832713023061, + 0.94801603944125156786526531504933] + + self.north_a = [1, -1.8214007435820627200939725298667, 4.0000101767406484043476666556671, + -3.7380998614928691026193519064691, 3.8589119784285759173769747576443, + -1.6951692350503837491970671180752, 0.89786559147978006745205448169145] + + ''' + northFilter + 进行信号的50hz陷波处理 + Input: + data :输入脑电数据 + Output: + dataFiltered : 陷波处理后的脑电数据 + ''' + + def northFilter(self, data): + try: + if np.all(self.notchZh[0] == 0): + zi = np.zeros([max(len(self.north_a), len(self.north_b)) - 1, self.num_chans]) + _, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=zi.T) + dataFiltered = lfilter(self.north_b, self.north_a, data) + else: + dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0]) + return np.asmatrix(dataFiltered) + except Exception: + print(Exception) + + ''' + getDataQ + Inputs: + data:脑电数据 + Rbuffer:待更新的中间系数 + Output: + Qs1 : 脑电特征1 + Qs2 : 脑电特征2 + Rbuffer : 单窗口更新后的系数 + + ''' + + def getDataQ(self, data, Rbuffer): + Qs1 = [0] * self.num_fbs + Qs2 = [0] * self.num_fbs + nulldata = np.zeros([self.num_chans, self.num_chans]) + Rnum = self.num_chans + for fb_num in range(self.num_fbs): + fb_data = np.squeeze(data[fb_num, :, :]) + if np.all(Rbuffer[fb_num] == 0): + whiteTemp = fb_data + Q, R = qr(whiteTemp, mode='economic') + Qs1[fb_num] = nulldata + Qs2[fb_num] = Q + Rbuffer[fb_num] = R + else: + whiteTemp = np.concatenate((Rbuffer[fb_num], fb_data), axis=0) + Q, R = qr(whiteTemp, mode='economic') + Qs1[fb_num] = Q[0:Rnum, :] + Qs2[fb_num] = Q[Rnum:, :] + Rbuffer[fb_num] = R + return Qs1, Qs2, Rbuffer + + ''' + myCCA:根据脑电特征和参考信号计算相关系数 + Inputs: + dataQ:脑电特征 + Qc2y:参考信号 + d : 相关系数取值数 + Output: + rho : 相关系数 + + ''' + + def myCCA(self, dataQ, Qc2y, d): + if len(Qc2y) == 0: + Cov = dataQ + else: + Cov = np.dot(Qc2y, dataQ) + # U, S, V = scipy.linalg.svd(Cov, 0) + # rho = np.minimum(np.maximum(np.diag(S[0: d]).T, 0), 1) + _, S, _ = np.linalg.svd(Cov, full_matrices=False) + rho = S + return rho[0:d] + + ''' + weightCCA:计算分类标签 + Inputs: + Qs1:脑电特征1 + Qs2:脑电特征2 + ref : 正余弦参考信号 + Cxy : 协方差中间参数 + Output: + result : 分类标签 + rho : 相关系数 + Cxy : 更新后的协方差中间参数 + ''' + + def weightCCA(self, Qs1, Qs2, ref, Cxy): + rMax = np.zeros([self.num_fbs, self.num_target]) + for fi in range(self.num_fbs): + for si in range(self.num_target): + Qc2y = np.squeeze(ref[si, :, :]) + # 更新协方差矩阵 + if np.all(Cxy[fi][si] == 0): + Cxy[fi, si] = np.dot(Qc2y, Qs2[fi]) + else: + Cxy[fi, si] = np.dot(Cxy[fi, si], Qs1[fi]) + np.dot(Qc2y, Qs2[fi]) + r = self.myCCA(Cxy[fi, si], [], self.rhoNum) + rMax[fi, si] = r[0] + rho = np.dot(self.weightValue, np.power(rMax, 2)) # weighted sum of r from all different filter banks' result + result = np.argmax(rho) + return result, rho, Cxy + + ''' + costF:损失函数,根据计算的相关系数,生成决策值,用于和阈值进行比较 + Inputs: + rho:相关系数 + method:相关系数计算参数 + C : 参数 + Output: + decideValue : 决策阈值 + ''' + + def costF(self, rho, method, C): + rho = rho.tolist() + rho.sort(reverse=True) + if method == 'DW1': + decideValue = (rho[0] - rho[1]) / (sum(rho) - self.num_target * np.log(sum(np.exp(rho)))) + elif method == 'DW11': + decideValue = -(rho[0] - rho[1]) + elif method == 'DW2': + decideValue = (rho[0] - C) / (rho[1] - rho[0]) + return decideValue + + ''' + onlineInit:将窗口长度,相位值、中间参数初始化 + ''' + + def onlineInit(self): + self.dataUseLen = [0] * self.winNum + self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans]) + self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans]) + self.phase = 0 + + ''' + filterInit:重置陷波器和滤波器的滤波参数 + ''' + + def filterInit(self): + self.notchZh = [0] + self.filterZf = [0] * self.num_fbs + + ''' + warmFilter:预热滤波器,去刺激前的4S数据对陷波器和滤波器参数进行初始化迭代,去除过渡带的效果 + Inputs: + data:预处理脑电数据 + ''' + + def warmFilter(self, data): + # 降采样在采集前完成 + temp = self.preprocessFilter(data) #预热陷波滤波器 + # 滤波器组频带分解 + filterData = self.filterbank(temp) #预热滤波器组 + + ''' + myDownSample:数据降采样 + Inputs: + data:脑电数据 + n:降采样的倍数 + Output: + eegData2 : 降采样后的数据 + ''' + + def myDownSample(self, data, n): + data = data[:8, self.phase:] + dataNum = data.shape[1] + remainNum = (dataNum - 1) % n + self.phase = n - 1 - remainNum + dataDowmSample = [] + for value in data: + value = value[0:value.size:n] + dataDowmSample.append(value) + eegData2 = np.array(dataDowmSample).reshape([8, int(np.round(data.shape[1] / n))]) + return eegData2 + + ''' + preprocessFilter:预处理,调用函数降采样和陷波处理 + Inputs: + data:脑电数据 + Output: + filterData : 降采样和陷波后的数据 + ''' + + def preprocessFilter(self, data): + # data = self.myDownSample(data, 4) + # filterData = self.northFilter(data[:8, :]) + filterData = self.northFilter(data[:, :]) + return filterData + + ''' + fbccaDWMW:分类函数,对输入的脑电信号进行识别,输出决策标签 + Inputs: + testdata:脑电数据 + referenceData:参考信号 + tValue:出决策阈值 + Output: + res : 决策标签 + rho_new:相关系数 + minEps:得到的决策阈值 + ''' + + # 动态窗算法主函数 + def fbccaDWMW(self, testdata, referenceData, tValue,calculateCount): + t1 = time.time() + # try: + # 初始参数 + res = -1 + minEps = float("inf") + # 降采样和陷波器处理 + northData = self.preprocessFilter(testdata) + newSampleNum = northData.shape[1] + # 数据大于延迟长度,则无法根据后面的规则更新窗口 + if newSampleNum > self.winDelayNum: + error('need add window delay time') + + # 防止秩小于导联数 + if newSampleNum < self.num_chans: + warnings.warn('data shape is [%d %d] need more data' % (newSampleNum, northData.shape[0])) + # 滤波器组频带分解 + filterData = self.filterbank(northData) + winMinTime = 0 + # 计算每个窗口的结果 + for wi in range(0, self.winNum, self.winStep): + # print('dataUseLen:',wi,calculateCount, self.dataUseLen) + if wi == 0: + self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum + else: + if self.dataUseLen[wi] == 0: + # 判断当前窗是否为新的窗口(因为每一次新的窗口进来时,都会使上一个窗口datauseLen>50) + if self.dataUseLen[wi - self.winStep] > self.winDelayNum*self.winStep: + self.dataUseLen[wi] = newSampleNum + else: + # print('中断: ',wi,calculateCount) + break + else: + self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum + + if self.dataUseLen[wi] > self.winMaxSampleNum: + self.dataUseLen[wi] = newSampleNum + self.Rbuffer[wi, :, :, :] = 0 + self.Cxy[wi, :, :, :, :] = 0 + Qs1, Qs2, self.Rbuffer[wi, :, :, :] = self.getDataQ(filterData, self.Rbuffer[wi, :, :, :]) + si = self.dataUseLen[wi] - newSampleNum + ei = self.dataUseLen[wi] + ref = referenceData[:, :, si:ei] + # 更新协方差 + predLabel_new, rho_new, self.Cxy[wi, :, :, :] = self.weightCCA(Qs1, Qs2, ref, self.Cxy[wi, :, :, :]) + # 增加限制,数据长度不能太短 + if self.dataUseLen[wi] > winMinTime * self.fs: + epsilon = self.costF(rho_new, self.DW_cost_method, C=0) + if epsilon < minEps: + minEps = epsilon + predLabel = predLabel_new + xxx = rho_new + if minEps < tValue: + res = predLabel + + if time.time() - t1 > 0.2 and self.winStep < 16: + self.winStep = self.winStep * 2 + # print(self.winStep, " ", time.time() - t1) + # if res != -1: + # print('--------------------- ',res,xxx,' --------------------------') + return res + + +if __name__ == '__main__': + # The number of sub-bands in filter bank analysis + fs = 250 + num_chans = 8 + num_target = 40 + num_filterBank = 3 + num_harm = 5 + stimTime = 0.2 # 多窗口窗长 + winNum = 50 # 窗口的个数 + trials = 1 + step = 50 + res = -1 + list_freqs = [8.0, 8.2, 8.4, 8.6, 8.8, 9.0, 9.2, 9.4, 9.6, 9.8, 10.0, 10.2, 10.4, 10.6, 10.8, 11.0, 11.2, 11.4, + 11.6, 11.8, 12.0, 12.2, 12.4, 12.6, 12.8, 13., 13.2, 13.4, 13.6, 13.8, 14., 14.2, 14.4, 14.6, 14.8, + 15.0, 15.2, 15.4, 15.6, 15.8] + # 初始化对象 + dw = FbccaDw(fs, num_target, num_chans, num_filterBank, num_harm, stimTime, [1.0, 0.3], [8, 8], winNum) + # frequenceband + dw.filterFrequenceBank() + referenceData = dw.reference(list_freqs, int(winNum * stimTime * fs), num_harm) + dw.setNotchFilterPara() + + prelabels = np.zeros((1, 40)) + coefficient = np.zeros([1, 1]) + path = "D:\\工作相关\\项目代码\\SDK封装与测试\\八神BCI脑电信号SDK20230105\\双鹰SDK对比\\offline_data250\\" + for index in range(1, trials + 1): + D = loadmat(os.path.join(path + str(1) + '-warmData.mat')) + warmData = D['warmData'] + dw.onlineInit() + dw.filterInit() + dw.warmFilter(warmData.T) + + tagget_i = 0 + for tagget_i in range(1, step + 1): + D = loadmat(os.path.join(path + str(1) + '-' + str(tagget_i) + '.mat')) + dataSlice = D['dataTemp'] + res = dw.fbccaDWMW(dataSlice.T, referenceData, tValue=-0.2) + if res != -1: + break + prelabels[0, index - 1] = res + 1 + print(index, '--', res + 1," 计算轮数", tagget_i) diff --git a/Debug_64ch_Decoder/Tools/plot_MI_EEG.py b/Debug_64ch_Decoder/Tools/plot_MI_EEG.py new file mode 100644 index 0000000..4392f49 --- /dev/null +++ b/Debug_64ch_Decoder/Tools/plot_MI_EEG.py @@ -0,0 +1,851 @@ + +import os +import io + +import numpy as np +import pandas as pd +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.patches import Ellipse +import matplotlib.cm as cm +import matplotlib.colors as mcolors +from scipy.spatial import Delaunay +from scipy.interpolate import Rbf +from scipy.signal import welch +from scipy.stats import sem +from scipy.signal import butter, filtfilt, hilbert +import base64 + +# 位置坐标 +def read_ch_pos(file_path=r'xy_64.xlsx'): + """ + 将电极位置信息转换为Dict + + 参数: + file_path: 电极位置存储文件, 必须包含'channel', 'x', 'y', 'z'列 + + """ + script_dir = os.path.dirname(os.path.abspath(__file__)) + file_path = os.path.join(script_dir,file_path ) + df = pd.read_excel(file_path) + # 确保列名正确 + if not all(col in df.columns for col in ['channel', 'x', 'y', 'z']): + raise ValueError("DataFrame必须包含'channel', 'x', 'y', 'z'列") + # 创建电极位置字典 + ch_pos = {} + for _, row in df.iterrows(): + ch_pos[row['channel']] = [row['x'], row['y'], row['z']] + return ch_pos +# 头部轮廓 +def draw_head(ax, center=(0, 0), radius=1.0, zorder=4): + """ + 绘制头部轮廓、鼻子和耳朵。 + + 参数: + - ax : matplotlib Axes 对象 + - center : (x, y) 头中心坐标 + - radius : float, 头半径 + - zorder : 绘制层级 + """ + + # 头圆 + head = plt.Circle(center, radius, fill=False, color='k', linewidth=1, zorder=zorder) + ax.add_artist(head) + + # 鼻子(参考 _make_head_outlines) + dx = np.exp(np.arccos(np.deg2rad(12)) * 1j) + dx_real, dx_imag = dx.real, dx.imag + nose_x = np.array([-dx_real, 0, dx_real]) * radius + center[0] + nose_y = np.array([dx_imag, 1.15, dx_imag]) * radius + center[1] + ax.plot(nose_x, nose_y, color='k', linewidth=1, zorder=zorder) + + # 耳朵(参考 _make_head_outlines 手动标定) + ear_radius = radius * 0.12 + ear_scale = radius * 2 # 根据半径缩放 + theta = np.linspace(np.pi / 2, 3 * np.pi / 2, 30) + + # 左耳 + left_ear_x_array = np.array([0.497, 0.510, 0.518, 0.5299, 0.5419, + 0.54, 0.547, 0.532, 0.510, 0.489]) * ear_scale + left_ear_y_array = np.array([0.0555, 0.0775, 0.0783, 0.0746, 0.0555, + -0.0055, -0.0932, -0.1313, -0.1384, -0.1199]) * ear_scale + center[1] + ax.plot(center[0] - left_ear_x_array, left_ear_y_array, color='k', linewidth=1, zorder=zorder) + + # 右耳 + right_ear_x_array = np.array([0.497, 0.510, 0.518, 0.5299, 0.5419, + 0.54, 0.547, 0.532, 0.510, 0.489]) * ear_scale + right_ear_y_array = np.array([0.0555, 0.0775, 0.0783, 0.0746, 0.0555, + -0.0055, -0.0932, -0.1313, -0.1384, -0.1199]) * ear_scale + center[1] + ax.plot(center[0] + right_ear_x_array, right_ear_y_array, color='k', linewidth=1, zorder=zorder) +# 地形图 插值 +def rbf_D_interpolate(xy, v, center=(0, 0), radius=1.1, grid_res=300, + n_extra=32, rbf_func='multiquadric', smooth=0, + border='mean', border_scale=1.0001, n_ngb=4): + """ + 使用 RBF + Delaunay 邻域均值方式生成平滑的 EEG topomap 插值表面。 + + 参数 + ---- + xy : (N,2) array + 电极二维坐标(与绘图坐标系一致) + v : (N,) array + 每个电极对应的值(e.g. PSD) + center : tuple (x0, y0) + 头部圆心(默认 (0,0)) + radius : float + 头部半径(用于生成边界点与网格范围) + grid_res : int + 网格分辨率(每轴点数) + n_extra : int + 边界虚拟点数量 + rbf_func : str + RBF 内核名称('multiquadric','thin_plate','gaussian',...) + smooth : float + RBF 平滑参数 + border : 'mean' or float + 若 'mean':边界点用邻近真实通道均值赋值(推荐) + 若 float:边界点赋相同常数值 + border_scale : float + 边界点半径相对 radius 的缩放(略微 >1 用以外推) + n_ngb : int + 为每个边界点取值时使用的最近真实通道数 + + 返回 + ---- + zi : (grid_res, grid_res) ndarray + 插值结果(与 grid_x, grid_y 对齐) + grid_x, grid_y : ndarrays + meshgrid(由 np.meshgrid 生成) + """ + xy = np.asarray(xy) + v = np.asarray(v) + if xy.ndim != 2 or xy.shape[1] != 2: + raise ValueError("xy must be shape (n_channels, 2)") + + n_points = xy.shape[0] + + # --- 1. 生成边界虚拟点(圆周) --- + theta = np.linspace(0.0, 2 * np.pi, n_extra, endpoint=False) + r_border = radius * border_scale + border_xy = np.column_stack([center[0] + r_border * np.cos(theta), + center[1] + r_border * np.sin(theta)]) + + # --- 2. 用 Delaunay 建图以便找到邻居(对边界点取邻居均值) --- + # 合并用于三角化的位置(真实点 + 边界点) + tri_xy = np.vstack([xy, border_xy]) + tri = Delaunay(tri_xy) + + # --- 3. 为边界点赋值 --- + if isinstance(border, str) and border == 'mean': + # 使用 Delaunay 的 vertex_neighbor_vertices 索引 + # 注意:tri.vertex_neighbor_vertices 给出 vertices -> neighbor indptr + indices, indptr = tri.vertex_neighbor_vertices + v_extra = np.zeros(n_extra) + used = np.zeros(n_extra, dtype=bool) + # 边界点在 tri_xy 中的索引范围 + rng = range(n_points, n_points + n_extra) + for idx, extra_idx in enumerate(rng): + neigh = indptr[indices[extra_idx]:indices[extra_idx + 1]] + # 仅保留原始点索引(小于 n_points) + neigh = neigh[neigh < n_points] + if neigh.size > 0: + used[idx] = True + # 使用最近 n_ngb 个邻居的均值(若邻居多则取最近的 n_ngb) + if neigh.size > n_ngb: + # 计算距离并选取最近 n_ngb + d = np.linalg.norm(xy[neigh] - tri_xy[extra_idx], axis=1) + order = np.argsort(d)[:n_ngb] + sel = neigh[order] + else: + sel = neigh + v_extra[idx] = v[sel].mean() + if not used.all() and used.any(): + v_extra[~used] = np.mean(v_extra[used]) + elif not used.any(): + v_extra[:] = np.mean(v) + else: + # border 是数值 + v_extra = np.full(n_extra, float(border)) + + # --- 4. 合并所有已知点并构建 RBF --- + all_xy = np.vstack([xy, border_xy]) + all_v = np.concatenate([v, v_extra]) + + rbf = Rbf(all_xy[:, 0], all_xy[:, 1], all_v, + function=rbf_func, smooth=smooth) + + # --- 5. 生成网格(使用 meshgrid,与主函数保持一致) --- + xmin, xmax = center[0] - radius, center[0] + radius + ymin, ymax = center[1] - radius, center[1] + radius + xi = np.linspace(xmin, xmax, grid_res) + yi = np.linspace(ymin, ymax, grid_res) + grid_x, grid_y = np.meshgrid(xi, yi) # meshgrid 与 imshow 对齐 + + # --- 6. 评估 RBF,返回与 grid 对齐的 zi --- + zi = rbf(grid_x, grid_y) + + return zi, grid_x, grid_y +# plv矩阵计算 +def calculate_plv(data): + """ + 计算相位锁定值(PLV)矩阵。 + + Parameters + ---------- + data : ndarray, shape (num_channels, num_samples) + EEG 数据,通道数为 num_channels,样本数为 num_samples。 + + Returns + ------- + plv_matrix : ndarray, shape (num_channels, num_channels) + 计算得到的 PLV 矩阵,表示各通道间的相位同步。 + """ + num_channels, num_samples = data.shape + plv_matrix = np.zeros((num_channels, num_channels)) + + # 计算每个通道的解析信号 + analytic_signals = np.apply_along_axis(hilbert, axis=1, arr=data) + + for i in range(num_channels): + for j in range(i + 1, num_channels): # 只计算上三角矩阵,避免重复计算 + # 计算 phase difference + phase_diff = np.angle(analytic_signals[i] * np.conj(analytic_signals[j])) + plv = np.abs(np.mean(np.exp(1j * phase_diff))) + plv_matrix[i, j] = plv + plv_matrix[j, i] = plv # 对称矩阵 + + return plv_matrix +# 矩阵阈值化 +def threshold_proportional(adj, prop=0.2): + """ + Apply a proportional threshold to retain the top proportion of strongest edges. + + Parameters + ---------- + adj : ndarray, shape (n_channels, n_channels) + Adjacency matrix to threshold. + prop : float + Proportion of edges to retain (0 < prop <= 1). + + Returns + ------- + bin_adj : ndarray, shape (n_channels, n_channels) + Binary adjacency matrix after thresholding. + """ + n = adj.shape[0] + triu_idx = np.triu_indices(n, k=1) + weights = adj[triu_idx] + k = int(np.floor(len(weights) * prop)) + + # Ensure that at least one edge is retained + k = max(k, 1) + + # Get the threshold value + thr = np.sort(weights)[-k] + + # Apply the threshold to create a binary adjacency matrix + bin_adj = np.where(adj >= thr, adj, 0.0) + + return bin_adj +# 单个脑网络 +def plot_single_network(ch_names,adj,ax=None, + node_size=20, node_color='orange',highlight_nodes=[], show_names=True, + edge_color='gray', weighted=True, + radius=1.1, figsize=(6, 6),cmap='RdYlBu_r'): + # 若 ax 未传入,则自己创建 + own_fig = False + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + own_fig = True + else: + fig = ax.figure + + # 坐标归一化 + pos3d = read_ch_pos() + all_chs_xy = np.array([pos3d[ch][:2] for ch in pos3d.keys()]) + all_chs_xy -= all_chs_xy.mean(axis=0) + all_chs_xy /= np.sqrt((all_chs_xy ** 2).sum(axis=1)).max() + xy_dict = dict(zip(pos3d.keys(), all_chs_xy)) + xy = np.array([xy_dict[ch] for ch in ch_names]) + center = xy_dict.get('CZ', np.mean(list(xy_dict.values()), axis=0)) + + # ===== 初始化绘图窗口 ===== + ax.set_aspect('equal') + ax.axis('off') + # 设置边界(与原类保持一致) + ear_radius = radius * 0.12 + nose_height = radius * 0.15 + margin_x = radius * 0.12 + 0.05 + ax.set_xlim(center[0] - radius - margin_x, center[0] + radius + margin_x) + ax.set_ylim(center[1] - radius - ear_radius, center[1] + radius + nose_height + ear_radius) + + # 绘制头部轮廓 + draw_head(ax, center=center, radius=radius) + + # 节点 + for ch in ch_names: + color = 'red' if ch in highlight_nodes else node_color + ax.scatter(*xy_dict[ch], s=node_size, color=color, edgecolor='k', zorder=4) + if show_names: + ax.text(xy_dict[ch][0], xy_dict[ch][1] + 0.03, ch, + ha='center', va='bottom', fontsize=8, zorder=5) + + # colorbar + norm = mcolors.Normalize(vmin=0, vmax=1) + color_map = matplotlib.colormaps.get_cmap(cmap) + # ========= 边 ========== + N = len(ch_names) + for i in range(N): + for j in range(i + 1, N): + w = adj[i, j] + if w > 0: + x = [xy[i, 0], xy[j, 0]] + y = [xy[i, 1], xy[j, 1]] + lw = 1.5 + if weighted: + ax.plot(x, y, + color=color_map(norm(w)), + linewidth=lw, + alpha=0.7, + zorder=3) + else: + ax.plot(x, y, + color=edge_color, + linewidth=lw, + alpha=0.7, + zorder=3) + + if own_fig: + # 不回传 添加颜色条 + sm = cm.ScalarMappable(norm=norm, cmap=color_map) + cbar = plt.colorbar(sm, ax=ax, fraction=0.035) + cbar.set_label('Connection Strength', fontsize=10) + cbar.ax.tick_params(direction='in', labelsize=10) + plt.show() + return fig + else: + + return ax +# 脑网络对比 +def plot_multiband_network(ch_names, adj_MI, adj_Rest,cmap='RdYlBu_r'): + + fig, axes = plt.subplots(1, 2, figsize=(8, 4)) + fontsize = 16 + fig.text(0.285, 0.08, 'MI', fontsize=fontsize, ha='center', va='center', rotation=0) + fig.text(0.68, 0.08, 'Rest', fontsize=fontsize, ha='center', va='center', rotation=0) + + im1 = plot_single_network(ch_names,adj_MI,ax=axes[0], show_names=True,cmap=cmap) + # Rest 行 + im2 = plot_single_network(ch_names,adj_Rest,ax=axes[1],show_names=True,cmap=cmap) + + # --- 合并 colorbar(右侧一个) --- + norm = mcolors.Normalize(vmin=0, vmax=1) + color_map = matplotlib.colormaps.get_cmap(cmap) + sm = cm.ScalarMappable(norm=norm, cmap=color_map) + cbar = plt.colorbar(sm, ax=axes.ravel().tolist(), fraction=0.02) + cbar.set_label('Connection Strength', fontsize=10) + cbar.ax.tick_params(direction='in', labelsize=10) + + # 将图像保存到内存字节流(PNG 格式) + buf = io.BytesIO() + fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') + plt.close(fig) # 释放内存 + buf.seek(0) + image_bytes = buf.read() + buf.close() + + return image_bytes + +# 多个频带psd +def compute_band_psd(eeg, fs, bands, labels, trial_idx=0,MI_label=1, Rest_label=2,avg = True): + """ + eeg: (n_trials, n_channels, n_samples) + """ + n_trials, n_channels, n_samples = eeg.shape + band_names = list(bands.keys()) + n_bands = len(band_names) + + psd_MI = np.zeros((n_bands, n_channels)) + psd_Rest = np.zeros((n_bands, n_channels)) + + # 先计算所有 trial 的功率谱 + f, Pxx = welch(eeg, fs=fs, axis=-1, nperseg=fs,noverlap = fs // 2) + + + for bi, (bname, (f1, f2)) in enumerate(bands.items()): + idx = np.logical_and(f >= f1, f <= f2) + band_power = Pxx[:, :, idx].mean(axis=-1) + + band_power_flat = band_power.flatten() + power_min = band_power_flat.min() + power_max = band_power_flat.max() + if power_max - power_min > 1e-12: + band_power_norm = (band_power - power_min) / (power_max - power_min) + else: + band_power_norm = band_power + + if avg: + psd_MI[bi] = band_power_norm[labels == MI_label].mean(axis=0) + psd_Rest[bi] = band_power_norm[labels == Rest_label].mean(axis=0) + else: + psd_MI[bi] = band_power_norm[labels == MI_label][trial_idx] + psd_Rest[bi] = band_power_norm[labels == Rest_label][trial_idx] + return band_names, psd_MI, psd_Rest +# 单个脑地形图 +def plot_single_topomap(ch_names, psd_values, cmap='RdYlBu_r', vlim=(0, 1), + show_names=True, node_size=3, radius=1.1, grid_res=300, + n_contours=None, contour_color='k', + ax=None,figsize=(6,6)): + # 若 ax 未传入,则自己创建 + own_fig = False + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + own_fig = True + else: + fig = ax.figure + + # ===== 初始化绘图窗口 ===== + ax.set_aspect('equal') + ax.axis('off') + # ax.set_title("EEG topomap (MNE-like)") + + # 坐标归一化 + pos3d = read_ch_pos() + all_chs_xy = np.array([pos3d[ch][:2] for ch in pos3d.keys()]) + all_chs_xy -= all_chs_xy.mean(axis=0) + all_chs_xy /= np.sqrt((all_chs_xy ** 2).sum(axis=1)).max() + pos2d_dict = dict(zip(pos3d.keys(), all_chs_xy)) + xy = np.array([pos2d_dict[ch] for ch in ch_names]) + center = pos2d_dict.get('CZ', np.mean(list(pos2d_dict.values()), axis=0)) + + # 绘制头部轮廓 + draw_head(ax, center=center, radius=radius) + # 绘制电极 + fontsize = 4 + ax.scatter(xy[:, 0], xy[:, 1], c='k', s=node_size, zorder=5) + if show_names: + for i, ch in enumerate(ch_names): + ax.text(xy[i, 0], xy[i, 1] + 0.03, ch, + ha='center', va='bottom', fontsize=fontsize, zorder=6) + + # 数据插值 + zi, grid_x, grid_y = rbf_D_interpolate( + xy, psd_values, radius=radius, + grid_res=grid_res + ) + xmin, xmax = center[0] - radius, center[0] + radius + ymin, ymax = center[1] - radius, center[1] + radius + extent = (xmin, xmax, ymin, ymax) + im = ax.imshow(zi, extent=extent, origin='lower', + cmap=cmap, vmin=vlim[0], vmax=vlim[1], + interpolation='bicubic', zorder=0) + # 裁剪路径 + patch_ = Ellipse(center, 2 * radius, 2 * radius, clip_on=True, transform=ax.transData) + im.set_clip_path(patch_) + # 初始等高线 + linewidths = 0.5 + if n_contours is None: + cset = ax.contour(grid_x, grid_y, zi, + colors=contour_color, linewidths=linewidths, zorder=2) + else: + cset = ax.contour(grid_x, grid_y, zi, levels=n_contours, + colors=contour_color, linewidths=linewidths, zorder=2) + cset.set_clip_path(patch_) + + + + if own_fig: + # 不回传 添加颜色条 + plt.colorbar(im, ax=ax, fraction=0.035) + plt.show() + return fig + else: + # plt.colorbar(im, ax=ax, fraction=0.035) + return im +# 脑地形图对比 +def plot_multiband_topomaps(ch_names, psd_MI, psd_Rest, bands): + band_names = list(bands.keys()) # 改动 1:新增这行 + n_bands = len(band_names) + fig, axes = plt.subplots(2, n_bands, figsize=(3*n_bands, 6)) + + fontsize = 16 + + axes[0, 0].text(-0.1, 0.5, 'MI', transform=axes[0, 0].transAxes, rotation=0, va='center', ha='center', fontsize=fontsize-2) + axes[1, 0].text(-0.1, 0.5, 'Rest', transform=axes[1, 0].transAxes, rotation=0, va='center', ha='center', fontsize=fontsize-2) + + imgs = [] + for i, bname in enumerate(band_names): + axes[0, i].set_title(bname, fontsize=fontsize, pad=0) + # MI 行 + im1 = plot_single_topomap(ch_names,psd_MI[i],ax=axes[0, i], show_names=True) + # Rest 行 + im2 = plot_single_topomap(ch_names,psd_Rest[i],ax=axes[1, i],show_names=True) + imgs.append(im1) + + # --- 单个右侧合并 colorbar --- + cbar = fig.colorbar(imgs[0], ax=axes,fraction=0.02) + # cbar.set_label("PSD Power",fontsize=fontsize-4) + cbar.ax.tick_params(direction='in', labelsize=10) + + # 将图像保存到内存字节流(PNG 格式) + buf = io.BytesIO() + fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') + plt.close(fig) # 释放内存 + buf.seek(0) + image_bytes = buf.read() + buf.close() + + return image_bytes + +# 小波 +def morlet_wavelet(f, fs, n_cycles=7): + """ + 创建 Morlet 小波 + f: 频率 + fs: 采样率 + """ + sigma_t = n_cycles / (2 * np.pi * f) + t = np.arange(-3*sigma_t, 3*sigma_t, 1/fs) + wavelet = (np.pi**-0.25) * np.exp(2j*np.pi*f*t) * np.exp(-(t**2)/(2*sigma_t**2)) + return wavelet + + +# 希尔伯特变换 计算ERDS 效果不佳 +def bandpass_filter(data, fs, band, order=4): + nyq = fs / 2 + b, a = butter(order, [band[0]/nyq, band[1]/nyq], btype='band') + return filtfilt(b, a, data, axis=-1) +def compute_power_hilbert(filtered_data,is_dB =True): + analytic = hilbert(filtered_data, axis=-1) + power = np.abs(analytic) ** 2 + if is_dB: + power = 10 * np.log10(power) + return power +def compute_power(data, fs=250, + bands={"mu": (8,12), "beta": (13,30)}): + """ + 返回: + power_dict[band] = (n_trials, n_ch, n_samples) + """ + power_dict = {} + for band_name, band_range in bands.items(): + filt = bandpass_filter(data, fs, band_range) + power = compute_power_hilbert(filt) + power_dict[band_name] = power + + return power_dict + +def compute_erds(power_MI, power_Rest, baseline_period=None): + """ + 计算事件相关去同步/同步 (ERDS) + + Parameters + ---------- + power_MI, power_Rest: (n_trials, n_ch, n_samples) + 功率数据,单位为 µV² 或 dB(取决于 compute_power_hilbert 的 is_dB 参数) + baseline_period: tuple (start_idx, end_idx) or None + 基线时间段索引。如果为None,使用 Rest 状态的平均值作为基线 + + 返回: + MI_erds_mean, MI_erds_sem + Rest_erds_mean, Rest_erds_sem + 所有返回值的形状为 (n_ch, n_samples) + """ + + if baseline_period is not None: + start_idx, end_idx = baseline_period + baseline = np.concatenate([power_MI[:, :, start_idx:end_idx], + power_Rest[:, :, start_idx:end_idx]], axis=0) + baseline = baseline.mean(axis=(0, 2), keepdims=True) + else: + baseline = power_Rest.mean(axis=(0,2), keepdims=True) + + # === ERDS (%) === + MI_erds = (power_MI - baseline) / baseline * 100 + Rest_erds = (power_Rest - baseline) / baseline * 100 + + return ( + MI_erds.mean(axis=0), sem(MI_erds, axis=0), + Rest_erds.mean(axis=0), sem(Rest_erds, axis=0), + ) + +def compute_all_erds(MI_power_dict, Rest_power_dict): + """ + 对多个频带同时计算 ERDS。 + + 输入: + MI_power_dict[band] = (n_trials, n_ch, n_samples) + Rest_power_dict[band] = (n_trials, n_ch, n_samples) + + 输出: + erds_MI[band] = (mean, sem) + erds_Rest[band] = (mean, sem) + """ + + erds_MI = {} + erds_Rest = {} + + for band in MI_power_dict.keys(): + MI_power = MI_power_dict[band] + Rest_power = Rest_power_dict[band] + + MI_mean, MI_sem, Rest_mean, Rest_sem = compute_erds(MI_power, Rest_power) + + erds_MI[band] = (MI_mean, MI_sem) + erds_Rest[band] = (Rest_mean, Rest_sem) + + return erds_MI, erds_Rest + +def plot_compare_erds(data_MI, data_Rest, mode="power", + ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'], + compare_names=['C3', 'CZ', 'C4'], bands=['mu', 'beta'], + fs=250, t=None, figsize=(12,6)): + + n_bands = len(bands) + n_chs = len(compare_names) + + # 自动添加单位 + if mode == "power": + # y_unit = "Power (µV²)" + y_unit = "Power (dB)" + elif mode == "erds": + y_unit = "ERDS (%)" + else: + y_unit = "" + + if t is None: + n_samples = next(iter(data_MI.values())).shape[-1] \ + if mode=="power" else next(iter(data_MI.values()))[0].shape[-1] + t = np.arange(n_samples) / fs + + fig, axes = plt.subplots(n_bands, n_chs, figsize=figsize, sharex=True, sharey=True) + + for i, band in enumerate(bands): + + # 选择数据结构 + if mode == "power": + MI_band = data_MI[band] # (trials, ch, samples) + Rest_band = data_Rest[band] + + avg_MI = MI_band.mean(axis=0) + sem_MI = MI_band.std(axis=0)/np.sqrt(MI_band.shape[0]) + + avg_Rest = Rest_band.mean(axis=0) + sem_Rest = Rest_band.std(axis=0)/np.sqrt(Rest_band.shape[0]) + + elif mode == "erds": + avg_MI, sem_MI = data_MI[band] + avg_Rest, sem_Rest = data_Rest[band] + + for j, ch in enumerate(compare_names): + ax = axes[i, j] if n_bands > 1 else axes[j] + + ch_idx = ch_names.index(ch) + + # 绘制 MI + ax.plot(t, avg_MI[ch_idx], color="C0", label="MI") + ax.fill_between(t, + avg_MI[ch_idx]-sem_MI[ch_idx], + avg_MI[ch_idx]+sem_MI[ch_idx], + alpha=0.3, color="C0") + + # 绘制 Rest + ax.plot(t, avg_Rest[ch_idx], color="C1", label="Rest") + ax.fill_between(t, + avg_Rest[ch_idx]-sem_Rest[ch_idx], + avg_Rest[ch_idx]+sem_Rest[ch_idx], + alpha=0.3, color="C1") + + if i == 0: + ax.set_title(ch) + + # ← Y 轴加单位 + if j == 0: + ax.set_ylabel(f"{band}\n{y_unit}") + + if i == n_bands - 1: + ax.set_xlabel("Time (s)") + + ax.grid(alpha=0.3) + + if i == 0 and j == n_chs - 1: + ax.legend() + + plt.tight_layout() + + # 将图像保存到内存字节流(PNG 格式) + buf = io.BytesIO() + fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') + plt.close(fig) # 释放内存 + buf.seek(0) + image_bytes = buf.read() + buf.close() + + return image_bytes + +# 对比 MI vs Rest 的功率谱密度 PSD +def plot_psd_compare(MI_data, Rest_data, ch_names, compare_names=['C3', 'CZ', 'C4'], + fs=250, nperseg=None, average=True, show_sem=True, + figsize=(12, 3), save_dir=None, filename="psd.png"): + """ + 对比 MI vs Rest 的功率谱密度 PSD + + MI_data, Rest_data: (n_trials, n_ch, n_samples) + channels: 需要绘制的通道 + average: 是否对所有试次平均 + show_sem: 是否绘制 SEM 阴影 + """ + + n_trials, n_ch, n_samples = MI_data.shape + n_trials = min(len(MI_data), len(Rest_data)) + # assert Rest_data.shape == MI_data.shape, "MI 和 Rest 数据维度必须一致" + + if nperseg is None: + nperseg = fs # 每 1 秒窗长度 + + # 计算 MI PSD + psd_MI_all = [] + for trial in range(n_trials): + psd_trial = [] + for ch in range(n_ch): + f, Pxx = welch(MI_data[trial, ch], fs=fs, nperseg=nperseg) + psd_trial.append(Pxx) + psd_MI_all.append(psd_trial) + psd_MI_all = np.array(psd_MI_all) + + # 计算 Rest PSD + psd_Rest_all = [] + for trial in range(n_trials): + psd_trial = [] + for ch in range(n_ch): + _, Pxx = welch(Rest_data[trial, ch], fs=fs, nperseg=nperseg) + psd_trial.append(Pxx) + psd_Rest_all.append(psd_trial) + psd_Rest_all = np.array(psd_Rest_all) + + # ---- Plot ---- + fig, ax = plt.subplots(1, len(compare_names), figsize=figsize) + if len(compare_names) == 1: + ax = [ax] + + for i, ch in enumerate(compare_names): + ch_idx = ch_names.index(ch) + psd_MI_ch = psd_MI_all[:, ch_idx, :] + psd_Rest_ch = psd_Rest_all[:, ch_idx, :] + + if average: + mean_MI = psd_MI_ch.mean(axis=0) + mean_Rest = psd_Rest_ch.mean(axis=0) + + ax[i].plot(f, mean_MI, color='C0', label='MI') + ax[i].plot(f, mean_Rest, color='C1', label='Rest') + + if show_sem: + ax[i].fill_between(f, mean_MI - sem(psd_MI_ch, axis=0), + mean_MI + sem(psd_MI_ch, axis=0), color='C0', alpha=0.3) + ax[i].fill_between(f, mean_Rest - sem(psd_Rest_ch, axis=0), + mean_Rest + sem(psd_Rest_ch, axis=0), color='C1', alpha=0.3) + else: + ax[i].plot(f, psd_MI_ch.T, color='C0', alpha=0.3) + ax[i].plot(f, psd_Rest_ch.T, color='C1', alpha=0.3) + + ax[i].set_title(ch) + ax[i].set_xlabel("Frequency (Hz)") + ax[i].set_ylabel("PSD (μV²/Hz)") + ax[i].grid(alpha=0.3) + if i == 0: + ax[i].legend() + + plt.tight_layout() + + # 将图像保存到内存字节流(PNG 格式) + buf = io.BytesIO() + fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') + plt.close(fig) # 释放内存 + buf.seek(0) + image_bytes = buf.read() + buf.close() + + return image_bytes + + +def plotMain( + ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'], + compare_names = [ 'C3','CZ','C4'], + Data = None,labels = None,MI_label = None,Rest_label = None, + fs = 250): + + trial_idx = 0 + + # 数据划分 + if not MI_label: + label_ = np.unique(labels) + else: + label_ = (MI_label,Rest_label) + MI_data = Data[labels == label_[0]] + Rest_data = Data[labels == label_[1]] + + # 典型 EEG 频带 + FREQ_BANDS = { + "Delta (0.8-4Hz)": (0.8, 4), + "Theta (4-8Hz)": (4, 8), + "Alpha (8-12Hz)": (8, 12), + "Beta (12-30Hz)": (12, 30), + "All (0.8-30Hz)": (0.8, 30) + } + # 利用welch估算PSD + band_names, psd_MI, psd_Rest= compute_band_psd( + eeg=Data, + fs=fs, + bands=FREQ_BANDS, + labels=labels, + trial_idx=trial_idx, + MI_label=MI_label, + Rest_label=Rest_label, + avg= True + ) + # 绘制地形图 + topomaps_imgBytes = plot_multiband_topomaps( + ch_names=ch_names, + psd_MI=psd_MI, + psd_Rest=psd_Rest, + bands=FREQ_BANDS + ) + + # 绘制脑网络 + mi_plv_matrix = calculate_plv(MI_data[trial_idx]) + mi_BI_matrix = threshold_proportional(mi_plv_matrix, prop=0.3) + rest_plv_matrix = calculate_plv(Rest_data[trial_idx]) + rest_BI_matrix = threshold_proportional(rest_plv_matrix, prop=0.3) + network_imgBytes = plot_multiband_network(ch_names, mi_BI_matrix, rest_BI_matrix) + + # ERDS 先计算erds,后平均 + MI_power = compute_power(MI_data) + Rest_power = compute_power(Rest_data) + erds_dict_MI, erds_dict_Rest = compute_all_erds(MI_power, Rest_power) + erds_imgBytes = plot_compare_erds(erds_dict_MI, erds_dict_Rest, ch_names=ch_names, + compare_names=compare_names, bands=['mu', 'beta'], + fs=fs, mode="erds") + + # 绘制PSD + psd_imgBytes = plot_psd_compare(MI_data, Rest_data, ch_names = ch_names, compare_names=compare_names, + fs=fs, nperseg=None, average=True, show_sem=True, + figsize=(12, 3)) + return {'topomaps_imgBytes':base64.b64encode(topomaps_imgBytes).decode(),'network_imgBytes':base64.b64encode(network_imgBytes).decode(), + 'erds_imgBytes':base64.b64encode(erds_imgBytes).decode(),'psd_imgBytes':base64.b64encode(psd_imgBytes).decode()} + +if __name__ == '__main__': + allData = np.random.uniform(-50,50,size=(80,21,1000)) + allLabel = np.random.randint(1,3,size=(80,)) + allData = allData[:len(allLabel)] + ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', + 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'] + compare_names = ['C3', 'CZ', 'C4'] + ret = plotMain(ch_names=ch_names, compare_names=compare_names, Data=allData, labels=allLabel, MI_label=1, Rest_label=2, + fs=250) + print('计算完成,开始发送') + from Tools.zmqClient import zmqClient + + zmqClient = zmqClient('192.168.76.101', 8088) + zmqClient.connect() + zmqClient.send_to_all('miReport', ret) diff --git a/Debug_64ch_Decoder/Zmq/zmqClient.py b/Debug_64ch_Decoder/Zmq/zmqClient.py new file mode 100644 index 0000000..0d29377 --- /dev/null +++ b/Debug_64ch_Decoder/Zmq/zmqClient.py @@ -0,0 +1,57 @@ +import threading +import time +import json +import zmq + + +class zmqClient: + def __init__(self, host, port): + self.host = host + self.port = port + self.client_socket = None + self.running = False + + # 记录客户端连接前的状态 + self.state = { + 'status_code': None, + 'energy': None + } + + def connect(self): + # 创建 ZeroMQ 上下文 + self.context = zmq.Context() + # 创建 REQ 套接字(请求端) + self.client_socket = self.context.socket(zmq.DEALER) + # client_id = b'client1' + # self.client_socket.setsockopt(zmq.IDENTITY,client_id) + self.client_socket.connect(f"tcp://{self.host}:{self.port}") # 连接到服务器 + self.running = True + + def send_to_all(self, method,params): + if method in self.state.keys(): + self.state[method] = params + try: + if self.running and self.client_socket != None: + msg = {'method': method, 'params': params} + # 发送响应 + print(msg) + self.client_socket.send_multipart([b'', json.dumps(msg).encode('utf-8')]) + else: + if method in self.state.keys(): + self.state[method] = params + except ConnectionResetError: + print("Connection lost.") + self.running = False + except Exception as e: + print(f"An error occurred: {e}") + + def close_connection(self): + self.running = False + self.client_socket.close() + self.context.term() + print("Client closed explicitly.") +# 使用TCP客户端 +if __name__ == "__main__": + client = zmqClient('127.0.0.1', 8099) + client.connect() + # client.close_connection() \ No newline at end of file diff --git a/Debug_64ch_Decoder/Zmq/zmqServer.py b/Debug_64ch_Decoder/Zmq/zmqServer.py new file mode 100644 index 0000000..5fb5bed --- /dev/null +++ b/Debug_64ch_Decoder/Zmq/zmqServer.py @@ -0,0 +1,104 @@ +import numpy as np +import zmq +import threading +import json +from Device.SunnyLinker import SunnyLinker64 + +class zmqServer(threading.Thread): + def __init__(self, host='0.0.0.0', port=8099): + threading.Thread.__init__(self) + self.host = host + self.port = port + self.running = False + self.get_Impedance = False # 是否返回阻抗值 + self.open_Impedance = None # 是否开启阻抗检测功能 + self.StartDecode = False # false 停止解码,true=开始解码 + self.StartTrain = False # False未进入训练状态,True处于训练状态 + self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态 + self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签 + self.IsExitApp = False # 当socket收到2的时候,就置为True,代表遥退出系统了。 + self.getReport = False # 获取训练报告内容 + self.daemon = True + # 创建 ZeroMQ 上下文 + self.context = zmq.Context() + # 创建 REP 套接字(响应端) + self.socket = self.context.socket(zmq.ROUTER) + self.socket.bind(f"tcp://{self.host}:{self.port}") # 绑定到端口 8099 + self.targetFreqs = [] + self.changeTarget = False # 更换目标频率 + self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化 + self.labels = [0x01, 0x02,0x03] + + self.decoder_switch = False #更换解码器 + self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi' + def run(self): + self.running = True + print(f"Server is running on {self.host}:{self.port}") + try: + while self.running: + # 等待客户端请求 + _,_,message = self.socket.recv_multipart() + message = json.loads(message.decode('utf-8')) + print(f"Received request: {message}") + # 处理请求 + method = message.get("method") + params = message.get("params") + if method == "sync": + self.state_mode = 'sync' + if method == "targetFreqs": + if not isinstance(params,list): + print('targetFreqs must be a list') + continue + if params != self.targetFreqs: + self.targetFreqs = params + self.changeTarget = True + if method == "decoderClass": + if not isinstance(params,str): + print('decoderClass must be a str') + continue + if params != self.decoder_class: + self.decoder_class = params + self.decoder_switch = True + if method == "getReport": + self.getReport = True + if method == "train":#训练状态 + self.state_mode = 'train' + self.StartTrain = True + self.currentLabel = params # 当前刺激端的训练标签 + self.sunnyLinker.push_trigger(self.labels[self.currentLabel]) + elif method == "predict":#预测状态 + self.state_mode = 'predict' + if params == 1: #开始解码 + self.StartDecode = True + self.sunnyLinker.push_trigger(0x63) + elif params == 2: #停止解码 + self.IsExitApp = True + self.running = False + elif method == "rest": #休息状态 + self.state_mode = 'rest' + elif method == "impedance": + if params == 1: + self.open_Impedance = True # 开启阻抗 + self.get_Impedance = True # 返回阻抗 + elif params == 2: + self.open_Impedance = False # 关闭阻抗 + self.get_Impedance = False # 停止返回阻抗 + + except Exception as e: + print(f"An socket error occurred: {e}") + finally: + self.running = False + # 关闭套接字和上下文 + self.socket.close() + self.context.term() + print("Server socket and context closed.") + def stop(self): + """显式关闭服务器""" + self.running = False + self.socket.close() + self.context.term() + print("Server closed explicitly.") + +if __name__ == '__main__': + server = zmqServer() + server.start() \ No newline at end of file diff --git a/Debug_64ch_Decoder/build_algorithm.spec b/Debug_64ch_Decoder/build_algorithm.spec new file mode 100644 index 0000000..9c14ffc --- /dev/null +++ b/Debug_64ch_Decoder/build_algorithm.spec @@ -0,0 +1,132 @@ +# -*- mode: python ; coding: utf-8 -*- + +import sys +import os +from PyInstaller.utils.hooks import collect_submodules, collect_data_files + +# ======================================================== +# 1. 工程配置区 (Project Config) +# ======================================================== +block_cipher = None +ENTRY_POINT = 'runDecoder.py' +APP_NAME = 'runDecoder' + +# ======================================================== +# 2. 依赖分析 (Dependency Analysis) +# ======================================================== +# 显式声明的隐藏导入,确保 PyInstaller 能找到所有 C 扩展和动态模块 +hidden_imports = [ + # sklearn Cython 扩展(极易被遗漏) + 'sklearn.utils._cython_blas', + 'torchsummary', + 'sklearn.neighbors._typedefs', + 'sklearn.neighbors._quad_tree', + 'sklearn.tree._utils', + 'sklearn.tree._criterion', + 'sklearn.tree._splitter', + 'sklearn.tree._tree', + 'sklearn.utils._weight_vector', + # torch 核心模块 + 'torch', + 'torch.nn', + 'torch.nn.modules', + 'torch.nn.modules.activation', + 'torch.nn.modules.batchnorm', + 'torch.nn.modules.conv', + 'torch.nn.modules.dropout', + 'torch.nn.modules.linear', + 'torch.nn.modules.normalization', + 'torch.nn.modules.pooling', + 'torch.nn.functional', + 'torch.autograd', + 'torch.optim', + 'torch.utils.data', + 'torch.cuda', + # einops(必须显式添加) + 'einops', + 'einops.layers', + 'einops.layers.torch', + # 并行计算相关 + 'multiprocessing', + 'multiprocessing.connection', + 'multiprocessing.context', + 'multiprocessing.managers', + 'multiprocessing.pool', + 'multiprocessing.process', + 'multiprocessing.queues', + 'multiprocessing.reduction', + 'multiprocessing.sharedctypes', + 'multiprocessing.synchronize', + 'multiprocessing.util', +] + +# ======================================================== +# 3. 资源锚定 (Data Anchoring) +# ======================================================== +# 收集 torch 的数据文件(triton、算子权重等) +datas = collect_data_files('torch') +datas += collect_data_files('torchvision') +# 收集 einops 数据文件 +datas += collect_data_files('einops') +# 收集 sklearn 数据文件 +datas += collect_data_files('sklearn') +# 收集 scipy 数据文件 +datas += collect_data_files('scipy') + +# ======================================================== +# 4. 构建流程 (Build Process) +# ======================================================== +a = Analysis( + [ENTRY_POINT], + pathex=[], + binaries=[], + datas=datas, + hiddenimports=hidden_imports, + hookspath=[], + hooksconfig={}, + runtime_hooks=['rthook.py'], + excludes=['tkinter', 'PyQt5', 'PySide2', 'PySide6', 'IPython', 'notebook', 'jupyter'], + win_no_prefer_redirects=False, + win_private_assemblies=False, + cipher=block_cipher, + noarchive=False, +) + +pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) + +exe = EXE( + pyz, + a.scripts, + [], + exclude_binaries=True, + name=APP_NAME, + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=False, + console=True, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) + +# ======================================================== +# 5. 打包模式: OneDir (单文件夹) + 资源旁路 +# ======================================================== +coll = COLLECT( + exe, + a.binaries, + a.zipfiles, + a.datas, + # 显式复制资源文件夹到 exe 同级目录 + Tree('online_Models', prefix='online_Models', excludes=['*.pyc', '__pycache__']), + Tree('Tools', prefix='Tools', excludes=['*.pyc', '__pycache__']), + # config.ini 作为单独文件 + [('config.ini', 'config.ini', 'DATA')], + strip=False, + upx=False, + upx_exclude=[], + name=APP_NAME, +) diff --git a/Debug_64ch_Decoder/build_with_copy.py b/Debug_64ch_Decoder/build_with_copy.py new file mode 100644 index 0000000..e8f86d6 --- /dev/null +++ b/Debug_64ch_Decoder/build_with_copy.py @@ -0,0 +1,88 @@ +import os +import shutil +import subprocess +import sys + +def main(): + # 1. 定义路径 + BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + DIST_DIR = os.path.join(BASE_DIR, 'dist') + APP_NAME = 'runDecoder' + TARGET_DIR = os.path.join(DIST_DIR, APP_NAME) + + # 定义需要复制的资源 {源路径: 目标子路径} + # 目标子路径相对于 TARGET_DIR + RESOURCES = { + 'config.ini': 'config.ini', + 'online_Models': 'online_Models', + 'Tools': 'Tools', + } + + # 2. 清理旧构建 + print("[1/3] Cleaning up old builds...") + if os.path.exists(DIST_DIR): + try: + shutil.rmtree(DIST_DIR) + print(" Cleaned dist/") + except Exception as e: + print(f" Warning: Could not clean dist/: {e}") + + BUILD_DIR = os.path.join(BASE_DIR, 'build') + if os.path.exists(BUILD_DIR): + try: + shutil.rmtree(BUILD_DIR) + print(" Cleaned build/") + except Exception as e: + print(f" Warning: Could not clean build/: {e}") + + # 3. 运行 PyInstaller + print("[2/3] Running PyInstaller...") + # 注意:我们这里不传 --noupx,因为已经在 spec 文件里把 upx=False 写死了 + cmd = [ + "pyinstaller", + "build_algorithm.spec", + "--clean" + ] + + try: + subprocess.check_call(cmd, shell=True) + except subprocess.CalledProcessError: + print("Error: PyInstaller failed.") + sys.exit(1) + + # 4. 复制外部资源文件夹 + print("[3/3] Verifying and Copying external resources...") + + for src_name, dst_name in RESOURCES.items(): + src_path = os.path.join(BASE_DIR, src_name) + dst_path = os.path.join(TARGET_DIR, dst_name) + + if os.path.exists(src_path): + if os.path.isfile(src_path): + # 如果是文件 + try: + shutil.copy2(src_path, dst_path) + print(f" Copied file: {src_name} -> {dst_name}") + except Exception as e: + print(f" Error copying file {src_name}: {e}") + else: + # 如果是文件夹 + if os.path.exists(dst_path): + try: + shutil.rmtree(dst_path) # 先删除 spec 生成的旧文件夹 (如果有) + except Exception as e: + print(f" Warning: Could not remove existing dir {dst_path}: {e}") + try: + shutil.copytree(src_path, dst_path, ignore=shutil.ignore_patterns('*.pyc', '__pycache__')) + print(f" Copied dir: {src_name} -> {dst_name}") + except Exception as e: + print(f" Error copying dir {src_name}: {e}") + else: + print(f" Warning: Source resource not found at {src_path}") + + print("\n" + "="*50) + print(f"SUCCESS! Build artifacts are in: {TARGET_DIR}") + print("="*50) + +if __name__ == "__main__": + main() diff --git a/Debug_64ch_Decoder/config.ini b/Debug_64ch_Decoder/config.ini new file mode 100644 index 0000000..f6efd88 --- /dev/null +++ b/Debug_64ch_Decoder/config.ini @@ -0,0 +1,153 @@ +[system] +SSVEP_ThresholdValue = [1,-0.023] +;SSVEP_ThresholdValue = [2,-0.00200] +SSMVEP_IntervalEpoch = [0.2,2.2] +MI_IntervalEpoch = [0.5,4.5] +Device_type=2 +Right_rehabilitation = 5 +Fault_rehabilitation = 5 +Num_blocks = 3 +Num_trials = 20 +Audio_device = -1 +Rest_time = 2 +Serial_port = COM44 + + +[Layout] +main_splitter_left = 993 +main_splitter_right = 922 +right_splitter_left = 233 +right_splitter_right = 771 +left_splitter_left = 503 +left_splitter_right = 501q + +[channel] +channel_x_fp1 = 419 +channel_y_fp1 = 124 +channel_x_fc1 = 439 +channel_y_fc1 = 296 +channel_x_fp2 = 576 +channel_y_fp2 = 124 +channel_x_fc2 = 556 +channel_y_fc2 = 299 +channel_x_f3 = 397 +channel_y_f3 = 231 +channel_x_cp1 = 439 +channel_y_cp1 = 426 +channel_x_f4 = 601 +channel_y_f4 = 232 +channel_x_cp2 = 559 +channel_y_cp2 = 425 +channel_x_fc3 = 379 +channel_y_fc3 = 295 +channel_x_af4 = 571 +channel_y_af4 = 171 +channel_x_po8 = 645 +channel_y_po8 = 564 +channel_x_fpz = 499 +channel_y_fpz = 112 +channel_x_fcz = 499 +channel_y_fcz = 300 +channel_x_poz = 500 +channel_y_poz = 554 +channel_x_po5 = 387 +channel_y_po5 = 551 +channel_x_po6 = 611 +channel_y_po6 = 551 +channel_x_c3 = 373 +channel_y_c3 = 363 +channel_x_fc5 = 319 +channel_y_fc5 = 292 +channel_x_c4 = 620 +channel_y_c4 = 363 +channel_x_fc6 = 676 +channel_y_fc6 = 288 +channel_x_p3 = 398 +channel_y_p3 = 491 +channel_x_cp5 = 322 +channel_y_cp5 = 430 +channel_x_p4 = 600 +channel_y_p4 = 489 +channel_x_cp6 = 678 +channel_y_cp6 = 430 +channel_x_c5 = 313 +channel_y_c5 = 361 +channel_x_f6 = 650 +channel_y_f6 = 223 +channel_x_f5 = 349 +channel_y_f5 = 224 +channel_x_po4 = 573 +channel_y_po4 = 551 +channel_x_po3 = 429 +channel_y_po3 = 550 +channel_x_cp4 = 619 +channel_y_cp4 = 424 +channel_x_cp3 = 381 +channel_y_cp3 = 426 +channel_x_fc4 = 619 +channel_y_fc4 = 295 +channel_x_o1 = 423 +channel_y_o1 = 598 +channel_x_ft9 = 252 +channel_y_ft9 = 168 +channel_x_o2 = 576 +channel_y_o2 = 597 +channel_x_ft10 = 798 +channel_y_ft10 = 277 +channel_x_f7 = 295 +channel_y_f7 = 214 +channel_x_tp9 = 202 +channel_y_tp9 = 445 +channel_x_f8 = 701 +channel_y_f8 = 215 +channel_x_t7 = 252 +channel_y_t7 = 362 +channel_x_tp7 = 261 +channel_y_tp7 = 436 +channel_x_ft8 = 734 +channel_y_ft8 = 283 +channel_x_ft7 = 264 +channel_y_ft7 = 286 +channel_x_af8 = 645 +channel_y_af8 = 159 +channel_x_af7 = 351 +channel_y_af7 = 160 +channel_x_p6 = 652 +channel_y_p6 = 499 +channel_x_p5 = 348 +channel_y_p5 = 499 +channel_x_c6 = 683 +channel_y_c6 = 362 +channel_x_f1 = 447 +channel_y_f1 = 236 +channel_x_t8 = 745 +channel_y_t8 = 361 +channel_x_f2 = 549 +channel_y_f2 = 235 +channel_x_p7 = 300 +channel_y_p7 = 505 +channel_x_c1 = 435 +channel_y_c1 = 363 +channel_x_p8 = 698 +channel_y_p8 = 508 +channel_x_c2 = 559 +channel_y_c2 = 359 +channel_x_fz = 499 +channel_y_fz = 238 +channel_x_po7 = 354 +channel_y_po7 = 562 +channel_x_tp8 = 735 +channel_y_tp8 = 438 +channel_x_oz = 498 +channel_y_oz = 609 +channel_x_af3 = 428 +channel_y_af3 = 170 +channel_x_pz = 501 +channel_y_pz = 486 +channel_x_p2 = 551 +channel_y_p2 = 483 +channel_x_cz = 499 +channel_y_cz = 361 +channel_x_p1 = 448 +channel_y_p1 = 488 + diff --git a/Debug_64ch_Decoder/online_Models/Model_2025-11-15-11-11-50.pth b/Debug_64ch_Decoder/online_Models/Model_2025-11-15-11-11-50.pth new file mode 100644 index 0000000..5df5491 Binary files /dev/null and b/Debug_64ch_Decoder/online_Models/Model_2025-11-15-11-11-50.pth differ diff --git a/Debug_64ch_Decoder/online_Models/Model_2025-11-17-16-55-25.pth b/Debug_64ch_Decoder/online_Models/Model_2025-11-17-16-55-25.pth new file mode 100644 index 0000000..efe3af6 Binary files /dev/null and b/Debug_64ch_Decoder/online_Models/Model_2025-11-17-16-55-25.pth differ diff --git a/Debug_64ch_Decoder/online_Models/Model_2025-11-18-10-15-35.pth b/Debug_64ch_Decoder/online_Models/Model_2025-11-18-10-15-35.pth new file mode 100644 index 0000000..b527c6b Binary files /dev/null and b/Debug_64ch_Decoder/online_Models/Model_2025-11-18-10-15-35.pth differ diff --git a/Debug_64ch_Decoder/online_Models/log_result.txt b/Debug_64ch_Decoder/online_Models/log_result.txt new file mode 100644 index 0000000..f29601a --- /dev/null +++ b/Debug_64ch_Decoder/online_Models/log_result.txt @@ -0,0 +1,252 @@ +0 0.5 +1 0.5 +2 0.5 +3 0.5 +4 0.5 +5 0.5 +6 0.4375 +7 0.5 +8 0.5 +9 0.5 +10 0.5 +11 0.5 +12 0.3125 +13 0.5 +14 0.375 +15 0.5625 +16 0.3125 +17 0.4375 +18 0.4375 +19 0.4375 +20 0.375 +21 0.375 +22 0.4375 +23 0.4375 +24 0.4375 +25 0.4375 +26 0.4375 +27 0.4375 +28 0.4375 +29 0.5 +30 0.5 +31 0.5625 +32 0.375 +33 0.5625 +34 0.5 +35 0.4375 +36 0.5 +37 0.4375 +38 0.4375 +39 0.5625 +40 0.5 +41 0.5 +42 0.5 +43 0.5 +44 0.5 +45 0.5 +46 0.5625 +47 0.5625 +48 0.4375 +49 0.4375 +50 0.5 +51 0.5625 +52 0.5 +53 0.4375 +54 0.5 +55 0.625 +56 0.4375 +57 0.625 +58 0.5 +59 0.5 +60 0.5 +61 0.5625 +62 0.625 +63 0.625 +64 0.5 +65 0.5625 +66 0.5 +67 0.5 +68 0.5 +69 0.5 +70 0.625 +71 0.5 +72 0.4375 +73 0.5625 +74 0.5625 +75 0.625 +76 0.4375 +77 0.4375 +78 0.4375 +79 0.5625 +80 0.5 +81 0.5 +82 0.4375 +83 0.4375 +84 0.4375 +85 0.4375 +86 0.625 +87 0.5625 +88 0.4375 +89 0.4375 +90 0.5625 +91 0.4375 +92 0.4375 +93 0.5 +94 0.4375 +95 0.5625 +96 0.5625 +97 0.5 +98 0.625 +99 0.5625 +100 0.5 +101 0.5 +102 0.5 +103 0.5 +104 0.5 +105 0.625 +106 0.625 +107 0.625 +108 0.4375 +109 0.5625 +110 0.5 +111 0.625 +112 0.5625 +113 0.5 +114 0.5 +115 0.625 +116 0.5 +117 0.5625 +118 0.625 +119 0.625 +120 0.4375 +121 0.4375 +122 0.4375 +123 0.5 +124 0.625 +125 0.625 +126 0.625 +127 0.625 +128 0.6875 +129 0.5625 +130 0.5625 +131 0.4375 +132 0.4375 +133 0.4375 +134 0.4375 +135 0.5625 +136 0.625 +137 0.5625 +138 0.5 +139 0.4375 +140 0.5 +141 0.625 +142 0.625 +143 0.5625 +144 0.625 +145 0.5625 +146 0.5625 +147 0.5 +148 0.5 +149 0.5 +150 0.4375 +151 0.4375 +152 0.5625 +153 0.625 +154 0.5 +155 0.625 +156 0.625 +157 0.625 +158 0.5625 +159 0.5625 +160 0.5625 +161 0.625 +162 0.5 +163 0.5625 +164 0.625 +165 0.4375 +166 0.5625 +167 0.625 +168 0.625 +169 0.5625 +170 0.5625 +171 0.5 +172 0.4375 +173 0.5625 +174 0.5 +175 0.4375 +176 0.5625 +177 0.5 +178 0.5625 +179 0.5625 +180 0.5625 +181 0.5 +182 0.5625 +183 0.5 +184 0.5625 +185 0.5625 +186 0.5625 +187 0.5 +188 0.4375 +189 0.5 +190 0.4375 +191 0.4375 +192 0.5 +193 0.5625 +194 0.5625 +195 0.5625 +196 0.625 +197 0.5 +198 0.5625 +199 0.625 +200 0.5 +201 0.5 +202 0.625 +203 0.5625 +204 0.625 +205 0.5 +206 0.5 +207 0.625 +208 0.625 +209 0.5625 +210 0.625 +211 0.4375 +212 0.5625 +213 0.5 +214 0.5 +215 0.5625 +216 0.4375 +217 0.5 +218 0.5625 +219 0.5 +220 0.625 +221 0.5625 +222 0.5625 +223 0.625 +224 0.5625 +225 0.5625 +226 0.625 +227 0.5625 +228 0.6875 +229 0.5 +230 0.5625 +231 0.625 +232 0.5 +233 0.625 +234 0.5 +235 0.5 +236 0.5 +237 0.4375 +238 0.625 +239 0.5625 +240 0.5625 +241 0.5 +242 0.5 +243 0.5625 +244 0.5625 +245 0.5625 +246 0.625 +247 0.5 +248 0.5 +249 0.4375 +The average accuracy is: 0.5235 +The best accuracy is: 0.6875 diff --git a/Debug_64ch_Decoder/rthook.py b/Debug_64ch_Decoder/rthook.py new file mode 100644 index 0000000..40131dc --- /dev/null +++ b/Debug_64ch_Decoder/rthook.py @@ -0,0 +1,22 @@ +import sys +import os +import multiprocessing + +# ============================================================ +# 0. Matplotlib 非交互式后端(必须在导入 matplotlib.pyplot 之前设置) +# plot_MI_EEG.py 等模块会用到 pyplot,必须在打包后的无显示器环境下工作 +# ============================================================ +if getattr(sys, 'frozen', False): + import matplotlib + matplotlib.use('Agg') + os.environ.setdefault('MPLBACKEND', 'Agg') + +# 1. 路径自适应:在 Frozen 模式下,将当前工作目录切换到可执行文件所在目录 +# 这样代码中使用的相对路径(如 './config.ini')就能正确指向 exe 旁边的文件 +if getattr(sys, 'frozen', False): + os.chdir(os.path.dirname(sys.executable)) + +# 2. 多进程保护:防止 Windows 下的无限递归炸弹 +# Windows 下 multiprocessing 需要 freeze_support() +if sys.platform.startswith('win'): + multiprocessing.freeze_support() diff --git a/Debug_64ch_Decoder/runDecoder.py b/Debug_64ch_Decoder/runDecoder.py new file mode 100644 index 0000000..e98c7d8 --- /dev/null +++ b/Debug_64ch_Decoder/runDecoder.py @@ -0,0 +1,17 @@ + +import time +from Decoder import Decoder_main +from PubLibrary.RunOnce import is_program_running + +if __name__ == "__main__": + + if not is_program_running(): + decoder = Decoder_main() + decoder.connect() + + try: + decoder.start() + while not decoder.zmqServer.IsExitApp: + time.sleep(1) + except KeyboardInterrupt: + decoder.stop() \ No newline at end of file diff --git a/Debug_64ch_Decoder_Optimize/Decoder.py b/Debug_64ch_Decoder_Optimize/Decoder.py new file mode 100644 index 0000000..9b79d83 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/Decoder.py @@ -0,0 +1,632 @@ +import ast +import threading +from datetime import datetime +import multiprocessing as mp +import numpy as np +import time +import torch +from queue import Empty +from scipy import signal +from torch.autograd import Variable +from Device.SunnyLinker import SunnyLinker64 +from SSMVEP.algorithm.tdca import TDCA +from SSMVEP.algorithm.base import generate_cca_references +from concentration.algorithm.calculate_focus import Calculate +from blinkdetection.algorithm.eye_detection import blink_detection +from Zmq.zmqServer import zmqServer +from Zmq.zmqClient import zmqClient +from MI.Algorithm.conformer_2class import onlineTrain +from PubLibrary.InifileHelper import IniRead +from SSVEP.dwfbcca import FbccaDw +from Tools.plot_MI_EEG import plotMain +from collections import deque + +class Decoder_main(threading.Thread): + def __init__(self): + threading.Thread.__init__(self) + self.Runing=True + self.decoder = None + + self.fs = 250 # 采样率 + self.energy = 0 # 电量 + self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常 + self.decoder_class = None #解码器类别 + + self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果 + + def connect(self, device_type=None, device_host=None, device_port=None, upper_host=None, upper_port=None): + self.DeviceType = device_type if device_type is not None else int(IniRead('system', 'Device_type')) + _device_host = device_host if device_host is not None else str(IniRead('system', 'Device_Host')) + _device_port = device_port if device_port is not None else int(IniRead('system', 'Device_Port')) + _upper_host = upper_host if upper_host is not None else str(IniRead('system', 'Upper_Host')) + _upper_port = upper_port if upper_port is not None else int(IniRead('system', 'Upper_Port')) + + if self.DeviceType == 1: + self.thread_data_server = SunnyLinker64(_device_host, _device_port, self.fs, 64, method='tcp') + self.thread_data_server.host = _device_host + self.thread_data_server.port = _device_port + + self.thread_data_server.toUv = True + self.thread_data_server.start() + + self.zmqServer = zmqServer() + self.zmqServer.start() + + self.zmqClient = zmqClient(_upper_host, _upper_port) + self.zmqClient.set_zmq_server(self.zmqServer) + self.zmqClient.connect() + + def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号 + # data: (chans, samples) + energy = np.mean(np.var(data, axis=1)) # 各通道方差均值 + if energy > threshold: + return False + return True + def init_Decoder(self,decoder_class): + ''' + 初始化解码器 + :param decoder_class: 'ssvep' or 'ssmvep' or 'mi' or 'concentration' or '' + :return: + ''' + self.decoder_class = decoder_class + if decoder_class == 'ssvep' or decoder_class == 'pvs': + self.n_chan = 8 + self.thread_data_server.interval_inited = False + DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue')) + self.ListFreq = self.zmqServer.targetFreqs + self.num_target = len(self.ListFreq) + if self.num_target == 0: + return + # 初始化对象 二代算法 + self.dw = FbccaDw(self.fs, self.num_target, self.n_chan, 5, 5, + 0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method) + # frequence band + self.dw.filterFrequenceBank() + self.dw.setNotchFilterPara() + self.calculateCount = 0 + self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.fs), + 5) + self.dw.filterInit() + self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 + + elif decoder_class == 'ssmvep': + self.thread_data_server.interval_init(decoder_class) + self.n_chan = 8 + self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) + self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 + self.single_train = 10 # 单类别数量 + self.num_target = 2 # 分类目标数目 + self.list_freqs = np.array([8, 9]) # 刺激频率 + self.list_phase = np.array([0, 0]) # 相位 + self.tdca = TDCA(padding_len=5, n_components=1) + self.Yf = generate_cca_references(self.list_freqs, srate=self.fs, T=self.sample_length, + phases=self.list_phase, n_harmonics=5) + self.parameter_init(5,45) + + elif decoder_class == 'mi' or decoder_class == 'ma': + self.thread_data_server.interval_init(decoder_class) + self.n_chan = 21 + self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) + self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 + self.single_train = 40 # 单类别数量 + self.num_target = 2 # 分类目标数目 + + self.parameter_init(8, 30) + + elif decoder_class == 'concentration': + self.thread_data_server.interval_inited = False + self.n_chan = 6 + self.win_len = 10 + self.win_step = 1 + self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue')) + self.calculate = Calculate(self.low_threshold, self.high_threshold, self.fs, self.win_len) + self.interval_epoch = [0, 1] + self.parameter_init(2, 40) + # self.eegQueue moved to Calculate class + + elif decoder_class == 'blink': + self.n_chan = 2 + self.l_freq = 0.1 # 带通滤波器低频截止 + self.h_freq = 8.0 # 带通滤波器高频截止 + self.total_samples = 0 # 总采样点数 + self.window_ms = 600 # 检测窗口大小 (ms) + self.step_ms = 100 # 滑动步长 (ms) + self.window_samples = int(self.window_ms * self.fs / 1000) # 150个样本点 + self.step_samples = int(self.step_ms * self.fs / 1000) # 25个样本点 + self.buffer_size = self.window_samples + self.step_samples * 5 + self.fp1_buffer = deque(maxlen=self.buffer_size) + self.fp2_buffer = deque(maxlen=self.buffer_size) + self.sample_counter = 0 + # 预计算滤波器系数,避免在循环中重复设计 + self.Dmin, self.Dmax, self.EMin, self.EMax, self.jitterwin,self.double_blink_interval,self.double_blink_jitter = ast.literal_eval(IniRead('system', 'blink')) + self.blink_count = 0 # 单次眨眼的次数 + self.last_blink_time = 0 # 上次检测到单次眨眼的时间(样本索引) + self.blink_timestamps = deque(maxlen=10) # 记录最近10次 单次眨眼的时间戳 + self.double_blink_count = 0 # 连续两次眨眼的次数 + self.double_blink_events = [] # 连续眨眼事件记录 + self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳 + self.blink_events = [] + self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.fs / 2), self.h_freq / (self.fs / 2)], btype='band') + + def parameter_init(self,bandPass_low,bandPass_high): + self.interval_epoch = [int(i * self.fs) for i in self.interval_epoch] # epoch截取信息 + self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.fs)] # 训练样本epoch + self.trainData = [] #训练数据 + self.trainLabel = [] #训练标签 + self.plotData = [] #报告分析数据 + self.plotLabel = [] #报告分析标签 + self.currentLabel = -1 #刺激界面当前显示的训练标签 + self.train_started = False #是否开始训练模型 + self.load_model = False # 调用模型是否完成的标志 + self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子 + self.b_design = signal.firwin(65, [bandPass_low / (self.fs/2), bandPass_high / (self.fs/2)], pass_zero=False) # 设计8-30Hz带通滤波器 + fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + filePath = './online_Models/' + self.modelPath = ''.join([filePath, fileName, '.pth']) + self.mp_data_queue = mp.Queue() #多进程传参队列 + self.mp_result_queue = mp.Queue() #多进程结果队列 + + def preprocess(self, signal_data): + # # 计算每行的平均值 + row_means = np.mean(signal_data, axis=-1, keepdims=True) + # 对每一行去均值 + signal_data = signal_data - row_means + + signal_data = signal.lfilter(self.b_notch, self.a_notch, signal_data, axis=-1) # 工频陷波 + signal_data = signal.lfilter(self.b_design, 1, signal_data, axis=-1) # 带通滤波 + return signal_data + + def run(self): + while self.Runing: + if self.zmqServer.decoder_switch or self.zmqServer.changeTarget: + print(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}") + self.zmqServer.decoder_switch = False + self.zmqServer.changeTarget = False + self.reset_state() # 切换前先统一清理旧状态 + self.init_Decoder(self.zmqServer.decoder_class) + + # 同步信息 + if self.zmqServer.state_mode == 'sync': + self.zmqClient.send_to_all('sync', self.zmqClient.state) + self.zmqServer.state_mode = 'rest' + # 状态异常,报告上位机 + if self.status_code != self.thread_data_server.status_code: + self.status_code = self.thread_data_server.status_code + self.zmqClient.send_to_all('status_code', int(self.status_code)) + print('status code') + + # 返回电量 + if self.energy != self.thread_data_server.energy: + self.energy = self.thread_data_server.energy + self.zmqClient.send_to_all('energy', int(self.energy)) + print('energy') + + if self.zmqServer.open_Impedance == True: # 开启阻抗检测功能,仅运行一次 + self.thread_data_server.Impedance(True) + print('Impedance') + self.zmqServer.open_Impedance = -1 + elif self.zmqServer.open_Impedance == False: + self.thread_data_server.Impedance(False) + self.zmqServer.open_Impedance = -1 + + if self.zmqServer.get_Impedance: # 返回阻抗值 + # print(self.zmqServer.get_Impedance) + # print(self.thread_data_server.GetDataLenCount()) + if self.thread_data_server.GetDataLenCount() > 250: + Impe_data = self.thread_data_server.getData(250) + # 计算阻抗 + imps = self.thread_data_server.getImpedance(Impe_data,self.zmqServer.decoder_class) + self.zmqClient.send_to_all('impedance', imps.tolist()) + else: + pass + if self.zmqServer.getReport: #返回训练报告内容 + self.zmqServer.getReport = False + allData = np.array(self.plotData) + allLabel = np.array(self.plotLabel) + 1 + nTrials = min(len(allLabel),len(allData)) + if nTrials < 30: + self.zmqClient.send_to_all('miReport',0) + else: + allData = allData[:nTrials] + allLabel = allLabel[:nTrials] + ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', + 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'] + compare_names = ['C3', 'CZ', 'C4'] + miReport = plotMain(ch_names=ch_names,compare_names=compare_names,Data=allData,labels=allLabel,MI_label=1,Rest_label=2, + fs=self.fs) + self.zmqClient.send_to_all('miReport',miReport) + + + # --- 取数优先:先执行 decoder(消费环形缓冲),再处理 plot/report 等重负载 --- + try: + if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs': + self.decoder_SSVEP() + elif self.decoder_class == 'ssmvep': + self.decoder_SSMVEP() + elif self.decoder_class == 'mi': + self.decoder_MI() + elif self.decoder_class == 'concentration': + self.decoder_concentration() + elif self.decoder_class == 'blink': + self.decoder_blink() + else: + if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 + if self.thread_data_server.GetDataLenCount() < 25: + time.sleep(0.005) + continue; + self.thread_data_server.getData(25) + except Exception as e: + print(f"Decoder Loop Error: {e}") + import traceback + traceback.print_exc() + time.sleep(0.1) # Prevent CPU spin if error is persistent + + def decoder_SSVEP(self): + if self.zmqServer.StartDecode: + self.zmqServer.StartDecode = False + self.decodingSteps = 1 + self.thread_data_server.ResetAll() + print('启动预测') + if self.thread_data_server.GetDataLenCount() < 50: + time.sleep(0.005) + return + if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码 + return + data = self.thread_data_server.getDataViaSSVEP(50) + data = data[:self.n_chan, :] + if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热 + self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 + self.dw.warmFilter(data) # 预热 + self.decodingSteps = 2 + print('预热数据完成。开始预测') + return + if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中 + choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount) + self.calculateCount += 1 + if choosenNum != -1 and self.is_valid_signal(data): + self.decodingSteps = 3 + print('预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount)) + self.calculateCount = 0 + if self.decodingSteps == 3: # 发送解码后的信息 + self.zmqClient.send_to_all('result', int(choosenNum)) + self.decodingSteps = 0 + print('发送给界面完成。') + + def decoder_SSMVEP(self): + '''模型训练''' + if self.load_model == False and all( + self.trainLabel.count(i) >= self.single_train for i in range(len(self.list_freqs))): # 模型尚未训练完成 + self.trainData = np.array(self.trainData) + self.trainLabel = np.array(self.trainLabel) + print(np.shape(self.trainData), (self.trainLabel)) + # 保存多个数组到文件 + # np.savez('20250520_yy.npz', array1=self.trainData, array2=self.trainLabel) + # self.decoder = self.fbtdca.fit(self.trainData, self.trainLabel, Yf=self.Yf) + self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf) + now = datetime.now() + formatted_time = now.strftime('%H:%M:%S.%f')[:-3] + print('模型训练完成', formatted_time) + self.load_model = True + self.zmqClient.send_to_all('paradigm', 1) + + '''训练阶段采集数据''' + if self.zmqServer.state_mode == 'train': # 训练状态 + if self.zmqServer.StartTrain: + self.currentLabel = self.zmqServer.currentLabel + self.zmqServer.StartTrain = False + if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ + self.train_epoch[1] \ + + self.thread_data_server.event_inner_idx: + time.sleep(0.0001) + return + print('训练队列数据:', self.thread_data_server.GetDataLenCount()) + trainTrial = self.thread_data_server.get_SSMVEPData() # 取出所有数据 + print('取出的: ', trainTrial.shape, 'event: ', trainTrial[-2, self.thread_data_server.event_inner_idx]) + trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 + trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.train_epoch[ + 0]:self.thread_data_server.event_inner_idx + self.train_epoch[1]] + print('trial: ', self.thread_data_server.event_inner_idx, self.train_epoch[0], self.train_epoch[1]) + if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance( + self.trainLabel, list) \ + and self.trainLabel.count(self.currentLabel) < self.single_train: + self.trainData.append(trainTrial) + self.trainLabel.append(self.currentLabel) + + elif self.zmqServer.state_mode == 'predict': # 测试状态 + if self.load_model == False: # 模型尚未训练完成 + time.sleep(0.01) + return + else: # 已有模型 + if self.zmqServer.StartDecode: + self.zmqServer.StartDecode = False + now = datetime.now() + formatted_time = now.strftime('%H:%M:%S.%f')[:-3] + print('启动预测 ', formatted_time) + + if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ + self.interval_epoch[1] \ + + self.thread_data_server.event_inner_idx: + time.sleep(0.0001) + return + data = self.thread_data_server.get_SSMVEPData() # 读取全部数据 + print('取出的: ', data.shape, 'event: ', data[-2, self.thread_data_server.event_inner_idx]) + data = self.preprocess(data[:self.n_chan, :]) # 预处理 + data = data[:, + self.thread_data_server.event_inner_idx + self.interval_epoch[ + 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]] + pad_eeg_test = np.zeros( + (data.shape[0], int((self.sample_length + 0.1) * self.fs))) + pad_eeg_test[:, :int(self.sample_length * self.fs)] = data + choosenNum, features_2 = self.decoder.predict(pad_eeg_test) + if isinstance(choosenNum, np.ndarray): + choosenNum = choosenNum[0] + print('结果:', choosenNum, 'rho: ', sorted(features_2[0]), + sorted(features_2[0])[-1] - sorted(features_2[0])[-2]) + self.zmqClient.send_to_all('result', int(choosenNum)) + print('发送给界面完成。') + else: # 休息状态 + if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 + if self.thread_data_server.GetDataLenCount() < 25: + time.sleep(0.005) + return + self.thread_data_server.getData(25) + + def decoder_MI(self): + '''模型训练''' + if self.train_started == False and all( + self.trainLabel.count(i) >= self.single_train for i in range(self.num_target)): # 模型尚未训练 + self.zmqClient.send_to_all('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机 + self.train_started = True + self.trainData = np.array(self.trainData) + self.trainLabel = np.array(self.trainLabel) + 1 + # print('训练集:',np.shape(self.trainData), (self.trainLabel)) + p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型 + p.start() + self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath, + 'n_chan': self.n_chan}) + + '''检查模型是否训练完成,调用''' + if self.load_model == False and self.train_started == True: + try: + result = self.mp_result_queue.get_nowait() + if result['status'] == 'success': + print("模型训练完成,加载新模型") + # 调用模型 + self.model = torch.load(self.modelPath, weights_only=False) + self.model.eval() + # 模型预热 + warmup_data = np.random.uniform(-1, 1, (1, 1, self.n_chan, 1000)) + warmup_data = torch.from_numpy(warmup_data) + warmup_data = Variable(warmup_data.type(torch.cuda.FloatTensor)) + with torch.no_grad(): + _ = self.model(warmup_data) + self.load_model = True + self.zmqClient.send_to_all('paradigm', 1) # 模型调用完毕,通知上位机 + else: + print("训练失败:", result['msg']) + except Empty: + pass # 还没完成 + except Exception as e: + print('模型调用失败: ', e) + + '''训练阶段采集数据''' + if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态 + if self.zmqServer.StartTrain: + self.currentLabel = self.zmqServer.currentLabel + self.zmqServer.StartTrain = False + if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ + self.interval_epoch[1] \ + + self.thread_data_server.event_inner_idx: + time.sleep(0.0001) + return + print('训练队列数据:', self.thread_data_server.GetDataLenCount()) + originalTrial = self.thread_data_server.get_MIData() # 取出MI导联数据 + print('取出的: ', originalTrial.shape, 'event: ', originalTrial[-2, self.thread_data_server.event_inner_idx]) + trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理 + trainTrial = trainTrial[:, self.thread_data_server.event_inner_idx + self.interval_epoch[ + 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]] + print('trial: ', self.thread_data_server.event_inner_idx, self.interval_epoch[0], self.interval_epoch[1]) + if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel, + list) \ + and self.trainLabel.count(self.currentLabel) < self.single_train: + self.trainData.append(trainTrial) + self.trainLabel.append(self.currentLabel) + print('训练集:', np.shape(self.trainData)) + self.plotData.append(originalTrial[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[ + 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]) + self.plotLabel.append(self.currentLabel) + + elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态 + if self.zmqServer.StartDecode: + self.zmqServer.StartDecode = False + now = datetime.now() + formatted_time = now.strftime('%H:%M:%S.%f')[:-3] + print('启动预测 ', formatted_time) + + if self.thread_data_server.epoch_finished == False or self.thread_data_server.GetDataLenCount() < \ + self.interval_epoch[1] \ + + self.thread_data_server.event_inner_idx: + time.sleep(0.0001) + return + originalData = self.thread_data_server.get_MIData() # 读取全部数据 + print('取出的: ', originalData.shape, 'event: ', originalData[-2, self.thread_data_server.event_inner_idx]) + start = time.time() + data = self.preprocess(originalData[:self.n_chan, :]) # 预处理 + data = data[:, + self.thread_data_server.event_inner_idx + self.interval_epoch[ + 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]] + self.plotData.append( + originalData[:self.n_chan, self.thread_data_server.event_inner_idx + self.interval_epoch[ + 0]:self.thread_data_server.event_inner_idx + self.interval_epoch[1]]) + + test_data = data[np.newaxis, np.newaxis, :, :] + test_data = torch.from_numpy(test_data) + test_data = Variable(test_data.type(torch.cuda.FloatTensor)) + with torch.no_grad(): + Cls = self.model(test_data) + y_pred = torch.max(Cls, 1)[1] + self.plotLabel.append(int(y_pred.item())) + print('运动意图识别: ', y_pred) + self.zmqClient.send_to_all('result', int(y_pred.item())) + end = time.time() + print(f'发送给界面完成,耗时{end - start:.3f}s。') + else: # 休息状态 + if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 + if self.thread_data_server.GetDataLenCount() < 25: + time.sleep(0.005) + return + self.thread_data_server.getData(25) + + def decoder_concentration(self): + if self.zmqServer.state_mode == 'predict': + if self.zmqServer.StartDecode: + self.zmqServer.StartDecode = False + self.thread_data_server.ResetAll() + now = datetime.now() + formatted_time = now.strftime('%H:%M:%S.%f')[:-3] + print('启动专注力预测 ', formatted_time) + if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.fs): # 每win_step得出一次结果 + time.sleep(0.005) + return + if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码 + return + data = self.thread_data_server.get_concentrateData(int(self.win_step * self.fs)) # 修改每次读取的数据 + result = self.calculate.queueOpt(data) + if result is not None: + self.zmqClient.send_to_all('result', int(result)) + else: # 休息状态 + if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 + if self.thread_data_server.GetDataLenCount() < 25: + time.sleep(0.005) + return + self.thread_data_server.getData(25) + +#### Blink detection ##### + def check_double_blink(self, current_time): + """ + 检查是否检测到连续两次眨眼 + @param current_time: 当前眨眼时间戳 + @return: True表示检测到连续两次眨眼 + """ + if len(self.blink_timestamps) < 2: + return False + + # 检查是否在去抖期内 + if self.last_double_blink_time > 0: + time_since_last_double_blink = current_time - self.last_double_blink_time + if time_since_last_double_blink < self.double_blink_jitter: + return False # 在去抖期内,忽略连续眨眼检测 + last_time = self.blink_timestamps[-1] # 当前眨眼 + prev_time = self.blink_timestamps[-2] # 上次眨眼 + + interval = last_time - prev_time + if interval <= self.double_blink_interval: + return True + + return False + + def process_blink_detection(self): + """ + 在缓冲区数据上执行,单次眨眼检测 + """ + if len(self.fp1_buffer) < self.window_samples: + return + + fp1_data = np.array(list(self.fp1_buffer)[-self.window_samples:]) + fp2_data = np.array(list(self.fp2_buffer)[-self.window_samples:]) + # 计算FP1和FP2的平均 + fp12_mean = (fp1_data + fp2_data) / 2.0 + # 带通滤波 + try: + fp12_filtered = signal.filtfilt(self.blink_b, self.blink_a, fp12_mean) + except Exception as e: + print(f"Filter error: {e}") + return + F = np.diff(fp12_filtered) + if len(F) < 3: + return + b, d, e = blink_detection(F, self.fs, self.Dmin, self.Dmax, self.EMin, self.EMax) + + if b == 1: + samples_since_last = self.total_samples - self.last_blink_time + time_since_last_ms = (samples_since_last / self.fs) * 1000 + if time_since_last_ms >= self.jitterwin: # self.jitterwin 单次眨眼去抖 using time_since_last_ms + self.blink_count += 1 + self.last_blink_time = self.total_samples + current_time = time.time() + self.blink_timestamps.append(current_time) + blink_event = { + 'count': self.blink_count, + 'time': current_time, + 'sample_index': self.total_samples, + 'duration_ms': d, + 'energy': e + } + self.blink_events.append(blink_event) + self.zmqClient.send_to_all('result', 1) # 检测到眨眼信号,通知上位机 + if self.check_double_blink(current_time): + self.double_blink_count += 1 + interval = self.blink_timestamps[-1] - self.blink_timestamps[-2] + double_blink_event = { + 'double_blink_count': self.double_blink_count, + 'blink1_time': self.blink_timestamps[-2], + 'blink2_time': self.blink_timestamps[-1], + 'interval': interval + } + self.double_blink_events.append(double_blink_event) + self.last_double_blink_time = current_time + self.zmqClient.send_to_all('result', 2) # 发送双次眨眼事件 + + def decoder_blink(self): + if self.thread_data_server.GetDataLenCount() < 50: + time.sleep(0.005) + return + if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 + data = self.thread_data_server.get_blinkData(50) + fp1_data = data[0, :] # ch1 (相当于FP1) + fp2_data = data[1, :] # ch2 (相当于FP2) + for i in range(len(fp1_data)): + self.fp1_buffer.append(fp1_data[i]) + self.fp2_buffer.append(fp2_data[i]) + self.total_samples += 1 + self.sample_counter += 1 + + if self.sample_counter >= self.step_samples: + self.process_blink_detection() + self.sample_counter = 0 + + def stop(self): + ''' + 停止运行 + @return: + ''' + self.zmqServer.stop() + self.Runing=False + + def reset_state(self): + """清空解码器状态和缓存数据""" + # 重置设备层缓存 + self.thread_data_server.reset_state() + + # 重置解码状态 + self.decodingSteps = 0 + self.calculateCount = 0 + + # 重置训练数据 + self.plotData = [] + self.plotLabel = [] + self.trainData = [] + self.trainLabel = [] + self.currentLabel = -1 + self.train_started = False + self.load_model = False + + # 重置多进程队列,确保切换 decoder 时旧数据不会泄漏到新队列 + if hasattr(self, 'mp_data_queue'): + while not self.mp_data_queue.empty(): + try: self.mp_data_queue.get_nowait() + except Empty: pass + if hasattr(self, 'mp_result_queue'): + while not self.mp_result_queue.empty(): + try: self.mp_result_queue.get_nowait() + except Empty: pass \ No newline at end of file diff --git a/Debug_64ch_Decoder_Optimize/Device/SunnyLinker.py b/Debug_64ch_Decoder_Optimize/Device/SunnyLinker.py new file mode 100644 index 0000000..4543d4b --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/Device/SunnyLinker.py @@ -0,0 +1,814 @@ +# -*-coding:utf-8 -*- +''' +SunnyLinker的通讯驱动 +''' +import ast +import socket +import threading +import time +import datetime +from typing import Dict +from collections import deque +import numpy as np +from threading import Thread, Event +import serial +from scipy import signal +from serial.serialutil import SerialException + +from Device.protocol import ProtocolFrame +from PubLibrary.InifileHelper import IniRead + +class RingBuffer: + def __init__(self, n_chan, n_points): + self.n_chan = n_chan + self.n_points = n_points + self.buffer = np.zeros((n_chan, n_points)) + self.currentPtr = 0 + self.readPtr = 0 + self.nUpdate = 0 + self.rawData = np.zeros((n_chan, 1)) + + ## append buffer and update current pointer + def appendBuffer(self, data): + if self.nUpdate == self.n_points: + raise Exception("Buffer is full") + + n = data.shape[1] + + # 计算可以写入的元素数量 + write_count = min(self.n_points - self.nUpdate, n) + # 写入新数据 + self.buffer[:, np.mod(np.arange(self.currentPtr, self.currentPtr + write_count), self.n_points)] = data[:,:write_count] + # 更新结束指针 + self.currentPtr = (self.currentPtr + write_count) % self.n_points + # 更新大小 + self.nUpdate += write_count + + ## get data from buffer + def getData(self, count=50): + # 确保不会尝试读取超过缓冲区当前大小的数据 + count = min(count, self.nUpdate) + + # 计算读取结束后的下一个位置 + next_read_ptr = (self.readPtr + count) % self.n_points + if self.readPtr + count <= self.n_points: + # 情况 1:不环绕,数据是连续的 + end_index = next_read_ptr if next_read_ptr != 0 else self.n_points + data = self.buffer[:, self.readPtr:end_index] + else: + # 情况 2:发生环绕,数据被分成两部分 + # 第一部分:从 readPtr 到缓冲区末尾 + part1 = self.buffer[:, self.readPtr:] + # 第二部分:从缓冲区开头到 (count - part1.shape[1]) 个点 + part2 = self.buffer[:, :next_read_ptr] + # 将两部分在列方向上拼接 + data = np.concatenate((part1, part2), axis=1) + + # 更新读指针 + self.readPtr = next_read_ptr + # 更新大小 + self.nUpdate -= count + return data + + # reset buffer + def resetAllPara(self): + self.nUpdate = 0 + self.currentPtr = 0 + self.readPtr = 0 # add by lizhenhua 清空读指针 + self.buffer = np.zeros((self.n_chan, self.n_points)) # add by lizhenhua 清空环形缓冲区 + + +class SunnyLinker64(Thread, ): + serial_port = str(IniRead('system', 'Serial_port')) + t_buffer = 10 + n_chan = 64 + srate = 250 + win_len = 10 + win_step = 1 + ring_buffer = 5 + receiveData = b'' + toUv=True#转为uV + RingBufferLock = threading.Lock() + + # 单例模式 + _instance = None + _initialized = False # 检查是否已经初始化 + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(SunnyLinker64, cls).__new__(cls) + return cls._instance + def __init__(self, host='127.0.0.1', port=7878, srate=250, n_chan=64,method = 'tcp'): + if SunnyLinker64._initialized: + return + Thread.__init__(self) + self.daemon = True + self.host = host + self.port = port + self.srate = srate + self.n_chan = n_chan + self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输 + self.__ringBuffer = RingBuffer(self.n_chan + 2, + int(np.round(self.t_buffer * self.srate))) + self.energy = 0 # 电量 + self.status_code = 0 # 与采集设备通信的状态码,0为异常,1为正常 + self.gain_value = 6 # 增益倍数 + self.interval_inited = False #ssmvep或mi时间窗是否初始化 + + # 设置初始化标志为True,防止重复初始化 + SunnyLinker64._initialized = True + + # --- 新增:用于心跳检测 --- + self.last_called = 0 # 初始化为0 + self.last_called_lock = threading.Lock() # 保护 last_called 的访问 + + def reset_state(self): + """清空采集器状态和缓存数据""" + with self.RingBufferLock: + self.__ringBuffer.resetAllPara() + self.count_events = {} + self.epoch_finished = False + self.pack_contain_event = False + self.event_inner_idx = -1 + self.interval_inited = False + + def interval_init(self,decoder_class): + if decoder_class == 'ssmvep': + interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) + self.interval_epoch = [int(i * self.srate) for i in interval_epoch] # epoch截取信息 + self.train_epoch = [int(self.interval_epoch[0]), + int(self.interval_epoch[1] + 0.1 * self.srate)] # 训练样本epoch + self.latency = (self.interval_epoch[ + 1] + 0.1 * self.srate) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点;0.1表示比实际需要的长度多取0.1,会被截掉 + self.train_latency = (self.train_epoch[1] + 0.1 * self.srate) // 5 + + elif decoder_class == 'mi': + interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) + self.interval_epoch = [int(i * self.srate) for i in interval_epoch] # epoch截取信息 + self.train_epoch = self.interval_epoch.copy() + self.latency = (self.interval_epoch[1]) // 5 # 提取epoch的延迟标记,5代表每次解包得到的5位采样点; + self.train_latency = self.latency + + print('时间窗:', (interval_epoch)) + self.count_events: Dict[str, int] = {} # 表示包延迟的计数信息 + self.event_inner_idx = -1 # event在5位数据包内部的idx + self.epoch_finished = False # 接收epoch是否完整 + self.pack_contain_event = False # 当前包是否含有event + self.predict_event = 99 + self.events = [1, 2, self.predict_event] + if getattr(self, 'serial', None) and self.serial.is_open: + self.serial.close() + self.serial = serial.Serial(self.serial_port, 460800, timeout=1) # 连接同步器串口 + self.interval_inited = True + + def set_sampleRate(self,sampleRate_Code=0x00): + ''' + 设置采样率 + :param sampleRate_Code: 0x00:250Hz,0x01:500Hz,0x02:1000Hz,0x03:2000Hz + ''' + function_code = 0x02 + gain_code = 0x06 + sampleRate_Code = [gain_code,sampleRate_Code] + packed_data = ProtocolFrame.pack(function_code, sampleRate_Code) + if self.method == 'tcp': + self.sock.send(packed_data) + + def push_trigger(self,label): + ''' + 数据打标 + @param label:标签类别 + ''' + function_code = None + label = [label] + packed_data = ProtocolFrame.pack(function_code, label) + if self.method == 'tcp' and hasattr(self,'serial'): + print('发送:', label, datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3]) + self.serial.write(packed_data) + def Impedance(self, On): + ''' + 阻抗检测开关 + :param On:True为开启,False为关闭 + :return: 组好的协议帧 + ''' + function_code = 0x01 + if On: + data = [0x1] + self.gain_value = 6 + else: + data = [0x0] + self.gain_value = 6 + packed_data = ProtocolFrame.pack(function_code, data) + if self.method == 'tcp': + self.sock.send(packed_data) + + def connect(self): + try: + if self.method == 'serial': + # 开启com口,波特率115200,超时5 + self.sock = serial.Serial(self.host, self.port, timeout=5) + self.sock.flushInput() # 清空缓冲区 + count = self.sock.inWaiting() # 获取串口缓冲区数据 + while not count: + count = self.sock.inWaiting() # 获取串口缓冲区数据 + # # 接收和存储数据 + data = (self.sock.read(count)) + self.receiveData = self.receiveData + data # 将接收数据存储在字符串中 + elif self.method == 'tcp': + # 重连前关闭旧 socket,避免资源泄漏 + if hasattr(self, 'sock') and self.sock: + try: + self.sock.close() + except Exception: + pass + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.connect((self.host, int(self.port))) + self.set_sampleRate(0x00) #设置250Hz采样率 + return True + except Exception as e: + print("请打开头环") + print(e) + return False + + print("connected") + return True + + def extract_packet(self, packet): + # 存储一个点的八通道数据 + dataList = [] + # 存储116个点的八通道数据 + dataMatrix = [] + + for j in range(5): + for i in range(self.n_chan): + if not self.toUv:#原始数据直接输出 + val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[ + 194 * j + 25 + 2 + i * 3] + + else:#转为uV + val = (packet[194 * j + 25 + i * 3] << 16) | (packet[194 * j + 25 + 1 + i * 3] << 8) | packet[ + 194 * j + 25 + 2 + i * 3] + if val < 8388608: + val = val * 4.5 / self.gain_value / 8388608 * 1000000; + else: + val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000; + dataList.append(val) + #同步触发源 + val = packet[194 * j + 25 + (i+1) * 3] + dataList.append(val) + #同步触发序号 + val = packet[194 * j + 25 + (i+1) * 3+1] + dataList.append(val) + + + # 将数据矩阵进行拼接 + if len(dataMatrix) == 0: + dataMatrix = np.asmatrix(dataList) + else: + dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0) + dataList.clear() + return np.transpose(dataMatrix) + + def run(self): + self.running = True + self.PackageLength = 998 + + # 尝试连接循环,断开后自动重连 + while self.running: + if self.connect(): + break + print(f"无法连接到 {self.host}:{self.port},15秒后重试...") + time.sleep(15) + + # 启动心跳检测线程 + threading.Thread(target=self.heartbeat_checker, daemon=True).start() + while self.running: + try: + if self.method == 'serial': + count = self.sock.inWaiting() # 获取串口缓冲区数据 + if count: + # 接收和存储数据 + data = (self.sock.read(count)) + self.receiveData = self.receiveData + data # 将接收数据存储在字符串中 + elif self.method == 'tcp': + data = self.sock.recv(600) + if not data: + break + self.receiveData += data + with self.last_called_lock: + self.last_called = time.time() + self.status_code = 1 # 收到数据,标记为正常 + if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind( + b'\x55\x55') >= self.PackageLength - 2: + + index = self.receiveData.index(b'\xaa') + self.receiveData = self.receiveData[index:] + if len(self.receiveData) >= self.PackageLength: + onepackage = self.receiveData[:self.PackageLength] + if onepackage[7] != 0: + self.energy = onepackage[7] # 电量 + self.receiveData = self.receiveData[self.PackageLength:] + dataMatrix = self.extract_packet(onepackage) + try: + with self.RingBufferLock: + if self.interval_inited: + self.epoch_finished = self.detect_event(dataMatrix) + if self.pack_contain_event: + self.__ringBuffer.resetAllPara() # 检测到当前pack含有event,清除ringbuffer中之前的数据 + self.__ringBuffer.appendBuffer(dataMatrix) + # self.plotBuffer.appendBuffer(dataMatrix) + if self.epoch_finished: + time.sleep(0.005) + print('epoch_finished: ', datetime.datetime.now().strftime('%H:%M:%S.%f')[:-3]) + else: + self.__ringBuffer.appendBuffer(dataMatrix) + except Exception as e: + print("锁:写入异常",e) + # self.RingBufferLock.release() + except ConnectionResetError: + self.status_code = 0 # 状态异常 + print("Connection was reset by the peer. 正在重新连接...") + self.sock.close() + # 退出循环后,run() 开头的重连循环会自动接管 + break + # 如果 running=True,重连循环会接管,不会执行到这里 + + # 检测是否含有标签 + def detect_event(self, samples): + self.pack_contain_event = False + events = np.array(samples[-2])[0].tolist() + for idx, event in enumerate(events): + if int(event) in self.events: + new_key = "".join( + [ + str(event), + datetime.datetime.now().strftime("%Y-%m-%d \ + -%H-%M-%S"), + ] + ) + if event == self.predict_event: + self.count_events[new_key] = self.latency + 1 + else: + self.count_events[new_key] = self.train_latency + 1 + self.event_inner_idx = idx + self.pack_contain_event = True + drop_items = [] + for key, value in self.count_events.items(): + value = value - 1 + if value == 0: + drop_items.append(key) + self.count_events[key] = value + for key in drop_items: + del self.count_events[key] + if drop_items: + return True + return False + + # --- 新增:心跳检测线程 --- + def heartbeat_checker(self): + """ + 定期检查是否在最近2秒内收到 eegData + 如果超过2秒未收到,则设置 status_code = 0 + """ + while self.running: + time.sleep(0.5) # 每0.5秒检查一次 + with self.last_called_lock: + now = time.time() + # 只有收到过一次数据后才开始判断超时 + if self.last_called > 0 and (now - self.last_called) > 30: + if self.status_code != 0: + print("EEG data timeout: disconnected") + self.status_code = 0 + + def getDataViaSSVEP(self,count): + ''' + ssvep的视觉通道,共8个通道 + @param count: 每通道读取的数值数量 + @return: 返回最新的数值 + ''' + data=self.getData(count) + # PO5,POZ,PO6,O2,PO8,OZ,O1,PO7 64是event导联 + rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55,64] + row_to_select=np.array(rows_to_extract) + data=data[row_to_select,:] + return data + def get_MIData(self): + ''' + 取出当前所有数值 + :return: + ''' + data = self.getData(self.__ringBuffer.nUpdate) + #MI选取导联:FC3,FC1,FCZ,FC2,FC4,C5,C3,C1,CZ,C2,C4,C6,CP3,CP1,CP2,CP4,P3,P1,PZ,P2,P4,event1,event2 + rows_to_extract = [8, 15, 12, 14, 18, 23, 16,59,50,58,17,45,29,11,10,19,20,61,51,60,21,64,65] + row_to_select = np.array(rows_to_extract) + data = data[row_to_select,:] + return data + def get_SSMVEPData(self): + ''' + 取出当前所有数值 + :return: + ''' + data = self.getData(self.__ringBuffer.nUpdate) + # PO5,POZ,PO6,O2,PO8,OZ,O1,PO7 64是event导联 + rows_to_extract = [13, 3, 2, 46, 9, 54, 47, 55, 64,65] + row_to_select = np.array(rows_to_extract) + data = data[row_to_select, :] + return data + + def get_concentrateData(self,count): + ''' + @param count: 每通道读取的数值数量 + @return: 返回最新的数值 + ''' + data=self.getData(count) + rows_to_extract = [0, 1] + row_to_select = np.array(rows_to_extract) + data = data[row_to_select, :] + return data + + def get_blinkData(self,count): + ''' + @param count: 每通道读取的数值数量 + @return: 返回最新的数值 + ''' + data=self.getData(count) + rows_to_extract = [0,1] + row_to_select = np.array(rows_to_extract) + data = data[row_to_select, :] + return data + + def getImpedance(self, data,decoder_class): + ''' + 获取阻抗值,已经放大100倍,单位是kΩ + @param data: 准备计算的通道数据,每通道200个值,注意不要把信号打标的通道传进来 + @return: 返回各个通道的阻抗值 + ''' + impedanceList = [] + for channelindex in range(data.shape[0]): + if len(data[channelindex]) > 0: + data_list = [] + # 设计陷波滤波器,去除50Hz成分 + is50filter = True + if is50filter: + b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽,1000是采样频率 + data_list = signal.lfilter(b, a, data[channelindex].tolist()) + + else: + data_list.extend(data[channelindex].tolist()) + + data_list = data_list[-1000:] + # 执行FFT + fft_result = np.fft.fft(data_list) + fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果 + freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴 + + # y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())], + # fft_magnitude[1:-1] * 2 / len(t[0].tolist()), + # [fft_magnitude[-1] / len(t[0].tolist())])) + + # 找到幅值最大的频率成分的索引(忽略直流分量,即索引0) + max_index = np.argmax(fft_magnitude[1:]) + + # 获取最大幅值的频率索引(加上1,因为索引0是直流分量) + freq_index = max_index + 1 + + # 获取最大幅值 + max_magnitude = fft_magnitude[freq_index] + + # 阻抗 + import math + result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4 + result *= 0.44 * 100 # 统一放大100倍 + impedanceList.append(int(result)) + # print(max_magnitude, result) + else: + impedanceList.append(0) + impedances = np.array(impedanceList) + if decoder_class in ('mi', 'ma'): + impedances = impedances[np.array([8, 15, 12, 14, 18, 23, 16, 59, 50, 58, 17, 45, 29, 11, 10, 19, 20, 61, 51, 60, 21])] + elif decoder_class == 'blink': + impedances = impedances[np.array([0, 1])] + elif decoder_class == 'concentration': + impedances = impedances[np.array([0, 1])] + else: + impedances = impedances[np.array([13, 3, 2, 46, 9, 54, 47, 55])] + return impedances + def getData(self,count): + ''' + 获取最新的数据 + @param count: 每通道返回的最数值数目 + @return: 所有通道的最新count个数值 + ''' + data=None + try: + with self.RingBufferLock: + data = self.__ringBuffer.getData(count) + except: + print("锁:读取异常") + # self.RingBufferLock.release() + + + return data + def GetDataLenCount(self): + ''' + 获取最新缓存中每个通道的数量 + @return: + ''' + return self.__ringBuffer.nUpdate + + def ResetAll(self): + ''' + 清空缓存 + @return: + ''' + with self.RingBufferLock: + self.__ringBuffer.resetAllPara() + def stop(self): + self.running = False + +class SunnyLinker8(Thread, ): + receiveData = '' + t_buffer = 10 + n_chan = 9 + srate = 1000 + receiveData = b'' + toUv=False#转为uV + RingBufferLock = threading.Lock() + def __init__(self, host, port, srate=1000, n_chan=9,method = 'tcp'): + Thread.__init__(self) + self.daemon = True + self.host = host + self.port = port + self.srate = srate + self.n_chan = n_chan + self.method = method #传输方式,'tcp'表示tcp传输,'serial'表示串口传输 + self.__ringBuffer = RingBuffer(self.n_chan + 2, + int(np.round(self.t_buffer * self.srate))) + self.energy = 0 #电量 + self.status_code = 0 #与采集设备通信的状态码,0为异常,1为正常 + self.gain_value = 6 # 增益倍数 + + def push_trigger(self,label): + ''' + 数据打标 + @param label:标签类别 + ''' + function_code = None + label = [label] + packed_data = ProtocolFrame.pack(function_code, label) + if self.method == 'tcp': + self.sock.send(packed_data) + elif self.method == 'serial': + self.sock.write(packed_data) + + def Impedance(self, On): + ''' + 阻抗检测开关 + :param On:True为开启,False为关闭 + :return: 组好的协议帧 + ''' + function_code = None + if On: + data = [0xA1] + self.gain_value = 24 + else: + data = [0xA0] + self.gain_value = 6 + packed_data = ProtocolFrame.pack(function_code, data) + if self.method == 'tcp': + self.sock.send(packed_data) + elif self.method == 'serial': + self.sock.write(packed_data) + + def connect(self): + try: + if self.method == 'serial': + # 开启com口,波特率115200,超时5 + self.sock = serial.Serial(self.host, self.port, timeout=5) + self.sock.flushInput() # 清空缓冲区 + count = self.sock.inWaiting() # 获取串口缓冲区数据 + while not count: + count = self.sock.inWaiting() # 获取串口缓冲区数据 + # # 接收和存储数据 + data = (self.sock.read(count)) + self.receiveData = self.receiveData + data # 将接收数据存储在字符串中 + print("connected") + elif self.method == 'tcp': + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.connect((self.host, int(self.port))) + print("connected") + except Exception as e: + print("请打开头环") + print(e) + + print("connected") + + def extract_packet(self, packet): + # 存储一个点的八通道数据 + dataList = [] + # 存储116个点的八通道数据 + dataMatrix = [] + + # index = (packet[1] << 24) | (packet[2] << 16) | (packet[3] << 8) | packet[4] + # print(index) + + for j in range(5): + for i in range(self.n_chan): + if not self.toUv:#原始数据直接输出 + val = (packet[26 * j + 25 + i * 3] << 16) | (packet[26 * j + 25 + 1 + i * 3] << 8) | packet[ + 26 * j + 25 + 2 + i * 3] + + else:#转为uV + val = (packet[26 * j + 25 + i * 3] << 16) | (packet[26 * j + 25 + 1 + i * 3] << 8) | packet[ + 26 * j + 25 + 2 + i * 3] + if val < 8388608: + val = val * 4.5 / self.gain_value / 8388608 * 1000000; + else: + val = (val - 16777216) * 4.5 / self.gain_value / 8388608 * 1000000; + dataList.append(val) + #同步触发源 + val = packet[26 * j + 25 + (i+1) * 3] + dataList.append(val) + #同步触发序号 + val = packet[26 * j + 25 + (i+1) * 3+1] + dataList.append(val) + + + # 将数据矩阵进行拼接 + if len(dataMatrix) == 0: + dataMatrix = np.asmatrix(dataList) + else: + dataMatrix = np.concatenate((dataMatrix, np.asmatrix(dataList)), axis=0) + dataList.clear() + return np.transpose(dataMatrix) + + def run(self): + self.connect() + self.running = True + self.PackageLength = 158 + start_time = time.time() + + try: + while self.running: + if self.method == 'serial': + end_time = time.time() + if end_time-start_time > 2: #超过2s未收到数据 + self.status_code = 0 #状态异常 + count = self.sock.inWaiting() # 获取串口缓冲区数据 + if count: + start_time = time.time() + self.status_code = 1 # 收到数据,状态正常 + # 接收和存储数据 + data = (self.sock.read(count)) + self.receiveData = self.receiveData + data # 将接收数据存储在字符串中 + elif self.method == 'tcp': + data = self.sock.recv(100) + if not data: + break + self.receiveData += data + if len(self.receiveData) >= self.PackageLength and self.receiveData.rfind( + b'\x55\x55') >= self.PackageLength - 2: + + index = self.receiveData.index(b'\xaa') + self.receiveData = self.receiveData[index:] + if len(self.receiveData) >= self.PackageLength: + onepackage = self.receiveData[:self.PackageLength] + if onepackage[7] != 0: + self.energy = onepackage[7] # 电量 + self.receiveData = self.receiveData[self.PackageLength:] + dataMatrix = self.extract_packet(onepackage) + try: + with self.RingBufferLock: + self.__ringBuffer.appendBuffer(dataMatrix) + except: + print("锁:写入异常") + self.sock.close() + + except ConnectionResetError: + self.status_code = 0 # 状态异常 + print("Connection was reset by the peer.") + except SerialException as Se: + self.status_code = 0 + print('串口通信异常!请检查适配器') + + + def process_packet(self): + if self.circular_buffer.buffer_length > 158: + packet = self.circular_buffer.extract_packet() + + if packet: + # Here you would parse the packet according to the protocol + # print("Received packet:%s,index:%s", len(packet),str(integer_value)) + return packet + else: + print("Received Nothing") + return None + + def getDataViaSSVEP(self,count): + ''' + ssvep的视觉通道,共8个通道 + @param count: 每通道读取的数值数量 + @return: 返回最新的数值 + ''' + data=self.getData(count) + data=data[:8,:] + return data + + def getImpedance(self, data): + ''' + 获取阻抗值,已经放大100倍,单位是kΩ + @param data: 准备计算的通道数据,每通道200个值,注意不要把信号打标的通道传进来 + @return: 返回各个通道的阻抗值 + ''' + impedanceList = [] + for channelindex in range(data.shape[0]): + if len(data[channelindex]) > 0: + data_list = [] + # 设计陷波滤波器,去除50Hz成分 + is50filter = True + if is50filter: + b, a = signal.iirnotch(50, 30, self.srate) # 30是带宽,1000是采样频率 + data_list = signal.lfilter(b, a, data[channelindex].tolist()) + + else: + data_list.extend(data[channelindex].tolist()) + + data_list = data_list[-1000:] + # 执行FFT + fft_result = np.fft.fft(data_list) + fft_magnitude = np.abs(fft_result / len(data_list)) # 归一化FFT结果 + freqs = np.fft.fftfreq(len(data_list), d=1 / self.srate) # 频率轴 + + # y_amp_modified = np.concatenate(([fft_magnitude[0] / len(t[0].tolist())], + # fft_magnitude[1:-1] * 2 / len(t[0].tolist()), + # [fft_magnitude[-1] / len(t[0].tolist())])) + + # 找到幅值最大的频率成分的索引(忽略直流分量,即索引0) + max_index = np.argmax(fft_magnitude[1:]) + + # 获取最大幅值的频率索引(加上1,因为索引0是直流分量) + freq_index = max_index + 1 + + # 获取最大幅值 + max_magnitude = fft_magnitude[freq_index] + + # 阻抗 + import math + result = math.sqrt(2) * math.pi * max_magnitude / 6 / 4 + result *= 0.44 * 100 # 统一放大100倍 + impedanceList.append(int(result)) + # print(max_magnitude, result) + else: + impedanceList.append(0) + # impedances = ":".join(map(str, impedanceList)) + impedances = np.array(impedanceList) + impedances = impedances[:8] + return impedances + def getData(self,count): + ''' + 获取最新的数据 + @param count: 每通道返回的最数值数目 + @return: 所有通道的最新count个数值 + ''' + data=None + try: + with self.RingBufferLock: + data = self.__ringBuffer.getData(count) + except: + print("锁:读取异常") + # self.RingBufferLock.release() + + + return data + def GetDataLenCount(self): + ''' + 获取最新缓存中每个通道的数量 + @return: + ''' + return self.__ringBuffer.nUpdate + + def ResetAll(self): + ''' + 清空缓存 + @return: + ''' + with self.RingBufferLock: + self.__ringBuffer.resetAllPara() + def stop(self): + self.running = False + + +if __name__ == "__main__": + # Usage + Linker = SunnyLinker64('127.0.0.1', 5086, 1000, 65) + Linker.start() + + try: + while True: + time.sleep(0.005) + if(Linker.count()>0): + # print(Linker.ringBuffer.nUpdate) + t = Linker.getData() + print(t.shape[1], Linker.count()) + # Linker.ringBuffer.nUpdate=0 + # time.sleep(0.2) + except KeyboardInterrupt: + Linker.stop() diff --git a/Debug_64ch_Decoder_Optimize/Device/protocol.py b/Debug_64ch_Decoder_Optimize/Device/protocol.py new file mode 100644 index 0000000..62b274b --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/Device/protocol.py @@ -0,0 +1,193 @@ +from typing import List, Tuple, Union, Optional + + +class ProtocolFrame: + # 协议常量 + FRAME_HEADER = 0xAA + FRAME_TAIL1 = 0x55 + FRAME_TAIL2 = 0x55 + RESERVED_SIZE = 6 + MIN_FRAME_SIZE = 13 # 帧头1 + 功能1 + 长度2 + 预留6 + CRC1 + 包尾2 + MAX_DATA_LENGTH = 0xFFFF # 最大数据长度 (2字节能表示的最大值) + + @staticmethod + def calculate_crc8(data: bytes) -> bytes: + """ + 计算CRC8校验值 + Args: + data: 需要计算CRC的数据 + Returns: + 一个字节的CRC值(bytes类型) + """ + crc = 0 + for byte in data: + crc ^= byte + for _ in range(8): + crc = ((crc << 1) ^ 0x07 if crc & 0x80 else crc << 1) & 0xFF + return bytes([crc]) + + @classmethod + def pack(cls, function, data: Union[bytes, bytearray, List[int]], + reserved: Optional[Union[bytes, bytearray, List[int]]] = None) -> bytes: + """ + 协议打包函数 + + Args: + function: 功能码 (1字节) + data: 数据块 + reserved: 预留字节(6字节,可选) + + Returns: + 打包后的字节数据 + """ + # 检查功能码 + if function != None: + if not 0 <= function <= 0xFF: + raise ValueError("功能码必须是1字节") + + # 转换数据为bytearray + if isinstance(data, list): + data = bytearray(data) + elif isinstance(data, bytes): + data = bytearray(data) + + # 检查数据长度 + data_length = len(data) + if data_length > cls.MAX_DATA_LENGTH: + raise ValueError(f"数据长度超过最大值 {cls.MAX_DATA_LENGTH}") + + # 处理预留字节 + if reserved is None: + reserved = bytearray([0] * cls.RESERVED_SIZE) + else: + if isinstance(reserved, list): + reserved = bytearray(reserved) + elif isinstance(reserved, bytes): + reserved = bytearray(reserved) + if len(reserved) != cls.RESERVED_SIZE: + raise ValueError(f"预留字节必须是{cls.RESERVED_SIZE}字节") + + # 构建帧 + frame = bytearray([cls.FRAME_HEADER]) # 帧头 (1字节) + if function != None: + frame.append(function) # 功能码 (1字节) + data_length+=6 + + # 数据长度 (2字节,大端序) + frame.append((data_length >> 8) & 0xFF) # 高字节 + frame.append(data_length & 0xFF) # 低字节 + + if function != None: + frame.extend(reserved) # 预留字节 (6字节) + frame.extend(data) # 数据块 (变长) + + # 计算CRC (从功能码开始到数据块结束) + crc = cls.calculate_crc8(frame[1:]) # 不包含帧头 + frame.extend(crc) # CRC校验 (1字节) + + # 添加帧尾 + frame.extend([cls.FRAME_TAIL1, cls.FRAME_TAIL2]) # 帧尾 (2字节) + + return bytes(frame) + + @classmethod + def unpack(cls, data: Union[bytes, bytearray]) -> Tuple[int, bytearray, bytearray]: + """ + 协议解包函数 + + Args: + data: 待解析的字节数据 + + Returns: + (功能码, 数据块, 预留字节) + + Raises: + ValueError: 当数据格式不正确时 + """ + # 检查数据长度 + if len(data) < cls.MIN_FRAME_SIZE: + raise ValueError("数据长度不足") + + # 检查帧头 + if data[0] != cls.FRAME_HEADER: + raise ValueError("帧头错误") + + # 检查帧尾 + if data[-2:] != bytes([cls.FRAME_TAIL1, cls.FRAME_TAIL2]): + raise ValueError("帧尾错误") + + # 解析基本信息 + function = data[1] # 功能码 (1字节) + + # 数据长度 (2字节,大端序) + data_length = (data[2] << 8) | data[3] + + reserved = data[4:10] # 预留字节 (6字节) + + # 检查数据长度 + expected_length = cls.MIN_FRAME_SIZE + data_length + if len(data) != expected_length: + raise ValueError(f"数据长度不匹配: 期望{expected_length}字节,实际{len(data)}字节") + + # 提取数据块 + payload = data[10:10 + data_length] + + # 验证CRC (从功能码开始到数据块结束) + received_crc = data[-3] + calculated_crc = cls.calculate_crc8(data[1:-3])[0] # 获取字节值 + + if received_crc != calculated_crc: + raise ValueError(f"CRC校验失败: 期望{calculated_crc:02X},实际{received_crc:02X}") + + return function, bytearray(payload), bytearray(reserved) + + + +def print_hex(data: bytes, label: str = ""): + """打印十六进制数据,并按字节添加空格""" + hex_str = ' '.join([f"{b:02X}" for b in data]) + if label: + print(f"{label}: {hex_str}") + else: + print(hex_str) + + +def print_frame_details(data: bytes): + """打印帧的详细信息""" + print("帧详细信息:") + print(f"帧头: {data[0]:02X}") + print(f"功能码: {data[1]:02X}") + print(f"数据长度: {data[2]:02X} {data[3]:02X} ({(data[2] << 8) | data[3]}字节)") + print(f"预留字节: {' '.join([f'{b:02X}' for b in data[4:10]])}") + data_length = (data[2] << 8) | data[3] + print(f"数据块: {' '.join([f'{b:02X}' for b in data[10:10 + data_length]])}") + print(f"CRC校验: {data[-3]:02X}") + print(f"帧尾: {data[-2]:02X} {data[-1]:02X}") + + +# 使用示例 +def example_usage(): + try: + + + # 示例1:简单数据打包 + function_code = 0x01 + data = [0x1] + packed_data = ProtocolFrame.pack(function_code, data) + print_hex(packed_data, "示例1 - 完整帧") + print_frame_details(packed_data) + print() + + # 示例3:解包验证 + function, payload, reserved = ProtocolFrame.unpack(packed_data) + print("解包结果:") + print(f"功能码: 0x{function:02X}") + print_hex(payload, "数据块") + print_hex(reserved, "预留字节") + + except ValueError as e: + print(f"错误: {e}") + + +if __name__ == "__main__": + example_usage() \ No newline at end of file diff --git a/Debug_64ch_Decoder_Optimize/MI/Algorithm/conformer_2class.py b/Debug_64ch_Decoder_Optimize/MI/Algorithm/conformer_2class.py new file mode 100644 index 0000000..8148b68 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/MI/Algorithm/conformer_2class.py @@ -0,0 +1,409 @@ +""" +EEG Conformer + +Convolutional Transformer for EEG decoding + +Couple CNN and Transformer in a concise manner with amazing results +""" +# remember to change paths + +import os +gpus = [0] +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' +os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus)) +import numpy as np +import math +import random +import time +import datetime + +from torch.utils.data import DataLoader +from torch.autograd import Variable + +import torch +import torch.nn.functional as F +from torch import nn +from torch import Tensor +from einops import rearrange +from einops.layers.torch import Rearrange, Reduce +# from common_spatial_pattern import csp + +# from torch.utils.tensorboard import SummaryWriter +from torch.backends import cudnn +cudnn.benchmark = True +cudnn.deterministic = True +from sklearn.model_selection import train_test_split +# writer = SummaryWriter('./TensorBoardX/') + + +# Convolution module +# use conv to capture local features, instead of postion embedding. +class PatchEmbedding(nn.Module): + def __init__(self, emb_size=40,n_chan=8): + # self.patch_size = patch_size + super().__init__() + + self.shallownet = nn.Sequential( + nn.Conv2d(1, 40, (1, 25), (1, 1)), + nn.Conv2d(40, 40, (n_chan, 1), (1, 1)), + nn.BatchNorm2d(40), + nn.ELU(), + nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT + nn.Dropout(0.5), + ) + + self.projection = nn.Sequential( + nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + + def forward(self, x: Tensor) -> Tensor: + b, _, _, _ = x.shape + x = self.shallownet(x) + x = self.projection(x) + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, emb_size, num_heads, dropout): + super().__init__() + self.emb_size = emb_size + self.num_heads = num_heads + self.keys = nn.Linear(emb_size, emb_size) + self.queries = nn.Linear(emb_size, emb_size) + self.values = nn.Linear(emb_size, emb_size) + self.att_drop = nn.Dropout(dropout) + self.projection = nn.Linear(emb_size, emb_size) + + def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: + queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) + keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) + values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) + energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) + if mask is not None: + fill_value = torch.finfo(torch.float32).min + energy.mask_fill(~mask, fill_value) + + scaling = self.emb_size ** (1 / 2) + att = F.softmax(energy / scaling, dim=-1) + att = self.att_drop(att) + out = torch.einsum('bhal, bhlv -> bhav ', att, values) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.projection(out) + return out + + +class ResidualAdd(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + res = x + x = self.fn(x, **kwargs) + x += res + return x + + +class FeedForwardBlock(nn.Sequential): + def __init__(self, emb_size, expansion, drop_p): + super().__init__( + nn.Linear(emb_size, expansion * emb_size), + nn.GELU(), + nn.Dropout(drop_p), + nn.Linear(expansion * emb_size, emb_size), + ) + + +class GELU(nn.Module): + def forward(self, input: Tensor) -> Tensor: + return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0))) + + +class TransformerEncoderBlock(nn.Sequential): + def __init__(self, + emb_size, + num_heads=10, + drop_p=0.5, + forward_expansion=4, + forward_drop_p=0.5): + super().__init__( + ResidualAdd(nn.Sequential( + nn.LayerNorm(emb_size), + MultiHeadAttention(emb_size, num_heads, drop_p), + nn.Dropout(drop_p) + )), + ResidualAdd(nn.Sequential( + nn.LayerNorm(emb_size), + FeedForwardBlock( + emb_size, expansion=forward_expansion, drop_p=forward_drop_p), + nn.Dropout(drop_p) + ) + )) + + +class TransformerEncoder(nn.Sequential): + def __init__(self, depth, emb_size): + super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)]) + + +class ClassificationHead(nn.Sequential): + def __init__(self, emb_size, n_classes): + super().__init__() + + # global average pooling + self.clshead = nn.Sequential( + Reduce('b n e -> b e', reduction='mean'), + nn.LayerNorm(emb_size), + nn.Linear(emb_size, n_classes) + ) + self.fc = nn.Sequential( + nn.Linear(2440, 256), + nn.ELU(), + nn.Dropout(0.5), + nn.Linear(256, 32), + nn.ELU(), + nn.Dropout(0.3), + nn.Linear(32, 2) + ) + + def forward(self, x): + x = x.contiguous().view(x.size(0), -1) + out = self.fc(x) + return out + + +class Conformer(nn.Sequential): + def __init__(self, emb_size=40, depth=6, n_classes=2,n_chan=8, **kwargs): + super().__init__( + + PatchEmbedding(emb_size,n_chan), + TransformerEncoder(depth, emb_size), + ClassificationHead(emb_size, n_classes) + ) + + +class ExP(): + def __init__(self,n_chan): + super(ExP, self).__init__() + self.n_chan = n_chan + self.batch_size = 24 + self.n_epochs = 250 + self.c_dim = 4 + self.lr = 0.0002 + self.b1 = 0.5 + self.b2 = 0.999 + + self.start_epoch = 0 + # 创建目录 + os.makedirs("online_Models", exist_ok=True) + self.log_write = open("./online_Models/log_result.txt", "w") + + + self.Tensor = torch.cuda.FloatTensor + self.LongTensor = torch.cuda.LongTensor + + self.criterion_cls = torch.nn.CrossEntropyLoss().cuda() + + self.model = Conformer(n_chan=self.n_chan).cuda() + self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))]) + self.model = self.model.cuda() + + # self.model = EEGNet().cuda() + # self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))]) + # self.model = self.model.cuda() + # summary(self.model, (1, 8, 1000)) + + + # Segmentation and Reconstruction (S&R) data augmentation + def interaug(self, timg, label): + # 确保输入是 numpy 数组(CPU) + if isinstance(timg, torch.Tensor): + timg = timg.cpu().numpy() + if isinstance(label, torch.Tensor): + label = label.cpu().numpy() + + aug_data = [] + aug_label = [] + for cls4aug in range(2): + cls_idx = np.where(label == cls4aug + 1) + tmp_data = timg[cls_idx] + tmp_label = label[cls_idx] + tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, self.n_chan, 1000)) + for ri in range(int(self.batch_size / 2)): + for rj in range(8): + rand_idx = np.random.randint(0, tmp_data.shape[0], 8) + tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, + rj * 125:(rj + 1) * 125] + + aug_data.append(tmp_aug_data) + aug_label.append(tmp_label[:int(self.batch_size / 2)]) + aug_data = np.concatenate(aug_data) + aug_label = np.concatenate(aug_label) + aug_shuffle = np.random.permutation(len(aug_data)) + aug_data = aug_data[aug_shuffle, :, :] + aug_label = aug_label[aug_shuffle] + + # 返回 numpy 数组,由调用方决定是否移到 GPU + return aug_data, aug_label + + def train(self,all_data,all_label,model_path): + all_data = np.array(all_data);all_label = np.array(all_label) + all_data = np.expand_dims(all_data, axis=1) + train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2, + random_state=42, stratify=all_label,shuffle=True) + + # === 优化:一次性预生成增强数据,避免每个 batch 都重复计算 === + aug_data, aug_label = self.interaug(train_data, train_label) + # 将原始数据和增强数据合并,再一起打乱 + train_data_full = np.concatenate([train_data, aug_data], axis=0) + train_label_full = np.concatenate([train_label, aug_label], axis=0) + shuffle_idx = np.random.permutation(len(train_data_full)) + train_data_full = train_data_full[shuffle_idx] + train_label_full = train_label_full[shuffle_idx] + + img = torch.from_numpy(train_data_full) + label = torch.from_numpy(train_label_full-1) + + dataset = torch.utils.data.TensorDataset(img, label) + self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True) + + test_data = torch.from_numpy(test_data) + test_label = torch.from_numpy(test_label-1) + test_dataset = torch.utils.data.TensorDataset(test_data, test_label) + self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True) + + # Optimizers + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2)) + + test_data = Variable(test_data.type(self.Tensor)) + test_label = Variable(test_label.type(self.LongTensor)) + + bestAcc = 0 + averAcc = 0 + num = 0 + Y_true = 0 + Y_pred = 0 + + # Train the cnn model + for e in range(self.n_epochs): + # in_epoch = time.time() + self.model.train() + for i, (img, label) in enumerate(self.dataloader): + + img = Variable(img.cuda().type(self.Tensor)) + label = Variable(label.cuda().type(self.LongTensor)) + + outputs = self.model(img) + + loss = self.criterion_cls(outputs, label) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + + # out_epoch = time.time() + + + # test process + if (e + 1) % 1 == 0: + self.model.eval() + Cls = self.model(test_data) + + loss_test = self.criterion_cls(Cls, test_label) + y_pred = torch.max(Cls, 1)[1] + acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0)) + train_pred = torch.max(outputs, 1)[1] + train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0)) + + print('Epoch:', e, + ' Train loss: %.6f' % loss.detach().cpu().numpy(), + ' Test loss: %.6f' % loss_test.detach().cpu().numpy(), + ' Train accuracy %.6f' % train_acc, + ' Test accuracy is %.6f' % acc) + + self.log_write.write(str(e) + " " + str(acc) + "\n") + num = num + 1 + averAcc = averAcc + acc + if acc > bestAcc: + bestAcc = acc + Y_true = test_label + Y_pred = y_pred + + + torch.save(self.model, model_path) + averAcc = averAcc / num + print('The average accuracy is:', averAcc) + print('The best accuracy is:', bestAcc) + self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n") + self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n") + + return bestAcc, averAcc, Y_true, Y_pred + # writer.close() + + +def onlineTrain(data_queue,result_queue): + import torch + print(f"[DEBUG] torch.__version__ = {torch.__version__}") + print(f"[DEBUG] torch.cuda.is_available() = {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"[DEBUG] GPU = {torch.cuda.get_device_name(0)}") + try: + starttime = datetime.datetime.now() + + # seed_n = np.random.randint(2025) + seed_n = 1877 + random.seed(seed_n) + np.random.seed(seed_n) + torch.manual_seed(seed_n) + torch.cuda.manual_seed(seed_n) + torch.cuda.manual_seed_all(seed_n) + + + # 从队列获取训练数据 + data = data_queue.get(timeout=30) + all_data, all_label,model_path,n_chan = data['data'], data['label'],data['modelPath'],data['n_chan'] + exp = ExP(n_chan) + print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path) + bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path) + print('THE BEST ACCURACY IS ' + str(bestAcc)) + + endtime = datetime.datetime.now() + print('train duration: ',str(endtime - starttime)) + + # 将模型或参数传回 + result_queue.put({ + 'status': 'success', + 'model_state': model_path, # 或保存路径 + 'timestamp': time.time() + }) + except Exception as e: + result_queue.put({'status': 'error', 'msg': str(e)}) + +def offlineTrain(all_data,all_label,modelPath): + starttime = datetime.datetime.now() + + # seed_n = np.random.randint(2025) + seed_n = 1877 + print('seed is ' + str(seed_n)) + random.seed(seed_n) + np.random.seed(seed_n) + torch.manual_seed(seed_n) + torch.cuda.manual_seed(seed_n) + torch.cuda.manual_seed_all(seed_n) + + exp = ExP() + + bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath) + print('THE BEST ACCURACY IS ' + str(bestAcc)) + + endtime = datetime.datetime.now() + print('train duration: ',str(endtime - starttime)) + + + +if __name__ == "__main__": + print(time.asctime(time.localtime(time.time()))) + print(time.asctime(time.localtime(time.time()))) diff --git a/Debug_64ch_Decoder_Optimize/MI/Algorithm/conformer_2class_cpu.py b/Debug_64ch_Decoder_Optimize/MI/Algorithm/conformer_2class_cpu.py new file mode 100644 index 0000000..6e29bc3 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/MI/Algorithm/conformer_2class_cpu.py @@ -0,0 +1,382 @@ +""" +EEG Conformer + +Convolutional Transformer for EEG decoding + +Couple CNN and Transformer in a concise manner with amazing results +""" +# remember to change paths +import os +import numpy as np +import math +import random +import time +import datetime + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch import nn +from torch import Tensor +from einops import rearrange +from einops.layers.torch import Rearrange, Reduce +from torch.backends import cudnn +from sklearn.model_selection import train_test_split +# writer = SummaryWriter('./TensorBoardX/') + + +# Convolution module +# use conv to capture local features, instead of postion embedding. +class PatchEmbedding(nn.Module): + def __init__(self, emb_size=40): + # self.patch_size = patch_size + super().__init__() + + self.shallownet = nn.Sequential( + nn.Conv2d(1, 40, (1, 25), (1, 1)), + nn.Conv2d(40, 40, (8, 1), (1, 1)), + nn.BatchNorm2d(40), + nn.ELU(), + nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT + nn.Dropout(0.5), + ) + + self.projection = nn.Sequential( + nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + + def forward(self, x: Tensor) -> Tensor: + b, _, _, _ = x.shape + x = self.shallownet(x) + x = self.projection(x) + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, emb_size, num_heads, dropout): + super().__init__() + self.emb_size = emb_size + self.num_heads = num_heads + self.keys = nn.Linear(emb_size, emb_size) + self.queries = nn.Linear(emb_size, emb_size) + self.values = nn.Linear(emb_size, emb_size) + self.att_drop = nn.Dropout(dropout) + self.projection = nn.Linear(emb_size, emb_size) + + def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: + queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) + keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) + values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) + energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) + if mask is not None: + fill_value = torch.finfo(torch.float32).min + energy.mask_fill(~mask, fill_value) + + scaling = self.emb_size ** (1 / 2) + att = F.softmax(energy / scaling, dim=-1) + att = self.att_drop(att) + out = torch.einsum('bhal, bhlv -> bhav ', att, values) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.projection(out) + return out + + +class ResidualAdd(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + res = x + x = self.fn(x, **kwargs) + x += res + return x + + +class FeedForwardBlock(nn.Sequential): + def __init__(self, emb_size, expansion, drop_p): + super().__init__( + nn.Linear(emb_size, expansion * emb_size), + nn.GELU(), + nn.Dropout(drop_p), + nn.Linear(expansion * emb_size, emb_size), + ) + + +class GELU(nn.Module): + def forward(self, input: Tensor) -> Tensor: + return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0))) + + +class TransformerEncoderBlock(nn.Sequential): + def __init__(self, + emb_size, + num_heads=10, + drop_p=0.5, + forward_expansion=4, + forward_drop_p=0.5): + super().__init__( + ResidualAdd(nn.Sequential( + nn.LayerNorm(emb_size), + MultiHeadAttention(emb_size, num_heads, drop_p), + nn.Dropout(drop_p) + )), + ResidualAdd(nn.Sequential( + nn.LayerNorm(emb_size), + FeedForwardBlock( + emb_size, expansion=forward_expansion, drop_p=forward_drop_p), + nn.Dropout(drop_p) + ) + )) + + +class TransformerEncoder(nn.Sequential): + def __init__(self, depth, emb_size): + super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)]) + + +class ClassificationHead(nn.Sequential): + def __init__(self, emb_size, n_classes): + super().__init__() + + # global average pooling + self.clshead = nn.Sequential( + Reduce('b n e -> b e', reduction='mean'), + nn.LayerNorm(emb_size), + nn.Linear(emb_size, n_classes) + ) + self.fc = nn.Sequential( + nn.Linear(2440, 256), + nn.ELU(), + nn.Dropout(0.5), + nn.Linear(256, 32), + nn.ELU(), + nn.Dropout(0.3), + nn.Linear(32, 2) + ) + + def forward(self, x): + x = x.contiguous().view(x.size(0), -1) + out = self.fc(x) + return out + + +class Conformer(nn.Sequential): + def __init__(self, emb_size=40, depth=6, n_classes=2, **kwargs): + super().__init__( + + PatchEmbedding(emb_size), + TransformerEncoder(depth, emb_size), + ClassificationHead(emb_size, n_classes) + ) + + +class ExP(): + def __init__(self): + super(ExP, self).__init__() + self.batch_size = 24 + self.n_epochs = 250 + self.c_dim = 4 + self.lr = 0.0002 + self.b1 = 0.5 + self.b2 = 0.999 + + self.start_epoch = 0 + + self.log_write = open("./online_Models/log_result.txt", "w") + + # 自动选择设备:有 GPU 用 GPU,否则用 CPU + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # self.device = torch.device("cpu") + print(f"Using device: {self.device}") + + # 定义张量类型(不再强制使用 cuda) + self.Tensor = torch.FloatTensor + self.LongTensor = torch.LongTensor + + # 将模型移到指定设备 + self.model = Conformer().to(self.device) + + # 损失函数也移到设备 + self.criterion_cls = torch.nn.CrossEntropyLoss().to(self.device) + + # self.model = EEGNet().cuda() + # self.model = nn.DataParallel(self.model,device_ids=[i for i in range(len(gpus))]) + # self.model = self.model.cuda() + # summary(self.model, (1, 8, 1000)) + + + # Segmentation and Reconstruction (S&R) data augmentation + def interaug(self, timg, label): + aug_data = [] + aug_label = [] + for cls4aug in range(2): + cls_idx = np.where(label == cls4aug + 1) + tmp_data = timg[cls_idx] + tmp_label = label[cls_idx] + tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, 8, 1000)) + for ri in range(int(self.batch_size / 2)): + for rj in range(8): + rand_idx = np.random.randint(0, tmp_data.shape[0], 8) + tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, + rj * 125:(rj + 1) * 125] + + aug_data.append(tmp_aug_data) + aug_label.append(tmp_label[:int(self.batch_size / 2)]) + aug_data = np.concatenate(aug_data) + aug_label = np.concatenate(aug_label) + aug_shuffle = np.random.permutation(len(aug_data)) + aug_data = aug_data[aug_shuffle, :, :] + aug_label = aug_label[aug_shuffle] + + aug_data = torch.from_numpy(aug_data).float().to(self.device) + aug_label = torch.from_numpy(aug_label - 1).long().to(self.device) + return aug_data, aug_label + + def train(self,all_data,all_label,model_path): + all_data = np.array(all_data);all_label = np.array(all_label) + all_data = np.expand_dims(all_data, axis=1) + train_data, test_data, train_label, test_label = train_test_split(all_data, all_label, test_size=0.2, + random_state=42, stratify=all_label,shuffle=True) + # 转为 Tensor + img = torch.from_numpy(train_data).float().to(self.device) + label = torch.from_numpy(train_label - 1).long().to(self.device) + + dataset = torch.utils.data.TensorDataset(img, label) + self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True) + + test_data = torch.from_numpy(test_data).float().to(self.device) + test_label = torch.from_numpy(test_label - 1).long().to(self.device) + test_dataset = torch.utils.data.TensorDataset(test_data, test_label) + self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True) + + # Optimizers + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2)) + + bestAcc = 0 + averAcc = 0 + num = 0 + Y_true = 0 + Y_pred = 0 + + # Train the cnn model + for e in range(self.n_epochs): + # in_epoch = time.time() + self.model.train() + for i, (img, label) in enumerate(self.dataloader): + + # data augmentation + aug_data, aug_label = self.interaug(train_data, train_label) + img = torch.cat((img, aug_data)) + label = torch.cat((label, aug_label)) + + + outputs = self.model(img) + + loss = self.criterion_cls(outputs, label) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + + # out_epoch = time.time() + + + # test process + if (e + 1) % 1 == 0: + self.model.eval() + with torch.no_grad(): + Cls = self.model(test_data) + + loss_test = self.criterion_cls(Cls, test_label) + y_pred = torch.max(Cls, 1)[1] + acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0)) + train_pred = torch.max(outputs, 1)[1] + train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0)) + + print('Epoch:', e, + ' Train loss: %.6f' % loss.detach().cpu().numpy(), + ' Test loss: %.6f' % loss_test.detach().cpu().numpy(), + ' Train accuracy %.6f' % train_acc, + ' Test accuracy is %.6f' % acc) + + self.log_write.write(str(e) + " " + str(acc) + "\n") + num = num + 1 + averAcc = averAcc + acc + if acc > bestAcc: + bestAcc = acc + Y_true = test_label + Y_pred = y_pred + + + torch.save(self.model, model_path) + averAcc = averAcc / num + print('The average accuracy is:', averAcc) + print('The best accuracy is:', bestAcc) + self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n") + self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n") + + return bestAcc, averAcc, Y_true, Y_pred + # writer.close() + + +def onlineTrain(data_queue,result_queue): + try: + starttime = datetime.datetime.now() + + # seed_n = np.random.randint(2025) + seed_n = 1877 + random.seed(seed_n) + np.random.seed(seed_n) + torch.manual_seed(seed_n) + torch.cuda.manual_seed(seed_n) + torch.cuda.manual_seed_all(seed_n) + + exp = ExP() + # 从队列获取训练数据 + data = data_queue.get(timeout=30) + all_data, all_label,model_path = data['data'], data['label'],data['modelPath'] + print('训练参数: ',np.shape(all_data),np.shape(all_label),model_path) + bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,model_path) + print('THE BEST ACCURACY IS ' + str(bestAcc)) + + endtime = datetime.datetime.now() + print('train duration: ',str(endtime - starttime)) + + # 将模型或参数传回 + result_queue.put({ + 'status': 'success', + 'model_state': model_path, # 或保存路径 + 'timestamp': time.time() + }) + except Exception as e: + result_queue.put({'status': 'error', 'msg': str(e)}) + +def offlineTrain(all_data,all_label,modelPath): + starttime = datetime.datetime.now() + + # seed_n = np.random.randint(2025) + seed_n = 1877 + print('seed is ' + str(seed_n)) + random.seed(seed_n) + np.random.seed(seed_n) + torch.manual_seed(seed_n) + torch.cuda.manual_seed(seed_n) + torch.cuda.manual_seed_all(seed_n) + + exp = ExP() + + bestAcc, averAcc, Y_true, Y_pred = exp.train(all_data,all_label,modelPath) + print('THE BEST ACCURACY IS ' + str(bestAcc)) + + endtime = datetime.datetime.now() + print('train duration: ',str(endtime - starttime)) + + + +if __name__ == "__main__": + print(time.asctime(time.localtime(time.time()))) + print(time.asctime(time.localtime(time.time()))) diff --git a/Debug_64ch_Decoder_Optimize/MI/Algorithm/otherModels.py b/Debug_64ch_Decoder_Optimize/MI/Algorithm/otherModels.py new file mode 100644 index 0000000..be03ac2 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/MI/Algorithm/otherModels.py @@ -0,0 +1,184 @@ +from torchsummary import summary +import torch +import torch.nn as nn + + +def weights_init(m): + if isinstance(m, nn.Conv2d): + nn.init.xavier_uniform_(m.weight) + # nn.init.constant(m.bias, 0) # bias may be none + + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.constant_(m.bias, 0) + + + +def square_activation(x): + return torch.square(x) + + +def safe_log(x): + return torch.clip(torch.log(x), min=1e-7, max=1e7) + + +class ShallowConvNet(nn.Module): + def __init__(self, num_classes=3, chans=19, samples=768): + super(ShallowConvNet, self).__init__() + self.conv_nums = 40 + self.features = nn.Sequential( + nn.Conv2d(1, self.conv_nums, (1, 25)), + nn.Conv2d(self.conv_nums, self.conv_nums, (chans, 1), bias=False), + nn.BatchNorm2d(self.conv_nums) + ) + self.avgpool = nn.AvgPool2d(kernel_size=(1, 75), stride=(1, 15)) + self.dropout = nn.Dropout() + + out = torch.ones((1, 1, chans, samples)) + out = self.features(out) + out = self.avgpool(out) + n_out_time = out.cpu().data.numpy().shape + self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes) + + def forward(self, x): + x = self.features(x) + x = square_activation(x) + x = self.avgpool(x) + x = safe_log(x) + x = self.dropout(x) + + features = torch.flatten(x, 1) + cls = self.classifier(features) + return cls + + +class EEGNet(nn.Module): + def __init__(self, num_classes=2, chans=8, samples=1000, dropout_rate=0.5, kernel_length=64, F1=8, + F2=16,): + super(EEGNet, self).__init__() + + self.features = nn.Sequential( + nn.Conv2d(1, F1, kernel_size=(1, kernel_length), bias=False), + nn.BatchNorm2d(F1), + nn.Conv2d(F1, F1, kernel_size=(chans, 1), groups=F1, bias=False), # groups=F1 for depthWiseConv + nn.BatchNorm2d(F1), + nn.ELU(inplace=True), + # nn.ReLU(), + nn.AvgPool2d((1, 4)), + nn.Dropout(dropout_rate), + # for SeparableCon2D + # SeparableConv2D(F1, F2, kernel1_size=(1, 16), bias=False), + nn.Conv2d(F1, F1, kernel_size=(1, 16), groups=F1, bias=False), # groups=F1 for depthWiseConv + nn.BatchNorm2d(F1), + nn.ELU(inplace=True), + # nn.ReLU(), + nn.Conv2d(F1, F2, kernel_size=(1, 1), groups=1, bias=False), # point-wise cnn + nn.BatchNorm2d(F2), + # nn.ReLU(), + nn.ELU(inplace=True), + nn.AvgPool2d((1, 8)), + nn.Dropout(p=dropout_rate), + # nn.Dropout(p=0.5), + ) + out = torch.ones((1, 1, chans, samples)) + out = self.features(out) + n_out_time = out.cpu().data.numpy().shape + self.classifier = nn.Linear(n_out_time[-1] * n_out_time[-2] * n_out_time[-3], num_classes) + + def forward(self, x): + conv_features = self.features(x) + features = torch.flatten(conv_features, 1) + cls = self.classifier(features) + return cls + + +class LMDA(nn.Module): + """ + LMDA-Net for the paper + """ + def __init__(self, chans=19, samples=768, num_classes=3, depth=9, kernel=75, channel_depth1=24, channel_depth2=9, + ave_depth=1, avepool=5): + super(LMDA, self).__init__() + self.ave_depth = ave_depth + self.channel_weight = nn.Parameter(torch.randn(depth, 1, chans), requires_grad=True) + nn.init.xavier_uniform_(self.channel_weight.data) + + + self.time_conv = nn.Sequential( + nn.Conv2d(depth, channel_depth1, kernel_size=(1, 1), groups=1, bias=False), + nn.BatchNorm2d(channel_depth1), + nn.Conv2d(channel_depth1, channel_depth1, kernel_size=(1, kernel), + groups=channel_depth1, bias=False), + nn.BatchNorm2d(channel_depth1), + nn.GELU(), + ) + # self.avgPool1 = nn.AvgPool2d((1, 24)) + self.chanel_conv = nn.Sequential( + nn.Conv2d(channel_depth1, channel_depth2, kernel_size=(1, 1), groups=1, bias=False), + nn.BatchNorm2d(channel_depth2), + nn.Conv2d(channel_depth2, channel_depth2, kernel_size=(chans, 1), groups=channel_depth2, bias=False), + nn.BatchNorm2d(channel_depth2), + nn.GELU(), + ) + + self.norm = nn.Sequential( + nn.AvgPool3d(kernel_size=(1, 1, avepool)), + # nn.AdaptiveAvgPool3d((9, 1, 35)), + nn.Dropout(p=0.65), + ) + + # 定义自动填充模块 + out = torch.ones((1, 1, chans, samples)) + out = torch.einsum('bdcw, hdc->bhcw', out, self.channel_weight) + out = self.time_conv(out) + out = self.chanel_conv(out) + out = self.norm(out) + n_out_time = out.cpu().data.numpy().shape + print('In ShallowNet, n_out_time shape: ', n_out_time) + self.classifier = nn.Linear(n_out_time[-1]*n_out_time[-2]*n_out_time[-3], num_classes) + + def EEGDepthAttention(self, x): + # x: input features with shape [N, C, H, W] + + N, C, H, W = x.size() + # K = W if W % 2 else W + 1 + k = 7 + adaptive_pool = nn.AdaptiveAvgPool2d((1, W)) + conv = nn.Conv2d(1, 1, kernel_size=(k, 1), padding=(k//2, 0), bias=True).to(x.device) # original kernel k + nn.init.xavier_uniform_(conv.weight) + nn.init.constant_(conv.bias, 0) + softmax = nn.Softmax(dim=-2) + x_pool = adaptive_pool(x) + x_transpose = x_pool.transpose(-2, -3) + y = conv(x_transpose) + y = softmax(y) + y = y.transpose(-2, -3) + return y * C * x + + def forward(self, x): + x = torch.einsum('bdcw, hdc->bhcw', x, self.channel_weight) + + x_time = self.time_conv(x) # batch, depth1, channel, samples_ + x_time = self.EEGDepthAttention(x_time) # DA1 + + x = self.chanel_conv(x_time) # batch, depth2, 1, samples_ + x = self.norm(x) + + features = torch.flatten(x, 1) + cls = self.classifier(features) + return cls + + +if __name__ == '__main__': + model = ShallowConvNet(num_classes=4, chans=22, samples=1125).cuda() + a = torch.randn(12, 1, 3, 875).cuda().float() + l2 = model(a) + model_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) + summary(model, show_input=True) + + print(l2.shape) + diff --git a/Debug_64ch_Decoder_Optimize/PubLibrary/InifileHelper.py b/Debug_64ch_Decoder_Optimize/PubLibrary/InifileHelper.py new file mode 100644 index 0000000..e647da3 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/PubLibrary/InifileHelper.py @@ -0,0 +1,30 @@ +# -*-coding:utf-8 -*- +import configparser +import os +import sys +from audioop import error + +BASE_DIR = os.getcwd() +IniFileName = os.path.join(BASE_DIR, 'config.ini') +# IniFileName=os.path.join( 'config.ini') + +def IniWrite(section,keyname,value): + # 创建ConfigParser对象 + config = configparser.ConfigParser() + config.read(IniFileName,encoding='utf-8') + with open(IniFileName, 'w') as configfile: + if not config.has_section(section): + config.add_section(section) + config[section][keyname]=str(value) + config.write(configfile) + +def IniRead(section,key): + + try: + config = configparser.ConfigParser() + config.read(IniFileName,encoding='utf-8') + return config[section][key] + except error as e: + print(e) + # 读取特定section和键的值 + return '5' \ No newline at end of file diff --git a/Debug_64ch_Decoder_Optimize/PubLibrary/RunOnce.py b/Debug_64ch_Decoder_Optimize/PubLibrary/RunOnce.py new file mode 100644 index 0000000..4201773 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/PubLibrary/RunOnce.py @@ -0,0 +1,15 @@ +import ctypes +import sys + + +def is_program_running(name='Global\\Decoder'): + # 创建互斥体 + mutex_name =name + h_mutex = ctypes.windll.kernel32.CreateMutexW(None, False, mutex_name) + + # 检查互斥体是否已经存在 + if ctypes.windll.kernel32.GetLastError() == 183: # ERROR_ALREADY_EXISTS + print("程序已经在运行.") + return True + + return False \ No newline at end of file diff --git a/Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/base.py b/Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/base.py new file mode 100644 index 0000000..dc00835 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/base.py @@ -0,0 +1,418 @@ +# -*- coding: utf-8 -*- +# +# Authors: Swolf +# Date: 2021/1/07 +# License: MIT License + + +from typing import Optional, List, Tuple, Union +import warnings +import numpy as np +from numpy import ndarray +from numpy.linalg import linalg +from scipy.linalg import solve, qr +from scipy.signal import sosfiltfilt, cheby1, cheb1ord +from sklearn.base import BaseEstimator, TransformerMixin, clone + + +def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray: + """Transform spatial filters to spatial patterns based on paper [1]_. + Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine + information from different channels to extract signals of interest from EEG signals, but if our goal is + neurophysiological interpretation or visualization of weights, activation patterns need to be constructed + from the obtained spatial filters. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + W : ndarray + Spatial filters, shape (n_channels, n_filters). + Cx : ndarray + Covariance matrix of eeg data, shape (n_channels, n_channels). + Cs : ndarray + Covariance matrix of source data, shape (n_channels, n_channels). + + Returns + ------- + A : ndarray + Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern. + + References + ---------- + .. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging. + Neuroimage 87 (2014): 96-110. + """ + # use linalg.solve instead of inv, makes it more stable + # see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py + # and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html + A = solve(Cs.T, np.dot(Cx, W).T).T + return A + + +class FilterBank(BaseEstimator, TransformerMixin): + """ + Filter bank decomposition is a bandpass filter array that divides the input signal into + multiple subband components and obtains the eigenvalues of each subband component. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + base_estimator : class + Estimator for model training and feature extraction. + filterbank : list[ndarray] + A bandpass filter bank used to divide the input signal into multiple subband components. + n_jobs : int + Sets the number of CPU working cores. The default is None. + + References + ---------- + .. [1] Chen X, Wang Y, Nakanishi M, et al. High-speed spelling with a noninvasive brain-computer interface[J]. + Proceedings of the national academy of sciences, 2015, 112(44): E6058-E6067. + """ + def __init__( + self, + base_estimator: BaseEstimator, + filterbank: List[ndarray], + n_jobs: Optional[int] = None, + ): + self.base_estimator = base_estimator + self.filterbank = filterbank + self.n_jobs = n_jobs + + def fit(self, X: ndarray, y: Optional[ndarray] = None, **kwargs): + """ + Training model + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + X : None + Training signal (parameters can be ignored, only used to maintain code structure). + y : None + Label data (ibid., ignorable). + Yf : None + Reference signal (ibid., ignorable). + """ + self.estimators_ = [ + clone(self.base_estimator) for _ in range(len(self.filterbank)) + ] + X = self.transform_filterbank(X) + for i, est in enumerate(self.estimators_): + est.fit(X[i], y, **kwargs) + # def wrapper(est, X, y, kwargs): + # est.fit(X, y, **kwargs) + # return est + # self.estimators_ = Parallel(n_jobs=self.n_jobs)( + # delayed(wrapper)(est, X[i], y, kwargs) for i, est in enumerate(self.estimators_)) + return self + + def transform(self, X: ndarray, **kwargs): + """ + The parameters stored in self are used to convert X into features, and X is filtered through the filter bank to + obtain the eigenvalues of each subband component. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + X : ndarray, shape(n_trials, n_channels, n_samples) + Test the signal. + + Returns + ------- + feat : ndarray, shape(n_trials, n_fre) + Feature array. + """ + X = self.transform_filterbank(X) + feat = [est.transform(X[i], **kwargs) for i, est in enumerate(self.estimators_)] + # def wrapper(est, X, kwargs): + # retval = est.transform(X, **kwargs) + # return retval + # feat = Parallel(n_jobs=self.n_jobs)( + # delayed(wrapper)(est, X[i], kwargs) for i, est in enumerate(self.estimators_)) + feat = np.concatenate(feat, axis=-1) + return feat + + def transform_filterbank(self, X: ndarray): + """ + The input signal is filtered through a filter bank. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + X : ndarray, shape(n_trials, n_channels, n_samples) + Input signal. + + Returns + ------- + Xs: ndarray, shape(Nfb, n_trials, n_channels, n_samples) + Individual subband components of the input signal. + """ + Xs = np.stack([sosfiltfilt(sos, X, axis=-1) for sos in self.filterbank]) + return Xs + + +class FilterBankSSVEP(FilterBank): + """ + Filter bank analysis for SSVEP. + The SSVEP is analyzed using filter banks, that is, multiple filters are combined to decompose the SSVEP signal + into specific segments (subbands containing the original data) and obtain its characteristic data. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + filterbank : list[ndarray] + The filter bank. + base_estimator : class + Estimator for model training and feature extraction. + filterweights : ndarray + Filter weight, default is None. + n_jobs : int + Sets the number of CPU working cores. The default is None. + """ + + def __init__( + self, + filterbank: List[ndarray], + base_estimator: BaseEstimator, + filterweights: Optional[ndarray] = None, + n_jobs: Optional[int] = None, + ): + self.filterweights = filterweights + super().__init__(base_estimator, filterbank, n_jobs=n_jobs) + + def transform(self, X: ndarray): # type: ignore[override] + """ + X is converted into features by using the parameters stored in self, and the eigenvalues of each subband + component are obtained after the input signal is filtered by the filter bank. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + X : ndarray, shape(n_trials, n_channels, n_samples) + Test the signal. + + Returns + ------- + features : ndarray, shape(n_trials, n_fre) + Feature array. + """ + features = super().transform(X) + if self.filterweights is None: + return features + else: + features = np.reshape( + features, (features.shape[0], len(self.filterbank), -1) + ) + return np.sum( + features * self.filterweights[np.newaxis, :, np.newaxis], axis=1 + ) + + + +def generate_filterbank( + passbands: List[Tuple[float, float]], + stopbands: List[Tuple[float, float]], + srate: int, + order: Optional[int] = None, + rp: float = 0.5, +): + """ + Create a filter bank, that is, obtain a bandpass filter coefficient that can divide the input signal into multiple + subband components. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + passbands : list or tuple(float, float) + Passband parameters. + stopbands : list or tuple(float, float) + Stopband parameters. + srate : float + Sampling rate. + order : int + Filter order. + rp : float + The maximum ripple allowed in the passband below the unit gain is 0.5 by default. + + Returns + ------- + Filterbank:ndarray, shape(len(passbands), N, 6) + Filter bank coefficient. + """ + filterbank = [] + for wp, ws in zip(passbands, stopbands): + if order is None: + N, wn = cheb1ord(wp, ws, 3, 40, fs=srate) + sos = cheby1(N, rp, wn, btype="bandpass", output="sos", fs=srate) + else: + sos = cheby1(order, rp, wp, btype="bandpass", output="sos", fs=srate) + + filterbank.append(sos) + return filterbank + +def process(data): + # 白化操作 + meanValue = np.mat(data.mean(axis=1)) + meanData = np.repeat(meanValue, data.shape[1], axis=1) + whiteTemp = data - meanData + # QR 分解 + rankWhiteTemp = whiteTemp.shape[0] + whiteTemp = np.transpose(whiteTemp) + Q, R = qr(whiteTemp.A, mode='economic') + # 计算矩阵的秩 + rankQ = linalg.matrix_rank(R) + if rankQ == 0: + raise ValueError('stats:canoncorr:badData') + elif rankQ <= rankWhiteTemp: + # warnings.warn('stats:canoncorr:NotFullRank') + Q = Q[:, 0:rankQ] + return Q, rankQ + +def reference(listFreqs,fs, numberSmples, num_harms): + numberFrequence = len(listFreqs) + timeIndex = np.arange(1, numberSmples + 1) / fs # time index + referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples)) + for frequenceIndex in range(numberFrequence): + temp = [] + for harmIndex in range(1, num_harms + 1): + stimFrequence = listFreqs[frequenceIndex] # in HZ + # Sin and Cos + temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence), + np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)]) + referenceTemp = np.mat(temp) + # 白化操作和QR分解 + Q, rankQ = process(referenceTemp) + referenceData[frequenceIndex] = np.transpose(Q) + return referenceData + +def generate_cca_references( + freqs: Union[ndarray, int, float], + srate, + T, + phases: Optional[Union[ndarray, int, float]] = None, + n_harmonics: int = 1, +): + """ + Construct a sine-cosine reference signal for canonical correlation analysis (CCA). + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + freqs : int or float + Frequency. + srate : int + Sampling rate. + T : int + Sampling time. + phases : int or float + Phase, default is None. + n_harmonics : int + The number of harmonics. The default value is 1. + + Returns + ------- + Yf:ndarray, shape(srate*T, n_harmonics*2) + Sine and cosine reference signal. + """ + if isinstance(freqs, int) or isinstance(freqs, float): + freqs = np.array([freqs]) + freqs = np.array(freqs)[:, np.newaxis] + if phases is None: + phases = 0 + if isinstance(phases, int) or isinstance(phases, float): + phases = np.array([phases]) + phases = np.array(phases)[:, np.newaxis] + t = np.linspace(0, T, int(T * srate)) + + Yf = [] + for i in range(n_harmonics): + Yf.append( + np.stack( + [ + np.sin(2 * np.pi * (i + 1) * freqs * t + np.pi * phases), + np.cos(2 * np.pi * (i + 1) * freqs * t + np.pi * phases), + ], + axis=1, + ) + ) + Yf = np.concatenate(Yf, axis=1) + return Yf + + +def sign_flip(u, s, vh=None): + """Flip signs of SVD or EIG using the method in paper [1]_. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + u: ndarray + left singular vectors, shape (M, K). + s: ndarray + singular values, shape (K,). + vh: ndarray or None + transpose of right singular vectors, shape (K, N). + + Returns + ------- + u: ndarray + corrected left singular vectors. + s: ndarray + singular values. + vh: ndarray + transpose of corrected right singular vectors. + + References + ---------- + .. [1] https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6422.pdf + """ + if vh is None: + total_proj = np.sum(u * s, axis=0) + signs = np.sign(total_proj) + + random_idx = signs == 0 + if np.any(random_idx): + signs[random_idx] = 1 + warnings.warn( + "The magnitude is close to zero, the sign will become arbitrary." + ) + + u = u * signs + + return u, s + else: + left_proj = np.sum(s[:, np.newaxis] * vh, axis=-1) + right_proj = np.sum(u * s, axis=0) + total_proj = left_proj + right_proj + signs = np.sign(total_proj) + + random_idx = signs == 0 + if np.any(random_idx): + signs[random_idx] = 1 + warnings.warn( + "The magnitude is close to zero, the sign will become arbitrary." + ) + + u = u * signs + vh = signs[:, np.newaxis] * vh + + return u, s, vh diff --git a/Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/dsp.py b/Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/dsp.py new file mode 100644 index 0000000..a2ae853 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/dsp.py @@ -0,0 +1,436 @@ +# -*- coding: utf-8 -*- +# DSP: Discriminal Spatial Patterns +# Authors: Swolf +# Junyang Wang <2144755928@qq.com> +# Last update date: 2022-8-11 +# License: MIT License + +from typing import Optional, List, Tuple +from itertools import combinations +import numpy as np +from scipy.linalg import eigh +from numpy import ndarray +from scipy.linalg import solve +from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin + + + +def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray: + """Transform spatial filters to spatial patterns based on paper [1]_. + Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine + information from different channels to extract signals of interest from EEG signals, but if our goal is + neurophysiological interpretation or visualization of weights, activation patterns need to be constructed + from the obtained spatial filters. + + update log: + 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation + + Parameters + ---------- + W : ndarray + Spatial filters, shape (n_channels, n_filters). + Cx : ndarray + Covariance matrix of eeg data, shape (n_channels, n_channels). + Cs : ndarray + Covariance matrix of source data, shape (n_channels, n_channels). + + Returns + ------- + A : ndarray + Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern. + + References + ---------- + .. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging. + Neuroimage 87 (2014): 96-110. + """ + # use linalg.solve instead of inv, makes it more stable + # see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py + # and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html + A = solve(Cs.T, np.dot(Cx, W).T).T + return A + +def isPD(B: ndarray) -> bool: + """Returns true when input matrix is positive-definite, via Cholesky decompositon method. + + Parameters + ---------- + B : ndarray + Any matrix, shape (N, N) + + Returns + ------- + bool + True if B is positve-definite. + + Notes + ----- + Use numpy.linalg rather than scipy.linalg. In this case, scipy.linalg has unpredictable behaviors. + """ + + try: + _ = np.linalg.cholesky(B) + return True + except np.linalg.LinAlgError: + return False + +def nearestPD(A: ndarray) -> ndarray: + """Find the nearest positive-definite matrix to input. + + Parameters + ---------- + A : ndarray + Any square matrxi, shape (N, N) + + Returns + ------- + A3 : ndarray + positive-definite matrix to A + + Notes + ----- + A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1]_, which + origins at [2]_. + + References + ---------- + .. [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd + .. [2] N.J. Higham, "Computing a nearest symmetric positive semidefinite matrix" (1988): + https://doi.org/10.1016/0024-3795(88)90223-6 + """ + + B = (A + A.T) / 2 + _, s, V = np.linalg.svd(B) + + H = np.dot(V.T, np.dot(np.diag(s), V)) + + A2 = (B + H) / 2 + + A3 = (A2 + A2.T) / 2 + + if isPD(A3): + return A3 + + print("Replace current matrix with the nearest positive-definite matrix.") + + spacing = np.spacing(np.linalg.norm(A)) + # The above is different from [1]. It appears that MATLAB's `chol` Cholesky + # decomposition will accept matrixes with exactly 0-eigenvalue, whereas + # Numpy's will not. So where [1] uses `eps(mineig)` (where `eps` is Matlab + # for `numpy.spacing`), we use the above definition. CAVEAT: our `spacing` + # will be much larger than [1]'s `eps(mineig)`, since `mineig` is usually on + # the order of 1e-16, and `eps(1e-16)` is on the order of 1e-34, whereas + # `spacing` will, for Gaussian random matrixes of small dimension, be on + # othe order of 1e-16. In practice, both ways converge, as the unit test + # below suggests. + eye = np.eye(A.shape[0]) + k = 1 + while not isPD(A3): + mineig = np.min(np.real(np.linalg.eigvals(A3))) + A3 += eye * (-mineig * k**2 + spacing) + k += 1 + + return A3 + +def xiang_dsp_kernel( + X: ndarray, y: ndarray +) -> Tuple[ndarray, ndarray, ndarray, ndarray]: + """ + DSP: Discriminal Spatial Patterns, only for two classes[1]_. + Import train data to solve spatial filters with DSP, + finds a projection matrix that maximize the between-class scatter matrix and + minimize the within-class scatter matrix. Currently only support for two types of data. + + Author: Swolf + + Created on: 2021-1-07 + + Update log: + + Parameters + ---------- + X : ndarray + EEG train data assuming removing mean, shape (n_trials, n_channels, n_samples) + y : ndarray + labels of EEG data, shape (n_trials, ) + + Returns + ------- + W : ndarray + spatial filters, shape (n_channels, n_filters) + D : ndarray + eigenvalues in descending order + M : ndarray + mean value of all classes and trials, i.e. common mode signals, shape (n_channel, n_samples) + A : ndarray + spatial patterns, shape (n_channels, n_filters) + + Notes + ----- + the implementation removes regularization on within-class scatter matrix Sw. + + References + ---------- + .. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in + a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831. + """ + X, y = np.copy(X), np.copy(y) + labels = np.unique(y) + X = np.reshape(X, (-1, *X.shape[-2:])) + X = X - np.mean(X, axis=-1, keepdims=True) + # the number of each label + n_labels = np.array([np.sum(y == label) for label in labels]) + # average template of all trials + M = np.mean(X, axis=0) + # class conditional template + Ms, Ss = zip( + *[ + ( + np.mean(X[y == label], axis=0), + np.sum( + np.matmul(X[y == label], np.swapaxes(X[y == label], -1, -2)), axis=0 + ), + ) + for label in labels + ] + ) + Ms, Ss = np.stack(Ms), np.stack(Ss) + # within-class scatter matrix + Sw = np.sum( + Ss + - n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)), + axis=0, + ) + Ms = Ms - M + # between-class scatter matrix + Sb = np.sum( + n_labels[:, np.newaxis, np.newaxis] * np.matmul(Ms, np.swapaxes(Ms, -1, -2)), + axis=0, + ) + + D, W = eigh(nearestPD(Sb), nearestPD(Sw)) + ix = np.argsort(D)[::-1] # in descending order + D, W = D[ix], W[:, ix] + A = robust_pattern(W, Sb, W.T @ Sb @ W) + + return W, D, M, A + + +def xiang_dsp_feature( + W: ndarray, M: ndarray, X: ndarray, n_components: int = 1 +) -> ndarray: + """ + Return DSP features in paper [1]_. + + Author: Swolf + + Created on: 2021-1-07 + + Update log: + + Parameters + ---------- + W : ndarray + spatial filters from csp_kernel, shape (n_channels, n_filters) + M : ndarray + common template for all classes, shape (n_channel, n_samples) + X : ndarray + eeg test data, shape (n_trials, n_channels, n_samples) + n_components : int, optional + length of the spatial filters, first k components to use, by default 1 + + Returns + ------- + features: ndarray + features, shape (n_trials, n_components, n_samples) + + Raises + ------ + ValueError + n_components should less than half of the number of channels + + Notes + ----- + 1. instead of meaning of filtered signals in paper [1]_., we directly return filtered signals. + + References + ---------- + .. [1] Liao, Xiang, et al. "Combining spatial filters for the classification of single-trial EEG in + a finger movement task." IEEE Transactions on Biomedical Engineering 54.5 (2007): 821-831. + """ + W, M, X = np.copy(W), np.copy(M), np.copy(X) + max_components = W.shape[1] + if n_components > max_components: + raise ValueError("n_components should less than the number of channels") + X = np.reshape(X, (-1, *X.shape[-2:])) + X = X - np.mean(X, axis=-1, keepdims=True) + # print('************: ',np.shape(W),np.shape(X),np.shape(M)) + features = np.matmul(W[:, :n_components].T, X - M) + return features + + +class DSP(BaseEstimator, TransformerMixin, ClassifierMixin): + """ + DSP: Discriminal Spatial Patterns + + Author: Swolf + + Created on: 2021-1-07 + + Update log: + + Parameters + ---------- + n_components : int + length of the spatial filter, first k components to use, by default 1 + transform_method : str + method of template matching, by default ’corr‘ (pearson correlation coefficient) + classes_ : int + number of the EEG classes + + Attributes + ---------- + n_components : int + length of the spatial filter, first k components to use, by default 1 + transform_method : str + method of template matching, by default ’corr‘ (pearson correlation coefficient) + classes_ : int + number of the EEG classes + W_ : ndarray, shape(n_channels, n_filters) + Spatial filters, shape(n_channels, n_filters), in which n_channels = n_filters + D_ : ndarray, shape(n_filters, ) + eigenvalues in descending order, shape(n_filters, ) + M_ : ndarray, shape(n_channels, n_samples) + mean value of all classes and trials, i.e. common mode signals, shape(n_channels, n_samples) + A_ : ndarray, shape(n_channels, n_filters) + spatial patterns, shape(n_channels, n_filters) + templates_: ndarray, shape(n_classes, n_filters, n_samples) + templates of train data, shape(n_classes, n_filters, n_samples) + + """ + + def __init__(self, n_components: int = 1, transform_method: str = "corr"): + self.n_components = n_components + self.transform_method = transform_method + + def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None): + """ + Import the train data to get a model. + + Parameters + ---------- + X : ndarray + train data, shape(n_trials, n_channels, n_samples) + y : ndarray + labels of train data, shape (n_trials, ) + Yf : ndarray + optional parameter + + Returns + ------- + W_ : ndarray + spatial filters, shape (n_channels, n_filters), in which n_channels = n_filters + D_ : ndarray + eigenvalues in descending order, shape (n_filters, ) + M_ : ndarray + template for all classes, shape (n_channel, n_samples) + A_ : ndarray + spatial patterns, shape (n_channels, n_filters) + templates_ : ndarray + templates of train data, shape (n_channels, n_filters, n_samples) + """ + X -= np.mean(X, axis=-1, keepdims=True) + self.classes_ = np.unique(y) + self.W_, self.D_, self.M_, self.A_ = xiang_dsp_kernel(X, y) + + self.templates_ = np.stack( + [ + np.mean( + xiang_dsp_feature( + self.W_, self.M_, X[y == label], n_components=self.W_.shape[1] + ), + axis=0, + ) + for label in self.classes_ + ] + ) + return self + + def transform(self, X: ndarray): + """ + Import the test data to get features. + + Parameters + ---------- + X : ndarray + test data, shape(n_trials, n_channels, n_samples) + + Returns + ------- + feature : ndarray, shape(n_trials,n_classes) + correlation coefficients of templates of train data and features of test data, shape(n_trials, n_classes) + """ + n_components = self.n_components + X -= np.mean(X, axis=-1, keepdims=True) + features = xiang_dsp_feature(self.W_, self.M_, X, n_components=n_components) + if self.transform_method is None: + return features.reshape((features.shape[0], -1)) + elif self.transform_method == "mean": + return np.mean(features, axis=-1) + elif self.transform_method == "corr": + return self._pearson_features( + features, self.templates_[:, :n_components, :] + ) + else: + raise ValueError("non-supported transform method") + + def _pearson_features(self, X: ndarray, templates: ndarray): + """ + Calculate pearson correlation coefficient. + + Parameters + ---------- + X : ndarray + features of test data after spatial filters, shape(n_trials, n_components, n_samples) + templates : ndarray + templates of train data, shape(n_classes, n_components, n_samples) + + Returns + ------- + corr : ndarray + pearson correlation coefficient, shape(n_trials, n_classes) + """ + X = np.reshape(X, (-1, *X.shape[-2:])) + templates = np.reshape(templates, (-1, *templates.shape[-2:])) + X = X - np.mean(X, axis=-1, keepdims=True) + templates = templates - np.mean(templates, axis=-1, keepdims=True) + X = np.reshape(X, (X.shape[0], -1)) + templates = np.reshape(templates, (templates.shape[0], -1)) + istd_X = 1 / np.std(X, axis=-1, keepdims=True) + istd_templates = 1 / np.std(templates, axis=-1, keepdims=True) + corr = (X @ templates.T) / (templates.shape[1] - 1) + corr = istd_X * corr * istd_templates.T + return corr + + def predict(self, X: ndarray): + """ + Import the templates and the test data to get prediction labels. + + Parameters + ---------- + X : ndarray + test data, shape(n_trials, n_channels, n_samples) + + Returns + ------- + labels : ndarray + prediction labels of test data, shape(n_trials,) + """ + feat = self.transform(X) + if self.transform_method == "corr": + labels = self.classes_[np.argmax(feat, axis=-1)] + else: + raise NotImplementedError() + return labels + + diff --git a/Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/tdca.py b/Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/tdca.py new file mode 100644 index 0000000..7b7247b --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/SSMVEP/algorithm/tdca.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +# +# Authors: Swolf +# Date: 2021/10/10 +# License: MIT License +""" +Task Decomposition Component Analysis. +""" +from typing import List + +import numpy as np +from scipy.linalg import qr +from scipy.stats import pearsonr +from numpy import ndarray +from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin +from typing import Optional, List +from SSMVEP.algorithm.base import FilterBankSSVEP +from SSMVEP.algorithm.dsp import xiang_dsp_kernel, xiang_dsp_feature + + +def proj_ref(Yf: ndarray): + Q, R = qr(Yf.T, mode="economic") + P = Q @ Q.T + return P + + +def aug_2(X: ndarray, n_samples: int, padding_len: int, P: ndarray, training: bool = True): + X = X.reshape((-1, *X.shape[-2:])) + n_trials, n_channels, n_points = X.shape + # if n_points < padding_len + n_samples: + # raise ValueError("the length of X should be larger than l+n_samples.") + aug_X = np.zeros((n_trials, (padding_len + 1) * n_channels, n_samples)) + if training: + for i in range(padding_len + 1): + aug_X[:, i * n_channels : (i + 1) * n_channels, :] = X[ + ..., i : i + n_samples + ] + else: + for i in range(padding_len + 1): + aug_X[:, i * n_channels : (i + 1) * n_channels, : n_samples - i] = X[ + ..., i:n_samples + ] + aug_Xp = aug_X @ P + aug_X = np.concatenate([aug_X, aug_Xp], axis=-1) + return aug_X + + +def tdca_feature( + X: ndarray, + templates: ndarray, + W: ndarray, + M: ndarray, + Ps: List[ndarray], + padding_len: int, + n_components: int = 1, + training=False, +): + rhos = [] + for Xk, P in zip(templates, Ps): + a = xiang_dsp_feature( + W, + M, + aug_2(X, P.shape[0], padding_len, P, training=training), + n_components=n_components, + ) + b = Xk[:n_components, :] + a = np.reshape(a, (-1)) + b = np.reshape(b, (-1)) + rhos.append(pearsonr(a, b)[0]) + return rhos + + +class TDCA(BaseEstimator, TransformerMixin, ClassifierMixin): + def __init__(self, padding_len: int, n_components: int = 1): + self.padding_len = padding_len + self.n_components = n_components + + def fit(self, X: ndarray, y: ndarray, Yf: ndarray): + X -= np.mean(X, axis=-1, keepdims=True) + self.classes_ = np.unique(y) + self.Ps_ = [proj_ref(Yf[i]) for i in range(len(self.classes_))] + # print(np.shape(self.Ps_)) + + aug_X_list, aug_Y_list = [], [] + for i, label in enumerate(self.classes_): + aug_X_list.append( + aug_2( + X[y == label], + self.Ps_[i].shape[0], + self.padding_len, + self.Ps_[i], + training=True, + ) + ) + aug_Y_list.append(y[y == label]) + + aug_X = np.concatenate(aug_X_list, axis=0) + aug_Y = np.concatenate(aug_Y_list, axis=0) + self.W_, _, self.M_, _ = xiang_dsp_kernel(aug_X, aug_Y) + + self.templates_ = np.stack( + [ + np.mean( + xiang_dsp_feature( + self.W_, + self.M_, + aug_X[aug_Y == label], + n_components=self.W_.shape[1], + ), + axis=0, + ) + for label in self.classes_ + ] + ) + return self + + def transform(self, X: ndarray): + n_components = self.n_components + X -= np.mean(X, axis=-1, keepdims=True) + X = X.reshape((-1, *X.shape[-2:])) + rhos = [ + tdca_feature( + tmp, + self.templates_, + self.W_, + self.M_, + self.Ps_, + self.padding_len, + n_components=n_components, + ) + for tmp in X + ] + rhos = np.stack(rhos) + return rhos + + def predict(self, X: ndarray): + feat = self.transform(X) + labels = self.classes_[np.argmax(feat, axis=-1)] + return labels,feat + + +class FBTDCA(FilterBankSSVEP, ClassifierMixin): + def __init__( + self, + filterbank: List[ndarray], + padding_len: int, + n_components: int = 1, + filterweights: Optional[ndarray] = None, + n_jobs: Optional[int] = None, + ): + self.padding_len = padding_len + self.n_components = n_components + self.filterweights = filterweights + self.n_jobs = n_jobs + super().__init__( + filterbank, + TDCA(padding_len, n_components=n_components), + filterweights=filterweights, + n_jobs=n_jobs, + ) + + def fit(self, X: ndarray, y: ndarray, Yf: Optional[ndarray] = None): # type: ignore[override] + self.classes_ = np.unique(y) + super().fit(X, y, Yf=Yf) + return self + + def predict(self, X: ndarray): + features = self.transform(X) + if self.filterweights is None: + features = np.reshape( + features, (features.shape[0], len(self.filterbank), -1) + ) + features = np.mean(features, axis=1) + labels = self.classes_[np.argmax(features, axis=-1)] + return labels,features diff --git a/Debug_64ch_Decoder_Optimize/SSVEP/dwfbcca.py b/Debug_64ch_Decoder_Optimize/SSVEP/dwfbcca.py new file mode 100644 index 0000000..3fe0739 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/SSVEP/dwfbcca.py @@ -0,0 +1,529 @@ + + +# -*- coding: utf-8 -*- +import os +import time +import warnings +from os import error +import numpy as np +import scipy +from numpy.linalg import linalg +from scipy.io import loadmat +from scipy.linalg import qr +from scipy.signal import filtfilt, lfilter +# from numpy.linalg import _umath_linalg + + +class FbccaDw: + def __init__(self, fs, num_target, num_chans, num_filter, num_harms, stimTime, parameter, width, winNum,method): + print('******************************************') + print('parameter list') + print('target:', num_target) + print('number of filter bank:', num_filter) + print('parameter:', parameter) + print('width:', width) + self.phase = 0 + self.bandWidth = width + self.winNum = winNum + self.num_harms = num_harms + self.num_target = num_target + self.num_chans = num_chans + self.winTimeDelay = stimTime + self.fs = fs + self.winMaxSampleNum = self.winTimeDelay * self.winNum * self.fs + self.winDelayNum = round(self.winTimeDelay * self.fs) + self.num_fbs = num_filter + parameterValue = np.power(np.arange(1, self.num_fbs + 1), -(parameter[0])) + parameter[1] + self.weightValue = parameterValue / (sum(parameterValue)) + + self.dataUseLen = [0] * self.winNum + self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans]) + self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans]) + self.rhoNum = 2 + self.notchZh = [0] + self.filterZf = [0] * self.num_fbs + self.north_b = [] + self.north_a = [] + self.filterBank_A = [] + self.filterBank_B = [] + self.winStep = 1 + self.DW_cost_method = 'DW11' if method==1 else 'DW1' + + ''' + filterFrequenceBank:根据刺激频率生成的通带和阻带,用于滤波器组频带分解 + ''' + + def filterFrequenceBank(self): + # 阻带的最高频率 + lastFrequence = 90 + freqBandWidth = self.bandWidth[1] + fStep = self.bandWidth[0] + bandFrequence = np.zeros((5, 4)) + # 第二列频率带 + band = list(range(freqBandWidth, lastFrequence, fStep)) + band[:] = [x - 2 for x in band] + colValue = np.maximum(np.asmatrix(band), 1) + bandFrequence[:, 1] = colValue[0, 0:5] + # 第一列频率带 + bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1) + # 第三列频率带 + bandFrequence[:, 2] = lastFrequence + 2 + # 第四列频率带 + bandFrequence[:, 3] = bandFrequence[:, 2] + 10 + # bandFrequence = np.array([[30,33,77,82], + # [62,68,77,82]]) + for idx_fb in range(self.num_fbs): + Nq = self.fs / 2 + Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq] + Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq] + [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3, + 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)] + [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency + self.filterBank_A.append(A) + self.filterBank_B.append(B) + # def filterFrequenceBank(self): + # # 阻带的最高频率 + # lastFrequence = 90 + # freqBandWidth = self.bandWidth[1] + # fStep = self.bandWidth[0] + # bandFrequence = np.zeros((5, 4)) + # # 第二列频率带 + # band = list(range(freqBandWidth, lastFrequence, fStep)) + # band[:] = [x - 2 for x in band] + # colValue = np.maximum(np.asmatrix(band), 1) + # bandFrequence[:, 1] = colValue[0, 0:5] + # # 第一列频率带 + # bandFrequence[:, 0] = np.maximum(bandFrequence[:, 1] - 4, 1) + # # 第三列频率带 + # bandFrequence[:, 2] = lastFrequence + 2 + # # 第四列频率带 + # bandFrequence[:, 3] = bandFrequence[:, 2] + 10 + # for idx_fb in range(self.num_fbs): + # Nq = self.fs / 2 + # Wp = [bandFrequence[idx_fb, 1] / Nq, bandFrequence[idx_fb, 2] / Nq] + # Ws = [bandFrequence[idx_fb, 0] / Nq, bandFrequence[idx_fb, 3] / Nq] + # [N, Wn] = scipy.signal.cheb1ord(Wp, Ws, 3, + # 40) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)] + # [B, A] = scipy.signal.cheby1(N, 0.5, Wn, 'bandpass') # Wn passband edge frequency + # self.filterBank_A.append(A) + # self.filterBank_B.append(B) + + ''' + Filter bank analysis + Input: + eeg : Input eeg data (# of targets, # of channels, Data length [sample]) + Output: + filterData : Generated filter Data + ''' + + def filterbank(self, eeg): + filterData = np.zeros((self.num_fbs, eeg.shape[1], eeg.shape[0])) + for filterIndex in range(self.num_fbs): + if np.all(self.filterZf[filterIndex] == 0): + zi = np.zeros( + [max(len(self.filterBank_A[filterIndex]), len(self.filterBank_B[filterIndex])) - 1, self.num_chans]) + _, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex], + eeg, zi=zi.T) + Data = lfilter(self.filterBank_B[filterIndex], self.filterBank_A[filterIndex], eeg) + else: + Data, self.filterZf[filterIndex] = lfilter(self.filterBank_B[filterIndex], + self.filterBank_A[filterIndex], eeg, + zi=self.filterZf[filterIndex]) + filterData[filterIndex, :, :] = Data.T + return filterData + + ''' + process + 矩阵的白化和QR正则化分解,降低矩阵的维度,加速计算时间 + Input: + data : 输入的二维脑电信号 + Output: + Q : 降维后的矩阵 + rankQ :正则矩阵的秩 + ''' + + def process(self, data): + # 白化操作 + meanValue = np.asmatrix(data.mean(axis=1)) + meanData = np.repeat(meanValue, data.shape[1], axis=1) + whiteTemp = data - meanData + # QR 分解 + rankWhiteTemp = whiteTemp.shape[0] + whiteTemp = np.transpose(whiteTemp) + Q, R = qr(whiteTemp.A, mode='economic') + # 计算矩阵的秩 + rankQ = linalg.matrix_rank(R) + if rankQ == 0: + raise ValueError('stats:canoncorr:badData') + elif rankQ <= rankWhiteTemp: + # warnings.warn('stats:canoncorr:NotFullRank') + Q = Q[:, 0:rankQ] + return Q, rankQ + + ''' + reference + Input: + listFreqs : 刺激频率列表 + numberSmples : 用于分类的脑电信号采样点个数 + num_harms : 谐波数 + Output: + y_ref : 生成的参考信号 (刺激目标数, 2 * 谐波数, 数据长度/采样点数) + ''' + + def reference(self, listFreqs, numberSmples, num_harms): + numberFrequence = len(listFreqs) + timeIndex = np.arange(1, numberSmples + 1) / self.fs # time index + referenceData = np.zeros((numberFrequence, 2 * num_harms, numberSmples)) + for frequenceIndex in range(numberFrequence): + temp = [] + for harmIndex in range(1, num_harms + 1): + stimFrequence = listFreqs[frequenceIndex] # in HZ + # Sin and Cos + temp.extend([np.sin(2 * np.pi * timeIndex * harmIndex * stimFrequence), + np.cos(2 * np.pi * timeIndex * harmIndex * stimFrequence)]) + referenceTemp = np.asmatrix(temp) + # 白化操作和QR分解 + Q, rankQ = self.process(referenceTemp) + referenceData[frequenceIndex] = np.transpose(Q) + return referenceData + + ''' + setNorthFilterPara + 陷波器的参数初始化 + self.north_b, self.north_a : 陷波器的参数设计 + ''' + + def setNotchFilterPara(self): + # notchFilterNum = 3 + # northFreq = 50 + # bwDen = 35 + # wo = northFreq / (self.fs / 2) + # bw = wo / bwDen + # self.north_b, self.north_a = iirnotch(wo, Q=35) # self.north_b, self.north_a = iircomb(northFreq, bwDen, 'notch') + # # n倍零极点,相当于重复滤波n次 + # if notchFilterNum > 1: + # z, p, k = tf2zpk(self.north_b, self.north_a) + # zNew = np.repeat(z, notchFilterNum, axis=0) + # zNew[1], zNew[4] = zNew[4], zNew[1] + # pNew = np.repeat(p, notchFilterNum, axis=0) + # pNew[1], pNew[4] = pNew[4], pNew[1] + # kNew = np.power(k, notchFilterNum) + # self.north_b, self.north_a = zpk2tf(zNew, pNew, kNew) + self.north_b = [0.94801603944125245604368501517456, -1.7577184027642647201616910024313, + 3.9303778338832491279219993884908, -3.7392330345967859095424046245171, + 3.9303778338832482397435796883656, -1.7577184027642638319832713023061, + 0.94801603944125156786526531504933] + + self.north_a = [1, -1.8214007435820627200939725298667, 4.0000101767406484043476666556671, + -3.7380998614928691026193519064691, 3.8589119784285759173769747576443, + -1.6951692350503837491970671180752, 0.89786559147978006745205448169145] + + ''' + northFilter + 进行信号的50hz陷波处理 + Input: + data :输入脑电数据 + Output: + dataFiltered : 陷波处理后的脑电数据 + ''' + + def northFilter(self, data): + try: + if np.all(self.notchZh[0] == 0): + zi = np.zeros([max(len(self.north_a), len(self.north_b)) - 1, self.num_chans]) + _, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=zi.T) + dataFiltered = lfilter(self.north_b, self.north_a, data) + else: + dataFiltered, self.notchZh[0] = lfilter(self.north_b, self.north_a, data, zi=self.notchZh[0]) + return np.asmatrix(dataFiltered) + except Exception: + print(Exception) + + ''' + getDataQ + Inputs: + data:脑电数据 + Rbuffer:待更新的中间系数 + Output: + Qs1 : 脑电特征1 + Qs2 : 脑电特征2 + Rbuffer : 单窗口更新后的系数 + + ''' + + def getDataQ(self, data, Rbuffer): + Qs1 = [0] * self.num_fbs + Qs2 = [0] * self.num_fbs + nulldata = np.zeros([self.num_chans, self.num_chans]) + Rnum = self.num_chans + for fb_num in range(self.num_fbs): + fb_data = np.squeeze(data[fb_num, :, :]) + if np.all(Rbuffer[fb_num] == 0): + whiteTemp = fb_data + Q, R = qr(whiteTemp, mode='economic') + Qs1[fb_num] = nulldata + Qs2[fb_num] = Q + Rbuffer[fb_num] = R + else: + whiteTemp = np.concatenate((Rbuffer[fb_num], fb_data), axis=0) + Q, R = qr(whiteTemp, mode='economic') + Qs1[fb_num] = Q[0:Rnum, :] + Qs2[fb_num] = Q[Rnum:, :] + Rbuffer[fb_num] = R + return Qs1, Qs2, Rbuffer + + ''' + myCCA:根据脑电特征和参考信号计算相关系数 + Inputs: + dataQ:脑电特征 + Qc2y:参考信号 + d : 相关系数取值数 + Output: + rho : 相关系数 + + ''' + + def myCCA(self, dataQ, Qc2y, d): + if len(Qc2y) == 0: + Cov = dataQ + else: + Cov = np.dot(Qc2y, dataQ) + # U, S, V = scipy.linalg.svd(Cov, 0) + # rho = np.minimum(np.maximum(np.diag(S[0: d]).T, 0), 1) + # gufunc = _umath_linalg.svd_n + # rho = gufunc(Cov) + rho = np.linalg.svd(Cov, compute_uv=False) + return rho[0:d] + + ''' + weightCCA:计算分类标签 + Inputs: + Qs1:脑电特征1 + Qs2:脑电特征2 + ref : 正余弦参考信号 + Cxy : 协方差中间参数 + Output: + result : 分类标签 + rho : 相关系数 + Cxy : 更新后的协方差中间参数 + ''' + + def weightCCA(self, Qs1, Qs2, ref, Cxy): + rMax = np.zeros([self.num_fbs, self.num_target]) + for fi in range(self.num_fbs): + for si in range(self.num_target): + Qc2y = np.squeeze(ref[si, :, :]) + # 更新协方差矩阵 + if np.all(Cxy[fi][si] == 0): + Cxy[fi, si] = np.dot(Qc2y, Qs2[fi]) + else: + Cxy[fi, si] = np.dot(Cxy[fi, si], Qs1[fi]) + np.dot(Qc2y, Qs2[fi]) + r = self.myCCA(Cxy[fi, si], [], self.rhoNum) + rMax[fi, si] = r[0] + rho = np.dot(self.weightValue, np.power(rMax, 2)) # weighted sum of r from all different filter banks' result + result = np.argmax(rho) + return result, rho, Cxy + + ''' + costF:损失函数,根据计算的相关系数,生成决策值,用于和阈值进行比较 + Inputs: + rho:相关系数 + method:相关系数计算参数 + C : 参数 + Output: + decideValue : 决策阈值 + ''' + + def costF(self, rho, method, C): + rho = rho.tolist() + rho.sort(reverse=True) + if method == 'DW1': + decideValue = (rho[0] - rho[1]) / (sum(rho) - self.num_target * np.log(sum(np.exp(rho)))) + elif method == 'DW11': + decideValue = -(rho[0] - rho[1]) + elif method == 'DW2': + decideValue = (rho[0] - C) / (rho[1] - rho[0]) + return decideValue + + ''' + onlineInit:将窗口长度,相位值、中间参数初始化 + ''' + + def onlineInit(self): + self.dataUseLen = [0] * self.winNum + self.Rbuffer = np.zeros([self.winNum, self.num_fbs, self.num_chans, self.num_chans]) + self.Cxy = np.zeros([self.winNum, self.num_fbs, self.num_target, 2 * self.num_harms, self.num_chans]) + self.phase = 0 + + ''' + filterInit:重置陷波器和滤波器的滤波参数 + ''' + + def filterInit(self): + self.notchZh = [0] + self.filterZf = [0] * self.num_fbs + + ''' + warmFilter:预热滤波器,去刺激前的4S数据对陷波器和滤波器参数进行初始化迭代,去除过渡带的效果 + Inputs: + data:预处理脑电数据 + ''' + + def warmFilter(self, data): + # 降采样在采集前完成 + temp = self.preprocessFilter(data) #预热陷波滤波器 + # 滤波器组频带分解 + filterData = self.filterbank(temp) #预热滤波器组 + + ''' + myDownSample:数据降采样 + Inputs: + data:脑电数据 + n:降采样的倍数 + Output: + eegData2 : 降采样后的数据 + ''' + + def myDownSample(self, data, n): + data = data[:8, self.phase:] + dataNum = data.shape[1] + remainNum = (dataNum - 1) % n + self.phase = n - 1 - remainNum + dataDowmSample = [] + for value in data: + value = value[0:value.size:n] + dataDowmSample.append(value) + eegData2 = np.array(dataDowmSample).reshape([8, int(np.round(data.shape[1] / n))]) + return eegData2 + + ''' + preprocessFilter:预处理,调用函数降采样和陷波处理 + Inputs: + data:脑电数据 + Output: + filterData : 降采样和陷波后的数据 + ''' + + def preprocessFilter(self, data): + # data = self.myDownSample(data, 4) + # filterData = self.northFilter(data[:8, :]) + filterData = self.northFilter(data[:, :]) + return filterData + + ''' + fbccaDWMW:分类函数,对输入的脑电信号进行识别,输出决策标签 + Inputs: + testdata:脑电数据 + referenceData:参考信号 + tValue:出决策阈值 + Output: + res : 决策标签 + rho_new:相关系数 + minEps:得到的决策阈值 + ''' + + # 动态窗算法主函数 + def fbccaDWMW(self, testdata, referenceData, tValue,calculateCount): + t1 = time.time() + # try: + # 初始参数 + res = -1 + minEps = float("inf") + # 降采样和陷波器处理 + northData = self.preprocessFilter(testdata) + newSampleNum = northData.shape[1] + # 数据大于延迟长度,则无法根据后面的规则更新窗口 + if newSampleNum > self.winDelayNum: + error('need add window delay time') + + # 防止秩小于导联数 + if newSampleNum < self.num_chans: + warnings.warn('data shape is [%d %d] need more data' % (newSampleNum, northData.shape[0])) + # 滤波器组频带分解 + filterData = self.filterbank(northData) + winMinTime = 0 + # 计算每个窗口的结果 + for wi in range(0, self.winNum, self.winStep): + # print('dataUseLen:',wi,calculateCount, self.dataUseLen) + if wi == 0: + self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum + else: + if self.dataUseLen[wi] == 0: + # 判断当前窗是否为新的窗口(因为每一次新的窗口进来时,都会使上一个窗口datauseLen>50) + if self.dataUseLen[wi - self.winStep] > self.winDelayNum*self.winStep: + self.dataUseLen[wi] = newSampleNum + else: + # print('中断: ',wi,calculateCount) + break + else: + self.dataUseLen[wi] = self.dataUseLen[wi] + newSampleNum + + if self.dataUseLen[wi] > self.winMaxSampleNum: + self.dataUseLen[wi] = newSampleNum + self.Rbuffer[wi, :, :, :] = 0 + self.Cxy[wi, :, :, :, :] = 0 + Qs1, Qs2, self.Rbuffer[wi, :, :, :] = self.getDataQ(filterData, self.Rbuffer[wi, :, :, :]) + si = self.dataUseLen[wi] - newSampleNum + ei = self.dataUseLen[wi] + ref = referenceData[:, :, si:ei] + # 更新协方差 + predLabel_new, rho_new, self.Cxy[wi, :, :, :] = self.weightCCA(Qs1, Qs2, ref, self.Cxy[wi, :, :, :]) + # 增加限制,数据长度不能太短 + if self.dataUseLen[wi] > winMinTime * self.fs: + epsilon = self.costF(rho_new, self.DW_cost_method, C=0) + if epsilon < minEps: + minEps = epsilon + predLabel = predLabel_new + xxx = rho_new + if minEps < tValue: + res = predLabel + + if time.time() - t1 > 0.2 and self.winStep < 16: + self.winStep = self.winStep * 2 + # print(self.winStep, " ", time.time() - t1) + # if res != -1: + # print('--------------------- ',res,xxx,' --------------------------') + return res + + +if __name__ == '__main__': + # The number of sub-bands in filter bank analysis + fs = 250 + num_chans = 8 + num_target = 40 + num_filterBank = 3 + num_harm = 5 + stimTime = 0.2 # 多窗口窗长 + winNum = 50 # 窗口的个数 + trials = 1 + step = 50 + res = -1 + list_freqs = [8.0, 8.2, 8.4, 8.6, 8.8, 9.0, 9.2, 9.4, 9.6, 9.8, 10.0, 10.2, 10.4, 10.6, 10.8, 11.0, 11.2, 11.4, + 11.6, 11.8, 12.0, 12.2, 12.4, 12.6, 12.8, 13., 13.2, 13.4, 13.6, 13.8, 14., 14.2, 14.4, 14.6, 14.8, + 15.0, 15.2, 15.4, 15.6, 15.8] + # 初始化对象 + dw = FbccaDw(fs, num_target, num_chans, num_filterBank, num_harm, stimTime, [1.0, 0.3], [8, 8], winNum) + # frequenceband + dw.filterFrequenceBank() + referenceData = dw.reference(list_freqs, int(winNum * stimTime * fs), num_harm) + dw.setNotchFilterPara() + + prelabels = np.zeros((1, 40)) + coefficient = np.zeros([1, 1]) + path = "D:\\工作相关\\项目代码\\SDK封装与测试\\八神BCI脑电信号SDK20230105\\双鹰SDK对比\\offline_data250\\" + for index in range(1, trials + 1): + D = loadmat(os.path.join(path + str(1) + '-warmData.mat')) + warmData = D['warmData'] + dw.onlineInit() + dw.filterInit() + dw.warmFilter(warmData.T) + + tagget_i = 0 + for tagget_i in range(1, step + 1): + D = loadmat(os.path.join(path + str(1) + '-' + str(tagget_i) + '.mat')) + dataSlice = D['dataTemp'] + res = dw.fbccaDWMW(dataSlice.T, referenceData, tValue=-0.2) + if res != -1: + break + prelabels[0, index - 1] = res + 1 + print(index, '--', res + 1," 计算轮数", tagget_i) diff --git a/Debug_64ch_Decoder_Optimize/Tools/plot_MI_EEG.py b/Debug_64ch_Decoder_Optimize/Tools/plot_MI_EEG.py new file mode 100644 index 0000000..b3e940e --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/Tools/plot_MI_EEG.py @@ -0,0 +1,851 @@ +import matplotlib +matplotlib.use('Agg') +import os +import io +import numpy as np +import pandas as pd +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.patches import Ellipse +import matplotlib.cm as cm +import matplotlib.colors as mcolors +from scipy.spatial import Delaunay +from scipy.interpolate import Rbf +from scipy.signal import welch +from scipy.stats import sem +from scipy.signal import butter, filtfilt, hilbert +import base64 + +# 位置坐标 +def read_ch_pos(file_path=r'xy_64.xlsx'): + """ + 将电极位置信息转换为Dict + + 参数: + file_path: 电极位置存储文件, 必须包含'channel', 'x', 'y', 'z'列 + + """ + script_dir = os.path.dirname(os.path.abspath(__file__)) + file_path = os.path.join(script_dir,file_path ) + df = pd.read_excel(file_path) + # 确保列名正确 + if not all(col in df.columns for col in ['channel', 'x', 'y', 'z']): + raise ValueError("DataFrame必须包含'channel', 'x', 'y', 'z'列") + # 创建电极位置字典 + ch_pos = {} + for _, row in df.iterrows(): + ch_pos[row['channel']] = [row['x'], row['y'], row['z']] + return ch_pos +# 头部轮廓 +def draw_head(ax, center=(0, 0), radius=1.0, zorder=4): + """ + 绘制头部轮廓、鼻子和耳朵。 + + 参数: + - ax : matplotlib Axes 对象 + - center : (x, y) 头中心坐标 + - radius : float, 头半径 + - zorder : 绘制层级 + """ + + # 头圆 + head = plt.Circle(center, radius, fill=False, color='k', linewidth=1, zorder=zorder) + ax.add_artist(head) + + # 鼻子(参考 _make_head_outlines) + dx = np.exp(np.arccos(np.deg2rad(12)) * 1j) + dx_real, dx_imag = dx.real, dx.imag + nose_x = np.array([-dx_real, 0, dx_real]) * radius + center[0] + nose_y = np.array([dx_imag, 1.15, dx_imag]) * radius + center[1] + ax.plot(nose_x, nose_y, color='k', linewidth=1, zorder=zorder) + + # 耳朵(参考 _make_head_outlines 手动标定) + ear_radius = radius * 0.12 + ear_scale = radius * 2 # 根据半径缩放 + theta = np.linspace(np.pi / 2, 3 * np.pi / 2, 30) + + # 左耳 + left_ear_x_array = np.array([0.497, 0.510, 0.518, 0.5299, 0.5419, + 0.54, 0.547, 0.532, 0.510, 0.489]) * ear_scale + left_ear_y_array = np.array([0.0555, 0.0775, 0.0783, 0.0746, 0.0555, + -0.0055, -0.0932, -0.1313, -0.1384, -0.1199]) * ear_scale + center[1] + ax.plot(center[0] - left_ear_x_array, left_ear_y_array, color='k', linewidth=1, zorder=zorder) + + # 右耳 + right_ear_x_array = np.array([0.497, 0.510, 0.518, 0.5299, 0.5419, + 0.54, 0.547, 0.532, 0.510, 0.489]) * ear_scale + right_ear_y_array = np.array([0.0555, 0.0775, 0.0783, 0.0746, 0.0555, + -0.0055, -0.0932, -0.1313, -0.1384, -0.1199]) * ear_scale + center[1] + ax.plot(center[0] + right_ear_x_array, right_ear_y_array, color='k', linewidth=1, zorder=zorder) +# 地形图 插值 +def rbf_D_interpolate(xy, v, center=(0, 0), radius=1.1, grid_res=300, + n_extra=32, rbf_func='multiquadric', smooth=0, + border='mean', border_scale=1.0001, n_ngb=4): + """ + 使用 RBF + Delaunay 邻域均值方式生成平滑的 EEG topomap 插值表面。 + + 参数 + ---- + xy : (N,2) array + 电极二维坐标(与绘图坐标系一致) + v : (N,) array + 每个电极对应的值(e.g. PSD) + center : tuple (x0, y0) + 头部圆心(默认 (0,0)) + radius : float + 头部半径(用于生成边界点与网格范围) + grid_res : int + 网格分辨率(每轴点数) + n_extra : int + 边界虚拟点数量 + rbf_func : str + RBF 内核名称('multiquadric','thin_plate','gaussian',...) + smooth : float + RBF 平滑参数 + border : 'mean' or float + 若 'mean':边界点用邻近真实通道均值赋值(推荐) + 若 float:边界点赋相同常数值 + border_scale : float + 边界点半径相对 radius 的缩放(略微 >1 用以外推) + n_ngb : int + 为每个边界点取值时使用的最近真实通道数 + + 返回 + ---- + zi : (grid_res, grid_res) ndarray + 插值结果(与 grid_x, grid_y 对齐) + grid_x, grid_y : ndarrays + meshgrid(由 np.meshgrid 生成) + """ + xy = np.asarray(xy) + v = np.asarray(v) + if xy.ndim != 2 or xy.shape[1] != 2: + raise ValueError("xy must be shape (n_channels, 2)") + + n_points = xy.shape[0] + + # --- 1. 生成边界虚拟点(圆周) --- + theta = np.linspace(0.0, 2 * np.pi, n_extra, endpoint=False) + r_border = radius * border_scale + border_xy = np.column_stack([center[0] + r_border * np.cos(theta), + center[1] + r_border * np.sin(theta)]) + + # --- 2. 用 Delaunay 建图以便找到邻居(对边界点取邻居均值) --- + # 合并用于三角化的位置(真实点 + 边界点) + tri_xy = np.vstack([xy, border_xy]) + tri = Delaunay(tri_xy) + + # --- 3. 为边界点赋值 --- + if isinstance(border, str) and border == 'mean': + # 使用 Delaunay 的 vertex_neighbor_vertices 索引 + # 注意:tri.vertex_neighbor_vertices 给出 vertices -> neighbor indptr + indices, indptr = tri.vertex_neighbor_vertices + v_extra = np.zeros(n_extra) + used = np.zeros(n_extra, dtype=bool) + # 边界点在 tri_xy 中的索引范围 + rng = range(n_points, n_points + n_extra) + for idx, extra_idx in enumerate(rng): + neigh = indptr[indices[extra_idx]:indices[extra_idx + 1]] + # 仅保留原始点索引(小于 n_points) + neigh = neigh[neigh < n_points] + if neigh.size > 0: + used[idx] = True + # 使用最近 n_ngb 个邻居的均值(若邻居多则取最近的 n_ngb) + if neigh.size > n_ngb: + # 计算距离并选取最近 n_ngb + d = np.linalg.norm(xy[neigh] - tri_xy[extra_idx], axis=1) + order = np.argsort(d)[:n_ngb] + sel = neigh[order] + else: + sel = neigh + v_extra[idx] = v[sel].mean() + if not used.all() and used.any(): + v_extra[~used] = np.mean(v_extra[used]) + elif not used.any(): + v_extra[:] = np.mean(v) + else: + # border 是数值 + v_extra = np.full(n_extra, float(border)) + + # --- 4. 合并所有已知点并构建 RBF --- + all_xy = np.vstack([xy, border_xy]) + all_v = np.concatenate([v, v_extra]) + + rbf = Rbf(all_xy[:, 0], all_xy[:, 1], all_v, + function=rbf_func, smooth=smooth) + + # --- 5. 生成网格(使用 meshgrid,与主函数保持一致) --- + xmin, xmax = center[0] - radius, center[0] + radius + ymin, ymax = center[1] - radius, center[1] + radius + xi = np.linspace(xmin, xmax, grid_res) + yi = np.linspace(ymin, ymax, grid_res) + grid_x, grid_y = np.meshgrid(xi, yi) # meshgrid 与 imshow 对齐 + + # --- 6. 评估 RBF,返回与 grid 对齐的 zi --- + zi = rbf(grid_x, grid_y) + + return zi, grid_x, grid_y +# plv矩阵计算 +def calculate_plv(data): + """ + 计算相位锁定值(PLV)矩阵。 + + Parameters + ---------- + data : ndarray, shape (num_channels, num_samples) + EEG 数据,通道数为 num_channels,样本数为 num_samples。 + + Returns + ------- + plv_matrix : ndarray, shape (num_channels, num_channels) + 计算得到的 PLV 矩阵,表示各通道间的相位同步。 + """ + num_channels, num_samples = data.shape + plv_matrix = np.zeros((num_channels, num_channels)) + + # 计算每个通道的解析信号 + analytic_signals = np.apply_along_axis(hilbert, axis=1, arr=data) + + for i in range(num_channels): + for j in range(i + 1, num_channels): # 只计算上三角矩阵,避免重复计算 + # 计算 phase difference + phase_diff = np.angle(analytic_signals[i] * np.conj(analytic_signals[j])) + plv = np.abs(np.mean(np.exp(1j * phase_diff))) + plv_matrix[i, j] = plv + plv_matrix[j, i] = plv # 对称矩阵 + + return plv_matrix +# 矩阵阈值化 +def threshold_proportional(adj, prop=0.2): + """ + Apply a proportional threshold to retain the top proportion of strongest edges. + + Parameters + ---------- + adj : ndarray, shape (n_channels, n_channels) + Adjacency matrix to threshold. + prop : float + Proportion of edges to retain (0 < prop <= 1). + + Returns + ------- + bin_adj : ndarray, shape (n_channels, n_channels) + Binary adjacency matrix after thresholding. + """ + n = adj.shape[0] + triu_idx = np.triu_indices(n, k=1) + weights = adj[triu_idx] + k = int(np.floor(len(weights) * prop)) + + # Ensure that at least one edge is retained + k = max(k, 1) + + # Get the threshold value + thr = np.sort(weights)[-k] + + # Apply the threshold to create a binary adjacency matrix + bin_adj = np.where(adj >= thr, adj, 0.0) + + return bin_adj +# 单个脑网络 +def plot_single_network(ch_names,adj,ax=None, + node_size=20, node_color='orange',highlight_nodes=[], show_names=True, + edge_color='gray', weighted=True, + radius=1.1, figsize=(6, 6),cmap='RdYlBu_r'): + # 若 ax 未传入,则自己创建 + own_fig = False + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + own_fig = True + else: + fig = ax.figure + + # 坐标归一化 + pos3d = read_ch_pos() + all_chs_xy = np.array([pos3d[ch][:2] for ch in pos3d.keys()]) + all_chs_xy -= all_chs_xy.mean(axis=0) + all_chs_xy /= np.sqrt((all_chs_xy ** 2).sum(axis=1)).max() + xy_dict = dict(zip(pos3d.keys(), all_chs_xy)) + xy = np.array([xy_dict[ch] for ch in ch_names]) + center = xy_dict.get('CZ', np.mean(list(xy_dict.values()), axis=0)) + + # ===== 初始化绘图窗口 ===== + ax.set_aspect('equal') + ax.axis('off') + # 设置边界(与原类保持一致) + ear_radius = radius * 0.12 + nose_height = radius * 0.15 + margin_x = radius * 0.12 + 0.05 + ax.set_xlim(center[0] - radius - margin_x, center[0] + radius + margin_x) + ax.set_ylim(center[1] - radius - ear_radius, center[1] + radius + nose_height + ear_radius) + + # 绘制头部轮廓 + draw_head(ax, center=center, radius=radius) + + # 节点 + for ch in ch_names: + color = 'red' if ch in highlight_nodes else node_color + ax.scatter(*xy_dict[ch], s=node_size, color=color, edgecolor='k', zorder=4) + if show_names: + ax.text(xy_dict[ch][0], xy_dict[ch][1] + 0.03, ch, + ha='center', va='bottom', fontsize=8, zorder=5) + + # colorbar + norm = mcolors.Normalize(vmin=0, vmax=1) + color_map = matplotlib.colormaps.get_cmap(cmap) + # ========= 边 ========== + N = len(ch_names) + for i in range(N): + for j in range(i + 1, N): + w = adj[i, j] + if w > 0: + x = [xy[i, 0], xy[j, 0]] + y = [xy[i, 1], xy[j, 1]] + lw = 1.5 + if weighted: + ax.plot(x, y, + color=color_map(norm(w)), + linewidth=lw, + alpha=0.7, + zorder=3) + else: + ax.plot(x, y, + color=edge_color, + linewidth=lw, + alpha=0.7, + zorder=3) + + if own_fig: + # 不回传 添加颜色条 + sm = cm.ScalarMappable(norm=norm, cmap=color_map) + cbar = plt.colorbar(sm, ax=ax, fraction=0.035) + cbar.set_label('Connection Strength', fontsize=10) + cbar.ax.tick_params(direction='in', labelsize=10) + plt.show() + return fig + else: + + return ax +# 脑网络对比 +def plot_multiband_network(ch_names, adj_MI, adj_Rest,cmap='RdYlBu_r'): + + fig, axes = plt.subplots(1, 2, figsize=(8, 4)) + fontsize = 16 + fig.text(0.285, 0.08, 'MI', fontsize=fontsize, ha='center', va='center', rotation=0) + fig.text(0.68, 0.08, 'Rest', fontsize=fontsize, ha='center', va='center', rotation=0) + + im1 = plot_single_network(ch_names,adj_MI,ax=axes[0], show_names=True,cmap=cmap) + # Rest 行 + im2 = plot_single_network(ch_names,adj_Rest,ax=axes[1],show_names=True,cmap=cmap) + + # --- 合并 colorbar(右侧一个) --- + norm = mcolors.Normalize(vmin=0, vmax=1) + color_map = matplotlib.colormaps.get_cmap(cmap) + sm = cm.ScalarMappable(norm=norm, cmap=color_map) + cbar = plt.colorbar(sm, ax=axes.ravel().tolist(), fraction=0.02) + cbar.set_label('Connection Strength', fontsize=10) + cbar.ax.tick_params(direction='in', labelsize=10) + + # 将图像保存到内存字节流(PNG 格式) + buf = io.BytesIO() + fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') + plt.close(fig) # 释放内存 + buf.seek(0) + image_bytes = buf.read() + buf.close() + + return image_bytes + +# 多个频带psd +def compute_band_psd(eeg, fs, bands, labels, trial_idx=0,MI_label=1, Rest_label=2,avg = True): + """ + eeg: (n_trials, n_channels, n_samples) + """ + n_trials, n_channels, n_samples = eeg.shape + band_names = list(bands.keys()) + n_bands = len(band_names) + + psd_MI = np.zeros((n_bands, n_channels)) + psd_Rest = np.zeros((n_bands, n_channels)) + + # 先计算所有 trial 的功率谱 + f, Pxx = welch(eeg, fs=fs, axis=-1, nperseg=fs,noverlap = fs // 2) + + + for bi, (bname, (f1, f2)) in enumerate(bands.items()): + idx = np.logical_and(f >= f1, f <= f2) + band_power = Pxx[:, :, idx].mean(axis=-1) + + band_power_flat = band_power.flatten() + power_min = band_power_flat.min() + power_max = band_power_flat.max() + if power_max - power_min > 1e-12: + band_power_norm = (band_power - power_min) / (power_max - power_min) + else: + band_power_norm = band_power + + if avg: + psd_MI[bi] = band_power_norm[labels == MI_label].mean(axis=0) + psd_Rest[bi] = band_power_norm[labels == Rest_label].mean(axis=0) + else: + psd_MI[bi] = band_power_norm[labels == MI_label][trial_idx] + psd_Rest[bi] = band_power_norm[labels == Rest_label][trial_idx] + return band_names, psd_MI, psd_Rest +# 单个脑地形图 +def plot_single_topomap(ch_names, psd_values, cmap='RdYlBu_r', vlim=(0, 1), + show_names=True, node_size=3, radius=1.1, grid_res=300, + n_contours=None, contour_color='k', + ax=None,figsize=(6,6)): + # 若 ax 未传入,则自己创建 + own_fig = False + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + own_fig = True + else: + fig = ax.figure + + # ===== 初始化绘图窗口 ===== + ax.set_aspect('equal') + ax.axis('off') + # ax.set_title("EEG topomap (MNE-like)") + + # 坐标归一化 + pos3d = read_ch_pos() + all_chs_xy = np.array([pos3d[ch][:2] for ch in pos3d.keys()]) + all_chs_xy -= all_chs_xy.mean(axis=0) + all_chs_xy /= np.sqrt((all_chs_xy ** 2).sum(axis=1)).max() + pos2d_dict = dict(zip(pos3d.keys(), all_chs_xy)) + xy = np.array([pos2d_dict[ch] for ch in ch_names]) + center = pos2d_dict.get('CZ', np.mean(list(pos2d_dict.values()), axis=0)) + + # 绘制头部轮廓 + draw_head(ax, center=center, radius=radius) + # 绘制电极 + fontsize = 4 + ax.scatter(xy[:, 0], xy[:, 1], c='k', s=node_size, zorder=5) + if show_names: + for i, ch in enumerate(ch_names): + ax.text(xy[i, 0], xy[i, 1] + 0.03, ch, + ha='center', va='bottom', fontsize=fontsize, zorder=6) + + # 数据插值 + zi, grid_x, grid_y = rbf_D_interpolate( + xy, psd_values, radius=radius, + grid_res=grid_res + ) + xmin, xmax = center[0] - radius, center[0] + radius + ymin, ymax = center[1] - radius, center[1] + radius + extent = (xmin, xmax, ymin, ymax) + im = ax.imshow(zi, extent=extent, origin='lower', + cmap=cmap, vmin=vlim[0], vmax=vlim[1], + interpolation='bicubic', zorder=0) + # 裁剪路径 + patch_ = Ellipse(center, 2 * radius, 2 * radius, clip_on=True, transform=ax.transData) + im.set_clip_path(patch_) + # 初始等高线 + linewidths = 0.5 + if n_contours is None: + cset = ax.contour(grid_x, grid_y, zi, + colors=contour_color, linewidths=linewidths, zorder=2) + else: + cset = ax.contour(grid_x, grid_y, zi, levels=n_contours, + colors=contour_color, linewidths=linewidths, zorder=2) + cset.set_clip_path(patch_) + + + + if own_fig: + # 不回传 添加颜色条 + plt.colorbar(im, ax=ax, fraction=0.035) + plt.show() + return fig + else: + # plt.colorbar(im, ax=ax, fraction=0.035) + return im +# 脑地形图对比 +def plot_multiband_topomaps(ch_names, psd_MI, psd_Rest, bands): + band_names = list(bands.keys()) # 改动 1:新增这行 + n_bands = len(band_names) + fig, axes = plt.subplots(2, n_bands, figsize=(3*n_bands, 6)) + + fontsize = 16 + + axes[0, 0].text(-0.1, 0.5, 'MI', transform=axes[0, 0].transAxes, rotation=0, va='center', ha='center', fontsize=fontsize-2) + axes[1, 0].text(-0.1, 0.5, 'Rest', transform=axes[1, 0].transAxes, rotation=0, va='center', ha='center', fontsize=fontsize-2) + + imgs = [] + for i, bname in enumerate(band_names): + axes[0, i].set_title(bname, fontsize=fontsize, pad=0) + # MI 行 + im1 = plot_single_topomap(ch_names,psd_MI[i],ax=axes[0, i], show_names=True) + # Rest 行 + im2 = plot_single_topomap(ch_names,psd_Rest[i],ax=axes[1, i],show_names=True) + imgs.append(im1) + + # --- 单个右侧合并 colorbar --- + cbar = fig.colorbar(imgs[0], ax=axes,fraction=0.02) + # cbar.set_label("PSD Power",fontsize=fontsize-4) + cbar.ax.tick_params(direction='in', labelsize=10) + + # 将图像保存到内存字节流(PNG 格式) + buf = io.BytesIO() + fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') + plt.close(fig) # 释放内存 + buf.seek(0) + image_bytes = buf.read() + buf.close() + + return image_bytes + +# 小波 +def morlet_wavelet(f, fs, n_cycles=7): + """ + 创建 Morlet 小波 + f: 频率 + fs: 采样率 + """ + sigma_t = n_cycles / (2 * np.pi * f) + t = np.arange(-3*sigma_t, 3*sigma_t, 1/fs) + wavelet = (np.pi**-0.25) * np.exp(2j*np.pi*f*t) * np.exp(-(t**2)/(2*sigma_t**2)) + return wavelet + + +# 希尔伯特变换 计算ERDS 效果不佳 +def bandpass_filter(data, fs, band, order=4): + nyq = fs / 2 + b, a = butter(order, [band[0]/nyq, band[1]/nyq], btype='band') + return filtfilt(b, a, data, axis=-1) +def compute_power_hilbert(filtered_data,is_dB =True): + analytic = hilbert(filtered_data, axis=-1) + power = np.abs(analytic) ** 2 + if is_dB: + power = 10 * np.log10(power) + return power +def compute_power(data, fs=250, + bands={"mu": (8,12), "beta": (13,30)}): + """ + 返回: + power_dict[band] = (n_trials, n_ch, n_samples) + """ + power_dict = {} + for band_name, band_range in bands.items(): + filt = bandpass_filter(data, fs, band_range) + power = compute_power_hilbert(filt) + power_dict[band_name] = power + + return power_dict + +def compute_erds(power_MI, power_Rest, baseline_period=None): + """ + 计算事件相关去同步/同步 (ERDS) + + Parameters + ---------- + power_MI, power_Rest: (n_trials, n_ch, n_samples) + 功率数据,单位为 µV² 或 dB(取决于 compute_power_hilbert 的 is_dB 参数) + baseline_period: tuple (start_idx, end_idx) or None + 基线时间段索引。如果为None,使用 Rest 状态的平均值作为基线 + + 返回: + MI_erds_mean, MI_erds_sem + Rest_erds_mean, Rest_erds_sem + 所有返回值的形状为 (n_ch, n_samples) + """ + + if baseline_period is not None: + start_idx, end_idx = baseline_period + baseline = np.concatenate([power_MI[:, :, start_idx:end_idx], + power_Rest[:, :, start_idx:end_idx]], axis=0) + baseline = baseline.mean(axis=(0, 2), keepdims=True) + else: + baseline = power_Rest.mean(axis=(0,2), keepdims=True) + + # === ERDS (%) === + MI_erds = (power_MI - baseline) / baseline * 100 + Rest_erds = (power_Rest - baseline) / baseline * 100 + + return ( + MI_erds.mean(axis=0), sem(MI_erds, axis=0), + Rest_erds.mean(axis=0), sem(Rest_erds, axis=0), + ) + +def compute_all_erds(MI_power_dict, Rest_power_dict): + """ + 对多个频带同时计算 ERDS。 + + 输入: + MI_power_dict[band] = (n_trials, n_ch, n_samples) + Rest_power_dict[band] = (n_trials, n_ch, n_samples) + + 输出: + erds_MI[band] = (mean, sem) + erds_Rest[band] = (mean, sem) + """ + + erds_MI = {} + erds_Rest = {} + + for band in MI_power_dict.keys(): + MI_power = MI_power_dict[band] + Rest_power = Rest_power_dict[band] + + MI_mean, MI_sem, Rest_mean, Rest_sem = compute_erds(MI_power, Rest_power) + + erds_MI[band] = (MI_mean, MI_sem) + erds_Rest[band] = (Rest_mean, Rest_sem) + + return erds_MI, erds_Rest + +def plot_compare_erds(data_MI, data_Rest, mode="power", + ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'], + compare_names=['C3', 'CZ', 'C4'], bands=['mu', 'beta'], + fs=250, t=None, figsize=(12,6)): + + n_bands = len(bands) + n_chs = len(compare_names) + + # 自动添加单位 + if mode == "power": + # y_unit = "Power (µV²)" + y_unit = "Power (dB)" + elif mode == "erds": + y_unit = "ERDS (%)" + else: + y_unit = "" + + if t is None: + n_samples = next(iter(data_MI.values())).shape[-1] \ + if mode=="power" else next(iter(data_MI.values()))[0].shape[-1] + t = np.arange(n_samples) / fs + + fig, axes = plt.subplots(n_bands, n_chs, figsize=figsize, sharex=True, sharey=True) + + for i, band in enumerate(bands): + + # 选择数据结构 + if mode == "power": + MI_band = data_MI[band] # (trials, ch, samples) + Rest_band = data_Rest[band] + + avg_MI = MI_band.mean(axis=0) + sem_MI = MI_band.std(axis=0)/np.sqrt(MI_band.shape[0]) + + avg_Rest = Rest_band.mean(axis=0) + sem_Rest = Rest_band.std(axis=0)/np.sqrt(Rest_band.shape[0]) + + elif mode == "erds": + avg_MI, sem_MI = data_MI[band] + avg_Rest, sem_Rest = data_Rest[band] + + for j, ch in enumerate(compare_names): + ax = axes[i, j] if n_bands > 1 else axes[j] + + ch_idx = ch_names.index(ch) + + # 绘制 MI + ax.plot(t, avg_MI[ch_idx], color="C0", label="MI") + ax.fill_between(t, + avg_MI[ch_idx]-sem_MI[ch_idx], + avg_MI[ch_idx]+sem_MI[ch_idx], + alpha=0.3, color="C0") + + # 绘制 Rest + ax.plot(t, avg_Rest[ch_idx], color="C1", label="Rest") + ax.fill_between(t, + avg_Rest[ch_idx]-sem_Rest[ch_idx], + avg_Rest[ch_idx]+sem_Rest[ch_idx], + alpha=0.3, color="C1") + + if i == 0: + ax.set_title(ch) + + # ← Y 轴加单位 + if j == 0: + ax.set_ylabel(f"{band}\n{y_unit}") + + if i == n_bands - 1: + ax.set_xlabel("Time (s)") + + ax.grid(alpha=0.3) + + if i == 0 and j == n_chs - 1: + ax.legend() + + plt.tight_layout() + + # 将图像保存到内存字节流(PNG 格式) + buf = io.BytesIO() + fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') + plt.close(fig) # 释放内存 + buf.seek(0) + image_bytes = buf.read() + buf.close() + + return image_bytes + +# 对比 MI vs Rest 的功率谱密度 PSD +def plot_psd_compare(MI_data, Rest_data, ch_names, compare_names=['C3', 'CZ', 'C4'], + fs=250, nperseg=None, average=True, show_sem=True, + figsize=(12, 3), save_dir=None, filename="psd.png"): + """ + 对比 MI vs Rest 的功率谱密度 PSD + + MI_data, Rest_data: (n_trials, n_ch, n_samples) + channels: 需要绘制的通道 + average: 是否对所有试次平均 + show_sem: 是否绘制 SEM 阴影 + """ + + n_trials, n_ch, n_samples = MI_data.shape + n_trials = min(len(MI_data), len(Rest_data)) + # assert Rest_data.shape == MI_data.shape, "MI 和 Rest 数据维度必须一致" + + if nperseg is None: + nperseg = fs # 每 1 秒窗长度 + + # 计算 MI PSD + psd_MI_all = [] + for trial in range(n_trials): + psd_trial = [] + for ch in range(n_ch): + f, Pxx = welch(MI_data[trial, ch], fs=fs, nperseg=nperseg) + psd_trial.append(Pxx) + psd_MI_all.append(psd_trial) + psd_MI_all = np.array(psd_MI_all) + + # 计算 Rest PSD + psd_Rest_all = [] + for trial in range(n_trials): + psd_trial = [] + for ch in range(n_ch): + _, Pxx = welch(Rest_data[trial, ch], fs=fs, nperseg=nperseg) + psd_trial.append(Pxx) + psd_Rest_all.append(psd_trial) + psd_Rest_all = np.array(psd_Rest_all) + + # ---- Plot ---- + fig, ax = plt.subplots(1, len(compare_names), figsize=figsize) + if len(compare_names) == 1: + ax = [ax] + + for i, ch in enumerate(compare_names): + ch_idx = ch_names.index(ch) + psd_MI_ch = psd_MI_all[:, ch_idx, :] + psd_Rest_ch = psd_Rest_all[:, ch_idx, :] + + if average: + mean_MI = psd_MI_ch.mean(axis=0) + mean_Rest = psd_Rest_ch.mean(axis=0) + + ax[i].plot(f, mean_MI, color='C0', label='MI') + ax[i].plot(f, mean_Rest, color='C1', label='Rest') + + if show_sem: + ax[i].fill_between(f, mean_MI - sem(psd_MI_ch, axis=0), + mean_MI + sem(psd_MI_ch, axis=0), color='C0', alpha=0.3) + ax[i].fill_between(f, mean_Rest - sem(psd_Rest_ch, axis=0), + mean_Rest + sem(psd_Rest_ch, axis=0), color='C1', alpha=0.3) + else: + ax[i].plot(f, psd_MI_ch.T, color='C0', alpha=0.3) + ax[i].plot(f, psd_Rest_ch.T, color='C1', alpha=0.3) + + ax[i].set_title(ch) + ax[i].set_xlabel("Frequency (Hz)") + ax[i].set_ylabel("PSD (μV²/Hz)") + ax[i].grid(alpha=0.3) + if i == 0: + ax[i].legend() + + plt.tight_layout() + + # 将图像保存到内存字节流(PNG 格式) + buf = io.BytesIO() + fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') + plt.close(fig) # 释放内存 + buf.seek(0) + image_bytes = buf.read() + buf.close() + + return image_bytes + + +def plotMain( + ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'], + compare_names = [ 'C3','CZ','C4'], + Data = None,labels = None,MI_label = None,Rest_label = None, + fs = 250): + + trial_idx = 0 + + # 数据划分 + if not MI_label: + label_ = np.unique(labels) + else: + label_ = (MI_label,Rest_label) + MI_data = Data[labels == label_[0]] + Rest_data = Data[labels == label_[1]] + + # 典型 EEG 频带 + FREQ_BANDS = { + "Delta (0.8-4Hz)": (0.8, 4), + "Theta (4-8Hz)": (4, 8), + "Alpha (8-12Hz)": (8, 12), + "Beta (12-30Hz)": (12, 30), + "All (0.8-30Hz)": (0.8, 30) + } + # 利用welch估算PSD + band_names, psd_MI, psd_Rest= compute_band_psd( + eeg=Data, + fs=fs, + bands=FREQ_BANDS, + labels=labels, + trial_idx=trial_idx, + MI_label=MI_label, + Rest_label=Rest_label, + avg= True + ) + # 绘制地形图 + topomaps_imgBytes = plot_multiband_topomaps( + ch_names=ch_names, + psd_MI=psd_MI, + psd_Rest=psd_Rest, + bands=FREQ_BANDS + ) + + # 绘制脑网络 + mi_plv_matrix = calculate_plv(MI_data[trial_idx]) + mi_BI_matrix = threshold_proportional(mi_plv_matrix, prop=0.3) + rest_plv_matrix = calculate_plv(Rest_data[trial_idx]) + rest_BI_matrix = threshold_proportional(rest_plv_matrix, prop=0.3) + network_imgBytes = plot_multiband_network(ch_names, mi_BI_matrix, rest_BI_matrix) + + # ERDS 先计算erds,后平均 + MI_power = compute_power(MI_data) + Rest_power = compute_power(Rest_data) + erds_dict_MI, erds_dict_Rest = compute_all_erds(MI_power, Rest_power) + erds_imgBytes = plot_compare_erds(erds_dict_MI, erds_dict_Rest, ch_names=ch_names, + compare_names=compare_names, bands=['mu', 'beta'], + fs=fs, mode="erds") + + # 绘制PSD + psd_imgBytes = plot_psd_compare(MI_data, Rest_data, ch_names = ch_names, compare_names=compare_names, + fs=fs, nperseg=None, average=True, show_sem=True, + figsize=(12, 3)) + return {'topomaps_imgBytes':base64.b64encode(topomaps_imgBytes).decode(),'network_imgBytes':base64.b64encode(network_imgBytes).decode(), + 'erds_imgBytes':base64.b64encode(erds_imgBytes).decode(),'psd_imgBytes':base64.b64encode(psd_imgBytes).decode()} + +if __name__ == '__main__': + allData = np.random.uniform(-50,50,size=(80,21,1000)) + allLabel = np.random.randint(1,3,size=(80,)) + allData = allData[:len(allLabel)] + ch_names = ['FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'CP3', 'CP1', + 'CP2', 'CP4', 'P3', 'P1', 'PZ', 'P2', 'P4'] + compare_names = ['C3', 'CZ', 'C4'] + ret = plotMain(ch_names=ch_names, compare_names=compare_names, Data=allData, labels=allLabel, MI_label=1, Rest_label=2, + fs=250) + print('计算完成,开始发送') + from Zmq.zmqClient import zmqClient + + zmqClient = zmqClient('192.168.76.101', 8088) + zmqClient.connect() + zmqClient.send_to_all('miReport', ret) diff --git a/Debug_64ch_Decoder_Optimize/Zmq/zmqClient.py b/Debug_64ch_Decoder_Optimize/Zmq/zmqClient.py new file mode 100644 index 0000000..b265e71 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/Zmq/zmqClient.py @@ -0,0 +1,68 @@ +import threading +import time +import json +import zmq + + +class zmqClient: + def __init__(self, host, port): + self.host = host + self.port = port + self.client_socket = None + self.running = False + self.zmq_server = None # Reference to zmqServer for Unity communication + # 记录客户端连接前的状态 + self.state = { + 'status_code': None, + 'energy': None + } + + def set_zmq_server(self, server): + """Set the zmqServer instance to forward messages to Unity""" + self.zmq_server = server + + def connect(self): + # 创建 ZeroMQ 上下文 + self.context = zmq.Context() + # 创建 REQ 套接字(请求端) + self.client_socket = self.context.socket(zmq.DEALER) + # client_id = b'client1' + # self.client_socket.setsockopt(zmq.IDENTITY,client_id) + self.client_socket.connect(f"tcp://{self.host}:{self.port}") # 连接到服务器 + self.running = True + + def send_to_all(self, method,params): + if method in self.state.keys(): + self.state[method] = params + + # Also send to Unity via zmqServer if connected + if self.zmq_server: + self.zmq_server.broadcast_message(method, params) + + try: + if self.running and self.client_socket != None: + msg = {'method': method, 'params': params} + if method in ['single_trial_plot', 'miReport']: + print(f"{{'method': '{method}', 'params': }}") + else: + print(msg) + self.client_socket.send_multipart([b'', json.dumps(msg).encode('utf-8')]) + else: + if method in self.state.keys(): + self.state[method] = params + except ConnectionResetError: + print("Connection lost.") + self.running = False + except Exception as e: + print(f"An error occurred: {e}") + + def close_connection(self): + self.running = False + self.client_socket.close() + self.context.term() + print("Client closed explicitly.") +# 使用TCP客户端 +if __name__ == "__main__": + client = zmqClient('127.0.0.1', 8099) + client.connect() + # client.close_connection() \ No newline at end of file diff --git a/Debug_64ch_Decoder_Optimize/Zmq/zmqServer.py b/Debug_64ch_Decoder_Optimize/Zmq/zmqServer.py new file mode 100644 index 0000000..41425d4 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/Zmq/zmqServer.py @@ -0,0 +1,149 @@ +import numpy as np +import zmq +import threading +import json +import queue +from Device.SunnyLinker import SunnyLinker64 + +class zmqServer(threading.Thread): + def __init__(self, host='0.0.0.0', port=8099): + threading.Thread.__init__(self) + self.host = host + self.port = port + self.running = False + self.get_Impedance = False # 是否返回阻抗值 + self.open_Impedance = None # 是否开启阻抗检测功能 + self.StartDecode = False # false 停止解码,true=开始解码 + self.StartTrain = False # False未进入训练状态,True处于训练状态 + self.state_mode = None # 'train'为训练状态,’rest'为休息状态,'test'为测试状态 + self.currentLabel = -1 # 接收刺激端消息,了解刺激端当前的训练标签 + self.IsExitApp = False # 当socket收到2的时候,就置为True,代表要退出系统了。 + self.getReport = False # 获取训练报告内容 + self.daemon = True + # 创建 ZeroMQ 上下文 + self.context = zmq.Context() + # 创建 REP 套接字(响应端) + self.socket = self.context.socket(zmq.ROUTER) + self.socket.bind(f"tcp://{self.host}:{self.port}") # 绑定到端口 8099 + self.targetFreqs = [] + self.changeTarget = False # 更换目标频率 + self.sunnyLinker = SunnyLinker64(None, None, None, None,None) #单例模式类,已在Decoder实例化 + self.labels = [0x01, 0x02,0x03] + + self.decoder_switch = False #更换解码器 + self.decoder_class = None #解码器类别 'ssvep','ssmvep','mi' + # Client Management (e.g. Unity, Other listeners) + self.clients = set() # 维护客户端ID + self.send_queue = queue.Queue() # 发送队列,安全信箱,维护socket线程 + + def broadcast_message(self, method, params): + """Put message into queue to be sent to all connected clients""" + self.send_queue.put((method, params)) + + def run(self): + self.running = True + print(f"Server is running on {self.host}:{self.port}") + # Use Poller for non-blocking receive + poller = zmq.Poller() + poller.register(self.socket, zmq.POLLIN) + try: + while self.running: + # 1. Process Send Queue (Send to all clients) + while not self.send_queue.empty(): + method, params = self.send_queue.get() + if self.clients: + try: + msg = {'method': method, 'params': params} + msg_bytes = json.dumps(msg).encode('utf-8') + if method in ['single_trial_plot', 'single_trial_plot', 'miReport']: + print(f"{{'method': '{method}', 'params': }}") + else: + print(f"Sending message: {msg}") + # Broadcast to all maintained clients + for client_id in list(self.clients): + try: + # Send: [ID, Empty, JSON] + self.socket.send_multipart([client_id, b'', msg_bytes]) + except Exception as e: + print(f"Error sending to client {client_id}: {e}") + except Exception as e: + print(f"Error preparing broadcast: {e}") + + # 2. Process Receive (Commands) + socks = dict(poller.poll(10)) # 100ms timeout + if self.socket in socks and socks[self.socket] == zmq.POLLIN: + frames = self.socket.recv_multipart() + if len(frames) < 3: + continue + ident, _, message_bytes = frames[:3] + if ident not in self.clients: # register client ID + self.clients.add(ident) + print(f"New Client Detected: {ident}") + try: + message = json.loads(message_bytes.decode('utf-8')) + except json.JSONDecodeError: + continue + print(f"Received request: {message}") + + method = message.get("method") # process request + params = message.get("params") + + if method == "sync": + self.state_mode = 'sync' + if method == "targetFreqs": + if not isinstance(params,list): + print('targetFreqs must be a list') + continue + if params != self.targetFreqs: + self.targetFreqs = params + self.changeTarget = True + if method == "decoderClass": + if not isinstance(params,str): + print('decoderClass must be a str') + continue + if params != self.decoder_class: + self.decoder_class = params + self.decoder_switch = True + if method == "getReport": + self.getReport = True + if method == "train":#训练状态 + self.state_mode = 'train' + self.StartTrain = True + self.currentLabel = params # 当前刺激端的训练标签 + self.sunnyLinker.push_trigger(self.labels[self.currentLabel]) + elif method == "predict":#预测状态 + self.state_mode = 'predict' + if params == 1: #开始解码 + self.StartDecode = True + self.sunnyLinker.push_trigger(0x63) + elif params == 2: #停止解码 + self.IsExitApp = True + self.running = False + elif method == "rest": #休息状态 + self.state_mode = 'rest' + elif method == "impedance": + if params == 1: + self.open_Impedance = True # 开启阻抗 + self.get_Impedance = True # 返回阻抗 + elif params == 2: + self.open_Impedance = False # 关闭阻抗 + self.get_Impedance = False # 停止返回阻抗 + + except Exception as e: + print(f"An socket error occurred: {e}") + finally: + self.running = False + # 关闭套接字和上下文 + self.socket.close() + self.context.term() + print("Server socket and context closed.") + def stop(self): + """显式关闭服务器""" + self.running = False + self.socket.close() + self.context.term() + print("Server closed explicitly.") + +if __name__ == '__main__': + server = zmqServer() + server.start() \ No newline at end of file diff --git a/Debug_64ch_Decoder_Optimize/blinkdetection/algorithm/eye_detection.py b/Debug_64ch_Decoder_Optimize/blinkdetection/algorithm/eye_detection.py new file mode 100644 index 0000000..6b89ccd --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/blinkdetection/algorithm/eye_detection.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Sep 29 16:14:17 2025 + +@author: 23749 +""" + + +import numpy as np +from scipy.signal import butter, filtfilt + +## 1.Bandpass Filter +def butter_bandpass(lowcut, highcut, fs, order=4): + # 滤波器 + nyq = 0.5 * fs #ny:Nyquist频率,即能表示的最大有效频率 + low = lowcut / nyq + high = highcut / nyq + b, a = butter(order, [low, high], btype='band') #巴特沃斯滤波器,order=4阶 + return b, a + +def bandpass_filter(data, lowcut, highcut, fs, order=4): + b, a = butter_bandpass(lowcut, highcut, fs, order) + return filtfilt(b, a, data) + +## 2.Eye Blink Dectection +def blink_detection(F, fs, Dmin, Dmax, Emin, Emax): + """ + 波形检测 + 输入: 差分特征向量 F, 采样率 fs + 输出: b (0/1), 以及计算出的 d, e + """ + + if F is None or len(F) < 3: + return 0, None, None + + # 找最大时间(peak) & 最小时间(valley) + t_peak = np.argmax(F) + t_valley = np.argmin(F) + + # 要求 peak 在 valley 之前(符合 blink 形态),否则交换 + if t_valley < t_peak: + t_peak, t_valley = t_valley, t_peak + + # 计算持续时间 d (ms) + d = (t_valley - t_peak) * 1000.0 / fs + + # 计算能量 e (差分平方和) + e = np.sum(F[t_peak:t_valley + 1] ** 2) + + # 阈值判定 + if Dmin <= d <= Dmax and Emin <= e <= Emax: + b = 1 # 检测到眨眼 + else: + b = 0 # 否则 no blink + return b, d, e + +if __name__ == '__main__': + import matplotlib.pyplot as plt + + fs = 250 # 采样率 + t = np.arange(0, 5, 1/fs) + eog = 0.01 * np.random.randn(len(t)) # 基线+噪声 + + # 模拟眨眼(在 2.0s 注入脉冲) + center = int(2.0 * fs) + eog[center:center+5] += 0.5 + eog[center+5:center+15] -= 0.4 + + # 测试 blink_detection + F = np.diff(eog) + b, d, e = blink_detection(F, fs, 70, 500, 0.1, 10) + print(f"Detected: {b}, Duration: {d}ms, Energy: {e}") diff --git a/Debug_64ch_Decoder_Optimize/build_algorithm.spec b/Debug_64ch_Decoder_Optimize/build_algorithm.spec new file mode 100644 index 0000000..02f241c --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/build_algorithm.spec @@ -0,0 +1,98 @@ +# -*- mode: python ; coding: utf-8 -*- + +import sys +import os +from PyInstaller.utils.hooks import collect_submodules, collect_data_files + +# ======================================================== +# 1. 工程配置区 (Project Config) +# ======================================================== +block_cipher = None +ENTRY_POINT = 'runDecoder.py' +APP_NAME = 'runDecoder' + +# ======================================================== +# 2. 依赖分析 (Dependency Analysis) +# ======================================================== +# 收集 sklearn, scipy 可能遗漏的隐藏导入 +hidden_imports = [ + 'sklearn.utils._cython_blas', + 'sklearn.neighbors.typedefs', + 'sklearn.neighbors.quad_tree', + 'sklearn.tree', + 'sklearn.tree._utils', + 'einops', # 必须显式添加 +] + +# 收集 torch 相关的隐式导入 +hidden_imports += ['torch', 'torchvision'] + +# 收集 pandas 相关的隐式导入 +hidden_imports += ['pandas'] + +# ======================================================== +# 3. 资源锚定 (Data Anchoring) +# ======================================================== +# Analysis 中的 datas 用于将文件嵌入到内部 +datas = [] + +# ======================================================== +# 4. 构建流程 (Build Process) +# ======================================================== +a = Analysis( + [ENTRY_POINT], + pathex=[], + binaries=[], + datas=datas, + hiddenimports=hidden_imports, + hookspath=[], + hooksconfig={}, + runtime_hooks=['rthook.py'], # 添加运行时钩子,处理路径和多进程 + excludes=['tkinter', 'PyQt5', 'PySide2', 'IPython', 'notebook'], # 排除 GUI 和交互式库减小体积 + win_no_prefer_redirects=False, + win_private_assemblies=False, + cipher=block_cipher, + noarchive=False, +) + +pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) + +exe = EXE( + pyz, + a.scripts, + [], + exclude_binaries=True, + name=APP_NAME, + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=False, + console=True, # 保持 True 以便查看日志,部署时可改为 False + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) + +# ======================================================== +# 5. 打包模式: OneDir (单文件夹) + 资源旁路 +# ======================================================== +# 使用 Tree 将文件夹原样复制到 dist/APP_NAME/ 下 +# 格式: Tree('源路径', prefix='目标子目录') + +coll = COLLECT( + exe, + a.binaries, + a.zipfiles, + a.datas, + # 显式复制资源文件夹到 exe 同级目录 + Tree('online_Models', prefix='online_Models', excludes=['*.pyc']), + Tree('Tools', prefix='Tools', excludes=['*.pyc']), + # config.ini 作为单独文件 + [('config.ini', 'config.ini', 'DATA')], + strip=False, + upx=False, + upx_exclude=[], + name=APP_NAME, +) diff --git a/Debug_64ch_Decoder_Optimize/build_with_copy.py b/Debug_64ch_Decoder_Optimize/build_with_copy.py new file mode 100644 index 0000000..e8f86d6 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/build_with_copy.py @@ -0,0 +1,88 @@ +import os +import shutil +import subprocess +import sys + +def main(): + # 1. 定义路径 + BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + DIST_DIR = os.path.join(BASE_DIR, 'dist') + APP_NAME = 'runDecoder' + TARGET_DIR = os.path.join(DIST_DIR, APP_NAME) + + # 定义需要复制的资源 {源路径: 目标子路径} + # 目标子路径相对于 TARGET_DIR + RESOURCES = { + 'config.ini': 'config.ini', + 'online_Models': 'online_Models', + 'Tools': 'Tools', + } + + # 2. 清理旧构建 + print("[1/3] Cleaning up old builds...") + if os.path.exists(DIST_DIR): + try: + shutil.rmtree(DIST_DIR) + print(" Cleaned dist/") + except Exception as e: + print(f" Warning: Could not clean dist/: {e}") + + BUILD_DIR = os.path.join(BASE_DIR, 'build') + if os.path.exists(BUILD_DIR): + try: + shutil.rmtree(BUILD_DIR) + print(" Cleaned build/") + except Exception as e: + print(f" Warning: Could not clean build/: {e}") + + # 3. 运行 PyInstaller + print("[2/3] Running PyInstaller...") + # 注意:我们这里不传 --noupx,因为已经在 spec 文件里把 upx=False 写死了 + cmd = [ + "pyinstaller", + "build_algorithm.spec", + "--clean" + ] + + try: + subprocess.check_call(cmd, shell=True) + except subprocess.CalledProcessError: + print("Error: PyInstaller failed.") + sys.exit(1) + + # 4. 复制外部资源文件夹 + print("[3/3] Verifying and Copying external resources...") + + for src_name, dst_name in RESOURCES.items(): + src_path = os.path.join(BASE_DIR, src_name) + dst_path = os.path.join(TARGET_DIR, dst_name) + + if os.path.exists(src_path): + if os.path.isfile(src_path): + # 如果是文件 + try: + shutil.copy2(src_path, dst_path) + print(f" Copied file: {src_name} -> {dst_name}") + except Exception as e: + print(f" Error copying file {src_name}: {e}") + else: + # 如果是文件夹 + if os.path.exists(dst_path): + try: + shutil.rmtree(dst_path) # 先删除 spec 生成的旧文件夹 (如果有) + except Exception as e: + print(f" Warning: Could not remove existing dir {dst_path}: {e}") + try: + shutil.copytree(src_path, dst_path, ignore=shutil.ignore_patterns('*.pyc', '__pycache__')) + print(f" Copied dir: {src_name} -> {dst_name}") + except Exception as e: + print(f" Error copying dir {src_name}: {e}") + else: + print(f" Warning: Source resource not found at {src_path}") + + print("\n" + "="*50) + print(f"SUCCESS! Build artifacts are in: {TARGET_DIR}") + print("="*50) + +if __name__ == "__main__": + main() diff --git a/Debug_64ch_Decoder_Optimize/concentration/algorithm/calculate_focus.py b/Debug_64ch_Decoder_Optimize/concentration/algorithm/calculate_focus.py new file mode 100644 index 0000000..b042398 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/concentration/algorithm/calculate_focus.py @@ -0,0 +1,396 @@ +import numpy as np +from scipy.signal import welch +from scipy.fft import fft +from scipy import signal +from collections import deque +import time +import os +# import logging +import base64 +import io + +# logger = logging.getLogger(__name__) +# +# try: +# import matplotlib +# matplotlib.use('Agg') +# import matplotlib.pyplot as plt +# MATPLOTLIB_AVAILABLE = True +# except ImportError: +# MATPLOTLIB_AVAILABLE = False +# logger.warning("matplotlib未安装,报告图表功能不可用") + + +class Calculate(): + def __init__(self, Threshold_value_low, Threshold_value_high, fs=250, win_len=10): + self.Threshold_value_low = Threshold_value_low + self.Threshold_value_high = Threshold_value_high + self.fs = fs + self.focus_result = [] + self.CLI_result = [] + self.EVI_result = [] + self.eegQueue = deque(maxlen=win_len) + + # # 存储历史数据用于绘图 + # self.beta_history = [] + # self.alpha_history = [] + # self.theta_history = [] + # self.focus_history = [] + # self.timestamp_history = [] + # + # # 记录开始时间 + # self.start_time = None + # self.recording = False + # + # # 图表保存路径 + # self.chart_dir = "reports" + # if not os.path.exists(self.chart_dir): + # os.makedirs(self.chart_dir) + # print(f"[调试] 创建目录: {self.chart_dir}") + + # 初始化滤波器 + self.b_notch, self.a_notch = signal.iirnotch(50 / (self.fs/2), 30) + self.b_design = signal.firwin(65, [2 / (self.fs/2), 40 / (self.fs/2)], pass_zero=False) + + print("[调试] Calculate 类初始化完成") + + def calculate_focus(self, beta, alpha, theta): + """ + 专注度计算 - 固定映射版本 + """ + # 原始比值 + raw = beta / (alpha + theta + 1e-10) + + # Sigmoid 映射:让 raw 在 0.3-1.5 区间敏感 + # 参数可调: + # k = 12 (斜率,越大越陡) + # x0 = 0.6 (中心点,raw=0.6时focus≈50) + k = 12.0 + x0 = 0.6 + focus = 100.0 / (1.0 + np.exp(-k * (raw - x0))) + + # 可选:添加滑动平均平滑 + return int(focus) + + def calculate_all(self, data, fs, nperseg=1000): + mean_x = np.mean(data, axis=-1, keepdims=True) + data = data - mean_x + freqs, psd = self.compute_psd_multichannel(data, fs, nperseg) + beta_psd = np.sum(self.band_psd(freqs, psd, (13, 30))) + alpha_psd = np.sum(self.band_psd(freqs, psd, (8, 13))) + theta_psd = np.sum(self.band_psd(freqs, psd, (4, 8))) + + print(f"[功率] β={beta_psd:.2f} | α={alpha_psd:.2f} | θ={theta_psd:.2f}") + + + focus_score = self.calculate_focus(beta_psd, alpha_psd, theta_psd) + focus_score = max(0, min(100, focus_score)) + + self.focus_result.append(focus_score) + if len(self.focus_result) > 3: + self.focus_result.pop(0) + final_focus = int(self.simple_moving_average(self.focus_result, window_size=5)) + + cli_denom = alpha_psd + beta_psd + CLI_score = np.log(theta_psd / (cli_denom + 1e-10)) if cli_denom > 0 else 0 + self.CLI_result.append(CLI_score) + if len(self.CLI_result) > 5: + self.CLI_result.pop(0) + final_CLI = round(self.simple_moving_average(self.CLI_result, window_size=5), 2) + + return final_focus, final_CLI, beta_psd, alpha_psd, theta_psd + + def compute_psd_multichannel(self, data, fs=250, nperseg=1000): + n_samples = data.shape[-1] + if n_samples < nperseg: + nperseg = n_samples + + noverlap = 500 + if noverlap >= nperseg: + noverlap = int(nperseg / 2) + + if nperseg == 0: + return np.array([]), np.zeros((data.shape[0], 0)) + + freqs, psd = welch(data, fs=fs, nperseg=nperseg, noverlap=noverlap, axis=-1) + return freqs, psd + + def band_psd(self, freqs, psd, band): + idx = np.logical_and(freqs >= band[0], freqs <= band[1]) + return np.sum(psd[:, idx], axis=-1) + + def simple_moving_average(self, data, window_size=5): + if len(data) == 0: + return 30 + window = data[-window_size:] + return sum(window) / len(window) + + def reset_queue(self): + self.eegQueue.clear() + + # def start_recording(self): + # """开始记录数据""" + # self.recording = True + # self.start_time = time.time() + # self.beta_history = [] + # self.alpha_history = [] + # self.theta_history = [] + # self.focus_history = [] + # self.timestamp_history = [] + # print("[调试] ========== 开始记录专注度数据 ==========") + + # def stop_recording(self): + # """停止记录并生成图表""" + # print(f"[调试] stop_recording被调用, recording={self.recording}, focus_history长度={len(self.focus_history)}") + # self.recording = False + # if len(self.focus_history) > 0: + # print("[调试] 数据非空,开始生成图表...") + # # 保存到本地文件 + # chart_path = self.save_chart_to_file() + # if chart_path: + # print(f"[调试] 本地文件保存成功: {chart_path}") + # else: + # print("[调试] 本地文件保存失败") + # # 生成base64编码 + # base64_data = self.generate_chart_base64() + # return base64_data + # else: + # print("[调试] 没有数据可保存,focus_history为空") + # return None + + # def add_data_point(self, focus, beta, alpha, theta): + # if not self.recording: + # return + # current_time = time.time() + # elapsed = current_time - self.start_time + # + # self.beta_history.append(beta) + # self.alpha_history.append(alpha) + # self.theta_history.append(theta) + # self.focus_history.append(focus) + # self.timestamp_history.append(elapsed) + # print(f"[调试] 记录数据点: time={elapsed:.1f}s, focus={focus}, beta={beta:.2f}") + + # def save_chart_to_file(self): + # """ + # 保存图表到本地文件(唯一实现) + # """ + # print(f"[调试] save_chart_to_file被调用, MATPLOTLIB_AVAILABLE={MATPLOTLIB_AVAILABLE}") + # + # if not MATPLOTLIB_AVAILABLE: + # print("[调试] matplotlib不可用,无法保存") + # return None + # + # if len(self.focus_history) < 2: + # print(f"[调试] 数据点不足,需要至少2个点,当前{len(self.focus_history)}个点") + # return None + # + # print(f"[调试] 开始保存图表到本地文件...") + # + # # 确保所有列表长度一致 + # min_len = min(len(self.beta_history), len(self.alpha_history), + # len(self.theta_history), len(self.focus_history), + # len(self.timestamp_history)) + # + # print(f"[调试] 数据长度: min_len={min_len}") + # + # beta_list = self.beta_history[:min_len] + # alpha_list = self.alpha_history[:min_len] + # theta_list = self.theta_history[:min_len] + # focus_list = self.focus_history[:min_len] + # times = self.timestamp_history[:min_len] + # + # # 生成文件名 + # timestamp = time.strftime("%Y%m%d_%H%M%S") + # chart_path = os.path.join(self.chart_dir, f"concentration_report_{timestamp}.png") + # print(f"[调试] 保存路径: {chart_path}") + # + # try: + # # 创建图表 + # fig, ax1 = plt.subplots(figsize=(14, 8)) + # + # # 左Y轴:功率数据 + # ax1.plot(times, beta_list, 'b-', linewidth=1.5, alpha=0.8, label='Beta Power') + # ax1.plot(times, alpha_list, 'g-', linewidth=1.5, alpha=0.8, label='Alpha Power') + # ax1.plot(times, theta_list, 'orange', linewidth=1.5, alpha=0.8, label='Theta Power') + # ax1.set_xlabel('Time (s)', fontsize=12) + # ax1.set_ylabel('Band Power', fontsize=12, color='black') + # ax1.tick_params(axis='y', labelcolor='black') + # ax1.legend(loc='upper left') + # ax1.grid(True, alpha=0.3) + # + # # 右Y轴:专注度 + # ax2 = ax1.twinx() + # ax2.plot(times, focus_list, 'r-', linewidth=2, alpha=0.9, label='Focus (%)') + # ax2.set_ylabel('Focus (%)', fontsize=12, color='red') + # ax2.tick_params(axis='y', labelcolor='red') + # ax2.set_ylim(0, 105) + # ax2.legend(loc='upper right') + # + # # 标题 + # duration = times[-1] if times else 0 + # avg_focus = np.mean(focus_list) if focus_list else 0 + # plt.title(f'Concentration and EEG Band Power Trend\nDuration: {duration:.1f}s, Avg Focus: {avg_focus:.1f}%', + # fontsize=14) + # + # plt.tight_layout() + # plt.savefig(chart_path, dpi=150, bbox_inches='tight') + # plt.close() + # + # print(f"\n{'='*60}") + # print(f"专注度报告图片已保存到本地:") + # print(f" 文件路径: {chart_path}") + # print(f" 数据点数: {min_len}") + # print(f" 时长: {duration:.1f}秒") + # print(f" 平均专注度: {avg_focus:.1f}%") + # print(f"{'='*60}\n") + # + # return chart_path + # + # except Exception as e: + # print(f"[调试] 保存文件时出错: {e}") + # import traceback + # traceback.print_exc() + # return None + # + # def generate_chart_base64(self): + # """ + # 生成图表的base64编码(用于网络传输) + # """ + # if not MATPLOTLIB_AVAILABLE: + # return None + # + # if len(self.focus_history) < 2: + # return None + # + # min_len = min(len(self.beta_history), len(self.alpha_history), + # len(self.theta_history), len(self.focus_history), + # len(self.timestamp_history)) + # + # beta_list = self.beta_history[:min_len] + # alpha_list = self.alpha_history[:min_len] + # theta_list = self.theta_history[:min_len] + # focus_list = self.focus_history[:min_len] + # times = self.timestamp_history[:min_len] + # + # fig, ax1 = plt.subplots(figsize=(14, 8)) + # + # ax1.plot(times, beta_list, 'b-', linewidth=1.5, alpha=0.8, label='Beta Power') + # ax1.plot(times, alpha_list, 'g-', linewidth=1.5, alpha=0.8, label='Alpha Power') + # ax1.plot(times, theta_list, 'orange', linewidth=1.5, alpha=0.8, label='Theta Power') + # ax1.set_xlabel('Time (s)', fontsize=12) + # ax1.set_ylabel('Band Power', fontsize=12, color='black') + # ax1.tick_params(axis='y', labelcolor='black') + # ax1.legend(loc='upper left') + # ax1.grid(True, alpha=0.3) + # + # ax2 = ax1.twinx() + # ax2.plot(times, focus_list, 'r-', linewidth=2, alpha=0.9, label='Focus (%)') + # ax2.set_ylabel('Focus (%)', fontsize=12, color='red') + # ax2.tick_params(axis='y', labelcolor='red') + # ax2.set_ylim(0, 105) + # ax2.legend(loc='upper right') + # + # duration = times[-1] if times else 0 + # avg_focus = np.mean(focus_list) if focus_list else 0 + # plt.title(f'Concentration and EEG Band Power Trend\nDuration: {duration:.1f}s, Avg Focus: {avg_focus:.1f}%', + # fontsize=14) + # + # plt.tight_layout() + # + # buffer = io.BytesIO() + # plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight') + # buffer.seek(0) + # image_base64 = base64.b64encode(buffer.read()).decode('utf-8') + # plt.close() + # + # return image_base64 + + def queueOpt(self, data): + if data is None or data.size == 0: + return None + if len(self.eegQueue) < self.eegQueue.maxlen: + self.eegQueue.append(data) + else: + self.eegQueue.append(data) + + if len(self.eegQueue) == self.eegQueue.maxlen: + eegData = np.hstack([self.eegQueue[i] for i in range(len(self.eegQueue))]) + if eegData.size == 0: + return None + eegData -= np.mean(eegData, axis=-1, keepdims=True) + eegData = signal.lfilter(self.b_notch, self.a_notch, eegData) + eegData = signal.lfilter(self.b_design, 1, eegData) + focus_score, CLI_score, beta, alpha, theta = self.calculate_all(eegData, fs=self.fs, nperseg=1000) + + # self.add_data_point(focus_score, beta, alpha, theta) + + return focus_score + return None + + +class Calculate2(): + def __init__(self, Threshold_value_low, Threshold_value_high): + self.Threshold_value_low = Threshold_value_low + self.Threshold_value_high = Threshold_value_high + self.focus_result = [] + self.theta_result = [] + self.alpha_result = [] + self.flow_result = [] + + def calculate_all(self, data, fs, L=2500): + mean_x = np.mean(data, axis=-1, keepdims=True) + data = data - mean_x + + Y = fft(data, axis=-1) + P2 = np.abs(Y / L) + P1 = P2[:, :L // 2 + 1] + P1[:, 1:-1] = 2 * P1[:, 1:-1] + + beta_power = self.PSD(P1, L, fs, 13, 30) + alpha_power = self.PSD(P1, L, fs, 8, 13) + theta_power = self.PSD(P1, L, fs, 4, 8) + gamma_power = self.PSD(P1, L, fs, 30, 100) + + focus_score = beta_power / (alpha_power + theta_power) + print('focus score:', focus_score) + focus_score = ((focus_score - self.Threshold_value_low) * 100) / (self.Threshold_value_high - self.Threshold_value_low) + self.focus_result.append(focus_score) + if len(self.focus_result) > 3: + self.focus_result.pop(0) + final_focus = int(self.simple_moving_average(self.focus_result, window_size=3)) + + self.theta_result.append(theta_power) + if len(self.theta_result) > 30: + self.theta_result.pop(0) + self.alpha_result.append(alpha_power) + if len(self.alpha_result) > 30: + self.alpha_result.pop(0) + rest_theta = self.simple_moving_average(self.theta_result, window_size=30) + rest_alpha = self.simple_moving_average(self.alpha_result, window_size=30) + distraction_score = (theta_power / rest_theta) * (1 - (alpha_power / rest_alpha)) + + flow_score = gamma_power / beta_power + flow_score = (flow_score / self.Threshold_value_high) * 100 + self.flow_result.append(flow_score) + if len(self.flow_result) > 3: + self.flow_result.pop(0) + final_flow = int(self.simple_moving_average(self.flow_result, window_size=3)) + + return final_focus, distraction_score, final_flow + + def PSD(self, P1, L, Fs, s_freq, e_freq): + s_point = round(s_freq * L / Fs) + e_point = round(e_freq * L / Fs) + x, y = P1.shape + band_PSD = 0 + for i in range(x): + for j in range(s_point, e_point): + band_PSD += P1[i, j] ** 2 + return band_PSD + + def simple_moving_average(self, data, window_size=3): + if len(data) == 0: + return [] + window = data[-window_size:] + return sum(window) / len(window) \ No newline at end of file diff --git a/Debug_64ch_Decoder_Optimize/config.ini b/Debug_64ch_Decoder_Optimize/config.ini new file mode 100644 index 0000000..4fc2c56 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/config.ini @@ -0,0 +1,161 @@ +[system] +SSVEP_ThresholdValue = [1,-0.023] +;SSVEP_ThresholdValue = [2,-0.00200] +SSMVEP_IntervalEpoch = [0.2,2.2] +list_freqs = [8, 9] +phase = [0, 0] +concentration_ThresholdValue = [0.1, 0.8] +MI_IntervalEpoch = [0.5,4.5] +blink = [70,500,100,500,800,3,2] +Right_rehabilitation = 5 +Fault_rehabilitation = 5 +Num_blocks = 1 +Num_trials = 10 +Audio_device = 0 +Rest_time = 2 +Device_type = 1 +Device_Host = 127.0.0.1 +Device_Port = 5086 +Upper_Host = 127.0.0.1 +Upper_Port = 8088 +Serial_port = COM44 + + +[Layout] +main_splitter_left = 993 +main_splitter_right = 922 +right_splitter_left = 233 +right_splitter_right = 771 +left_splitter_left = 503 +left_splitter_right = 501q + +[channel] +channel_x_fp1 = 419 +channel_y_fp1 = 124 +channel_x_fc1 = 439 +channel_y_fc1 = 296 +channel_x_fp2 = 576 +channel_y_fp2 = 124 +channel_x_fc2 = 556 +channel_y_fc2 = 299 +channel_x_f3 = 397 +channel_y_f3 = 231 +channel_x_cp1 = 439 +channel_y_cp1 = 426 +channel_x_f4 = 601 +channel_y_f4 = 232 +channel_x_cp2 = 559 +channel_y_cp2 = 425 +channel_x_fc3 = 379 +channel_y_fc3 = 295 +channel_x_af4 = 571 +channel_y_af4 = 171 +channel_x_po8 = 645 +channel_y_po8 = 564 +channel_x_fpz = 499 +channel_y_fpz = 112 +channel_x_fcz = 499 +channel_y_fcz = 300 +channel_x_poz = 500 +channel_y_poz = 554 +channel_x_po5 = 387 +channel_y_po5 = 551 +channel_x_po6 = 611 +channel_y_po6 = 551 +channel_x_c3 = 373 +channel_y_c3 = 363 +channel_x_fc5 = 319 +channel_y_fc5 = 292 +channel_x_c4 = 620 +channel_y_c4 = 363 +channel_x_fc6 = 676 +channel_y_fc6 = 288 +channel_x_p3 = 398 +channel_y_p3 = 491 +channel_x_cp5 = 322 +channel_y_cp5 = 430 +channel_x_p4 = 600 +channel_y_p4 = 489 +channel_x_cp6 = 678 +channel_y_cp6 = 430 +channel_x_c5 = 313 +channel_y_c5 = 361 +channel_x_f6 = 650 +channel_y_f6 = 223 +channel_x_f5 = 349 +channel_y_f5 = 224 +channel_x_po4 = 573 +channel_y_po4 = 551 +channel_x_po3 = 429 +channel_y_po3 = 550 +channel_x_cp4 = 619 +channel_y_cp4 = 424 +channel_x_cp3 = 381 +channel_y_cp3 = 426 +channel_x_fc4 = 619 +channel_y_fc4 = 295 +channel_x_o1 = 423 +channel_y_o1 = 598 +channel_x_ft9 = 252 +channel_y_ft9 = 168 +channel_x_o2 = 576 +channel_y_o2 = 597 +channel_x_ft10 = 798 +channel_y_ft10 = 277 +channel_x_f7 = 295 +channel_y_f7 = 214 +channel_x_tp9 = 202 +channel_y_tp9 = 445 +channel_x_f8 = 701 +channel_y_f8 = 215 +channel_x_t7 = 252 +channel_y_t7 = 362 +channel_x_tp7 = 261 +channel_y_tp7 = 436 +channel_x_ft8 = 734 +channel_y_ft8 = 283 +channel_x_ft7 = 264 +channel_y_ft7 = 286 +channel_x_af8 = 645 +channel_y_af8 = 159 +channel_x_af7 = 351 +channel_y_af7 = 160 +channel_x_p6 = 652 +channel_y_p6 = 499 +channel_x_p5 = 348 +channel_y_p5 = 499 +channel_x_c6 = 683 +channel_y_c6 = 362 +channel_x_f1 = 447 +channel_y_f1 = 236 +channel_x_t8 = 745 +channel_y_t8 = 361 +channel_x_f2 = 549 +channel_y_f2 = 235 +channel_x_p7 = 300 +channel_y_p7 = 505 +channel_x_c1 = 435 +channel_y_c1 = 363 +channel_x_p8 = 698 +channel_y_p8 = 508 +channel_x_c2 = 559 +channel_y_c2 = 359 +channel_x_fz = 499 +channel_y_fz = 238 +channel_x_po7 = 354 +channel_y_po7 = 562 +channel_x_tp8 = 735 +channel_y_tp8 = 438 +channel_x_oz = 498 +channel_y_oz = 609 +channel_x_af3 = 428 +channel_y_af3 = 170 +channel_x_pz = 501 +channel_y_pz = 486 +channel_x_p2 = 551 +channel_y_p2 = 483 +channel_x_cz = 499 +channel_y_cz = 361 +channel_x_p1 = 448 +channel_y_p1 = 488 + diff --git a/Debug_64ch_Decoder_Optimize/online_Models/Model_2025-11-15-11-11-50.pth b/Debug_64ch_Decoder_Optimize/online_Models/Model_2025-11-15-11-11-50.pth new file mode 100644 index 0000000..5df5491 Binary files /dev/null and b/Debug_64ch_Decoder_Optimize/online_Models/Model_2025-11-15-11-11-50.pth differ diff --git a/Debug_64ch_Decoder_Optimize/online_Models/Model_2025-11-17-16-55-25.pth b/Debug_64ch_Decoder_Optimize/online_Models/Model_2025-11-17-16-55-25.pth new file mode 100644 index 0000000..efe3af6 Binary files /dev/null and b/Debug_64ch_Decoder_Optimize/online_Models/Model_2025-11-17-16-55-25.pth differ diff --git a/Debug_64ch_Decoder_Optimize/online_Models/Model_2025-11-18-10-15-35.pth b/Debug_64ch_Decoder_Optimize/online_Models/Model_2025-11-18-10-15-35.pth new file mode 100644 index 0000000..b527c6b Binary files /dev/null and b/Debug_64ch_Decoder_Optimize/online_Models/Model_2025-11-18-10-15-35.pth differ diff --git a/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-01-08-14-55-10.pth b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-01-08-14-55-10.pth new file mode 100644 index 0000000..a713669 Binary files /dev/null and b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-01-08-14-55-10.pth differ diff --git a/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-14-21-38.pth b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-14-21-38.pth new file mode 100644 index 0000000..7b4dd25 Binary files /dev/null and b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-14-21-38.pth differ diff --git a/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-15-26-09.pth b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-15-26-09.pth new file mode 100644 index 0000000..293ce95 Binary files /dev/null and b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-15-26-09.pth differ diff --git a/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-15-44-21.pth b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-15-44-21.pth new file mode 100644 index 0000000..801adc3 Binary files /dev/null and b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-15-44-21.pth differ diff --git a/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-16-06-24.pth b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-16-06-24.pth new file mode 100644 index 0000000..3198b66 Binary files /dev/null and b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-16-06-24.pth differ diff --git a/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-16-30-12.pth b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-16-30-12.pth new file mode 100644 index 0000000..fb55250 Binary files /dev/null and b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-16-30-12.pth differ diff --git a/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-16-44-52.pth b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-16-44-52.pth new file mode 100644 index 0000000..08350b8 Binary files /dev/null and b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-26-16-44-52.pth differ diff --git a/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-30-13-08-50.pth b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-30-13-08-50.pth new file mode 100644 index 0000000..a86f47f Binary files /dev/null and b/Debug_64ch_Decoder_Optimize/online_Models/Model_2026-05-30-13-08-50.pth differ diff --git a/Debug_64ch_Decoder_Optimize/online_Models/log_result.txt b/Debug_64ch_Decoder_Optimize/online_Models/log_result.txt new file mode 100644 index 0000000..9dae636 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/online_Models/log_result.txt @@ -0,0 +1,252 @@ +0 0.5 +1 0.5 +2 0.375 +3 0.5 +4 0.4375 +5 0.375 +6 0.5 +7 0.5 +8 0.375 +9 0.375 +10 0.375 +11 0.375 +12 0.5 +13 0.5625 +14 0.5625 +15 0.5 +16 0.5 +17 0.5 +18 0.5 +19 0.5625 +20 0.4375 +21 0.5 +22 0.5 +23 0.375 +24 0.375 +25 0.375 +26 0.375 +27 0.375 +28 0.3125 +29 0.375 +30 0.5625 +31 0.5 +32 0.5 +33 0.5625 +34 0.5625 +35 0.3125 +36 0.3125 +37 0.3125 +38 0.375 +39 0.5625 +40 0.3125 +41 0.5625 +42 0.3125 +43 0.375 +44 0.5625 +45 0.5 +46 0.375 +47 0.375 +48 0.3125 +49 0.375 +50 0.375 +51 0.5 +52 0.5625 +53 0.375 +54 0.5625 +55 0.5625 +56 0.375 +57 0.375 +58 0.375 +59 0.5 +60 0.3125 +61 0.375 +62 0.375 +63 0.375 +64 0.375 +65 0.375 +66 0.3125 +67 0.375 +68 0.5625 +69 0.5625 +70 0.5625 +71 0.5 +72 0.5625 +73 0.375 +74 0.375 +75 0.375 +76 0.375 +77 0.375 +78 0.5 +79 0.375 +80 0.375 +81 0.5 +82 0.375 +83 0.375 +84 0.375 +85 0.375 +86 0.3125 +87 0.375 +88 0.375 +89 0.5 +90 0.375 +91 0.4375 +92 0.3125 +93 0.3125 +94 0.375 +95 0.375 +96 0.375 +97 0.375 +98 0.3125 +99 0.4375 +100 0.375 +101 0.375 +102 0.375 +103 0.3125 +104 0.5625 +105 0.5 +106 0.5625 +107 0.5625 +108 0.5 +109 0.3125 +110 0.5625 +111 0.5625 +112 0.5 +113 0.3125 +114 0.5 +115 0.3125 +116 0.375 +117 0.3125 +118 0.3125 +119 0.3125 +120 0.3125 +121 0.375 +122 0.375 +123 0.375 +124 0.375 +125 0.3125 +126 0.375 +127 0.375 +128 0.375 +129 0.375 +130 0.5625 +131 0.375 +132 0.5 +133 0.3125 +134 0.3125 +135 0.3125 +136 0.375 +137 0.5 +138 0.3125 +139 0.375 +140 0.3125 +141 0.3125 +142 0.3125 +143 0.5625 +144 0.3125 +145 0.375 +146 0.5 +147 0.5 +148 0.375 +149 0.4375 +150 0.5 +151 0.3125 +152 0.375 +153 0.375 +154 0.375 +155 0.3125 +156 0.375 +157 0.4375 +158 0.4375 +159 0.375 +160 0.375 +161 0.3125 +162 0.375 +163 0.375 +164 0.375 +165 0.3125 +166 0.3125 +167 0.3125 +168 0.375 +169 0.3125 +170 0.3125 +171 0.3125 +172 0.375 +173 0.3125 +174 0.3125 +175 0.5 +176 0.3125 +177 0.375 +178 0.375 +179 0.3125 +180 0.3125 +181 0.3125 +182 0.3125 +183 0.5625 +184 0.5625 +185 0.3125 +186 0.5 +187 0.5 +188 0.5625 +189 0.5 +190 0.5625 +191 0.5625 +192 0.5625 +193 0.5 +194 0.5 +195 0.5625 +196 0.5625 +197 0.5625 +198 0.5625 +199 0.5 +200 0.5625 +201 0.5625 +202 0.375 +203 0.375 +204 0.375 +205 0.375 +206 0.375 +207 0.5 +208 0.5 +209 0.5625 +210 0.5625 +211 0.5625 +212 0.3125 +213 0.5 +214 0.5 +215 0.5625 +216 0.5 +217 0.5 +218 0.5 +219 0.5625 +220 0.5 +221 0.4375 +222 0.5 +223 0.5 +224 0.4375 +225 0.5 +226 0.4375 +227 0.5 +228 0.5 +229 0.375 +230 0.375 +231 0.3125 +232 0.375 +233 0.375 +234 0.375 +235 0.5625 +236 0.5625 +237 0.5625 +238 0.5625 +239 0.5625 +240 0.5 +241 0.5 +242 0.5 +243 0.5625 +244 0.5625 +245 0.375 +246 0.375 +247 0.375 +248 0.3125 +249 0.375 +The average accuracy is: 0.42675 +The best accuracy is: 0.5625 diff --git a/Debug_64ch_Decoder_Optimize/rthook.py b/Debug_64ch_Decoder_Optimize/rthook.py new file mode 100644 index 0000000..3a5424b --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/rthook.py @@ -0,0 +1,13 @@ +import sys +import os +import multiprocessing + +# 1. 路径自适应:在 Frozen 模式下,将当前工作目录切换到可执行文件所在目录 +# 这样代码中使用的相对路径(如 './config.ini')就能正确指向 exe 旁边的文件 +if getattr(sys, 'frozen', False): + os.chdir(os.path.dirname(sys.executable)) + +# 2. 多进程保护:防止 Windows 下的无限递归炸弹 +# Windows 下 multiprocessing 需要 freeze_support() +if sys.platform.startswith('win'): + multiprocessing.freeze_support() diff --git a/Debug_64ch_Decoder_Optimize/runDecoder.py b/Debug_64ch_Decoder_Optimize/runDecoder.py new file mode 100644 index 0000000..9199d14 --- /dev/null +++ b/Debug_64ch_Decoder_Optimize/runDecoder.py @@ -0,0 +1,35 @@ +import matplotlib +matplotlib.use('Agg') +import argparse +import sys +import time +from Decoder import Decoder_main +from PubLibrary.RunOnce import is_program_running + +if __name__ == "__main__": + if not is_program_running(): + # 解析命令行参数 + parser = argparse.ArgumentParser(description="EEG Decoder Application") + parser.add_argument('-dt', '--device-type', type=int, default=None, help="Device Type") + parser.add_argument('-dh', '--device-host', type=str, default=None, help="Device Host IP") + parser.add_argument('-dp', '--device-port', type=int, default=None, help="Device Port") + parser.add_argument('-uh', '--upper-host', type=str, default=None, help="Upper Computer Host IP") + parser.add_argument('-up', '--upper-port', type=int, default=None, help="Upper Computer Port") + + args = parser.parse_args() + + decoder = Decoder_main() + decoder.connect( + device_type=args.device_type, + device_host=args.device_host, + device_port=args.device_port, + upper_host=args.upper_host, + upper_port=args.upper_port + ) + + try: + decoder.start() + while not decoder.zmqServer.IsExitApp: + time.sleep(1) + except KeyboardInterrupt: + decoder.stop() diff --git a/Debug_64ch_Decoder_Optimize/ssvep&ssmvep&mi&concentrate&pvs&blink merge.txt b/Debug_64ch_Decoder_Optimize/ssvep&ssmvep&mi&concentrate&pvs&blink merge.txt new file mode 100644 index 0000000..e69de29