import ast import glob import os import sys import threading from datetime import datetime import multiprocessing as mp import numpy as np import time import torch from queue import Empty from scipy import signal from torch.autograd import Variable # from Device.SunnyLinker import SunnyLinker64 from SSMVEP.algorithm.tdca import TDCA from SSMVEP.algorithm.base import generate_cca_references # from concentration.algorithm.calculate_focus import Calculate # from blinkdetection.algorithm.eye_detection import blink_detection from Zmq.zmqServer import zmqServer from Zmq.zmqClient import zmqClient from MI.Algorithm.conformer_2class import onlineTrain from PubLibrary.InifileHelper import IniRead from logs.log import algo_log from SSVEP.dwfbcca import FbccaDw # from Tools.plot_MI_EEG import plotMain from collections import deque from Zmq.filterProcess import SlidingFilter save_train_data = int(IniRead('system', 'save_train_data', 0)) def get_root_path(): """ Nuitka 打包专用:获取程序根目录(.py 或 .exe 所在目录) """ if getattr(sys, 'frozen', False): # 打包后:返回 exe 所在目录 return os.path.dirname(sys.executable) else: # 开发时:返回 py 文件所在目录 return os.path.dirname(os.path.abspath(__file__)) MODEL_FOLDER = "online_Models" class Decoder_main(threading.Thread): def __init__(self, device_info=None): threading.Thread.__init__(self) self.device_info = device_info self.Runing=True self.decoder = None self.decoder_class = None #解码器类别 self.decodingSteps = 0 # 0=停止解码 1=预热 2=解码中 3=解码完成,发送解码结果 self.zmqServer = zmqServer(device_info=self.device_info) self.zmqServer.start() # 启动ZMQ接收线程 self.sliding_filter = SlidingFilter( ring_buffer=self.zmqServer.filterBuffer, n_chan=self.zmqServer.device_info['channel_nums'], srate=self.zmqServer.device_info['sample_rate'] ) # 注册滤波结果回调(示例:打印数据形状) self.sliding_filter.filter_result_callback = self.zmqServer.send_filtered_data # 注册 beta_psd 广播回调,每秒通过 8099 端口发送给上位机 self.sliding_filter.set_beta_broadcast_callback(lambda v: self.zmqServer.broadcast_message('beta_psd', v)) def is_valid_signal(self, data, threshold=1e5): # 判断当前信号是否为有效信号 # data: (chans, samples) energy = np.mean(np.var(data, axis=1)) # 各通道方差均值 if energy > threshold: return False return True def init_Decoder(self,decoder_class): ''' 初始化解码器 :param decoder_class: 'ssvep' or 'ssmvep' or 'mi' or 'concentration' or '' :return: ''' self.decoder_class = decoder_class if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs': self.n_chan = 8 # self.thread_data_server.interval_inited = False DW_cost_method, self.DW_cost_tv = ast.literal_eval(IniRead('system', 'SSVEP_ThresholdValue')) self.ListFreq = self.zmqServer.targetFreqs self.num_target = len(self.ListFreq) if self.num_target == 0: return # 初始化对象 二代算法 self.dw = FbccaDw(self.device_info['sample_rate'], self.num_target, self.n_chan, 5, 5, 0.2, [2.0, 0.1], [8, 7], 50, DW_cost_method) # frequence band self.dw.filterFrequenceBank() self.dw.setNotchFilterPara() self.calculateCount = 0 self.referenceData = self.dw.reference(self.ListFreq, int(50 * 0.2 * self.device_info['sample_rate']), 5) self.dw.filterInit() self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 elif decoder_class == 'ssmvep': self.zmqServer.interval_init(decoder_class) self.n_chan = 8 self.interval_epoch = ast.literal_eval(IniRead('system', 'SSMVEP_IntervalEpoch')) # [0.2, 2.2] self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度2s,# 精确到小数点后6位 self.single_train = 10 # 单类别数量 self.num_target = 2 # 分类目标数目 self.list_freqs = np.array([8, 9]) # 刺激频率 self.list_phase = np.array([0, 0]) # 相位 self.tdca = TDCA(padding_len=5, n_components=1) self.Yf = generate_cca_references(self.list_freqs, srate=self.device_info['sample_rate'], T=self.sample_length, phases=self.list_phase, n_harmonics=5) self.parameter_init(5,45) elif decoder_class == 'mi' or decoder_class == 'ma': self.zmqServer.interval_init(decoder_class) self.n_chan = 21 self.interval_epoch = ast.literal_eval(IniRead('system', 'MI_IntervalEpoch')) # [0.5, 4.5] self.sample_length = round(self.interval_epoch[1] - self.interval_epoch[0], 6) # 解码数据长度4s,# 精确到小数点后6位 self.single_train = 40 # 单类别数量 self.num_target = 2 # 分类目标数目 self.parameter_init(8, 30) # elif decoder_class == 'concentration': # self.thread_data_server.interval_inited = False # self.n_chan = 6 # self.win_len = 10 # self.win_step = 1 # self.low_threshold, self.high_threshold = ast.literal_eval(IniRead('system', 'concentration_ThresholdValue')) # self.calculate = Calculate(self.low_threshold, self.high_threshold, self.device_info['sample_rate'], self.win_len) # self.interval_epoch = [0, 1] # self.parameter_init(2, 40) # # self.eegQueue moved to Calculate class # elif decoder_class == 'blink': # self.n_chan = 2 # self.l_freq = 0.1 # 带通滤波器低频截止 # self.h_freq = 8.0 # 带通滤波器高频截止 # self.total_samples = 0 # 总采样点数 # self.window_ms = 600 # 检测窗口大小 (ms) # self.step_ms = 100 # 滑动步长 (ms) # self.window_samples = int(self.window_ms * self.device_info['sample_rate'] / 1000) # 150个样本点 # self.step_samples = int(self.step_ms * self.device_info['sample_rate'] / 1000) # 25个样本点 # self.buffer_size = self.window_samples + self.step_samples * 5 # self.fp1_buffer = deque(maxlen=self.buffer_size) # self.fp2_buffer = deque(maxlen=self.buffer_size) # self.sample_counter = 0 # # 预计算滤波器系数,避免在循环中重复设计 # self.Dmin, self.Dmax, self.EMin, self.EMax, self.jitterwin,self.double_blink_interval,self.double_blink_jitter = ast.literal_eval(IniRead('system', 'blink')) # self.blink_count = 0 # 单次眨眼的次数 # self.last_blink_time = 0 # 上次检测到单次眨眼的时间(样本索引) # self.blink_timestamps = deque(maxlen=10) # 记录最近10次 单次眨眼的时间戳 # self.double_blink_count = 0 # 连续两次眨眼的次数 # self.double_blink_events = [] # 连续眨眼事件记录 # self.last_double_blink_time = 0 # 上次检测到连续眨眼的时间戳 # self.blink_events = [] # self.blink_b, self.blink_a = signal.butter(4, [self.l_freq / (self.device_info['sample_rate'] / 2), self.h_freq / (self.device_info['sample_rate'] / 2)], btype='band') def parameter_init(self,bandPass_low,bandPass_high): self.interval_epoch = [int(i * self.device_info['sample_rate']) for i in self.interval_epoch] # epoch截取信息 self.train_epoch = [int(self.interval_epoch[0]), int(self.interval_epoch[1] + 0.1 * self.device_info['sample_rate'])] # 训练样本epoch self.trainData = [] #训练数据 self.trainLabel = [] #训练标签 self.plotData = [] #报告分析数据 self.plotLabel = [] #报告分析标签 self.currentLabel = -1 #刺激界面当前显示的训练标签 self.train_started = False #是否开始训练模型 self.load_model = False # 调用模型是否完成的标志 self.b_notch, self.a_notch = signal.iirnotch(50 / (self.device_info['sample_rate']/2), 30) # 50Hz工频陷波,250是采样率,30是质量因子 self.b_design = signal.firwin(65, [bandPass_low / (self.device_info['sample_rate']/2), bandPass_high / (self.device_info['sample_rate']/2)], pass_zero=False) # 设计8-30Hz带通滤波器 filePath = os.path.join(get_root_path(), MODEL_FOLDER) + os.sep for old_pth in glob.glob(os.path.join(filePath, '*.pth')): os.remove(old_pth) fileName = 'Model_' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S') self.modelPath = ''.join([filePath, fileName, '.pth']) self.mp_data_queue = mp.Queue() self.mp_result_queue = mp.Queue() def preprocess(self, signal_data): # # 计算每行的平均值 row_means = np.mean(signal_data, axis=-1, keepdims=True) # 对每一行去均值 signal_data = signal_data - row_means signal_data = signal.lfilter(self.b_notch, self.a_notch, signal_data, axis=-1) # 工频陷波 signal_data = signal.lfilter(self.b_design, 1, signal_data, axis=-1) # 带通滤波 return signal_data def run(self): while self.Runing: # 当滤波数据大于5秒时,启动滤波线程 if not self.sliding_filter.is_alive() and self.zmqServer.filterBuffer.GetDataLenCount() > self.device_info['sample_rate'] * 5: algo_log("启动滤波线程", level="DEBUG") self.sliding_filter.start() if self.zmqServer.decoder_switch or self.zmqServer.changeTarget: algo_log(f"Decoder_class Switch Detected: {self.zmqServer.decoder_class}", level="DEBUG") self.zmqServer.decoder_switch = False self.zmqServer.changeTarget = False self.reset_state() # 切换前先统一清理旧状态 self.init_Decoder(self.zmqServer.decoder_class) # 同步信息 if self.zmqServer.state_mode == 'sync': # self.zmqClient.send_to_all('sync', self.zmqClient.state) self.zmqServer.state_mode = 'rest' try: if self.zmqServer.open_Impedance: time.sleep(0.005) continue if self.decoder_class == 'ssvep' or self.decoder_class == 'pvs': self.decoder_SSVEP() elif self.decoder_class == 'ssmvep': self.decoder_SSMVEP() elif self.decoder_class == 'mi': self.decoder_MI() else: if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: time.sleep(0.005) continue self.zmqServer.paradigmBuffer.getData(25) except Exception as e: algo_log(f"Decoder Loop Error: {e}") time.sleep(0.1) # Prevent CPU spin if error is persistent def decoder_SSVEP(self): if self.zmqServer.StartDecode: self.zmqServer.StartDecode = False self.decodingSteps = 1 self.zmqServer.paradigmBuffer.resetAllPara() algo_log('启动SSVEP预测', level="DEBUG") if self.zmqServer.paradigmBuffer.GetDataLenCount() < 50: time.sleep(0.005) return if self.zmqServer.open_Impedance: # 阻抗检测状态不解码 return data = self.zmqServer.paradigmBuffer.getDataViaSSVEP(50) # algo_log(f"SSVEP取出的:{data.shape}, data = {data[:, :10]}", level="DEBUG") data = data[:self.n_chan, :] if self.decodingSteps == 1 and hasattr(self,'dw'): # 开始预热 self.dw.onlineInit() # 刺激闪烁的第1s重置 --在线数据采集时 self.dw.warmFilter(data) # 预热 self.decodingSteps = 2 algo_log('SSVEP预热数据完成。开始预测', level="DEBUG") return if self.decodingSteps == 2 and hasattr(self,'dw'): # 解码中 choosenNum = self.dw.fbccaDWMW(data, self.referenceData, self.DW_cost_tv, self.calculateCount) self.calculateCount += 1 if choosenNum != -1 and self.is_valid_signal(data): self.decodingSteps = 3 algo_log('SSVEP预测结果:' + str(choosenNum) + ',计算次数:' + str(self.calculateCount), level="DEBUG") self.calculateCount = 0 if self.decodingSteps == 3: # 发送解码后的信息 self.zmqServer.broadcast_message('result', int(choosenNum)) self.decodingSteps = 0 algo_log('SSVEP发送给界面完成。', level="DEBUG") def decoder_SSMVEP(self): '''模型训练''' if self.load_model == False and all( self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练完成 self.trainData = np.array(self.trainData) self.trainLabel = np.array(self.trainLabel) algo_log(f"开始SSMVEP模型训练,数据形状:{np.shape(self.trainData)},标签形状:{self.trainLabel.shape}", level="DEBUG") if save_train_data == 1: now_str = datetime.now().strftime("%Y%m%d_%H%M%S") save_path = f"{now_str}.npz" np.savez(save_path, array1=self.trainData, array2=self.trainLabel) self.decoder = self.tdca.fit(self.trainData, self.trainLabel, Yf=self.Yf) now = datetime.now() formatted_time = now.strftime('%H:%M:%S.%f')[:-3] algo_log(f"SSMVEP模型训练完成,时间:{formatted_time}", level="DEBUG") self.load_model = True self.zmqServer.broadcast_message('paradigm', 1) '''训练阶段采集数据''' if self.zmqServer.state_mode == 'train': # 训练状态 if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \ self.train_epoch[1] + self.zmqServer.event_inner_idx: self.currentLabel = self.zmqServer.currentLabel trainTrial = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 取出所有数据 algo_log(f"取出的:{trainTrial.shape},event:{trainTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") trainTrial = self.preprocess(trainTrial[:self.n_chan, :]) # 预处理 trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.train_epoch[ 0]:self.zmqServer.event_inner_idx + self.train_epoch[1]] if trainTrial.shape[1] == (self.train_epoch[1] - self.train_epoch[0]) and isinstance( self.trainLabel, list) \ and self.trainLabel.count(self.currentLabel) < self.single_train: algo_log(f"SSMVEP训练集:{np.shape(self.trainData)}", level="DEBUG") self.trainData.append(trainTrial) self.trainLabel.append(self.currentLabel) else: time.sleep(0.0001) return elif self.zmqServer.state_mode == 'predict': # 测试状态 if self.load_model == False: # 模型尚未训练完成 time.sleep(0.01) return else: # 已有模型 if self.zmqServer.StartDecode: self.zmqServer.StartDecode = False now = datetime.now() formatted_time = now.strftime('%H:%M:%S.%f')[:-3] algo_log(f"SSMVEP模型启动预测 {formatted_time}", level="DEBUG") if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ self.interval_epoch[1] \ + self.zmqServer.event_inner_idx: # algo_log(f"SSMVEP模型启动预测 {self.zmqServer.epoch_finished}", level="DEBUG") time.sleep(0.0001) return data = self.zmqServer.paradigmBuffer.get_SSMVEPData() # 读取全部数据 algo_log(f"取出的:{data.shape}, event: {data[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") data = self.preprocess(data[:self.n_chan, :]) # 预处理 data = data[:, self.zmqServer.event_inner_idx + self.interval_epoch[ 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]] pad_eeg_test = np.zeros( (data.shape[0], int((self.sample_length + 0.1) * self.device_info['sample_rate']))) pad_eeg_test[:, :int(self.sample_length * self.device_info['sample_rate'])] = data choosenNum, features_2 = self.decoder.predict(pad_eeg_test) if isinstance(choosenNum, np.ndarray): choosenNum = choosenNum[0] algo_log(f"结果:{choosenNum}, rho: {sorted(features_2[0])[-1] - sorted(features_2[0])[-2]}", level="DEBUG") self.zmqServer.broadcast_message('result', int(choosenNum)) algo_log("SSMVEP发送给界面完成。", level="DEBUG") else: # 休息状态 if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: time.sleep(0.005) return self.zmqServer.paradigmBuffer.getData(25) def decoder_MI(self): '''模型训练''' if self.train_started == False and all( self.trainLabel.count(i) >= self.single_train for i in [1, 2]): # 模型尚未训练 self.zmqServer.broadcast_message('paradigm', 2) # 模型训练前,训练集采集完毕,通知上位机 self.train_started = True self.trainData = np.array(self.trainData) self.trainLabel = np.array(self.trainLabel) algo_log(f"MI开始训练,训练集:{np.shape(self.trainData)},标签shape:{np.shape(self.trainLabel)}", level="DEBUG") if save_train_data == 1: now_str = datetime.now().strftime("%Y%m%d_%H%M%S") save_path = f"{now_str}.npz" np.savez(save_path, array1=self.trainData, array2=self.trainLabel) p = mp.Process(target=onlineTrain, args=(self.mp_data_queue, self.mp_result_queue)) # 开启子进程,训练模型 p.start() self.mp_data_queue.put({'data': self.trainData, 'label': self.trainLabel, 'modelPath': self.modelPath, 'n_chan': self.n_chan}) '''检查模型是否训练完成,调用''' if self.load_model == False and self.train_started == True: try: result = self.mp_result_queue.get_nowait() if result['status'] == 'success': algo_log("MI模型训练完成,加载新模型", level="DEBUG") # 调用模型 self.model = torch.load(self.modelPath, weights_only=False) self.model.eval() # 模型预热 warmup_data = np.random.uniform(-1, 1, (1, 1, self.n_chan, 1000)) warmup_data = torch.from_numpy(warmup_data) warmup_data = Variable(warmup_data.type(torch.cuda.FloatTensor)) with torch.no_grad(): _ = self.model(warmup_data) self.load_model = True self.zmqServer.broadcast_message('paradigm', 1) # 模型调用完毕,通知上位机 else: algo_log("MI训练失败: " + result['msg'], level="DEBUG") except Empty: pass # 还没完成 except Exception as e: algo_log("MI模型训练失败: " + str(e), level="DEBUG") '''训练阶段采集数据''' if self.zmqServer.state_mode == 'train' and self.train_started == False: # 训练状态 if self.zmqServer.epoch_finished and self.zmqServer.paradigmBuffer.GetDataLenCount() >= \ self.zmqServer.train_epoch[1] + self.zmqServer.event_inner_idx: self.currentLabel = self.zmqServer.currentLabel # 同步当前标签 algo_log(f"训练队列数据:{self.zmqServer.paradigmBuffer.GetDataLenCount()}", level="DEBUG") originalTrial = self.zmqServer.paradigmBuffer.get_MIData() # 取出MI导联数据 algo_log(f"取出的:{originalTrial.shape},event: {originalTrial[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") trainTrial = self.preprocess(originalTrial[:self.n_chan, :]) # 预处理 trainTrial = trainTrial[:, self.zmqServer.event_inner_idx + self.interval_epoch[ 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]] # algo_log(f"trial: {self.zmqServer.event_inner_idx},{self.interval_epoch[0]},{self.interval_epoch[1]}", level="DEBUG") if trainTrial.shape[1] == (self.interval_epoch[1] - self.interval_epoch[0]) and isinstance(self.trainLabel, list) \ and self.trainLabel.count(self.currentLabel) < self.single_train: self.trainData.append(trainTrial) self.trainLabel.append(self.currentLabel) algo_log(f"训练集:{np.shape(self.trainData)}", level="DEBUG") self.plotData.append(originalTrial[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[ 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]) self.plotLabel.append(self.currentLabel) else: time.sleep(0.0001) return elif self.zmqServer.state_mode == 'predict' and self.load_model == True: # 测试状态 if self.zmqServer.StartDecode: self.zmqServer.StartDecode = False now = datetime.now() formatted_time = now.strftime('%H:%M:%S.%f')[:-3] algo_log(f"MI启动预测 {formatted_time}", level="DEBUG") if self.zmqServer.epoch_finished == False or self.zmqServer.paradigmBuffer.GetDataLenCount() < \ self.interval_epoch[1] \ + self.zmqServer.event_inner_idx: time.sleep(0.0001) return originalData = self.zmqServer.paradigmBuffer.get_MIData() # 读取全部数据 algo_log(f"取出的:{originalData.shape},event: {originalData[-2, self.zmqServer.event_inner_idx]}", level="DEBUG") start = time.time() data = self.preprocess(originalData[:self.n_chan, :]) # 预处理 data = data[:, self.zmqServer.event_inner_idx + self.interval_epoch[ 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]] self.plotData.append( originalData[:self.n_chan, self.zmqServer.event_inner_idx + self.interval_epoch[ 0]:self.zmqServer.event_inner_idx + self.interval_epoch[1]]) test_data = data[np.newaxis, np.newaxis, :, :] test_data = torch.from_numpy(test_data) test_data = Variable(test_data.type(torch.cuda.FloatTensor)) with torch.no_grad(): Cls = self.model(test_data) y_pred = torch.max(Cls, 1)[1] self.plotLabel.append(int(y_pred.item())) algo_log(f"MI运动意图识别: {y_pred}") self.zmqServer.broadcast_message('result', int(y_pred.item())) end = time.time() algo_log(f'MI发送给界面完成,耗时{end - start:.3f}s。') else: # 休息状态 if self.zmqServer.paradigmBuffer.GetDataLenCount() < 25: time.sleep(0.005) return self.zmqServer.paradigmBuffer.getData(25) # def decoder_concentration(self): # if self.zmqServer.state_mode == 'predict': # if self.zmqServer.StartDecode: # self.zmqServer.StartDecode = False # self.thread_data_server.ResetAll() # now = datetime.now() # formatted_time = now.strftime('%H:%M:%S.%f')[:-3] # print('启动专注力预测 ', formatted_time) # if self.thread_data_server.GetDataLenCount() < int(self.win_step * self.device_info['sample_rate']): # 每win_step得出一次结果 # time.sleep(0.005) # return # if self.zmqServer.get_Impedance != False: # 阻抗检测状态不解码 # return # data = self.thread_data_server.get_concentrateData(int(self.win_step * self.device_info['sample_rate'])) # 修改每次读取的数据 # result = self.calculate.queueOpt(data) # if result is not None: # self.zmqClient.send_to_all('result', int(result)) # else: # 休息状态 # if self.zmqServer.get_Impedance == False: # 非阻抗检测状态 # if self.thread_data_server.GetDataLenCount() < 25: # time.sleep(0.005) # return # self.thread_data_server.getData(25) def stop(self): ''' 停止运行 @return: ''' self.zmqServer.stop() self.sliding_filter.stop() self.Runing=False def reset_state(self): """清空解码器状态和缓存数据""" # 重置设备层缓存 self.zmqServer.reset_state() # 重置解码状态 self.decodingSteps = 0 self.calculateCount = 0 # 重置训练数据 self.plotData = [] self.plotLabel = [] self.trainData = [] self.trainLabel = [] self.currentLabel = -1 self.train_started = False self.load_model = False # 重置多进程队列,确保切换 decoder 时旧数据不会泄漏到新队列 if hasattr(self, 'mp_data_queue'): while not self.mp_data_queue.empty(): try: self.mp_data_queue.get_nowait() except Empty: pass if hasattr(self, 'mp_result_queue'): while not self.mp_result_queue.empty(): try: self.mp_result_queue.get_nowait() except Empty: pass